Add LTX-2.3 model architecture with prompt-conditioned adaptive layer normalization (adaln) support. Introduce gating mechanisms in attention modules and update transformer configurations to accommodate new parameters. Refactor video and audio processing to utilize adaptive normalization, improving model flexibility and performance. Update weight loading and initialization logic to support dynamic block structures in the decoder.

This commit is contained in:
Prince Canuma
2026-03-10 16:47:36 +01:00
parent d028b239fb
commit 207c223354
8 changed files with 545 additions and 239 deletions

View File

@@ -20,20 +20,25 @@ class Modality:
context_mask: Optional[mx.array] = None
# Optional precomputed positional embeddings (RoPE) to avoid recomputation
positional_embeddings: Optional[Tuple[mx.array, mx.array]] = None
# Raw sigma value (scalar per batch) for prompt adaln (LTX-2.3)
sigma: Optional[mx.array] = None
@dataclass(frozen=True)
class TransformerArgs:
x: mx.array
context: mx.array
context_mask: Optional[mx.array]
timesteps: mx.array
embedded_timestep: mx.array
positional_embeddings: Tuple[mx.array, mx.array]
cross_positional_embeddings: Optional[Tuple[mx.array, mx.array]]
cross_scale_shift_timestep: Optional[mx.array]
cross_gate_timestep: Optional[mx.array]
x: mx.array
context: mx.array
context_mask: Optional[mx.array]
timesteps: mx.array
embedded_timestep: mx.array
positional_embeddings: Tuple[mx.array, mx.array]
cross_positional_embeddings: Optional[Tuple[mx.array, mx.array]]
cross_scale_shift_timestep: Optional[mx.array]
cross_gate_timestep: Optional[mx.array]
enabled: bool
# LTX-2.3: prompt-conditioned timestep embeddings for cross-attention
prompt_timesteps: Optional[mx.array] = None
prompt_embedded_timestep: Optional[mx.array] = None
class BasicAVTransformerBlock(nn.Module):
@@ -50,20 +55,13 @@ class BasicAVTransformerBlock(nn.Module):
audio: Optional[TransformerConfig] = None,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
norm_eps: float = 1e-6,
has_prompt_adaln: bool = False,
):
"""Initialize transformer block.
Args:
idx: Block index
video: Video modality configuration
audio: Audio modality configuration
rope_type: Type of rotary position embedding
norm_eps: Epsilon for normalization
"""
super().__init__()
self.idx = idx
self.norm_eps = norm_eps
self.has_prompt_adaln = has_prompt_adaln
# Video components
if video is not None:
@@ -74,6 +72,7 @@ class BasicAVTransformerBlock(nn.Module):
context_dim=None, # Self-attention
rope_type=rope_type,
norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
)
self.attn2 = Attention(
query_dim=video.dim,
@@ -82,10 +81,15 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=video.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
)
self.ff = FeedForward(video.dim, dim_out=video.dim)
# 6 scale-shift parameters: 3 for attention, 3 for MLP
self.scale_shift_table = mx.zeros((6, video.dim))
# 9 params for LTX-2.3 (self-attn + cross-attn + FFN), 6 for LTX-2
num_ada_params = 9 if has_prompt_adaln else 6
self.scale_shift_table = mx.zeros((num_ada_params, video.dim))
if has_prompt_adaln:
self.prompt_scale_shift_table = mx.zeros((2, video.dim))
# Audio components
if audio is not None:
@@ -96,6 +100,7 @@ class BasicAVTransformerBlock(nn.Module):
context_dim=None,
rope_type=rope_type,
norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
)
self.audio_attn2 = Attention(
query_dim=audio.dim,
@@ -104,9 +109,14 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
)
self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)
self.audio_scale_shift_table = mx.zeros((6, audio.dim))
num_audio_ada_params = 9 if has_prompt_adaln else 6
self.audio_scale_shift_table = mx.zeros((num_audio_ada_params, audio.dim))
if has_prompt_adaln:
self.audio_prompt_scale_shift_table = mx.zeros((2, audio.dim))
# Cross-modal attention (when both video and audio are enabled)
if audio is not None and video is not None:
@@ -118,6 +128,7 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
)
# Video-to-Audio: Q from audio, K/V from video
self.video_to_audio_attn = Attention(
@@ -127,6 +138,7 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
)
# Scale-shift tables for cross-attention
self.scale_shift_table_a2v_ca_audio = mx.zeros((5, audio.dim))
@@ -254,11 +266,23 @@ class BasicAVTransformerBlock(nn.Module):
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa
# Cross-attention with text context
vx = vx + self.attn2(
rms_norm(vx, eps=self.norm_eps),
context=video.context,
mask=video.context_mask,
)
if self.has_prompt_adaln:
# LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln
vshift_q, vscale_q, vgate_q = self.get_ada_values(
self.scale_shift_table, vx.shape[0], video.timesteps, slice(6, 9)
)
vprompt_shift_kv, vprompt_scale_kv = self.get_ada_values(
self.prompt_scale_shift_table, vx.shape[0], video.prompt_timesteps, slice(0, 2)
)
attn_input = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_q) + vshift_q
encoder_hidden_states = video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv
vx = vx + self.attn2(attn_input, context=encoder_hidden_states, mask=video.context_mask) * vgate_q
else:
vx = vx + self.attn2(
rms_norm(vx, eps=self.norm_eps),
context=video.context,
mask=video.context_mask,
)
# Process audio self-attention and cross-attention with text
if run_ax:
@@ -271,11 +295,23 @@ class BasicAVTransformerBlock(nn.Module):
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa
# Cross-attention with text context
ax = ax + self.audio_attn2(
rms_norm(ax, eps=self.norm_eps),
context=audio.context,
mask=audio.context_mask,
)
if self.has_prompt_adaln:
# LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln
ashift_q, ascale_q, agate_q = self.get_ada_values(
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(6, 9)
)
aprompt_shift_kv, aprompt_scale_kv = self.get_ada_values(
self.audio_prompt_scale_shift_table, ax.shape[0], audio.prompt_timesteps, slice(0, 2)
)
attn_input_a = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q
encoder_hidden_states_a = audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv
ax = ax + self.audio_attn2(attn_input_a, context=encoder_hidden_states_a, mask=audio.context_mask) * agate_q
else:
ax = ax + self.audio_attn2(
rms_norm(ax, eps=self.norm_eps),
context=audio.context,
mask=audio.context_mask,
)
# Audio-Video cross-modal attention
if run_a2v or run_v2a:
@@ -341,7 +377,7 @@ class BasicAVTransformerBlock(nn.Module):
# Process video feed-forward
if run_vx:
vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None)
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, 6)
)
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
vx = vx + self.ff(vx_scaled) * vgate_mlp
@@ -349,7 +385,7 @@ class BasicAVTransformerBlock(nn.Module):
# Process audio feed-forward
if run_ax:
ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None)
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6)
)
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
ax = ax + self.audio_ff(ax_scaled) * agate_mlp