diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 4121738..21790a7 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -344,6 +344,7 @@ def denoise_distilled( context=text_embeddings, context_mask=None, enabled=True, + sigma=mx.full((b,), sigma, dtype=dtype), ) audio_modality = None @@ -359,6 +360,7 @@ def denoise_distilled( context=audio_embeddings, context_mask=None, enabled=True, + sigma=mx.full((ab,), sigma, dtype=dtype), ) velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality) @@ -493,6 +495,8 @@ def denoise_dev( else: timesteps = mx.full((b, num_tokens), sigma, dtype=dtype) + sigma_array = mx.full((b,), sigma, dtype=dtype) + # Positive conditioning pass video_modality_pos = Modality( latent=latents_flat, @@ -502,6 +506,7 @@ def denoise_dev( context_mask=None, enabled=True, positional_embeddings=precomputed_rope, + sigma=sigma_array, ) velocity_pos, _ = transformer(video=video_modality_pos, audio=None) @@ -523,6 +528,7 @@ def denoise_dev( context_mask=None, enabled=True, positional_embeddings=precomputed_rope, + sigma=sigma_array, ) velocity_neg, _ = transformer(video=video_modality_neg, audio=None) @@ -957,10 +963,18 @@ def generate_video( mx.random.seed(seed) + # Read transformer config to detect model version + import json + transformer_config_path = model_path / "transformer" / "config.json" + has_prompt_adaln = False + if transformer_config_path.exists(): + with open(transformer_config_path) as f: + has_prompt_adaln = json.load(f).get("has_prompt_adaln", False) + # Load text encoder with console.status("[blue]📝 Loading text encoder...[/]", spinner="dots"): from mlx_video.models.ltx.text_encoder import LTX2TextEncoder - text_encoder = LTX2TextEncoder() + text_encoder = LTX2TextEncoder(has_prompt_adaln=has_prompt_adaln) text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path) mx.eval(text_encoder.parameters()) console.print("[green]✓[/] Text encoder loaded") @@ -1084,7 +1098,10 @@ def generate_video( # Upsample latents with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"): - upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')) + upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors")) + if not upscaler_files: + raise FileNotFoundError(f"No spatial upscaler found in {model_path}") + upsampler = load_upsampler(str(upscaler_files[0])) mx.eval(upsampler.parameters()) vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder")) diff --git a/mlx_video/models/ltx/attention.py b/mlx_video/models/ltx/attention.py index 4024c91..ebc0a24 100644 --- a/mlx_video/models/ltx/attention.py +++ b/mlx_video/models/ltx/attention.py @@ -67,17 +67,8 @@ class Attention(nn.Module): dim_head: int = 64, norm_eps: float = 1e-6, rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + has_gate_logits: bool = False, ): - """Initialize attention module. - - Args: - query_dim: Dimension of query input - context_dim: Dimension of context (key/value) input. If None, same as query_dim - heads: Number of attention heads - dim_head: Dimension per head - norm_eps: Epsilon for RMS normalization - rope_type: Type of rotary position embedding - """ super().__init__() self.rope_type = rope_type @@ -99,6 +90,10 @@ class Attention(nn.Module): # Output projection self.to_out = nn.Linear(inner_dim, query_dim, bias=True) + # Per-head gating (LTX-2.3) + if has_gate_logits: + self.to_gate_logits = nn.Linear(query_dim, heads, bias=True) + def __call__( self, x: mx.array, @@ -119,6 +114,11 @@ class Attention(nn.Module): Returns: Attention output of shape (B, seq_len, query_dim) """ + # Compute per-head gate early (from original input) + gate = None + if hasattr(self, "to_gate_logits"): + gate = 2.0 * mx.sigmoid(self.to_gate_logits(x)) # (B, seq, heads) + # Compute Q, K, V q = self.to_q(x) context = x if context is None else context @@ -138,5 +138,12 @@ class Attention(nn.Module): # Compute attention out = scaled_dot_product_attention(q, k, v, self.heads, mask) + # Apply per-head gating + if gate is not None: + b, seq_len, _ = out.shape + out = mx.reshape(out, (b, seq_len, self.heads, self.dim_head)) + out = out * gate[..., None] + out = mx.reshape(out, (b, seq_len, -1)) + # Project output return self.to_out(out) diff --git a/mlx_video/models/ltx/config.py b/mlx_video/models/ltx/config.py index c63fcd7..40bb9ef 100644 --- a/mlx_video/models/ltx/config.py +++ b/mlx_video/models/ltx/config.py @@ -131,6 +131,12 @@ class LTXModelConfig(BaseModelConfig): # Attention type attention_type: AttentionType = AttentionType.DEFAULT + # LTX-2.3: prompt-conditioned adaptive layer norm + # Controls: gate_logits in attention, 9-param scale_shift_table, + # prompt_adaln_single, per-block prompt_scale_shift_table, + # removal of caption_projection + has_prompt_adaln: bool = False + # VAE config vae_config: Optional[VideoVAEConfig] = None diff --git a/mlx_video/models/ltx/ltx.py b/mlx_video/models/ltx/ltx.py index 5551b0a..6a63d7b 100644 --- a/mlx_video/models/ltx/ltx.py +++ b/mlx_video/models/ltx/ltx.py @@ -26,7 +26,7 @@ class TransformerArgsPreprocessor: self, patchify_proj: nn.Linear, adaln: AdaLayerNormSingle, - caption_projection: PixArtAlphaTextProjection, + caption_projection: Optional[PixArtAlphaTextProjection], inner_dim: int, max_pos: List[int], num_attention_heads: int, @@ -35,10 +35,12 @@ class TransformerArgsPreprocessor: positional_embedding_theta: float, rope_type: LTXRopeType, double_precision_rope: bool = False, + prompt_adaln: Optional[AdaLayerNormSingle] = None, ): self.patchify_proj = patchify_proj self.adaln = adaln self.caption_projection = caption_projection + self.prompt_adaln = prompt_adaln self.inner_dim = inner_dim self.max_pos = max_pos self.num_attention_heads = num_attention_heads @@ -64,6 +66,19 @@ class TransformerArgsPreprocessor: return timestep_emb, embedded_timestep + def _prepare_timestep_with_adaln( + self, + adaln: AdaLayerNormSingle, + timestep: mx.array, + batch_size: int, + hidden_dtype: mx.Dtype = None, + ) -> Tuple[mx.array, mx.array]: + timestep = timestep * self.timestep_scale_multiplier + timestep_emb, embedded_timestep = adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype) + timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1])) + embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])) + return timestep_emb, embedded_timestep + def _prepare_context( self, context: mx.array, @@ -72,9 +87,8 @@ class TransformerArgsPreprocessor: ) -> Tuple[mx.array, Optional[mx.array]]: batch_size = x.shape[0] - # Context is already processed through embeddings connector in text encoder - # Here we just apply the caption projection - context = self.caption_projection(context) + if self.caption_projection is not None: + context = self.caption_projection(context) context = mx.reshape(context, (batch_size, -1, x.shape[-1])) return context, attention_mask @@ -134,6 +148,14 @@ class TransformerArgsPreprocessor: num_attention_heads=self.num_attention_heads, ) + # Prompt-conditioned timestep (LTX-2.3) - uses raw sigma, not per-token timesteps + prompt_timestep = None + prompt_embedded_timestep = None + if self.prompt_adaln is not None and modality.sigma is not None: + prompt_timestep, prompt_embedded_timestep = self._prepare_timestep_with_adaln( + self.prompt_adaln, modality.sigma, x.shape[0], hidden_dtype=x.dtype, + ) + return TransformerArgs( x=x, context=context, @@ -145,6 +167,8 @@ class TransformerArgsPreprocessor: cross_scale_shift_timestep=None, cross_gate_timestep=None, enabled=modality.enabled, + prompt_timesteps=prompt_timestep, + prompt_embedded_timestep=prompt_embedded_timestep, ) @@ -154,7 +178,7 @@ class MultiModalTransformerArgsPreprocessor: self, patchify_proj: nn.Linear, adaln: AdaLayerNormSingle, - caption_projection: PixArtAlphaTextProjection, + caption_projection: Optional[PixArtAlphaTextProjection], cross_scale_shift_adaln: AdaLayerNormSingle, cross_gate_adaln: AdaLayerNormSingle, inner_dim: int, @@ -168,6 +192,7 @@ class MultiModalTransformerArgsPreprocessor: rope_type: LTXRopeType, av_ca_timestep_scale_multiplier: int, double_precision_rope: bool = False, + prompt_adaln: Optional[AdaLayerNormSingle] = None, ): self.simple_preprocessor = TransformerArgsPreprocessor( patchify_proj=patchify_proj, @@ -181,6 +206,7 @@ class MultiModalTransformerArgsPreprocessor: positional_embedding_theta=positional_embedding_theta, rope_type=rope_type, double_precision_rope=double_precision_rope, + prompt_adaln=prompt_adaln, ) self.cross_scale_shift_adaln = cross_scale_shift_adaln self.cross_gate_adaln = cross_gate_adaln @@ -280,11 +306,17 @@ class LTXModel(nn.Module): def _init_video(self, config: LTXModelConfig) -> None: self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True) - self.adaln_single = AdaLayerNormSingle(self.inner_dim) - self.caption_projection = PixArtAlphaTextProjection( - in_features=config.caption_channels, - hidden_size=self.inner_dim, - ) + + adaln_coefficient = 9 if config.has_prompt_adaln else 6 + self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=adaln_coefficient) + + if config.has_prompt_adaln: + self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2) + else: + self.caption_projection = PixArtAlphaTextProjection( + in_features=config.caption_channels, + hidden_size=self.inner_dim, + ) self.scale_shift_table = mx.zeros((2, self.inner_dim)) self.norm_out = nn.LayerNorm(self.inner_dim, eps=config.norm_eps, affine=False) @@ -292,13 +324,17 @@ class LTXModel(nn.Module): def _init_audio(self, config: LTXModelConfig) -> None: self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True) - self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim) - # Audio caption projection: receives pre-processed embeddings from text encoder's audio_embeddings_connector - self.audio_caption_projection = PixArtAlphaTextProjection( - in_features=config.audio_caption_channels, - hidden_size=self.audio_inner_dim, - ) + audio_adaln_coefficient = 9 if config.has_prompt_adaln else 6 + self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient) + + if config.has_prompt_adaln: + self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2) + else: + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=config.audio_caption_channels, + hidden_size=self.audio_inner_dim, + ) # Output components self.audio_scale_shift_table = mx.zeros((2, self.audio_inner_dim)) @@ -331,7 +367,7 @@ class LTXModel(nn.Module): self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor( patchify_proj=self.patchify_proj, adaln=self.adaln_single, - caption_projection=self.caption_projection, + caption_projection=getattr(self, "caption_projection", None), cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single, cross_gate_adaln=self.av_ca_a2v_gate_adaln_single, inner_dim=self.inner_dim, @@ -345,11 +381,12 @@ class LTXModel(nn.Module): rope_type=config.rope_type, av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier, double_precision_rope=config.double_precision_rope, + prompt_adaln=getattr(self, "prompt_adaln_single", None), ) self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor( patchify_proj=self.audio_patchify_proj, adaln=self.audio_adaln_single, - caption_projection=self.audio_caption_projection, + caption_projection=getattr(self, "audio_caption_projection", None), cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single, cross_gate_adaln=self.av_ca_v2a_gate_adaln_single, inner_dim=self.audio_inner_dim, @@ -363,12 +400,13 @@ class LTXModel(nn.Module): rope_type=config.rope_type, av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier, double_precision_rope=config.double_precision_rope, + prompt_adaln=getattr(self, "audio_prompt_adaln_single", None), ) elif config.model_type.is_video_enabled(): self.video_args_preprocessor = TransformerArgsPreprocessor( patchify_proj=self.patchify_proj, adaln=self.adaln_single, - caption_projection=self.caption_projection, + caption_projection=getattr(self, "caption_projection", None), inner_dim=self.inner_dim, max_pos=config.positional_embedding_max_pos, num_attention_heads=self.num_attention_heads, @@ -377,12 +415,13 @@ class LTXModel(nn.Module): positional_embedding_theta=config.positional_embedding_theta, rope_type=config.rope_type, double_precision_rope=config.double_precision_rope, + prompt_adaln=getattr(self, "prompt_adaln_single", None), ) elif config.model_type.is_audio_enabled(): self.audio_args_preprocessor = TransformerArgsPreprocessor( patchify_proj=self.audio_patchify_proj, adaln=self.audio_adaln_single, - caption_projection=self.audio_caption_projection, + caption_projection=getattr(self, "audio_caption_projection", None), inner_dim=self.audio_inner_dim, max_pos=config.audio_positional_embedding_max_pos, num_attention_heads=self.audio_num_attention_heads, @@ -391,13 +430,13 @@ class LTXModel(nn.Module): positional_embedding_theta=config.positional_embedding_theta, rope_type=config.rope_type, double_precision_rope=config.double_precision_rope, + prompt_adaln=getattr(self, "audio_prompt_adaln_single", None), ) def _init_transformer_blocks(self, config: LTXModelConfig) -> None: video_config = config.get_video_config() audio_config = config.get_audio_config() - self.transformer_blocks = { idx: BasicAVTransformerBlock( idx=idx, @@ -405,6 +444,7 @@ class LTXModel(nn.Module): audio=audio_config, rope_type=config.rope_type, norm_eps=config.norm_eps, + has_prompt_adaln=config.has_prompt_adaln, ) for idx in range(config.num_layers) } diff --git a/mlx_video/models/ltx/text_encoder.py b/mlx_video/models/ltx/text_encoder.py index f6824f8..1c16524 100644 --- a/mlx_video/models/ltx/text_encoder.py +++ b/mlx_video/models/ltx/text_encoder.py @@ -208,6 +208,7 @@ class ConnectorAttention(nn.Module): dim: int = 3840, num_heads: int = 30, head_dim: int = 128, + has_gate_logits: bool = False, ): super().__init__() self.num_heads = num_heads @@ -218,13 +219,14 @@ class ConnectorAttention(nn.Module): self.to_q = nn.Linear(dim, inner_dim, bias=True) self.to_k = nn.Linear(dim, inner_dim, bias=True) self.to_v = nn.Linear(dim, inner_dim, bias=True) - # Direct attribute for MLX parameter tracking (not a list) self.to_out = nn.Linear(inner_dim, dim, bias=True) - # Standard RMSNorm (not Gemma-style) on full inner_dim self.q_norm = nn.RMSNorm(inner_dim, eps=1e-6) self.k_norm = nn.RMSNorm(inner_dim, eps=1e-6) + if has_gate_logits: + self.to_gate_logits = nn.Linear(dim, num_heads, bias=True) + def __call__( self, x: mx.array, @@ -233,12 +235,17 @@ class ConnectorAttention(nn.Module): ) -> mx.array: batch_size, seq_len, _ = x.shape + # Compute per-head gate early (from original input) + gate = None + if hasattr(self, "to_gate_logits"): + gate = 2.0 * mx.sigmoid(self.to_gate_logits(x)) # (B, seq, heads) + # Project to Q, K, V - q = self.to_q(x) # (B, seq, inner_dim) + q = self.to_q(x) k = self.to_k(x) v = self.to_v(x) - # QK normalization on full inner_dim BEFORE reshape (matches PyTorch) + # QK normalization on full inner_dim BEFORE reshape q = self.q_norm(q) k = self.k_norm(k) @@ -248,15 +255,18 @@ class ConnectorAttention(nn.Module): v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3) if pe is not None: - # pe: tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2) - # Apply SPLIT RoPE: operates on first half of head dimensions q = self._apply_split_rope(q, pe[0], pe[1]) k = self._apply_split_rope(k, pe[0], pe[1]) - # No mask needed for connector - after register replacement, all positions are valid out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None) out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1) + # Apply per-head gating + if gate is not None: + out = mx.reshape(out, (batch_size, seq_len, self.num_heads, self.head_dim)) + out = out * gate[..., None] + out = mx.reshape(out, (batch_size, seq_len, -1)) + return self.to_out(out) def _apply_split_rope( @@ -326,9 +336,9 @@ class ConnectorFeedForward(nn.Module): class ConnectorTransformerBlock(nn.Module): - def __init__(self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128): + def __init__(self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128, has_gate_logits: bool = False): super().__init__() - self.attn1 = ConnectorAttention(dim, num_heads, head_dim) + self.attn1 = ConnectorAttention(dim, num_heads, head_dim, has_gate_logits=has_gate_logits) self.ff = ConnectorFeedForward(dim) def __call__( @@ -367,6 +377,7 @@ class Embeddings1DConnector(nn.Module): num_learnable_registers: int = 128, positional_embedding_theta: float = 10000.0, positional_embedding_max_pos: list = None, + has_gate_logits: bool = False, ): super().__init__() self.dim = dim @@ -376,9 +387,8 @@ class Embeddings1DConnector(nn.Module): self.positional_embedding_theta = positional_embedding_theta self.positional_embedding_max_pos = positional_embedding_max_pos or [4096] - # Use dict with int keys for MLX to track parameters (lists are not tracked) self.transformer_1d_blocks = { - i: ConnectorTransformerBlock(dim, num_heads, head_dim) + i: ConnectorTransformerBlock(dim, num_heads, head_dim, has_gate_logits=has_gate_logits) for i in range(num_layers) } @@ -572,16 +582,100 @@ def norm_and_concat_hidden_states( return normed -class GemmaFeaturesExtractor(nn.Module): +def norm_and_concat_per_token_rms( + encoded_text: mx.array, + attention_mask: mx.array, +) -> mx.array: + """Per-token RMSNorm normalization for V2 feature extraction (LTX-2.3). - def __init__(self, input_dim: int = 188160, output_dim: int = 3840): + Args: + encoded_text: (B, T, D, L) stacked hidden states + attention_mask: (B, T) binary mask (1=valid, 0=padding) + + Returns: + (B, T, D*L) normalized tensor with padding zeroed out. + """ + b, t, d, num_layers = encoded_text.shape + dtype = encoded_text.dtype + + # Per-token RMSNorm across hidden dimension: variance = mean(x^2) over dim D + variance = mx.mean(encoded_text.astype(mx.float32) ** 2, axis=2, keepdims=True) # (B, T, 1, L) + normed = encoded_text.astype(mx.float32) * mx.rsqrt(variance + 1e-6) + normed = normed.astype(dtype) + + # Flatten layers: (B, T, D*L) + normed = mx.reshape(normed, (b, t, d * num_layers)) + + # Zero out padded positions + mask_3d = attention_mask[:, :, None].astype(mx.bool_) # (B, T, 1) + normed = mx.where(mask_3d, normed, mx.zeros_like(normed)) + + return normed + + +def _rescale_norm(x: mx.array, target_dim: int, source_dim: int) -> mx.array: + """Rescale normalization: x * sqrt(target_dim / source_dim).""" + return x * math.sqrt(target_dim / source_dim) + + +class GemmaFeaturesExtractor(nn.Module): + """V1 feature extractor (LTX-2): 8 * (x - mean) / range normalization.""" + + def __init__(self, input_dim: int = 188160, output_dim: int = 3840, bias: bool = False): super().__init__() - self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=False) + self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=bias) def __call__(self, x: mx.array) -> mx.array: return self.aggregate_embed(x) +class GemmaFeaturesExtractorV2(nn.Module): + """V2 feature extractor (LTX-2.3): per-token RMSNorm + rescale normalization.""" + + def __init__( + self, + flat_dim: int, + embedding_dim: int, + video_output_dim: int, + audio_output_dim: int, + bias: bool = True, + ): + super().__init__() + self.embedding_dim = embedding_dim # Gemma hidden_dim (3840), used for rescale + self.video_aggregate_embed = nn.Linear(flat_dim, video_output_dim, bias=bias) + self.audio_aggregate_embed = nn.Linear(flat_dim, audio_output_dim, bias=bias) + + def __call__( + self, + hidden_states: List[mx.array], + attention_mask: mx.array, + mode: str = "video", + ) -> mx.array: + """Extract features with per-token RMSNorm + rescale. + + Args: + hidden_states: List of hidden states from all Gemma layers + attention_mask: Binary attention mask (B, T) + mode: "video" or "audio" to select which aggregate embed to use + + Returns: + Projected features + """ + # Stack hidden states: (B, T, D, L) + encoded = mx.stack(hidden_states, axis=-1) + + # Per-token RMSNorm + flatten + normed = norm_and_concat_per_token_rms(encoded, attention_mask) + normed = normed.astype(encoded.dtype) + + if mode == "video": + target_dim = self.video_aggregate_embed.weight.shape[0] + return self.video_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim)) + else: + target_dim = self.audio_aggregate_embed.weight.shape[0] + return self.audio_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim)) + + @@ -603,39 +697,54 @@ class LTX2TextEncoder(nn.Module): hidden_dim: int = 3840, audio_dim: int = 2048, num_layers: int = 49, # 48 transformer layers + 1 embedding + has_prompt_adaln: bool = False, ): super().__init__() self.hidden_dim = hidden_dim self.audio_dim = audio_dim self.num_layers = num_layers + self.has_prompt_adaln = has_prompt_adaln self.language_model = None - # Feature extractor: 3840*49 -> 3840 - self.feature_extractor = GemmaFeaturesExtractor( - input_dim=hidden_dim * num_layers, - output_dim=hidden_dim, - ) + feature_input_dim = hidden_dim * num_layers - # Video embeddings connector: 2-layer transformer - self.video_embeddings_connector = Embeddings1DConnector( - dim=hidden_dim, - num_heads=30, - head_dim=128, - num_layers=2, - num_learnable_registers=128, - positional_embedding_max_pos=[4096], # Match PyTorch - ) + if has_prompt_adaln: + # LTX-2.3: V2 feature extractor with per-token RMSNorm + rescale + video_output_dim = 4096 + audio_output_dim = 2048 + self.feature_extractor_v2 = GemmaFeaturesExtractorV2( + flat_dim=feature_input_dim, # 3840 * 49 = 188160 (concatenated) + embedding_dim=hidden_dim, # 3840 (Gemma hidden_dim, for rescale) + video_output_dim=video_output_dim, + audio_output_dim=audio_output_dim, + bias=True, + ) - # Audio embeddings connector: separate 2-layer transformer (same architecture as video) - # Both connectors process the feature extractor output independently - self.audio_embeddings_connector = Embeddings1DConnector( - dim=hidden_dim, - num_heads=30, - head_dim=128, - num_layers=2, - num_learnable_registers=128, - positional_embedding_max_pos=[4096], # Match PyTorch - ) + # Deeper connectors with matching dims and gate_logits + self.video_embeddings_connector = Embeddings1DConnector( + dim=video_output_dim, num_heads=32, head_dim=128, + num_layers=8, num_learnable_registers=128, + positional_embedding_max_pos=[4096], has_gate_logits=True, + ) + self.audio_embeddings_connector = Embeddings1DConnector( + dim=audio_output_dim, num_heads=32, head_dim=64, + num_layers=8, num_learnable_registers=128, + positional_embedding_max_pos=[4096], has_gate_logits=True, + ) + else: + # LTX-2: shared feature extractor, 3840-dim connectors + self.feature_extractor = GemmaFeaturesExtractor(feature_input_dim, hidden_dim) + + self.video_embeddings_connector = Embeddings1DConnector( + dim=hidden_dim, num_heads=30, head_dim=128, + num_layers=2, num_learnable_registers=128, + positional_embedding_max_pos=[4096], + ) + self.audio_embeddings_connector = Embeddings1DConnector( + dim=hidden_dim, num_heads=30, head_dim=128, + num_layers=2, num_learnable_registers=128, + positional_embedding_max_pos=[4096], + ) self.processor = None @@ -669,81 +778,9 @@ class LTX2TextEncoder(nn.Module): transformer_weights = mx.load(str(transformer_files[0])) if transformer_weights: - # Load feature extractor (aggregate_embed) - # Reformatted key: "aggregate_embed.weight" - # Monolithic key: "text_embedding_projection.aggregate_embed.weight" - agg_key = "aggregate_embed.weight" if is_reformatted else "text_embedding_projection.aggregate_embed.weight" - if agg_key in transformer_weights: - self.feature_extractor.aggregate_embed.weight = transformer_weights[agg_key] - - # Load video_embeddings_connector weights - connector_weights = {} - if is_reformatted: - # Reformatted: keys are already sanitized with "video_embeddings_connector." prefix - for key, value in transformer_weights.items(): - if key.startswith("video_embeddings_connector."): - new_key = key.replace("video_embeddings_connector.", "") - connector_weights[new_key] = value - else: - # Monolithic: keys have "model.diffusion_model.video_embeddings_connector." prefix - for key, value in transformer_weights.items(): - if key.startswith("model.diffusion_model.video_embeddings_connector."): - new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "") - connector_weights[new_key] = value - - if connector_weights: - # Map weight names to our structure (only needed for monolithic/raw PyTorch keys) - mapped_weights = {} - for key, value in connector_weights.items(): - new_key = key - if not is_reformatted: - # Map ff.net.0.proj -> ff.proj_in (GEGLU projection) - new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") - # Map ff.net.2 -> ff.proj_out (output Linear) - new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") - # Map to_out.0 -> to_out (Sequential -> direct) - new_key = new_key.replace(".to_out.0.", ".to_out.") - mapped_weights[new_key] = value - - self.video_embeddings_connector.load_weights( - list(mapped_weights.items()), strict=False - ) - - # Manually load learnable_registers (it's a plain mx.array, not a parameter) - if "learnable_registers" in connector_weights: - self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"] - - # Load audio_embeddings_connector weights (same structure as video connector) - audio_connector_weights = {} - if is_reformatted: - for key, value in transformer_weights.items(): - if key.startswith("audio_embeddings_connector."): - new_key = key.replace("audio_embeddings_connector.", "") - audio_connector_weights[new_key] = value - else: - for key, value in transformer_weights.items(): - if key.startswith("model.diffusion_model.audio_embeddings_connector."): - new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "") - audio_connector_weights[new_key] = value - - if audio_connector_weights: - # Map weight names to our structure (same as video connector) - mapped_audio_weights = {} - for key, value in audio_connector_weights.items(): - new_key = key - if not is_reformatted: - new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") - new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") - new_key = new_key.replace(".to_out.0.", ".to_out.") - mapped_audio_weights[new_key] = value - - self.audio_embeddings_connector.load_weights( - list(mapped_audio_weights.items()), strict=False - ) - - # Manually load learnable_registers (it's a plain mx.array, not a parameter) - if "learnable_registers" in audio_connector_weights: - self.audio_embeddings_connector.learnable_registers = audio_connector_weights["learnable_registers"] + self._load_feature_extractors(transformer_weights, is_reformatted) + self._load_connector("video_embeddings_connector", transformer_weights, is_reformatted) + self._load_connector("audio_embeddings_connector", transformer_weights, is_reformatted) else: print("WARNING: No transformer weights found for text projection connectors. " "Text conditioning will use uninitialized weights!") @@ -763,6 +800,63 @@ class LTX2TextEncoder(nn.Module): print("Text encoder loaded successfully") + def _load_feature_extractors(self, weights: dict, is_reformatted: bool): + """Load feature extractor weights for both LTX-2 and LTX-2.3.""" + if self.has_prompt_adaln: + # LTX-2.3: V2 feature extractor with separate video/audio aggregate embeds + for attr, prefix in [ + ("video_aggregate_embed", "video_aggregate_embed"), + ("audio_aggregate_embed", "audio_aggregate_embed"), + ]: + w_key = f"{prefix}.weight" + b_key = f"{prefix}.bias" + if w_key in weights: + submodule = getattr(self.feature_extractor_v2, attr) + submodule.weight = weights[w_key] + if b_key in weights: + submodule.bias = weights[b_key] + else: + # LTX-2: single aggregate_embed + agg_key = "aggregate_embed.weight" if is_reformatted else "text_embedding_projection.aggregate_embed.weight" + if agg_key in weights: + self.feature_extractor.aggregate_embed.weight = weights[agg_key] + + def _load_connector(self, name: str, weights: dict, is_reformatted: bool): + """Load a connector's weights (video or audio).""" + connector = getattr(self, name) + + # Extract connector-specific weights + connector_weights = {} + if is_reformatted: + prefix = f"{name}." + for key, value in weights.items(): + if key.startswith(prefix): + connector_weights[key[len(prefix):]] = value + else: + mono_prefix = f"model.diffusion_model.{name}." + for key, value in weights.items(): + if key.startswith(mono_prefix): + connector_weights[key[len(mono_prefix):]] = value + + if not connector_weights: + return + + # Sanitize key names (only needed for monolithic/raw PyTorch keys) + mapped = {} + for key, value in connector_weights.items(): + new_key = key + if not is_reformatted: + new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") + new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") + new_key = new_key.replace(".to_out.0.", ".to_out.") + mapped[new_key] = value + + connector.load_weights(list(mapped.items()), strict=False) + + # Manually load learnable_registers (plain mx.array, not tracked as parameter) + if "learnable_registers" in connector_weights: + connector.learnable_registers = connector_weights["learnable_registers"] + def encode( self, prompt: str, @@ -795,21 +889,40 @@ class LTX2TextEncoder(nn.Module): attention_mask = mx.array(inputs["attention_mask"]) _, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True) - concat_hidden = norm_and_concat_hidden_states( - all_hidden_states, attention_mask, padding_side="left" - ) - features = self.feature_extractor(concat_hidden) - additive_mask = (attention_mask - 1).astype(features.dtype) - additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 - video_embeddings, _ = self.video_embeddings_connector(features, additive_mask) + if self.has_prompt_adaln: + # LTX-2.3: V2 feature extraction (per-token RMSNorm + rescale) + video_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="video") + additive_mask = (attention_mask - 1).astype(video_features.dtype) + additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 - if return_audio_embeddings: - # Process features through audio connector independently (same input as video) - audio_embeddings, _ = self.audio_embeddings_connector(features, additive_mask) - return video_embeddings, audio_embeddings + video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask) + + if return_audio_embeddings: + audio_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="audio") + audio_mask = (attention_mask - 1).astype(audio_features.dtype) + audio_mask = audio_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 + audio_embeddings, _ = self.audio_embeddings_connector(audio_features, audio_mask) + return video_embeddings, audio_embeddings + else: + return video_embeddings, attention_mask else: - return video_embeddings, attention_mask + # LTX-2: V1 feature extraction (8 * (x - mean) / range) + concat_hidden = norm_and_concat_hidden_states( + all_hidden_states, attention_mask, padding_side="left" + ) + + video_features = self.feature_extractor(concat_hidden) + additive_mask = (attention_mask - 1).astype(video_features.dtype) + additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 + + video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask) + + if return_audio_embeddings: + audio_embeddings, _ = self.audio_embeddings_connector(video_features, additive_mask) + return video_embeddings, audio_embeddings + else: + return video_embeddings, attention_mask def __call__( self, diff --git a/mlx_video/models/ltx/transformer.py b/mlx_video/models/ltx/transformer.py index 5a60989..4b311e6 100644 --- a/mlx_video/models/ltx/transformer.py +++ b/mlx_video/models/ltx/transformer.py @@ -20,20 +20,25 @@ class Modality: context_mask: Optional[mx.array] = None # Optional precomputed positional embeddings (RoPE) to avoid recomputation positional_embeddings: Optional[Tuple[mx.array, mx.array]] = None + # Raw sigma value (scalar per batch) for prompt adaln (LTX-2.3) + sigma: Optional[mx.array] = None @dataclass(frozen=True) class TransformerArgs: - x: mx.array - context: mx.array - context_mask: Optional[mx.array] - timesteps: mx.array - embedded_timestep: mx.array - positional_embeddings: Tuple[mx.array, mx.array] - cross_positional_embeddings: Optional[Tuple[mx.array, mx.array]] - cross_scale_shift_timestep: Optional[mx.array] - cross_gate_timestep: Optional[mx.array] + x: mx.array + context: mx.array + context_mask: Optional[mx.array] + timesteps: mx.array + embedded_timestep: mx.array + positional_embeddings: Tuple[mx.array, mx.array] + cross_positional_embeddings: Optional[Tuple[mx.array, mx.array]] + cross_scale_shift_timestep: Optional[mx.array] + cross_gate_timestep: Optional[mx.array] enabled: bool + # LTX-2.3: prompt-conditioned timestep embeddings for cross-attention + prompt_timesteps: Optional[mx.array] = None + prompt_embedded_timestep: Optional[mx.array] = None class BasicAVTransformerBlock(nn.Module): @@ -50,20 +55,13 @@ class BasicAVTransformerBlock(nn.Module): audio: Optional[TransformerConfig] = None, rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, norm_eps: float = 1e-6, + has_prompt_adaln: bool = False, ): - """Initialize transformer block. - - Args: - idx: Block index - video: Video modality configuration - audio: Audio modality configuration - rope_type: Type of rotary position embedding - norm_eps: Epsilon for normalization - """ super().__init__() self.idx = idx self.norm_eps = norm_eps + self.has_prompt_adaln = has_prompt_adaln # Video components if video is not None: @@ -74,6 +72,7 @@ class BasicAVTransformerBlock(nn.Module): context_dim=None, # Self-attention rope_type=rope_type, norm_eps=norm_eps, + has_gate_logits=has_prompt_adaln, ) self.attn2 = Attention( query_dim=video.dim, @@ -82,10 +81,15 @@ class BasicAVTransformerBlock(nn.Module): dim_head=video.d_head, rope_type=rope_type, norm_eps=norm_eps, + has_gate_logits=has_prompt_adaln, ) self.ff = FeedForward(video.dim, dim_out=video.dim) - # 6 scale-shift parameters: 3 for attention, 3 for MLP - self.scale_shift_table = mx.zeros((6, video.dim)) + # 9 params for LTX-2.3 (self-attn + cross-attn + FFN), 6 for LTX-2 + num_ada_params = 9 if has_prompt_adaln else 6 + self.scale_shift_table = mx.zeros((num_ada_params, video.dim)) + + if has_prompt_adaln: + self.prompt_scale_shift_table = mx.zeros((2, video.dim)) # Audio components if audio is not None: @@ -96,6 +100,7 @@ class BasicAVTransformerBlock(nn.Module): context_dim=None, rope_type=rope_type, norm_eps=norm_eps, + has_gate_logits=has_prompt_adaln, ) self.audio_attn2 = Attention( query_dim=audio.dim, @@ -104,9 +109,14 @@ class BasicAVTransformerBlock(nn.Module): dim_head=audio.d_head, rope_type=rope_type, norm_eps=norm_eps, + has_gate_logits=has_prompt_adaln, ) self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim) - self.audio_scale_shift_table = mx.zeros((6, audio.dim)) + num_audio_ada_params = 9 if has_prompt_adaln else 6 + self.audio_scale_shift_table = mx.zeros((num_audio_ada_params, audio.dim)) + + if has_prompt_adaln: + self.audio_prompt_scale_shift_table = mx.zeros((2, audio.dim)) # Cross-modal attention (when both video and audio are enabled) if audio is not None and video is not None: @@ -118,6 +128,7 @@ class BasicAVTransformerBlock(nn.Module): dim_head=audio.d_head, rope_type=rope_type, norm_eps=norm_eps, + has_gate_logits=has_prompt_adaln, ) # Video-to-Audio: Q from audio, K/V from video self.video_to_audio_attn = Attention( @@ -127,6 +138,7 @@ class BasicAVTransformerBlock(nn.Module): dim_head=audio.d_head, rope_type=rope_type, norm_eps=norm_eps, + has_gate_logits=has_prompt_adaln, ) # Scale-shift tables for cross-attention self.scale_shift_table_a2v_ca_audio = mx.zeros((5, audio.dim)) @@ -254,11 +266,23 @@ class BasicAVTransformerBlock(nn.Module): vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa # Cross-attention with text context - vx = vx + self.attn2( - rms_norm(vx, eps=self.norm_eps), - context=video.context, - mask=video.context_mask, - ) + if self.has_prompt_adaln: + # LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln + vshift_q, vscale_q, vgate_q = self.get_ada_values( + self.scale_shift_table, vx.shape[0], video.timesteps, slice(6, 9) + ) + vprompt_shift_kv, vprompt_scale_kv = self.get_ada_values( + self.prompt_scale_shift_table, vx.shape[0], video.prompt_timesteps, slice(0, 2) + ) + attn_input = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_q) + vshift_q + encoder_hidden_states = video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv + vx = vx + self.attn2(attn_input, context=encoder_hidden_states, mask=video.context_mask) * vgate_q + else: + vx = vx + self.attn2( + rms_norm(vx, eps=self.norm_eps), + context=video.context, + mask=video.context_mask, + ) # Process audio self-attention and cross-attention with text if run_ax: @@ -271,11 +295,23 @@ class BasicAVTransformerBlock(nn.Module): ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa # Cross-attention with text context - ax = ax + self.audio_attn2( - rms_norm(ax, eps=self.norm_eps), - context=audio.context, - mask=audio.context_mask, - ) + if self.has_prompt_adaln: + # LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln + ashift_q, ascale_q, agate_q = self.get_ada_values( + self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(6, 9) + ) + aprompt_shift_kv, aprompt_scale_kv = self.get_ada_values( + self.audio_prompt_scale_shift_table, ax.shape[0], audio.prompt_timesteps, slice(0, 2) + ) + attn_input_a = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q + encoder_hidden_states_a = audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv + ax = ax + self.audio_attn2(attn_input_a, context=encoder_hidden_states_a, mask=audio.context_mask) * agate_q + else: + ax = ax + self.audio_attn2( + rms_norm(ax, eps=self.norm_eps), + context=audio.context, + mask=audio.context_mask, + ) # Audio-Video cross-modal attention if run_a2v or run_v2a: @@ -341,7 +377,7 @@ class BasicAVTransformerBlock(nn.Module): # Process video feed-forward if run_vx: vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values( - self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None) + self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, 6) ) vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp vx = vx + self.ff(vx_scaled) * vgate_mlp @@ -349,7 +385,7 @@ class BasicAVTransformerBlock(nn.Module): # Process audio feed-forward if run_ax: ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values( - self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None) + self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6) ) ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp ax = ax + self.audio_ff(ax_scaled) * agate_mlp diff --git a/mlx_video/models/ltx/upsampler.py b/mlx_video/models/ltx/upsampler.py index 7f43536..1180664 100644 --- a/mlx_video/models/ltx/upsampler.py +++ b/mlx_video/models/ltx/upsampler.py @@ -301,15 +301,14 @@ def upsample_latents( latent_std: mx.array, debug: bool = False, ) -> mx.array: - # Un-normalize: latent * std + mean latent_mean = latent_mean.reshape(1, -1, 1, 1, 1) latent_std = latent_std.reshape(1, -1, 1, 1, 1) latent = latent * latent_std + latent_mean - + # Upsample latent = upsampler(latent, debug=debug) - + # Re-normalize: (latent - mean) / std latent = (latent - latent_mean) / latent_std @@ -350,19 +349,18 @@ def load_upsampler(weights_path: str) -> LatentUpsampler: for key, value in raw_weights.items(): new_key = key + # LTX-2.3 upsampler uses sequential indexing: upsampler.0.* -> upsampler.conv.* + if key.startswith("upsampler.0."): + new_key = key.replace("upsampler.0.", "upsampler.conv.") + # Conv3d weights: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I) - if "conv" in key and "weight" in key and value.ndim == 5: + if "weight" in new_key and value.ndim == 5: value = mx.transpose(value, (0, 2, 3, 4, 1)) # Conv2d weights: PyTorch (O, I, H, W) -> MLX (O, H, W, I) - if "conv" in key and "weight" in key and value.ndim == 4: + if "weight" in new_key and value.ndim == 4: value = mx.transpose(value, (0, 2, 3, 1)) - # Map upsampler.conv to upsampler.conv (SpatialRationalResampler) - # Keys: upsampler.conv.weight, upsampler.conv.bias, upsampler.blur_down.kernel - if key.startswith("upsampler."): - new_key = key # Keep as is for SpatialRationalResampler - sanitized[new_key] = value # Load weights diff --git a/mlx_video/models/ltx/video_vae/decoder.py b/mlx_video/models/ltx/video_vae/decoder.py index 5f45d8a..105082c 100644 --- a/mlx_video/models/ltx/video_vae/decoder.py +++ b/mlx_video/models/ltx/video_vae/decoder.py @@ -250,6 +250,18 @@ class LTX2VideoDecoder(nn.Module): - conv_out: 128 -> 48 (3 * 4^2 for patch_size=4) """ + # Block definitions: ("res", channels, num_layers) or ("d2s", in_channels, reduction, stride) + # stride is (D, H, W) tuple + DEFAULT_BLOCKS = [ + ("res", 1024, 5), + ("d2s", 1024, 2, (2, 2, 2)), + ("res", 512, 5), + ("d2s", 512, 2, (2, 2, 2)), + ("res", 256, 5), + ("d2s", 256, 2, (2, 2, 2)), + ("res", 128, 5), + ] + def __init__( self, in_channels: int = 128, @@ -258,6 +270,7 @@ class LTX2VideoDecoder(nn.Module): num_layers_per_block: int = 5, spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT, timestep_conditioning: bool = True, + decoder_blocks: list = None, ): super().__init__() @@ -272,13 +285,17 @@ class LTX2VideoDecoder(nn.Module): # Per-channel statistics for denormalization (loaded from weights) self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels) - # Initial conv: 128 -> 1024 + blocks = decoder_blocks or self.DEFAULT_BLOCKS + first_ch = blocks[0][1] + last_ch = blocks[-1][1] + + # Initial conv: in_channels -> first block channels class ConvInWrapper(nn.Module): def __init__(self_inner): super().__init__() self_inner.conv = CausalConv3d( in_channels=in_channels, - out_channels=1024, + out_channels=first_ch, kernel_size=3, stride=1, padding=1, @@ -288,45 +305,32 @@ class LTX2VideoDecoder(nn.Module): return self_inner.conv(x, causal=causal) self.conv_in = ConvInWrapper() - # Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample - # Use dict with int keys for MLX to track parameters properly - self.up_blocks = { - 0: ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning), - 1: DepthToSpaceUpsample( - dims=3, - in_channels=1024, - stride=(2, 2, 2), - residual=True, - out_channels_reduction_factor=2, - spatial_padding_mode=spatial_padding_mode, - ), - 2: ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning), - 3: DepthToSpaceUpsample( - dims=3, - in_channels=512, - stride=(2, 2, 2), - residual=True, - out_channels_reduction_factor=2, - spatial_padding_mode=spatial_padding_mode, - ), - 4: ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning), - 5: DepthToSpaceUpsample( - dims=3, - in_channels=256, - stride=(2, 2, 2), - residual=True, - out_channels_reduction_factor=2, - spatial_padding_mode=spatial_padding_mode, - ), - 6: ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning), - } + # Build up blocks from config + self.up_blocks = {} + for idx, block_def in enumerate(blocks): + block_type = block_def[0] + ch = block_def[1] + if block_type == "res": + num_layers = block_def[2] if len(block_def) > 2 else num_layers_per_block + self.up_blocks[idx] = ResBlockGroup(ch, num_layers, spatial_padding_mode, timestep_conditioning) + elif block_type == "d2s": + reduction = block_def[2] if len(block_def) > 2 else 2 + stride = block_def[3] if len(block_def) > 3 else (2, 2, 2) + self.up_blocks[idx] = DepthToSpaceUpsample( + dims=3, + in_channels=ch, + stride=stride, + residual=True, + out_channels_reduction_factor=reduction, + spatial_padding_mode=spatial_padding_mode, + ) final_out_channels = out_channels * patch_size * patch_size class ConvOutWrapper(nn.Module): def __init__(self_inner): super().__init__() self_inner.conv = CausalConv3d( - in_channels=128, + in_channels=last_ch, out_channels=final_out_channels, kernel_size=3, stride=1, @@ -342,9 +346,9 @@ class LTX2VideoDecoder(nn.Module): if timestep_conditioning: self.timestep_scale_multiplier = mx.array(1000.0) self.last_time_embedder = PixArtAlphaTimestepEmbedder( - embedding_dim=128 * 2 # 256, matches (2, 128) table + embedding_dim=last_ch * 2 ) - self.last_scale_shift_table = mx.zeros((2, 128)) + self.last_scale_shift_table = mx.zeros((2, last_ch)) def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: # Build decoder weights dict with key remapping @@ -418,11 +422,96 @@ class LTX2VideoDecoder(nn.Module): weights.update(mx.load(str(wf))) - model = cls(timestep_conditioning=config_dict.get("timestep_conditioning", False)) + # Infer block structure from weights + decoder_blocks = cls._infer_blocks(weights) + + model = cls( + timestep_conditioning=config_dict.get("timestep_conditioning", False), + decoder_blocks=decoder_blocks, + ) weights = model.sanitize(weights) model.load_weights(list(weights.items()), strict=strict) return model + @staticmethod + def _infer_blocks(weights: dict) -> list: + """Infer decoder block structure from weight keys.""" + block_indices = set() + for k in weights: + if "up_blocks." in k: + idx_str = k.split("up_blocks.")[1].split(".")[0] + if idx_str.isdigit(): + block_indices.add(int(idx_str)) + + if not block_indices: + return None + + # First pass: collect block info + raw_blocks = [] + for idx in sorted(block_indices): + has_conv = any(f"up_blocks.{idx}.conv." in k for k in weights) + res_indices = set() + for k in weights: + prefix = f"up_blocks.{idx}.res_blocks." + if prefix in k: + res_idx = k.split(prefix)[1].split(".")[0] + if res_idx.isdigit(): + res_indices.add(int(res_idx)) + + if has_conv and not res_indices: + # D2S block - get conv shape + for k, v in weights.items(): + if f"up_blocks.{idx}.conv." in k and "weight" in k: + in_ch = v.shape[-1] if v.ndim == 5 else v.shape[1] + conv_out_ch = v.shape[0] + raw_blocks.append(("d2s", in_ch, conv_out_ch)) + break + elif res_indices: + num_res = max(res_indices) + 1 + for k, v in weights.items(): + if f"up_blocks.{idx}.res_blocks.0.conv1" in k and "weight" in k: + ch = v.shape[0] + raw_blocks.append(("res", ch, num_res)) + break + + # Second pass: determine d2s strides using the channel progression + # For each d2s block, the next res block tells us the expected output channels + blocks = [] + for i, block in enumerate(raw_blocks): + if block[0] == "res": + blocks.append(block) + elif block[0] == "d2s": + in_ch, conv_out_ch = block[1], block[2] + # Find next res block's channels + next_ch = None + for j in range(i + 1, len(raw_blocks)): + if raw_blocks[j][0] == "res": + next_ch = raw_blocks[j][1] + break + + if next_ch is None: + next_ch = in_ch // 2 # fallback + + # out_ch = in_ch // reduction + reduction = in_ch // next_ch if next_ch > 0 else 2 + + # conv_out = next_ch * multiplier → multiplier = conv_out / next_ch + multiplier = conv_out_ch // next_ch if next_ch > 0 else 8 + + # Determine stride from multiplier + if multiplier == 8: + stride = (2, 2, 2) + elif multiplier == 4: + stride = (1, 2, 2) + elif multiplier == 2: + stride = (2, 1, 1) + else: + stride = (2, 2, 2) + + blocks.append(("d2s", in_ch, reduction, stride)) + + return blocks if blocks else None + def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array: