Refactor LTXModel to include a from_pretrained class method for loading and sanitizing model weights. Update generate.py to utilize this method, streamlining the transformer loading process and improving code clarity.
This commit is contained in:
@@ -24,7 +24,7 @@ console = Console()
|
||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
|
||||
from mlx_video.models.ltx.ltx import LTXModel
|
||||
from mlx_video.models.ltx.transformer import Modality
|
||||
from mlx_video.convert import sanitize_transformer_weights
|
||||
|
||||
from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding, get_model_path
|
||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
|
||||
@@ -891,9 +891,7 @@ def generate_video(
|
||||
# Load transformer
|
||||
transformer_desc = f"🤖 Loading {pipeline_name.lower()} transformer{' (A/V mode)' if audio else ''}..."
|
||||
with console.status(f"[blue]{transformer_desc}[/]", spinner="dots"):
|
||||
raw_weights = mx.load(str(model_path / weight_file))
|
||||
sanitized = sanitize_transformer_weights(raw_weights)
|
||||
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
|
||||
|
||||
|
||||
model_type = LTXModelType.AudioVideo if audio else LTXModelType.VideoOnly
|
||||
|
||||
@@ -925,9 +923,9 @@ def generate_video(
|
||||
)
|
||||
|
||||
config = LTXModelConfig(**config_kwargs)
|
||||
transformer = LTXModel(config)
|
||||
transformer.load_weights(list(sanitized.items()), strict=False)
|
||||
mx.eval(transformer.parameters())
|
||||
|
||||
transformer = LTXModel.from_pretrained(model_path=model_path/weight_file, config=config, strict=True)
|
||||
|
||||
console.print("[green]✓[/] Transformer loaded")
|
||||
|
||||
# ==========================================================================
|
||||
|
||||
Reference in New Issue
Block a user