Add LTX-2.3 model architecture with prompt-conditioned adaptive layer normalization (adaln) support. Introduce gating mechanisms in attention modules and update transformer configurations to accommodate new parameters. Refactor video and audio processing to utilize adaptive normalization, improving model flexibility and performance. Update weight loading and initialization logic to support dynamic block structures in the decoder.
This commit is contained in:
@@ -20,20 +20,25 @@ class Modality:
|
||||
context_mask: Optional[mx.array] = None
|
||||
# Optional precomputed positional embeddings (RoPE) to avoid recomputation
|
||||
positional_embeddings: Optional[Tuple[mx.array, mx.array]] = None
|
||||
# Raw sigma value (scalar per batch) for prompt adaln (LTX-2.3)
|
||||
sigma: Optional[mx.array] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TransformerArgs:
|
||||
x: mx.array
|
||||
context: mx.array
|
||||
context_mask: Optional[mx.array]
|
||||
timesteps: mx.array
|
||||
embedded_timestep: mx.array
|
||||
positional_embeddings: Tuple[mx.array, mx.array]
|
||||
cross_positional_embeddings: Optional[Tuple[mx.array, mx.array]]
|
||||
cross_scale_shift_timestep: Optional[mx.array]
|
||||
cross_gate_timestep: Optional[mx.array]
|
||||
x: mx.array
|
||||
context: mx.array
|
||||
context_mask: Optional[mx.array]
|
||||
timesteps: mx.array
|
||||
embedded_timestep: mx.array
|
||||
positional_embeddings: Tuple[mx.array, mx.array]
|
||||
cross_positional_embeddings: Optional[Tuple[mx.array, mx.array]]
|
||||
cross_scale_shift_timestep: Optional[mx.array]
|
||||
cross_gate_timestep: Optional[mx.array]
|
||||
enabled: bool
|
||||
# LTX-2.3: prompt-conditioned timestep embeddings for cross-attention
|
||||
prompt_timesteps: Optional[mx.array] = None
|
||||
prompt_embedded_timestep: Optional[mx.array] = None
|
||||
|
||||
|
||||
class BasicAVTransformerBlock(nn.Module):
|
||||
@@ -50,20 +55,13 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
audio: Optional[TransformerConfig] = None,
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||
norm_eps: float = 1e-6,
|
||||
has_prompt_adaln: bool = False,
|
||||
):
|
||||
"""Initialize transformer block.
|
||||
|
||||
Args:
|
||||
idx: Block index
|
||||
video: Video modality configuration
|
||||
audio: Audio modality configuration
|
||||
rope_type: Type of rotary position embedding
|
||||
norm_eps: Epsilon for normalization
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.idx = idx
|
||||
self.norm_eps = norm_eps
|
||||
self.has_prompt_adaln = has_prompt_adaln
|
||||
|
||||
# Video components
|
||||
if video is not None:
|
||||
@@ -74,6 +72,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
context_dim=None, # Self-attention
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
has_gate_logits=has_prompt_adaln,
|
||||
)
|
||||
self.attn2 = Attention(
|
||||
query_dim=video.dim,
|
||||
@@ -82,10 +81,15 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
dim_head=video.d_head,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
has_gate_logits=has_prompt_adaln,
|
||||
)
|
||||
self.ff = FeedForward(video.dim, dim_out=video.dim)
|
||||
# 6 scale-shift parameters: 3 for attention, 3 for MLP
|
||||
self.scale_shift_table = mx.zeros((6, video.dim))
|
||||
# 9 params for LTX-2.3 (self-attn + cross-attn + FFN), 6 for LTX-2
|
||||
num_ada_params = 9 if has_prompt_adaln else 6
|
||||
self.scale_shift_table = mx.zeros((num_ada_params, video.dim))
|
||||
|
||||
if has_prompt_adaln:
|
||||
self.prompt_scale_shift_table = mx.zeros((2, video.dim))
|
||||
|
||||
# Audio components
|
||||
if audio is not None:
|
||||
@@ -96,6 +100,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
context_dim=None,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
has_gate_logits=has_prompt_adaln,
|
||||
)
|
||||
self.audio_attn2 = Attention(
|
||||
query_dim=audio.dim,
|
||||
@@ -104,9 +109,14 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
dim_head=audio.d_head,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
has_gate_logits=has_prompt_adaln,
|
||||
)
|
||||
self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)
|
||||
self.audio_scale_shift_table = mx.zeros((6, audio.dim))
|
||||
num_audio_ada_params = 9 if has_prompt_adaln else 6
|
||||
self.audio_scale_shift_table = mx.zeros((num_audio_ada_params, audio.dim))
|
||||
|
||||
if has_prompt_adaln:
|
||||
self.audio_prompt_scale_shift_table = mx.zeros((2, audio.dim))
|
||||
|
||||
# Cross-modal attention (when both video and audio are enabled)
|
||||
if audio is not None and video is not None:
|
||||
@@ -118,6 +128,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
dim_head=audio.d_head,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
has_gate_logits=has_prompt_adaln,
|
||||
)
|
||||
# Video-to-Audio: Q from audio, K/V from video
|
||||
self.video_to_audio_attn = Attention(
|
||||
@@ -127,6 +138,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
dim_head=audio.d_head,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
has_gate_logits=has_prompt_adaln,
|
||||
)
|
||||
# Scale-shift tables for cross-attention
|
||||
self.scale_shift_table_a2v_ca_audio = mx.zeros((5, audio.dim))
|
||||
@@ -254,11 +266,23 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa
|
||||
|
||||
# Cross-attention with text context
|
||||
vx = vx + self.attn2(
|
||||
rms_norm(vx, eps=self.norm_eps),
|
||||
context=video.context,
|
||||
mask=video.context_mask,
|
||||
)
|
||||
if self.has_prompt_adaln:
|
||||
# LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln
|
||||
vshift_q, vscale_q, vgate_q = self.get_ada_values(
|
||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(6, 9)
|
||||
)
|
||||
vprompt_shift_kv, vprompt_scale_kv = self.get_ada_values(
|
||||
self.prompt_scale_shift_table, vx.shape[0], video.prompt_timesteps, slice(0, 2)
|
||||
)
|
||||
attn_input = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_q) + vshift_q
|
||||
encoder_hidden_states = video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv
|
||||
vx = vx + self.attn2(attn_input, context=encoder_hidden_states, mask=video.context_mask) * vgate_q
|
||||
else:
|
||||
vx = vx + self.attn2(
|
||||
rms_norm(vx, eps=self.norm_eps),
|
||||
context=video.context,
|
||||
mask=video.context_mask,
|
||||
)
|
||||
|
||||
# Process audio self-attention and cross-attention with text
|
||||
if run_ax:
|
||||
@@ -271,11 +295,23 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa
|
||||
|
||||
# Cross-attention with text context
|
||||
ax = ax + self.audio_attn2(
|
||||
rms_norm(ax, eps=self.norm_eps),
|
||||
context=audio.context,
|
||||
mask=audio.context_mask,
|
||||
)
|
||||
if self.has_prompt_adaln:
|
||||
# LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln
|
||||
ashift_q, ascale_q, agate_q = self.get_ada_values(
|
||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(6, 9)
|
||||
)
|
||||
aprompt_shift_kv, aprompt_scale_kv = self.get_ada_values(
|
||||
self.audio_prompt_scale_shift_table, ax.shape[0], audio.prompt_timesteps, slice(0, 2)
|
||||
)
|
||||
attn_input_a = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q
|
||||
encoder_hidden_states_a = audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv
|
||||
ax = ax + self.audio_attn2(attn_input_a, context=encoder_hidden_states_a, mask=audio.context_mask) * agate_q
|
||||
else:
|
||||
ax = ax + self.audio_attn2(
|
||||
rms_norm(ax, eps=self.norm_eps),
|
||||
context=audio.context,
|
||||
mask=audio.context_mask,
|
||||
)
|
||||
|
||||
# Audio-Video cross-modal attention
|
||||
if run_a2v or run_v2a:
|
||||
@@ -341,7 +377,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
# Process video feed-forward
|
||||
if run_vx:
|
||||
vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
|
||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None)
|
||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, 6)
|
||||
)
|
||||
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
|
||||
vx = vx + self.ff(vx_scaled) * vgate_mlp
|
||||
@@ -349,7 +385,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
# Process audio feed-forward
|
||||
if run_ax:
|
||||
ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
|
||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None)
|
||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6)
|
||||
)
|
||||
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
|
||||
ax = ax + self.audio_ff(ax_scaled) * agate_mlp
|
||||
|
||||
Reference in New Issue
Block a user