From 2681f75d2f77204c890b6dede13f8b3abe096835 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Tue, 20 Jan 2026 12:56:29 +0100 Subject: [PATCH] Refactor LTXModel to include a from_pretrained class method for loading and sanitizing model weights. Update generate.py to utilize this method, streamlining the transformer loading process and improving code clarity. --- mlx_video/generate.py | 12 +++++------ mlx_video/models/ltx/ltx.py | 43 +++++++++++++++++++++++++++++++------ 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 1ac4508..8c99153 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -24,7 +24,7 @@ console = Console() from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType from mlx_video.models.ltx.ltx import LTXModel from mlx_video.models.ltx.transformer import Modality -from mlx_video.convert import sanitize_transformer_weights + from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding, get_model_path from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder @@ -891,9 +891,7 @@ def generate_video( # Load transformer transformer_desc = f"🤖 Loading {pipeline_name.lower()} transformer{' (A/V mode)' if audio else ''}..." with console.status(f"[blue]{transformer_desc}[/]", spinner="dots"): - raw_weights = mx.load(str(model_path / weight_file)) - sanitized = sanitize_transformer_weights(raw_weights) - sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()} + model_type = LTXModelType.AudioVideo if audio else LTXModelType.VideoOnly @@ -925,9 +923,9 @@ def generate_video( ) config = LTXModelConfig(**config_kwargs) - transformer = LTXModel(config) - transformer.load_weights(list(sanitized.items()), strict=False) - mx.eval(transformer.parameters()) + + transformer = LTXModel.from_pretrained(model_path=model_path/weight_file, config=config, strict=True) + console.print("[green]✓[/] Transformer loaded") # ========================================================================== diff --git a/mlx_video/models/ltx/ltx.py b/mlx_video/models/ltx/ltx.py index c7c51a2..f083485 100644 --- a/mlx_video/models/ltx/ltx.py +++ b/mlx_video/models/ltx/ltx.py @@ -2,7 +2,7 @@ from typing import List, Optional, Tuple import mlx.core as mx import mlx.nn as nn - +from pathlib import Path from mlx_video.models.ltx.config import ( LTXModelConfig, LTXModelType, @@ -497,19 +497,50 @@ class LTXModel(nn.Module): def sanitize(self, weights: dict) -> dict: sanitized = {} + for key, value in weights.items(): new_key = key + # Skip non-transformer weights (VAE, vocoder, audio_vae, connectors) + if not key.startswith("model.diffusion_model.") or "audio_embeddings_connector" in key or "video_embeddings_connector" in key: + continue + + # Remove 'model.diffusion_model.' prefix + new_key = new_key.replace("model.diffusion_model.", "") + + new_key = new_key.replace(".to_out.0.", ".to_out.") + + new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") + new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") + new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.") + new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.") + + new_key = new_key.replace(".linear_1.", ".linear1.") + new_key = new_key.replace(".linear_2.", ".linear2.") - # Handle common remappings - # transformer_blocks.X -> transformer_blocks[X] - if "transformer_blocks." in new_key: - # Keep as-is for now, MLX handles this - pass sanitized[new_key] = value return sanitized + @classmethod + def from_pretrained(cls, model_path: [Path, List[Path]], config: LTXModelConfig, strict: bool = True) -> None: + model = cls(config) + + weights = {} + if isinstance(model_path, Path): + model_path = [model_path] + for weight_file in model_path: + weights.update(mx.load(str(weight_file))) + + + sanitized = model.sanitize(weights) + sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()} + + model.load_weights(list(sanitized.items()), strict=strict) + mx.eval(model.parameters()) + model.eval() + return model + class X0Model(nn.Module):