Add LTX-2.3 model architecture with prompt-conditioned adaptive layer normalization (adaln) support. Introduce gating mechanisms in attention modules and update transformer configurations to accommodate new parameters. Refactor video and audio processing to utilize adaptive normalization, improving model flexibility and performance. Update weight loading and initialization logic to support dynamic block structures in the decoder.

This commit is contained in:
Prince Canuma
2026-03-10 16:47:36 +01:00
parent d028b239fb
commit 207c223354
8 changed files with 545 additions and 239 deletions

View File

@@ -67,17 +67,8 @@ class Attention(nn.Module):
dim_head: int = 64,
norm_eps: float = 1e-6,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
has_gate_logits: bool = False,
):
"""Initialize attention module.
Args:
query_dim: Dimension of query input
context_dim: Dimension of context (key/value) input. If None, same as query_dim
heads: Number of attention heads
dim_head: Dimension per head
norm_eps: Epsilon for RMS normalization
rope_type: Type of rotary position embedding
"""
super().__init__()
self.rope_type = rope_type
@@ -99,6 +90,10 @@ class Attention(nn.Module):
# Output projection
self.to_out = nn.Linear(inner_dim, query_dim, bias=True)
# Per-head gating (LTX-2.3)
if has_gate_logits:
self.to_gate_logits = nn.Linear(query_dim, heads, bias=True)
def __call__(
self,
x: mx.array,
@@ -119,6 +114,11 @@ class Attention(nn.Module):
Returns:
Attention output of shape (B, seq_len, query_dim)
"""
# Compute per-head gate early (from original input)
gate = None
if hasattr(self, "to_gate_logits"):
gate = 2.0 * mx.sigmoid(self.to_gate_logits(x)) # (B, seq, heads)
# Compute Q, K, V
q = self.to_q(x)
context = x if context is None else context
@@ -138,5 +138,12 @@ class Attention(nn.Module):
# Compute attention
out = scaled_dot_product_attention(q, k, v, self.heads, mask)
# Apply per-head gating
if gate is not None:
b, seq_len, _ = out.shape
out = mx.reshape(out, (b, seq_len, self.heads, self.dim_head))
out = out * gate[..., None]
out = mx.reshape(out, (b, seq_len, -1))
# Project output
return self.to_out(out)