From 849cc45d8442e518bd0237916814b400db32b857 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 28 Feb 2026 14:11:13 +0100 Subject: [PATCH] feat(wan): Add LoRA with improved quantization pipeline --- README.md | 18 +- mlx_video/convert_wan.py | 235 ++++++++++++++-- mlx_video/generate_wan.py | 250 ++++++++++++----- mlx_video/lora/__init__.py | 25 ++ mlx_video/lora/apply.py | 393 +++++++++++++++++++++++++++ mlx_video/lora/loader.py | 122 +++++++++ mlx_video/lora/types.py | 74 +++++ mlx_video/models/wan/attention.py | 21 +- mlx_video/models/wan/loading.py | 26 +- mlx_video/models/wan/model.py | 14 +- mlx_video/models/wan/text_encoder.py | 2 +- mlx_video/models/wan/transformer.py | 8 +- mlx_video/models/wan/vae22.py | 10 + tests/test_wan_convert.py | 72 +++++ tests/test_wan_i2v.py | 279 +++++++++++++++++++ tests/test_wan_lora.py | 334 +++++++++++++++++++++++ tests/test_wan_vae.py | 80 ++++++ 17 files changed, 1852 insertions(+), 111 deletions(-) create mode 100644 mlx_video/lora/__init__.py create mode 100644 mlx_video/lora/apply.py create mode 100644 mlx_video/lora/loader.py create mode 100644 mlx_video/lora/types.py create mode 100644 tests/test_wan_lora.py diff --git a/README.md b/README.md index 751f48a..f3a05a6 100644 --- a/README.md +++ b/README.md @@ -82,15 +82,15 @@ python -m mlx_video.generate \ Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.com/Wan-Video/Wan2.2) are text-to-video diffusion models built on a DiT (Diffusion Transformer) backbone with a T5 text encoder and 3D VAE. They share the same model architecture — the difference is in the inference pipeline: -| | Wan2.1 | Wan2.2 T2V-14B | Wan2.2 I2V-14B | -|---|--------|--------|--------| -| **Task** | Text-to-Video | Text-to-Video | Image-to-Video | -| **Pipeline** | Single model | Dual model | Dual model | -| **Sizes** | 1.3B, 14B | 14B | 14B | -| **Steps** | 50 | 40 | 40 | -| **Guidance** | 5.0 (fixed) | 3.0 / 4.0 | 3.5 / 3.5 | -| **Shift** | 5.0 | 12.0 | 5.0 | -| **VAE** | Wan2.1 (z=16) | Wan2.1 (z=16) | Wan2.1 (z=16) + encoder | +| | Wan2.1 | Wan2.2 T2V-14B | Wan2.2 I2V-14B | Wan2.2 TI2V-5B | +|---|--------|--------|--------|--------| +| **Task** | Text-to-Video | Text-to-Video | Image-to-Video | Text+Image-to-Video | +| **Pipeline** | Single model | Dual model | Dual model | Single model | +| **Sizes** | 1.3B, 14B | 14B | 14B | 5B | +| **Steps** | 50 | 40 | 40 | 40 | +| **Guidance** | 5.0 (fixed) | 3.0 / 4.0 | 3.5 / 3.5 | 5.0 (fixed) | +| **Shift** | 5.0 | 12.0 | 5.0 | 5.0 | +| **VAE** | Wan2.1 (z=16) | Wan2.1 (z=16) | Wan2.1 (z=16) + encoder | Wan2.2 (z=48) | ### Step 1: Download Weights diff --git a/mlx_video/convert_wan.py b/mlx_video/convert_wan.py index f3f9037..a7930c1 100644 --- a/mlx_video/convert_wan.py +++ b/mlx_video/convert_wan.py @@ -1,13 +1,16 @@ """Weight conversion for Wan2.2 models (PyTorch -> MLX).""" +import gc import logging from pathlib import Path -from typing import Dict +from typing import Dict, List, Optional, Tuple import mlx.core as mx import mlx.utils import numpy as np +logger = logging.getLogger(__name__) + def load_torch_weights(path: str) -> Dict[str, mx.array]: """Load PyTorch .pth weights and convert to MLX arrays. @@ -88,6 +91,7 @@ def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, etc. """ sanitized = {} + consumed = set() for key, value in weights.items(): new_key = key @@ -99,36 +103,43 @@ def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, value = value.reshape(value.shape[0], -1) new_key = "patch_embedding_proj.weight" sanitized[new_key] = value + consumed.add(key) continue if key == "patch_embedding.bias": new_key = "patch_embedding_proj.bias" sanitized[new_key] = value + consumed.add(key) continue # Text embedding Sequential: 0=Linear, 1=GELU(no params), 2=Linear if key.startswith("text_embedding.0."): new_key = key.replace("text_embedding.0.", "text_embedding_0.") sanitized[new_key] = value + consumed.add(key) continue if key.startswith("text_embedding.2."): new_key = key.replace("text_embedding.2.", "text_embedding_1.") sanitized[new_key] = value + consumed.add(key) continue # Time embedding Sequential: 0=Linear, 1=SiLU(no params), 2=Linear if key.startswith("time_embedding.0."): new_key = key.replace("time_embedding.0.", "time_embedding_0.") sanitized[new_key] = value + consumed.add(key) continue if key.startswith("time_embedding.2."): new_key = key.replace("time_embedding.2.", "time_embedding_1.") sanitized[new_key] = value + consumed.add(key) continue # Time projection Sequential: 0=SiLU(no params), 1=Linear if key.startswith("time_projection.1."): new_key = key.replace("time_projection.1.", "time_projection.") sanitized[new_key] = value + consumed.add(key) continue # FFN: Sequential(Linear, GELU, Linear) -> ffn.{0,2} -> ffn.fc1, ffn.fc2 @@ -137,9 +148,15 @@ def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, # Skip the freqs buffer (we compute it in the model) if key == "freqs": + consumed.add(key) continue sanitized[new_key] = value + consumed.add(key) + + unconsumed = set(weights.keys()) - consumed + if unconsumed: + logger.warning("Unconsumed transformer weight keys: %s", sorted(unconsumed)) return sanitized @@ -171,6 +188,7 @@ def sanitize_wan_t5_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array] norm.weight """ sanitized = {} + consumed = set() for key, value in weights.items(): new_key = key @@ -179,6 +197,11 @@ def sanitize_wan_t5_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array] new_key = new_key.replace(".ffn.gate.0.", ".ffn.gate_proj.") sanitized[new_key] = value + consumed.add(key) + + unconsumed = set(weights.keys()) - consumed + if unconsumed: + logger.warning("Unconsumed T5 weight keys: %s", sorted(unconsumed)) return sanitized @@ -189,6 +212,7 @@ def sanitize_wan_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array Handles Conv3d and Conv2d weight transpositions for MLX format. """ sanitized = {} + consumed = set() for key, value in weights.items(): new_key = key @@ -206,10 +230,78 @@ def sanitize_wan_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array # Need to adapt naming for our simplified structure sanitized[new_key] = value + consumed.add(key) + + unconsumed = set(weights.keys()) - consumed + if unconsumed: + logger.warning("Unconsumed VAE weight keys: %s", sorted(unconsumed)) return sanitized +def _load_lora_configs( + lora_configs: List[Tuple[str, float]], +) -> Dict[str, list]: + """Load LoRA weights from config tuples, returning module_to_loras dict. + + Shared between weight-merging and runtime-wrapping paths. + """ + from mlx_video.lora import LoRAConfig, load_multiple_loras + from mlx_video.utils import Colors + + print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}") + + configs = [] + for lora_path, strength in lora_configs: + try: + config = LoRAConfig(path=lora_path, strength=strength) + configs.append(config) + print(f" - {Path(lora_path).name} (strength: {strength})") + except Exception as e: + print(f"{Colors.RED}Error loading LoRA {lora_path}: {e}{Colors.RESET}") + raise + + module_to_loras = load_multiple_loras(configs) + + if not module_to_loras: + print(f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}") + + return module_to_loras + + +def load_and_apply_loras( + model_weights: Dict[str, mx.array], + lora_configs: Optional[List[Tuple[str, float]]] = None, + verbose: bool = False, + quantization_bits: int = 0, +) -> Dict[str, mx.array]: + """Load and apply LoRA weights to model weights by merging into weight dict. + + For non-quantized (bf16) models. For quantized models, use apply_loras_to_model(). + """ + from mlx_video.lora import apply_loras_to_weights + from mlx_video.utils import Colors + + if not lora_configs: + return model_weights + + module_to_loras = _load_lora_configs(lora_configs) + if not module_to_loras: + return model_weights + + print(f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}") + if verbose: + print(f" Model has {len(model_weights)} weight keys") + + modified_weights = apply_loras_to_weights( + model_weights, module_to_loras, verbose=verbose, quantization_bits=quantization_bits + ) + + print(f"{Colors.GREEN}✓ LoRAs applied successfully{Colors.RESET}") + + return modified_weights + + def convert_wan_checkpoint( checkpoint_dir: str, output_dir: str, @@ -464,30 +556,45 @@ def _quantize_saved_model( is_dual: bool, bits: int, group_size: int, + source_dir: Path = None, ): - """Load saved bf16 model, quantize, and re-save.""" + """Load saved bf16 model, quantize, and re-save. + + Args: + output_dir: Directory to write quantized weights to. + config: WanModelConfig for creating the model. + is_dual: Whether this is a dual-expert model. + bits: Quantization bits. + group_size: Quantization group size. + source_dir: Directory to read bf16 weights from. Defaults to output_dir. + """ import json import mlx.nn as nn from mlx_video.models.wan.model import WanModel - model_files = [] + if source_dir is None: + source_dir = output_dir + + model_names = [] if is_dual: for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]: - p = output_dir / name - if p.exists(): - model_files.append(p) + if (source_dir / name).exists(): + model_names.append(name) else: - p = output_dir / "model.safetensors" - if p.exists(): - model_files.append(p) + if (source_dir / "model.safetensors").exists(): + model_names.append("model.safetensors") - for model_path in model_files: - print(f" Quantizing {model_path.name}...") + for name in model_names: + print(f" Quantizing {name}...") model = WanModel(config) - weights = mx.load(str(model_path)) + weights = mx.load(str(source_dir / name)) model.load_weights(list(weights.items()), strict=False) + mx.eval(model.parameters()) + del weights + gc.collect() + mx.clear_cache() # Apply quantization to targeted layers nn.quantize( @@ -499,10 +606,30 @@ def _quantize_saved_model( # Save quantized weights weights_dict = dict(mlx.utils.tree_flatten(model.parameters())) - mx.save_safetensors(str(model_path), weights_dict) + + # Validate: check for NaN/Inf in bias tensors (corruption canary) + bad_keys = [] + for k, v in weights_dict.items(): + if k.endswith(".bias") and not k.endswith(".biases"): + mx.eval(v) + if mx.any(mx.isnan(v)).item() or mx.any(mx.isinf(v)).item(): + bad_keys.append(k) + if bad_keys: + raise RuntimeError( + f"Quantization produced corrupted weights in {model_path.name}: " + f"{len(bad_keys)} bias tensors contain NaN/Inf " + f"(e.g. {bad_keys[0]}). Try re-running with more available memory." + ) + + mx.save_safetensors(str(output_dir / name), weights_dict) n_quantized = sum(1 for k in weights_dict if ".scales" in k) print(f" {n_quantized} layers quantized, {len(weights_dict)} tensors saved") + # Free model before processing next file + del model, weights_dict + gc.collect() + mx.clear_cache() + # Update config.json with quantization metadata config_path = output_dir / "config.json" with open(config_path) as f: @@ -516,6 +643,68 @@ def _quantize_saved_model( print(f" Updated config.json with quantization metadata") +def quantize_mlx_model( + mlx_model_dir: str, + output_dir: str, + bits: int = 4, + group_size: int = 64, +): + """Quantize an already-converted MLX model (skips PyTorch conversion). + + Args: + mlx_model_dir: Path to existing MLX model directory (bf16/fp16). + output_dir: Path to output quantized model directory. + bits: Quantization bits (4 or 8). + group_size: Quantization group size (32, 64, or 128). + """ + import json + import shutil + + src = Path(mlx_model_dir) + dst = Path(output_dir) + + config_path = src / "config.json" + if not config_path.exists(): + raise FileNotFoundError(f"No config.json found in {src}") + + with open(config_path) as f: + cfg = json.load(f) + + if cfg.get("quantization"): + raise ValueError( + f"Model at {src} is already quantized " + f"({cfg['quantization']['bits']}-bit). Use a bf16/fp16 source." + ) + + # Detect dual vs single expert + is_dual = (src / "low_noise_model.safetensors").exists() and ( + src / "high_noise_model.safetensors" + ).exists() + + # Build model config + from mlx_video.models.wan.config import WanModelConfig + + config_dict = {k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__} + for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"): + if key in config_dict and isinstance(config_dict[key], list): + config_dict[key] = tuple(config_dict[key]) + config = WanModelConfig(**config_dict) + + # Copy non-transformer files to output dir (skip large model weights) + transformer_files = {"low_noise_model.safetensors", "high_noise_model.safetensors", "model.safetensors"} + if dst.resolve() != src.resolve(): + dst.mkdir(parents=True, exist_ok=True) + for f in src.iterdir(): + if f.is_file() and f.name not in transformer_files: + shutil.copy2(f, dst / f.name) + print(f"Copied non-transformer files from {src} to {dst}") + + print(f"Quantizing transformer weights ({bits}-bit, group_size={group_size})...") + _quantize_saved_model(dst, config, is_dual, bits, group_size, source_dir=src) + + print(f"\nQuantization complete! Output: {dst}") + + if __name__ == "__main__": import argparse @@ -551,6 +740,11 @@ if __name__ == "__main__": action="store_true", help="Quantize transformer weights for faster inference", ) + parser.add_argument( + "--quantize-only", + action="store_true", + help="Quantize an already-converted MLX model (skips PyTorch conversion)", + ) parser.add_argument( "--bits", type=int, @@ -566,7 +760,14 @@ if __name__ == "__main__": help="Quantization group size (default: 64)", ) args = parser.parse_args() - convert_wan_checkpoint( - args.checkpoint_dir, args.output_dir, args.dtype, args.model_version, - quantize=args.quantize, bits=args.bits, group_size=args.group_size, - ) + + if args.quantize_only: + quantize_mlx_model( + args.checkpoint_dir, args.output_dir, + bits=args.bits, group_size=args.group_size, + ) + else: + convert_wan_checkpoint( + args.checkpoint_dir, args.output_dir, args.dtype, args.model_version, + quantize=args.quantize, bits=args.bits, group_size=args.group_size, + ) diff --git a/mlx_video/generate_wan.py b/mlx_video/generate_wan.py index 1bd4fe7..a1875d5 100644 --- a/mlx_video/generate_wan.py +++ b/mlx_video/generate_wan.py @@ -43,6 +43,10 @@ def generate_video( seed: int = -1, output_path: str = "output.mp4", scheduler: str = "unipc", + loras: list | None = None, + loras_high: list | None = None, + loras_low: list | None = None, + ): """Generate video using Wan pipeline (supports T2V and I2V). @@ -60,6 +64,10 @@ def generate_video( seed: Random seed (-1 for random) output_path: Output video path scheduler: Solver type: 'euler', 'dpm++', or 'unipc' (default) + loras: Optional list of (path, strength) tuples applied to all models + loras_high: Optional list of (path, strength) tuples for high-noise model only + loras_low: Optional list of (path, strength) tuples for low-noise model only + """ import json @@ -156,6 +164,12 @@ def generate_video( parts = [float(x) for x in guide_scale.split(",")] guide_scale = tuple(parts) if len(parts) > 1 else parts[0] + # Detect CFG-disabled mode (guide_scale=1.0 for all models → skip uncond pass for 2x speedup) + if isinstance(guide_scale, tuple): + cfg_disabled = all(gs <= 1.0 for gs in guide_scale) + else: + cfg_disabled = guide_scale <= 1.0 + # Validate frame count assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}" @@ -181,6 +195,8 @@ def generate_video( print(f" Neg prompt: {neg_display}") print(f" Size: {width}x{height}, Frames: {num_frames}") print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}") + if cfg_disabled: + print(f" CFG: disabled (guide_scale≤1 → B=1 fast path, 2x denoising speedup)") print(f"{Colors.RESET}") # Seed @@ -233,8 +249,12 @@ def generate_video( # Encode prompts print(f"{Colors.BLUE}Encoding text...{Colors.RESET}") context = encode_text(t5_encoder, tokenizer, prompt, config.text_len) - context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len) - mx.eval(context, context_null) + if cfg_disabled: + context_null = None + mx.eval(context) + else: + context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len) + mx.eval(context, context_null) # Free T5 from memory del t5_encoder @@ -319,48 +339,78 @@ def generate_video( print(f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}") t2 = time.time() + # Merge per-model LoRAs with shared LoRAs + _loras_low = (loras or []) + (loras_low or []) or None + _loras_high = (loras or []) + (loras_high or []) or None + _loras_single = loras + if is_dual: low_noise_path = model_dir / "low_noise_model.safetensors" high_noise_path = model_dir / "high_noise_model.safetensors" - low_noise_model = load_wan_model(low_noise_path, config, quantization) - high_noise_model = load_wan_model(high_noise_path, config, quantization) + low_noise_model = load_wan_model(low_noise_path, config, quantization, loras=_loras_low) + high_noise_model = load_wan_model(high_noise_path, config, quantization, loras=_loras_high) else: - single_model = load_wan_model(model_dir / "model.safetensors", config, quantization) + single_model = load_wan_model(model_dir / "model.safetensors", config, quantization, loras=_loras_single) print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}") # Precompute text embeddings once (avoids redundant MLP in every step) # Each model has its own text_embedding weights, so dual models need separate embeddings - if is_dual: - context_emb_low = low_noise_model.embed_text([context, context_null]) - context_emb_high = high_noise_model.embed_text([context, context_null]) - mx.eval(context_emb_low, context_emb_high) - context_cfg_low = mx.concatenate([context_emb_low[0:1], context_emb_low[1:2]], axis=0) - context_cfg_high = mx.concatenate([context_emb_high[0:1], context_emb_high[1:2]], axis=0) + if cfg_disabled: + # No CFG: only compute cond embeddings (B=1 forward pass, 2x faster) + if is_dual: + context_emb_low = low_noise_model.embed_text([context]) + context_emb_high = high_noise_model.embed_text([context]) + mx.eval(context_emb_low, context_emb_high) + context_cond_low = context_emb_low[0:1] + context_cond_high = context_emb_high[0:1] + else: + context_emb = single_model.embed_text([context]) + mx.eval(context_emb) + context_cond = context_emb[0:1] else: - context_emb = single_model.embed_text([context, context_null]) - mx.eval(context_emb) - context_cfg = mx.concatenate([context_emb[0:1], context_emb[1:2]], axis=0) + if is_dual: + context_emb_low = low_noise_model.embed_text([context, context_null]) + context_emb_high = high_noise_model.embed_text([context, context_null]) + mx.eval(context_emb_low, context_emb_high) + context_cfg_low = mx.concatenate([context_emb_low[0:1], context_emb_low[1:2]], axis=0) + context_cfg_high = mx.concatenate([context_emb_high[0:1], context_emb_high[1:2]], axis=0) + else: + context_emb = single_model.embed_text([context, context_null]) + mx.eval(context_emb) + context_cfg = mx.concatenate([context_emb[0:1], context_emb[1:2]], axis=0) # Precompute cross-attention K/V caches (constant across all steps) - if is_dual: - cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg_low) - cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg_high) - mx.eval(cross_kv_low, cross_kv_high) + if cfg_disabled: + if is_dual: + cross_kv_low = low_noise_model.prepare_cross_kv(context_cond_low) + cross_kv_high = high_noise_model.prepare_cross_kv(context_cond_high) + mx.eval(cross_kv_low, cross_kv_high) + else: + cross_kv = single_model.prepare_cross_kv(context_cond) + mx.eval(cross_kv) else: - cross_kv = single_model.prepare_cross_kv(context_cfg) - mx.eval(cross_kv) + if is_dual: + cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg_low) + cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg_high) + mx.eval(cross_kv_low, cross_kv_high) + else: + cross_kv = single_model.prepare_cross_kv(context_cfg) + mx.eval(cross_kv) # Precompute RoPE frequencies (grid sizes are constant across all steps) f_grid = t_latent // patch_size[0] h_grid = h_latent // patch_size[1] w_grid = w_latent // patch_size[2] - cfg_grid_sizes = [(f_grid, h_grid, w_grid), (f_grid, h_grid, w_grid)] + if cfg_disabled: + rope_grid_sizes = [(f_grid, h_grid, w_grid)] + else: + rope_grid_sizes = [(f_grid, h_grid, w_grid), (f_grid, h_grid, w_grid)] if is_dual: - rope_cos_sin_low = low_noise_model.prepare_rope(cfg_grid_sizes) - rope_cos_sin_high = high_noise_model.prepare_rope(cfg_grid_sizes) + rope_cos_sin_low = low_noise_model.prepare_rope(rope_grid_sizes) + rope_cos_sin_high = high_noise_model.prepare_rope(rope_grid_sizes) mx.eval(rope_cos_sin_low, rope_cos_sin_high) else: - rope_cos_sin = ref_model.prepare_rope(cfg_grid_sizes) + rope_cos_sin = ref_model.prepare_rope(rope_grid_sizes) mx.eval(rope_cos_sin) # Setup scheduler @@ -395,58 +445,86 @@ def generate_video( for i, t in enumerate(tqdm(range(steps), desc="Diffusion")): timestep_val = timestep_list[i] - # Select model, guide scale, cached K/V, and precomputed RoPE + # Select model, cached K/V, and precomputed RoPE if is_dual: if timestep_val >= boundary: model = high_noise_model - gs = guide_scale[1] kv = cross_kv_high rcs = rope_cos_sin_high else: model = low_noise_model - gs = guide_scale[0] kv = cross_kv_low rcs = rope_cos_sin_low else: model = single_model - gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0] kv = cross_kv rcs = rope_cos_sin - # Build per-token timesteps for TI2V-5B (first-frame patches get t=0) - if is_i2v_mask_blend: - t_tokens = i2v_mask_tokens * timestep_val # [1, L] - # Pad to seq_len if needed - pad_len = seq_len - t_tokens.shape[1] - if pad_len > 0: - t_tokens = mx.concatenate( - [t_tokens, mx.full((1, pad_len), timestep_val)], axis=1 - ) - # Batch for CFG: both cond and uncond get same timesteps - t_batch = mx.concatenate([t_tokens, t_tokens], axis=0) # [2, L] + if cfg_disabled: + # No CFG: B=1 forward pass (2x faster than B=2 CFG batch) + if is_i2v_mask_blend: + t_tokens = i2v_mask_tokens * timestep_val + pad_len = seq_len - t_tokens.shape[1] + if pad_len > 0: + t_tokens = mx.concatenate( + [t_tokens, mx.full((1, pad_len), timestep_val)], axis=1 + ) + t_batch = t_tokens # [1, L] + else: + t_batch = mx.array([timestep_val]) + + y_arg = [y_i2v] if is_i2v_channel_concat else None + + if is_dual: + ctx = context_cond_high if timestep_val >= boundary else context_cond_low + else: + ctx = context_cond + preds = model( + [latents], + t=t_batch, + context=ctx, + seq_len=seq_len, + cross_kv_caches=kv, + y=y_arg, + rope_cos_sin=rcs, + ) + noise_pred = preds[0] + del preds else: - t_batch = mx.array([timestep_val, timestep_val]) + # CFG: batch cond + uncond into single B=2 forward pass + if is_dual: + gs = guide_scale[1] if timestep_val >= boundary else guide_scale[0] + else: + gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0] - # I2V-14B: pass y conditioning to model (same y for cond and uncond) - y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None + if is_i2v_mask_blend: + t_tokens = i2v_mask_tokens * timestep_val + pad_len = seq_len - t_tokens.shape[1] + if pad_len > 0: + t_tokens = mx.concatenate( + [t_tokens, mx.full((1, pad_len), timestep_val)], axis=1 + ) + t_batch = mx.concatenate([t_tokens, t_tokens], axis=0) + else: + t_batch = mx.array([timestep_val, timestep_val]) - # CFG: batch cond + uncond into single B=2 forward pass - ctx = context_cfg if not is_dual else ( - context_cfg_high if timestep_val >= boundary else context_cfg_low - ) - preds = model( - [latents, latents], - t=t_batch, - context=ctx, - seq_len=seq_len, - cross_kv_caches=kv, - y=y_arg, - rope_cos_sin=rcs, - ) - noise_pred_cond, noise_pred_uncond = preds[0], preds[1] + y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None - # Classifier-free guidance + scheduler step - noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond) + ctx = context_cfg if not is_dual else ( + context_cfg_high if timestep_val >= boundary else context_cfg_low + ) + preds = model( + [latents, latents], + t=t_batch, + context=ctx, + seq_len=seq_len, + cross_kv_caches=kv, + y=y_arg, + rope_cos_sin=rcs, + ) + noise_pred_cond, noise_pred_uncond = preds[0], preds[1] + noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond) + del noise_pred_cond, noise_pred_uncond, preds latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0) @@ -455,7 +533,7 @@ def generate_video( latents = (1.0 - i2v_mask) * z_img + i2v_mask * latents # Release temporaries before eval to free memory for graph execution - del noise_pred_cond, noise_pred_uncond, noise_pred, preds + del noise_pred mx.eval(latents) print(f"{Colors.DIM} Denoising: {time.time() - t3:.1f}s{Colors.RESET}") @@ -463,11 +541,19 @@ def generate_video( # Free transformer models and text embeddings if is_dual: del low_noise_model, high_noise_model, cross_kv_low, cross_kv_high - del context_cfg_low, context_cfg_high + if cfg_disabled: + del context_cond_low, context_cond_high + else: + del context_cfg_low, context_cfg_high else: del single_model, cross_kv - del context_cfg - del model, kv, context, context_null + if cfg_disabled: + del context_cond + else: + del context_cfg + del model, kv, context + if context_null is not None: + del context_null gc.collect(); mx.clear_cache() # Load VAE and decode @@ -478,25 +564,36 @@ def generate_video( is_wan22_vae = config.vae_z_dim == 48 + # Warm-up: prepend a copy of the first latent frame to provide temporal + # context for the real first frame. Causal convolutions in the VAE decoder + # pad with zeros on the left, so the first few output frames have degraded + # quality (no temporal context). By duplicating the first latent, the real + # first frame sees its own features as left context instead of zeros. + # We trim the extra output frames after decoding. + warmup_trim = vae_stride[0] # 4 frames per latent temporal position + latents_for_decode = mx.concatenate([latents[:, 0:1], latents], axis=1) + if is_wan22_vae: from mlx_video.models.wan.vae22 import denormalize_latents # latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE) - z = latents.transpose(1, 2, 3, 0)[None] # [1, T, H, W, C] + z = latents_for_decode.transpose(1, 2, 3, 0)[None] # [1, T+1, H, W, C] z = denormalize_latents(z) video = vae(z) # [1, T', H', W', 3] mx.eval(video) print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}") video = np.array(video[0]) # [T', H', W', 3] + video = video[warmup_trim:] # Trim warm-up frames video = (video + 1.0) / 2.0 video = np.clip(video * 255.0, 0, 255).astype(np.uint8) else: - video = vae.decode(latents[None]) # [1, 3, T, H, W] + video = vae.decode(latents_for_decode[None]) # [1, 3, T+1*4, H, W] mx.eval(video) print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}") - video = np.array(video[0]) # [3, T, H, W] + video = np.array(video[0]) # [3, T', H, W] + video = video[:, warmup_trim:] # Trim warm-up frames (channels-first) video = (video + 1.0) / 2.0 video = np.clip(video * 255.0, 0, 255).astype(np.uint8) video = video.transpose(1, 2, 3, 0) # [T, H, W, 3] @@ -529,6 +626,19 @@ def main(): choices=["euler", "dpm++", "unipc"], help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)", ) + parser.add_argument( + "--lora", nargs=2, action="append", metavar=("PATH", "STRENGTH"), + help="Apply a LoRA to all models (repeatable). Format: --lora path.safetensors 0.8", + ) + parser.add_argument( + "--lora-high", nargs=2, action="append", metavar=("PATH", "STRENGTH"), + help="Apply a LoRA to high-noise model only (dual-model, repeatable)", + ) + parser.add_argument( + "--lora-low", nargs=2, action="append", metavar=("PATH", "STRENGTH"), + help="Apply a LoRA to low-noise model only (dual-model, repeatable)", + ) + args = parser.parse_args() # Parse guide scale @@ -542,6 +652,12 @@ def main(): if args.no_negative_prompt: neg_prompt = "" + # Parse LoRA configs: convert [path, strength_str] → (path, float) + def _parse_lora_args(lora_list): + if not lora_list: + return None + return [(path, float(strength)) for path, strength in lora_list] + generate_video( model_dir=args.model_dir, prompt=args.prompt, @@ -556,6 +672,10 @@ def main(): seed=args.seed, output_path=args.output_path, scheduler=args.scheduler, + loras=_parse_lora_args(args.lora), + loras_high=_parse_lora_args(args.lora_high), + loras_low=_parse_lora_args(args.lora_low), + ) diff --git a/mlx_video/lora/__init__.py b/mlx_video/lora/__init__.py new file mode 100644 index 0000000..4c0d81b --- /dev/null +++ b/mlx_video/lora/__init__.py @@ -0,0 +1,25 @@ +"""LoRA support for mlx-video.""" + +from mlx_video.lora.apply import ( + LoRALinear, + apply_lora_to_linear, + apply_loras_to_model, + apply_loras_to_weights, +) +from mlx_video.lora.loader import ( + load_lora_weights, + load_multiple_loras, +) +from mlx_video.lora.types import AppliedLoRA, LoRAConfig, LoRAWeights + +__all__ = [ + "LoRAConfig", + "LoRAWeights", + "AppliedLoRA", + "load_lora_weights", + "load_multiple_loras", + "apply_lora_to_linear", + "apply_loras_to_weights", + "apply_loras_to_model", + "LoRALinear", +] diff --git a/mlx_video/lora/apply.py b/mlx_video/lora/apply.py new file mode 100644 index 0000000..97b694e --- /dev/null +++ b/mlx_video/lora/apply.py @@ -0,0 +1,393 @@ +"""Apply LoRA weights to model layers.""" + +from typing import Dict, List, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from mlx_video.lora.types import LoRAWeights + + +def apply_lora_to_linear( + linear_weight: mx.array, + lora_weights_and_strengths: List[Tuple[LoRAWeights, float]], +) -> mx.array: + """Apply one or more LoRAs to a linear layer weight. + + Args: + linear_weight: Original weight matrix [out_features, in_features] + lora_weights_and_strengths: List of (LoRAWeights, strength) tuples + + Returns: + Modified weight with LoRA deltas applied (preserves original dtype) + """ + orig_dtype = linear_weight.dtype + modified_weight = linear_weight + + for weights, strength in lora_weights_and_strengths: + scale = weights.scale + # Compute delta in float32 for precision, then cast back to avoid + # promoting model weights (e.g. bfloat16 → float32 causes ~1.5x slowdown) + delta = (weights.lora_B @ weights.lora_A) * (scale * strength) + modified_weight = modified_weight + delta.astype(orig_dtype) + + return modified_weight + + +def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str: + """Normalize LoRA module name to match Wan2.2 MLX model weight keys. + + Handles: + - Stripping common prefixes (diffusion_model., model., etc.) + - FFN key mapping: ffn.0 → ffn.fc1, ffn.2 → ffn.fc2 + - Embedding key mapping: text_embedding.0 → text_embedding_0, etc. + - Time projection: time_projection.1 → time_projection + - Patch embedding: patch_embedding → patch_embedding_proj + + Args: + lora_key: Original LoRA module name + model_keys: Set of all model weight keys + + Returns: + Normalized key that matches model weights + """ + # Try the key as-is first + if f"{lora_key}.weight" in model_keys or lora_key in model_keys: + return lora_key + + # Common prefixes to strip + prefixes_to_strip = [ + "model.diffusion_model.", + "diffusion_model.", + "base_model.model.", + "model.", + ] + + candidates = [lora_key] + for prefix in prefixes_to_strip: + if lora_key.startswith(prefix): + candidates.append(lora_key[len(prefix):]) + + for candidate in candidates: + # Try as-is + if f"{candidate}.weight" in model_keys or candidate in model_keys: + return candidate + + # Apply Wan2.2 key transformations + transformed = candidate + + # FFN: ffn.0 → ffn.fc1, ffn.2 → ffn.fc2 + transformed = transformed.replace(".ffn.0.", ".ffn.fc1.") + transformed = transformed.replace(".ffn.2.", ".ffn.fc2.") + if transformed.endswith(".ffn.0"): + transformed = transformed[:-len(".ffn.0")] + ".ffn.fc1" + if transformed.endswith(".ffn.2"): + transformed = transformed[:-len(".ffn.2")] + ".ffn.fc2" + + # Text embedding: text_embedding.0 → text_embedding_0 + transformed = transformed.replace("text_embedding.0.", "text_embedding_0.") + transformed = transformed.replace("text_embedding.2.", "text_embedding_1.") + if transformed.endswith("text_embedding.0"): + transformed = transformed[:-len("text_embedding.0")] + "text_embedding_0" + if transformed.endswith("text_embedding.2"): + transformed = transformed[:-len("text_embedding.2")] + "text_embedding_1" + + # Time embedding: time_embedding.0 → time_embedding_0 + transformed = transformed.replace("time_embedding.0.", "time_embedding_0.") + transformed = transformed.replace("time_embedding.2.", "time_embedding_1.") + if transformed.endswith("time_embedding.0"): + transformed = transformed[:-len("time_embedding.0")] + "time_embedding_0" + if transformed.endswith("time_embedding.2"): + transformed = transformed[:-len("time_embedding.2")] + "time_embedding_1" + + # Time projection: time_projection.1 → time_projection + transformed = transformed.replace("time_projection.1.", "time_projection.") + if transformed.endswith("time_projection.1"): + transformed = transformed[:-len("time_projection.1")] + "time_projection" + + # Patch embedding: patch_embedding → patch_embedding_proj + if "patch_embedding" in transformed and "patch_embedding_proj" not in transformed: + transformed = transformed.replace("patch_embedding", "patch_embedding_proj") + + if f"{transformed}.weight" in model_keys or transformed in model_keys: + return transformed + + # Return best attempt with prefix stripped + for prefix in prefixes_to_strip: + if lora_key.startswith(prefix): + return lora_key[len(prefix):] + + return lora_key + + +# Also support LTX-style key normalization +def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str: + """Normalize LoRA module name to match LTX MLX model weight keys.""" + if f"{lora_key}.weight" in model_keys or lora_key in model_keys: + return lora_key + + prefixes_to_strip = [ + "model.diffusion_model.", + "diffusion_model.", + "model.", + ] + + for prefix in prefixes_to_strip: + if lora_key.startswith(prefix): + normalized = lora_key[len(prefix):] + + if f"{normalized}.weight" in model_keys or normalized in model_keys: + return normalized + + transformed = normalized + if transformed.endswith(".to_out.0"): + transformed = transformed[:-len(".to_out.0")] + ".to_out" + transformed = transformed.replace(".to_out.0.", ".to_out.") + transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.") + transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in") + transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.") + transformed = transformed.replace(".ff.net.2", ".ff.proj_out") + transformed = transformed.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.") + transformed = transformed.replace(".audio_ff.net.0.proj", ".audio_ff.proj_in") + transformed = transformed.replace(".audio_ff.net.2.", ".audio_ff.proj_out.") + transformed = transformed.replace(".audio_ff.net.2", ".audio_ff.proj_out") + + if f"{transformed}.weight" in model_keys or transformed in model_keys: + return transformed + + # Try transformations on the original key + transformed = lora_key + if transformed.endswith(".to_out.0"): + transformed = transformed[:-len(".to_out.0")] + ".to_out" + transformed = transformed.replace(".to_out.0.", ".to_out.") + transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.") + transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in") + transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.") + transformed = transformed.replace(".ff.net.2", ".ff.proj_out") + + if f"{transformed}.weight" in model_keys or transformed in model_keys: + return transformed + + for prefix in prefixes_to_strip: + if lora_key.startswith(prefix): + return lora_key[len(prefix):] + + return lora_key + + +def _normalize_lora_key(lora_key: str, model_keys: set) -> str: + """Normalize LoRA module name to match model weight keys. + + Auto-detects whether to use Wan2.2 or LTX key normalization based + on the presence of architecture-specific keys in the model. + """ + # Detect model architecture from keys + is_wan = any("self_attn.q.weight" in k for k in model_keys) + + if is_wan: + return _normalize_wan_lora_key(lora_key, model_keys) + else: + return _normalize_ltx_lora_key(lora_key, model_keys) + + +def apply_loras_to_weights( + model_weights: Dict[str, mx.array], + module_to_loras: Dict[str, List[Tuple[LoRAWeights, float]]], + verbose: bool = False, + quantization_bits: int = 0, +) -> Dict[str, mx.array]: + """Apply LoRAs to model weights. + + Args: + model_weights: Original model state dictionary + module_to_loras: Dictionary mapping module names to lists of + (LoRAWeights, strength) tuples + verbose: If True, print detailed debug information + quantization_bits: If >0, weights are quantized at this bit width. + Quantized layers are dequantized before LoRA application + and re-quantized after. + + Returns: + New state dictionary with LoRA-modified weights + """ + modified_weights = dict(model_weights) + model_keys = set(model_weights.keys()) + + applied_count = 0 + skipped_count = 0 + skipped_modules = [] + + for module_name, loras in module_to_loras.items(): + normalized_name = _normalize_lora_key(module_name, model_keys) + weight_key = f"{normalized_name}.weight" + + if weight_key not in modified_weights: + if normalized_name not in modified_weights: + skipped_count += 1 + skipped_modules.append(module_name) + if verbose and skipped_count <= 5: + print(f" DEBUG: '{module_name}' -> '{normalized_name}' -> NOT FOUND") + similar = [ + k + for k in list(model_keys)[:1000] + if normalized_name.split(".")[-1] in k + ][:3] + if similar: + print(f" Similar keys: {similar}") + continue + weight_key = normalized_name + + original_weight = modified_weights[weight_key] + + # Handle quantized weights: dequantize → apply delta → re-quantize + scales_key = f"{normalized_name}.scales" + biases_key = f"{normalized_name}.biases" + is_quantized = ( + original_weight.dtype == mx.uint32 + and scales_key in modified_weights + and biases_key in modified_weights + ) + + if is_quantized: + scales = modified_weights[scales_key] + biases = modified_weights[biases_key] + group_size = (original_weight.shape[-1] * 32) // (scales.shape[-1] * quantization_bits) + dequantized = mx.dequantize( + original_weight, scales, biases, group_size=group_size, bits=quantization_bits + ) + modified = apply_lora_to_linear(dequantized, loras) + # Re-quantize with same parameters + new_w, new_scales, new_biases = mx.quantize(modified, group_size=group_size, bits=quantization_bits) + modified_weights[weight_key] = new_w + modified_weights[scales_key] = new_scales + modified_weights[biases_key] = new_biases + else: + modified_weights[weight_key] = apply_lora_to_linear(original_weight, loras) + + applied_count += 1 + + if applied_count > 0: + print(f" ✓ Applied to {applied_count} modules") + if skipped_count > 0: + print(f" ⚠ Skipped {skipped_count} incompatible modules") + + return modified_weights + + +class LoRALinear(nn.Module): + """Linear layer with on-the-fly LoRA application. + + Wraps nn.Linear or nn.QuantizedLinear, computing LoRA delta at runtime: + output = base_linear(x) + (x @ lora_A.T @ lora_B.T) * scale * strength + """ + + def __init__( + self, + linear: nn.Module, + lora_weights_and_strengths: List[Tuple[LoRAWeights, float]], + ): + super().__init__() + self.linear = linear + self.lora_weights_and_strengths = lora_weights_and_strengths + + def __call__(self, x: mx.array) -> mx.array: + output = self.linear(x) + for weights, strength in self.lora_weights_and_strengths: + scale = weights.scale + lora_out = x @ weights.lora_A.T @ weights.lora_B.T + output = output + (scale * strength * lora_out) + return output + + +def apply_loras_to_model( + model: nn.Module, + module_to_loras: Dict[str, List[Tuple[LoRAWeights, float]]], + verbose: bool = False, +) -> int: + """Apply LoRAs to a model by merging into weights. + + For QuantizedLinear layers: dequantizes to bf16, merges LoRA delta, and + replaces with a regular nn.Linear (no per-step overhead, no re-quantization + precision loss). Non-LoRA layers stay quantized. + + For nn.Linear layers: merges LoRA delta directly into the weight. + + Args: + model: The model to apply LoRAs to + module_to_loras: Dictionary mapping module names to (LoRAWeights, strength) lists + verbose: Print debug info + + Returns: + Number of modules modified + """ + # Build a set of model module paths for key normalization + module_paths = set() + for name, _ in model.named_modules(): + module_paths.add(name) + module_paths.add(f"{name}.weight") + + # Map LoRA keys → model module paths + lora_to_module = {} + for lora_key in module_to_loras: + normalized = _normalize_lora_key(lora_key, module_paths) + if normalized.endswith(".weight"): + normalized = normalized[: -len(".weight")] + lora_to_module[lora_key] = normalized + + applied_count = 0 + dequant_count = 0 + skipped = [] + + for lora_key, loras in module_to_loras.items(): + module_path = lora_to_module[lora_key] + parts = module_path.split(".") + + # Traverse to the parent module + parent = model + try: + for part in parts[:-1]: + parent = getattr(parent, part) if not part.isdigit() else parent[int(part)] + leaf_name = parts[-1] + target = getattr(parent, leaf_name) if not leaf_name.isdigit() else parent[int(leaf_name)] + except (AttributeError, IndexError, TypeError): + skipped.append(lora_key) + if verbose: + print(f" DEBUG: '{lora_key}' -> '{module_path}' -> module not found") + continue + + if isinstance(target, nn.QuantizedLinear): + # Dequantize → merge LoRA → replace with bf16 Linear + weight = mx.dequantize( + target.weight, target.scales, target.biases, + group_size=target.group_size, bits=target.bits, + ) + merged = apply_lora_to_linear(weight, loras) + new_linear = nn.Linear(merged.shape[1], merged.shape[0]) + new_linear.weight = merged + if "bias" in target: + new_linear.bias = target.bias + if leaf_name.isdigit(): + parent[int(leaf_name)] = new_linear + else: + setattr(parent, leaf_name, new_linear) + dequant_count += 1 + applied_count += 1 + elif isinstance(target, nn.Linear): + # Merge directly into weight + target.weight = apply_lora_to_linear(target.weight, loras) + applied_count += 1 + else: + skipped.append(lora_key) + if verbose: + print(f" DEBUG: '{module_path}' is {type(target).__name__}, not Linear") + continue + + if applied_count > 0: + msg = f" ✓ Applied to {applied_count} modules" + if dequant_count > 0: + msg += f" ({dequant_count} dequantized to bf16)" + print(msg) + if skipped: + print(f" ⚠ Skipped {len(skipped)} incompatible modules") + + return applied_count diff --git a/mlx_video/lora/loader.py b/mlx_video/lora/loader.py new file mode 100644 index 0000000..adf11b1 --- /dev/null +++ b/mlx_video/lora/loader.py @@ -0,0 +1,122 @@ +"""LoRA weight loading utilities.""" + +import re +from pathlib import Path +from typing import Dict, List, Optional + +import mlx.core as mx + +from mlx_video.lora.types import LoRAConfig, LoRAWeights + + +def load_lora_weights(lora_path: Path) -> Dict[str, LoRAWeights]: + """Load LoRA weights from a safetensors file. + + Supports both key conventions: + - {module_name}.lora_A.weight / {module_name}.lora_B.weight + - {module_name}.lora_down.weight / {module_name}.lora_up.weight + + Args: + lora_path: Path to the LoRA safetensors file + + Returns: + Dictionary mapping module names to LoRAWeights objects + + Raises: + FileNotFoundError: If the LoRA file doesn't exist + ValueError: If the LoRA file format is invalid + """ + if not lora_path.exists(): + raise FileNotFoundError(f"LoRA file not found: {lora_path}") + + all_weights = mx.load(str(lora_path)) + + # Group weights by module name, handling both naming conventions + lora_weights = {} + module_names = set() + + for key in all_weights.keys(): + # Format 1: {module}.lora_A.weight / {module}.lora_B.weight + match = re.match(r"(.+)\.lora_([AB])\.weight$", key) + if match: + module_names.add(match.group(1)) + continue + # Format 2: {module}.lora_down.weight / {module}.lora_up.weight + match = re.match(r"(.+)\.lora_(down|up)\.weight$", key) + if match: + module_names.add(match.group(1)) + + for module_name in module_names: + # Try both key conventions + key_a = f"{module_name}.lora_A.weight" + key_b = f"{module_name}.lora_B.weight" + if key_a not in all_weights or key_b not in all_weights: + key_a = f"{module_name}.lora_down.weight" + key_b = f"{module_name}.lora_up.weight" + if key_a not in all_weights or key_b not in all_weights: + continue + + lora_a = all_weights[key_a] + lora_b = all_weights[key_b] + + if lora_a.ndim != 2 or lora_b.ndim != 2: + raise ValueError( + f"Invalid LoRA shape for {module_name}: " + f"lora_A={lora_a.shape}, lora_B={lora_b.shape}" + ) + + rank = lora_a.shape[0] + if lora_b.shape[1] != rank: + raise ValueError( + f"LoRA rank mismatch for {module_name}: " + f"lora_A rank={rank}, lora_B rank={lora_b.shape[1]}" + ) + + # Check for per-module alpha stored as a scalar tensor + alpha_key = f"{module_name}.alpha" + if alpha_key in all_weights: + alpha = float(all_weights[alpha_key].item()) + else: + alpha = float(rank) + + lora_weights[module_name] = LoRAWeights( + lora_A=lora_a, + lora_B=lora_b, + rank=rank, + alpha=alpha, + module_name=module_name, + ) + + if not lora_weights: + raise ValueError(f"No valid LoRA weights found in {lora_path}") + + return lora_weights + + +def load_multiple_loras( + configs: List[LoRAConfig], +) -> Dict[str, List[tuple]]: + """Load multiple LoRA configurations. + + Args: + configs: List of LoRAConfig objects + + Returns: + Dictionary mapping module names to lists of (LoRAWeights, strength) tuples. + """ + module_to_loras: Dict[str, list] = {} + + for config in configs: + lora_weights = load_lora_weights(config.path) + + for module_name, weights in lora_weights.items(): + if config.target_modules is not None: + if module_name not in config.target_modules: + continue + + if module_name not in module_to_loras: + module_to_loras[module_name] = [] + + module_to_loras[module_name].append((weights, config.strength)) + + return module_to_loras diff --git a/mlx_video/lora/types.py b/mlx_video/lora/types.py new file mode 100644 index 0000000..c1aa3cc --- /dev/null +++ b/mlx_video/lora/types.py @@ -0,0 +1,74 @@ +"""Data structures for LoRA support.""" + +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import mlx.core as mx + + +@dataclass +class LoRAWeights: + """Container for LoRA weight matrices. + + Attributes: + lora_A: Low-rank matrix A of shape [rank, in_features] + lora_B: Low-rank matrix B of shape [out_features, rank] + rank: Rank of the LoRA decomposition + alpha: LoRA scaling parameter (default: rank) + module_name: Target module name in the model + """ + + lora_A: mx.array + lora_B: mx.array + rank: int + alpha: float + module_name: str + + @property + def scale(self) -> float: + """Compute the scale factor: alpha / rank.""" + return self.alpha / self.rank + + +@dataclass +class LoRAConfig: + """Configuration for a single LoRA. + + Attributes: + path: Path to the LoRA safetensors file + strength: Strength/weight to apply this LoRA (typically 0.0-2.0) + target_modules: Optional list of module names to apply LoRA to. + If None, applies to all available modules in the LoRA. + """ + + path: Path + strength: float = 1.0 + target_modules: Optional[list[str]] = None + + def __post_init__(self): + """Validate and normalize the configuration.""" + self.path = Path(self.path) + if not self.path.exists(): + raise FileNotFoundError(f"LoRA file not found: {self.path}") + if self.strength < 0: + raise ValueError(f"LoRA strength must be non-negative, got {self.strength}") + + +@dataclass +class AppliedLoRA: + """Represents a LoRA applied to a specific module. + + Attributes: + weights: The LoRA weight matrices + strength: Application strength for this LoRA + """ + + weights: LoRAWeights + strength: float + + def compute_delta(self) -> mx.array: + """Compute the weight delta: strength * scale * (lora_B @ lora_A).""" + scale = self.weights.scale + delta = self.weights.lora_B @ self.weights.lora_A + return scale * self.strength * delta diff --git a/mlx_video/models/wan/attention.py b/mlx_video/models/wan/attention.py index e3fe24a..b0a6f2f 100644 --- a/mlx_video/models/wan/attention.py +++ b/mlx_video/models/wan/attention.py @@ -4,6 +4,15 @@ import mlx.nn as nn from .rope import rope_apply +def _linear_dtype(layer) -> mx.Dtype: + """Get the compute dtype of a linear layer, handling QuantizedLinear and LoRA wrappers.""" + # Unwrap LoRA wrapper to get the underlying linear layer + inner = getattr(layer, "linear", layer) + if isinstance(inner, nn.QuantizedLinear): + return inner.scales.dtype + return inner.weight.dtype + + class WanRMSNorm(nn.Module): """RMS normalization with learnable scale.""" @@ -73,8 +82,8 @@ class WanSelfAttention(nn.Module): b, s, _ = x.shape n, d = self.num_heads, self.head_dim - # Cast to weight dtype for efficient matmul (bfloat16 matching official autocast) - w_dtype = self.q.weight.dtype + # Cast to compute dtype for efficient matmul (bfloat16 matching official autocast) + w_dtype = _linear_dtype(self.q) x_w = x.astype(w_dtype) q = self.q(x_w) @@ -154,8 +163,8 @@ class WanCrossAttention(nn.Module): """ b = context.shape[0] n, d = self.num_heads, self.head_dim - # Cast to weight dtype for efficient matmul - w_dtype = self.k.weight.dtype + # Cast to compute dtype for efficient matmul + w_dtype = _linear_dtype(self.k) ctx = context.astype(w_dtype) k = self.k(ctx) if self.norm_k is not None: @@ -174,8 +183,8 @@ class WanCrossAttention(nn.Module): b = x.shape[0] n, d = self.num_heads, self.head_dim - # Cast to weight dtype for efficient matmul (bfloat16 matching official autocast) - w_dtype = self.q.weight.dtype + # Cast to compute dtype for efficient matmul (bfloat16 matching official autocast) + w_dtype = _linear_dtype(self.q) q = self.q(x.astype(w_dtype)) if self.norm_q is not None: q = self.norm_q(q) diff --git a/mlx_video/models/wan/loading.py b/mlx_video/models/wan/loading.py index 4ef795b..35e3d12 100644 --- a/mlx_video/models/wan/loading.py +++ b/mlx_video/models/wan/loading.py @@ -6,14 +6,15 @@ import mlx.core as mx import mlx.nn as nn -def load_wan_model(model_path: Path, config, quantization: dict | None = None): - """Load and initialize WanModel, with optional quantization support. +def load_wan_model(model_path: Path, config, quantization: dict | None = None, loras: list | None = None): + """Load and initialize WanModel, with optional quantization and LoRA support. Args: model_path: Path to model safetensors file config: WanModelConfig quantization: Optional dict with 'bits' and 'group_size' keys. If provided, creates QuantizedLinear stubs before loading. + loras: Optional list of (lora_path, strength) tuples to apply. """ from mlx_video.models.wan.model import WanModel @@ -30,6 +31,27 @@ def load_wan_model(model_path: Path, config, quantization: dict | None = None): ) weights = mx.load(str(model_path)) + + # Apply LoRAs: dequantize+merge for quantized models, weight merge for bf16 + if loras: + if quantization: + # Dequantize LoRA-targeted layers, merge delta, replace with bf16 Linear. + # Non-LoRA layers stay 4-bit. Zero per-step overhead. + from mlx_video.convert_wan import _load_lora_configs + from mlx_video.lora import apply_loras_to_model + + model.load_weights(list(weights.items()), strict=False) + mx.eval(model.parameters()) + module_to_loras = _load_lora_configs(loras) + apply_loras_to_model(model, module_to_loras) + mx.eval(model.parameters()) + return model + else: + # Weight merging: fold LoRA into bf16 weights before loading + from mlx_video.convert_wan import load_and_apply_loras + + weights = load_and_apply_loras(dict(weights), loras) + model.load_weights(list(weights.items()), strict=False) mx.eval(model.parameters()) return model diff --git a/mlx_video/models/wan/model.py b/mlx_video/models/wan/model.py index a196c05..e6a3a40 100644 --- a/mlx_video/models/wan/model.py +++ b/mlx_video/models/wan/model.py @@ -4,7 +4,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .attention import WanLayerNorm +from .attention import WanLayerNorm, _linear_dtype from .config import WanModelConfig from .rope import rope_params, rope_precompute_cos_sin from .transformer import WanAttentionBlock @@ -54,7 +54,7 @@ class Head(nn.Module): e1 = mod[:, :, 1, :] # [B, L_e, dim] scale x_norm = self.norm(x) x_mod = x_norm * (1 + e1) + e0 # type promotion handles bf16→f32 - return self.head(x_mod.astype(self.head.weight.dtype)) + return self.head(x_mod.astype(_linear_dtype(self.head))) class WanModel(nn.Module): @@ -79,7 +79,7 @@ class WanModel(nn.Module): # Text embedding MLP self.text_embedding_0 = nn.Linear(config.text_dim, dim) - self.text_embedding_act = nn.GELU(approx="precise") + self.text_embedding_act = nn.GELU(approx="tanh") self.text_embedding_1 = nn.Linear(dim, dim) # Time embedding MLP @@ -149,7 +149,7 @@ class WanModel(nn.Module): # Project and cast to model dtype to prevent float32 cascade from input latents patches = self.patch_embedding_proj(x) # [L, dim] - patches = patches.astype(self.patch_embedding_proj.weight.dtype) + patches = patches.astype(_linear_dtype(self.patch_embedding_proj)) patches = patches[None, :, :] # [1, L, dim] return patches, (f_out, h_out, w_out) @@ -186,7 +186,7 @@ class WanModel(nn.Module): Returns: Embedded context [B, text_len, dim] in model dtype """ - model_dtype = self.patch_embedding_proj.weight.dtype + model_dtype = _linear_dtype(self.patch_embedding_proj) context_padded = [] for ctx in context: pad_len = self.text_len - ctx.shape[0] @@ -231,7 +231,7 @@ class WanModel(nn.Module): Returns: (cos_f, sin_f) precomputed frequency tensors """ - w_dtype = self.patch_embedding_proj.weight.dtype + w_dtype = _linear_dtype(self.patch_embedding_proj) return rope_precompute_cos_sin(grid_sizes, self.freqs, dtype=w_dtype) def __call__( @@ -348,7 +348,7 @@ class WanModel(nn.Module): # Pre-compute attention mask from seq_lens (constant across all blocks) attn_mask = None - w_dtype = self.patch_embedding_proj.weight.dtype + w_dtype = _linear_dtype(self.patch_embedding_proj) if any(sl < seq_len for sl in seq_lens_list): attn_mask = mx.zeros((batch_size, 1, 1, seq_len), dtype=w_dtype) for i, sl in enumerate(seq_lens_list): diff --git a/mlx_video/models/wan/text_encoder.py b/mlx_video/models/wan/text_encoder.py index d325ed5..b81a072 100644 --- a/mlx_video/models/wan/text_encoder.py +++ b/mlx_video/models/wan/text_encoder.py @@ -146,7 +146,7 @@ class T5FeedForward(nn.Module): self.dim = dim self.dim_ffn = dim_ffn self.gate_proj = nn.Linear(dim, dim_ffn, bias=False) - self.gate_act = nn.GELU(approx="precise") + self.gate_act = nn.GELU(approx="tanh") self.fc1 = nn.Linear(dim, dim_ffn, bias=False) self.fc2 = nn.Linear(dim_ffn, dim, bias=False) diff --git a/mlx_video/models/wan/transformer.py b/mlx_video/models/wan/transformer.py index c85c90e..857bcae 100644 --- a/mlx_video/models/wan/transformer.py +++ b/mlx_video/models/wan/transformer.py @@ -1,7 +1,7 @@ import mlx.core as mx import mlx.nn as nn -from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention +from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention, _linear_dtype class WanAttentionBlock(nn.Module): @@ -84,10 +84,10 @@ class WanFFN(nn.Module): def __init__(self, dim: int, ffn_dim: int): super().__init__() self.fc1 = nn.Linear(dim, ffn_dim) - self.act = nn.GELU(approx="precise") + self.act = nn.GELU(approx="tanh") self.fc2 = nn.Linear(ffn_dim, dim) def __call__(self, x: mx.array) -> mx.array: - # Cast to weight dtype for efficient matmul (bfloat16 matching official autocast) - x_w = x.astype(self.fc1.weight.dtype) + # Cast to compute dtype for efficient matmul (bfloat16 matching official autocast) + x_w = x.astype(_linear_dtype(self.fc1)) return self.fc2(self.act(self.fc1(x_w))) diff --git a/mlx_video/models/wan/vae22.py b/mlx_video/models/wan/vae22.py index 36fe8e4..8c31d3b 100644 --- a/mlx_video/models/wan/vae22.py +++ b/mlx_video/models/wan/vae22.py @@ -7,12 +7,15 @@ Weight keys mirror the PyTorch checkpoint hierarchy so only tensor format conversion (channels-first → channels-last) is needed. """ +import logging import math import mlx.core as mx import mlx.nn as nn import numpy as np +logger = logging.getLogger(__name__) + CACHE_T = 2 # Per-channel normalization for z_dim=48 latent space @@ -774,11 +777,13 @@ def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) -> Maps PyTorch nn.Sequential indices to our named layers. """ sanitized = {} + consumed = set() for key, value in weights.items(): # Skip encoder and conv1 unless requested if not include_encoder: if key.startswith("encoder.") or key.startswith("conv1."): + consumed.add(key) continue new_key = key @@ -832,5 +837,10 @@ def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) -> value = mx.array(np.array(value).squeeze()) sanitized[new_key] = value + consumed.add(key) + + unconsumed = set(weights.keys()) - consumed + if unconsumed: + logger.warning("Unconsumed Wan2.2 VAE weight keys: %s", sorted(unconsumed)) return sanitized diff --git a/tests/test_wan_convert.py b/tests/test_wan_convert.py index 8b1213c..81630ce 100644 --- a/tests/test_wan_convert.py +++ b/tests/test_wan_convert.py @@ -1,5 +1,7 @@ """Tests for Wan weight conversion utilities.""" +import logging + import mlx.core as mx import numpy as np import pytest @@ -94,6 +96,27 @@ class TestSanitizeTransformerWeights: for key in weights: assert key in out + def test_no_unconsumed_keys(self, caplog): + from mlx_video.convert_wan import sanitize_wan_transformer_weights + weights = { + "patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)), + "patch_embedding.bias": mx.random.normal((5120,)), + "text_embedding.0.weight": mx.zeros((64, 32)), + "text_embedding.2.weight": mx.zeros((64, 64)), + "time_embedding.0.weight": mx.zeros((64, 32)), + "time_embedding.2.weight": mx.zeros((64, 64)), + "time_projection.1.weight": mx.zeros((384, 64)), + "blocks.0.ffn.0.weight": mx.zeros((128, 64)), + "blocks.0.ffn.2.weight": mx.zeros((64, 128)), + "blocks.0.self_attn.q.weight": mx.zeros((64, 64)), + "blocks.0.modulation": mx.zeros((1, 6, 64)), + "head.head.weight": mx.zeros((64, 64)), + "freqs": mx.zeros((1024, 64, 2)), + } + with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"): + sanitize_wan_transformer_weights(weights) + assert "Unconsumed" not in caplog.text + class TestSanitizeT5Weights: def test_gate_rename(self): @@ -119,6 +142,19 @@ class TestSanitizeT5Weights: for key in weights: assert key in out + def test_no_unconsumed_keys(self, caplog): + from mlx_video.convert_wan import sanitize_wan_t5_weights + weights = { + "token_embedding.weight": mx.zeros((100, 64)), + "blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)), + "blocks.0.ffn.fc1.weight": mx.zeros((128, 64)), + "blocks.0.ffn.fc2.weight": mx.zeros((64, 128)), + "norm.weight": mx.zeros((64,)), + } + with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"): + sanitize_wan_t5_weights(weights) + assert "Unconsumed" not in caplog.text + class TestSanitizeVAEWeights: def test_conv3d_transpose(self): @@ -161,6 +197,18 @@ class TestSanitizeVAEWeights: assert out["linear.weight"].shape == (8, 4) assert out["norm.weight"].shape == (8,) + def test_no_unconsumed_keys(self, caplog): + from mlx_video.convert_wan import sanitize_wan_vae_weights + weights = { + "decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), + "decoder.proj.weight": mx.zeros((16, 8, 3, 3)), + "decoder.norm.weight": mx.zeros((64,)), + "decoder.bias": mx.zeros((16,)), + } + with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"): + sanitize_wan_vae_weights(weights) + assert "Unconsumed" not in caplog.text + # --------------------------------------------------------------------------- # Wan2.1 Conversion Tests @@ -233,3 +281,27 @@ class TestSanitizeEncoderWeights: assert "encoder.conv1.weight" in out assert "conv1.weight" in out assert "conv2.weight" in out + + def test_no_unconsumed_keys(self, caplog): + from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + + weights = { + "encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)), + "conv1.weight": mx.zeros((8, 1, 1, 1, 8)), + "conv2.weight": mx.zeros((8, 1, 1, 1, 8)), + } + with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"): + sanitize_wan22_vae_weights(weights, include_encoder=True) + assert "Unconsumed" not in caplog.text + + def test_no_unconsumed_keys_exclude_encoder(self, caplog): + from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + + weights = { + "encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)), + "conv1.weight": mx.zeros((8, 1, 1, 1, 8)), + "conv2.weight": mx.zeros((8, 1, 1, 1, 8)), + } + with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"): + sanitize_wan22_vae_weights(weights, include_encoder=False) + assert "Unconsumed" not in caplog.text diff --git a/tests/test_wan_i2v.py b/tests/test_wan_i2v.py index 53077a0..1843715 100644 --- a/tests/test_wan_i2v.py +++ b/tests/test_wan_i2v.py @@ -291,3 +291,282 @@ class TestI2VMaskConstruction: encoded = mx.zeros((16, 5, 10, 18)) y = mx.concatenate([mask, encoded], axis=0) assert y.shape == (20, 5, 10, 18) + + +# --------------------------------------------------------------------------- +# Integration: I2V end-to-end pipeline +# --------------------------------------------------------------------------- + + +class TestI2VEndToEndPipeline: + """Full I2V pipeline: image → preprocess → VAE encode → y tensor → denoise → VAE decode.""" + + def test_full_i2v_pipeline(self): + """End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode.""" + from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan.vae import WanVAE + + mx.random.seed(0) + + # --- Tiny I2V model config (z_dim=16 to match VAE normalization stats) --- + config = _make_tiny_i2v_config() + config.vae_z_dim = 16 + config.out_dim = 16 # must match VAE z_dim for decode + config.in_dim = 16 + 4 + 16 # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36 + model = WanModel(config) + + # --- Tiny VAE (with encoder) --- + vae = WanVAE(z_dim=config.vae_z_dim, encoder=True) + + # --- Synthetic image: [B=1, 3, T=1, H=32, W=32] in [-1, 1] --- + height, width = 32, 32 + num_frames = 5 # small temporal extent + img = mx.random.uniform(-1, 1, (1, 3, 1, height, width)) + + # Build video: first frame = image, rest = zeros -> [1, 3, F, H, W] + video = mx.concatenate([ + img, + mx.zeros((1, 3, num_frames - 1, height, width)), + ], axis=2) + + # --- VAE encode --- + z_video = vae.encode(video) # [1, z_dim, T_lat, H_lat, W_lat] + mx.eval(z_video) + assert z_video.ndim == 5 + assert z_video.shape[1] == config.vae_z_dim + + z_video = z_video[0] # [z_dim, T_lat, H_lat, W_lat] + t_latent = z_video.shape[1] + h_latent = z_video.shape[2] + w_latent = z_video.shape[3] + + # --- Build I2V mask (4 channels) --- + msk = mx.ones((1, num_frames, h_latent, w_latent)) + msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1) + msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1) + msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent) + msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat] + + # --- Build y tensor: [mask(4ch) + encoded(z_dim ch)] --- + y_i2v = mx.concatenate([msk, z_video], axis=0) + mx.eval(y_i2v) + assert y_i2v.shape[0] == 4 + config.vae_z_dim + + # --- Denoising loop (2 steps) --- + C_noise = config.out_dim # noise channels + pt, ph, pw = config.patch_size + seq_len = (t_latent // pt) * (h_latent // ph) * (w_latent // pw) + + sched = FlowMatchEulerScheduler() + num_steps = 2 + sched.set_timesteps(num_steps, shift=config.sample_shift) + + latents = mx.random.normal((C_noise, t_latent, h_latent, w_latent)) + context = mx.random.normal((4, config.text_dim)) + + for i in range(num_steps): + t_val = sched.timesteps[i].item() + pred = model( + [latents], + mx.array([t_val]), + [context], + seq_len, + y=[y_i2v], + )[0] + latents = sched.step(pred[None], t_val, latents[None]).squeeze(0) + mx.eval(latents) + + assert latents.shape == (C_noise, t_latent, h_latent, w_latent) + assert not mx.any(mx.isnan(latents)).item(), "NaN in denoised latents" + assert not mx.any(mx.isinf(latents)).item(), "Inf in denoised latents" + + # --- VAE decode --- + decoded = vae.decode(latents[None]) # [1, 3, T_out, H_out, W_out] + mx.eval(decoded) + assert decoded.ndim == 5 + assert decoded.shape[0] == 1 + assert decoded.shape[1] == 3 # RGB output + assert not mx.any(mx.isnan(decoded)).item(), "NaN in decoded video" + assert not mx.any(mx.isinf(decoded)).item(), "Inf in decoded video" + # VAE decode clips to [-1, 1] + assert float(decoded.max()) <= 1.0 + assert float(decoded.min()) >= -1.0 + + +class TestDualModelSwitching: + """Test dual-model selection logic: high_noise vs low_noise based on boundary.""" + + def test_model_selection_by_timestep(self): + """Verify high_noise model used for timesteps >= boundary, low_noise otherwise.""" + from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + + mx.random.seed(1) + config = _make_tiny_i2v_config() + assert config.dual_model is True + + high_noise_model = WanModel(config) + low_noise_model = WanModel(config) + + boundary = config.boundary * config.num_train_timesteps # 0.9 * 1000 = 900 + + C_noise = config.out_dim # 4 + C_y = config.in_dim - config.out_dim # 9 - 4 = 5 + F, H, W = 1, 4, 4 + pt, ph, pw = config.patch_size + seq_len = (F // pt) * (H // ph) * (W // pw) + + sched = FlowMatchEulerScheduler() + num_steps = 5 + sched.set_timesteps(num_steps, shift=config.sample_shift) + + guide_scale = config.sample_guide_scale # (3.5, 3.5) + assert isinstance(guide_scale, tuple) and len(guide_scale) == 2 + + latents = mx.random.normal((C_noise, F, H, W)) + y_i2v = mx.random.normal((C_y, F, H, W)) + context = mx.random.normal((4, config.text_dim)) + + high_used_steps = [] + low_used_steps = [] + + timestep_list = sched.timesteps.tolist() + for i in range(num_steps): + timestep_val = timestep_list[i] + + if timestep_val >= boundary: + model = high_noise_model + gs = guide_scale[1] + high_used_steps.append(i) + else: + model = low_noise_model + gs = guide_scale[0] + low_used_steps.append(i) + + # CFG pass: cond + uncond + preds = model( + [latents, latents], + mx.array([timestep_val, timestep_val]), + [context, context], + seq_len, + y=[y_i2v, y_i2v], + ) + noise_pred_cond, noise_pred_uncond = preds[0], preds[1] + noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond) + + latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0) + mx.eval(latents) + + # With shift=5.0, early timesteps should be high (>=900), later ones low + assert len(high_used_steps) > 0, "High-noise model was never selected" + assert len(low_used_steps) > 0, "Low-noise model was never selected" + # High-noise steps should come before low-noise steps (timesteps decrease) + if high_used_steps and low_used_steps: + assert max(high_used_steps) < min(low_used_steps) or \ + min(high_used_steps) < max(low_used_steps), \ + "Model switching should happen during the loop" + + assert latents.shape == (C_noise, F, H, W) + assert not mx.any(mx.isnan(latents)).item() + + def test_guide_scale_tuple_applied_per_model(self): + """Verify (low_gs, high_gs) tuple applies different scales per model.""" + from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + + mx.random.seed(2) + config = _make_tiny_i2v_config() + config.sample_guide_scale = (2.0, 5.0) # distinct values + + model = WanModel(config) + boundary = config.boundary * config.num_train_timesteps + + C_noise = config.out_dim + F, H, W = 1, 4, 4 + pt, ph, pw = config.patch_size + seq_len = (F // pt) * (H // ph) * (W // pw) + + sched = FlowMatchEulerScheduler() + sched.set_timesteps(5, shift=config.sample_shift) + + latents = mx.random.normal((C_noise, F, H, W)) + context = mx.random.normal((4, config.text_dim)) + guide_scale = config.sample_guide_scale + C_y = config.in_dim - config.out_dim # y channels + y_i2v = mx.random.normal((C_y, F, H, W)) + + # Track which guide scale was used at each step + gs_per_step = [] + + timestep_list = sched.timesteps.tolist() + for i in range(5): + timestep_val = timestep_list[i] + + if timestep_val >= boundary: + gs = guide_scale[1] # high_gs = 5.0 + else: + gs = guide_scale[0] # low_gs = 2.0 + gs_per_step.append(gs) + + pred = model( + [latents, latents], + mx.array([timestep_val, timestep_val]), + [context, context], + seq_len, + y=[y_i2v, y_i2v], + ) + noise_pred = pred[1] + gs * (pred[0] - pred[1]) + latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0) + mx.eval(latents) + + # Verify both guide scales were used + assert 5.0 in gs_per_step, "High guide scale (5.0) was never used" + assert 2.0 in gs_per_step, "Low guide scale (2.0) was never used" + # High gs should appear first (high timesteps come first) + first_high = gs_per_step.index(5.0) + last_low = len(gs_per_step) - 1 - gs_per_step[::-1].index(2.0) + assert first_high < last_low, "High gs steps should precede low gs steps" + + def test_single_model_fallback_with_tuple_guide_scale(self): + """When dual_model=False, guide_scale tuple should use first element.""" + from mlx_video.models.wan.model import WanModel + from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + + mx.random.seed(3) + config = _make_tiny_config() + config.dual_model = False + config.sample_guide_scale = (3.0, 5.0) + + model = WanModel(config) + guide_scale = config.sample_guide_scale + + C, F, H, W = config.in_dim, 1, 4, 4 + pt, ph, pw = config.patch_size + seq_len = (F // pt) * (H // ph) * (W // pw) + + sched = FlowMatchEulerScheduler() + sched.set_timesteps(3, shift=3.0) + + latents = mx.random.normal((C, F, H, W)) + context = mx.random.normal((4, config.text_dim)) + + # Mimic generate_wan.py single-model logic: + # gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0] + gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0] + assert gs == 3.0, "Single model should use first element of guide_scale tuple" + + for i in range(3): + t_val = sched.timesteps[i].item() + pred = model( + [latents, latents], + mx.array([t_val, t_val]), + [context, context], + seq_len, + ) + noise_pred = pred[1] + gs * (pred[0] - pred[1]) + latents = sched.step(noise_pred[None], t_val, latents[None]).squeeze(0) + mx.eval(latents) + + assert latents.shape == (C, F, H, W) + assert not mx.any(mx.isnan(latents)).item() diff --git a/tests/test_wan_lora.py b/tests/test_wan_lora.py new file mode 100644 index 0000000..1670d84 --- /dev/null +++ b/tests/test_wan_lora.py @@ -0,0 +1,334 @@ +"""Tests for LoRA loading and application.""" + +import tempfile +from pathlib import Path + +import mlx.core as mx +import numpy as np +import pytest + + +class TestLoRATypes: + """Test LoRA data structures.""" + + def test_lora_weights_scale(self): + from mlx_video.lora.types import LoRAWeights + + w = LoRAWeights( + lora_A=mx.zeros((16, 64)), + lora_B=mx.zeros((128, 16)), + rank=16, + alpha=32.0, + module_name="test", + ) + assert w.scale == 2.0 + + def test_lora_weights_scale_default(self): + from mlx_video.lora.types import LoRAWeights + + w = LoRAWeights( + lora_A=mx.zeros((16, 64)), + lora_B=mx.zeros((128, 16)), + rank=16, + alpha=16.0, + module_name="test", + ) + assert w.scale == 1.0 + + def test_applied_lora_delta(self): + from mlx_video.lora.types import AppliedLoRA, LoRAWeights + + lora_a = mx.ones((2, 4)) + lora_b = mx.ones((8, 2)) + w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test") + applied = AppliedLoRA(weights=w, strength=0.5) + delta = applied.compute_delta() + # scale=1.0, strength=0.5, B@A = [[2,2,2,2]]*8 (each row sum of 2 ones) + expected = 0.5 * mx.ones((8, 4)) * 2.0 + assert mx.allclose(delta, expected).item() + + +class TestLoRALoader: + """Test LoRA weight loading from safetensors.""" + + def _make_lora_file(self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"): + """Helper to create a mock LoRA safetensors file.""" + weights = {} + for name in module_names: + if key_format == "AB": + weights[f"{name}.lora_A.weight"] = mx.random.normal((rank, in_dim)) + weights[f"{name}.lora_B.weight"] = mx.random.normal((out_dim, rank)) + else: + weights[f"{name}.lora_down.weight"] = mx.random.normal((rank, in_dim)) + weights[f"{name}.lora_up.weight"] = mx.random.normal((out_dim, rank)) + path = Path(tmp_dir) / "test_lora.safetensors" + mx.save_safetensors(str(path), weights) + return path + + def test_load_lora_a_b_format(self): + from mlx_video.lora.loader import load_lora_weights + + with tempfile.TemporaryDirectory() as tmp: + path = self._make_lora_file(tmp, ["blocks.0.self_attn.q"], key_format="AB") + lora_weights = load_lora_weights(path) + assert "blocks.0.self_attn.q" in lora_weights + w = lora_weights["blocks.0.self_attn.q"] + assert w.rank == 4 + assert w.alpha == 4.0 # default: alpha == rank + assert w.lora_A.shape == (4, 64) + assert w.lora_B.shape == (128, 4) + + def test_load_lora_down_up_format(self): + from mlx_video.lora.loader import load_lora_weights + + with tempfile.TemporaryDirectory() as tmp: + path = self._make_lora_file( + tmp, ["blocks.0.self_attn.q"], key_format="down_up" + ) + lora_weights = load_lora_weights(path) + assert "blocks.0.self_attn.q" in lora_weights + + def test_load_multiple_modules(self): + from mlx_video.lora.loader import load_lora_weights + + modules = [ + "blocks.0.self_attn.q", + "blocks.0.self_attn.k", + "blocks.0.ffn.fc1", + ] + with tempfile.TemporaryDirectory() as tmp: + path = self._make_lora_file(tmp, modules) + lora_weights = load_lora_weights(path) + assert len(lora_weights) == 3 + for name in modules: + assert name in lora_weights + + def test_load_with_alpha(self): + from mlx_video.lora.loader import load_lora_weights + + with tempfile.TemporaryDirectory() as tmp: + weights = { + "test.lora_A.weight": mx.random.normal((8, 64)), + "test.lora_B.weight": mx.random.normal((128, 8)), + "test.alpha": mx.array(16.0), + } + path = Path(tmp) / "lora.safetensors" + mx.save_safetensors(str(path), weights) + lora_weights = load_lora_weights(path) + assert lora_weights["test"].alpha == 16.0 + assert lora_weights["test"].rank == 8 + assert lora_weights["test"].scale == 2.0 + + def test_file_not_found(self): + from mlx_video.lora.loader import load_lora_weights + + with pytest.raises(FileNotFoundError): + load_lora_weights(Path("/nonexistent/lora.safetensors")) + + +class TestWanKeyNormalization: + """Test Wan2.2 LoRA key normalization.""" + + def _wan_model_keys(self): + """Simulate typical Wan2.2 MLX model weight keys.""" + keys = set() + for i in range(2): + for layer in ["self_attn.q", "self_attn.k", "self_attn.v", "self_attn.o", + "cross_attn.q", "cross_attn.k", "cross_attn.v", "cross_attn.o"]: + keys.add(f"blocks.{i}.{layer}.weight") + keys.add(f"blocks.{i}.ffn.fc1.weight") + keys.add(f"blocks.{i}.ffn.fc2.weight") + keys.add("text_embedding_0.weight") + keys.add("text_embedding_1.weight") + keys.add("time_embedding_0.weight") + keys.add("time_embedding_1.weight") + keys.add("time_projection.weight") + keys.add("patch_embedding_proj.weight") + return keys + + def test_direct_match(self): + from mlx_video.lora.apply import _normalize_wan_lora_key + + keys = self._wan_model_keys() + assert _normalize_wan_lora_key("blocks.0.self_attn.q", keys) == "blocks.0.self_attn.q" + + def test_strip_diffusion_model_prefix(self): + from mlx_video.lora.apply import _normalize_wan_lora_key + + keys = self._wan_model_keys() + result = _normalize_wan_lora_key("diffusion_model.blocks.0.self_attn.q", keys) + assert result == "blocks.0.self_attn.q" + + def test_strip_model_prefix(self): + from mlx_video.lora.apply import _normalize_wan_lora_key + + keys = self._wan_model_keys() + result = _normalize_wan_lora_key("model.diffusion_model.blocks.0.self_attn.k", keys) + assert result == "blocks.0.self_attn.k" + + def test_ffn_key_mapping(self): + from mlx_video.lora.apply import _normalize_wan_lora_key + + keys = self._wan_model_keys() + assert _normalize_wan_lora_key("blocks.0.ffn.0", keys) == "blocks.0.ffn.fc1" + assert _normalize_wan_lora_key("blocks.0.ffn.2", keys) == "blocks.0.ffn.fc2" + + def test_text_embedding_mapping(self): + from mlx_video.lora.apply import _normalize_wan_lora_key + + keys = self._wan_model_keys() + assert _normalize_wan_lora_key("text_embedding.0", keys) == "text_embedding_0" + assert _normalize_wan_lora_key("text_embedding.2", keys) == "text_embedding_1" + + def test_time_embedding_mapping(self): + from mlx_video.lora.apply import _normalize_wan_lora_key + + keys = self._wan_model_keys() + assert _normalize_wan_lora_key("time_embedding.0", keys) == "time_embedding_0" + assert _normalize_wan_lora_key("time_embedding.2", keys) == "time_embedding_1" + + def test_time_projection_mapping(self): + from mlx_video.lora.apply import _normalize_wan_lora_key + + keys = self._wan_model_keys() + assert _normalize_wan_lora_key("time_projection.1", keys) == "time_projection" + + def test_patch_embedding_mapping(self): + from mlx_video.lora.apply import _normalize_wan_lora_key + + keys = self._wan_model_keys() + assert _normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj" + + def test_combined_prefix_and_ffn(self): + from mlx_video.lora.apply import _normalize_wan_lora_key + + keys = self._wan_model_keys() + result = _normalize_wan_lora_key("diffusion_model.blocks.1.ffn.0", keys) + assert result == "blocks.1.ffn.fc1" + + +class TestApplyLoRA: + """Test LoRA delta application to weights.""" + + def test_preserves_bfloat16_dtype(self): + """LoRA delta must not promote bfloat16 weights to float32.""" + from mlx_video.lora.apply import apply_lora_to_linear + from mlx_video.lora.types import LoRAWeights + + original = mx.ones((8, 4), dtype=mx.bfloat16) + # LoRA weights in float32 (typical when loaded from safetensors) + lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1 + lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1 + w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test") + result = apply_lora_to_linear(original, [(w, 1.0)]) + assert result.dtype == mx.bfloat16, f"Expected bfloat16, got {result.dtype}" + + def test_preserves_float16_dtype(self): + from mlx_video.lora.apply import apply_lora_to_linear + from mlx_video.lora.types import LoRAWeights + + original = mx.ones((8, 4), dtype=mx.float16) + lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1 + lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1 + w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test") + result = apply_lora_to_linear(original, [(w, 1.0)]) + assert result.dtype == mx.float16, f"Expected float16, got {result.dtype}" + + def test_apply_single_lora(self): + from mlx_video.lora.apply import apply_lora_to_linear + from mlx_video.lora.types import LoRAWeights + + original = mx.ones((8, 4)) + lora_a = mx.ones((2, 4)) * 0.1 + lora_b = mx.ones((8, 2)) * 0.1 + w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test") + result = apply_lora_to_linear(original, [(w, 1.0)]) + # delta = 1.0 * (B @ A) = ones(8,2)*0.1 @ ones(2,4)*0.1 = 0.02 * ones(8,4) + expected = original + 0.02 * mx.ones((8, 4)) + assert mx.allclose(result, expected, atol=1e-6).item() + + def test_apply_multiple_loras(self): + from mlx_video.lora.apply import apply_lora_to_linear + from mlx_video.lora.types import LoRAWeights + + original = mx.zeros((8, 4)) + w1 = LoRAWeights( + lora_A=mx.ones((2, 4)), + lora_B=mx.ones((8, 2)), + rank=2, alpha=2.0, module_name="a", + ) + w2 = LoRAWeights( + lora_A=mx.ones((2, 4)) * 2, + lora_B=mx.ones((8, 2)) * 2, + rank=2, alpha=4.0, module_name="b", + ) + result = apply_lora_to_linear(original, [(w1, 1.0), (w2, 0.5)]) + # w1 delta: 1.0 * 1.0 * (ones(8,2) @ ones(2,4)) = 2 * ones(8,4) + # w2 delta: 2.0 * 0.5 * (2*ones(8,2) @ 2*ones(2,4)) = 1.0 * 8*ones(8,4) = 8 + delta1 = mx.ones((8, 4)) * 2.0 + delta2 = mx.ones((8, 4)) * 8.0 + expected = delta1 + delta2 + assert mx.allclose(result, expected, atol=1e-5).item() + + def test_apply_loras_to_weights_dict(self): + from mlx_video.lora.apply import apply_loras_to_weights + from mlx_video.lora.types import LoRAWeights + + model_weights = { + "blocks.0.self_attn.q.weight": mx.ones((128, 64)), + "blocks.0.self_attn.k.weight": mx.ones((128, 64)), + "blocks.0.ffn.fc1.weight": mx.ones((256, 64)), + } + w = LoRAWeights( + lora_A=mx.ones((4, 64)) * 0.01, + lora_B=mx.ones((128, 4)) * 0.01, + rank=4, alpha=4.0, module_name="blocks.0.self_attn.q", + ) + module_to_loras = {"blocks.0.self_attn.q": [(w, 1.0)]} + result = apply_loras_to_weights(model_weights, module_to_loras) + # Only q should be modified + assert not mx.array_equal( + result["blocks.0.self_attn.q.weight"], + model_weights["blocks.0.self_attn.q.weight"], + ).item() + assert mx.array_equal( + result["blocks.0.self_attn.k.weight"], + model_weights["blocks.0.self_attn.k.weight"], + ).item() + + +class TestEndToEnd: + """End-to-end LoRA loading and application.""" + + def test_load_and_apply_loras(self): + from mlx_video.convert_wan import load_and_apply_loras + + with tempfile.TemporaryDirectory() as tmp: + # Create mock LoRA safetensors + rank = 4 + weights = { + "blocks.0.self_attn.q.lora_A.weight": mx.random.normal((rank, 64)), + "blocks.0.self_attn.q.lora_B.weight": mx.random.normal((128, rank)), + } + lora_path = Path(tmp) / "test.safetensors" + mx.save_safetensors(str(lora_path), weights) + + # Create mock model weights + model_weights = { + "blocks.0.self_attn.q.weight": mx.ones((128, 64)), + "blocks.0.self_attn.k.weight": mx.ones((128, 64)), + } + + result = load_and_apply_loras( + model_weights, [(str(lora_path), 1.0)] + ) + + # q weight should be modified, k unchanged + assert not mx.array_equal( + result["blocks.0.self_attn.q.weight"], + model_weights["blocks.0.self_attn.q.weight"], + ).item() + assert mx.array_equal( + result["blocks.0.self_attn.k.weight"], + model_weights["blocks.0.self_attn.k.weight"], + ).item() diff --git a/tests/test_wan_vae.py b/tests/test_wan_vae.py index 2d7fefb..cd2cf94 100644 --- a/tests/test_wan_vae.py +++ b/tests/test_wan_vae.py @@ -868,4 +868,84 @@ class TestVAEEncoderTemporalOrder: assert out_wrong.shape[1] == 2 +# --------------------------------------------------------------------------- +# VAE Encode → Decode Round-Trip Tests +# --------------------------------------------------------------------------- + + +class TestVAE21RoundTrip: + """Encode→decode round-trip for Wan 2.1 VAE (channels-first).""" + + def test_encode_decode_shape_and_values(self): + """Encoder3d → Decoder3d: output shape matches input, values are finite.""" + from mlx_video.models.wan.vae import Decoder3d, Encoder3d + + z_dim = 4 + dim = 8 + # No temporal up/downsampling to keep the test simple + enc = Encoder3d( + dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False] + ) + dec = Decoder3d( + dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False] + ) + mx.eval(enc.parameters(), dec.parameters()) + + # [B=1, C=3, T=1, H=8, W=8] + x = mx.random.normal((1, 3, 1, 8, 8)) * 0.5 + + z = enc(x) + mx.eval(z) + # 3 spatial downsamples (÷8): H=1, W=1 + assert z.shape == (1, z_dim, 1, 1, 1) + + x_hat = dec(z) + mx.eval(x_hat) + # 3 spatial upsamples (×8): should recover original shape + assert x_hat.shape == x.shape + + out_np = np.array(x_hat) + assert np.all(np.isfinite(out_np)) + assert np.abs(out_np).max() < 1000 + + +class TestVAE22RoundTrip: + """Encode→decode round-trip for Wan 2.2 VAE (channels-last).""" + + def test_encode_decode_shape_and_values(self): + """Wan22VAEEncoder → Wan22VAEDecoder: shapes consistent, values in range.""" + from mlx_video.models.wan.vae22 import ( + Wan22VAEDecoder, + Wan22VAEEncoder, + denormalize_latents, + ) + + enc = Wan22VAEEncoder(z_dim=48, dim=16) + dec = Wan22VAEDecoder(z_dim=48, dec_dim=8) + mx.eval(enc.parameters(), dec.parameters()) + + # [B=1, T=1, H=32, W=32, C=3] + img = mx.random.normal((1, 1, 32, 32, 3)) * 0.5 + + z_norm = enc(img) + mx.eval(z_norm) + # patchify(÷2) + 3 spatial downsamples(÷8) = ÷16 + assert z_norm.shape == (1, 1, 2, 2, 48) + + z = denormalize_latents(z_norm) + out = dec(z) + mx.eval(out) + + # 3 spatial upsamples(×8) + unpatchify(×2) = ×16 + assert out.shape[0] == 1 # batch + assert out.shape[2] == 32 # H recovered + assert out.shape[3] == 32 # W recovered + assert out.shape[-1] == 3 # RGB + + out_np = np.array(out) + assert np.all(np.isfinite(out_np)) + assert out_np.min() >= -1.0 - 1e-6 + assert out_np.max() <= 1.0 + 1e-6 + +