feat(wan): Add LoRA with improved quantization pipeline

This commit is contained in:
Daniel
2026-02-28 14:11:13 +01:00
parent dbab95ec45
commit 849cc45d84
17 changed files with 1852 additions and 111 deletions

View File

@@ -1,13 +1,16 @@
"""Weight conversion for Wan2.2 models (PyTorch -> MLX)."""
import gc
import logging
from pathlib import Path
from typing import Dict
from typing import Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.utils
import numpy as np
logger = logging.getLogger(__name__)
def load_torch_weights(path: str) -> Dict[str, mx.array]:
"""Load PyTorch .pth weights and convert to MLX arrays.
@@ -88,6 +91,7 @@ def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str,
etc.
"""
sanitized = {}
consumed = set()
for key, value in weights.items():
new_key = key
@@ -99,36 +103,43 @@ def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str,
value = value.reshape(value.shape[0], -1)
new_key = "patch_embedding_proj.weight"
sanitized[new_key] = value
consumed.add(key)
continue
if key == "patch_embedding.bias":
new_key = "patch_embedding_proj.bias"
sanitized[new_key] = value
consumed.add(key)
continue
# Text embedding Sequential: 0=Linear, 1=GELU(no params), 2=Linear
if key.startswith("text_embedding.0."):
new_key = key.replace("text_embedding.0.", "text_embedding_0.")
sanitized[new_key] = value
consumed.add(key)
continue
if key.startswith("text_embedding.2."):
new_key = key.replace("text_embedding.2.", "text_embedding_1.")
sanitized[new_key] = value
consumed.add(key)
continue
# Time embedding Sequential: 0=Linear, 1=SiLU(no params), 2=Linear
if key.startswith("time_embedding.0."):
new_key = key.replace("time_embedding.0.", "time_embedding_0.")
sanitized[new_key] = value
consumed.add(key)
continue
if key.startswith("time_embedding.2."):
new_key = key.replace("time_embedding.2.", "time_embedding_1.")
sanitized[new_key] = value
consumed.add(key)
continue
# Time projection Sequential: 0=SiLU(no params), 1=Linear
if key.startswith("time_projection.1."):
new_key = key.replace("time_projection.1.", "time_projection.")
sanitized[new_key] = value
consumed.add(key)
continue
# FFN: Sequential(Linear, GELU, Linear) -> ffn.{0,2} -> ffn.fc1, ffn.fc2
@@ -137,9 +148,15 @@ def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str,
# Skip the freqs buffer (we compute it in the model)
if key == "freqs":
consumed.add(key)
continue
sanitized[new_key] = value
consumed.add(key)
unconsumed = set(weights.keys()) - consumed
if unconsumed:
logger.warning("Unconsumed transformer weight keys: %s", sorted(unconsumed))
return sanitized
@@ -171,6 +188,7 @@ def sanitize_wan_t5_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]
norm.weight
"""
sanitized = {}
consumed = set()
for key, value in weights.items():
new_key = key
@@ -179,6 +197,11 @@ def sanitize_wan_t5_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]
new_key = new_key.replace(".ffn.gate.0.", ".ffn.gate_proj.")
sanitized[new_key] = value
consumed.add(key)
unconsumed = set(weights.keys()) - consumed
if unconsumed:
logger.warning("Unconsumed T5 weight keys: %s", sorted(unconsumed))
return sanitized
@@ -189,6 +212,7 @@ def sanitize_wan_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array
Handles Conv3d and Conv2d weight transpositions for MLX format.
"""
sanitized = {}
consumed = set()
for key, value in weights.items():
new_key = key
@@ -206,10 +230,78 @@ def sanitize_wan_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array
# Need to adapt naming for our simplified structure
sanitized[new_key] = value
consumed.add(key)
unconsumed = set(weights.keys()) - consumed
if unconsumed:
logger.warning("Unconsumed VAE weight keys: %s", sorted(unconsumed))
return sanitized
def _load_lora_configs(
lora_configs: List[Tuple[str, float]],
) -> Dict[str, list]:
"""Load LoRA weights from config tuples, returning module_to_loras dict.
Shared between weight-merging and runtime-wrapping paths.
"""
from mlx_video.lora import LoRAConfig, load_multiple_loras
from mlx_video.utils import Colors
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
configs = []
for lora_path, strength in lora_configs:
try:
config = LoRAConfig(path=lora_path, strength=strength)
configs.append(config)
print(f" - {Path(lora_path).name} (strength: {strength})")
except Exception as e:
print(f"{Colors.RED}Error loading LoRA {lora_path}: {e}{Colors.RESET}")
raise
module_to_loras = load_multiple_loras(configs)
if not module_to_loras:
print(f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}")
return module_to_loras
def load_and_apply_loras(
model_weights: Dict[str, mx.array],
lora_configs: Optional[List[Tuple[str, float]]] = None,
verbose: bool = False,
quantization_bits: int = 0,
) -> Dict[str, mx.array]:
"""Load and apply LoRA weights to model weights by merging into weight dict.
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
"""
from mlx_video.lora import apply_loras_to_weights
from mlx_video.utils import Colors
if not lora_configs:
return model_weights
module_to_loras = _load_lora_configs(lora_configs)
if not module_to_loras:
return model_weights
print(f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}")
if verbose:
print(f" Model has {len(model_weights)} weight keys")
modified_weights = apply_loras_to_weights(
model_weights, module_to_loras, verbose=verbose, quantization_bits=quantization_bits
)
print(f"{Colors.GREEN}✓ LoRAs applied successfully{Colors.RESET}")
return modified_weights
def convert_wan_checkpoint(
checkpoint_dir: str,
output_dir: str,
@@ -464,30 +556,45 @@ def _quantize_saved_model(
is_dual: bool,
bits: int,
group_size: int,
source_dir: Path = None,
):
"""Load saved bf16 model, quantize, and re-save."""
"""Load saved bf16 model, quantize, and re-save.
Args:
output_dir: Directory to write quantized weights to.
config: WanModelConfig for creating the model.
is_dual: Whether this is a dual-expert model.
bits: Quantization bits.
group_size: Quantization group size.
source_dir: Directory to read bf16 weights from. Defaults to output_dir.
"""
import json
import mlx.nn as nn
from mlx_video.models.wan.model import WanModel
model_files = []
if source_dir is None:
source_dir = output_dir
model_names = []
if is_dual:
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
p = output_dir / name
if p.exists():
model_files.append(p)
if (source_dir / name).exists():
model_names.append(name)
else:
p = output_dir / "model.safetensors"
if p.exists():
model_files.append(p)
if (source_dir / "model.safetensors").exists():
model_names.append("model.safetensors")
for model_path in model_files:
print(f" Quantizing {model_path.name}...")
for name in model_names:
print(f" Quantizing {name}...")
model = WanModel(config)
weights = mx.load(str(model_path))
weights = mx.load(str(source_dir / name))
model.load_weights(list(weights.items()), strict=False)
mx.eval(model.parameters())
del weights
gc.collect()
mx.clear_cache()
# Apply quantization to targeted layers
nn.quantize(
@@ -499,10 +606,30 @@ def _quantize_saved_model(
# Save quantized weights
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
mx.save_safetensors(str(model_path), weights_dict)
# Validate: check for NaN/Inf in bias tensors (corruption canary)
bad_keys = []
for k, v in weights_dict.items():
if k.endswith(".bias") and not k.endswith(".biases"):
mx.eval(v)
if mx.any(mx.isnan(v)).item() or mx.any(mx.isinf(v)).item():
bad_keys.append(k)
if bad_keys:
raise RuntimeError(
f"Quantization produced corrupted weights in {model_path.name}: "
f"{len(bad_keys)} bias tensors contain NaN/Inf "
f"(e.g. {bad_keys[0]}). Try re-running with more available memory."
)
mx.save_safetensors(str(output_dir / name), weights_dict)
n_quantized = sum(1 for k in weights_dict if ".scales" in k)
print(f" {n_quantized} layers quantized, {len(weights_dict)} tensors saved")
# Free model before processing next file
del model, weights_dict
gc.collect()
mx.clear_cache()
# Update config.json with quantization metadata
config_path = output_dir / "config.json"
with open(config_path) as f:
@@ -516,6 +643,68 @@ def _quantize_saved_model(
print(f" Updated config.json with quantization metadata")
def quantize_mlx_model(
mlx_model_dir: str,
output_dir: str,
bits: int = 4,
group_size: int = 64,
):
"""Quantize an already-converted MLX model (skips PyTorch conversion).
Args:
mlx_model_dir: Path to existing MLX model directory (bf16/fp16).
output_dir: Path to output quantized model directory.
bits: Quantization bits (4 or 8).
group_size: Quantization group size (32, 64, or 128).
"""
import json
import shutil
src = Path(mlx_model_dir)
dst = Path(output_dir)
config_path = src / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"No config.json found in {src}")
with open(config_path) as f:
cfg = json.load(f)
if cfg.get("quantization"):
raise ValueError(
f"Model at {src} is already quantized "
f"({cfg['quantization']['bits']}-bit). Use a bf16/fp16 source."
)
# Detect dual vs single expert
is_dual = (src / "low_noise_model.safetensors").exists() and (
src / "high_noise_model.safetensors"
).exists()
# Build model config
from mlx_video.models.wan.config import WanModelConfig
config_dict = {k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__}
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
if key in config_dict and isinstance(config_dict[key], list):
config_dict[key] = tuple(config_dict[key])
config = WanModelConfig(**config_dict)
# Copy non-transformer files to output dir (skip large model weights)
transformer_files = {"low_noise_model.safetensors", "high_noise_model.safetensors", "model.safetensors"}
if dst.resolve() != src.resolve():
dst.mkdir(parents=True, exist_ok=True)
for f in src.iterdir():
if f.is_file() and f.name not in transformer_files:
shutil.copy2(f, dst / f.name)
print(f"Copied non-transformer files from {src} to {dst}")
print(f"Quantizing transformer weights ({bits}-bit, group_size={group_size})...")
_quantize_saved_model(dst, config, is_dual, bits, group_size, source_dir=src)
print(f"\nQuantization complete! Output: {dst}")
if __name__ == "__main__":
import argparse
@@ -551,6 +740,11 @@ if __name__ == "__main__":
action="store_true",
help="Quantize transformer weights for faster inference",
)
parser.add_argument(
"--quantize-only",
action="store_true",
help="Quantize an already-converted MLX model (skips PyTorch conversion)",
)
parser.add_argument(
"--bits",
type=int,
@@ -566,7 +760,14 @@ if __name__ == "__main__":
help="Quantization group size (default: 64)",
)
args = parser.parse_args()
convert_wan_checkpoint(
args.checkpoint_dir, args.output_dir, args.dtype, args.model_version,
quantize=args.quantize, bits=args.bits, group_size=args.group_size,
)
if args.quantize_only:
quantize_mlx_model(
args.checkpoint_dir, args.output_dir,
bits=args.bits, group_size=args.group_size,
)
else:
convert_wan_checkpoint(
args.checkpoint_dir, args.output_dir, args.dtype, args.model_version,
quantize=args.quantize, bits=args.bits, group_size=args.group_size,
)

View File

@@ -43,6 +43,10 @@ def generate_video(
seed: int = -1,
output_path: str = "output.mp4",
scheduler: str = "unipc",
loras: list | None = None,
loras_high: list | None = None,
loras_low: list | None = None,
):
"""Generate video using Wan pipeline (supports T2V and I2V).
@@ -60,6 +64,10 @@ def generate_video(
seed: Random seed (-1 for random)
output_path: Output video path
scheduler: Solver type: 'euler', 'dpm++', or 'unipc' (default)
loras: Optional list of (path, strength) tuples applied to all models
loras_high: Optional list of (path, strength) tuples for high-noise model only
loras_low: Optional list of (path, strength) tuples for low-noise model only
"""
import json
@@ -156,6 +164,12 @@ def generate_video(
parts = [float(x) for x in guide_scale.split(",")]
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
# Detect CFG-disabled mode (guide_scale=1.0 for all models → skip uncond pass for 2x speedup)
if isinstance(guide_scale, tuple):
cfg_disabled = all(gs <= 1.0 for gs in guide_scale)
else:
cfg_disabled = guide_scale <= 1.0
# Validate frame count
assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}"
@@ -181,6 +195,8 @@ def generate_video(
print(f" Neg prompt: {neg_display}")
print(f" Size: {width}x{height}, Frames: {num_frames}")
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}")
if cfg_disabled:
print(f" CFG: disabled (guide_scale≤1 → B=1 fast path, 2x denoising speedup)")
print(f"{Colors.RESET}")
# Seed
@@ -233,8 +249,12 @@ def generate_video(
# Encode prompts
print(f"{Colors.BLUE}Encoding text...{Colors.RESET}")
context = encode_text(t5_encoder, tokenizer, prompt, config.text_len)
context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len)
mx.eval(context, context_null)
if cfg_disabled:
context_null = None
mx.eval(context)
else:
context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len)
mx.eval(context, context_null)
# Free T5 from memory
del t5_encoder
@@ -319,48 +339,78 @@ def generate_video(
print(f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}")
t2 = time.time()
# Merge per-model LoRAs with shared LoRAs
_loras_low = (loras or []) + (loras_low or []) or None
_loras_high = (loras or []) + (loras_high or []) or None
_loras_single = loras
if is_dual:
low_noise_path = model_dir / "low_noise_model.safetensors"
high_noise_path = model_dir / "high_noise_model.safetensors"
low_noise_model = load_wan_model(low_noise_path, config, quantization)
high_noise_model = load_wan_model(high_noise_path, config, quantization)
low_noise_model = load_wan_model(low_noise_path, config, quantization, loras=_loras_low)
high_noise_model = load_wan_model(high_noise_path, config, quantization, loras=_loras_high)
else:
single_model = load_wan_model(model_dir / "model.safetensors", config, quantization)
single_model = load_wan_model(model_dir / "model.safetensors", config, quantization, loras=_loras_single)
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
# Precompute text embeddings once (avoids redundant MLP in every step)
# Each model has its own text_embedding weights, so dual models need separate embeddings
if is_dual:
context_emb_low = low_noise_model.embed_text([context, context_null])
context_emb_high = high_noise_model.embed_text([context, context_null])
mx.eval(context_emb_low, context_emb_high)
context_cfg_low = mx.concatenate([context_emb_low[0:1], context_emb_low[1:2]], axis=0)
context_cfg_high = mx.concatenate([context_emb_high[0:1], context_emb_high[1:2]], axis=0)
if cfg_disabled:
# No CFG: only compute cond embeddings (B=1 forward pass, 2x faster)
if is_dual:
context_emb_low = low_noise_model.embed_text([context])
context_emb_high = high_noise_model.embed_text([context])
mx.eval(context_emb_low, context_emb_high)
context_cond_low = context_emb_low[0:1]
context_cond_high = context_emb_high[0:1]
else:
context_emb = single_model.embed_text([context])
mx.eval(context_emb)
context_cond = context_emb[0:1]
else:
context_emb = single_model.embed_text([context, context_null])
mx.eval(context_emb)
context_cfg = mx.concatenate([context_emb[0:1], context_emb[1:2]], axis=0)
if is_dual:
context_emb_low = low_noise_model.embed_text([context, context_null])
context_emb_high = high_noise_model.embed_text([context, context_null])
mx.eval(context_emb_low, context_emb_high)
context_cfg_low = mx.concatenate([context_emb_low[0:1], context_emb_low[1:2]], axis=0)
context_cfg_high = mx.concatenate([context_emb_high[0:1], context_emb_high[1:2]], axis=0)
else:
context_emb = single_model.embed_text([context, context_null])
mx.eval(context_emb)
context_cfg = mx.concatenate([context_emb[0:1], context_emb[1:2]], axis=0)
# Precompute cross-attention K/V caches (constant across all steps)
if is_dual:
cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg_low)
cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg_high)
mx.eval(cross_kv_low, cross_kv_high)
if cfg_disabled:
if is_dual:
cross_kv_low = low_noise_model.prepare_cross_kv(context_cond_low)
cross_kv_high = high_noise_model.prepare_cross_kv(context_cond_high)
mx.eval(cross_kv_low, cross_kv_high)
else:
cross_kv = single_model.prepare_cross_kv(context_cond)
mx.eval(cross_kv)
else:
cross_kv = single_model.prepare_cross_kv(context_cfg)
mx.eval(cross_kv)
if is_dual:
cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg_low)
cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg_high)
mx.eval(cross_kv_low, cross_kv_high)
else:
cross_kv = single_model.prepare_cross_kv(context_cfg)
mx.eval(cross_kv)
# Precompute RoPE frequencies (grid sizes are constant across all steps)
f_grid = t_latent // patch_size[0]
h_grid = h_latent // patch_size[1]
w_grid = w_latent // patch_size[2]
cfg_grid_sizes = [(f_grid, h_grid, w_grid), (f_grid, h_grid, w_grid)]
if cfg_disabled:
rope_grid_sizes = [(f_grid, h_grid, w_grid)]
else:
rope_grid_sizes = [(f_grid, h_grid, w_grid), (f_grid, h_grid, w_grid)]
if is_dual:
rope_cos_sin_low = low_noise_model.prepare_rope(cfg_grid_sizes)
rope_cos_sin_high = high_noise_model.prepare_rope(cfg_grid_sizes)
rope_cos_sin_low = low_noise_model.prepare_rope(rope_grid_sizes)
rope_cos_sin_high = high_noise_model.prepare_rope(rope_grid_sizes)
mx.eval(rope_cos_sin_low, rope_cos_sin_high)
else:
rope_cos_sin = ref_model.prepare_rope(cfg_grid_sizes)
rope_cos_sin = ref_model.prepare_rope(rope_grid_sizes)
mx.eval(rope_cos_sin)
# Setup scheduler
@@ -395,58 +445,86 @@ def generate_video(
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
timestep_val = timestep_list[i]
# Select model, guide scale, cached K/V, and precomputed RoPE
# Select model, cached K/V, and precomputed RoPE
if is_dual:
if timestep_val >= boundary:
model = high_noise_model
gs = guide_scale[1]
kv = cross_kv_high
rcs = rope_cos_sin_high
else:
model = low_noise_model
gs = guide_scale[0]
kv = cross_kv_low
rcs = rope_cos_sin_low
else:
model = single_model
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
kv = cross_kv
rcs = rope_cos_sin
# Build per-token timesteps for TI2V-5B (first-frame patches get t=0)
if is_i2v_mask_blend:
t_tokens = i2v_mask_tokens * timestep_val # [1, L]
# Pad to seq_len if needed
pad_len = seq_len - t_tokens.shape[1]
if pad_len > 0:
t_tokens = mx.concatenate(
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
)
# Batch for CFG: both cond and uncond get same timesteps
t_batch = mx.concatenate([t_tokens, t_tokens], axis=0) # [2, L]
if cfg_disabled:
# No CFG: B=1 forward pass (2x faster than B=2 CFG batch)
if is_i2v_mask_blend:
t_tokens = i2v_mask_tokens * timestep_val
pad_len = seq_len - t_tokens.shape[1]
if pad_len > 0:
t_tokens = mx.concatenate(
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
)
t_batch = t_tokens # [1, L]
else:
t_batch = mx.array([timestep_val])
y_arg = [y_i2v] if is_i2v_channel_concat else None
if is_dual:
ctx = context_cond_high if timestep_val >= boundary else context_cond_low
else:
ctx = context_cond
preds = model(
[latents],
t=t_batch,
context=ctx,
seq_len=seq_len,
cross_kv_caches=kv,
y=y_arg,
rope_cos_sin=rcs,
)
noise_pred = preds[0]
del preds
else:
t_batch = mx.array([timestep_val, timestep_val])
# CFG: batch cond + uncond into single B=2 forward pass
if is_dual:
gs = guide_scale[1] if timestep_val >= boundary else guide_scale[0]
else:
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
# I2V-14B: pass y conditioning to model (same y for cond and uncond)
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
if is_i2v_mask_blend:
t_tokens = i2v_mask_tokens * timestep_val
pad_len = seq_len - t_tokens.shape[1]
if pad_len > 0:
t_tokens = mx.concatenate(
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
)
t_batch = mx.concatenate([t_tokens, t_tokens], axis=0)
else:
t_batch = mx.array([timestep_val, timestep_val])
# CFG: batch cond + uncond into single B=2 forward pass
ctx = context_cfg if not is_dual else (
context_cfg_high if timestep_val >= boundary else context_cfg_low
)
preds = model(
[latents, latents],
t=t_batch,
context=ctx,
seq_len=seq_len,
cross_kv_caches=kv,
y=y_arg,
rope_cos_sin=rcs,
)
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
# Classifier-free guidance + scheduler step
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
ctx = context_cfg if not is_dual else (
context_cfg_high if timestep_val >= boundary else context_cfg_low
)
preds = model(
[latents, latents],
t=t_batch,
context=ctx,
seq_len=seq_len,
cross_kv_caches=kv,
y=y_arg,
rope_cos_sin=rcs,
)
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
del noise_pred_cond, noise_pred_uncond, preds
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
@@ -455,7 +533,7 @@ def generate_video(
latents = (1.0 - i2v_mask) * z_img + i2v_mask * latents
# Release temporaries before eval to free memory for graph execution
del noise_pred_cond, noise_pred_uncond, noise_pred, preds
del noise_pred
mx.eval(latents)
print(f"{Colors.DIM} Denoising: {time.time() - t3:.1f}s{Colors.RESET}")
@@ -463,11 +541,19 @@ def generate_video(
# Free transformer models and text embeddings
if is_dual:
del low_noise_model, high_noise_model, cross_kv_low, cross_kv_high
del context_cfg_low, context_cfg_high
if cfg_disabled:
del context_cond_low, context_cond_high
else:
del context_cfg_low, context_cfg_high
else:
del single_model, cross_kv
del context_cfg
del model, kv, context, context_null
if cfg_disabled:
del context_cond
else:
del context_cfg
del model, kv, context
if context_null is not None:
del context_null
gc.collect(); mx.clear_cache()
# Load VAE and decode
@@ -478,25 +564,36 @@ def generate_video(
is_wan22_vae = config.vae_z_dim == 48
# Warm-up: prepend a copy of the first latent frame to provide temporal
# context for the real first frame. Causal convolutions in the VAE decoder
# pad with zeros on the left, so the first few output frames have degraded
# quality (no temporal context). By duplicating the first latent, the real
# first frame sees its own features as left context instead of zeros.
# We trim the extra output frames after decoding.
warmup_trim = vae_stride[0] # 4 frames per latent temporal position
latents_for_decode = mx.concatenate([latents[:, 0:1], latents], axis=1)
if is_wan22_vae:
from mlx_video.models.wan.vae22 import denormalize_latents
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
z = latents.transpose(1, 2, 3, 0)[None] # [1, T, H, W, C]
z = latents_for_decode.transpose(1, 2, 3, 0)[None] # [1, T+1, H, W, C]
z = denormalize_latents(z)
video = vae(z) # [1, T', H', W', 3]
mx.eval(video)
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
video = np.array(video[0]) # [T', H', W', 3]
video = video[warmup_trim:] # Trim warm-up frames
video = (video + 1.0) / 2.0
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
else:
video = vae.decode(latents[None]) # [1, 3, T, H, W]
video = vae.decode(latents_for_decode[None]) # [1, 3, T+1*4, H, W]
mx.eval(video)
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
video = np.array(video[0]) # [3, T, H, W]
video = np.array(video[0]) # [3, T', H, W]
video = video[:, warmup_trim:] # Trim warm-up frames (channels-first)
video = (video + 1.0) / 2.0
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
video = video.transpose(1, 2, 3, 0) # [T, H, W, 3]
@@ -529,6 +626,19 @@ def main():
choices=["euler", "dpm++", "unipc"],
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
)
parser.add_argument(
"--lora", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to all models (repeatable). Format: --lora path.safetensors 0.8",
)
parser.add_argument(
"--lora-high", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to high-noise model only (dual-model, repeatable)",
)
parser.add_argument(
"--lora-low", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to low-noise model only (dual-model, repeatable)",
)
args = parser.parse_args()
# Parse guide scale
@@ -542,6 +652,12 @@ def main():
if args.no_negative_prompt:
neg_prompt = ""
# Parse LoRA configs: convert [path, strength_str] → (path, float)
def _parse_lora_args(lora_list):
if not lora_list:
return None
return [(path, float(strength)) for path, strength in lora_list]
generate_video(
model_dir=args.model_dir,
prompt=args.prompt,
@@ -556,6 +672,10 @@ def main():
seed=args.seed,
output_path=args.output_path,
scheduler=args.scheduler,
loras=_parse_lora_args(args.lora),
loras_high=_parse_lora_args(args.lora_high),
loras_low=_parse_lora_args(args.lora_low),
)

View File

@@ -0,0 +1,25 @@
"""LoRA support for mlx-video."""
from mlx_video.lora.apply import (
LoRALinear,
apply_lora_to_linear,
apply_loras_to_model,
apply_loras_to_weights,
)
from mlx_video.lora.loader import (
load_lora_weights,
load_multiple_loras,
)
from mlx_video.lora.types import AppliedLoRA, LoRAConfig, LoRAWeights
__all__ = [
"LoRAConfig",
"LoRAWeights",
"AppliedLoRA",
"load_lora_weights",
"load_multiple_loras",
"apply_lora_to_linear",
"apply_loras_to_weights",
"apply_loras_to_model",
"LoRALinear",
]

393
mlx_video/lora/apply.py Normal file
View File

@@ -0,0 +1,393 @@
"""Apply LoRA weights to model layers."""
from typing import Dict, List, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.lora.types import LoRAWeights
def apply_lora_to_linear(
linear_weight: mx.array,
lora_weights_and_strengths: List[Tuple[LoRAWeights, float]],
) -> mx.array:
"""Apply one or more LoRAs to a linear layer weight.
Args:
linear_weight: Original weight matrix [out_features, in_features]
lora_weights_and_strengths: List of (LoRAWeights, strength) tuples
Returns:
Modified weight with LoRA deltas applied (preserves original dtype)
"""
orig_dtype = linear_weight.dtype
modified_weight = linear_weight
for weights, strength in lora_weights_and_strengths:
scale = weights.scale
# Compute delta in float32 for precision, then cast back to avoid
# promoting model weights (e.g. bfloat16 → float32 causes ~1.5x slowdown)
delta = (weights.lora_B @ weights.lora_A) * (scale * strength)
modified_weight = modified_weight + delta.astype(orig_dtype)
return modified_weight
def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str:
"""Normalize LoRA module name to match Wan2.2 MLX model weight keys.
Handles:
- Stripping common prefixes (diffusion_model., model., etc.)
- FFN key mapping: ffn.0 → ffn.fc1, ffn.2 → ffn.fc2
- Embedding key mapping: text_embedding.0 → text_embedding_0, etc.
- Time projection: time_projection.1 → time_projection
- Patch embedding: patch_embedding → patch_embedding_proj
Args:
lora_key: Original LoRA module name
model_keys: Set of all model weight keys
Returns:
Normalized key that matches model weights
"""
# Try the key as-is first
if f"{lora_key}.weight" in model_keys or lora_key in model_keys:
return lora_key
# Common prefixes to strip
prefixes_to_strip = [
"model.diffusion_model.",
"diffusion_model.",
"base_model.model.",
"model.",
]
candidates = [lora_key]
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
candidates.append(lora_key[len(prefix):])
for candidate in candidates:
# Try as-is
if f"{candidate}.weight" in model_keys or candidate in model_keys:
return candidate
# Apply Wan2.2 key transformations
transformed = candidate
# FFN: ffn.0 → ffn.fc1, ffn.2 → ffn.fc2
transformed = transformed.replace(".ffn.0.", ".ffn.fc1.")
transformed = transformed.replace(".ffn.2.", ".ffn.fc2.")
if transformed.endswith(".ffn.0"):
transformed = transformed[:-len(".ffn.0")] + ".ffn.fc1"
if transformed.endswith(".ffn.2"):
transformed = transformed[:-len(".ffn.2")] + ".ffn.fc2"
# Text embedding: text_embedding.0 → text_embedding_0
transformed = transformed.replace("text_embedding.0.", "text_embedding_0.")
transformed = transformed.replace("text_embedding.2.", "text_embedding_1.")
if transformed.endswith("text_embedding.0"):
transformed = transformed[:-len("text_embedding.0")] + "text_embedding_0"
if transformed.endswith("text_embedding.2"):
transformed = transformed[:-len("text_embedding.2")] + "text_embedding_1"
# Time embedding: time_embedding.0 → time_embedding_0
transformed = transformed.replace("time_embedding.0.", "time_embedding_0.")
transformed = transformed.replace("time_embedding.2.", "time_embedding_1.")
if transformed.endswith("time_embedding.0"):
transformed = transformed[:-len("time_embedding.0")] + "time_embedding_0"
if transformed.endswith("time_embedding.2"):
transformed = transformed[:-len("time_embedding.2")] + "time_embedding_1"
# Time projection: time_projection.1 → time_projection
transformed = transformed.replace("time_projection.1.", "time_projection.")
if transformed.endswith("time_projection.1"):
transformed = transformed[:-len("time_projection.1")] + "time_projection"
# Patch embedding: patch_embedding → patch_embedding_proj
if "patch_embedding" in transformed and "patch_embedding_proj" not in transformed:
transformed = transformed.replace("patch_embedding", "patch_embedding_proj")
if f"{transformed}.weight" in model_keys or transformed in model_keys:
return transformed
# Return best attempt with prefix stripped
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
return lora_key[len(prefix):]
return lora_key
# Also support LTX-style key normalization
def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str:
"""Normalize LoRA module name to match LTX MLX model weight keys."""
if f"{lora_key}.weight" in model_keys or lora_key in model_keys:
return lora_key
prefixes_to_strip = [
"model.diffusion_model.",
"diffusion_model.",
"model.",
]
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
normalized = lora_key[len(prefix):]
if f"{normalized}.weight" in model_keys or normalized in model_keys:
return normalized
transformed = normalized
if transformed.endswith(".to_out.0"):
transformed = transformed[:-len(".to_out.0")] + ".to_out"
transformed = transformed.replace(".to_out.0.", ".to_out.")
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.")
transformed = transformed.replace(".ff.net.2", ".ff.proj_out")
transformed = transformed.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
transformed = transformed.replace(".audio_ff.net.0.proj", ".audio_ff.proj_in")
transformed = transformed.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
transformed = transformed.replace(".audio_ff.net.2", ".audio_ff.proj_out")
if f"{transformed}.weight" in model_keys or transformed in model_keys:
return transformed
# Try transformations on the original key
transformed = lora_key
if transformed.endswith(".to_out.0"):
transformed = transformed[:-len(".to_out.0")] + ".to_out"
transformed = transformed.replace(".to_out.0.", ".to_out.")
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.")
transformed = transformed.replace(".ff.net.2", ".ff.proj_out")
if f"{transformed}.weight" in model_keys or transformed in model_keys:
return transformed
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
return lora_key[len(prefix):]
return lora_key
def _normalize_lora_key(lora_key: str, model_keys: set) -> str:
"""Normalize LoRA module name to match model weight keys.
Auto-detects whether to use Wan2.2 or LTX key normalization based
on the presence of architecture-specific keys in the model.
"""
# Detect model architecture from keys
is_wan = any("self_attn.q.weight" in k for k in model_keys)
if is_wan:
return _normalize_wan_lora_key(lora_key, model_keys)
else:
return _normalize_ltx_lora_key(lora_key, model_keys)
def apply_loras_to_weights(
model_weights: Dict[str, mx.array],
module_to_loras: Dict[str, List[Tuple[LoRAWeights, float]]],
verbose: bool = False,
quantization_bits: int = 0,
) -> Dict[str, mx.array]:
"""Apply LoRAs to model weights.
Args:
model_weights: Original model state dictionary
module_to_loras: Dictionary mapping module names to lists of
(LoRAWeights, strength) tuples
verbose: If True, print detailed debug information
quantization_bits: If >0, weights are quantized at this bit width.
Quantized layers are dequantized before LoRA application
and re-quantized after.
Returns:
New state dictionary with LoRA-modified weights
"""
modified_weights = dict(model_weights)
model_keys = set(model_weights.keys())
applied_count = 0
skipped_count = 0
skipped_modules = []
for module_name, loras in module_to_loras.items():
normalized_name = _normalize_lora_key(module_name, model_keys)
weight_key = f"{normalized_name}.weight"
if weight_key not in modified_weights:
if normalized_name not in modified_weights:
skipped_count += 1
skipped_modules.append(module_name)
if verbose and skipped_count <= 5:
print(f" DEBUG: '{module_name}' -> '{normalized_name}' -> NOT FOUND")
similar = [
k
for k in list(model_keys)[:1000]
if normalized_name.split(".")[-1] in k
][:3]
if similar:
print(f" Similar keys: {similar}")
continue
weight_key = normalized_name
original_weight = modified_weights[weight_key]
# Handle quantized weights: dequantize → apply delta → re-quantize
scales_key = f"{normalized_name}.scales"
biases_key = f"{normalized_name}.biases"
is_quantized = (
original_weight.dtype == mx.uint32
and scales_key in modified_weights
and biases_key in modified_weights
)
if is_quantized:
scales = modified_weights[scales_key]
biases = modified_weights[biases_key]
group_size = (original_weight.shape[-1] * 32) // (scales.shape[-1] * quantization_bits)
dequantized = mx.dequantize(
original_weight, scales, biases, group_size=group_size, bits=quantization_bits
)
modified = apply_lora_to_linear(dequantized, loras)
# Re-quantize with same parameters
new_w, new_scales, new_biases = mx.quantize(modified, group_size=group_size, bits=quantization_bits)
modified_weights[weight_key] = new_w
modified_weights[scales_key] = new_scales
modified_weights[biases_key] = new_biases
else:
modified_weights[weight_key] = apply_lora_to_linear(original_weight, loras)
applied_count += 1
if applied_count > 0:
print(f" ✓ Applied to {applied_count} modules")
if skipped_count > 0:
print(f" ⚠ Skipped {skipped_count} incompatible modules")
return modified_weights
class LoRALinear(nn.Module):
"""Linear layer with on-the-fly LoRA application.
Wraps nn.Linear or nn.QuantizedLinear, computing LoRA delta at runtime:
output = base_linear(x) + (x @ lora_A.T @ lora_B.T) * scale * strength
"""
def __init__(
self,
linear: nn.Module,
lora_weights_and_strengths: List[Tuple[LoRAWeights, float]],
):
super().__init__()
self.linear = linear
self.lora_weights_and_strengths = lora_weights_and_strengths
def __call__(self, x: mx.array) -> mx.array:
output = self.linear(x)
for weights, strength in self.lora_weights_and_strengths:
scale = weights.scale
lora_out = x @ weights.lora_A.T @ weights.lora_B.T
output = output + (scale * strength * lora_out)
return output
def apply_loras_to_model(
model: nn.Module,
module_to_loras: Dict[str, List[Tuple[LoRAWeights, float]]],
verbose: bool = False,
) -> int:
"""Apply LoRAs to a model by merging into weights.
For QuantizedLinear layers: dequantizes to bf16, merges LoRA delta, and
replaces with a regular nn.Linear (no per-step overhead, no re-quantization
precision loss). Non-LoRA layers stay quantized.
For nn.Linear layers: merges LoRA delta directly into the weight.
Args:
model: The model to apply LoRAs to
module_to_loras: Dictionary mapping module names to (LoRAWeights, strength) lists
verbose: Print debug info
Returns:
Number of modules modified
"""
# Build a set of model module paths for key normalization
module_paths = set()
for name, _ in model.named_modules():
module_paths.add(name)
module_paths.add(f"{name}.weight")
# Map LoRA keys → model module paths
lora_to_module = {}
for lora_key in module_to_loras:
normalized = _normalize_lora_key(lora_key, module_paths)
if normalized.endswith(".weight"):
normalized = normalized[: -len(".weight")]
lora_to_module[lora_key] = normalized
applied_count = 0
dequant_count = 0
skipped = []
for lora_key, loras in module_to_loras.items():
module_path = lora_to_module[lora_key]
parts = module_path.split(".")
# Traverse to the parent module
parent = model
try:
for part in parts[:-1]:
parent = getattr(parent, part) if not part.isdigit() else parent[int(part)]
leaf_name = parts[-1]
target = getattr(parent, leaf_name) if not leaf_name.isdigit() else parent[int(leaf_name)]
except (AttributeError, IndexError, TypeError):
skipped.append(lora_key)
if verbose:
print(f" DEBUG: '{lora_key}' -> '{module_path}' -> module not found")
continue
if isinstance(target, nn.QuantizedLinear):
# Dequantize → merge LoRA → replace with bf16 Linear
weight = mx.dequantize(
target.weight, target.scales, target.biases,
group_size=target.group_size, bits=target.bits,
)
merged = apply_lora_to_linear(weight, loras)
new_linear = nn.Linear(merged.shape[1], merged.shape[0])
new_linear.weight = merged
if "bias" in target:
new_linear.bias = target.bias
if leaf_name.isdigit():
parent[int(leaf_name)] = new_linear
else:
setattr(parent, leaf_name, new_linear)
dequant_count += 1
applied_count += 1
elif isinstance(target, nn.Linear):
# Merge directly into weight
target.weight = apply_lora_to_linear(target.weight, loras)
applied_count += 1
else:
skipped.append(lora_key)
if verbose:
print(f" DEBUG: '{module_path}' is {type(target).__name__}, not Linear")
continue
if applied_count > 0:
msg = f" ✓ Applied to {applied_count} modules"
if dequant_count > 0:
msg += f" ({dequant_count} dequantized to bf16)"
print(msg)
if skipped:
print(f" ⚠ Skipped {len(skipped)} incompatible modules")
return applied_count

122
mlx_video/lora/loader.py Normal file
View File

@@ -0,0 +1,122 @@
"""LoRA weight loading utilities."""
import re
from pathlib import Path
from typing import Dict, List, Optional
import mlx.core as mx
from mlx_video.lora.types import LoRAConfig, LoRAWeights
def load_lora_weights(lora_path: Path) -> Dict[str, LoRAWeights]:
"""Load LoRA weights from a safetensors file.
Supports both key conventions:
- {module_name}.lora_A.weight / {module_name}.lora_B.weight
- {module_name}.lora_down.weight / {module_name}.lora_up.weight
Args:
lora_path: Path to the LoRA safetensors file
Returns:
Dictionary mapping module names to LoRAWeights objects
Raises:
FileNotFoundError: If the LoRA file doesn't exist
ValueError: If the LoRA file format is invalid
"""
if not lora_path.exists():
raise FileNotFoundError(f"LoRA file not found: {lora_path}")
all_weights = mx.load(str(lora_path))
# Group weights by module name, handling both naming conventions
lora_weights = {}
module_names = set()
for key in all_weights.keys():
# Format 1: {module}.lora_A.weight / {module}.lora_B.weight
match = re.match(r"(.+)\.lora_([AB])\.weight$", key)
if match:
module_names.add(match.group(1))
continue
# Format 2: {module}.lora_down.weight / {module}.lora_up.weight
match = re.match(r"(.+)\.lora_(down|up)\.weight$", key)
if match:
module_names.add(match.group(1))
for module_name in module_names:
# Try both key conventions
key_a = f"{module_name}.lora_A.weight"
key_b = f"{module_name}.lora_B.weight"
if key_a not in all_weights or key_b not in all_weights:
key_a = f"{module_name}.lora_down.weight"
key_b = f"{module_name}.lora_up.weight"
if key_a not in all_weights or key_b not in all_weights:
continue
lora_a = all_weights[key_a]
lora_b = all_weights[key_b]
if lora_a.ndim != 2 or lora_b.ndim != 2:
raise ValueError(
f"Invalid LoRA shape for {module_name}: "
f"lora_A={lora_a.shape}, lora_B={lora_b.shape}"
)
rank = lora_a.shape[0]
if lora_b.shape[1] != rank:
raise ValueError(
f"LoRA rank mismatch for {module_name}: "
f"lora_A rank={rank}, lora_B rank={lora_b.shape[1]}"
)
# Check for per-module alpha stored as a scalar tensor
alpha_key = f"{module_name}.alpha"
if alpha_key in all_weights:
alpha = float(all_weights[alpha_key].item())
else:
alpha = float(rank)
lora_weights[module_name] = LoRAWeights(
lora_A=lora_a,
lora_B=lora_b,
rank=rank,
alpha=alpha,
module_name=module_name,
)
if not lora_weights:
raise ValueError(f"No valid LoRA weights found in {lora_path}")
return lora_weights
def load_multiple_loras(
configs: List[LoRAConfig],
) -> Dict[str, List[tuple]]:
"""Load multiple LoRA configurations.
Args:
configs: List of LoRAConfig objects
Returns:
Dictionary mapping module names to lists of (LoRAWeights, strength) tuples.
"""
module_to_loras: Dict[str, list] = {}
for config in configs:
lora_weights = load_lora_weights(config.path)
for module_name, weights in lora_weights.items():
if config.target_modules is not None:
if module_name not in config.target_modules:
continue
if module_name not in module_to_loras:
module_to_loras[module_name] = []
module_to_loras[module_name].append((weights, config.strength))
return module_to_loras

74
mlx_video/lora/types.py Normal file
View File

@@ -0,0 +1,74 @@
"""Data structures for LoRA support."""
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import mlx.core as mx
@dataclass
class LoRAWeights:
"""Container for LoRA weight matrices.
Attributes:
lora_A: Low-rank matrix A of shape [rank, in_features]
lora_B: Low-rank matrix B of shape [out_features, rank]
rank: Rank of the LoRA decomposition
alpha: LoRA scaling parameter (default: rank)
module_name: Target module name in the model
"""
lora_A: mx.array
lora_B: mx.array
rank: int
alpha: float
module_name: str
@property
def scale(self) -> float:
"""Compute the scale factor: alpha / rank."""
return self.alpha / self.rank
@dataclass
class LoRAConfig:
"""Configuration for a single LoRA.
Attributes:
path: Path to the LoRA safetensors file
strength: Strength/weight to apply this LoRA (typically 0.0-2.0)
target_modules: Optional list of module names to apply LoRA to.
If None, applies to all available modules in the LoRA.
"""
path: Path
strength: float = 1.0
target_modules: Optional[list[str]] = None
def __post_init__(self):
"""Validate and normalize the configuration."""
self.path = Path(self.path)
if not self.path.exists():
raise FileNotFoundError(f"LoRA file not found: {self.path}")
if self.strength < 0:
raise ValueError(f"LoRA strength must be non-negative, got {self.strength}")
@dataclass
class AppliedLoRA:
"""Represents a LoRA applied to a specific module.
Attributes:
weights: The LoRA weight matrices
strength: Application strength for this LoRA
"""
weights: LoRAWeights
strength: float
def compute_delta(self) -> mx.array:
"""Compute the weight delta: strength * scale * (lora_B @ lora_A)."""
scale = self.weights.scale
delta = self.weights.lora_B @ self.weights.lora_A
return scale * self.strength * delta

View File

@@ -4,6 +4,15 @@ import mlx.nn as nn
from .rope import rope_apply
def _linear_dtype(layer) -> mx.Dtype:
"""Get the compute dtype of a linear layer, handling QuantizedLinear and LoRA wrappers."""
# Unwrap LoRA wrapper to get the underlying linear layer
inner = getattr(layer, "linear", layer)
if isinstance(inner, nn.QuantizedLinear):
return inner.scales.dtype
return inner.weight.dtype
class WanRMSNorm(nn.Module):
"""RMS normalization with learnable scale."""
@@ -73,8 +82,8 @@ class WanSelfAttention(nn.Module):
b, s, _ = x.shape
n, d = self.num_heads, self.head_dim
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = self.q.weight.dtype
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = _linear_dtype(self.q)
x_w = x.astype(w_dtype)
q = self.q(x_w)
@@ -154,8 +163,8 @@ class WanCrossAttention(nn.Module):
"""
b = context.shape[0]
n, d = self.num_heads, self.head_dim
# Cast to weight dtype for efficient matmul
w_dtype = self.k.weight.dtype
# Cast to compute dtype for efficient matmul
w_dtype = _linear_dtype(self.k)
ctx = context.astype(w_dtype)
k = self.k(ctx)
if self.norm_k is not None:
@@ -174,8 +183,8 @@ class WanCrossAttention(nn.Module):
b = x.shape[0]
n, d = self.num_heads, self.head_dim
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = self.q.weight.dtype
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = _linear_dtype(self.q)
q = self.q(x.astype(w_dtype))
if self.norm_q is not None:
q = self.norm_q(q)

View File

@@ -6,14 +6,15 @@ import mlx.core as mx
import mlx.nn as nn
def load_wan_model(model_path: Path, config, quantization: dict | None = None):
"""Load and initialize WanModel, with optional quantization support.
def load_wan_model(model_path: Path, config, quantization: dict | None = None, loras: list | None = None):
"""Load and initialize WanModel, with optional quantization and LoRA support.
Args:
model_path: Path to model safetensors file
config: WanModelConfig
quantization: Optional dict with 'bits' and 'group_size' keys.
If provided, creates QuantizedLinear stubs before loading.
loras: Optional list of (lora_path, strength) tuples to apply.
"""
from mlx_video.models.wan.model import WanModel
@@ -30,6 +31,27 @@ def load_wan_model(model_path: Path, config, quantization: dict | None = None):
)
weights = mx.load(str(model_path))
# Apply LoRAs: dequantize+merge for quantized models, weight merge for bf16
if loras:
if quantization:
# Dequantize LoRA-targeted layers, merge delta, replace with bf16 Linear.
# Non-LoRA layers stay 4-bit. Zero per-step overhead.
from mlx_video.convert_wan import _load_lora_configs
from mlx_video.lora import apply_loras_to_model
model.load_weights(list(weights.items()), strict=False)
mx.eval(model.parameters())
module_to_loras = _load_lora_configs(loras)
apply_loras_to_model(model, module_to_loras)
mx.eval(model.parameters())
return model
else:
# Weight merging: fold LoRA into bf16 weights before loading
from mlx_video.convert_wan import load_and_apply_loras
weights = load_and_apply_loras(dict(weights), loras)
model.load_weights(list(weights.items()), strict=False)
mx.eval(model.parameters())
return model

View File

@@ -4,7 +4,7 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .attention import WanLayerNorm
from .attention import WanLayerNorm, _linear_dtype
from .config import WanModelConfig
from .rope import rope_params, rope_precompute_cos_sin
from .transformer import WanAttentionBlock
@@ -54,7 +54,7 @@ class Head(nn.Module):
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
x_norm = self.norm(x)
x_mod = x_norm * (1 + e1) + e0 # type promotion handles bf16→f32
return self.head(x_mod.astype(self.head.weight.dtype))
return self.head(x_mod.astype(_linear_dtype(self.head)))
class WanModel(nn.Module):
@@ -79,7 +79,7 @@ class WanModel(nn.Module):
# Text embedding MLP
self.text_embedding_0 = nn.Linear(config.text_dim, dim)
self.text_embedding_act = nn.GELU(approx="precise")
self.text_embedding_act = nn.GELU(approx="tanh")
self.text_embedding_1 = nn.Linear(dim, dim)
# Time embedding MLP
@@ -149,7 +149,7 @@ class WanModel(nn.Module):
# Project and cast to model dtype to prevent float32 cascade from input latents
patches = self.patch_embedding_proj(x) # [L, dim]
patches = patches.astype(self.patch_embedding_proj.weight.dtype)
patches = patches.astype(_linear_dtype(self.patch_embedding_proj))
patches = patches[None, :, :] # [1, L, dim]
return patches, (f_out, h_out, w_out)
@@ -186,7 +186,7 @@ class WanModel(nn.Module):
Returns:
Embedded context [B, text_len, dim] in model dtype
"""
model_dtype = self.patch_embedding_proj.weight.dtype
model_dtype = _linear_dtype(self.patch_embedding_proj)
context_padded = []
for ctx in context:
pad_len = self.text_len - ctx.shape[0]
@@ -231,7 +231,7 @@ class WanModel(nn.Module):
Returns:
(cos_f, sin_f) precomputed frequency tensors
"""
w_dtype = self.patch_embedding_proj.weight.dtype
w_dtype = _linear_dtype(self.patch_embedding_proj)
return rope_precompute_cos_sin(grid_sizes, self.freqs, dtype=w_dtype)
def __call__(
@@ -348,7 +348,7 @@ class WanModel(nn.Module):
# Pre-compute attention mask from seq_lens (constant across all blocks)
attn_mask = None
w_dtype = self.patch_embedding_proj.weight.dtype
w_dtype = _linear_dtype(self.patch_embedding_proj)
if any(sl < seq_len for sl in seq_lens_list):
attn_mask = mx.zeros((batch_size, 1, 1, seq_len), dtype=w_dtype)
for i, sl in enumerate(seq_lens_list):

View File

@@ -146,7 +146,7 @@ class T5FeedForward(nn.Module):
self.dim = dim
self.dim_ffn = dim_ffn
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
self.gate_act = nn.GELU(approx="precise")
self.gate_act = nn.GELU(approx="tanh")
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)

View File

@@ -1,7 +1,7 @@
import mlx.core as mx
import mlx.nn as nn
from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention
from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention, _linear_dtype
class WanAttentionBlock(nn.Module):
@@ -84,10 +84,10 @@ class WanFFN(nn.Module):
def __init__(self, dim: int, ffn_dim: int):
super().__init__()
self.fc1 = nn.Linear(dim, ffn_dim)
self.act = nn.GELU(approx="precise")
self.act = nn.GELU(approx="tanh")
self.fc2 = nn.Linear(ffn_dim, dim)
def __call__(self, x: mx.array) -> mx.array:
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
x_w = x.astype(self.fc1.weight.dtype)
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
x_w = x.astype(_linear_dtype(self.fc1))
return self.fc2(self.act(self.fc1(x_w)))

View File

@@ -7,12 +7,15 @@ Weight keys mirror the PyTorch checkpoint hierarchy so only tensor format
conversion (channels-first → channels-last) is needed.
"""
import logging
import math
import mlx.core as mx
import mlx.nn as nn
import numpy as np
logger = logging.getLogger(__name__)
CACHE_T = 2
# Per-channel normalization for z_dim=48 latent space
@@ -774,11 +777,13 @@ def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) ->
Maps PyTorch nn.Sequential indices to our named layers.
"""
sanitized = {}
consumed = set()
for key, value in weights.items():
# Skip encoder and conv1 unless requested
if not include_encoder:
if key.startswith("encoder.") or key.startswith("conv1."):
consumed.add(key)
continue
new_key = key
@@ -832,5 +837,10 @@ def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) ->
value = mx.array(np.array(value).squeeze())
sanitized[new_key] = value
consumed.add(key)
unconsumed = set(weights.keys()) - consumed
if unconsumed:
logger.warning("Unconsumed Wan2.2 VAE weight keys: %s", sorted(unconsumed))
return sanitized