Refactor LTXModel to include a from_pretrained class method for loading and sanitizing model weights. Update generate.py to utilize this method, streamlining the transformer loading process and improving code clarity.

This commit is contained in:
Prince Canuma
2026-01-20 12:56:29 +01:00
parent bbb3de6aa7
commit 2681f75d2f
2 changed files with 42 additions and 13 deletions

View File

@@ -24,7 +24,7 @@ console = Console()
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.transformer import Modality
from mlx_video.convert import sanitize_transformer_weights
from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding, get_model_path
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
@@ -891,9 +891,7 @@ def generate_video(
# Load transformer
transformer_desc = f"🤖 Loading {pipeline_name.lower()} transformer{' (A/V mode)' if audio else ''}..."
with console.status(f"[blue]{transformer_desc}[/]", spinner="dots"):
raw_weights = mx.load(str(model_path / weight_file))
sanitized = sanitize_transformer_weights(raw_weights)
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
model_type = LTXModelType.AudioVideo if audio else LTXModelType.VideoOnly
@@ -925,9 +923,9 @@ def generate_video(
)
config = LTXModelConfig(**config_kwargs)
transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False)
mx.eval(transformer.parameters())
transformer = LTXModel.from_pretrained(model_path=model_path/weight_file, config=config, strict=True)
console.print("[green]✓[/] Transformer loaded")
# ==========================================================================