feat(wan): Add LoRA with improved quantization pipeline

This commit is contained in:
Daniel
2026-02-28 14:11:13 +01:00
parent dbab95ec45
commit 849cc45d84
17 changed files with 1852 additions and 111 deletions

View File

@@ -82,15 +82,15 @@ python -m mlx_video.generate \
Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.com/Wan-Video/Wan2.2) are text-to-video diffusion models built on a DiT (Diffusion Transformer) backbone with a T5 text encoder and 3D VAE. They share the same model architecture — the difference is in the inference pipeline: Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.com/Wan-Video/Wan2.2) are text-to-video diffusion models built on a DiT (Diffusion Transformer) backbone with a T5 text encoder and 3D VAE. They share the same model architecture — the difference is in the inference pipeline:
| | Wan2.1 | Wan2.2 T2V-14B | Wan2.2 I2V-14B | | | Wan2.1 | Wan2.2 T2V-14B | Wan2.2 I2V-14B | Wan2.2 TI2V-5B |
|---|--------|--------|--------| |---|--------|--------|--------|--------|
| **Task** | Text-to-Video | Text-to-Video | Image-to-Video | | **Task** | Text-to-Video | Text-to-Video | Image-to-Video | Text+Image-to-Video |
| **Pipeline** | Single model | Dual model | Dual model | | **Pipeline** | Single model | Dual model | Dual model | Single model |
| **Sizes** | 1.3B, 14B | 14B | 14B | | **Sizes** | 1.3B, 14B | 14B | 14B | 5B |
| **Steps** | 50 | 40 | 40 | | **Steps** | 50 | 40 | 40 | 40 |
| **Guidance** | 5.0 (fixed) | 3.0 / 4.0 | 3.5 / 3.5 | | **Guidance** | 5.0 (fixed) | 3.0 / 4.0 | 3.5 / 3.5 | 5.0 (fixed) |
| **Shift** | 5.0 | 12.0 | 5.0 | | **Shift** | 5.0 | 12.0 | 5.0 | 5.0 |
| **VAE** | Wan2.1 (z=16) | Wan2.1 (z=16) | Wan2.1 (z=16) + encoder | | **VAE** | Wan2.1 (z=16) | Wan2.1 (z=16) | Wan2.1 (z=16) + encoder | Wan2.2 (z=48) |
### Step 1: Download Weights ### Step 1: Download Weights

View File

@@ -1,13 +1,16 @@
"""Weight conversion for Wan2.2 models (PyTorch -> MLX).""" """Weight conversion for Wan2.2 models (PyTorch -> MLX)."""
import gc
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict, List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.utils import mlx.utils
import numpy as np import numpy as np
logger = logging.getLogger(__name__)
def load_torch_weights(path: str) -> Dict[str, mx.array]: def load_torch_weights(path: str) -> Dict[str, mx.array]:
"""Load PyTorch .pth weights and convert to MLX arrays. """Load PyTorch .pth weights and convert to MLX arrays.
@@ -88,6 +91,7 @@ def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str,
etc. etc.
""" """
sanitized = {} sanitized = {}
consumed = set()
for key, value in weights.items(): for key, value in weights.items():
new_key = key new_key = key
@@ -99,36 +103,43 @@ def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str,
value = value.reshape(value.shape[0], -1) value = value.reshape(value.shape[0], -1)
new_key = "patch_embedding_proj.weight" new_key = "patch_embedding_proj.weight"
sanitized[new_key] = value sanitized[new_key] = value
consumed.add(key)
continue continue
if key == "patch_embedding.bias": if key == "patch_embedding.bias":
new_key = "patch_embedding_proj.bias" new_key = "patch_embedding_proj.bias"
sanitized[new_key] = value sanitized[new_key] = value
consumed.add(key)
continue continue
# Text embedding Sequential: 0=Linear, 1=GELU(no params), 2=Linear # Text embedding Sequential: 0=Linear, 1=GELU(no params), 2=Linear
if key.startswith("text_embedding.0."): if key.startswith("text_embedding.0."):
new_key = key.replace("text_embedding.0.", "text_embedding_0.") new_key = key.replace("text_embedding.0.", "text_embedding_0.")
sanitized[new_key] = value sanitized[new_key] = value
consumed.add(key)
continue continue
if key.startswith("text_embedding.2."): if key.startswith("text_embedding.2."):
new_key = key.replace("text_embedding.2.", "text_embedding_1.") new_key = key.replace("text_embedding.2.", "text_embedding_1.")
sanitized[new_key] = value sanitized[new_key] = value
consumed.add(key)
continue continue
# Time embedding Sequential: 0=Linear, 1=SiLU(no params), 2=Linear # Time embedding Sequential: 0=Linear, 1=SiLU(no params), 2=Linear
if key.startswith("time_embedding.0."): if key.startswith("time_embedding.0."):
new_key = key.replace("time_embedding.0.", "time_embedding_0.") new_key = key.replace("time_embedding.0.", "time_embedding_0.")
sanitized[new_key] = value sanitized[new_key] = value
consumed.add(key)
continue continue
if key.startswith("time_embedding.2."): if key.startswith("time_embedding.2."):
new_key = key.replace("time_embedding.2.", "time_embedding_1.") new_key = key.replace("time_embedding.2.", "time_embedding_1.")
sanitized[new_key] = value sanitized[new_key] = value
consumed.add(key)
continue continue
# Time projection Sequential: 0=SiLU(no params), 1=Linear # Time projection Sequential: 0=SiLU(no params), 1=Linear
if key.startswith("time_projection.1."): if key.startswith("time_projection.1."):
new_key = key.replace("time_projection.1.", "time_projection.") new_key = key.replace("time_projection.1.", "time_projection.")
sanitized[new_key] = value sanitized[new_key] = value
consumed.add(key)
continue continue
# FFN: Sequential(Linear, GELU, Linear) -> ffn.{0,2} -> ffn.fc1, ffn.fc2 # FFN: Sequential(Linear, GELU, Linear) -> ffn.{0,2} -> ffn.fc1, ffn.fc2
@@ -137,9 +148,15 @@ def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str,
# Skip the freqs buffer (we compute it in the model) # Skip the freqs buffer (we compute it in the model)
if key == "freqs": if key == "freqs":
consumed.add(key)
continue continue
sanitized[new_key] = value sanitized[new_key] = value
consumed.add(key)
unconsumed = set(weights.keys()) - consumed
if unconsumed:
logger.warning("Unconsumed transformer weight keys: %s", sorted(unconsumed))
return sanitized return sanitized
@@ -171,6 +188,7 @@ def sanitize_wan_t5_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]
norm.weight norm.weight
""" """
sanitized = {} sanitized = {}
consumed = set()
for key, value in weights.items(): for key, value in weights.items():
new_key = key new_key = key
@@ -179,6 +197,11 @@ def sanitize_wan_t5_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]
new_key = new_key.replace(".ffn.gate.0.", ".ffn.gate_proj.") new_key = new_key.replace(".ffn.gate.0.", ".ffn.gate_proj.")
sanitized[new_key] = value sanitized[new_key] = value
consumed.add(key)
unconsumed = set(weights.keys()) - consumed
if unconsumed:
logger.warning("Unconsumed T5 weight keys: %s", sorted(unconsumed))
return sanitized return sanitized
@@ -189,6 +212,7 @@ def sanitize_wan_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array
Handles Conv3d and Conv2d weight transpositions for MLX format. Handles Conv3d and Conv2d weight transpositions for MLX format.
""" """
sanitized = {} sanitized = {}
consumed = set()
for key, value in weights.items(): for key, value in weights.items():
new_key = key new_key = key
@@ -206,10 +230,78 @@ def sanitize_wan_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array
# Need to adapt naming for our simplified structure # Need to adapt naming for our simplified structure
sanitized[new_key] = value sanitized[new_key] = value
consumed.add(key)
unconsumed = set(weights.keys()) - consumed
if unconsumed:
logger.warning("Unconsumed VAE weight keys: %s", sorted(unconsumed))
return sanitized return sanitized
def _load_lora_configs(
lora_configs: List[Tuple[str, float]],
) -> Dict[str, list]:
"""Load LoRA weights from config tuples, returning module_to_loras dict.
Shared between weight-merging and runtime-wrapping paths.
"""
from mlx_video.lora import LoRAConfig, load_multiple_loras
from mlx_video.utils import Colors
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
configs = []
for lora_path, strength in lora_configs:
try:
config = LoRAConfig(path=lora_path, strength=strength)
configs.append(config)
print(f" - {Path(lora_path).name} (strength: {strength})")
except Exception as e:
print(f"{Colors.RED}Error loading LoRA {lora_path}: {e}{Colors.RESET}")
raise
module_to_loras = load_multiple_loras(configs)
if not module_to_loras:
print(f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}")
return module_to_loras
def load_and_apply_loras(
model_weights: Dict[str, mx.array],
lora_configs: Optional[List[Tuple[str, float]]] = None,
verbose: bool = False,
quantization_bits: int = 0,
) -> Dict[str, mx.array]:
"""Load and apply LoRA weights to model weights by merging into weight dict.
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
"""
from mlx_video.lora import apply_loras_to_weights
from mlx_video.utils import Colors
if not lora_configs:
return model_weights
module_to_loras = _load_lora_configs(lora_configs)
if not module_to_loras:
return model_weights
print(f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}")
if verbose:
print(f" Model has {len(model_weights)} weight keys")
modified_weights = apply_loras_to_weights(
model_weights, module_to_loras, verbose=verbose, quantization_bits=quantization_bits
)
print(f"{Colors.GREEN}✓ LoRAs applied successfully{Colors.RESET}")
return modified_weights
def convert_wan_checkpoint( def convert_wan_checkpoint(
checkpoint_dir: str, checkpoint_dir: str,
output_dir: str, output_dir: str,
@@ -464,30 +556,45 @@ def _quantize_saved_model(
is_dual: bool, is_dual: bool,
bits: int, bits: int,
group_size: int, group_size: int,
source_dir: Path = None,
): ):
"""Load saved bf16 model, quantize, and re-save.""" """Load saved bf16 model, quantize, and re-save.
Args:
output_dir: Directory to write quantized weights to.
config: WanModelConfig for creating the model.
is_dual: Whether this is a dual-expert model.
bits: Quantization bits.
group_size: Quantization group size.
source_dir: Directory to read bf16 weights from. Defaults to output_dir.
"""
import json import json
import mlx.nn as nn import mlx.nn as nn
from mlx_video.models.wan.model import WanModel from mlx_video.models.wan.model import WanModel
model_files = [] if source_dir is None:
source_dir = output_dir
model_names = []
if is_dual: if is_dual:
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]: for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
p = output_dir / name if (source_dir / name).exists():
if p.exists(): model_names.append(name)
model_files.append(p)
else: else:
p = output_dir / "model.safetensors" if (source_dir / "model.safetensors").exists():
if p.exists(): model_names.append("model.safetensors")
model_files.append(p)
for model_path in model_files: for name in model_names:
print(f" Quantizing {model_path.name}...") print(f" Quantizing {name}...")
model = WanModel(config) model = WanModel(config)
weights = mx.load(str(model_path)) weights = mx.load(str(source_dir / name))
model.load_weights(list(weights.items()), strict=False) model.load_weights(list(weights.items()), strict=False)
mx.eval(model.parameters())
del weights
gc.collect()
mx.clear_cache()
# Apply quantization to targeted layers # Apply quantization to targeted layers
nn.quantize( nn.quantize(
@@ -499,10 +606,30 @@ def _quantize_saved_model(
# Save quantized weights # Save quantized weights
weights_dict = dict(mlx.utils.tree_flatten(model.parameters())) weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
mx.save_safetensors(str(model_path), weights_dict)
# Validate: check for NaN/Inf in bias tensors (corruption canary)
bad_keys = []
for k, v in weights_dict.items():
if k.endswith(".bias") and not k.endswith(".biases"):
mx.eval(v)
if mx.any(mx.isnan(v)).item() or mx.any(mx.isinf(v)).item():
bad_keys.append(k)
if bad_keys:
raise RuntimeError(
f"Quantization produced corrupted weights in {model_path.name}: "
f"{len(bad_keys)} bias tensors contain NaN/Inf "
f"(e.g. {bad_keys[0]}). Try re-running with more available memory."
)
mx.save_safetensors(str(output_dir / name), weights_dict)
n_quantized = sum(1 for k in weights_dict if ".scales" in k) n_quantized = sum(1 for k in weights_dict if ".scales" in k)
print(f" {n_quantized} layers quantized, {len(weights_dict)} tensors saved") print(f" {n_quantized} layers quantized, {len(weights_dict)} tensors saved")
# Free model before processing next file
del model, weights_dict
gc.collect()
mx.clear_cache()
# Update config.json with quantization metadata # Update config.json with quantization metadata
config_path = output_dir / "config.json" config_path = output_dir / "config.json"
with open(config_path) as f: with open(config_path) as f:
@@ -516,6 +643,68 @@ def _quantize_saved_model(
print(f" Updated config.json with quantization metadata") print(f" Updated config.json with quantization metadata")
def quantize_mlx_model(
mlx_model_dir: str,
output_dir: str,
bits: int = 4,
group_size: int = 64,
):
"""Quantize an already-converted MLX model (skips PyTorch conversion).
Args:
mlx_model_dir: Path to existing MLX model directory (bf16/fp16).
output_dir: Path to output quantized model directory.
bits: Quantization bits (4 or 8).
group_size: Quantization group size (32, 64, or 128).
"""
import json
import shutil
src = Path(mlx_model_dir)
dst = Path(output_dir)
config_path = src / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"No config.json found in {src}")
with open(config_path) as f:
cfg = json.load(f)
if cfg.get("quantization"):
raise ValueError(
f"Model at {src} is already quantized "
f"({cfg['quantization']['bits']}-bit). Use a bf16/fp16 source."
)
# Detect dual vs single expert
is_dual = (src / "low_noise_model.safetensors").exists() and (
src / "high_noise_model.safetensors"
).exists()
# Build model config
from mlx_video.models.wan.config import WanModelConfig
config_dict = {k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__}
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
if key in config_dict and isinstance(config_dict[key], list):
config_dict[key] = tuple(config_dict[key])
config = WanModelConfig(**config_dict)
# Copy non-transformer files to output dir (skip large model weights)
transformer_files = {"low_noise_model.safetensors", "high_noise_model.safetensors", "model.safetensors"}
if dst.resolve() != src.resolve():
dst.mkdir(parents=True, exist_ok=True)
for f in src.iterdir():
if f.is_file() and f.name not in transformer_files:
shutil.copy2(f, dst / f.name)
print(f"Copied non-transformer files from {src} to {dst}")
print(f"Quantizing transformer weights ({bits}-bit, group_size={group_size})...")
_quantize_saved_model(dst, config, is_dual, bits, group_size, source_dir=src)
print(f"\nQuantization complete! Output: {dst}")
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
@@ -551,6 +740,11 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="Quantize transformer weights for faster inference", help="Quantize transformer weights for faster inference",
) )
parser.add_argument(
"--quantize-only",
action="store_true",
help="Quantize an already-converted MLX model (skips PyTorch conversion)",
)
parser.add_argument( parser.add_argument(
"--bits", "--bits",
type=int, type=int,
@@ -566,6 +760,13 @@ if __name__ == "__main__":
help="Quantization group size (default: 64)", help="Quantization group size (default: 64)",
) )
args = parser.parse_args() args = parser.parse_args()
if args.quantize_only:
quantize_mlx_model(
args.checkpoint_dir, args.output_dir,
bits=args.bits, group_size=args.group_size,
)
else:
convert_wan_checkpoint( convert_wan_checkpoint(
args.checkpoint_dir, args.output_dir, args.dtype, args.model_version, args.checkpoint_dir, args.output_dir, args.dtype, args.model_version,
quantize=args.quantize, bits=args.bits, group_size=args.group_size, quantize=args.quantize, bits=args.bits, group_size=args.group_size,

View File

@@ -43,6 +43,10 @@ def generate_video(
seed: int = -1, seed: int = -1,
output_path: str = "output.mp4", output_path: str = "output.mp4",
scheduler: str = "unipc", scheduler: str = "unipc",
loras: list | None = None,
loras_high: list | None = None,
loras_low: list | None = None,
): ):
"""Generate video using Wan pipeline (supports T2V and I2V). """Generate video using Wan pipeline (supports T2V and I2V).
@@ -60,6 +64,10 @@ def generate_video(
seed: Random seed (-1 for random) seed: Random seed (-1 for random)
output_path: Output video path output_path: Output video path
scheduler: Solver type: 'euler', 'dpm++', or 'unipc' (default) scheduler: Solver type: 'euler', 'dpm++', or 'unipc' (default)
loras: Optional list of (path, strength) tuples applied to all models
loras_high: Optional list of (path, strength) tuples for high-noise model only
loras_low: Optional list of (path, strength) tuples for low-noise model only
""" """
import json import json
@@ -156,6 +164,12 @@ def generate_video(
parts = [float(x) for x in guide_scale.split(",")] parts = [float(x) for x in guide_scale.split(",")]
guide_scale = tuple(parts) if len(parts) > 1 else parts[0] guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
# Detect CFG-disabled mode (guide_scale=1.0 for all models → skip uncond pass for 2x speedup)
if isinstance(guide_scale, tuple):
cfg_disabled = all(gs <= 1.0 for gs in guide_scale)
else:
cfg_disabled = guide_scale <= 1.0
# Validate frame count # Validate frame count
assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}" assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}"
@@ -181,6 +195,8 @@ def generate_video(
print(f" Neg prompt: {neg_display}") print(f" Neg prompt: {neg_display}")
print(f" Size: {width}x{height}, Frames: {num_frames}") print(f" Size: {width}x{height}, Frames: {num_frames}")
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}") print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}")
if cfg_disabled:
print(f" CFG: disabled (guide_scale≤1 → B=1 fast path, 2x denoising speedup)")
print(f"{Colors.RESET}") print(f"{Colors.RESET}")
# Seed # Seed
@@ -233,6 +249,10 @@ def generate_video(
# Encode prompts # Encode prompts
print(f"{Colors.BLUE}Encoding text...{Colors.RESET}") print(f"{Colors.BLUE}Encoding text...{Colors.RESET}")
context = encode_text(t5_encoder, tokenizer, prompt, config.text_len) context = encode_text(t5_encoder, tokenizer, prompt, config.text_len)
if cfg_disabled:
context_null = None
mx.eval(context)
else:
context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len) context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len)
mx.eval(context, context_null) mx.eval(context, context_null)
@@ -319,17 +339,35 @@ def generate_video(
print(f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}") print(f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}")
t2 = time.time() t2 = time.time()
# Merge per-model LoRAs with shared LoRAs
_loras_low = (loras or []) + (loras_low or []) or None
_loras_high = (loras or []) + (loras_high or []) or None
_loras_single = loras
if is_dual: if is_dual:
low_noise_path = model_dir / "low_noise_model.safetensors" low_noise_path = model_dir / "low_noise_model.safetensors"
high_noise_path = model_dir / "high_noise_model.safetensors" high_noise_path = model_dir / "high_noise_model.safetensors"
low_noise_model = load_wan_model(low_noise_path, config, quantization) low_noise_model = load_wan_model(low_noise_path, config, quantization, loras=_loras_low)
high_noise_model = load_wan_model(high_noise_path, config, quantization) high_noise_model = load_wan_model(high_noise_path, config, quantization, loras=_loras_high)
else: else:
single_model = load_wan_model(model_dir / "model.safetensors", config, quantization) single_model = load_wan_model(model_dir / "model.safetensors", config, quantization, loras=_loras_single)
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}") print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
# Precompute text embeddings once (avoids redundant MLP in every step) # Precompute text embeddings once (avoids redundant MLP in every step)
# Each model has its own text_embedding weights, so dual models need separate embeddings # Each model has its own text_embedding weights, so dual models need separate embeddings
if cfg_disabled:
# No CFG: only compute cond embeddings (B=1 forward pass, 2x faster)
if is_dual:
context_emb_low = low_noise_model.embed_text([context])
context_emb_high = high_noise_model.embed_text([context])
mx.eval(context_emb_low, context_emb_high)
context_cond_low = context_emb_low[0:1]
context_cond_high = context_emb_high[0:1]
else:
context_emb = single_model.embed_text([context])
mx.eval(context_emb)
context_cond = context_emb[0:1]
else:
if is_dual: if is_dual:
context_emb_low = low_noise_model.embed_text([context, context_null]) context_emb_low = low_noise_model.embed_text([context, context_null])
context_emb_high = high_noise_model.embed_text([context, context_null]) context_emb_high = high_noise_model.embed_text([context, context_null])
@@ -342,6 +380,15 @@ def generate_video(
context_cfg = mx.concatenate([context_emb[0:1], context_emb[1:2]], axis=0) context_cfg = mx.concatenate([context_emb[0:1], context_emb[1:2]], axis=0)
# Precompute cross-attention K/V caches (constant across all steps) # Precompute cross-attention K/V caches (constant across all steps)
if cfg_disabled:
if is_dual:
cross_kv_low = low_noise_model.prepare_cross_kv(context_cond_low)
cross_kv_high = high_noise_model.prepare_cross_kv(context_cond_high)
mx.eval(cross_kv_low, cross_kv_high)
else:
cross_kv = single_model.prepare_cross_kv(context_cond)
mx.eval(cross_kv)
else:
if is_dual: if is_dual:
cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg_low) cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg_low)
cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg_high) cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg_high)
@@ -354,13 +401,16 @@ def generate_video(
f_grid = t_latent // patch_size[0] f_grid = t_latent // patch_size[0]
h_grid = h_latent // patch_size[1] h_grid = h_latent // patch_size[1]
w_grid = w_latent // patch_size[2] w_grid = w_latent // patch_size[2]
cfg_grid_sizes = [(f_grid, h_grid, w_grid), (f_grid, h_grid, w_grid)] if cfg_disabled:
rope_grid_sizes = [(f_grid, h_grid, w_grid)]
else:
rope_grid_sizes = [(f_grid, h_grid, w_grid), (f_grid, h_grid, w_grid)]
if is_dual: if is_dual:
rope_cos_sin_low = low_noise_model.prepare_rope(cfg_grid_sizes) rope_cos_sin_low = low_noise_model.prepare_rope(rope_grid_sizes)
rope_cos_sin_high = high_noise_model.prepare_rope(cfg_grid_sizes) rope_cos_sin_high = high_noise_model.prepare_rope(rope_grid_sizes)
mx.eval(rope_cos_sin_low, rope_cos_sin_high) mx.eval(rope_cos_sin_low, rope_cos_sin_high)
else: else:
rope_cos_sin = ref_model.prepare_rope(cfg_grid_sizes) rope_cos_sin = ref_model.prepare_rope(rope_grid_sizes)
mx.eval(rope_cos_sin) mx.eval(rope_cos_sin)
# Setup scheduler # Setup scheduler
@@ -395,42 +445,71 @@ def generate_video(
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")): for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
timestep_val = timestep_list[i] timestep_val = timestep_list[i]
# Select model, guide scale, cached K/V, and precomputed RoPE # Select model, cached K/V, and precomputed RoPE
if is_dual: if is_dual:
if timestep_val >= boundary: if timestep_val >= boundary:
model = high_noise_model model = high_noise_model
gs = guide_scale[1]
kv = cross_kv_high kv = cross_kv_high
rcs = rope_cos_sin_high rcs = rope_cos_sin_high
else: else:
model = low_noise_model model = low_noise_model
gs = guide_scale[0]
kv = cross_kv_low kv = cross_kv_low
rcs = rope_cos_sin_low rcs = rope_cos_sin_low
else: else:
model = single_model model = single_model
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
kv = cross_kv kv = cross_kv
rcs = rope_cos_sin rcs = rope_cos_sin
# Build per-token timesteps for TI2V-5B (first-frame patches get t=0) if cfg_disabled:
# No CFG: B=1 forward pass (2x faster than B=2 CFG batch)
if is_i2v_mask_blend: if is_i2v_mask_blend:
t_tokens = i2v_mask_tokens * timestep_val # [1, L] t_tokens = i2v_mask_tokens * timestep_val
# Pad to seq_len if needed
pad_len = seq_len - t_tokens.shape[1] pad_len = seq_len - t_tokens.shape[1]
if pad_len > 0: if pad_len > 0:
t_tokens = mx.concatenate( t_tokens = mx.concatenate(
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1 [t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
) )
# Batch for CFG: both cond and uncond get same timesteps t_batch = t_tokens # [1, L]
t_batch = mx.concatenate([t_tokens, t_tokens], axis=0) # [2, L] else:
t_batch = mx.array([timestep_val])
y_arg = [y_i2v] if is_i2v_channel_concat else None
if is_dual:
ctx = context_cond_high if timestep_val >= boundary else context_cond_low
else:
ctx = context_cond
preds = model(
[latents],
t=t_batch,
context=ctx,
seq_len=seq_len,
cross_kv_caches=kv,
y=y_arg,
rope_cos_sin=rcs,
)
noise_pred = preds[0]
del preds
else:
# CFG: batch cond + uncond into single B=2 forward pass
if is_dual:
gs = guide_scale[1] if timestep_val >= boundary else guide_scale[0]
else:
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
if is_i2v_mask_blend:
t_tokens = i2v_mask_tokens * timestep_val
pad_len = seq_len - t_tokens.shape[1]
if pad_len > 0:
t_tokens = mx.concatenate(
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
)
t_batch = mx.concatenate([t_tokens, t_tokens], axis=0)
else: else:
t_batch = mx.array([timestep_val, timestep_val]) t_batch = mx.array([timestep_val, timestep_val])
# I2V-14B: pass y conditioning to model (same y for cond and uncond)
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
# CFG: batch cond + uncond into single B=2 forward pass
ctx = context_cfg if not is_dual else ( ctx = context_cfg if not is_dual else (
context_cfg_high if timestep_val >= boundary else context_cfg_low context_cfg_high if timestep_val >= boundary else context_cfg_low
) )
@@ -444,9 +523,8 @@ def generate_video(
rope_cos_sin=rcs, rope_cos_sin=rcs,
) )
noise_pred_cond, noise_pred_uncond = preds[0], preds[1] noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
# Classifier-free guidance + scheduler step
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond) noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
del noise_pred_cond, noise_pred_uncond, preds
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0) latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
@@ -455,7 +533,7 @@ def generate_video(
latents = (1.0 - i2v_mask) * z_img + i2v_mask * latents latents = (1.0 - i2v_mask) * z_img + i2v_mask * latents
# Release temporaries before eval to free memory for graph execution # Release temporaries before eval to free memory for graph execution
del noise_pred_cond, noise_pred_uncond, noise_pred, preds del noise_pred
mx.eval(latents) mx.eval(latents)
print(f"{Colors.DIM} Denoising: {time.time() - t3:.1f}s{Colors.RESET}") print(f"{Colors.DIM} Denoising: {time.time() - t3:.1f}s{Colors.RESET}")
@@ -463,11 +541,19 @@ def generate_video(
# Free transformer models and text embeddings # Free transformer models and text embeddings
if is_dual: if is_dual:
del low_noise_model, high_noise_model, cross_kv_low, cross_kv_high del low_noise_model, high_noise_model, cross_kv_low, cross_kv_high
if cfg_disabled:
del context_cond_low, context_cond_high
else:
del context_cfg_low, context_cfg_high del context_cfg_low, context_cfg_high
else: else:
del single_model, cross_kv del single_model, cross_kv
if cfg_disabled:
del context_cond
else:
del context_cfg del context_cfg
del model, kv, context, context_null del model, kv, context
if context_null is not None:
del context_null
gc.collect(); mx.clear_cache() gc.collect(); mx.clear_cache()
# Load VAE and decode # Load VAE and decode
@@ -478,25 +564,36 @@ def generate_video(
is_wan22_vae = config.vae_z_dim == 48 is_wan22_vae = config.vae_z_dim == 48
# Warm-up: prepend a copy of the first latent frame to provide temporal
# context for the real first frame. Causal convolutions in the VAE decoder
# pad with zeros on the left, so the first few output frames have degraded
# quality (no temporal context). By duplicating the first latent, the real
# first frame sees its own features as left context instead of zeros.
# We trim the extra output frames after decoding.
warmup_trim = vae_stride[0] # 4 frames per latent temporal position
latents_for_decode = mx.concatenate([latents[:, 0:1], latents], axis=1)
if is_wan22_vae: if is_wan22_vae:
from mlx_video.models.wan.vae22 import denormalize_latents from mlx_video.models.wan.vae22 import denormalize_latents
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE) # latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
z = latents.transpose(1, 2, 3, 0)[None] # [1, T, H, W, C] z = latents_for_decode.transpose(1, 2, 3, 0)[None] # [1, T+1, H, W, C]
z = denormalize_latents(z) z = denormalize_latents(z)
video = vae(z) # [1, T', H', W', 3] video = vae(z) # [1, T', H', W', 3]
mx.eval(video) mx.eval(video)
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}") print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
video = np.array(video[0]) # [T', H', W', 3] video = np.array(video[0]) # [T', H', W', 3]
video = video[warmup_trim:] # Trim warm-up frames
video = (video + 1.0) / 2.0 video = (video + 1.0) / 2.0
video = np.clip(video * 255.0, 0, 255).astype(np.uint8) video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
else: else:
video = vae.decode(latents[None]) # [1, 3, T, H, W] video = vae.decode(latents_for_decode[None]) # [1, 3, T+1*4, H, W]
mx.eval(video) mx.eval(video)
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}") print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
video = np.array(video[0]) # [3, T, H, W] video = np.array(video[0]) # [3, T', H, W]
video = video[:, warmup_trim:] # Trim warm-up frames (channels-first)
video = (video + 1.0) / 2.0 video = (video + 1.0) / 2.0
video = np.clip(video * 255.0, 0, 255).astype(np.uint8) video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
video = video.transpose(1, 2, 3, 0) # [T, H, W, 3] video = video.transpose(1, 2, 3, 0) # [T, H, W, 3]
@@ -529,6 +626,19 @@ def main():
choices=["euler", "dpm++", "unipc"], choices=["euler", "dpm++", "unipc"],
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)", help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
) )
parser.add_argument(
"--lora", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to all models (repeatable). Format: --lora path.safetensors 0.8",
)
parser.add_argument(
"--lora-high", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to high-noise model only (dual-model, repeatable)",
)
parser.add_argument(
"--lora-low", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to low-noise model only (dual-model, repeatable)",
)
args = parser.parse_args() args = parser.parse_args()
# Parse guide scale # Parse guide scale
@@ -542,6 +652,12 @@ def main():
if args.no_negative_prompt: if args.no_negative_prompt:
neg_prompt = "" neg_prompt = ""
# Parse LoRA configs: convert [path, strength_str] → (path, float)
def _parse_lora_args(lora_list):
if not lora_list:
return None
return [(path, float(strength)) for path, strength in lora_list]
generate_video( generate_video(
model_dir=args.model_dir, model_dir=args.model_dir,
prompt=args.prompt, prompt=args.prompt,
@@ -556,6 +672,10 @@ def main():
seed=args.seed, seed=args.seed,
output_path=args.output_path, output_path=args.output_path,
scheduler=args.scheduler, scheduler=args.scheduler,
loras=_parse_lora_args(args.lora),
loras_high=_parse_lora_args(args.lora_high),
loras_low=_parse_lora_args(args.lora_low),
) )

View File

@@ -0,0 +1,25 @@
"""LoRA support for mlx-video."""
from mlx_video.lora.apply import (
LoRALinear,
apply_lora_to_linear,
apply_loras_to_model,
apply_loras_to_weights,
)
from mlx_video.lora.loader import (
load_lora_weights,
load_multiple_loras,
)
from mlx_video.lora.types import AppliedLoRA, LoRAConfig, LoRAWeights
__all__ = [
"LoRAConfig",
"LoRAWeights",
"AppliedLoRA",
"load_lora_weights",
"load_multiple_loras",
"apply_lora_to_linear",
"apply_loras_to_weights",
"apply_loras_to_model",
"LoRALinear",
]

393
mlx_video/lora/apply.py Normal file
View File

@@ -0,0 +1,393 @@
"""Apply LoRA weights to model layers."""
from typing import Dict, List, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.lora.types import LoRAWeights
def apply_lora_to_linear(
linear_weight: mx.array,
lora_weights_and_strengths: List[Tuple[LoRAWeights, float]],
) -> mx.array:
"""Apply one or more LoRAs to a linear layer weight.
Args:
linear_weight: Original weight matrix [out_features, in_features]
lora_weights_and_strengths: List of (LoRAWeights, strength) tuples
Returns:
Modified weight with LoRA deltas applied (preserves original dtype)
"""
orig_dtype = linear_weight.dtype
modified_weight = linear_weight
for weights, strength in lora_weights_and_strengths:
scale = weights.scale
# Compute delta in float32 for precision, then cast back to avoid
# promoting model weights (e.g. bfloat16 → float32 causes ~1.5x slowdown)
delta = (weights.lora_B @ weights.lora_A) * (scale * strength)
modified_weight = modified_weight + delta.astype(orig_dtype)
return modified_weight
def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str:
"""Normalize LoRA module name to match Wan2.2 MLX model weight keys.
Handles:
- Stripping common prefixes (diffusion_model., model., etc.)
- FFN key mapping: ffn.0 → ffn.fc1, ffn.2 → ffn.fc2
- Embedding key mapping: text_embedding.0 → text_embedding_0, etc.
- Time projection: time_projection.1 → time_projection
- Patch embedding: patch_embedding → patch_embedding_proj
Args:
lora_key: Original LoRA module name
model_keys: Set of all model weight keys
Returns:
Normalized key that matches model weights
"""
# Try the key as-is first
if f"{lora_key}.weight" in model_keys or lora_key in model_keys:
return lora_key
# Common prefixes to strip
prefixes_to_strip = [
"model.diffusion_model.",
"diffusion_model.",
"base_model.model.",
"model.",
]
candidates = [lora_key]
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
candidates.append(lora_key[len(prefix):])
for candidate in candidates:
# Try as-is
if f"{candidate}.weight" in model_keys or candidate in model_keys:
return candidate
# Apply Wan2.2 key transformations
transformed = candidate
# FFN: ffn.0 → ffn.fc1, ffn.2 → ffn.fc2
transformed = transformed.replace(".ffn.0.", ".ffn.fc1.")
transformed = transformed.replace(".ffn.2.", ".ffn.fc2.")
if transformed.endswith(".ffn.0"):
transformed = transformed[:-len(".ffn.0")] + ".ffn.fc1"
if transformed.endswith(".ffn.2"):
transformed = transformed[:-len(".ffn.2")] + ".ffn.fc2"
# Text embedding: text_embedding.0 → text_embedding_0
transformed = transformed.replace("text_embedding.0.", "text_embedding_0.")
transformed = transformed.replace("text_embedding.2.", "text_embedding_1.")
if transformed.endswith("text_embedding.0"):
transformed = transformed[:-len("text_embedding.0")] + "text_embedding_0"
if transformed.endswith("text_embedding.2"):
transformed = transformed[:-len("text_embedding.2")] + "text_embedding_1"
# Time embedding: time_embedding.0 → time_embedding_0
transformed = transformed.replace("time_embedding.0.", "time_embedding_0.")
transformed = transformed.replace("time_embedding.2.", "time_embedding_1.")
if transformed.endswith("time_embedding.0"):
transformed = transformed[:-len("time_embedding.0")] + "time_embedding_0"
if transformed.endswith("time_embedding.2"):
transformed = transformed[:-len("time_embedding.2")] + "time_embedding_1"
# Time projection: time_projection.1 → time_projection
transformed = transformed.replace("time_projection.1.", "time_projection.")
if transformed.endswith("time_projection.1"):
transformed = transformed[:-len("time_projection.1")] + "time_projection"
# Patch embedding: patch_embedding → patch_embedding_proj
if "patch_embedding" in transformed and "patch_embedding_proj" not in transformed:
transformed = transformed.replace("patch_embedding", "patch_embedding_proj")
if f"{transformed}.weight" in model_keys or transformed in model_keys:
return transformed
# Return best attempt with prefix stripped
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
return lora_key[len(prefix):]
return lora_key
# Also support LTX-style key normalization
def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str:
"""Normalize LoRA module name to match LTX MLX model weight keys."""
if f"{lora_key}.weight" in model_keys or lora_key in model_keys:
return lora_key
prefixes_to_strip = [
"model.diffusion_model.",
"diffusion_model.",
"model.",
]
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
normalized = lora_key[len(prefix):]
if f"{normalized}.weight" in model_keys or normalized in model_keys:
return normalized
transformed = normalized
if transformed.endswith(".to_out.0"):
transformed = transformed[:-len(".to_out.0")] + ".to_out"
transformed = transformed.replace(".to_out.0.", ".to_out.")
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.")
transformed = transformed.replace(".ff.net.2", ".ff.proj_out")
transformed = transformed.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
transformed = transformed.replace(".audio_ff.net.0.proj", ".audio_ff.proj_in")
transformed = transformed.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
transformed = transformed.replace(".audio_ff.net.2", ".audio_ff.proj_out")
if f"{transformed}.weight" in model_keys or transformed in model_keys:
return transformed
# Try transformations on the original key
transformed = lora_key
if transformed.endswith(".to_out.0"):
transformed = transformed[:-len(".to_out.0")] + ".to_out"
transformed = transformed.replace(".to_out.0.", ".to_out.")
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.")
transformed = transformed.replace(".ff.net.2", ".ff.proj_out")
if f"{transformed}.weight" in model_keys or transformed in model_keys:
return transformed
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
return lora_key[len(prefix):]
return lora_key
def _normalize_lora_key(lora_key: str, model_keys: set) -> str:
"""Normalize LoRA module name to match model weight keys.
Auto-detects whether to use Wan2.2 or LTX key normalization based
on the presence of architecture-specific keys in the model.
"""
# Detect model architecture from keys
is_wan = any("self_attn.q.weight" in k for k in model_keys)
if is_wan:
return _normalize_wan_lora_key(lora_key, model_keys)
else:
return _normalize_ltx_lora_key(lora_key, model_keys)
def apply_loras_to_weights(
model_weights: Dict[str, mx.array],
module_to_loras: Dict[str, List[Tuple[LoRAWeights, float]]],
verbose: bool = False,
quantization_bits: int = 0,
) -> Dict[str, mx.array]:
"""Apply LoRAs to model weights.
Args:
model_weights: Original model state dictionary
module_to_loras: Dictionary mapping module names to lists of
(LoRAWeights, strength) tuples
verbose: If True, print detailed debug information
quantization_bits: If >0, weights are quantized at this bit width.
Quantized layers are dequantized before LoRA application
and re-quantized after.
Returns:
New state dictionary with LoRA-modified weights
"""
modified_weights = dict(model_weights)
model_keys = set(model_weights.keys())
applied_count = 0
skipped_count = 0
skipped_modules = []
for module_name, loras in module_to_loras.items():
normalized_name = _normalize_lora_key(module_name, model_keys)
weight_key = f"{normalized_name}.weight"
if weight_key not in modified_weights:
if normalized_name not in modified_weights:
skipped_count += 1
skipped_modules.append(module_name)
if verbose and skipped_count <= 5:
print(f" DEBUG: '{module_name}' -> '{normalized_name}' -> NOT FOUND")
similar = [
k
for k in list(model_keys)[:1000]
if normalized_name.split(".")[-1] in k
][:3]
if similar:
print(f" Similar keys: {similar}")
continue
weight_key = normalized_name
original_weight = modified_weights[weight_key]
# Handle quantized weights: dequantize → apply delta → re-quantize
scales_key = f"{normalized_name}.scales"
biases_key = f"{normalized_name}.biases"
is_quantized = (
original_weight.dtype == mx.uint32
and scales_key in modified_weights
and biases_key in modified_weights
)
if is_quantized:
scales = modified_weights[scales_key]
biases = modified_weights[biases_key]
group_size = (original_weight.shape[-1] * 32) // (scales.shape[-1] * quantization_bits)
dequantized = mx.dequantize(
original_weight, scales, biases, group_size=group_size, bits=quantization_bits
)
modified = apply_lora_to_linear(dequantized, loras)
# Re-quantize with same parameters
new_w, new_scales, new_biases = mx.quantize(modified, group_size=group_size, bits=quantization_bits)
modified_weights[weight_key] = new_w
modified_weights[scales_key] = new_scales
modified_weights[biases_key] = new_biases
else:
modified_weights[weight_key] = apply_lora_to_linear(original_weight, loras)
applied_count += 1
if applied_count > 0:
print(f" ✓ Applied to {applied_count} modules")
if skipped_count > 0:
print(f" ⚠ Skipped {skipped_count} incompatible modules")
return modified_weights
class LoRALinear(nn.Module):
"""Linear layer with on-the-fly LoRA application.
Wraps nn.Linear or nn.QuantizedLinear, computing LoRA delta at runtime:
output = base_linear(x) + (x @ lora_A.T @ lora_B.T) * scale * strength
"""
def __init__(
self,
linear: nn.Module,
lora_weights_and_strengths: List[Tuple[LoRAWeights, float]],
):
super().__init__()
self.linear = linear
self.lora_weights_and_strengths = lora_weights_and_strengths
def __call__(self, x: mx.array) -> mx.array:
output = self.linear(x)
for weights, strength in self.lora_weights_and_strengths:
scale = weights.scale
lora_out = x @ weights.lora_A.T @ weights.lora_B.T
output = output + (scale * strength * lora_out)
return output
def apply_loras_to_model(
model: nn.Module,
module_to_loras: Dict[str, List[Tuple[LoRAWeights, float]]],
verbose: bool = False,
) -> int:
"""Apply LoRAs to a model by merging into weights.
For QuantizedLinear layers: dequantizes to bf16, merges LoRA delta, and
replaces with a regular nn.Linear (no per-step overhead, no re-quantization
precision loss). Non-LoRA layers stay quantized.
For nn.Linear layers: merges LoRA delta directly into the weight.
Args:
model: The model to apply LoRAs to
module_to_loras: Dictionary mapping module names to (LoRAWeights, strength) lists
verbose: Print debug info
Returns:
Number of modules modified
"""
# Build a set of model module paths for key normalization
module_paths = set()
for name, _ in model.named_modules():
module_paths.add(name)
module_paths.add(f"{name}.weight")
# Map LoRA keys → model module paths
lora_to_module = {}
for lora_key in module_to_loras:
normalized = _normalize_lora_key(lora_key, module_paths)
if normalized.endswith(".weight"):
normalized = normalized[: -len(".weight")]
lora_to_module[lora_key] = normalized
applied_count = 0
dequant_count = 0
skipped = []
for lora_key, loras in module_to_loras.items():
module_path = lora_to_module[lora_key]
parts = module_path.split(".")
# Traverse to the parent module
parent = model
try:
for part in parts[:-1]:
parent = getattr(parent, part) if not part.isdigit() else parent[int(part)]
leaf_name = parts[-1]
target = getattr(parent, leaf_name) if not leaf_name.isdigit() else parent[int(leaf_name)]
except (AttributeError, IndexError, TypeError):
skipped.append(lora_key)
if verbose:
print(f" DEBUG: '{lora_key}' -> '{module_path}' -> module not found")
continue
if isinstance(target, nn.QuantizedLinear):
# Dequantize → merge LoRA → replace with bf16 Linear
weight = mx.dequantize(
target.weight, target.scales, target.biases,
group_size=target.group_size, bits=target.bits,
)
merged = apply_lora_to_linear(weight, loras)
new_linear = nn.Linear(merged.shape[1], merged.shape[0])
new_linear.weight = merged
if "bias" in target:
new_linear.bias = target.bias
if leaf_name.isdigit():
parent[int(leaf_name)] = new_linear
else:
setattr(parent, leaf_name, new_linear)
dequant_count += 1
applied_count += 1
elif isinstance(target, nn.Linear):
# Merge directly into weight
target.weight = apply_lora_to_linear(target.weight, loras)
applied_count += 1
else:
skipped.append(lora_key)
if verbose:
print(f" DEBUG: '{module_path}' is {type(target).__name__}, not Linear")
continue
if applied_count > 0:
msg = f" ✓ Applied to {applied_count} modules"
if dequant_count > 0:
msg += f" ({dequant_count} dequantized to bf16)"
print(msg)
if skipped:
print(f" ⚠ Skipped {len(skipped)} incompatible modules")
return applied_count

122
mlx_video/lora/loader.py Normal file
View File

@@ -0,0 +1,122 @@
"""LoRA weight loading utilities."""
import re
from pathlib import Path
from typing import Dict, List, Optional
import mlx.core as mx
from mlx_video.lora.types import LoRAConfig, LoRAWeights
def load_lora_weights(lora_path: Path) -> Dict[str, LoRAWeights]:
"""Load LoRA weights from a safetensors file.
Supports both key conventions:
- {module_name}.lora_A.weight / {module_name}.lora_B.weight
- {module_name}.lora_down.weight / {module_name}.lora_up.weight
Args:
lora_path: Path to the LoRA safetensors file
Returns:
Dictionary mapping module names to LoRAWeights objects
Raises:
FileNotFoundError: If the LoRA file doesn't exist
ValueError: If the LoRA file format is invalid
"""
if not lora_path.exists():
raise FileNotFoundError(f"LoRA file not found: {lora_path}")
all_weights = mx.load(str(lora_path))
# Group weights by module name, handling both naming conventions
lora_weights = {}
module_names = set()
for key in all_weights.keys():
# Format 1: {module}.lora_A.weight / {module}.lora_B.weight
match = re.match(r"(.+)\.lora_([AB])\.weight$", key)
if match:
module_names.add(match.group(1))
continue
# Format 2: {module}.lora_down.weight / {module}.lora_up.weight
match = re.match(r"(.+)\.lora_(down|up)\.weight$", key)
if match:
module_names.add(match.group(1))
for module_name in module_names:
# Try both key conventions
key_a = f"{module_name}.lora_A.weight"
key_b = f"{module_name}.lora_B.weight"
if key_a not in all_weights or key_b not in all_weights:
key_a = f"{module_name}.lora_down.weight"
key_b = f"{module_name}.lora_up.weight"
if key_a not in all_weights or key_b not in all_weights:
continue
lora_a = all_weights[key_a]
lora_b = all_weights[key_b]
if lora_a.ndim != 2 or lora_b.ndim != 2:
raise ValueError(
f"Invalid LoRA shape for {module_name}: "
f"lora_A={lora_a.shape}, lora_B={lora_b.shape}"
)
rank = lora_a.shape[0]
if lora_b.shape[1] != rank:
raise ValueError(
f"LoRA rank mismatch for {module_name}: "
f"lora_A rank={rank}, lora_B rank={lora_b.shape[1]}"
)
# Check for per-module alpha stored as a scalar tensor
alpha_key = f"{module_name}.alpha"
if alpha_key in all_weights:
alpha = float(all_weights[alpha_key].item())
else:
alpha = float(rank)
lora_weights[module_name] = LoRAWeights(
lora_A=lora_a,
lora_B=lora_b,
rank=rank,
alpha=alpha,
module_name=module_name,
)
if not lora_weights:
raise ValueError(f"No valid LoRA weights found in {lora_path}")
return lora_weights
def load_multiple_loras(
configs: List[LoRAConfig],
) -> Dict[str, List[tuple]]:
"""Load multiple LoRA configurations.
Args:
configs: List of LoRAConfig objects
Returns:
Dictionary mapping module names to lists of (LoRAWeights, strength) tuples.
"""
module_to_loras: Dict[str, list] = {}
for config in configs:
lora_weights = load_lora_weights(config.path)
for module_name, weights in lora_weights.items():
if config.target_modules is not None:
if module_name not in config.target_modules:
continue
if module_name not in module_to_loras:
module_to_loras[module_name] = []
module_to_loras[module_name].append((weights, config.strength))
return module_to_loras

74
mlx_video/lora/types.py Normal file
View File

@@ -0,0 +1,74 @@
"""Data structures for LoRA support."""
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import mlx.core as mx
@dataclass
class LoRAWeights:
"""Container for LoRA weight matrices.
Attributes:
lora_A: Low-rank matrix A of shape [rank, in_features]
lora_B: Low-rank matrix B of shape [out_features, rank]
rank: Rank of the LoRA decomposition
alpha: LoRA scaling parameter (default: rank)
module_name: Target module name in the model
"""
lora_A: mx.array
lora_B: mx.array
rank: int
alpha: float
module_name: str
@property
def scale(self) -> float:
"""Compute the scale factor: alpha / rank."""
return self.alpha / self.rank
@dataclass
class LoRAConfig:
"""Configuration for a single LoRA.
Attributes:
path: Path to the LoRA safetensors file
strength: Strength/weight to apply this LoRA (typically 0.0-2.0)
target_modules: Optional list of module names to apply LoRA to.
If None, applies to all available modules in the LoRA.
"""
path: Path
strength: float = 1.0
target_modules: Optional[list[str]] = None
def __post_init__(self):
"""Validate and normalize the configuration."""
self.path = Path(self.path)
if not self.path.exists():
raise FileNotFoundError(f"LoRA file not found: {self.path}")
if self.strength < 0:
raise ValueError(f"LoRA strength must be non-negative, got {self.strength}")
@dataclass
class AppliedLoRA:
"""Represents a LoRA applied to a specific module.
Attributes:
weights: The LoRA weight matrices
strength: Application strength for this LoRA
"""
weights: LoRAWeights
strength: float
def compute_delta(self) -> mx.array:
"""Compute the weight delta: strength * scale * (lora_B @ lora_A)."""
scale = self.weights.scale
delta = self.weights.lora_B @ self.weights.lora_A
return scale * self.strength * delta

View File

@@ -4,6 +4,15 @@ import mlx.nn as nn
from .rope import rope_apply from .rope import rope_apply
def _linear_dtype(layer) -> mx.Dtype:
"""Get the compute dtype of a linear layer, handling QuantizedLinear and LoRA wrappers."""
# Unwrap LoRA wrapper to get the underlying linear layer
inner = getattr(layer, "linear", layer)
if isinstance(inner, nn.QuantizedLinear):
return inner.scales.dtype
return inner.weight.dtype
class WanRMSNorm(nn.Module): class WanRMSNorm(nn.Module):
"""RMS normalization with learnable scale.""" """RMS normalization with learnable scale."""
@@ -73,8 +82,8 @@ class WanSelfAttention(nn.Module):
b, s, _ = x.shape b, s, _ = x.shape
n, d = self.num_heads, self.head_dim n, d = self.num_heads, self.head_dim
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast) # Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = self.q.weight.dtype w_dtype = _linear_dtype(self.q)
x_w = x.astype(w_dtype) x_w = x.astype(w_dtype)
q = self.q(x_w) q = self.q(x_w)
@@ -154,8 +163,8 @@ class WanCrossAttention(nn.Module):
""" """
b = context.shape[0] b = context.shape[0]
n, d = self.num_heads, self.head_dim n, d = self.num_heads, self.head_dim
# Cast to weight dtype for efficient matmul # Cast to compute dtype for efficient matmul
w_dtype = self.k.weight.dtype w_dtype = _linear_dtype(self.k)
ctx = context.astype(w_dtype) ctx = context.astype(w_dtype)
k = self.k(ctx) k = self.k(ctx)
if self.norm_k is not None: if self.norm_k is not None:
@@ -174,8 +183,8 @@ class WanCrossAttention(nn.Module):
b = x.shape[0] b = x.shape[0]
n, d = self.num_heads, self.head_dim n, d = self.num_heads, self.head_dim
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast) # Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = self.q.weight.dtype w_dtype = _linear_dtype(self.q)
q = self.q(x.astype(w_dtype)) q = self.q(x.astype(w_dtype))
if self.norm_q is not None: if self.norm_q is not None:
q = self.norm_q(q) q = self.norm_q(q)

View File

@@ -6,14 +6,15 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
def load_wan_model(model_path: Path, config, quantization: dict | None = None): def load_wan_model(model_path: Path, config, quantization: dict | None = None, loras: list | None = None):
"""Load and initialize WanModel, with optional quantization support. """Load and initialize WanModel, with optional quantization and LoRA support.
Args: Args:
model_path: Path to model safetensors file model_path: Path to model safetensors file
config: WanModelConfig config: WanModelConfig
quantization: Optional dict with 'bits' and 'group_size' keys. quantization: Optional dict with 'bits' and 'group_size' keys.
If provided, creates QuantizedLinear stubs before loading. If provided, creates QuantizedLinear stubs before loading.
loras: Optional list of (lora_path, strength) tuples to apply.
""" """
from mlx_video.models.wan.model import WanModel from mlx_video.models.wan.model import WanModel
@@ -30,6 +31,27 @@ def load_wan_model(model_path: Path, config, quantization: dict | None = None):
) )
weights = mx.load(str(model_path)) weights = mx.load(str(model_path))
# Apply LoRAs: dequantize+merge for quantized models, weight merge for bf16
if loras:
if quantization:
# Dequantize LoRA-targeted layers, merge delta, replace with bf16 Linear.
# Non-LoRA layers stay 4-bit. Zero per-step overhead.
from mlx_video.convert_wan import _load_lora_configs
from mlx_video.lora import apply_loras_to_model
model.load_weights(list(weights.items()), strict=False)
mx.eval(model.parameters())
module_to_loras = _load_lora_configs(loras)
apply_loras_to_model(model, module_to_loras)
mx.eval(model.parameters())
return model
else:
# Weight merging: fold LoRA into bf16 weights before loading
from mlx_video.convert_wan import load_and_apply_loras
weights = load_and_apply_loras(dict(weights), loras)
model.load_weights(list(weights.items()), strict=False) model.load_weights(list(weights.items()), strict=False)
mx.eval(model.parameters()) mx.eval(model.parameters())
return model return model

View File

@@ -4,7 +4,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from .attention import WanLayerNorm from .attention import WanLayerNorm, _linear_dtype
from .config import WanModelConfig from .config import WanModelConfig
from .rope import rope_params, rope_precompute_cos_sin from .rope import rope_params, rope_precompute_cos_sin
from .transformer import WanAttentionBlock from .transformer import WanAttentionBlock
@@ -54,7 +54,7 @@ class Head(nn.Module):
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
x_norm = self.norm(x) x_norm = self.norm(x)
x_mod = x_norm * (1 + e1) + e0 # type promotion handles bf16→f32 x_mod = x_norm * (1 + e1) + e0 # type promotion handles bf16→f32
return self.head(x_mod.astype(self.head.weight.dtype)) return self.head(x_mod.astype(_linear_dtype(self.head)))
class WanModel(nn.Module): class WanModel(nn.Module):
@@ -79,7 +79,7 @@ class WanModel(nn.Module):
# Text embedding MLP # Text embedding MLP
self.text_embedding_0 = nn.Linear(config.text_dim, dim) self.text_embedding_0 = nn.Linear(config.text_dim, dim)
self.text_embedding_act = nn.GELU(approx="precise") self.text_embedding_act = nn.GELU(approx="tanh")
self.text_embedding_1 = nn.Linear(dim, dim) self.text_embedding_1 = nn.Linear(dim, dim)
# Time embedding MLP # Time embedding MLP
@@ -149,7 +149,7 @@ class WanModel(nn.Module):
# Project and cast to model dtype to prevent float32 cascade from input latents # Project and cast to model dtype to prevent float32 cascade from input latents
patches = self.patch_embedding_proj(x) # [L, dim] patches = self.patch_embedding_proj(x) # [L, dim]
patches = patches.astype(self.patch_embedding_proj.weight.dtype) patches = patches.astype(_linear_dtype(self.patch_embedding_proj))
patches = patches[None, :, :] # [1, L, dim] patches = patches[None, :, :] # [1, L, dim]
return patches, (f_out, h_out, w_out) return patches, (f_out, h_out, w_out)
@@ -186,7 +186,7 @@ class WanModel(nn.Module):
Returns: Returns:
Embedded context [B, text_len, dim] in model dtype Embedded context [B, text_len, dim] in model dtype
""" """
model_dtype = self.patch_embedding_proj.weight.dtype model_dtype = _linear_dtype(self.patch_embedding_proj)
context_padded = [] context_padded = []
for ctx in context: for ctx in context:
pad_len = self.text_len - ctx.shape[0] pad_len = self.text_len - ctx.shape[0]
@@ -231,7 +231,7 @@ class WanModel(nn.Module):
Returns: Returns:
(cos_f, sin_f) precomputed frequency tensors (cos_f, sin_f) precomputed frequency tensors
""" """
w_dtype = self.patch_embedding_proj.weight.dtype w_dtype = _linear_dtype(self.patch_embedding_proj)
return rope_precompute_cos_sin(grid_sizes, self.freqs, dtype=w_dtype) return rope_precompute_cos_sin(grid_sizes, self.freqs, dtype=w_dtype)
def __call__( def __call__(
@@ -348,7 +348,7 @@ class WanModel(nn.Module):
# Pre-compute attention mask from seq_lens (constant across all blocks) # Pre-compute attention mask from seq_lens (constant across all blocks)
attn_mask = None attn_mask = None
w_dtype = self.patch_embedding_proj.weight.dtype w_dtype = _linear_dtype(self.patch_embedding_proj)
if any(sl < seq_len for sl in seq_lens_list): if any(sl < seq_len for sl in seq_lens_list):
attn_mask = mx.zeros((batch_size, 1, 1, seq_len), dtype=w_dtype) attn_mask = mx.zeros((batch_size, 1, 1, seq_len), dtype=w_dtype)
for i, sl in enumerate(seq_lens_list): for i, sl in enumerate(seq_lens_list):

View File

@@ -146,7 +146,7 @@ class T5FeedForward(nn.Module):
self.dim = dim self.dim = dim
self.dim_ffn = dim_ffn self.dim_ffn = dim_ffn
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False) self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
self.gate_act = nn.GELU(approx="precise") self.gate_act = nn.GELU(approx="tanh")
self.fc1 = nn.Linear(dim, dim_ffn, bias=False) self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False) self.fc2 = nn.Linear(dim_ffn, dim, bias=False)

View File

@@ -1,7 +1,7 @@
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention, _linear_dtype
class WanAttentionBlock(nn.Module): class WanAttentionBlock(nn.Module):
@@ -84,10 +84,10 @@ class WanFFN(nn.Module):
def __init__(self, dim: int, ffn_dim: int): def __init__(self, dim: int, ffn_dim: int):
super().__init__() super().__init__()
self.fc1 = nn.Linear(dim, ffn_dim) self.fc1 = nn.Linear(dim, ffn_dim)
self.act = nn.GELU(approx="precise") self.act = nn.GELU(approx="tanh")
self.fc2 = nn.Linear(ffn_dim, dim) self.fc2 = nn.Linear(ffn_dim, dim)
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast) # Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
x_w = x.astype(self.fc1.weight.dtype) x_w = x.astype(_linear_dtype(self.fc1))
return self.fc2(self.act(self.fc1(x_w))) return self.fc2(self.act(self.fc1(x_w)))

View File

@@ -7,12 +7,15 @@ Weight keys mirror the PyTorch checkpoint hierarchy so only tensor format
conversion (channels-first → channels-last) is needed. conversion (channels-first → channels-last) is needed.
""" """
import logging
import math import math
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
logger = logging.getLogger(__name__)
CACHE_T = 2 CACHE_T = 2
# Per-channel normalization for z_dim=48 latent space # Per-channel normalization for z_dim=48 latent space
@@ -774,11 +777,13 @@ def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) ->
Maps PyTorch nn.Sequential indices to our named layers. Maps PyTorch nn.Sequential indices to our named layers.
""" """
sanitized = {} sanitized = {}
consumed = set()
for key, value in weights.items(): for key, value in weights.items():
# Skip encoder and conv1 unless requested # Skip encoder and conv1 unless requested
if not include_encoder: if not include_encoder:
if key.startswith("encoder.") or key.startswith("conv1."): if key.startswith("encoder.") or key.startswith("conv1."):
consumed.add(key)
continue continue
new_key = key new_key = key
@@ -832,5 +837,10 @@ def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) ->
value = mx.array(np.array(value).squeeze()) value = mx.array(np.array(value).squeeze())
sanitized[new_key] = value sanitized[new_key] = value
consumed.add(key)
unconsumed = set(weights.keys()) - consumed
if unconsumed:
logger.warning("Unconsumed Wan2.2 VAE weight keys: %s", sorted(unconsumed))
return sanitized return sanitized

View File

@@ -1,5 +1,7 @@
"""Tests for Wan weight conversion utilities.""" """Tests for Wan weight conversion utilities."""
import logging
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
import pytest import pytest
@@ -94,6 +96,27 @@ class TestSanitizeTransformerWeights:
for key in weights: for key in weights:
assert key in out assert key in out
def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
"patch_embedding.bias": mx.random.normal((5120,)),
"text_embedding.0.weight": mx.zeros((64, 32)),
"text_embedding.2.weight": mx.zeros((64, 64)),
"time_embedding.0.weight": mx.zeros((64, 32)),
"time_embedding.2.weight": mx.zeros((64, 64)),
"time_projection.1.weight": mx.zeros((384, 64)),
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
"blocks.0.ffn.2.weight": mx.zeros((64, 128)),
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
"blocks.0.modulation": mx.zeros((1, 6, 64)),
"head.head.weight": mx.zeros((64, 64)),
"freqs": mx.zeros((1024, 64, 2)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
sanitize_wan_transformer_weights(weights)
assert "Unconsumed" not in caplog.text
class TestSanitizeT5Weights: class TestSanitizeT5Weights:
def test_gate_rename(self): def test_gate_rename(self):
@@ -119,6 +142,19 @@ class TestSanitizeT5Weights:
for key in weights: for key in weights:
assert key in out assert key in out
def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_t5_weights
weights = {
"token_embedding.weight": mx.zeros((100, 64)),
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
"norm.weight": mx.zeros((64,)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
sanitize_wan_t5_weights(weights)
assert "Unconsumed" not in caplog.text
class TestSanitizeVAEWeights: class TestSanitizeVAEWeights:
def test_conv3d_transpose(self): def test_conv3d_transpose(self):
@@ -161,6 +197,18 @@ class TestSanitizeVAEWeights:
assert out["linear.weight"].shape == (8, 4) assert out["linear.weight"].shape == (8, 4)
assert out["norm.weight"].shape == (8,) assert out["norm.weight"].shape == (8,)
def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = {
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)),
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)),
"decoder.norm.weight": mx.zeros((64,)),
"decoder.bias": mx.zeros((16,)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
sanitize_wan_vae_weights(weights)
assert "Unconsumed" not in caplog.text
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Wan2.1 Conversion Tests # Wan2.1 Conversion Tests
@@ -233,3 +281,27 @@ class TestSanitizeEncoderWeights:
assert "encoder.conv1.weight" in out assert "encoder.conv1.weight" in out
assert "conv1.weight" in out assert "conv1.weight" in out
assert "conv2.weight" in out assert "conv2.weight" in out
def test_no_unconsumed_keys(self, caplog):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
sanitize_wan22_vae_weights(weights, include_encoder=True)
assert "Unconsumed" not in caplog.text
def test_no_unconsumed_keys_exclude_encoder(self, caplog):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
sanitize_wan22_vae_weights(weights, include_encoder=False)
assert "Unconsumed" not in caplog.text

View File

@@ -291,3 +291,282 @@ class TestI2VMaskConstruction:
encoded = mx.zeros((16, 5, 10, 18)) encoded = mx.zeros((16, 5, 10, 18))
y = mx.concatenate([mask, encoded], axis=0) y = mx.concatenate([mask, encoded], axis=0)
assert y.shape == (20, 5, 10, 18) assert y.shape == (20, 5, 10, 18)
# ---------------------------------------------------------------------------
# Integration: I2V end-to-end pipeline
# ---------------------------------------------------------------------------
class TestI2VEndToEndPipeline:
"""Full I2V pipeline: image → preprocess → VAE encode → y tensor → denoise → VAE decode."""
def test_full_i2v_pipeline(self):
"""End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan.vae import WanVAE
mx.random.seed(0)
# --- Tiny I2V model config (z_dim=16 to match VAE normalization stats) ---
config = _make_tiny_i2v_config()
config.vae_z_dim = 16
config.out_dim = 16 # must match VAE z_dim for decode
config.in_dim = 16 + 4 + 16 # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36
model = WanModel(config)
# --- Tiny VAE (with encoder) ---
vae = WanVAE(z_dim=config.vae_z_dim, encoder=True)
# --- Synthetic image: [B=1, 3, T=1, H=32, W=32] in [-1, 1] ---
height, width = 32, 32
num_frames = 5 # small temporal extent
img = mx.random.uniform(-1, 1, (1, 3, 1, height, width))
# Build video: first frame = image, rest = zeros -> [1, 3, F, H, W]
video = mx.concatenate([
img,
mx.zeros((1, 3, num_frames - 1, height, width)),
], axis=2)
# --- VAE encode ---
z_video = vae.encode(video) # [1, z_dim, T_lat, H_lat, W_lat]
mx.eval(z_video)
assert z_video.ndim == 5
assert z_video.shape[1] == config.vae_z_dim
z_video = z_video[0] # [z_dim, T_lat, H_lat, W_lat]
t_latent = z_video.shape[1]
h_latent = z_video.shape[2]
w_latent = z_video.shape[3]
# --- Build I2V mask (4 channels) ---
msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
# --- Build y tensor: [mask(4ch) + encoded(z_dim ch)] ---
y_i2v = mx.concatenate([msk, z_video], axis=0)
mx.eval(y_i2v)
assert y_i2v.shape[0] == 4 + config.vae_z_dim
# --- Denoising loop (2 steps) ---
C_noise = config.out_dim # noise channels
pt, ph, pw = config.patch_size
seq_len = (t_latent // pt) * (h_latent // ph) * (w_latent // pw)
sched = FlowMatchEulerScheduler()
num_steps = 2
sched.set_timesteps(num_steps, shift=config.sample_shift)
latents = mx.random.normal((C_noise, t_latent, h_latent, w_latent))
context = mx.random.normal((4, config.text_dim))
for i in range(num_steps):
t_val = sched.timesteps[i].item()
pred = model(
[latents],
mx.array([t_val]),
[context],
seq_len,
y=[y_i2v],
)[0]
latents = sched.step(pred[None], t_val, latents[None]).squeeze(0)
mx.eval(latents)
assert latents.shape == (C_noise, t_latent, h_latent, w_latent)
assert not mx.any(mx.isnan(latents)).item(), "NaN in denoised latents"
assert not mx.any(mx.isinf(latents)).item(), "Inf in denoised latents"
# --- VAE decode ---
decoded = vae.decode(latents[None]) # [1, 3, T_out, H_out, W_out]
mx.eval(decoded)
assert decoded.ndim == 5
assert decoded.shape[0] == 1
assert decoded.shape[1] == 3 # RGB output
assert not mx.any(mx.isnan(decoded)).item(), "NaN in decoded video"
assert not mx.any(mx.isinf(decoded)).item(), "Inf in decoded video"
# VAE decode clips to [-1, 1]
assert float(decoded.max()) <= 1.0
assert float(decoded.min()) >= -1.0
class TestDualModelSwitching:
"""Test dual-model selection logic: high_noise vs low_noise based on boundary."""
def test_model_selection_by_timestep(self):
"""Verify high_noise model used for timesteps >= boundary, low_noise otherwise."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
mx.random.seed(1)
config = _make_tiny_i2v_config()
assert config.dual_model is True
high_noise_model = WanModel(config)
low_noise_model = WanModel(config)
boundary = config.boundary * config.num_train_timesteps # 0.9 * 1000 = 900
C_noise = config.out_dim # 4
C_y = config.in_dim - config.out_dim # 9 - 4 = 5
F, H, W = 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
sched = FlowMatchEulerScheduler()
num_steps = 5
sched.set_timesteps(num_steps, shift=config.sample_shift)
guide_scale = config.sample_guide_scale # (3.5, 3.5)
assert isinstance(guide_scale, tuple) and len(guide_scale) == 2
latents = mx.random.normal((C_noise, F, H, W))
y_i2v = mx.random.normal((C_y, F, H, W))
context = mx.random.normal((4, config.text_dim))
high_used_steps = []
low_used_steps = []
timestep_list = sched.timesteps.tolist()
for i in range(num_steps):
timestep_val = timestep_list[i]
if timestep_val >= boundary:
model = high_noise_model
gs = guide_scale[1]
high_used_steps.append(i)
else:
model = low_noise_model
gs = guide_scale[0]
low_used_steps.append(i)
# CFG pass: cond + uncond
preds = model(
[latents, latents],
mx.array([timestep_val, timestep_val]),
[context, context],
seq_len,
y=[y_i2v, y_i2v],
)
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
mx.eval(latents)
# With shift=5.0, early timesteps should be high (>=900), later ones low
assert len(high_used_steps) > 0, "High-noise model was never selected"
assert len(low_used_steps) > 0, "Low-noise model was never selected"
# High-noise steps should come before low-noise steps (timesteps decrease)
if high_used_steps and low_used_steps:
assert max(high_used_steps) < min(low_used_steps) or \
min(high_used_steps) < max(low_used_steps), \
"Model switching should happen during the loop"
assert latents.shape == (C_noise, F, H, W)
assert not mx.any(mx.isnan(latents)).item()
def test_guide_scale_tuple_applied_per_model(self):
"""Verify (low_gs, high_gs) tuple applies different scales per model."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
mx.random.seed(2)
config = _make_tiny_i2v_config()
config.sample_guide_scale = (2.0, 5.0) # distinct values
model = WanModel(config)
boundary = config.boundary * config.num_train_timesteps
C_noise = config.out_dim
F, H, W = 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=config.sample_shift)
latents = mx.random.normal((C_noise, F, H, W))
context = mx.random.normal((4, config.text_dim))
guide_scale = config.sample_guide_scale
C_y = config.in_dim - config.out_dim # y channels
y_i2v = mx.random.normal((C_y, F, H, W))
# Track which guide scale was used at each step
gs_per_step = []
timestep_list = sched.timesteps.tolist()
for i in range(5):
timestep_val = timestep_list[i]
if timestep_val >= boundary:
gs = guide_scale[1] # high_gs = 5.0
else:
gs = guide_scale[0] # low_gs = 2.0
gs_per_step.append(gs)
pred = model(
[latents, latents],
mx.array([timestep_val, timestep_val]),
[context, context],
seq_len,
y=[y_i2v, y_i2v],
)
noise_pred = pred[1] + gs * (pred[0] - pred[1])
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
mx.eval(latents)
# Verify both guide scales were used
assert 5.0 in gs_per_step, "High guide scale (5.0) was never used"
assert 2.0 in gs_per_step, "Low guide scale (2.0) was never used"
# High gs should appear first (high timesteps come first)
first_high = gs_per_step.index(5.0)
last_low = len(gs_per_step) - 1 - gs_per_step[::-1].index(2.0)
assert first_high < last_low, "High gs steps should precede low gs steps"
def test_single_model_fallback_with_tuple_guide_scale(self):
"""When dual_model=False, guide_scale tuple should use first element."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
mx.random.seed(3)
config = _make_tiny_config()
config.dual_model = False
config.sample_guide_scale = (3.0, 5.0)
model = WanModel(config)
guide_scale = config.sample_guide_scale
C, F, H, W = config.in_dim, 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
sched = FlowMatchEulerScheduler()
sched.set_timesteps(3, shift=3.0)
latents = mx.random.normal((C, F, H, W))
context = mx.random.normal((4, config.text_dim))
# Mimic generate_wan.py single-model logic:
# gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
assert gs == 3.0, "Single model should use first element of guide_scale tuple"
for i in range(3):
t_val = sched.timesteps[i].item()
pred = model(
[latents, latents],
mx.array([t_val, t_val]),
[context, context],
seq_len,
)
noise_pred = pred[1] + gs * (pred[0] - pred[1])
latents = sched.step(noise_pred[None], t_val, latents[None]).squeeze(0)
mx.eval(latents)
assert latents.shape == (C, F, H, W)
assert not mx.any(mx.isnan(latents)).item()

334
tests/test_wan_lora.py Normal file
View File

@@ -0,0 +1,334 @@
"""Tests for LoRA loading and application."""
import tempfile
from pathlib import Path
import mlx.core as mx
import numpy as np
import pytest
class TestLoRATypes:
"""Test LoRA data structures."""
def test_lora_weights_scale(self):
from mlx_video.lora.types import LoRAWeights
w = LoRAWeights(
lora_A=mx.zeros((16, 64)),
lora_B=mx.zeros((128, 16)),
rank=16,
alpha=32.0,
module_name="test",
)
assert w.scale == 2.0
def test_lora_weights_scale_default(self):
from mlx_video.lora.types import LoRAWeights
w = LoRAWeights(
lora_A=mx.zeros((16, 64)),
lora_B=mx.zeros((128, 16)),
rank=16,
alpha=16.0,
module_name="test",
)
assert w.scale == 1.0
def test_applied_lora_delta(self):
from mlx_video.lora.types import AppliedLoRA, LoRAWeights
lora_a = mx.ones((2, 4))
lora_b = mx.ones((8, 2))
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
applied = AppliedLoRA(weights=w, strength=0.5)
delta = applied.compute_delta()
# scale=1.0, strength=0.5, B@A = [[2,2,2,2]]*8 (each row sum of 2 ones)
expected = 0.5 * mx.ones((8, 4)) * 2.0
assert mx.allclose(delta, expected).item()
class TestLoRALoader:
"""Test LoRA weight loading from safetensors."""
def _make_lora_file(self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"):
"""Helper to create a mock LoRA safetensors file."""
weights = {}
for name in module_names:
if key_format == "AB":
weights[f"{name}.lora_A.weight"] = mx.random.normal((rank, in_dim))
weights[f"{name}.lora_B.weight"] = mx.random.normal((out_dim, rank))
else:
weights[f"{name}.lora_down.weight"] = mx.random.normal((rank, in_dim))
weights[f"{name}.lora_up.weight"] = mx.random.normal((out_dim, rank))
path = Path(tmp_dir) / "test_lora.safetensors"
mx.save_safetensors(str(path), weights)
return path
def test_load_lora_a_b_format(self):
from mlx_video.lora.loader import load_lora_weights
with tempfile.TemporaryDirectory() as tmp:
path = self._make_lora_file(tmp, ["blocks.0.self_attn.q"], key_format="AB")
lora_weights = load_lora_weights(path)
assert "blocks.0.self_attn.q" in lora_weights
w = lora_weights["blocks.0.self_attn.q"]
assert w.rank == 4
assert w.alpha == 4.0 # default: alpha == rank
assert w.lora_A.shape == (4, 64)
assert w.lora_B.shape == (128, 4)
def test_load_lora_down_up_format(self):
from mlx_video.lora.loader import load_lora_weights
with tempfile.TemporaryDirectory() as tmp:
path = self._make_lora_file(
tmp, ["blocks.0.self_attn.q"], key_format="down_up"
)
lora_weights = load_lora_weights(path)
assert "blocks.0.self_attn.q" in lora_weights
def test_load_multiple_modules(self):
from mlx_video.lora.loader import load_lora_weights
modules = [
"blocks.0.self_attn.q",
"blocks.0.self_attn.k",
"blocks.0.ffn.fc1",
]
with tempfile.TemporaryDirectory() as tmp:
path = self._make_lora_file(tmp, modules)
lora_weights = load_lora_weights(path)
assert len(lora_weights) == 3
for name in modules:
assert name in lora_weights
def test_load_with_alpha(self):
from mlx_video.lora.loader import load_lora_weights
with tempfile.TemporaryDirectory() as tmp:
weights = {
"test.lora_A.weight": mx.random.normal((8, 64)),
"test.lora_B.weight": mx.random.normal((128, 8)),
"test.alpha": mx.array(16.0),
}
path = Path(tmp) / "lora.safetensors"
mx.save_safetensors(str(path), weights)
lora_weights = load_lora_weights(path)
assert lora_weights["test"].alpha == 16.0
assert lora_weights["test"].rank == 8
assert lora_weights["test"].scale == 2.0
def test_file_not_found(self):
from mlx_video.lora.loader import load_lora_weights
with pytest.raises(FileNotFoundError):
load_lora_weights(Path("/nonexistent/lora.safetensors"))
class TestWanKeyNormalization:
"""Test Wan2.2 LoRA key normalization."""
def _wan_model_keys(self):
"""Simulate typical Wan2.2 MLX model weight keys."""
keys = set()
for i in range(2):
for layer in ["self_attn.q", "self_attn.k", "self_attn.v", "self_attn.o",
"cross_attn.q", "cross_attn.k", "cross_attn.v", "cross_attn.o"]:
keys.add(f"blocks.{i}.{layer}.weight")
keys.add(f"blocks.{i}.ffn.fc1.weight")
keys.add(f"blocks.{i}.ffn.fc2.weight")
keys.add("text_embedding_0.weight")
keys.add("text_embedding_1.weight")
keys.add("time_embedding_0.weight")
keys.add("time_embedding_1.weight")
keys.add("time_projection.weight")
keys.add("patch_embedding_proj.weight")
return keys
def test_direct_match(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
assert _normalize_wan_lora_key("blocks.0.self_attn.q", keys) == "blocks.0.self_attn.q"
def test_strip_diffusion_model_prefix(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
result = _normalize_wan_lora_key("diffusion_model.blocks.0.self_attn.q", keys)
assert result == "blocks.0.self_attn.q"
def test_strip_model_prefix(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
result = _normalize_wan_lora_key("model.diffusion_model.blocks.0.self_attn.k", keys)
assert result == "blocks.0.self_attn.k"
def test_ffn_key_mapping(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
assert _normalize_wan_lora_key("blocks.0.ffn.0", keys) == "blocks.0.ffn.fc1"
assert _normalize_wan_lora_key("blocks.0.ffn.2", keys) == "blocks.0.ffn.fc2"
def test_text_embedding_mapping(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
assert _normalize_wan_lora_key("text_embedding.0", keys) == "text_embedding_0"
assert _normalize_wan_lora_key("text_embedding.2", keys) == "text_embedding_1"
def test_time_embedding_mapping(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
assert _normalize_wan_lora_key("time_embedding.0", keys) == "time_embedding_0"
assert _normalize_wan_lora_key("time_embedding.2", keys) == "time_embedding_1"
def test_time_projection_mapping(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
assert _normalize_wan_lora_key("time_projection.1", keys) == "time_projection"
def test_patch_embedding_mapping(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
assert _normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj"
def test_combined_prefix_and_ffn(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
result = _normalize_wan_lora_key("diffusion_model.blocks.1.ffn.0", keys)
assert result == "blocks.1.ffn.fc1"
class TestApplyLoRA:
"""Test LoRA delta application to weights."""
def test_preserves_bfloat16_dtype(self):
"""LoRA delta must not promote bfloat16 weights to float32."""
from mlx_video.lora.apply import apply_lora_to_linear
from mlx_video.lora.types import LoRAWeights
original = mx.ones((8, 4), dtype=mx.bfloat16)
# LoRA weights in float32 (typical when loaded from safetensors)
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
result = apply_lora_to_linear(original, [(w, 1.0)])
assert result.dtype == mx.bfloat16, f"Expected bfloat16, got {result.dtype}"
def test_preserves_float16_dtype(self):
from mlx_video.lora.apply import apply_lora_to_linear
from mlx_video.lora.types import LoRAWeights
original = mx.ones((8, 4), dtype=mx.float16)
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
result = apply_lora_to_linear(original, [(w, 1.0)])
assert result.dtype == mx.float16, f"Expected float16, got {result.dtype}"
def test_apply_single_lora(self):
from mlx_video.lora.apply import apply_lora_to_linear
from mlx_video.lora.types import LoRAWeights
original = mx.ones((8, 4))
lora_a = mx.ones((2, 4)) * 0.1
lora_b = mx.ones((8, 2)) * 0.1
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
result = apply_lora_to_linear(original, [(w, 1.0)])
# delta = 1.0 * (B @ A) = ones(8,2)*0.1 @ ones(2,4)*0.1 = 0.02 * ones(8,4)
expected = original + 0.02 * mx.ones((8, 4))
assert mx.allclose(result, expected, atol=1e-6).item()
def test_apply_multiple_loras(self):
from mlx_video.lora.apply import apply_lora_to_linear
from mlx_video.lora.types import LoRAWeights
original = mx.zeros((8, 4))
w1 = LoRAWeights(
lora_A=mx.ones((2, 4)),
lora_B=mx.ones((8, 2)),
rank=2, alpha=2.0, module_name="a",
)
w2 = LoRAWeights(
lora_A=mx.ones((2, 4)) * 2,
lora_B=mx.ones((8, 2)) * 2,
rank=2, alpha=4.0, module_name="b",
)
result = apply_lora_to_linear(original, [(w1, 1.0), (w2, 0.5)])
# w1 delta: 1.0 * 1.0 * (ones(8,2) @ ones(2,4)) = 2 * ones(8,4)
# w2 delta: 2.0 * 0.5 * (2*ones(8,2) @ 2*ones(2,4)) = 1.0 * 8*ones(8,4) = 8
delta1 = mx.ones((8, 4)) * 2.0
delta2 = mx.ones((8, 4)) * 8.0
expected = delta1 + delta2
assert mx.allclose(result, expected, atol=1e-5).item()
def test_apply_loras_to_weights_dict(self):
from mlx_video.lora.apply import apply_loras_to_weights
from mlx_video.lora.types import LoRAWeights
model_weights = {
"blocks.0.self_attn.q.weight": mx.ones((128, 64)),
"blocks.0.self_attn.k.weight": mx.ones((128, 64)),
"blocks.0.ffn.fc1.weight": mx.ones((256, 64)),
}
w = LoRAWeights(
lora_A=mx.ones((4, 64)) * 0.01,
lora_B=mx.ones((128, 4)) * 0.01,
rank=4, alpha=4.0, module_name="blocks.0.self_attn.q",
)
module_to_loras = {"blocks.0.self_attn.q": [(w, 1.0)]}
result = apply_loras_to_weights(model_weights, module_to_loras)
# Only q should be modified
assert not mx.array_equal(
result["blocks.0.self_attn.q.weight"],
model_weights["blocks.0.self_attn.q.weight"],
).item()
assert mx.array_equal(
result["blocks.0.self_attn.k.weight"],
model_weights["blocks.0.self_attn.k.weight"],
).item()
class TestEndToEnd:
"""End-to-end LoRA loading and application."""
def test_load_and_apply_loras(self):
from mlx_video.convert_wan import load_and_apply_loras
with tempfile.TemporaryDirectory() as tmp:
# Create mock LoRA safetensors
rank = 4
weights = {
"blocks.0.self_attn.q.lora_A.weight": mx.random.normal((rank, 64)),
"blocks.0.self_attn.q.lora_B.weight": mx.random.normal((128, rank)),
}
lora_path = Path(tmp) / "test.safetensors"
mx.save_safetensors(str(lora_path), weights)
# Create mock model weights
model_weights = {
"blocks.0.self_attn.q.weight": mx.ones((128, 64)),
"blocks.0.self_attn.k.weight": mx.ones((128, 64)),
}
result = load_and_apply_loras(
model_weights, [(str(lora_path), 1.0)]
)
# q weight should be modified, k unchanged
assert not mx.array_equal(
result["blocks.0.self_attn.q.weight"],
model_weights["blocks.0.self_attn.q.weight"],
).item()
assert mx.array_equal(
result["blocks.0.self_attn.k.weight"],
model_weights["blocks.0.self_attn.k.weight"],
).item()

View File

@@ -868,4 +868,84 @@ class TestVAEEncoderTemporalOrder:
assert out_wrong.shape[1] == 2 assert out_wrong.shape[1] == 2
# ---------------------------------------------------------------------------
# VAE Encode → Decode Round-Trip Tests
# ---------------------------------------------------------------------------
class TestVAE21RoundTrip:
"""Encode→decode round-trip for Wan 2.1 VAE (channels-first)."""
def test_encode_decode_shape_and_values(self):
"""Encoder3d → Decoder3d: output shape matches input, values are finite."""
from mlx_video.models.wan.vae import Decoder3d, Encoder3d
z_dim = 4
dim = 8
# No temporal up/downsampling to keep the test simple
enc = Encoder3d(
dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False]
)
dec = Decoder3d(
dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False]
)
mx.eval(enc.parameters(), dec.parameters())
# [B=1, C=3, T=1, H=8, W=8]
x = mx.random.normal((1, 3, 1, 8, 8)) * 0.5
z = enc(x)
mx.eval(z)
# 3 spatial downsamples (÷8): H=1, W=1
assert z.shape == (1, z_dim, 1, 1, 1)
x_hat = dec(z)
mx.eval(x_hat)
# 3 spatial upsamples (×8): should recover original shape
assert x_hat.shape == x.shape
out_np = np.array(x_hat)
assert np.all(np.isfinite(out_np))
assert np.abs(out_np).max() < 1000
class TestVAE22RoundTrip:
"""Encode→decode round-trip for Wan 2.2 VAE (channels-last)."""
def test_encode_decode_shape_and_values(self):
"""Wan22VAEEncoder → Wan22VAEDecoder: shapes consistent, values in range."""
from mlx_video.models.wan.vae22 import (
Wan22VAEDecoder,
Wan22VAEEncoder,
denormalize_latents,
)
enc = Wan22VAEEncoder(z_dim=48, dim=16)
dec = Wan22VAEDecoder(z_dim=48, dec_dim=8)
mx.eval(enc.parameters(), dec.parameters())
# [B=1, T=1, H=32, W=32, C=3]
img = mx.random.normal((1, 1, 32, 32, 3)) * 0.5
z_norm = enc(img)
mx.eval(z_norm)
# patchify(÷2) + 3 spatial downsamples(÷8) = ÷16
assert z_norm.shape == (1, 1, 2, 2, 48)
z = denormalize_latents(z_norm)
out = dec(z)
mx.eval(out)
# 3 spatial upsamples(×8) + unpatchify(×2) = ×16
assert out.shape[0] == 1 # batch
assert out.shape[2] == 32 # H recovered
assert out.shape[3] == 32 # W recovered
assert out.shape[-1] == 3 # RGB
out_np = np.array(out)
assert np.all(np.isfinite(out_np))
assert out_np.min() >= -1.0 - 1e-6
assert out_np.max() <= 1.0 + 1e-6