From f8f78aeab55e348bc40d8db07b717b7cb1c9db31 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 23 Jan 2026 17:45:50 +0100 Subject: [PATCH] Add LTXModel with a from_pretrained class method for loading model weights from a specified path. Update weight sanitization to handle positional embeddings and dtype consistency. Refactor timestep and context preparation methods to accept hidden_dtype, improving flexibility in model processing. --- mlx_video/models/ltx/ltx.py | 95 +++++++++++++++++++++++++++++-------- 1 file changed, 74 insertions(+), 21 deletions(-) diff --git a/mlx_video/models/ltx/ltx.py b/mlx_video/models/ltx/ltx.py index e89f140..b130665 100644 --- a/mlx_video/models/ltx/ltx.py +++ b/mlx_video/models/ltx/ltx.py @@ -2,7 +2,7 @@ from typing import List, Optional, Tuple import mlx.core as mx import mlx.nn as nn - +from pathlib import Path from mlx_video.models.ltx.config import ( LTXModelConfig, LTXModelType, @@ -52,10 +52,11 @@ class TransformerArgsPreprocessor: self, timestep: mx.array, batch_size: int, + hidden_dtype: mx.Dtype = None, ) -> Tuple[mx.array, mx.array]: timestep = timestep * self.timestep_scale_multiplier - timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1)) + timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype) # Reshape to (batch, tokens, dim) timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1])) @@ -70,6 +71,9 @@ class TransformerArgsPreprocessor: attention_mask: Optional[mx.array] = None, ) -> Tuple[mx.array, Optional[mx.array]]: batch_size = x.shape[0] + + # Context is already processed through embeddings connector in text encoder + # Here we just apply the caption projection context = self.caption_projection(context) context = mx.reshape(context, (batch_size, -1, x.shape[-1])) return context, attention_mask @@ -114,16 +118,21 @@ class TransformerArgsPreprocessor: def prepare(self, modality: Modality) -> TransformerArgs: x = self.patchify_proj(modality.latent) - timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0]) + timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype) context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask) attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype) - pe = self._prepare_positional_embeddings( - positions=modality.positions, - inner_dim=self.inner_dim, - max_pos=self.max_pos, - use_middle_indices_grid=self.use_middle_indices_grid, - num_attention_heads=self.num_attention_heads, - ) + + # Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation) + if modality.positional_embeddings is not None: + pe = modality.positional_embeddings + else: + pe = self._prepare_positional_embeddings( + positions=modality.positions, + inner_dim=self.inner_dim, + max_pos=self.max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + num_attention_heads=self.num_attention_heads, + ) return TransformerArgs( x=x, @@ -198,6 +207,7 @@ class MultiModalTransformerArgsPreprocessor: timestep=modality.timesteps, timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier, batch_size=transformer_args.x.shape[0], + hidden_dtype=transformer_args.x.dtype, ) return replace( @@ -212,15 +222,16 @@ class MultiModalTransformerArgsPreprocessor: timestep: mx.array, timestep_scale_multiplier: int, batch_size: int, + hidden_dtype: mx.Dtype = None, ) -> Tuple[mx.array, mx.array]: timestep = timestep * timestep_scale_multiplier av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier - scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1)) + scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype) scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1])) - gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor) + gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype) gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1])) return scale_shift_timestep, gate_timestep @@ -282,6 +293,8 @@ class LTXModel(nn.Module): def _init_audio(self, config: LTXModelConfig) -> None: self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True) self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim) + + # Audio caption projection: receives pre-processed embeddings from text encoder's audio_embeddings_connector self.audio_caption_projection = PixArtAlphaTextProjection( in_features=config.audio_caption_channels, hidden_size=self.audio_inner_dim, @@ -384,8 +397,9 @@ class LTXModel(nn.Module): video_config = config.get_video_config() audio_config = config.get_audio_config() - self.transformer_blocks = [ - BasicAVTransformerBlock( + + self.transformer_blocks = { + idx: BasicAVTransformerBlock( idx=idx, video=video_config, audio=audio_config, @@ -393,7 +407,7 @@ class LTXModel(nn.Module): norm_eps=config.norm_eps, ) for idx in range(config.num_layers) - ] + } def _process_transformer_blocks( self, @@ -401,7 +415,7 @@ class LTXModel(nn.Module): audio: Optional[TransformerArgs], ) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]: """Process through all transformer blocks.""" - for block in self.transformer_blocks: + for block in self.transformer_blocks.values(): video, audio = block(video=video, audio=audio) return video, audio @@ -483,19 +497,58 @@ class LTXModel(nn.Module): def sanitize(self, weights: dict) -> dict: sanitized = {} + + if "model.diffusion_model." not in weights: + return weights + for key, value in weights.items(): new_key = key + # Skip non-transformer weights (VAE, vocoder, audio_vae, connectors) + if not key.startswith("model.diffusion_model.") or "audio_embeddings_connector" in key or "video_embeddings_connector" in key: + continue + + # Remove 'model.diffusion_model.' prefix + new_key = new_key.replace("model.diffusion_model.", "") + + new_key = new_key.replace(".to_out.0.", ".to_out.") + + new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") + new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") + new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.") + new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.") + + new_key = new_key.replace(".linear_1.", ".linear1.") + new_key = new_key.replace(".linear_2.", ".linear2.") - # Handle common remappings - # transformer_blocks.X -> transformer_blocks[X] - if "transformer_blocks." in new_key: - # Keep as-is for now, MLX handles this - pass sanitized[new_key] = value return sanitized + @classmethod + def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTXModel": + import json + + config_dict = {} + with open(model_path / "config.json", "r") as f: + config_dict = json.load(f) + config = LTXModelConfig(**config_dict) + model = cls(config) + + weights = {} + + for weight_file in model_path.glob("*.safetensors"): + weights.update(mx.load(str(weight_file))) + + + sanitized = model.sanitize(weights) + sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()} + + model.load_weights(list(sanitized.items()), strict=strict) + mx.eval(model.parameters()) + model.eval() + return model + class X0Model(nn.Module):