557 lines
19 KiB
Python
557 lines
19 KiB
Python
"""Weight conversion for Wan2.2 models (PyTorch -> MLX)."""
|
|
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Dict
|
|
|
|
import mlx.core as mx
|
|
import mlx.utils
|
|
import numpy as np
|
|
|
|
|
|
def load_torch_weights(path: str) -> Dict[str, mx.array]:
|
|
"""Load PyTorch .pth weights and convert to MLX arrays.
|
|
|
|
Args:
|
|
path: Path to .pth file
|
|
|
|
Returns:
|
|
Dictionary of MLX arrays
|
|
"""
|
|
try:
|
|
import torch
|
|
except ImportError:
|
|
raise ImportError("PyTorch is required to load .pth weights: pip install torch")
|
|
|
|
logging.info(f"Loading weights from {path}")
|
|
state_dict = torch.load(path, map_location="cpu", weights_only=True)
|
|
|
|
weights = {}
|
|
for key, value in state_dict.items():
|
|
if isinstance(value, torch.Tensor):
|
|
np_val = value.detach().float().numpy()
|
|
weights[key] = mx.array(np_val)
|
|
|
|
return weights
|
|
|
|
|
|
def load_safetensors_weights(path: str) -> Dict[str, mx.array]:
|
|
"""Load safetensors weights as MLX arrays.
|
|
|
|
Args:
|
|
path: Path to directory with safetensors files or single file
|
|
|
|
Returns:
|
|
Dictionary of MLX arrays
|
|
"""
|
|
path = Path(path)
|
|
weights = {}
|
|
if path.is_file():
|
|
weights = mx.load(str(path))
|
|
elif path.is_dir():
|
|
for sf in sorted(path.glob("*.safetensors")):
|
|
weights.update(mx.load(str(sf)))
|
|
return weights
|
|
|
|
|
|
def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|
"""Convert Wan2.2 transformer weight keys to MLX model structure.
|
|
|
|
Wan2.2 keys follow the pattern:
|
|
patch_embedding.weight/bias
|
|
text_embedding.{0,2}.weight/bias
|
|
time_embedding.{0,2}.weight/bias
|
|
time_projection.1.weight/bias
|
|
blocks.{i}.norm1.weight
|
|
blocks.{i}.self_attn.{q,k,v,o}.weight/bias
|
|
blocks.{i}.self_attn.norm_q.weight
|
|
blocks.{i}.self_attn.norm_k.weight
|
|
blocks.{i}.norm3.weight/bias (if cross_attn_norm)
|
|
blocks.{i}.cross_attn.{q,k,v,o}.weight/bias
|
|
blocks.{i}.cross_attn.norm_q.weight
|
|
blocks.{i}.cross_attn.norm_k.weight
|
|
blocks.{i}.norm2.weight
|
|
blocks.{i}.ffn.{0,2}.weight/bias
|
|
blocks.{i}.modulation
|
|
head.norm.weight
|
|
head.head.weight/bias
|
|
head.modulation
|
|
freqs (buffer)
|
|
|
|
MLX model uses:
|
|
patch_embedding_proj.weight/bias (after patchify reshape)
|
|
text_embedding_0.weight/bias, text_embedding_1.weight/bias
|
|
time_embedding_0.weight/bias, time_embedding_1.weight/bias
|
|
time_projection.weight/bias
|
|
blocks.{i}.norm1.weight
|
|
blocks.{i}.self_attn.{q,k,v,o}.weight/bias
|
|
etc.
|
|
"""
|
|
sanitized = {}
|
|
|
|
for key, value in weights.items():
|
|
new_key = key
|
|
|
|
# Patch embedding: Conv3d(16, 5120, (1,2,2)) weight is [O, I, D, H, W]
|
|
# MLX Linear expects [O, I*D*H*W] after we flatten in patchify
|
|
if key == "patch_embedding.weight":
|
|
# Original: [dim, in_dim, 1, 2, 2] -> reshape to [dim, in_dim*1*2*2]
|
|
value = value.reshape(value.shape[0], -1)
|
|
new_key = "patch_embedding_proj.weight"
|
|
sanitized[new_key] = value
|
|
continue
|
|
if key == "patch_embedding.bias":
|
|
new_key = "patch_embedding_proj.bias"
|
|
sanitized[new_key] = value
|
|
continue
|
|
|
|
# Text embedding Sequential: 0=Linear, 1=GELU(no params), 2=Linear
|
|
if key.startswith("text_embedding.0."):
|
|
new_key = key.replace("text_embedding.0.", "text_embedding_0.")
|
|
sanitized[new_key] = value
|
|
continue
|
|
if key.startswith("text_embedding.2."):
|
|
new_key = key.replace("text_embedding.2.", "text_embedding_1.")
|
|
sanitized[new_key] = value
|
|
continue
|
|
|
|
# Time embedding Sequential: 0=Linear, 1=SiLU(no params), 2=Linear
|
|
if key.startswith("time_embedding.0."):
|
|
new_key = key.replace("time_embedding.0.", "time_embedding_0.")
|
|
sanitized[new_key] = value
|
|
continue
|
|
if key.startswith("time_embedding.2."):
|
|
new_key = key.replace("time_embedding.2.", "time_embedding_1.")
|
|
sanitized[new_key] = value
|
|
continue
|
|
|
|
# Time projection Sequential: 0=SiLU(no params), 1=Linear
|
|
if key.startswith("time_projection.1."):
|
|
new_key = key.replace("time_projection.1.", "time_projection.")
|
|
sanitized[new_key] = value
|
|
continue
|
|
|
|
# FFN: Sequential(Linear, GELU, Linear) -> ffn.{0,2} -> ffn.fc1, ffn.fc2
|
|
new_key = new_key.replace(".ffn.0.", ".ffn.fc1.")
|
|
new_key = new_key.replace(".ffn.2.", ".ffn.fc2.")
|
|
|
|
# Skip the freqs buffer (we compute it in the model)
|
|
if key == "freqs":
|
|
continue
|
|
|
|
sanitized[new_key] = value
|
|
|
|
return sanitized
|
|
|
|
|
|
def sanitize_wan_t5_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|
"""Convert Wan2.2 T5 encoder weight keys to MLX T5Encoder structure.
|
|
|
|
Wan2.2 T5 keys:
|
|
token_embedding.weight
|
|
pos_embedding.embedding.weight (if shared_pos)
|
|
blocks.{i}.norm1.weight
|
|
blocks.{i}.attn.{q,k,v,o}.weight
|
|
blocks.{i}.norm2.weight
|
|
blocks.{i}.ffn.gate.0.weight (gate linear)
|
|
blocks.{i}.ffn.fc1.weight
|
|
blocks.{i}.ffn.fc2.weight
|
|
blocks.{i}.pos_embedding.embedding.weight (if not shared_pos)
|
|
norm.weight
|
|
|
|
MLX T5Encoder structure:
|
|
token_embedding.weight
|
|
blocks.{i}.norm1.weight
|
|
blocks.{i}.attn.{q,k,v,o}.weight
|
|
blocks.{i}.norm2.weight
|
|
blocks.{i}.ffn.gate_proj.weight (mapped from gate.0)
|
|
blocks.{i}.ffn.fc1.weight
|
|
blocks.{i}.ffn.fc2.weight
|
|
blocks.{i}.pos_embedding.embedding.weight
|
|
norm.weight
|
|
"""
|
|
sanitized = {}
|
|
|
|
for key, value in weights.items():
|
|
new_key = key
|
|
|
|
# Map gate.0 -> gate_proj (the GELU is a separate module, not a parameter)
|
|
new_key = new_key.replace(".ffn.gate.0.", ".ffn.gate_proj.")
|
|
|
|
sanitized[new_key] = value
|
|
|
|
return sanitized
|
|
|
|
|
|
def sanitize_wan_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|
"""Convert Wan2.2 VAE weight keys to MLX WanVAE structure.
|
|
|
|
Handles Conv3d and Conv2d weight transpositions for MLX format.
|
|
"""
|
|
sanitized = {}
|
|
|
|
for key, value in weights.items():
|
|
new_key = key
|
|
|
|
# Handle Conv3d: PyTorch [O, I, D, H, W] -> MLX CausalConv3d weight [O, D, H, W, I]
|
|
if "weight" in key and value.ndim == 5:
|
|
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
|
|
|
# Handle Conv2d: PyTorch [O, I, H, W] -> MLX [O, H, W, I]
|
|
if "weight" in key and value.ndim == 4:
|
|
value = mx.transpose(value, (0, 2, 3, 1))
|
|
|
|
# Map decoder keys to MLX decoder structure
|
|
# Wan2.2 uses encoder/decoder with downsamples/upsamples
|
|
# Need to adapt naming for our simplified structure
|
|
|
|
sanitized[new_key] = value
|
|
|
|
return sanitized
|
|
|
|
|
|
def convert_wan_checkpoint(
|
|
checkpoint_dir: str,
|
|
output_dir: str,
|
|
dtype: str = "bfloat16",
|
|
model_version: str = "auto",
|
|
quantize: bool = False,
|
|
bits: int = 4,
|
|
group_size: int = 64,
|
|
):
|
|
"""Convert a Wan2.1 or Wan2.2 checkpoint directory to MLX format.
|
|
|
|
Wan2.2 expected structure:
|
|
checkpoint_dir/
|
|
models_t5_umt5-xxl-enc-bf16.pth
|
|
Wan2.1_VAE.pth
|
|
low_noise_model/ (safetensors)
|
|
high_noise_model/ (safetensors)
|
|
|
|
Wan2.1 expected structure:
|
|
checkpoint_dir/
|
|
models_t5_umt5-xxl-enc-bf16.pth
|
|
Wan2.1_VAE.pth
|
|
diffusion_pytorch_model*.safetensors (single model)
|
|
|
|
Args:
|
|
checkpoint_dir: Path to Wan checkpoint directory
|
|
output_dir: Path to output MLX model directory
|
|
dtype: Target dtype
|
|
model_version: "2.1", "2.2", or "auto" (detect from directory)
|
|
quantize: Whether to quantize the transformer weights
|
|
bits: Quantization bits (4 or 8)
|
|
group_size: Quantization group size (32, 64, or 128)
|
|
"""
|
|
import json
|
|
|
|
checkpoint_dir = Path(checkpoint_dir)
|
|
output_dir = Path(output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
dtype_map = {
|
|
"float16": mx.float16,
|
|
"float32": mx.float32,
|
|
"bfloat16": mx.bfloat16,
|
|
}
|
|
target_dtype = dtype_map.get(dtype, mx.bfloat16)
|
|
|
|
# Auto-detect version
|
|
if model_version == "auto":
|
|
if (checkpoint_dir / "low_noise_model").exists():
|
|
model_version = "2.2"
|
|
elif (checkpoint_dir / "Wan2.2_VAE.pth").exists():
|
|
model_version = "2.2"
|
|
else:
|
|
model_version = "2.1"
|
|
print(f"Auto-detected Wan{model_version} checkpoint")
|
|
|
|
is_dual = (checkpoint_dir / "low_noise_model").exists()
|
|
|
|
if is_dual:
|
|
# Wan2.2: Convert dual transformer models
|
|
low_noise_path = checkpoint_dir / "low_noise_model"
|
|
if low_noise_path.exists():
|
|
print("Converting low-noise transformer...")
|
|
weights = load_safetensors_weights(str(low_noise_path))
|
|
weights = sanitize_wan_transformer_weights(weights)
|
|
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
|
out_path = output_dir / "low_noise_model.safetensors"
|
|
mx.save_safetensors(str(out_path), weights)
|
|
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
|
|
|
high_noise_path = checkpoint_dir / "high_noise_model"
|
|
if high_noise_path.exists():
|
|
print("Converting high-noise transformer...")
|
|
weights = load_safetensors_weights(str(high_noise_path))
|
|
weights = sanitize_wan_transformer_weights(weights)
|
|
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
|
out_path = output_dir / "high_noise_model.safetensors"
|
|
mx.save_safetensors(str(out_path), weights)
|
|
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
|
else:
|
|
# Wan2.1: Convert single transformer model
|
|
# Try safetensors in the checkpoint dir itself
|
|
print("Converting transformer (single model)...")
|
|
weights = load_safetensors_weights(str(checkpoint_dir))
|
|
if not weights:
|
|
# Fallback: look for .pth files
|
|
for pth in sorted(checkpoint_dir.glob("*.pth")):
|
|
if "t5" not in pth.name.lower() and "vae" not in pth.name.lower():
|
|
print(f" Loading from {pth.name}...")
|
|
weights = load_torch_weights(str(pth))
|
|
break
|
|
if weights:
|
|
weights = sanitize_wan_transformer_weights(weights)
|
|
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
|
out_path = output_dir / "model.safetensors"
|
|
mx.save_safetensors(str(out_path), weights)
|
|
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
|
else:
|
|
print(" Warning: No transformer weights found!")
|
|
|
|
# Save config — detect model size from source config.json or transformer weights
|
|
from mlx_video.models.wan.config import WanModelConfig
|
|
|
|
def _detect_config():
|
|
"""Detect config from source config.json or transformer weight shapes."""
|
|
if is_dual:
|
|
return WanModelConfig.wan22_t2v_14b()
|
|
|
|
# Try reading source config.json first (most reliable)
|
|
src_cfg_path = checkpoint_dir / "config.json"
|
|
src_config = None
|
|
if src_cfg_path.exists():
|
|
with open(src_cfg_path) as f:
|
|
src_config = json.load(f)
|
|
|
|
if src_config and "dim" in src_config:
|
|
src_dim = src_config.get("dim", 5120)
|
|
src_in_dim = src_config.get("in_dim", 16)
|
|
src_out_dim = src_config.get("out_dim", 16)
|
|
src_ffn_dim = src_config.get("ffn_dim", 13824)
|
|
src_num_heads = src_config.get("num_heads", 40)
|
|
src_num_layers = src_config.get("num_layers", 40)
|
|
src_model_type = src_config.get("model_type", "t2v")
|
|
src_text_len = src_config.get("text_len", 512)
|
|
|
|
print(f" Source config: dim={src_dim}, layers={src_num_layers}, "
|
|
f"heads={src_num_heads}, type={src_model_type}")
|
|
|
|
is_22 = model_version == "2.2"
|
|
|
|
# Wan2.2 uses different VAE with z_dim=48 and stride (4,16,16)
|
|
vae_z = 48 if is_22 else 16
|
|
vae_s = (4, 16, 16) if is_22 else (4, 8, 8)
|
|
fps = 24 if is_22 else 16
|
|
|
|
return WanModelConfig(
|
|
model_type=src_model_type,
|
|
model_version=model_version,
|
|
dim=src_dim,
|
|
ffn_dim=src_ffn_dim,
|
|
in_dim=src_in_dim,
|
|
out_dim=src_out_dim,
|
|
num_heads=src_num_heads,
|
|
num_layers=src_num_layers,
|
|
text_len=src_text_len,
|
|
vae_z_dim=vae_z,
|
|
vae_stride=vae_s,
|
|
dual_model=False,
|
|
boundary=0.0,
|
|
sample_shift=5.0,
|
|
sample_steps=50,
|
|
sample_guide_scale=5.0,
|
|
sample_fps=fps,
|
|
)
|
|
|
|
# Fallback: detect from saved transformer weight shapes
|
|
saved_model = output_dir / "model.safetensors"
|
|
if saved_model.exists():
|
|
det_weights = mx.load(str(saved_model))
|
|
dim = None
|
|
for k, v in det_weights.items():
|
|
if "patch_embedding_proj.weight" in k:
|
|
dim = v.shape[0]
|
|
break
|
|
del det_weights
|
|
if dim is not None and dim <= 2048:
|
|
print(f" Auto-detected 1.3B model (dim={dim})")
|
|
return WanModelConfig.wan21_t2v_1_3b()
|
|
|
|
return WanModelConfig.wan21_t2v_14b()
|
|
|
|
config = _detect_config()
|
|
config_path = output_dir / "config.json"
|
|
with open(config_path, "w") as f:
|
|
json.dump(config.to_dict(), f, indent=2)
|
|
print(f" Saved config to {config_path}")
|
|
|
|
# Convert T5 encoder
|
|
t5_path = checkpoint_dir / "models_t5_umt5-xxl-enc-bf16.pth"
|
|
if t5_path.exists():
|
|
print("Converting T5 encoder...")
|
|
weights = load_torch_weights(str(t5_path))
|
|
weights = sanitize_wan_t5_weights(weights)
|
|
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
|
out_path = output_dir / "t5_encoder.safetensors"
|
|
mx.save_safetensors(str(out_path), weights)
|
|
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
|
|
|
# Convert VAE (check both naming conventions)
|
|
vae_path = checkpoint_dir / "Wan2.1_VAE.pth"
|
|
is_wan22_vae = False
|
|
if not vae_path.exists():
|
|
vae_path = checkpoint_dir / "Wan2.2_VAE.pth"
|
|
is_wan22_vae = True
|
|
if vae_path.exists():
|
|
print(f"Converting VAE ({'Wan2.2' if is_wan22_vae else 'Wan2.1'})...")
|
|
weights = load_torch_weights(str(vae_path))
|
|
if is_wan22_vae:
|
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
|
weights = sanitize_wan22_vae_weights(weights)
|
|
else:
|
|
weights = sanitize_wan_vae_weights(weights)
|
|
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
|
out_path = output_dir / "vae.safetensors"
|
|
mx.save_safetensors(str(out_path), weights)
|
|
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
|
|
|
# Quantize transformer weights if requested
|
|
if quantize:
|
|
print(f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})...")
|
|
_quantize_saved_model(output_dir, config, is_dual, bits, group_size)
|
|
|
|
print(f"\nConversion complete! Output: {output_dir}")
|
|
|
|
|
|
def _quantize_predicate(path: str, module) -> bool:
|
|
"""Return True for layers that should be quantized.
|
|
|
|
Targets heavyweight Linear layers in attention and FFN blocks.
|
|
Skips embeddings, norms, head, and modulation (small, precision-sensitive).
|
|
"""
|
|
if not hasattr(module, "to_quantized"):
|
|
return False
|
|
# Quantize attention Q/K/V/O and FFN fc1/fc2
|
|
quantize_patterns = (
|
|
".self_attn.q", ".self_attn.k", ".self_attn.v", ".self_attn.o",
|
|
".cross_attn.q", ".cross_attn.k", ".cross_attn.v", ".cross_attn.o",
|
|
".ffn.fc1", ".ffn.fc2",
|
|
)
|
|
return any(path.endswith(p) for p in quantize_patterns)
|
|
|
|
|
|
def _quantize_saved_model(
|
|
output_dir: Path,
|
|
config,
|
|
is_dual: bool,
|
|
bits: int,
|
|
group_size: int,
|
|
):
|
|
"""Load saved bf16 model, quantize, and re-save."""
|
|
import json
|
|
|
|
import mlx.nn as nn
|
|
|
|
from mlx_video.models.wan.model import WanModel
|
|
|
|
model_files = []
|
|
if is_dual:
|
|
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
|
|
p = output_dir / name
|
|
if p.exists():
|
|
model_files.append(p)
|
|
else:
|
|
p = output_dir / "model.safetensors"
|
|
if p.exists():
|
|
model_files.append(p)
|
|
|
|
for model_path in model_files:
|
|
print(f" Quantizing {model_path.name}...")
|
|
model = WanModel(config)
|
|
weights = mx.load(str(model_path))
|
|
model.load_weights(list(weights.items()), strict=False)
|
|
|
|
# Apply quantization to targeted layers
|
|
nn.quantize(
|
|
model,
|
|
group_size=group_size,
|
|
bits=bits,
|
|
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
|
)
|
|
|
|
# Save quantized weights
|
|
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
|
mx.save_safetensors(str(model_path), weights_dict)
|
|
n_quantized = sum(1 for k in weights_dict if ".scales" in k)
|
|
print(f" {n_quantized} layers quantized, {len(weights_dict)} tensors saved")
|
|
|
|
# Update config.json with quantization metadata
|
|
config_path = output_dir / "config.json"
|
|
with open(config_path) as f:
|
|
cfg = json.load(f)
|
|
cfg["quantization"] = {
|
|
"group_size": group_size,
|
|
"bits": bits,
|
|
}
|
|
with open(config_path, "w") as f:
|
|
json.dump(cfg, f, indent=2)
|
|
print(f" Updated config.json with quantization metadata")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Convert Wan model to MLX format")
|
|
parser.add_argument(
|
|
"--checkpoint-dir",
|
|
type=str,
|
|
required=True,
|
|
help="Path to Wan checkpoint directory",
|
|
)
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
type=str,
|
|
default="wan_mlx_model",
|
|
help="Output path for MLX model",
|
|
)
|
|
parser.add_argument(
|
|
"--dtype",
|
|
type=str,
|
|
choices=["float16", "float32", "bfloat16"],
|
|
default="bfloat16",
|
|
help="Target dtype",
|
|
)
|
|
parser.add_argument(
|
|
"--model-version",
|
|
type=str,
|
|
choices=["2.1", "2.2", "auto"],
|
|
default="auto",
|
|
help="Wan model version (auto-detect by default)",
|
|
)
|
|
parser.add_argument(
|
|
"--quantize",
|
|
action="store_true",
|
|
help="Quantize transformer weights for faster inference",
|
|
)
|
|
parser.add_argument(
|
|
"--bits",
|
|
type=int,
|
|
choices=[4, 8],
|
|
default=4,
|
|
help="Quantization bits (default: 4)",
|
|
)
|
|
parser.add_argument(
|
|
"--group-size",
|
|
type=int,
|
|
choices=[32, 64, 128],
|
|
default=64,
|
|
help="Quantization group size (default: 64)",
|
|
)
|
|
args = parser.parse_args()
|
|
convert_wan_checkpoint(
|
|
args.checkpoint_dir, args.output_dir, args.dtype, args.model_version,
|
|
quantize=args.quantize, bits=args.bits, group_size=args.group_size,
|
|
)
|