Refactor LTXModel to include a from_pretrained class method for loading and sanitizing model weights. Update generate.py to utilize this method, streamlining the transformer loading process and improving code clarity.
This commit is contained in:
@@ -2,7 +2,7 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from pathlib import Path
|
||||
from mlx_video.models.ltx.config import (
|
||||
LTXModelConfig,
|
||||
LTXModelType,
|
||||
@@ -497,19 +497,50 @@ class LTXModel(nn.Module):
|
||||
|
||||
def sanitize(self, weights: dict) -> dict:
|
||||
sanitized = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
# Skip non-transformer weights (VAE, vocoder, audio_vae, connectors)
|
||||
if not key.startswith("model.diffusion_model.") or "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
|
||||
continue
|
||||
|
||||
# Remove 'model.diffusion_model.' prefix
|
||||
new_key = new_key.replace("model.diffusion_model.", "")
|
||||
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
|
||||
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
||||
|
||||
new_key = new_key.replace(".linear_1.", ".linear1.")
|
||||
new_key = new_key.replace(".linear_2.", ".linear2.")
|
||||
|
||||
# Handle common remappings
|
||||
# transformer_blocks.X -> transformer_blocks[X]
|
||||
if "transformer_blocks." in new_key:
|
||||
# Keep as-is for now, MLX handles this
|
||||
pass
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: [Path, List[Path]], config: LTXModelConfig, strict: bool = True) -> None:
|
||||
model = cls(config)
|
||||
|
||||
weights = {}
|
||||
if isinstance(model_path, Path):
|
||||
model_path = [model_path]
|
||||
for weight_file in model_path:
|
||||
weights.update(mx.load(str(weight_file)))
|
||||
|
||||
|
||||
sanitized = model.sanitize(weights)
|
||||
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
|
||||
|
||||
model.load_weights(list(sanitized.items()), strict=strict)
|
||||
mx.eval(model.parameters())
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class X0Model(nn.Module):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user