Add LTX-2.3 model architecture with prompt-conditioned adaptive layer normalization (adaln) support. Introduce gating mechanisms in attention modules and update transformer configurations to accommodate new parameters. Refactor video and audio processing to utilize adaptive normalization, improving model flexibility and performance. Update weight loading and initialization logic to support dynamic block structures in the decoder.
This commit is contained in:
@@ -67,17 +67,8 @@ class Attention(nn.Module):
|
||||
dim_head: int = 64,
|
||||
norm_eps: float = 1e-6,
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||
has_gate_logits: bool = False,
|
||||
):
|
||||
"""Initialize attention module.
|
||||
|
||||
Args:
|
||||
query_dim: Dimension of query input
|
||||
context_dim: Dimension of context (key/value) input. If None, same as query_dim
|
||||
heads: Number of attention heads
|
||||
dim_head: Dimension per head
|
||||
norm_eps: Epsilon for RMS normalization
|
||||
rope_type: Type of rotary position embedding
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.rope_type = rope_type
|
||||
@@ -99,6 +90,10 @@ class Attention(nn.Module):
|
||||
# Output projection
|
||||
self.to_out = nn.Linear(inner_dim, query_dim, bias=True)
|
||||
|
||||
# Per-head gating (LTX-2.3)
|
||||
if has_gate_logits:
|
||||
self.to_gate_logits = nn.Linear(query_dim, heads, bias=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
@@ -119,6 +114,11 @@ class Attention(nn.Module):
|
||||
Returns:
|
||||
Attention output of shape (B, seq_len, query_dim)
|
||||
"""
|
||||
# Compute per-head gate early (from original input)
|
||||
gate = None
|
||||
if hasattr(self, "to_gate_logits"):
|
||||
gate = 2.0 * mx.sigmoid(self.to_gate_logits(x)) # (B, seq, heads)
|
||||
|
||||
# Compute Q, K, V
|
||||
q = self.to_q(x)
|
||||
context = x if context is None else context
|
||||
@@ -138,5 +138,12 @@ class Attention(nn.Module):
|
||||
# Compute attention
|
||||
out = scaled_dot_product_attention(q, k, v, self.heads, mask)
|
||||
|
||||
# Apply per-head gating
|
||||
if gate is not None:
|
||||
b, seq_len, _ = out.shape
|
||||
out = mx.reshape(out, (b, seq_len, self.heads, self.dim_head))
|
||||
out = out * gate[..., None]
|
||||
out = mx.reshape(out, (b, seq_len, -1))
|
||||
|
||||
# Project output
|
||||
return self.to_out(out)
|
||||
|
||||
Reference in New Issue
Block a user