From 207c223354ec24fe22fe398acb618d0c06d967f2 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Tue, 10 Mar 2026 16:47:36 +0100 Subject: [PATCH] Add LTX-2.3 model architecture with prompt-conditioned adaptive layer normalization (adaln) support. Introduce gating mechanisms in attention modules and update transformer configurations to accommodate new parameters. Refactor video and audio processing to utilize adaptive normalization, improving model flexibility and performance. Update weight loading and initialization logic to support dynamic block structures in the decoder. --- mlx_video/generate.py | 21 +- mlx_video/models/ltx/attention.py | 27 +- mlx_video/models/ltx/config.py | 6 + mlx_video/models/ltx/ltx.py | 82 +++-- mlx_video/models/ltx/text_encoder.py | 363 ++++++++++++++-------- mlx_video/models/ltx/transformer.py | 102 ++++-- mlx_video/models/ltx/upsampler.py | 18 +- mlx_video/models/ltx/video_vae/decoder.py | 165 +++++++--- 8 files changed, 545 insertions(+), 239 deletions(-) diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 4121738..21790a7 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -344,6 +344,7 @@ def denoise_distilled( context=text_embeddings, context_mask=None, enabled=True, + sigma=mx.full((b,), sigma, dtype=dtype), ) audio_modality = None @@ -359,6 +360,7 @@ def denoise_distilled( context=audio_embeddings, context_mask=None, enabled=True, + sigma=mx.full((ab,), sigma, dtype=dtype), ) velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality) @@ -493,6 +495,8 @@ def denoise_dev( else: timesteps = mx.full((b, num_tokens), sigma, dtype=dtype) + sigma_array = mx.full((b,), sigma, dtype=dtype) + # Positive conditioning pass video_modality_pos = Modality( latent=latents_flat, @@ -502,6 +506,7 @@ def denoise_dev( context_mask=None, enabled=True, positional_embeddings=precomputed_rope, + sigma=sigma_array, ) velocity_pos, _ = transformer(video=video_modality_pos, audio=None) @@ -523,6 +528,7 @@ def denoise_dev( context_mask=None, enabled=True, positional_embeddings=precomputed_rope, + sigma=sigma_array, ) velocity_neg, _ = transformer(video=video_modality_neg, audio=None) @@ -957,10 +963,18 @@ def generate_video( mx.random.seed(seed) + # Read transformer config to detect model version + import json + transformer_config_path = model_path / "transformer" / "config.json" + has_prompt_adaln = False + if transformer_config_path.exists(): + with open(transformer_config_path) as f: + has_prompt_adaln = json.load(f).get("has_prompt_adaln", False) + # Load text encoder with console.status("[blue]📝 Loading text encoder...[/]", spinner="dots"): from mlx_video.models.ltx.text_encoder import LTX2TextEncoder - text_encoder = LTX2TextEncoder() + text_encoder = LTX2TextEncoder(has_prompt_adaln=has_prompt_adaln) text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path) mx.eval(text_encoder.parameters()) console.print("[green]✓[/] Text encoder loaded") @@ -1084,7 +1098,10 @@ def generate_video( # Upsample latents with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"): - upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')) + upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors")) + if not upscaler_files: + raise FileNotFoundError(f"No spatial upscaler found in {model_path}") + upsampler = load_upsampler(str(upscaler_files[0])) mx.eval(upsampler.parameters()) vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder")) diff --git a/mlx_video/models/ltx/attention.py b/mlx_video/models/ltx/attention.py index 4024c91..ebc0a24 100644 --- a/mlx_video/models/ltx/attention.py +++ b/mlx_video/models/ltx/attention.py @@ -67,17 +67,8 @@ class Attention(nn.Module): dim_head: int = 64, norm_eps: float = 1e-6, rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + has_gate_logits: bool = False, ): - """Initialize attention module. - - Args: - query_dim: Dimension of query input - context_dim: Dimension of context (key/value) input. If None, same as query_dim - heads: Number of attention heads - dim_head: Dimension per head - norm_eps: Epsilon for RMS normalization - rope_type: Type of rotary position embedding - """ super().__init__() self.rope_type = rope_type @@ -99,6 +90,10 @@ class Attention(nn.Module): # Output projection self.to_out = nn.Linear(inner_dim, query_dim, bias=True) + # Per-head gating (LTX-2.3) + if has_gate_logits: + self.to_gate_logits = nn.Linear(query_dim, heads, bias=True) + def __call__( self, x: mx.array, @@ -119,6 +114,11 @@ class Attention(nn.Module): Returns: Attention output of shape (B, seq_len, query_dim) """ + # Compute per-head gate early (from original input) + gate = None + if hasattr(self, "to_gate_logits"): + gate = 2.0 * mx.sigmoid(self.to_gate_logits(x)) # (B, seq, heads) + # Compute Q, K, V q = self.to_q(x) context = x if context is None else context @@ -138,5 +138,12 @@ class Attention(nn.Module): # Compute attention out = scaled_dot_product_attention(q, k, v, self.heads, mask) + # Apply per-head gating + if gate is not None: + b, seq_len, _ = out.shape + out = mx.reshape(out, (b, seq_len, self.heads, self.dim_head)) + out = out * gate[..., None] + out = mx.reshape(out, (b, seq_len, -1)) + # Project output return self.to_out(out) diff --git a/mlx_video/models/ltx/config.py b/mlx_video/models/ltx/config.py index c63fcd7..40bb9ef 100644 --- a/mlx_video/models/ltx/config.py +++ b/mlx_video/models/ltx/config.py @@ -131,6 +131,12 @@ class LTXModelConfig(BaseModelConfig): # Attention type attention_type: AttentionType = AttentionType.DEFAULT + # LTX-2.3: prompt-conditioned adaptive layer norm + # Controls: gate_logits in attention, 9-param scale_shift_table, + # prompt_adaln_single, per-block prompt_scale_shift_table, + # removal of caption_projection + has_prompt_adaln: bool = False + # VAE config vae_config: Optional[VideoVAEConfig] = None diff --git a/mlx_video/models/ltx/ltx.py b/mlx_video/models/ltx/ltx.py index 5551b0a..6a63d7b 100644 --- a/mlx_video/models/ltx/ltx.py +++ b/mlx_video/models/ltx/ltx.py @@ -26,7 +26,7 @@ class TransformerArgsPreprocessor: self, patchify_proj: nn.Linear, adaln: AdaLayerNormSingle, - caption_projection: PixArtAlphaTextProjection, + caption_projection: Optional[PixArtAlphaTextProjection], inner_dim: int, max_pos: List[int], num_attention_heads: int, @@ -35,10 +35,12 @@ class TransformerArgsPreprocessor: positional_embedding_theta: float, rope_type: LTXRopeType, double_precision_rope: bool = False, + prompt_adaln: Optional[AdaLayerNormSingle] = None, ): self.patchify_proj = patchify_proj self.adaln = adaln self.caption_projection = caption_projection + self.prompt_adaln = prompt_adaln self.inner_dim = inner_dim self.max_pos = max_pos self.num_attention_heads = num_attention_heads @@ -64,6 +66,19 @@ class TransformerArgsPreprocessor: return timestep_emb, embedded_timestep + def _prepare_timestep_with_adaln( + self, + adaln: AdaLayerNormSingle, + timestep: mx.array, + batch_size: int, + hidden_dtype: mx.Dtype = None, + ) -> Tuple[mx.array, mx.array]: + timestep = timestep * self.timestep_scale_multiplier + timestep_emb, embedded_timestep = adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype) + timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1])) + embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])) + return timestep_emb, embedded_timestep + def _prepare_context( self, context: mx.array, @@ -72,9 +87,8 @@ class TransformerArgsPreprocessor: ) -> Tuple[mx.array, Optional[mx.array]]: batch_size = x.shape[0] - # Context is already processed through embeddings connector in text encoder - # Here we just apply the caption projection - context = self.caption_projection(context) + if self.caption_projection is not None: + context = self.caption_projection(context) context = mx.reshape(context, (batch_size, -1, x.shape[-1])) return context, attention_mask @@ -134,6 +148,14 @@ class TransformerArgsPreprocessor: num_attention_heads=self.num_attention_heads, ) + # Prompt-conditioned timestep (LTX-2.3) - uses raw sigma, not per-token timesteps + prompt_timestep = None + prompt_embedded_timestep = None + if self.prompt_adaln is not None and modality.sigma is not None: + prompt_timestep, prompt_embedded_timestep = self._prepare_timestep_with_adaln( + self.prompt_adaln, modality.sigma, x.shape[0], hidden_dtype=x.dtype, + ) + return TransformerArgs( x=x, context=context, @@ -145,6 +167,8 @@ class TransformerArgsPreprocessor: cross_scale_shift_timestep=None, cross_gate_timestep=None, enabled=modality.enabled, + prompt_timesteps=prompt_timestep, + prompt_embedded_timestep=prompt_embedded_timestep, ) @@ -154,7 +178,7 @@ class MultiModalTransformerArgsPreprocessor: self, patchify_proj: nn.Linear, adaln: AdaLayerNormSingle, - caption_projection: PixArtAlphaTextProjection, + caption_projection: Optional[PixArtAlphaTextProjection], cross_scale_shift_adaln: AdaLayerNormSingle, cross_gate_adaln: AdaLayerNormSingle, inner_dim: int, @@ -168,6 +192,7 @@ class MultiModalTransformerArgsPreprocessor: rope_type: LTXRopeType, av_ca_timestep_scale_multiplier: int, double_precision_rope: bool = False, + prompt_adaln: Optional[AdaLayerNormSingle] = None, ): self.simple_preprocessor = TransformerArgsPreprocessor( patchify_proj=patchify_proj, @@ -181,6 +206,7 @@ class MultiModalTransformerArgsPreprocessor: positional_embedding_theta=positional_embedding_theta, rope_type=rope_type, double_precision_rope=double_precision_rope, + prompt_adaln=prompt_adaln, ) self.cross_scale_shift_adaln = cross_scale_shift_adaln self.cross_gate_adaln = cross_gate_adaln @@ -280,11 +306,17 @@ class LTXModel(nn.Module): def _init_video(self, config: LTXModelConfig) -> None: self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True) - self.adaln_single = AdaLayerNormSingle(self.inner_dim) - self.caption_projection = PixArtAlphaTextProjection( - in_features=config.caption_channels, - hidden_size=self.inner_dim, - ) + + adaln_coefficient = 9 if config.has_prompt_adaln else 6 + self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=adaln_coefficient) + + if config.has_prompt_adaln: + self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2) + else: + self.caption_projection = PixArtAlphaTextProjection( + in_features=config.caption_channels, + hidden_size=self.inner_dim, + ) self.scale_shift_table = mx.zeros((2, self.inner_dim)) self.norm_out = nn.LayerNorm(self.inner_dim, eps=config.norm_eps, affine=False) @@ -292,13 +324,17 @@ class LTXModel(nn.Module): def _init_audio(self, config: LTXModelConfig) -> None: self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True) - self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim) - # Audio caption projection: receives pre-processed embeddings from text encoder's audio_embeddings_connector - self.audio_caption_projection = PixArtAlphaTextProjection( - in_features=config.audio_caption_channels, - hidden_size=self.audio_inner_dim, - ) + audio_adaln_coefficient = 9 if config.has_prompt_adaln else 6 + self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient) + + if config.has_prompt_adaln: + self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2) + else: + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=config.audio_caption_channels, + hidden_size=self.audio_inner_dim, + ) # Output components self.audio_scale_shift_table = mx.zeros((2, self.audio_inner_dim)) @@ -331,7 +367,7 @@ class LTXModel(nn.Module): self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor( patchify_proj=self.patchify_proj, adaln=self.adaln_single, - caption_projection=self.caption_projection, + caption_projection=getattr(self, "caption_projection", None), cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single, cross_gate_adaln=self.av_ca_a2v_gate_adaln_single, inner_dim=self.inner_dim, @@ -345,11 +381,12 @@ class LTXModel(nn.Module): rope_type=config.rope_type, av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier, double_precision_rope=config.double_precision_rope, + prompt_adaln=getattr(self, "prompt_adaln_single", None), ) self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor( patchify_proj=self.audio_patchify_proj, adaln=self.audio_adaln_single, - caption_projection=self.audio_caption_projection, + caption_projection=getattr(self, "audio_caption_projection", None), cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single, cross_gate_adaln=self.av_ca_v2a_gate_adaln_single, inner_dim=self.audio_inner_dim, @@ -363,12 +400,13 @@ class LTXModel(nn.Module): rope_type=config.rope_type, av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier, double_precision_rope=config.double_precision_rope, + prompt_adaln=getattr(self, "audio_prompt_adaln_single", None), ) elif config.model_type.is_video_enabled(): self.video_args_preprocessor = TransformerArgsPreprocessor( patchify_proj=self.patchify_proj, adaln=self.adaln_single, - caption_projection=self.caption_projection, + caption_projection=getattr(self, "caption_projection", None), inner_dim=self.inner_dim, max_pos=config.positional_embedding_max_pos, num_attention_heads=self.num_attention_heads, @@ -377,12 +415,13 @@ class LTXModel(nn.Module): positional_embedding_theta=config.positional_embedding_theta, rope_type=config.rope_type, double_precision_rope=config.double_precision_rope, + prompt_adaln=getattr(self, "prompt_adaln_single", None), ) elif config.model_type.is_audio_enabled(): self.audio_args_preprocessor = TransformerArgsPreprocessor( patchify_proj=self.audio_patchify_proj, adaln=self.audio_adaln_single, - caption_projection=self.audio_caption_projection, + caption_projection=getattr(self, "audio_caption_projection", None), inner_dim=self.audio_inner_dim, max_pos=config.audio_positional_embedding_max_pos, num_attention_heads=self.audio_num_attention_heads, @@ -391,13 +430,13 @@ class LTXModel(nn.Module): positional_embedding_theta=config.positional_embedding_theta, rope_type=config.rope_type, double_precision_rope=config.double_precision_rope, + prompt_adaln=getattr(self, "audio_prompt_adaln_single", None), ) def _init_transformer_blocks(self, config: LTXModelConfig) -> None: video_config = config.get_video_config() audio_config = config.get_audio_config() - self.transformer_blocks = { idx: BasicAVTransformerBlock( idx=idx, @@ -405,6 +444,7 @@ class LTXModel(nn.Module): audio=audio_config, rope_type=config.rope_type, norm_eps=config.norm_eps, + has_prompt_adaln=config.has_prompt_adaln, ) for idx in range(config.num_layers) } diff --git a/mlx_video/models/ltx/text_encoder.py b/mlx_video/models/ltx/text_encoder.py index f6824f8..1c16524 100644 --- a/mlx_video/models/ltx/text_encoder.py +++ b/mlx_video/models/ltx/text_encoder.py @@ -208,6 +208,7 @@ class ConnectorAttention(nn.Module): dim: int = 3840, num_heads: int = 30, head_dim: int = 128, + has_gate_logits: bool = False, ): super().__init__() self.num_heads = num_heads @@ -218,13 +219,14 @@ class ConnectorAttention(nn.Module): self.to_q = nn.Linear(dim, inner_dim, bias=True) self.to_k = nn.Linear(dim, inner_dim, bias=True) self.to_v = nn.Linear(dim, inner_dim, bias=True) - # Direct attribute for MLX parameter tracking (not a list) self.to_out = nn.Linear(inner_dim, dim, bias=True) - # Standard RMSNorm (not Gemma-style) on full inner_dim self.q_norm = nn.RMSNorm(inner_dim, eps=1e-6) self.k_norm = nn.RMSNorm(inner_dim, eps=1e-6) + if has_gate_logits: + self.to_gate_logits = nn.Linear(dim, num_heads, bias=True) + def __call__( self, x: mx.array, @@ -233,12 +235,17 @@ class ConnectorAttention(nn.Module): ) -> mx.array: batch_size, seq_len, _ = x.shape + # Compute per-head gate early (from original input) + gate = None + if hasattr(self, "to_gate_logits"): + gate = 2.0 * mx.sigmoid(self.to_gate_logits(x)) # (B, seq, heads) + # Project to Q, K, V - q = self.to_q(x) # (B, seq, inner_dim) + q = self.to_q(x) k = self.to_k(x) v = self.to_v(x) - # QK normalization on full inner_dim BEFORE reshape (matches PyTorch) + # QK normalization on full inner_dim BEFORE reshape q = self.q_norm(q) k = self.k_norm(k) @@ -248,15 +255,18 @@ class ConnectorAttention(nn.Module): v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3) if pe is not None: - # pe: tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2) - # Apply SPLIT RoPE: operates on first half of head dimensions q = self._apply_split_rope(q, pe[0], pe[1]) k = self._apply_split_rope(k, pe[0], pe[1]) - # No mask needed for connector - after register replacement, all positions are valid out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None) out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1) + # Apply per-head gating + if gate is not None: + out = mx.reshape(out, (batch_size, seq_len, self.num_heads, self.head_dim)) + out = out * gate[..., None] + out = mx.reshape(out, (batch_size, seq_len, -1)) + return self.to_out(out) def _apply_split_rope( @@ -326,9 +336,9 @@ class ConnectorFeedForward(nn.Module): class ConnectorTransformerBlock(nn.Module): - def __init__(self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128): + def __init__(self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128, has_gate_logits: bool = False): super().__init__() - self.attn1 = ConnectorAttention(dim, num_heads, head_dim) + self.attn1 = ConnectorAttention(dim, num_heads, head_dim, has_gate_logits=has_gate_logits) self.ff = ConnectorFeedForward(dim) def __call__( @@ -367,6 +377,7 @@ class Embeddings1DConnector(nn.Module): num_learnable_registers: int = 128, positional_embedding_theta: float = 10000.0, positional_embedding_max_pos: list = None, + has_gate_logits: bool = False, ): super().__init__() self.dim = dim @@ -376,9 +387,8 @@ class Embeddings1DConnector(nn.Module): self.positional_embedding_theta = positional_embedding_theta self.positional_embedding_max_pos = positional_embedding_max_pos or [4096] - # Use dict with int keys for MLX to track parameters (lists are not tracked) self.transformer_1d_blocks = { - i: ConnectorTransformerBlock(dim, num_heads, head_dim) + i: ConnectorTransformerBlock(dim, num_heads, head_dim, has_gate_logits=has_gate_logits) for i in range(num_layers) } @@ -572,16 +582,100 @@ def norm_and_concat_hidden_states( return normed -class GemmaFeaturesExtractor(nn.Module): +def norm_and_concat_per_token_rms( + encoded_text: mx.array, + attention_mask: mx.array, +) -> mx.array: + """Per-token RMSNorm normalization for V2 feature extraction (LTX-2.3). - def __init__(self, input_dim: int = 188160, output_dim: int = 3840): + Args: + encoded_text: (B, T, D, L) stacked hidden states + attention_mask: (B, T) binary mask (1=valid, 0=padding) + + Returns: + (B, T, D*L) normalized tensor with padding zeroed out. + """ + b, t, d, num_layers = encoded_text.shape + dtype = encoded_text.dtype + + # Per-token RMSNorm across hidden dimension: variance = mean(x^2) over dim D + variance = mx.mean(encoded_text.astype(mx.float32) ** 2, axis=2, keepdims=True) # (B, T, 1, L) + normed = encoded_text.astype(mx.float32) * mx.rsqrt(variance + 1e-6) + normed = normed.astype(dtype) + + # Flatten layers: (B, T, D*L) + normed = mx.reshape(normed, (b, t, d * num_layers)) + + # Zero out padded positions + mask_3d = attention_mask[:, :, None].astype(mx.bool_) # (B, T, 1) + normed = mx.where(mask_3d, normed, mx.zeros_like(normed)) + + return normed + + +def _rescale_norm(x: mx.array, target_dim: int, source_dim: int) -> mx.array: + """Rescale normalization: x * sqrt(target_dim / source_dim).""" + return x * math.sqrt(target_dim / source_dim) + + +class GemmaFeaturesExtractor(nn.Module): + """V1 feature extractor (LTX-2): 8 * (x - mean) / range normalization.""" + + def __init__(self, input_dim: int = 188160, output_dim: int = 3840, bias: bool = False): super().__init__() - self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=False) + self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=bias) def __call__(self, x: mx.array) -> mx.array: return self.aggregate_embed(x) +class GemmaFeaturesExtractorV2(nn.Module): + """V2 feature extractor (LTX-2.3): per-token RMSNorm + rescale normalization.""" + + def __init__( + self, + flat_dim: int, + embedding_dim: int, + video_output_dim: int, + audio_output_dim: int, + bias: bool = True, + ): + super().__init__() + self.embedding_dim = embedding_dim # Gemma hidden_dim (3840), used for rescale + self.video_aggregate_embed = nn.Linear(flat_dim, video_output_dim, bias=bias) + self.audio_aggregate_embed = nn.Linear(flat_dim, audio_output_dim, bias=bias) + + def __call__( + self, + hidden_states: List[mx.array], + attention_mask: mx.array, + mode: str = "video", + ) -> mx.array: + """Extract features with per-token RMSNorm + rescale. + + Args: + hidden_states: List of hidden states from all Gemma layers + attention_mask: Binary attention mask (B, T) + mode: "video" or "audio" to select which aggregate embed to use + + Returns: + Projected features + """ + # Stack hidden states: (B, T, D, L) + encoded = mx.stack(hidden_states, axis=-1) + + # Per-token RMSNorm + flatten + normed = norm_and_concat_per_token_rms(encoded, attention_mask) + normed = normed.astype(encoded.dtype) + + if mode == "video": + target_dim = self.video_aggregate_embed.weight.shape[0] + return self.video_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim)) + else: + target_dim = self.audio_aggregate_embed.weight.shape[0] + return self.audio_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim)) + + @@ -603,39 +697,54 @@ class LTX2TextEncoder(nn.Module): hidden_dim: int = 3840, audio_dim: int = 2048, num_layers: int = 49, # 48 transformer layers + 1 embedding + has_prompt_adaln: bool = False, ): super().__init__() self.hidden_dim = hidden_dim self.audio_dim = audio_dim self.num_layers = num_layers + self.has_prompt_adaln = has_prompt_adaln self.language_model = None - # Feature extractor: 3840*49 -> 3840 - self.feature_extractor = GemmaFeaturesExtractor( - input_dim=hidden_dim * num_layers, - output_dim=hidden_dim, - ) + feature_input_dim = hidden_dim * num_layers - # Video embeddings connector: 2-layer transformer - self.video_embeddings_connector = Embeddings1DConnector( - dim=hidden_dim, - num_heads=30, - head_dim=128, - num_layers=2, - num_learnable_registers=128, - positional_embedding_max_pos=[4096], # Match PyTorch - ) + if has_prompt_adaln: + # LTX-2.3: V2 feature extractor with per-token RMSNorm + rescale + video_output_dim = 4096 + audio_output_dim = 2048 + self.feature_extractor_v2 = GemmaFeaturesExtractorV2( + flat_dim=feature_input_dim, # 3840 * 49 = 188160 (concatenated) + embedding_dim=hidden_dim, # 3840 (Gemma hidden_dim, for rescale) + video_output_dim=video_output_dim, + audio_output_dim=audio_output_dim, + bias=True, + ) - # Audio embeddings connector: separate 2-layer transformer (same architecture as video) - # Both connectors process the feature extractor output independently - self.audio_embeddings_connector = Embeddings1DConnector( - dim=hidden_dim, - num_heads=30, - head_dim=128, - num_layers=2, - num_learnable_registers=128, - positional_embedding_max_pos=[4096], # Match PyTorch - ) + # Deeper connectors with matching dims and gate_logits + self.video_embeddings_connector = Embeddings1DConnector( + dim=video_output_dim, num_heads=32, head_dim=128, + num_layers=8, num_learnable_registers=128, + positional_embedding_max_pos=[4096], has_gate_logits=True, + ) + self.audio_embeddings_connector = Embeddings1DConnector( + dim=audio_output_dim, num_heads=32, head_dim=64, + num_layers=8, num_learnable_registers=128, + positional_embedding_max_pos=[4096], has_gate_logits=True, + ) + else: + # LTX-2: shared feature extractor, 3840-dim connectors + self.feature_extractor = GemmaFeaturesExtractor(feature_input_dim, hidden_dim) + + self.video_embeddings_connector = Embeddings1DConnector( + dim=hidden_dim, num_heads=30, head_dim=128, + num_layers=2, num_learnable_registers=128, + positional_embedding_max_pos=[4096], + ) + self.audio_embeddings_connector = Embeddings1DConnector( + dim=hidden_dim, num_heads=30, head_dim=128, + num_layers=2, num_learnable_registers=128, + positional_embedding_max_pos=[4096], + ) self.processor = None @@ -669,81 +778,9 @@ class LTX2TextEncoder(nn.Module): transformer_weights = mx.load(str(transformer_files[0])) if transformer_weights: - # Load feature extractor (aggregate_embed) - # Reformatted key: "aggregate_embed.weight" - # Monolithic key: "text_embedding_projection.aggregate_embed.weight" - agg_key = "aggregate_embed.weight" if is_reformatted else "text_embedding_projection.aggregate_embed.weight" - if agg_key in transformer_weights: - self.feature_extractor.aggregate_embed.weight = transformer_weights[agg_key] - - # Load video_embeddings_connector weights - connector_weights = {} - if is_reformatted: - # Reformatted: keys are already sanitized with "video_embeddings_connector." prefix - for key, value in transformer_weights.items(): - if key.startswith("video_embeddings_connector."): - new_key = key.replace("video_embeddings_connector.", "") - connector_weights[new_key] = value - else: - # Monolithic: keys have "model.diffusion_model.video_embeddings_connector." prefix - for key, value in transformer_weights.items(): - if key.startswith("model.diffusion_model.video_embeddings_connector."): - new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "") - connector_weights[new_key] = value - - if connector_weights: - # Map weight names to our structure (only needed for monolithic/raw PyTorch keys) - mapped_weights = {} - for key, value in connector_weights.items(): - new_key = key - if not is_reformatted: - # Map ff.net.0.proj -> ff.proj_in (GEGLU projection) - new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") - # Map ff.net.2 -> ff.proj_out (output Linear) - new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") - # Map to_out.0 -> to_out (Sequential -> direct) - new_key = new_key.replace(".to_out.0.", ".to_out.") - mapped_weights[new_key] = value - - self.video_embeddings_connector.load_weights( - list(mapped_weights.items()), strict=False - ) - - # Manually load learnable_registers (it's a plain mx.array, not a parameter) - if "learnable_registers" in connector_weights: - self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"] - - # Load audio_embeddings_connector weights (same structure as video connector) - audio_connector_weights = {} - if is_reformatted: - for key, value in transformer_weights.items(): - if key.startswith("audio_embeddings_connector."): - new_key = key.replace("audio_embeddings_connector.", "") - audio_connector_weights[new_key] = value - else: - for key, value in transformer_weights.items(): - if key.startswith("model.diffusion_model.audio_embeddings_connector."): - new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "") - audio_connector_weights[new_key] = value - - if audio_connector_weights: - # Map weight names to our structure (same as video connector) - mapped_audio_weights = {} - for key, value in audio_connector_weights.items(): - new_key = key - if not is_reformatted: - new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") - new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") - new_key = new_key.replace(".to_out.0.", ".to_out.") - mapped_audio_weights[new_key] = value - - self.audio_embeddings_connector.load_weights( - list(mapped_audio_weights.items()), strict=False - ) - - # Manually load learnable_registers (it's a plain mx.array, not a parameter) - if "learnable_registers" in audio_connector_weights: - self.audio_embeddings_connector.learnable_registers = audio_connector_weights["learnable_registers"] + self._load_feature_extractors(transformer_weights, is_reformatted) + self._load_connector("video_embeddings_connector", transformer_weights, is_reformatted) + self._load_connector("audio_embeddings_connector", transformer_weights, is_reformatted) else: print("WARNING: No transformer weights found for text projection connectors. " "Text conditioning will use uninitialized weights!") @@ -763,6 +800,63 @@ class LTX2TextEncoder(nn.Module): print("Text encoder loaded successfully") + def _load_feature_extractors(self, weights: dict, is_reformatted: bool): + """Load feature extractor weights for both LTX-2 and LTX-2.3.""" + if self.has_prompt_adaln: + # LTX-2.3: V2 feature extractor with separate video/audio aggregate embeds + for attr, prefix in [ + ("video_aggregate_embed", "video_aggregate_embed"), + ("audio_aggregate_embed", "audio_aggregate_embed"), + ]: + w_key = f"{prefix}.weight" + b_key = f"{prefix}.bias" + if w_key in weights: + submodule = getattr(self.feature_extractor_v2, attr) + submodule.weight = weights[w_key] + if b_key in weights: + submodule.bias = weights[b_key] + else: + # LTX-2: single aggregate_embed + agg_key = "aggregate_embed.weight" if is_reformatted else "text_embedding_projection.aggregate_embed.weight" + if agg_key in weights: + self.feature_extractor.aggregate_embed.weight = weights[agg_key] + + def _load_connector(self, name: str, weights: dict, is_reformatted: bool): + """Load a connector's weights (video or audio).""" + connector = getattr(self, name) + + # Extract connector-specific weights + connector_weights = {} + if is_reformatted: + prefix = f"{name}." + for key, value in weights.items(): + if key.startswith(prefix): + connector_weights[key[len(prefix):]] = value + else: + mono_prefix = f"model.diffusion_model.{name}." + for key, value in weights.items(): + if key.startswith(mono_prefix): + connector_weights[key[len(mono_prefix):]] = value + + if not connector_weights: + return + + # Sanitize key names (only needed for monolithic/raw PyTorch keys) + mapped = {} + for key, value in connector_weights.items(): + new_key = key + if not is_reformatted: + new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") + new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") + new_key = new_key.replace(".to_out.0.", ".to_out.") + mapped[new_key] = value + + connector.load_weights(list(mapped.items()), strict=False) + + # Manually load learnable_registers (plain mx.array, not tracked as parameter) + if "learnable_registers" in connector_weights: + connector.learnable_registers = connector_weights["learnable_registers"] + def encode( self, prompt: str, @@ -795,21 +889,40 @@ class LTX2TextEncoder(nn.Module): attention_mask = mx.array(inputs["attention_mask"]) _, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True) - concat_hidden = norm_and_concat_hidden_states( - all_hidden_states, attention_mask, padding_side="left" - ) - features = self.feature_extractor(concat_hidden) - additive_mask = (attention_mask - 1).astype(features.dtype) - additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 - video_embeddings, _ = self.video_embeddings_connector(features, additive_mask) + if self.has_prompt_adaln: + # LTX-2.3: V2 feature extraction (per-token RMSNorm + rescale) + video_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="video") + additive_mask = (attention_mask - 1).astype(video_features.dtype) + additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 - if return_audio_embeddings: - # Process features through audio connector independently (same input as video) - audio_embeddings, _ = self.audio_embeddings_connector(features, additive_mask) - return video_embeddings, audio_embeddings + video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask) + + if return_audio_embeddings: + audio_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="audio") + audio_mask = (attention_mask - 1).astype(audio_features.dtype) + audio_mask = audio_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 + audio_embeddings, _ = self.audio_embeddings_connector(audio_features, audio_mask) + return video_embeddings, audio_embeddings + else: + return video_embeddings, attention_mask else: - return video_embeddings, attention_mask + # LTX-2: V1 feature extraction (8 * (x - mean) / range) + concat_hidden = norm_and_concat_hidden_states( + all_hidden_states, attention_mask, padding_side="left" + ) + + video_features = self.feature_extractor(concat_hidden) + additive_mask = (attention_mask - 1).astype(video_features.dtype) + additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 + + video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask) + + if return_audio_embeddings: + audio_embeddings, _ = self.audio_embeddings_connector(video_features, additive_mask) + return video_embeddings, audio_embeddings + else: + return video_embeddings, attention_mask def __call__( self, diff --git a/mlx_video/models/ltx/transformer.py b/mlx_video/models/ltx/transformer.py index 5a60989..4b311e6 100644 --- a/mlx_video/models/ltx/transformer.py +++ b/mlx_video/models/ltx/transformer.py @@ -20,20 +20,25 @@ class Modality: context_mask: Optional[mx.array] = None # Optional precomputed positional embeddings (RoPE) to avoid recomputation positional_embeddings: Optional[Tuple[mx.array, mx.array]] = None + # Raw sigma value (scalar per batch) for prompt adaln (LTX-2.3) + sigma: Optional[mx.array] = None @dataclass(frozen=True) class TransformerArgs: - x: mx.array - context: mx.array - context_mask: Optional[mx.array] - timesteps: mx.array - embedded_timestep: mx.array - positional_embeddings: Tuple[mx.array, mx.array] - cross_positional_embeddings: Optional[Tuple[mx.array, mx.array]] - cross_scale_shift_timestep: Optional[mx.array] - cross_gate_timestep: Optional[mx.array] + x: mx.array + context: mx.array + context_mask: Optional[mx.array] + timesteps: mx.array + embedded_timestep: mx.array + positional_embeddings: Tuple[mx.array, mx.array] + cross_positional_embeddings: Optional[Tuple[mx.array, mx.array]] + cross_scale_shift_timestep: Optional[mx.array] + cross_gate_timestep: Optional[mx.array] enabled: bool + # LTX-2.3: prompt-conditioned timestep embeddings for cross-attention + prompt_timesteps: Optional[mx.array] = None + prompt_embedded_timestep: Optional[mx.array] = None class BasicAVTransformerBlock(nn.Module): @@ -50,20 +55,13 @@ class BasicAVTransformerBlock(nn.Module): audio: Optional[TransformerConfig] = None, rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, norm_eps: float = 1e-6, + has_prompt_adaln: bool = False, ): - """Initialize transformer block. - - Args: - idx: Block index - video: Video modality configuration - audio: Audio modality configuration - rope_type: Type of rotary position embedding - norm_eps: Epsilon for normalization - """ super().__init__() self.idx = idx self.norm_eps = norm_eps + self.has_prompt_adaln = has_prompt_adaln # Video components if video is not None: @@ -74,6 +72,7 @@ class BasicAVTransformerBlock(nn.Module): context_dim=None, # Self-attention rope_type=rope_type, norm_eps=norm_eps, + has_gate_logits=has_prompt_adaln, ) self.attn2 = Attention( query_dim=video.dim, @@ -82,10 +81,15 @@ class BasicAVTransformerBlock(nn.Module): dim_head=video.d_head, rope_type=rope_type, norm_eps=norm_eps, + has_gate_logits=has_prompt_adaln, ) self.ff = FeedForward(video.dim, dim_out=video.dim) - # 6 scale-shift parameters: 3 for attention, 3 for MLP - self.scale_shift_table = mx.zeros((6, video.dim)) + # 9 params for LTX-2.3 (self-attn + cross-attn + FFN), 6 for LTX-2 + num_ada_params = 9 if has_prompt_adaln else 6 + self.scale_shift_table = mx.zeros((num_ada_params, video.dim)) + + if has_prompt_adaln: + self.prompt_scale_shift_table = mx.zeros((2, video.dim)) # Audio components if audio is not None: @@ -96,6 +100,7 @@ class BasicAVTransformerBlock(nn.Module): context_dim=None, rope_type=rope_type, norm_eps=norm_eps, + has_gate_logits=has_prompt_adaln, ) self.audio_attn2 = Attention( query_dim=audio.dim, @@ -104,9 +109,14 @@ class BasicAVTransformerBlock(nn.Module): dim_head=audio.d_head, rope_type=rope_type, norm_eps=norm_eps, + has_gate_logits=has_prompt_adaln, ) self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim) - self.audio_scale_shift_table = mx.zeros((6, audio.dim)) + num_audio_ada_params = 9 if has_prompt_adaln else 6 + self.audio_scale_shift_table = mx.zeros((num_audio_ada_params, audio.dim)) + + if has_prompt_adaln: + self.audio_prompt_scale_shift_table = mx.zeros((2, audio.dim)) # Cross-modal attention (when both video and audio are enabled) if audio is not None and video is not None: @@ -118,6 +128,7 @@ class BasicAVTransformerBlock(nn.Module): dim_head=audio.d_head, rope_type=rope_type, norm_eps=norm_eps, + has_gate_logits=has_prompt_adaln, ) # Video-to-Audio: Q from audio, K/V from video self.video_to_audio_attn = Attention( @@ -127,6 +138,7 @@ class BasicAVTransformerBlock(nn.Module): dim_head=audio.d_head, rope_type=rope_type, norm_eps=norm_eps, + has_gate_logits=has_prompt_adaln, ) # Scale-shift tables for cross-attention self.scale_shift_table_a2v_ca_audio = mx.zeros((5, audio.dim)) @@ -254,11 +266,23 @@ class BasicAVTransformerBlock(nn.Module): vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa # Cross-attention with text context - vx = vx + self.attn2( - rms_norm(vx, eps=self.norm_eps), - context=video.context, - mask=video.context_mask, - ) + if self.has_prompt_adaln: + # LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln + vshift_q, vscale_q, vgate_q = self.get_ada_values( + self.scale_shift_table, vx.shape[0], video.timesteps, slice(6, 9) + ) + vprompt_shift_kv, vprompt_scale_kv = self.get_ada_values( + self.prompt_scale_shift_table, vx.shape[0], video.prompt_timesteps, slice(0, 2) + ) + attn_input = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_q) + vshift_q + encoder_hidden_states = video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv + vx = vx + self.attn2(attn_input, context=encoder_hidden_states, mask=video.context_mask) * vgate_q + else: + vx = vx + self.attn2( + rms_norm(vx, eps=self.norm_eps), + context=video.context, + mask=video.context_mask, + ) # Process audio self-attention and cross-attention with text if run_ax: @@ -271,11 +295,23 @@ class BasicAVTransformerBlock(nn.Module): ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa # Cross-attention with text context - ax = ax + self.audio_attn2( - rms_norm(ax, eps=self.norm_eps), - context=audio.context, - mask=audio.context_mask, - ) + if self.has_prompt_adaln: + # LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln + ashift_q, ascale_q, agate_q = self.get_ada_values( + self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(6, 9) + ) + aprompt_shift_kv, aprompt_scale_kv = self.get_ada_values( + self.audio_prompt_scale_shift_table, ax.shape[0], audio.prompt_timesteps, slice(0, 2) + ) + attn_input_a = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q + encoder_hidden_states_a = audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv + ax = ax + self.audio_attn2(attn_input_a, context=encoder_hidden_states_a, mask=audio.context_mask) * agate_q + else: + ax = ax + self.audio_attn2( + rms_norm(ax, eps=self.norm_eps), + context=audio.context, + mask=audio.context_mask, + ) # Audio-Video cross-modal attention if run_a2v or run_v2a: @@ -341,7 +377,7 @@ class BasicAVTransformerBlock(nn.Module): # Process video feed-forward if run_vx: vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values( - self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None) + self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, 6) ) vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp vx = vx + self.ff(vx_scaled) * vgate_mlp @@ -349,7 +385,7 @@ class BasicAVTransformerBlock(nn.Module): # Process audio feed-forward if run_ax: ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values( - self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None) + self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6) ) ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp ax = ax + self.audio_ff(ax_scaled) * agate_mlp diff --git a/mlx_video/models/ltx/upsampler.py b/mlx_video/models/ltx/upsampler.py index 7f43536..1180664 100644 --- a/mlx_video/models/ltx/upsampler.py +++ b/mlx_video/models/ltx/upsampler.py @@ -301,15 +301,14 @@ def upsample_latents( latent_std: mx.array, debug: bool = False, ) -> mx.array: - # Un-normalize: latent * std + mean latent_mean = latent_mean.reshape(1, -1, 1, 1, 1) latent_std = latent_std.reshape(1, -1, 1, 1, 1) latent = latent * latent_std + latent_mean - + # Upsample latent = upsampler(latent, debug=debug) - + # Re-normalize: (latent - mean) / std latent = (latent - latent_mean) / latent_std @@ -350,19 +349,18 @@ def load_upsampler(weights_path: str) -> LatentUpsampler: for key, value in raw_weights.items(): new_key = key + # LTX-2.3 upsampler uses sequential indexing: upsampler.0.* -> upsampler.conv.* + if key.startswith("upsampler.0."): + new_key = key.replace("upsampler.0.", "upsampler.conv.") + # Conv3d weights: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I) - if "conv" in key and "weight" in key and value.ndim == 5: + if "weight" in new_key and value.ndim == 5: value = mx.transpose(value, (0, 2, 3, 4, 1)) # Conv2d weights: PyTorch (O, I, H, W) -> MLX (O, H, W, I) - if "conv" in key and "weight" in key and value.ndim == 4: + if "weight" in new_key and value.ndim == 4: value = mx.transpose(value, (0, 2, 3, 1)) - # Map upsampler.conv to upsampler.conv (SpatialRationalResampler) - # Keys: upsampler.conv.weight, upsampler.conv.bias, upsampler.blur_down.kernel - if key.startswith("upsampler."): - new_key = key # Keep as is for SpatialRationalResampler - sanitized[new_key] = value # Load weights diff --git a/mlx_video/models/ltx/video_vae/decoder.py b/mlx_video/models/ltx/video_vae/decoder.py index 5f45d8a..105082c 100644 --- a/mlx_video/models/ltx/video_vae/decoder.py +++ b/mlx_video/models/ltx/video_vae/decoder.py @@ -250,6 +250,18 @@ class LTX2VideoDecoder(nn.Module): - conv_out: 128 -> 48 (3 * 4^2 for patch_size=4) """ + # Block definitions: ("res", channels, num_layers) or ("d2s", in_channels, reduction, stride) + # stride is (D, H, W) tuple + DEFAULT_BLOCKS = [ + ("res", 1024, 5), + ("d2s", 1024, 2, (2, 2, 2)), + ("res", 512, 5), + ("d2s", 512, 2, (2, 2, 2)), + ("res", 256, 5), + ("d2s", 256, 2, (2, 2, 2)), + ("res", 128, 5), + ] + def __init__( self, in_channels: int = 128, @@ -258,6 +270,7 @@ class LTX2VideoDecoder(nn.Module): num_layers_per_block: int = 5, spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT, timestep_conditioning: bool = True, + decoder_blocks: list = None, ): super().__init__() @@ -272,13 +285,17 @@ class LTX2VideoDecoder(nn.Module): # Per-channel statistics for denormalization (loaded from weights) self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels) - # Initial conv: 128 -> 1024 + blocks = decoder_blocks or self.DEFAULT_BLOCKS + first_ch = blocks[0][1] + last_ch = blocks[-1][1] + + # Initial conv: in_channels -> first block channels class ConvInWrapper(nn.Module): def __init__(self_inner): super().__init__() self_inner.conv = CausalConv3d( in_channels=in_channels, - out_channels=1024, + out_channels=first_ch, kernel_size=3, stride=1, padding=1, @@ -288,45 +305,32 @@ class LTX2VideoDecoder(nn.Module): return self_inner.conv(x, causal=causal) self.conv_in = ConvInWrapper() - # Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample - # Use dict with int keys for MLX to track parameters properly - self.up_blocks = { - 0: ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning), - 1: DepthToSpaceUpsample( - dims=3, - in_channels=1024, - stride=(2, 2, 2), - residual=True, - out_channels_reduction_factor=2, - spatial_padding_mode=spatial_padding_mode, - ), - 2: ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning), - 3: DepthToSpaceUpsample( - dims=3, - in_channels=512, - stride=(2, 2, 2), - residual=True, - out_channels_reduction_factor=2, - spatial_padding_mode=spatial_padding_mode, - ), - 4: ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning), - 5: DepthToSpaceUpsample( - dims=3, - in_channels=256, - stride=(2, 2, 2), - residual=True, - out_channels_reduction_factor=2, - spatial_padding_mode=spatial_padding_mode, - ), - 6: ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning), - } + # Build up blocks from config + self.up_blocks = {} + for idx, block_def in enumerate(blocks): + block_type = block_def[0] + ch = block_def[1] + if block_type == "res": + num_layers = block_def[2] if len(block_def) > 2 else num_layers_per_block + self.up_blocks[idx] = ResBlockGroup(ch, num_layers, spatial_padding_mode, timestep_conditioning) + elif block_type == "d2s": + reduction = block_def[2] if len(block_def) > 2 else 2 + stride = block_def[3] if len(block_def) > 3 else (2, 2, 2) + self.up_blocks[idx] = DepthToSpaceUpsample( + dims=3, + in_channels=ch, + stride=stride, + residual=True, + out_channels_reduction_factor=reduction, + spatial_padding_mode=spatial_padding_mode, + ) final_out_channels = out_channels * patch_size * patch_size class ConvOutWrapper(nn.Module): def __init__(self_inner): super().__init__() self_inner.conv = CausalConv3d( - in_channels=128, + in_channels=last_ch, out_channels=final_out_channels, kernel_size=3, stride=1, @@ -342,9 +346,9 @@ class LTX2VideoDecoder(nn.Module): if timestep_conditioning: self.timestep_scale_multiplier = mx.array(1000.0) self.last_time_embedder = PixArtAlphaTimestepEmbedder( - embedding_dim=128 * 2 # 256, matches (2, 128) table + embedding_dim=last_ch * 2 ) - self.last_scale_shift_table = mx.zeros((2, 128)) + self.last_scale_shift_table = mx.zeros((2, last_ch)) def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: # Build decoder weights dict with key remapping @@ -418,11 +422,96 @@ class LTX2VideoDecoder(nn.Module): weights.update(mx.load(str(wf))) - model = cls(timestep_conditioning=config_dict.get("timestep_conditioning", False)) + # Infer block structure from weights + decoder_blocks = cls._infer_blocks(weights) + + model = cls( + timestep_conditioning=config_dict.get("timestep_conditioning", False), + decoder_blocks=decoder_blocks, + ) weights = model.sanitize(weights) model.load_weights(list(weights.items()), strict=strict) return model + @staticmethod + def _infer_blocks(weights: dict) -> list: + """Infer decoder block structure from weight keys.""" + block_indices = set() + for k in weights: + if "up_blocks." in k: + idx_str = k.split("up_blocks.")[1].split(".")[0] + if idx_str.isdigit(): + block_indices.add(int(idx_str)) + + if not block_indices: + return None + + # First pass: collect block info + raw_blocks = [] + for idx in sorted(block_indices): + has_conv = any(f"up_blocks.{idx}.conv." in k for k in weights) + res_indices = set() + for k in weights: + prefix = f"up_blocks.{idx}.res_blocks." + if prefix in k: + res_idx = k.split(prefix)[1].split(".")[0] + if res_idx.isdigit(): + res_indices.add(int(res_idx)) + + if has_conv and not res_indices: + # D2S block - get conv shape + for k, v in weights.items(): + if f"up_blocks.{idx}.conv." in k and "weight" in k: + in_ch = v.shape[-1] if v.ndim == 5 else v.shape[1] + conv_out_ch = v.shape[0] + raw_blocks.append(("d2s", in_ch, conv_out_ch)) + break + elif res_indices: + num_res = max(res_indices) + 1 + for k, v in weights.items(): + if f"up_blocks.{idx}.res_blocks.0.conv1" in k and "weight" in k: + ch = v.shape[0] + raw_blocks.append(("res", ch, num_res)) + break + + # Second pass: determine d2s strides using the channel progression + # For each d2s block, the next res block tells us the expected output channels + blocks = [] + for i, block in enumerate(raw_blocks): + if block[0] == "res": + blocks.append(block) + elif block[0] == "d2s": + in_ch, conv_out_ch = block[1], block[2] + # Find next res block's channels + next_ch = None + for j in range(i + 1, len(raw_blocks)): + if raw_blocks[j][0] == "res": + next_ch = raw_blocks[j][1] + break + + if next_ch is None: + next_ch = in_ch // 2 # fallback + + # out_ch = in_ch // reduction + reduction = in_ch // next_ch if next_ch > 0 else 2 + + # conv_out = next_ch * multiplier → multiplier = conv_out / next_ch + multiplier = conv_out_ch // next_ch if next_ch > 0 else 8 + + # Determine stride from multiplier + if multiplier == 8: + stride = (2, 2, 2) + elif multiplier == 4: + stride = (1, 2, 2) + elif multiplier == 2: + stride = (2, 1, 1) + else: + stride = (2, 2, 2) + + blocks.append(("d2s", in_ch, reduction, stride)) + + return blocks if blocks else None + def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array: