Files
mlx-video/mlx_video/models/wan/loading.py

184 lines
5.7 KiB
Python

"""Wan model loading utilities."""
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
def load_wan_model(model_path: Path, config, quantization: dict | None = None, loras: list | None = None):
"""Load and initialize WanModel, with optional quantization and LoRA support.
Args:
model_path: Path to model safetensors file
config: WanModelConfig
quantization: Optional dict with 'bits' and 'group_size' keys.
If provided, creates QuantizedLinear stubs before loading.
loras: Optional list of (lora_path, strength) tuples to apply.
"""
from mlx_video.models.wan.model import WanModel
model = WanModel(config)
if quantization:
from mlx_video.convert_wan import _quantize_predicate
nn.quantize(
model,
group_size=quantization["group_size"],
bits=quantization["bits"],
class_predicate=lambda path, m: _quantize_predicate(path, m),
)
weights = mx.load(str(model_path))
# Apply LoRAs: dequantize+merge for quantized models, weight merge for bf16
if loras:
if quantization:
# Dequantize LoRA-targeted layers, merge delta, replace with bf16 Linear.
# Non-LoRA layers stay 4-bit. Zero per-step overhead.
from mlx_video.convert_wan import _load_lora_configs
from mlx_video.lora import apply_loras_to_model
model.load_weights(list(weights.items()), strict=False)
mx.eval(model.parameters())
module_to_loras = _load_lora_configs(loras)
apply_loras_to_model(model, module_to_loras)
mx.eval(model.parameters())
return model
else:
# Weight merging: fold LoRA into bf16 weights before loading
from mlx_video.convert_wan import load_and_apply_loras
weights = load_and_apply_loras(dict(weights), loras)
model.load_weights(list(weights.items()), strict=False)
mx.eval(model.parameters())
return model
def load_t5_encoder(model_path: Path, config):
"""Load T5 text encoder.
Weights are upcast to float32 for maximum precision — the T5 encoder
only runs once per generation, so performance impact is negligible.
This matches the official which computes softmax in float32 explicitly.
"""
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=config.t5_vocab_size,
dim=config.t5_dim,
dim_attn=config.t5_dim_attn,
dim_ffn=config.t5_dim_ffn,
num_heads=config.t5_num_heads,
num_layers=config.t5_num_layers,
num_buckets=config.t5_num_buckets,
shared_pos=False,
)
weights = mx.load(str(model_path))
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
encoder.load_weights(list(weights.items()))
mx.eval(encoder.parameters())
return encoder
def load_vae_decoder(model_path: Path, config=None):
"""Load VAE decoder (skips encoder weights with strict=False).
For Wan2.2 (vae_z_dim=48), uses Wan22VAEDecoder.
For Wan2.1 (vae_z_dim=16), uses WanVAE.
"""
is_wan22 = config is not None and config.vae_z_dim == 48
if is_wan22:
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
vae = Wan22VAEDecoder(z_dim=48)
else:
from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16)
weights = mx.load(str(model_path))
# Upcast VAE weights to float32 for quality — official Wan2.2 runs VAE in float32
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
vae.load_weights(list(weights.items()), strict=False)
mx.eval(vae.parameters())
return vae
def load_vae_encoder(model_path: Path, config=None):
"""Load VAE encoder for I2V image encoding.
For Wan2.2 TI2V (vae_z_dim=48), uses Wan22VAEEncoder.
For Wan2.1/I2V-14B (vae_z_dim=16), uses WanVAE with encoder=True.
"""
if config is not None and config.vae_z_dim == 16:
from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16, encoder=True)
else:
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
vae = Wan22VAEEncoder(z_dim=config.vae_z_dim if config else 48)
weights = mx.load(str(model_path))
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
vae.load_weights(list(weights.items()), strict=False)
mx.eval(vae.parameters())
return vae
def _clean_text(text: str) -> str:
"""Clean text matching official Wan2.2 tokenizer preprocessing.
Applies ftfy.fix_text (fixes mojibake, normalizes fullwidth chars),
double HTML unescape, and whitespace normalization. Critical for
correct tokenization of the Chinese negative prompt.
"""
import html
import re
try:
import ftfy
text = ftfy.fix_text(text)
except ImportError:
pass
text = html.unescape(html.unescape(text))
text = re.sub(r"\s+", " ", text).strip()
return text
def encode_text(
encoder,
tokenizer,
prompt: str,
text_len: int = 512,
) -> mx.array:
"""Encode text prompt using T5 encoder.
Args:
encoder: T5Encoder model
tokenizer: HuggingFace tokenizer
prompt: Text prompt
text_len: Maximum text length
Returns:
Text embeddings [L, dim]
"""
prompt = _clean_text(prompt)
tokens = tokenizer(
prompt,
max_length=text_len,
padding="max_length",
truncation=True,
return_tensors="np",
)
ids = mx.array(tokens["input_ids"])
mask = mx.array(tokens["attention_mask"])
embeddings = encoder(ids, mask=mask)
# Return only non-padding tokens
seq_len = int(mask.sum().item())
return embeddings[0, :seq_len]