feat(wan): Add Wan2.1/2.2 T2V with quantization support

This commit is contained in:
Daniel
2026-02-26 16:16:07 +01:00
parent 7a74946c57
commit e64483a66a
21 changed files with 5309 additions and 35 deletions

556
mlx_video/convert_wan.py Normal file
View File

@@ -0,0 +1,556 @@
"""Weight conversion for Wan2.2 models (PyTorch -> MLX)."""
import logging
from pathlib import Path
from typing import Dict
import mlx.core as mx
import mlx.utils
import numpy as np
def load_torch_weights(path: str) -> Dict[str, mx.array]:
"""Load PyTorch .pth weights and convert to MLX arrays.
Args:
path: Path to .pth file
Returns:
Dictionary of MLX arrays
"""
try:
import torch
except ImportError:
raise ImportError("PyTorch is required to load .pth weights: pip install torch")
logging.info(f"Loading weights from {path}")
state_dict = torch.load(path, map_location="cpu", weights_only=True)
weights = {}
for key, value in state_dict.items():
if isinstance(value, torch.Tensor):
np_val = value.detach().float().numpy()
weights[key] = mx.array(np_val)
return weights
def load_safetensors_weights(path: str) -> Dict[str, mx.array]:
"""Load safetensors weights as MLX arrays.
Args:
path: Path to directory with safetensors files or single file
Returns:
Dictionary of MLX arrays
"""
path = Path(path)
weights = {}
if path.is_file():
weights = mx.load(str(path))
elif path.is_dir():
for sf in sorted(path.glob("*.safetensors")):
weights.update(mx.load(str(sf)))
return weights
def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Convert Wan2.2 transformer weight keys to MLX model structure.
Wan2.2 keys follow the pattern:
patch_embedding.weight/bias
text_embedding.{0,2}.weight/bias
time_embedding.{0,2}.weight/bias
time_projection.1.weight/bias
blocks.{i}.norm1.weight
blocks.{i}.self_attn.{q,k,v,o}.weight/bias
blocks.{i}.self_attn.norm_q.weight
blocks.{i}.self_attn.norm_k.weight
blocks.{i}.norm3.weight/bias (if cross_attn_norm)
blocks.{i}.cross_attn.{q,k,v,o}.weight/bias
blocks.{i}.cross_attn.norm_q.weight
blocks.{i}.cross_attn.norm_k.weight
blocks.{i}.norm2.weight
blocks.{i}.ffn.{0,2}.weight/bias
blocks.{i}.modulation
head.norm.weight
head.head.weight/bias
head.modulation
freqs (buffer)
MLX model uses:
patch_embedding_proj.weight/bias (after patchify reshape)
text_embedding_0.weight/bias, text_embedding_1.weight/bias
time_embedding_0.weight/bias, time_embedding_1.weight/bias
time_projection.weight/bias
blocks.{i}.norm1.weight
blocks.{i}.self_attn.{q,k,v,o}.weight/bias
etc.
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Patch embedding: Conv3d(16, 5120, (1,2,2)) weight is [O, I, D, H, W]
# MLX Linear expects [O, I*D*H*W] after we flatten in patchify
if key == "patch_embedding.weight":
# Original: [dim, in_dim, 1, 2, 2] -> reshape to [dim, in_dim*1*2*2]
value = value.reshape(value.shape[0], -1)
new_key = "patch_embedding_proj.weight"
sanitized[new_key] = value
continue
if key == "patch_embedding.bias":
new_key = "patch_embedding_proj.bias"
sanitized[new_key] = value
continue
# Text embedding Sequential: 0=Linear, 1=GELU(no params), 2=Linear
if key.startswith("text_embedding.0."):
new_key = key.replace("text_embedding.0.", "text_embedding_0.")
sanitized[new_key] = value
continue
if key.startswith("text_embedding.2."):
new_key = key.replace("text_embedding.2.", "text_embedding_1.")
sanitized[new_key] = value
continue
# Time embedding Sequential: 0=Linear, 1=SiLU(no params), 2=Linear
if key.startswith("time_embedding.0."):
new_key = key.replace("time_embedding.0.", "time_embedding_0.")
sanitized[new_key] = value
continue
if key.startswith("time_embedding.2."):
new_key = key.replace("time_embedding.2.", "time_embedding_1.")
sanitized[new_key] = value
continue
# Time projection Sequential: 0=SiLU(no params), 1=Linear
if key.startswith("time_projection.1."):
new_key = key.replace("time_projection.1.", "time_projection.")
sanitized[new_key] = value
continue
# FFN: Sequential(Linear, GELU, Linear) -> ffn.{0,2} -> ffn.fc1, ffn.fc2
new_key = new_key.replace(".ffn.0.", ".ffn.fc1.")
new_key = new_key.replace(".ffn.2.", ".ffn.fc2.")
# Skip the freqs buffer (we compute it in the model)
if key == "freqs":
continue
sanitized[new_key] = value
return sanitized
def sanitize_wan_t5_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Convert Wan2.2 T5 encoder weight keys to MLX T5Encoder structure.
Wan2.2 T5 keys:
token_embedding.weight
pos_embedding.embedding.weight (if shared_pos)
blocks.{i}.norm1.weight
blocks.{i}.attn.{q,k,v,o}.weight
blocks.{i}.norm2.weight
blocks.{i}.ffn.gate.0.weight (gate linear)
blocks.{i}.ffn.fc1.weight
blocks.{i}.ffn.fc2.weight
blocks.{i}.pos_embedding.embedding.weight (if not shared_pos)
norm.weight
MLX T5Encoder structure:
token_embedding.weight
blocks.{i}.norm1.weight
blocks.{i}.attn.{q,k,v,o}.weight
blocks.{i}.norm2.weight
blocks.{i}.ffn.gate_proj.weight (mapped from gate.0)
blocks.{i}.ffn.fc1.weight
blocks.{i}.ffn.fc2.weight
blocks.{i}.pos_embedding.embedding.weight
norm.weight
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Map gate.0 -> gate_proj (the GELU is a separate module, not a parameter)
new_key = new_key.replace(".ffn.gate.0.", ".ffn.gate_proj.")
sanitized[new_key] = value
return sanitized
def sanitize_wan_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Convert Wan2.2 VAE weight keys to MLX WanVAE structure.
Handles Conv3d and Conv2d weight transpositions for MLX format.
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Handle Conv3d: PyTorch [O, I, D, H, W] -> MLX CausalConv3d weight [O, D, H, W, I]
if "weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Handle Conv2d: PyTorch [O, I, H, W] -> MLX [O, H, W, I]
if "weight" in key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
# Map decoder keys to MLX decoder structure
# Wan2.2 uses encoder/decoder with downsamples/upsamples
# Need to adapt naming for our simplified structure
sanitized[new_key] = value
return sanitized
def convert_wan_checkpoint(
checkpoint_dir: str,
output_dir: str,
dtype: str = "bfloat16",
model_version: str = "auto",
quantize: bool = False,
bits: int = 4,
group_size: int = 64,
):
"""Convert a Wan2.1 or Wan2.2 checkpoint directory to MLX format.
Wan2.2 expected structure:
checkpoint_dir/
models_t5_umt5-xxl-enc-bf16.pth
Wan2.1_VAE.pth
low_noise_model/ (safetensors)
high_noise_model/ (safetensors)
Wan2.1 expected structure:
checkpoint_dir/
models_t5_umt5-xxl-enc-bf16.pth
Wan2.1_VAE.pth
diffusion_pytorch_model*.safetensors (single model)
Args:
checkpoint_dir: Path to Wan checkpoint directory
output_dir: Path to output MLX model directory
dtype: Target dtype
model_version: "2.1", "2.2", or "auto" (detect from directory)
quantize: Whether to quantize the transformer weights
bits: Quantization bits (4 or 8)
group_size: Quantization group size (32, 64, or 128)
"""
import json
checkpoint_dir = Path(checkpoint_dir)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
dtype_map = {
"float16": mx.float16,
"float32": mx.float32,
"bfloat16": mx.bfloat16,
}
target_dtype = dtype_map.get(dtype, mx.bfloat16)
# Auto-detect version
if model_version == "auto":
if (checkpoint_dir / "low_noise_model").exists():
model_version = "2.2"
elif (checkpoint_dir / "Wan2.2_VAE.pth").exists():
model_version = "2.2"
else:
model_version = "2.1"
print(f"Auto-detected Wan{model_version} checkpoint")
is_dual = (checkpoint_dir / "low_noise_model").exists()
if is_dual:
# Wan2.2: Convert dual transformer models
low_noise_path = checkpoint_dir / "low_noise_model"
if low_noise_path.exists():
print("Converting low-noise transformer...")
weights = load_safetensors_weights(str(low_noise_path))
weights = sanitize_wan_transformer_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "low_noise_model.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
high_noise_path = checkpoint_dir / "high_noise_model"
if high_noise_path.exists():
print("Converting high-noise transformer...")
weights = load_safetensors_weights(str(high_noise_path))
weights = sanitize_wan_transformer_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "high_noise_model.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
else:
# Wan2.1: Convert single transformer model
# Try safetensors in the checkpoint dir itself
print("Converting transformer (single model)...")
weights = load_safetensors_weights(str(checkpoint_dir))
if not weights:
# Fallback: look for .pth files
for pth in sorted(checkpoint_dir.glob("*.pth")):
if "t5" not in pth.name.lower() and "vae" not in pth.name.lower():
print(f" Loading from {pth.name}...")
weights = load_torch_weights(str(pth))
break
if weights:
weights = sanitize_wan_transformer_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "model.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
else:
print(" Warning: No transformer weights found!")
# Save config — detect model size from source config.json or transformer weights
from mlx_video.models.wan.config import WanModelConfig
def _detect_config():
"""Detect config from source config.json or transformer weight shapes."""
if is_dual:
return WanModelConfig.wan22_t2v_14b()
# Try reading source config.json first (most reliable)
src_cfg_path = checkpoint_dir / "config.json"
src_config = None
if src_cfg_path.exists():
with open(src_cfg_path) as f:
src_config = json.load(f)
if src_config and "dim" in src_config:
src_dim = src_config.get("dim", 5120)
src_in_dim = src_config.get("in_dim", 16)
src_out_dim = src_config.get("out_dim", 16)
src_ffn_dim = src_config.get("ffn_dim", 13824)
src_num_heads = src_config.get("num_heads", 40)
src_num_layers = src_config.get("num_layers", 40)
src_model_type = src_config.get("model_type", "t2v")
src_text_len = src_config.get("text_len", 512)
print(f" Source config: dim={src_dim}, layers={src_num_layers}, "
f"heads={src_num_heads}, type={src_model_type}")
is_22 = model_version == "2.2"
# Wan2.2 uses different VAE with z_dim=48 and stride (4,16,16)
vae_z = 48 if is_22 else 16
vae_s = (4, 16, 16) if is_22 else (4, 8, 8)
fps = 24 if is_22 else 16
return WanModelConfig(
model_type=src_model_type,
model_version=model_version,
dim=src_dim,
ffn_dim=src_ffn_dim,
in_dim=src_in_dim,
out_dim=src_out_dim,
num_heads=src_num_heads,
num_layers=src_num_layers,
text_len=src_text_len,
vae_z_dim=vae_z,
vae_stride=vae_s,
dual_model=False,
boundary=0.0,
sample_shift=5.0,
sample_steps=50,
sample_guide_scale=5.0,
sample_fps=fps,
)
# Fallback: detect from saved transformer weight shapes
saved_model = output_dir / "model.safetensors"
if saved_model.exists():
det_weights = mx.load(str(saved_model))
dim = None
for k, v in det_weights.items():
if "patch_embedding_proj.weight" in k:
dim = v.shape[0]
break
del det_weights
if dim is not None and dim <= 2048:
print(f" Auto-detected 1.3B model (dim={dim})")
return WanModelConfig.wan21_t2v_1_3b()
return WanModelConfig.wan21_t2v_14b()
config = _detect_config()
config_path = output_dir / "config.json"
with open(config_path, "w") as f:
json.dump(config.to_dict(), f, indent=2)
print(f" Saved config to {config_path}")
# Convert T5 encoder
t5_path = checkpoint_dir / "models_t5_umt5-xxl-enc-bf16.pth"
if t5_path.exists():
print("Converting T5 encoder...")
weights = load_torch_weights(str(t5_path))
weights = sanitize_wan_t5_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "t5_encoder.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
# Convert VAE (check both naming conventions)
vae_path = checkpoint_dir / "Wan2.1_VAE.pth"
is_wan22_vae = False
if not vae_path.exists():
vae_path = checkpoint_dir / "Wan2.2_VAE.pth"
is_wan22_vae = True
if vae_path.exists():
print(f"Converting VAE ({'Wan2.2' if is_wan22_vae else 'Wan2.1'})...")
weights = load_torch_weights(str(vae_path))
if is_wan22_vae:
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = sanitize_wan22_vae_weights(weights)
else:
weights = sanitize_wan_vae_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "vae.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
# Quantize transformer weights if requested
if quantize:
print(f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})...")
_quantize_saved_model(output_dir, config, is_dual, bits, group_size)
print(f"\nConversion complete! Output: {output_dir}")
def _quantize_predicate(path: str, module) -> bool:
"""Return True for layers that should be quantized.
Targets heavyweight Linear layers in attention and FFN blocks.
Skips embeddings, norms, head, and modulation (small, precision-sensitive).
"""
if not hasattr(module, "to_quantized"):
return False
# Quantize attention Q/K/V/O and FFN fc1/fc2
quantize_patterns = (
".self_attn.q", ".self_attn.k", ".self_attn.v", ".self_attn.o",
".cross_attn.q", ".cross_attn.k", ".cross_attn.v", ".cross_attn.o",
".ffn.fc1", ".ffn.fc2",
)
return any(path.endswith(p) for p in quantize_patterns)
def _quantize_saved_model(
output_dir: Path,
config,
is_dual: bool,
bits: int,
group_size: int,
):
"""Load saved bf16 model, quantize, and re-save."""
import json
import mlx.nn as nn
from mlx_video.models.wan.model import WanModel
model_files = []
if is_dual:
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
p = output_dir / name
if p.exists():
model_files.append(p)
else:
p = output_dir / "model.safetensors"
if p.exists():
model_files.append(p)
for model_path in model_files:
print(f" Quantizing {model_path.name}...")
model = WanModel(config)
weights = mx.load(str(model_path))
model.load_weights(list(weights.items()), strict=False)
# Apply quantization to targeted layers
nn.quantize(
model,
group_size=group_size,
bits=bits,
class_predicate=lambda path, m: _quantize_predicate(path, m),
)
# Save quantized weights
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
mx.save_safetensors(str(model_path), weights_dict)
n_quantized = sum(1 for k in weights_dict if ".scales" in k)
print(f" {n_quantized} layers quantized, {len(weights_dict)} tensors saved")
# Update config.json with quantization metadata
config_path = output_dir / "config.json"
with open(config_path) as f:
cfg = json.load(f)
cfg["quantization"] = {
"group_size": group_size,
"bits": bits,
}
with open(config_path, "w") as f:
json.dump(cfg, f, indent=2)
print(f" Updated config.json with quantization metadata")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Convert Wan model to MLX format")
parser.add_argument(
"--checkpoint-dir",
type=str,
required=True,
help="Path to Wan checkpoint directory",
)
parser.add_argument(
"--output-dir",
type=str,
default="wan_mlx_model",
help="Output path for MLX model",
)
parser.add_argument(
"--dtype",
type=str,
choices=["float16", "float32", "bfloat16"],
default="bfloat16",
help="Target dtype",
)
parser.add_argument(
"--model-version",
type=str,
choices=["2.1", "2.2", "auto"],
default="auto",
help="Wan model version (auto-detect by default)",
)
parser.add_argument(
"--quantize",
action="store_true",
help="Quantize transformer weights for faster inference",
)
parser.add_argument(
"--bits",
type=int,
choices=[4, 8],
default=4,
help="Quantization bits (default: 4)",
)
parser.add_argument(
"--group-size",
type=int,
choices=[32, 64, 128],
default=64,
help="Quantization group size (default: 64)",
)
args = parser.parse_args()
convert_wan_checkpoint(
args.checkpoint_dir, args.output_dir, args.dtype, args.model_version,
quantize=args.quantize, bits=args.bits, group_size=args.group_size,
)