Refactor LTXModel to include a from_pretrained class method for loading and sanitizing model weights. Update generate.py to utilize this method, streamlining the transformer loading process and improving code clarity.

This commit is contained in:
Prince Canuma
2026-01-20 12:56:29 +01:00
parent bbb3de6aa7
commit 2681f75d2f
2 changed files with 42 additions and 13 deletions

View File

@@ -24,7 +24,7 @@ console = Console()
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.transformer import Modality
from mlx_video.convert import sanitize_transformer_weights
from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding, get_model_path
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
@@ -891,9 +891,7 @@ def generate_video(
# Load transformer
transformer_desc = f"🤖 Loading {pipeline_name.lower()} transformer{' (A/V mode)' if audio else ''}..."
with console.status(f"[blue]{transformer_desc}[/]", spinner="dots"):
raw_weights = mx.load(str(model_path / weight_file))
sanitized = sanitize_transformer_weights(raw_weights)
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
model_type = LTXModelType.AudioVideo if audio else LTXModelType.VideoOnly
@@ -925,9 +923,9 @@ def generate_video(
)
config = LTXModelConfig(**config_kwargs)
transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False)
mx.eval(transformer.parameters())
transformer = LTXModel.from_pretrained(model_path=model_path/weight_file, config=config, strict=True)
console.print("[green]✓[/] Transformer loaded")
# ==========================================================================

View File

@@ -2,7 +2,7 @@ from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from pathlib import Path
from mlx_video.models.ltx.config import (
LTXModelConfig,
LTXModelType,
@@ -497,19 +497,50 @@ class LTXModel(nn.Module):
def sanitize(self, weights: dict) -> dict:
sanitized = {}
for key, value in weights.items():
new_key = key
# Skip non-transformer weights (VAE, vocoder, audio_vae, connectors)
if not key.startswith("model.diffusion_model.") or "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
continue
# Remove 'model.diffusion_model.' prefix
new_key = new_key.replace("model.diffusion_model.", "")
new_key = new_key.replace(".to_out.0.", ".to_out.")
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
new_key = new_key.replace(".linear_1.", ".linear1.")
new_key = new_key.replace(".linear_2.", ".linear2.")
# Handle common remappings
# transformer_blocks.X -> transformer_blocks[X]
if "transformer_blocks." in new_key:
# Keep as-is for now, MLX handles this
pass
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: [Path, List[Path]], config: LTXModelConfig, strict: bool = True) -> None:
model = cls(config)
weights = {}
if isinstance(model_path, Path):
model_path = [model_path]
for weight_file in model_path:
weights.update(mx.load(str(weight_file)))
sanitized = model.sanitize(weights)
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
model.load_weights(list(sanitized.items()), strict=strict)
mx.eval(model.parameters())
model.eval()
return model
class X0Model(nn.Module):