Add LTX-2.3 model architecture with prompt-conditioned adaptive layer normalization (adaln) support. Introduce gating mechanisms in attention modules and update transformer configurations to accommodate new parameters. Refactor video and audio processing to utilize adaptive normalization, improving model flexibility and performance. Update weight loading and initialization logic to support dynamic block structures in the decoder.

This commit is contained in:
Prince Canuma
2026-03-10 16:47:36 +01:00
parent d028b239fb
commit 207c223354
8 changed files with 545 additions and 239 deletions

View File

@@ -344,6 +344,7 @@ def denoise_distilled(
context=text_embeddings, context=text_embeddings,
context_mask=None, context_mask=None,
enabled=True, enabled=True,
sigma=mx.full((b,), sigma, dtype=dtype),
) )
audio_modality = None audio_modality = None
@@ -359,6 +360,7 @@ def denoise_distilled(
context=audio_embeddings, context=audio_embeddings,
context_mask=None, context_mask=None,
enabled=True, enabled=True,
sigma=mx.full((ab,), sigma, dtype=dtype),
) )
velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality) velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality)
@@ -493,6 +495,8 @@ def denoise_dev(
else: else:
timesteps = mx.full((b, num_tokens), sigma, dtype=dtype) timesteps = mx.full((b, num_tokens), sigma, dtype=dtype)
sigma_array = mx.full((b,), sigma, dtype=dtype)
# Positive conditioning pass # Positive conditioning pass
video_modality_pos = Modality( video_modality_pos = Modality(
latent=latents_flat, latent=latents_flat,
@@ -502,6 +506,7 @@ def denoise_dev(
context_mask=None, context_mask=None,
enabled=True, enabled=True,
positional_embeddings=precomputed_rope, positional_embeddings=precomputed_rope,
sigma=sigma_array,
) )
velocity_pos, _ = transformer(video=video_modality_pos, audio=None) velocity_pos, _ = transformer(video=video_modality_pos, audio=None)
@@ -523,6 +528,7 @@ def denoise_dev(
context_mask=None, context_mask=None,
enabled=True, enabled=True,
positional_embeddings=precomputed_rope, positional_embeddings=precomputed_rope,
sigma=sigma_array,
) )
velocity_neg, _ = transformer(video=video_modality_neg, audio=None) velocity_neg, _ = transformer(video=video_modality_neg, audio=None)
@@ -957,10 +963,18 @@ def generate_video(
mx.random.seed(seed) mx.random.seed(seed)
# Read transformer config to detect model version
import json
transformer_config_path = model_path / "transformer" / "config.json"
has_prompt_adaln = False
if transformer_config_path.exists():
with open(transformer_config_path) as f:
has_prompt_adaln = json.load(f).get("has_prompt_adaln", False)
# Load text encoder # Load text encoder
with console.status("[blue]📝 Loading text encoder...[/]", spinner="dots"): with console.status("[blue]📝 Loading text encoder...[/]", spinner="dots"):
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
text_encoder = LTX2TextEncoder() text_encoder = LTX2TextEncoder(has_prompt_adaln=has_prompt_adaln)
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path) text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
mx.eval(text_encoder.parameters()) mx.eval(text_encoder.parameters())
console.print("[green]✓[/] Text encoder loaded") console.print("[green]✓[/] Text encoder loaded")
@@ -1084,7 +1098,10 @@ def generate_video(
# Upsample latents # Upsample latents
with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"): with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"):
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')) upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
if not upscaler_files:
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
upsampler = load_upsampler(str(upscaler_files[0]))
mx.eval(upsampler.parameters()) mx.eval(upsampler.parameters())
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder")) vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))

View File

@@ -67,17 +67,8 @@ class Attention(nn.Module):
dim_head: int = 64, dim_head: int = 64,
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
has_gate_logits: bool = False,
): ):
"""Initialize attention module.
Args:
query_dim: Dimension of query input
context_dim: Dimension of context (key/value) input. If None, same as query_dim
heads: Number of attention heads
dim_head: Dimension per head
norm_eps: Epsilon for RMS normalization
rope_type: Type of rotary position embedding
"""
super().__init__() super().__init__()
self.rope_type = rope_type self.rope_type = rope_type
@@ -99,6 +90,10 @@ class Attention(nn.Module):
# Output projection # Output projection
self.to_out = nn.Linear(inner_dim, query_dim, bias=True) self.to_out = nn.Linear(inner_dim, query_dim, bias=True)
# Per-head gating (LTX-2.3)
if has_gate_logits:
self.to_gate_logits = nn.Linear(query_dim, heads, bias=True)
def __call__( def __call__(
self, self,
x: mx.array, x: mx.array,
@@ -119,6 +114,11 @@ class Attention(nn.Module):
Returns: Returns:
Attention output of shape (B, seq_len, query_dim) Attention output of shape (B, seq_len, query_dim)
""" """
# Compute per-head gate early (from original input)
gate = None
if hasattr(self, "to_gate_logits"):
gate = 2.0 * mx.sigmoid(self.to_gate_logits(x)) # (B, seq, heads)
# Compute Q, K, V # Compute Q, K, V
q = self.to_q(x) q = self.to_q(x)
context = x if context is None else context context = x if context is None else context
@@ -138,5 +138,12 @@ class Attention(nn.Module):
# Compute attention # Compute attention
out = scaled_dot_product_attention(q, k, v, self.heads, mask) out = scaled_dot_product_attention(q, k, v, self.heads, mask)
# Apply per-head gating
if gate is not None:
b, seq_len, _ = out.shape
out = mx.reshape(out, (b, seq_len, self.heads, self.dim_head))
out = out * gate[..., None]
out = mx.reshape(out, (b, seq_len, -1))
# Project output # Project output
return self.to_out(out) return self.to_out(out)

View File

@@ -131,6 +131,12 @@ class LTXModelConfig(BaseModelConfig):
# Attention type # Attention type
attention_type: AttentionType = AttentionType.DEFAULT attention_type: AttentionType = AttentionType.DEFAULT
# LTX-2.3: prompt-conditioned adaptive layer norm
# Controls: gate_logits in attention, 9-param scale_shift_table,
# prompt_adaln_single, per-block prompt_scale_shift_table,
# removal of caption_projection
has_prompt_adaln: bool = False
# VAE config # VAE config
vae_config: Optional[VideoVAEConfig] = None vae_config: Optional[VideoVAEConfig] = None

View File

@@ -26,7 +26,7 @@ class TransformerArgsPreprocessor:
self, self,
patchify_proj: nn.Linear, patchify_proj: nn.Linear,
adaln: AdaLayerNormSingle, adaln: AdaLayerNormSingle,
caption_projection: PixArtAlphaTextProjection, caption_projection: Optional[PixArtAlphaTextProjection],
inner_dim: int, inner_dim: int,
max_pos: List[int], max_pos: List[int],
num_attention_heads: int, num_attention_heads: int,
@@ -35,10 +35,12 @@ class TransformerArgsPreprocessor:
positional_embedding_theta: float, positional_embedding_theta: float,
rope_type: LTXRopeType, rope_type: LTXRopeType,
double_precision_rope: bool = False, double_precision_rope: bool = False,
prompt_adaln: Optional[AdaLayerNormSingle] = None,
): ):
self.patchify_proj = patchify_proj self.patchify_proj = patchify_proj
self.adaln = adaln self.adaln = adaln
self.caption_projection = caption_projection self.caption_projection = caption_projection
self.prompt_adaln = prompt_adaln
self.inner_dim = inner_dim self.inner_dim = inner_dim
self.max_pos = max_pos self.max_pos = max_pos
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
@@ -64,6 +66,19 @@ class TransformerArgsPreprocessor:
return timestep_emb, embedded_timestep return timestep_emb, embedded_timestep
def _prepare_timestep_with_adaln(
self,
adaln: AdaLayerNormSingle,
timestep: mx.array,
batch_size: int,
hidden_dtype: mx.Dtype = None,
) -> Tuple[mx.array, mx.array]:
timestep = timestep * self.timestep_scale_multiplier
timestep_emb, embedded_timestep = adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
return timestep_emb, embedded_timestep
def _prepare_context( def _prepare_context(
self, self,
context: mx.array, context: mx.array,
@@ -72,8 +87,7 @@ class TransformerArgsPreprocessor:
) -> Tuple[mx.array, Optional[mx.array]]: ) -> Tuple[mx.array, Optional[mx.array]]:
batch_size = x.shape[0] batch_size = x.shape[0]
# Context is already processed through embeddings connector in text encoder if self.caption_projection is not None:
# Here we just apply the caption projection
context = self.caption_projection(context) context = self.caption_projection(context)
context = mx.reshape(context, (batch_size, -1, x.shape[-1])) context = mx.reshape(context, (batch_size, -1, x.shape[-1]))
return context, attention_mask return context, attention_mask
@@ -134,6 +148,14 @@ class TransformerArgsPreprocessor:
num_attention_heads=self.num_attention_heads, num_attention_heads=self.num_attention_heads,
) )
# Prompt-conditioned timestep (LTX-2.3) - uses raw sigma, not per-token timesteps
prompt_timestep = None
prompt_embedded_timestep = None
if self.prompt_adaln is not None and modality.sigma is not None:
prompt_timestep, prompt_embedded_timestep = self._prepare_timestep_with_adaln(
self.prompt_adaln, modality.sigma, x.shape[0], hidden_dtype=x.dtype,
)
return TransformerArgs( return TransformerArgs(
x=x, x=x,
context=context, context=context,
@@ -145,6 +167,8 @@ class TransformerArgsPreprocessor:
cross_scale_shift_timestep=None, cross_scale_shift_timestep=None,
cross_gate_timestep=None, cross_gate_timestep=None,
enabled=modality.enabled, enabled=modality.enabled,
prompt_timesteps=prompt_timestep,
prompt_embedded_timestep=prompt_embedded_timestep,
) )
@@ -154,7 +178,7 @@ class MultiModalTransformerArgsPreprocessor:
self, self,
patchify_proj: nn.Linear, patchify_proj: nn.Linear,
adaln: AdaLayerNormSingle, adaln: AdaLayerNormSingle,
caption_projection: PixArtAlphaTextProjection, caption_projection: Optional[PixArtAlphaTextProjection],
cross_scale_shift_adaln: AdaLayerNormSingle, cross_scale_shift_adaln: AdaLayerNormSingle,
cross_gate_adaln: AdaLayerNormSingle, cross_gate_adaln: AdaLayerNormSingle,
inner_dim: int, inner_dim: int,
@@ -168,6 +192,7 @@ class MultiModalTransformerArgsPreprocessor:
rope_type: LTXRopeType, rope_type: LTXRopeType,
av_ca_timestep_scale_multiplier: int, av_ca_timestep_scale_multiplier: int,
double_precision_rope: bool = False, double_precision_rope: bool = False,
prompt_adaln: Optional[AdaLayerNormSingle] = None,
): ):
self.simple_preprocessor = TransformerArgsPreprocessor( self.simple_preprocessor = TransformerArgsPreprocessor(
patchify_proj=patchify_proj, patchify_proj=patchify_proj,
@@ -181,6 +206,7 @@ class MultiModalTransformerArgsPreprocessor:
positional_embedding_theta=positional_embedding_theta, positional_embedding_theta=positional_embedding_theta,
rope_type=rope_type, rope_type=rope_type,
double_precision_rope=double_precision_rope, double_precision_rope=double_precision_rope,
prompt_adaln=prompt_adaln,
) )
self.cross_scale_shift_adaln = cross_scale_shift_adaln self.cross_scale_shift_adaln = cross_scale_shift_adaln
self.cross_gate_adaln = cross_gate_adaln self.cross_gate_adaln = cross_gate_adaln
@@ -280,7 +306,13 @@ class LTXModel(nn.Module):
def _init_video(self, config: LTXModelConfig) -> None: def _init_video(self, config: LTXModelConfig) -> None:
self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True) self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True)
self.adaln_single = AdaLayerNormSingle(self.inner_dim)
adaln_coefficient = 9 if config.has_prompt_adaln else 6
self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=adaln_coefficient)
if config.has_prompt_adaln:
self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2)
else:
self.caption_projection = PixArtAlphaTextProjection( self.caption_projection = PixArtAlphaTextProjection(
in_features=config.caption_channels, in_features=config.caption_channels,
hidden_size=self.inner_dim, hidden_size=self.inner_dim,
@@ -292,9 +324,13 @@ class LTXModel(nn.Module):
def _init_audio(self, config: LTXModelConfig) -> None: def _init_audio(self, config: LTXModelConfig) -> None:
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True) self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim)
# Audio caption projection: receives pre-processed embeddings from text encoder's audio_embeddings_connector audio_adaln_coefficient = 9 if config.has_prompt_adaln else 6
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient)
if config.has_prompt_adaln:
self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2)
else:
self.audio_caption_projection = PixArtAlphaTextProjection( self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=config.audio_caption_channels, in_features=config.audio_caption_channels,
hidden_size=self.audio_inner_dim, hidden_size=self.audio_inner_dim,
@@ -331,7 +367,7 @@ class LTXModel(nn.Module):
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor( self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
patchify_proj=self.patchify_proj, patchify_proj=self.patchify_proj,
adaln=self.adaln_single, adaln=self.adaln_single,
caption_projection=self.caption_projection, caption_projection=getattr(self, "caption_projection", None),
cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single, cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
cross_gate_adaln=self.av_ca_a2v_gate_adaln_single, cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
inner_dim=self.inner_dim, inner_dim=self.inner_dim,
@@ -345,11 +381,12 @@ class LTXModel(nn.Module):
rope_type=config.rope_type, rope_type=config.rope_type,
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier, av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
double_precision_rope=config.double_precision_rope, double_precision_rope=config.double_precision_rope,
prompt_adaln=getattr(self, "prompt_adaln_single", None),
) )
self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor( self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
patchify_proj=self.audio_patchify_proj, patchify_proj=self.audio_patchify_proj,
adaln=self.audio_adaln_single, adaln=self.audio_adaln_single,
caption_projection=self.audio_caption_projection, caption_projection=getattr(self, "audio_caption_projection", None),
cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single, cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
cross_gate_adaln=self.av_ca_v2a_gate_adaln_single, cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
inner_dim=self.audio_inner_dim, inner_dim=self.audio_inner_dim,
@@ -363,12 +400,13 @@ class LTXModel(nn.Module):
rope_type=config.rope_type, rope_type=config.rope_type,
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier, av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
double_precision_rope=config.double_precision_rope, double_precision_rope=config.double_precision_rope,
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
) )
elif config.model_type.is_video_enabled(): elif config.model_type.is_video_enabled():
self.video_args_preprocessor = TransformerArgsPreprocessor( self.video_args_preprocessor = TransformerArgsPreprocessor(
patchify_proj=self.patchify_proj, patchify_proj=self.patchify_proj,
adaln=self.adaln_single, adaln=self.adaln_single,
caption_projection=self.caption_projection, caption_projection=getattr(self, "caption_projection", None),
inner_dim=self.inner_dim, inner_dim=self.inner_dim,
max_pos=config.positional_embedding_max_pos, max_pos=config.positional_embedding_max_pos,
num_attention_heads=self.num_attention_heads, num_attention_heads=self.num_attention_heads,
@@ -377,12 +415,13 @@ class LTXModel(nn.Module):
positional_embedding_theta=config.positional_embedding_theta, positional_embedding_theta=config.positional_embedding_theta,
rope_type=config.rope_type, rope_type=config.rope_type,
double_precision_rope=config.double_precision_rope, double_precision_rope=config.double_precision_rope,
prompt_adaln=getattr(self, "prompt_adaln_single", None),
) )
elif config.model_type.is_audio_enabled(): elif config.model_type.is_audio_enabled():
self.audio_args_preprocessor = TransformerArgsPreprocessor( self.audio_args_preprocessor = TransformerArgsPreprocessor(
patchify_proj=self.audio_patchify_proj, patchify_proj=self.audio_patchify_proj,
adaln=self.audio_adaln_single, adaln=self.audio_adaln_single,
caption_projection=self.audio_caption_projection, caption_projection=getattr(self, "audio_caption_projection", None),
inner_dim=self.audio_inner_dim, inner_dim=self.audio_inner_dim,
max_pos=config.audio_positional_embedding_max_pos, max_pos=config.audio_positional_embedding_max_pos,
num_attention_heads=self.audio_num_attention_heads, num_attention_heads=self.audio_num_attention_heads,
@@ -391,13 +430,13 @@ class LTXModel(nn.Module):
positional_embedding_theta=config.positional_embedding_theta, positional_embedding_theta=config.positional_embedding_theta,
rope_type=config.rope_type, rope_type=config.rope_type,
double_precision_rope=config.double_precision_rope, double_precision_rope=config.double_precision_rope,
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
) )
def _init_transformer_blocks(self, config: LTXModelConfig) -> None: def _init_transformer_blocks(self, config: LTXModelConfig) -> None:
video_config = config.get_video_config() video_config = config.get_video_config()
audio_config = config.get_audio_config() audio_config = config.get_audio_config()
self.transformer_blocks = { self.transformer_blocks = {
idx: BasicAVTransformerBlock( idx: BasicAVTransformerBlock(
idx=idx, idx=idx,
@@ -405,6 +444,7 @@ class LTXModel(nn.Module):
audio=audio_config, audio=audio_config,
rope_type=config.rope_type, rope_type=config.rope_type,
norm_eps=config.norm_eps, norm_eps=config.norm_eps,
has_prompt_adaln=config.has_prompt_adaln,
) )
for idx in range(config.num_layers) for idx in range(config.num_layers)
} }

View File

@@ -208,6 +208,7 @@ class ConnectorAttention(nn.Module):
dim: int = 3840, dim: int = 3840,
num_heads: int = 30, num_heads: int = 30,
head_dim: int = 128, head_dim: int = 128,
has_gate_logits: bool = False,
): ):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
@@ -218,13 +219,14 @@ class ConnectorAttention(nn.Module):
self.to_q = nn.Linear(dim, inner_dim, bias=True) self.to_q = nn.Linear(dim, inner_dim, bias=True)
self.to_k = nn.Linear(dim, inner_dim, bias=True) self.to_k = nn.Linear(dim, inner_dim, bias=True)
self.to_v = nn.Linear(dim, inner_dim, bias=True) self.to_v = nn.Linear(dim, inner_dim, bias=True)
# Direct attribute for MLX parameter tracking (not a list)
self.to_out = nn.Linear(inner_dim, dim, bias=True) self.to_out = nn.Linear(inner_dim, dim, bias=True)
# Standard RMSNorm (not Gemma-style) on full inner_dim
self.q_norm = nn.RMSNorm(inner_dim, eps=1e-6) self.q_norm = nn.RMSNorm(inner_dim, eps=1e-6)
self.k_norm = nn.RMSNorm(inner_dim, eps=1e-6) self.k_norm = nn.RMSNorm(inner_dim, eps=1e-6)
if has_gate_logits:
self.to_gate_logits = nn.Linear(dim, num_heads, bias=True)
def __call__( def __call__(
self, self,
x: mx.array, x: mx.array,
@@ -233,12 +235,17 @@ class ConnectorAttention(nn.Module):
) -> mx.array: ) -> mx.array:
batch_size, seq_len, _ = x.shape batch_size, seq_len, _ = x.shape
# Compute per-head gate early (from original input)
gate = None
if hasattr(self, "to_gate_logits"):
gate = 2.0 * mx.sigmoid(self.to_gate_logits(x)) # (B, seq, heads)
# Project to Q, K, V # Project to Q, K, V
q = self.to_q(x) # (B, seq, inner_dim) q = self.to_q(x)
k = self.to_k(x) k = self.to_k(x)
v = self.to_v(x) v = self.to_v(x)
# QK normalization on full inner_dim BEFORE reshape (matches PyTorch) # QK normalization on full inner_dim BEFORE reshape
q = self.q_norm(q) q = self.q_norm(q)
k = self.k_norm(k) k = self.k_norm(k)
@@ -248,15 +255,18 @@ class ConnectorAttention(nn.Module):
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3) v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
if pe is not None: if pe is not None:
# pe: tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2)
# Apply SPLIT RoPE: operates on first half of head dimensions
q = self._apply_split_rope(q, pe[0], pe[1]) q = self._apply_split_rope(q, pe[0], pe[1])
k = self._apply_split_rope(k, pe[0], pe[1]) k = self._apply_split_rope(k, pe[0], pe[1])
# No mask needed for connector - after register replacement, all positions are valid
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None) out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None)
out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1) out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
# Apply per-head gating
if gate is not None:
out = mx.reshape(out, (batch_size, seq_len, self.num_heads, self.head_dim))
out = out * gate[..., None]
out = mx.reshape(out, (batch_size, seq_len, -1))
return self.to_out(out) return self.to_out(out)
def _apply_split_rope( def _apply_split_rope(
@@ -326,9 +336,9 @@ class ConnectorFeedForward(nn.Module):
class ConnectorTransformerBlock(nn.Module): class ConnectorTransformerBlock(nn.Module):
def __init__(self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128): def __init__(self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128, has_gate_logits: bool = False):
super().__init__() super().__init__()
self.attn1 = ConnectorAttention(dim, num_heads, head_dim) self.attn1 = ConnectorAttention(dim, num_heads, head_dim, has_gate_logits=has_gate_logits)
self.ff = ConnectorFeedForward(dim) self.ff = ConnectorFeedForward(dim)
def __call__( def __call__(
@@ -367,6 +377,7 @@ class Embeddings1DConnector(nn.Module):
num_learnable_registers: int = 128, num_learnable_registers: int = 128,
positional_embedding_theta: float = 10000.0, positional_embedding_theta: float = 10000.0,
positional_embedding_max_pos: list = None, positional_embedding_max_pos: list = None,
has_gate_logits: bool = False,
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@@ -376,9 +387,8 @@ class Embeddings1DConnector(nn.Module):
self.positional_embedding_theta = positional_embedding_theta self.positional_embedding_theta = positional_embedding_theta
self.positional_embedding_max_pos = positional_embedding_max_pos or [4096] self.positional_embedding_max_pos = positional_embedding_max_pos or [4096]
# Use dict with int keys for MLX to track parameters (lists are not tracked)
self.transformer_1d_blocks = { self.transformer_1d_blocks = {
i: ConnectorTransformerBlock(dim, num_heads, head_dim) i: ConnectorTransformerBlock(dim, num_heads, head_dim, has_gate_logits=has_gate_logits)
for i in range(num_layers) for i in range(num_layers)
} }
@@ -572,16 +582,100 @@ def norm_and_concat_hidden_states(
return normed return normed
class GemmaFeaturesExtractor(nn.Module): def norm_and_concat_per_token_rms(
encoded_text: mx.array,
attention_mask: mx.array,
) -> mx.array:
"""Per-token RMSNorm normalization for V2 feature extraction (LTX-2.3).
def __init__(self, input_dim: int = 188160, output_dim: int = 3840): Args:
encoded_text: (B, T, D, L) stacked hidden states
attention_mask: (B, T) binary mask (1=valid, 0=padding)
Returns:
(B, T, D*L) normalized tensor with padding zeroed out.
"""
b, t, d, num_layers = encoded_text.shape
dtype = encoded_text.dtype
# Per-token RMSNorm across hidden dimension: variance = mean(x^2) over dim D
variance = mx.mean(encoded_text.astype(mx.float32) ** 2, axis=2, keepdims=True) # (B, T, 1, L)
normed = encoded_text.astype(mx.float32) * mx.rsqrt(variance + 1e-6)
normed = normed.astype(dtype)
# Flatten layers: (B, T, D*L)
normed = mx.reshape(normed, (b, t, d * num_layers))
# Zero out padded positions
mask_3d = attention_mask[:, :, None].astype(mx.bool_) # (B, T, 1)
normed = mx.where(mask_3d, normed, mx.zeros_like(normed))
return normed
def _rescale_norm(x: mx.array, target_dim: int, source_dim: int) -> mx.array:
"""Rescale normalization: x * sqrt(target_dim / source_dim)."""
return x * math.sqrt(target_dim / source_dim)
class GemmaFeaturesExtractor(nn.Module):
"""V1 feature extractor (LTX-2): 8 * (x - mean) / range normalization."""
def __init__(self, input_dim: int = 188160, output_dim: int = 3840, bias: bool = False):
super().__init__() super().__init__()
self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=False) self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=bias)
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
return self.aggregate_embed(x) return self.aggregate_embed(x)
class GemmaFeaturesExtractorV2(nn.Module):
"""V2 feature extractor (LTX-2.3): per-token RMSNorm + rescale normalization."""
def __init__(
self,
flat_dim: int,
embedding_dim: int,
video_output_dim: int,
audio_output_dim: int,
bias: bool = True,
):
super().__init__()
self.embedding_dim = embedding_dim # Gemma hidden_dim (3840), used for rescale
self.video_aggregate_embed = nn.Linear(flat_dim, video_output_dim, bias=bias)
self.audio_aggregate_embed = nn.Linear(flat_dim, audio_output_dim, bias=bias)
def __call__(
self,
hidden_states: List[mx.array],
attention_mask: mx.array,
mode: str = "video",
) -> mx.array:
"""Extract features with per-token RMSNorm + rescale.
Args:
hidden_states: List of hidden states from all Gemma layers
attention_mask: Binary attention mask (B, T)
mode: "video" or "audio" to select which aggregate embed to use
Returns:
Projected features
"""
# Stack hidden states: (B, T, D, L)
encoded = mx.stack(hidden_states, axis=-1)
# Per-token RMSNorm + flatten
normed = norm_and_concat_per_token_rms(encoded, attention_mask)
normed = normed.astype(encoded.dtype)
if mode == "video":
target_dim = self.video_aggregate_embed.weight.shape[0]
return self.video_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim))
else:
target_dim = self.audio_aggregate_embed.weight.shape[0]
return self.audio_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim))
@@ -603,38 +697,53 @@ class LTX2TextEncoder(nn.Module):
hidden_dim: int = 3840, hidden_dim: int = 3840,
audio_dim: int = 2048, audio_dim: int = 2048,
num_layers: int = 49, # 48 transformer layers + 1 embedding num_layers: int = 49, # 48 transformer layers + 1 embedding
has_prompt_adaln: bool = False,
): ):
super().__init__() super().__init__()
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.audio_dim = audio_dim self.audio_dim = audio_dim
self.num_layers = num_layers self.num_layers = num_layers
self.has_prompt_adaln = has_prompt_adaln
self.language_model = None self.language_model = None
# Feature extractor: 3840*49 -> 3840 feature_input_dim = hidden_dim * num_layers
self.feature_extractor = GemmaFeaturesExtractor(
input_dim=hidden_dim * num_layers, if has_prompt_adaln:
output_dim=hidden_dim, # LTX-2.3: V2 feature extractor with per-token RMSNorm + rescale
video_output_dim = 4096
audio_output_dim = 2048
self.feature_extractor_v2 = GemmaFeaturesExtractorV2(
flat_dim=feature_input_dim, # 3840 * 49 = 188160 (concatenated)
embedding_dim=hidden_dim, # 3840 (Gemma hidden_dim, for rescale)
video_output_dim=video_output_dim,
audio_output_dim=audio_output_dim,
bias=True,
) )
# Video embeddings connector: 2-layer transformer # Deeper connectors with matching dims and gate_logits
self.video_embeddings_connector = Embeddings1DConnector( self.video_embeddings_connector = Embeddings1DConnector(
dim=hidden_dim, dim=video_output_dim, num_heads=32, head_dim=128,
num_heads=30, num_layers=8, num_learnable_registers=128,
head_dim=128, positional_embedding_max_pos=[4096], has_gate_logits=True,
num_layers=2,
num_learnable_registers=128,
positional_embedding_max_pos=[4096], # Match PyTorch
) )
# Audio embeddings connector: separate 2-layer transformer (same architecture as video)
# Both connectors process the feature extractor output independently
self.audio_embeddings_connector = Embeddings1DConnector( self.audio_embeddings_connector = Embeddings1DConnector(
dim=hidden_dim, dim=audio_output_dim, num_heads=32, head_dim=64,
num_heads=30, num_layers=8, num_learnable_registers=128,
head_dim=128, positional_embedding_max_pos=[4096], has_gate_logits=True,
num_layers=2, )
num_learnable_registers=128, else:
positional_embedding_max_pos=[4096], # Match PyTorch # LTX-2: shared feature extractor, 3840-dim connectors
self.feature_extractor = GemmaFeaturesExtractor(feature_input_dim, hidden_dim)
self.video_embeddings_connector = Embeddings1DConnector(
dim=hidden_dim, num_heads=30, head_dim=128,
num_layers=2, num_learnable_registers=128,
positional_embedding_max_pos=[4096],
)
self.audio_embeddings_connector = Embeddings1DConnector(
dim=hidden_dim, num_heads=30, head_dim=128,
num_layers=2, num_learnable_registers=128,
positional_embedding_max_pos=[4096],
) )
self.processor = None self.processor = None
@@ -669,81 +778,9 @@ class LTX2TextEncoder(nn.Module):
transformer_weights = mx.load(str(transformer_files[0])) transformer_weights = mx.load(str(transformer_files[0]))
if transformer_weights: if transformer_weights:
# Load feature extractor (aggregate_embed) self._load_feature_extractors(transformer_weights, is_reformatted)
# Reformatted key: "aggregate_embed.weight" self._load_connector("video_embeddings_connector", transformer_weights, is_reformatted)
# Monolithic key: "text_embedding_projection.aggregate_embed.weight" self._load_connector("audio_embeddings_connector", transformer_weights, is_reformatted)
agg_key = "aggregate_embed.weight" if is_reformatted else "text_embedding_projection.aggregate_embed.weight"
if agg_key in transformer_weights:
self.feature_extractor.aggregate_embed.weight = transformer_weights[agg_key]
# Load video_embeddings_connector weights
connector_weights = {}
if is_reformatted:
# Reformatted: keys are already sanitized with "video_embeddings_connector." prefix
for key, value in transformer_weights.items():
if key.startswith("video_embeddings_connector."):
new_key = key.replace("video_embeddings_connector.", "")
connector_weights[new_key] = value
else:
# Monolithic: keys have "model.diffusion_model.video_embeddings_connector." prefix
for key, value in transformer_weights.items():
if key.startswith("model.diffusion_model.video_embeddings_connector."):
new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "")
connector_weights[new_key] = value
if connector_weights:
# Map weight names to our structure (only needed for monolithic/raw PyTorch keys)
mapped_weights = {}
for key, value in connector_weights.items():
new_key = key
if not is_reformatted:
# Map ff.net.0.proj -> ff.proj_in (GEGLU projection)
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
# Map ff.net.2 -> ff.proj_out (output Linear)
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
# Map to_out.0 -> to_out (Sequential -> direct)
new_key = new_key.replace(".to_out.0.", ".to_out.")
mapped_weights[new_key] = value
self.video_embeddings_connector.load_weights(
list(mapped_weights.items()), strict=False
)
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
if "learnable_registers" in connector_weights:
self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"]
# Load audio_embeddings_connector weights (same structure as video connector)
audio_connector_weights = {}
if is_reformatted:
for key, value in transformer_weights.items():
if key.startswith("audio_embeddings_connector."):
new_key = key.replace("audio_embeddings_connector.", "")
audio_connector_weights[new_key] = value
else:
for key, value in transformer_weights.items():
if key.startswith("model.diffusion_model.audio_embeddings_connector."):
new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "")
audio_connector_weights[new_key] = value
if audio_connector_weights:
# Map weight names to our structure (same as video connector)
mapped_audio_weights = {}
for key, value in audio_connector_weights.items():
new_key = key
if not is_reformatted:
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
new_key = new_key.replace(".to_out.0.", ".to_out.")
mapped_audio_weights[new_key] = value
self.audio_embeddings_connector.load_weights(
list(mapped_audio_weights.items()), strict=False
)
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
if "learnable_registers" in audio_connector_weights:
self.audio_embeddings_connector.learnable_registers = audio_connector_weights["learnable_registers"]
else: else:
print("WARNING: No transformer weights found for text projection connectors. " print("WARNING: No transformer weights found for text projection connectors. "
"Text conditioning will use uninitialized weights!") "Text conditioning will use uninitialized weights!")
@@ -763,6 +800,63 @@ class LTX2TextEncoder(nn.Module):
print("Text encoder loaded successfully") print("Text encoder loaded successfully")
def _load_feature_extractors(self, weights: dict, is_reformatted: bool):
"""Load feature extractor weights for both LTX-2 and LTX-2.3."""
if self.has_prompt_adaln:
# LTX-2.3: V2 feature extractor with separate video/audio aggregate embeds
for attr, prefix in [
("video_aggregate_embed", "video_aggregate_embed"),
("audio_aggregate_embed", "audio_aggregate_embed"),
]:
w_key = f"{prefix}.weight"
b_key = f"{prefix}.bias"
if w_key in weights:
submodule = getattr(self.feature_extractor_v2, attr)
submodule.weight = weights[w_key]
if b_key in weights:
submodule.bias = weights[b_key]
else:
# LTX-2: single aggregate_embed
agg_key = "aggregate_embed.weight" if is_reformatted else "text_embedding_projection.aggregate_embed.weight"
if agg_key in weights:
self.feature_extractor.aggregate_embed.weight = weights[agg_key]
def _load_connector(self, name: str, weights: dict, is_reformatted: bool):
"""Load a connector's weights (video or audio)."""
connector = getattr(self, name)
# Extract connector-specific weights
connector_weights = {}
if is_reformatted:
prefix = f"{name}."
for key, value in weights.items():
if key.startswith(prefix):
connector_weights[key[len(prefix):]] = value
else:
mono_prefix = f"model.diffusion_model.{name}."
for key, value in weights.items():
if key.startswith(mono_prefix):
connector_weights[key[len(mono_prefix):]] = value
if not connector_weights:
return
# Sanitize key names (only needed for monolithic/raw PyTorch keys)
mapped = {}
for key, value in connector_weights.items():
new_key = key
if not is_reformatted:
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
new_key = new_key.replace(".to_out.0.", ".to_out.")
mapped[new_key] = value
connector.load_weights(list(mapped.items()), strict=False)
# Manually load learnable_registers (plain mx.array, not tracked as parameter)
if "learnable_registers" in connector_weights:
connector.learnable_registers = connector_weights["learnable_registers"]
def encode( def encode(
self, self,
prompt: str, prompt: str,
@@ -795,18 +889,37 @@ class LTX2TextEncoder(nn.Module):
attention_mask = mx.array(inputs["attention_mask"]) attention_mask = mx.array(inputs["attention_mask"])
_, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True) _, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True)
if self.has_prompt_adaln:
# LTX-2.3: V2 feature extraction (per-token RMSNorm + rescale)
video_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="video")
additive_mask = (attention_mask - 1).astype(video_features.dtype)
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask)
if return_audio_embeddings:
audio_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="audio")
audio_mask = (attention_mask - 1).astype(audio_features.dtype)
audio_mask = audio_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
audio_embeddings, _ = self.audio_embeddings_connector(audio_features, audio_mask)
return video_embeddings, audio_embeddings
else:
return video_embeddings, attention_mask
else:
# LTX-2: V1 feature extraction (8 * (x - mean) / range)
concat_hidden = norm_and_concat_hidden_states( concat_hidden = norm_and_concat_hidden_states(
all_hidden_states, attention_mask, padding_side="left" all_hidden_states, attention_mask, padding_side="left"
) )
features = self.feature_extractor(concat_hidden)
additive_mask = (attention_mask - 1).astype(features.dtype) video_features = self.feature_extractor(concat_hidden)
additive_mask = (attention_mask - 1).astype(video_features.dtype)
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
video_embeddings, _ = self.video_embeddings_connector(features, additive_mask) video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask)
if return_audio_embeddings: if return_audio_embeddings:
# Process features through audio connector independently (same input as video) audio_embeddings, _ = self.audio_embeddings_connector(video_features, additive_mask)
audio_embeddings, _ = self.audio_embeddings_connector(features, additive_mask)
return video_embeddings, audio_embeddings return video_embeddings, audio_embeddings
else: else:
return video_embeddings, attention_mask return video_embeddings, attention_mask

View File

@@ -20,6 +20,8 @@ class Modality:
context_mask: Optional[mx.array] = None context_mask: Optional[mx.array] = None
# Optional precomputed positional embeddings (RoPE) to avoid recomputation # Optional precomputed positional embeddings (RoPE) to avoid recomputation
positional_embeddings: Optional[Tuple[mx.array, mx.array]] = None positional_embeddings: Optional[Tuple[mx.array, mx.array]] = None
# Raw sigma value (scalar per batch) for prompt adaln (LTX-2.3)
sigma: Optional[mx.array] = None
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -34,6 +36,9 @@ class TransformerArgs:
cross_scale_shift_timestep: Optional[mx.array] cross_scale_shift_timestep: Optional[mx.array]
cross_gate_timestep: Optional[mx.array] cross_gate_timestep: Optional[mx.array]
enabled: bool enabled: bool
# LTX-2.3: prompt-conditioned timestep embeddings for cross-attention
prompt_timesteps: Optional[mx.array] = None
prompt_embedded_timestep: Optional[mx.array] = None
class BasicAVTransformerBlock(nn.Module): class BasicAVTransformerBlock(nn.Module):
@@ -50,20 +55,13 @@ class BasicAVTransformerBlock(nn.Module):
audio: Optional[TransformerConfig] = None, audio: Optional[TransformerConfig] = None,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
has_prompt_adaln: bool = False,
): ):
"""Initialize transformer block.
Args:
idx: Block index
video: Video modality configuration
audio: Audio modality configuration
rope_type: Type of rotary position embedding
norm_eps: Epsilon for normalization
"""
super().__init__() super().__init__()
self.idx = idx self.idx = idx
self.norm_eps = norm_eps self.norm_eps = norm_eps
self.has_prompt_adaln = has_prompt_adaln
# Video components # Video components
if video is not None: if video is not None:
@@ -74,6 +72,7 @@ class BasicAVTransformerBlock(nn.Module):
context_dim=None, # Self-attention context_dim=None, # Self-attention
rope_type=rope_type, rope_type=rope_type,
norm_eps=norm_eps, norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
) )
self.attn2 = Attention( self.attn2 = Attention(
query_dim=video.dim, query_dim=video.dim,
@@ -82,10 +81,15 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=video.d_head, dim_head=video.d_head,
rope_type=rope_type, rope_type=rope_type,
norm_eps=norm_eps, norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
) )
self.ff = FeedForward(video.dim, dim_out=video.dim) self.ff = FeedForward(video.dim, dim_out=video.dim)
# 6 scale-shift parameters: 3 for attention, 3 for MLP # 9 params for LTX-2.3 (self-attn + cross-attn + FFN), 6 for LTX-2
self.scale_shift_table = mx.zeros((6, video.dim)) num_ada_params = 9 if has_prompt_adaln else 6
self.scale_shift_table = mx.zeros((num_ada_params, video.dim))
if has_prompt_adaln:
self.prompt_scale_shift_table = mx.zeros((2, video.dim))
# Audio components # Audio components
if audio is not None: if audio is not None:
@@ -96,6 +100,7 @@ class BasicAVTransformerBlock(nn.Module):
context_dim=None, context_dim=None,
rope_type=rope_type, rope_type=rope_type,
norm_eps=norm_eps, norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
) )
self.audio_attn2 = Attention( self.audio_attn2 = Attention(
query_dim=audio.dim, query_dim=audio.dim,
@@ -104,9 +109,14 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=audio.d_head, dim_head=audio.d_head,
rope_type=rope_type, rope_type=rope_type,
norm_eps=norm_eps, norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
) )
self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim) self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)
self.audio_scale_shift_table = mx.zeros((6, audio.dim)) num_audio_ada_params = 9 if has_prompt_adaln else 6
self.audio_scale_shift_table = mx.zeros((num_audio_ada_params, audio.dim))
if has_prompt_adaln:
self.audio_prompt_scale_shift_table = mx.zeros((2, audio.dim))
# Cross-modal attention (when both video and audio are enabled) # Cross-modal attention (when both video and audio are enabled)
if audio is not None and video is not None: if audio is not None and video is not None:
@@ -118,6 +128,7 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=audio.d_head, dim_head=audio.d_head,
rope_type=rope_type, rope_type=rope_type,
norm_eps=norm_eps, norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
) )
# Video-to-Audio: Q from audio, K/V from video # Video-to-Audio: Q from audio, K/V from video
self.video_to_audio_attn = Attention( self.video_to_audio_attn = Attention(
@@ -127,6 +138,7 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=audio.d_head, dim_head=audio.d_head,
rope_type=rope_type, rope_type=rope_type,
norm_eps=norm_eps, norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
) )
# Scale-shift tables for cross-attention # Scale-shift tables for cross-attention
self.scale_shift_table_a2v_ca_audio = mx.zeros((5, audio.dim)) self.scale_shift_table_a2v_ca_audio = mx.zeros((5, audio.dim))
@@ -254,6 +266,18 @@ class BasicAVTransformerBlock(nn.Module):
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa
# Cross-attention with text context # Cross-attention with text context
if self.has_prompt_adaln:
# LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln
vshift_q, vscale_q, vgate_q = self.get_ada_values(
self.scale_shift_table, vx.shape[0], video.timesteps, slice(6, 9)
)
vprompt_shift_kv, vprompt_scale_kv = self.get_ada_values(
self.prompt_scale_shift_table, vx.shape[0], video.prompt_timesteps, slice(0, 2)
)
attn_input = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_q) + vshift_q
encoder_hidden_states = video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv
vx = vx + self.attn2(attn_input, context=encoder_hidden_states, mask=video.context_mask) * vgate_q
else:
vx = vx + self.attn2( vx = vx + self.attn2(
rms_norm(vx, eps=self.norm_eps), rms_norm(vx, eps=self.norm_eps),
context=video.context, context=video.context,
@@ -271,6 +295,18 @@ class BasicAVTransformerBlock(nn.Module):
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa
# Cross-attention with text context # Cross-attention with text context
if self.has_prompt_adaln:
# LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln
ashift_q, ascale_q, agate_q = self.get_ada_values(
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(6, 9)
)
aprompt_shift_kv, aprompt_scale_kv = self.get_ada_values(
self.audio_prompt_scale_shift_table, ax.shape[0], audio.prompt_timesteps, slice(0, 2)
)
attn_input_a = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q
encoder_hidden_states_a = audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv
ax = ax + self.audio_attn2(attn_input_a, context=encoder_hidden_states_a, mask=audio.context_mask) * agate_q
else:
ax = ax + self.audio_attn2( ax = ax + self.audio_attn2(
rms_norm(ax, eps=self.norm_eps), rms_norm(ax, eps=self.norm_eps),
context=audio.context, context=audio.context,
@@ -341,7 +377,7 @@ class BasicAVTransformerBlock(nn.Module):
# Process video feed-forward # Process video feed-forward
if run_vx: if run_vx:
vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values( vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None) self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, 6)
) )
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
vx = vx + self.ff(vx_scaled) * vgate_mlp vx = vx + self.ff(vx_scaled) * vgate_mlp
@@ -349,7 +385,7 @@ class BasicAVTransformerBlock(nn.Module):
# Process audio feed-forward # Process audio feed-forward
if run_ax: if run_ax:
ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values( ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None) self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6)
) )
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
ax = ax + self.audio_ff(ax_scaled) * agate_mlp ax = ax + self.audio_ff(ax_scaled) * agate_mlp

View File

@@ -301,7 +301,6 @@ def upsample_latents(
latent_std: mx.array, latent_std: mx.array,
debug: bool = False, debug: bool = False,
) -> mx.array: ) -> mx.array:
# Un-normalize: latent * std + mean # Un-normalize: latent * std + mean
latent_mean = latent_mean.reshape(1, -1, 1, 1, 1) latent_mean = latent_mean.reshape(1, -1, 1, 1, 1)
latent_std = latent_std.reshape(1, -1, 1, 1, 1) latent_std = latent_std.reshape(1, -1, 1, 1, 1)
@@ -350,19 +349,18 @@ def load_upsampler(weights_path: str) -> LatentUpsampler:
for key, value in raw_weights.items(): for key, value in raw_weights.items():
new_key = key new_key = key
# LTX-2.3 upsampler uses sequential indexing: upsampler.0.* -> upsampler.conv.*
if key.startswith("upsampler.0."):
new_key = key.replace("upsampler.0.", "upsampler.conv.")
# Conv3d weights: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I) # Conv3d weights: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
if "conv" in key and "weight" in key and value.ndim == 5: if "weight" in new_key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1)) value = mx.transpose(value, (0, 2, 3, 4, 1))
# Conv2d weights: PyTorch (O, I, H, W) -> MLX (O, H, W, I) # Conv2d weights: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
if "conv" in key and "weight" in key and value.ndim == 4: if "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1)) value = mx.transpose(value, (0, 2, 3, 1))
# Map upsampler.conv to upsampler.conv (SpatialRationalResampler)
# Keys: upsampler.conv.weight, upsampler.conv.bias, upsampler.blur_down.kernel
if key.startswith("upsampler."):
new_key = key # Keep as is for SpatialRationalResampler
sanitized[new_key] = value sanitized[new_key] = value
# Load weights # Load weights

View File

@@ -250,6 +250,18 @@ class LTX2VideoDecoder(nn.Module):
- conv_out: 128 -> 48 (3 * 4^2 for patch_size=4) - conv_out: 128 -> 48 (3 * 4^2 for patch_size=4)
""" """
# Block definitions: ("res", channels, num_layers) or ("d2s", in_channels, reduction, stride)
# stride is (D, H, W) tuple
DEFAULT_BLOCKS = [
("res", 1024, 5),
("d2s", 1024, 2, (2, 2, 2)),
("res", 512, 5),
("d2s", 512, 2, (2, 2, 2)),
("res", 256, 5),
("d2s", 256, 2, (2, 2, 2)),
("res", 128, 5),
]
def __init__( def __init__(
self, self,
in_channels: int = 128, in_channels: int = 128,
@@ -258,6 +270,7 @@ class LTX2VideoDecoder(nn.Module):
num_layers_per_block: int = 5, num_layers_per_block: int = 5,
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT, spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
timestep_conditioning: bool = True, timestep_conditioning: bool = True,
decoder_blocks: list = None,
): ):
super().__init__() super().__init__()
@@ -272,13 +285,17 @@ class LTX2VideoDecoder(nn.Module):
# Per-channel statistics for denormalization (loaded from weights) # Per-channel statistics for denormalization (loaded from weights)
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels) self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
# Initial conv: 128 -> 1024 blocks = decoder_blocks or self.DEFAULT_BLOCKS
first_ch = blocks[0][1]
last_ch = blocks[-1][1]
# Initial conv: in_channels -> first block channels
class ConvInWrapper(nn.Module): class ConvInWrapper(nn.Module):
def __init__(self_inner): def __init__(self_inner):
super().__init__() super().__init__()
self_inner.conv = CausalConv3d( self_inner.conv = CausalConv3d(
in_channels=in_channels, in_channels=in_channels,
out_channels=1024, out_channels=first_ch,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
@@ -288,45 +305,32 @@ class LTX2VideoDecoder(nn.Module):
return self_inner.conv(x, causal=causal) return self_inner.conv(x, causal=causal)
self.conv_in = ConvInWrapper() self.conv_in = ConvInWrapper()
# Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample # Build up blocks from config
# Use dict with int keys for MLX to track parameters properly self.up_blocks = {}
self.up_blocks = { for idx, block_def in enumerate(blocks):
0: ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning), block_type = block_def[0]
1: DepthToSpaceUpsample( ch = block_def[1]
if block_type == "res":
num_layers = block_def[2] if len(block_def) > 2 else num_layers_per_block
self.up_blocks[idx] = ResBlockGroup(ch, num_layers, spatial_padding_mode, timestep_conditioning)
elif block_type == "d2s":
reduction = block_def[2] if len(block_def) > 2 else 2
stride = block_def[3] if len(block_def) > 3 else (2, 2, 2)
self.up_blocks[idx] = DepthToSpaceUpsample(
dims=3, dims=3,
in_channels=1024, in_channels=ch,
stride=(2, 2, 2), stride=stride,
residual=True, residual=True,
out_channels_reduction_factor=2, out_channels_reduction_factor=reduction,
spatial_padding_mode=spatial_padding_mode, spatial_padding_mode=spatial_padding_mode,
), )
2: ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
3: DepthToSpaceUpsample(
dims=3,
in_channels=512,
stride=(2, 2, 2),
residual=True,
out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode,
),
4: ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
5: DepthToSpaceUpsample(
dims=3,
in_channels=256,
stride=(2, 2, 2),
residual=True,
out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode,
),
6: ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
}
final_out_channels = out_channels * patch_size * patch_size final_out_channels = out_channels * patch_size * patch_size
class ConvOutWrapper(nn.Module): class ConvOutWrapper(nn.Module):
def __init__(self_inner): def __init__(self_inner):
super().__init__() super().__init__()
self_inner.conv = CausalConv3d( self_inner.conv = CausalConv3d(
in_channels=128, in_channels=last_ch,
out_channels=final_out_channels, out_channels=final_out_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@@ -342,9 +346,9 @@ class LTX2VideoDecoder(nn.Module):
if timestep_conditioning: if timestep_conditioning:
self.timestep_scale_multiplier = mx.array(1000.0) self.timestep_scale_multiplier = mx.array(1000.0)
self.last_time_embedder = PixArtAlphaTimestepEmbedder( self.last_time_embedder = PixArtAlphaTimestepEmbedder(
embedding_dim=128 * 2 # 256, matches (2, 128) table embedding_dim=last_ch * 2
) )
self.last_scale_shift_table = mx.zeros((2, 128)) self.last_scale_shift_table = mx.zeros((2, last_ch))
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
# Build decoder weights dict with key remapping # Build decoder weights dict with key remapping
@@ -418,11 +422,96 @@ class LTX2VideoDecoder(nn.Module):
weights.update(mx.load(str(wf))) weights.update(mx.load(str(wf)))
model = cls(timestep_conditioning=config_dict.get("timestep_conditioning", False)) # Infer block structure from weights
decoder_blocks = cls._infer_blocks(weights)
model = cls(
timestep_conditioning=config_dict.get("timestep_conditioning", False),
decoder_blocks=decoder_blocks,
)
weights = model.sanitize(weights) weights = model.sanitize(weights)
model.load_weights(list(weights.items()), strict=strict) model.load_weights(list(weights.items()), strict=strict)
return model return model
@staticmethod
def _infer_blocks(weights: dict) -> list:
"""Infer decoder block structure from weight keys."""
block_indices = set()
for k in weights:
if "up_blocks." in k:
idx_str = k.split("up_blocks.")[1].split(".")[0]
if idx_str.isdigit():
block_indices.add(int(idx_str))
if not block_indices:
return None
# First pass: collect block info
raw_blocks = []
for idx in sorted(block_indices):
has_conv = any(f"up_blocks.{idx}.conv." in k for k in weights)
res_indices = set()
for k in weights:
prefix = f"up_blocks.{idx}.res_blocks."
if prefix in k:
res_idx = k.split(prefix)[1].split(".")[0]
if res_idx.isdigit():
res_indices.add(int(res_idx))
if has_conv and not res_indices:
# D2S block - get conv shape
for k, v in weights.items():
if f"up_blocks.{idx}.conv." in k and "weight" in k:
in_ch = v.shape[-1] if v.ndim == 5 else v.shape[1]
conv_out_ch = v.shape[0]
raw_blocks.append(("d2s", in_ch, conv_out_ch))
break
elif res_indices:
num_res = max(res_indices) + 1
for k, v in weights.items():
if f"up_blocks.{idx}.res_blocks.0.conv1" in k and "weight" in k:
ch = v.shape[0]
raw_blocks.append(("res", ch, num_res))
break
# Second pass: determine d2s strides using the channel progression
# For each d2s block, the next res block tells us the expected output channels
blocks = []
for i, block in enumerate(raw_blocks):
if block[0] == "res":
blocks.append(block)
elif block[0] == "d2s":
in_ch, conv_out_ch = block[1], block[2]
# Find next res block's channels
next_ch = None
for j in range(i + 1, len(raw_blocks)):
if raw_blocks[j][0] == "res":
next_ch = raw_blocks[j][1]
break
if next_ch is None:
next_ch = in_ch // 2 # fallback
# out_ch = in_ch // reduction
reduction = in_ch // next_ch if next_ch > 0 else 2
# conv_out = next_ch * multiplier → multiplier = conv_out / next_ch
multiplier = conv_out_ch // next_ch if next_ch > 0 else 8
# Determine stride from multiplier
if multiplier == 8:
stride = (2, 2, 2)
elif multiplier == 4:
stride = (1, 2, 2)
elif multiplier == 2:
stride = (2, 1, 1)
else:
stride = (2, 2, 2)
blocks.append(("d2s", in_ch, reduction, stride))
return blocks if blocks else None
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array: def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array: