Remove deprecated stubs for video conversion and generation; introduce new weight conversion and generation scripts for Wan2.2 models in MLX.

This commit is contained in:
Prince Canuma
2026-03-18 17:20:36 +01:00
parent 7b9d0a5e44
commit 95d7c81b20
4 changed files with 0 additions and 7 deletions

View File

@@ -0,0 +1,773 @@
"""Weight conversion for Wan2.2 models (PyTorch -> MLX)."""
import gc
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.utils
import numpy as np
logger = logging.getLogger(__name__)
def load_torch_weights(path: str) -> Dict[str, mx.array]:
"""Load PyTorch .pth weights and convert to MLX arrays.
Args:
path: Path to .pth file
Returns:
Dictionary of MLX arrays
"""
try:
import torch
except ImportError:
raise ImportError("PyTorch is required to load .pth weights: pip install torch")
logging.info(f"Loading weights from {path}")
state_dict = torch.load(path, map_location="cpu", weights_only=True)
weights = {}
for key, value in state_dict.items():
if isinstance(value, torch.Tensor):
np_val = value.detach().float().numpy()
weights[key] = mx.array(np_val)
return weights
def load_safetensors_weights(path: str) -> Dict[str, mx.array]:
"""Load safetensors weights as MLX arrays.
Args:
path: Path to directory with safetensors files or single file
Returns:
Dictionary of MLX arrays
"""
path = Path(path)
weights = {}
if path.is_file():
weights = mx.load(str(path))
elif path.is_dir():
for sf in sorted(path.glob("*.safetensors")):
weights.update(mx.load(str(sf)))
return weights
def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Convert Wan2.2 transformer weight keys to MLX model structure.
Wan2.2 keys follow the pattern:
patch_embedding.weight/bias
text_embedding.{0,2}.weight/bias
time_embedding.{0,2}.weight/bias
time_projection.1.weight/bias
blocks.{i}.norm1.weight
blocks.{i}.self_attn.{q,k,v,o}.weight/bias
blocks.{i}.self_attn.norm_q.weight
blocks.{i}.self_attn.norm_k.weight
blocks.{i}.norm3.weight/bias (if cross_attn_norm)
blocks.{i}.cross_attn.{q,k,v,o}.weight/bias
blocks.{i}.cross_attn.norm_q.weight
blocks.{i}.cross_attn.norm_k.weight
blocks.{i}.norm2.weight
blocks.{i}.ffn.{0,2}.weight/bias
blocks.{i}.modulation
head.norm.weight
head.head.weight/bias
head.modulation
freqs (buffer)
MLX model uses:
patch_embedding_proj.weight/bias (after patchify reshape)
text_embedding_0.weight/bias, text_embedding_1.weight/bias
time_embedding_0.weight/bias, time_embedding_1.weight/bias
time_projection.weight/bias
blocks.{i}.norm1.weight
blocks.{i}.self_attn.{q,k,v,o}.weight/bias
etc.
"""
sanitized = {}
consumed = set()
for key, value in weights.items():
new_key = key
# Patch embedding: Conv3d(16, 5120, (1,2,2)) weight is [O, I, D, H, W]
# MLX Linear expects [O, I*D*H*W] after we flatten in patchify
if key == "patch_embedding.weight":
# Original: [dim, in_dim, 1, 2, 2] -> reshape to [dim, in_dim*1*2*2]
value = value.reshape(value.shape[0], -1)
new_key = "patch_embedding_proj.weight"
sanitized[new_key] = value
consumed.add(key)
continue
if key == "patch_embedding.bias":
new_key = "patch_embedding_proj.bias"
sanitized[new_key] = value
consumed.add(key)
continue
# Text embedding Sequential: 0=Linear, 1=GELU(no params), 2=Linear
if key.startswith("text_embedding.0."):
new_key = key.replace("text_embedding.0.", "text_embedding_0.")
sanitized[new_key] = value
consumed.add(key)
continue
if key.startswith("text_embedding.2."):
new_key = key.replace("text_embedding.2.", "text_embedding_1.")
sanitized[new_key] = value
consumed.add(key)
continue
# Time embedding Sequential: 0=Linear, 1=SiLU(no params), 2=Linear
if key.startswith("time_embedding.0."):
new_key = key.replace("time_embedding.0.", "time_embedding_0.")
sanitized[new_key] = value
consumed.add(key)
continue
if key.startswith("time_embedding.2."):
new_key = key.replace("time_embedding.2.", "time_embedding_1.")
sanitized[new_key] = value
consumed.add(key)
continue
# Time projection Sequential: 0=SiLU(no params), 1=Linear
if key.startswith("time_projection.1."):
new_key = key.replace("time_projection.1.", "time_projection.")
sanitized[new_key] = value
consumed.add(key)
continue
# FFN: Sequential(Linear, GELU, Linear) -> ffn.{0,2} -> ffn.fc1, ffn.fc2
new_key = new_key.replace(".ffn.0.", ".ffn.fc1.")
new_key = new_key.replace(".ffn.2.", ".ffn.fc2.")
# Skip the freqs buffer (we compute it in the model)
if key == "freqs":
consumed.add(key)
continue
sanitized[new_key] = value
consumed.add(key)
unconsumed = set(weights.keys()) - consumed
if unconsumed:
logger.warning("Unconsumed transformer weight keys: %s", sorted(unconsumed))
return sanitized
def sanitize_wan_t5_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Convert Wan2.2 T5 encoder weight keys to MLX T5Encoder structure.
Wan2.2 T5 keys:
token_embedding.weight
pos_embedding.embedding.weight (if shared_pos)
blocks.{i}.norm1.weight
blocks.{i}.attn.{q,k,v,o}.weight
blocks.{i}.norm2.weight
blocks.{i}.ffn.gate.0.weight (gate linear)
blocks.{i}.ffn.fc1.weight
blocks.{i}.ffn.fc2.weight
blocks.{i}.pos_embedding.embedding.weight (if not shared_pos)
norm.weight
MLX T5Encoder structure:
token_embedding.weight
blocks.{i}.norm1.weight
blocks.{i}.attn.{q,k,v,o}.weight
blocks.{i}.norm2.weight
blocks.{i}.ffn.gate_proj.weight (mapped from gate.0)
blocks.{i}.ffn.fc1.weight
blocks.{i}.ffn.fc2.weight
blocks.{i}.pos_embedding.embedding.weight
norm.weight
"""
sanitized = {}
consumed = set()
for key, value in weights.items():
new_key = key
# Map gate.0 -> gate_proj (the GELU is a separate module, not a parameter)
new_key = new_key.replace(".ffn.gate.0.", ".ffn.gate_proj.")
sanitized[new_key] = value
consumed.add(key)
unconsumed = set(weights.keys()) - consumed
if unconsumed:
logger.warning("Unconsumed T5 weight keys: %s", sorted(unconsumed))
return sanitized
def sanitize_wan_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Convert Wan2.2 VAE weight keys to MLX WanVAE structure.
Handles Conv3d and Conv2d weight transpositions for MLX format.
"""
sanitized = {}
consumed = set()
for key, value in weights.items():
new_key = key
# Handle Conv3d: PyTorch [O, I, D, H, W] -> MLX CausalConv3d weight [O, D, H, W, I]
if "weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Handle Conv2d: PyTorch [O, I, H, W] -> MLX [O, H, W, I]
if "weight" in key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
# Map decoder keys to MLX decoder structure
# Wan2.2 uses encoder/decoder with downsamples/upsamples
# Need to adapt naming for our simplified structure
sanitized[new_key] = value
consumed.add(key)
unconsumed = set(weights.keys()) - consumed
if unconsumed:
logger.warning("Unconsumed VAE weight keys: %s", sorted(unconsumed))
return sanitized
def _load_lora_configs(
lora_configs: List[Tuple[str, float]],
) -> Dict[str, list]:
"""Load LoRA weights from config tuples, returning module_to_loras dict.
Shared between weight-merging and runtime-wrapping paths.
"""
from mlx_video.lora import LoRAConfig, load_multiple_loras
from mlx_video.generate_wan import Colors
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
configs = []
for lora_path, strength in lora_configs:
try:
config = LoRAConfig(path=lora_path, strength=strength)
configs.append(config)
print(f" - {Path(lora_path).name} (strength: {strength})")
except Exception as e:
print(f"{Colors.RED}Error loading LoRA {lora_path}: {e}{Colors.RESET}")
raise
module_to_loras = load_multiple_loras(configs)
if not module_to_loras:
print(f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}")
return module_to_loras
def load_and_apply_loras(
model_weights: Dict[str, mx.array],
lora_configs: Optional[List[Tuple[str, float]]] = None,
verbose: bool = False,
quantization_bits: int = 0,
) -> Dict[str, mx.array]:
"""Load and apply LoRA weights to model weights by merging into weight dict.
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
"""
from mlx_video.lora import apply_loras_to_weights
from mlx_video.generate_wan import Colors
if not lora_configs:
return model_weights
module_to_loras = _load_lora_configs(lora_configs)
if not module_to_loras:
return model_weights
print(f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}")
if verbose:
print(f" Model has {len(model_weights)} weight keys")
modified_weights = apply_loras_to_weights(
model_weights, module_to_loras, verbose=verbose, quantization_bits=quantization_bits
)
print(f"{Colors.GREEN}✓ LoRAs applied successfully{Colors.RESET}")
return modified_weights
def convert_wan_checkpoint(
checkpoint_dir: str,
output_dir: str,
dtype: str = "bfloat16",
model_version: str = "auto",
quantize: bool = False,
bits: int = 4,
group_size: int = 64,
):
"""Convert a Wan2.1 or Wan2.2 checkpoint directory to MLX format.
Wan2.2 expected structure:
checkpoint_dir/
models_t5_umt5-xxl-enc-bf16.pth
Wan2.1_VAE.pth
low_noise_model/ (safetensors)
high_noise_model/ (safetensors)
Wan2.1 expected structure:
checkpoint_dir/
models_t5_umt5-xxl-enc-bf16.pth
Wan2.1_VAE.pth
diffusion_pytorch_model*.safetensors (single model)
Args:
checkpoint_dir: Path to Wan checkpoint directory
output_dir: Path to output MLX model directory
dtype: Target dtype
model_version: "2.1", "2.2", or "auto" (detect from directory)
quantize: Whether to quantize the transformer weights
bits: Quantization bits (4 or 8)
group_size: Quantization group size (32, 64, or 128)
"""
import json
checkpoint_dir = Path(checkpoint_dir)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
dtype_map = {
"float16": mx.float16,
"float32": mx.float32,
"bfloat16": mx.bfloat16,
}
target_dtype = dtype_map.get(dtype, mx.bfloat16)
# Auto-detect version
if model_version == "auto":
if (checkpoint_dir / "low_noise_model").exists():
model_version = "2.2"
elif (checkpoint_dir / "Wan2.2_VAE.pth").exists():
model_version = "2.2"
else:
model_version = "2.1"
print(f"Auto-detected Wan{model_version} checkpoint")
is_dual = (checkpoint_dir / "low_noise_model").exists()
if is_dual:
# Wan2.2: Convert dual transformer models
low_noise_path = checkpoint_dir / "low_noise_model"
if low_noise_path.exists():
print("Converting low-noise transformer...")
weights = load_safetensors_weights(str(low_noise_path))
weights = sanitize_wan_transformer_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "low_noise_model.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
high_noise_path = checkpoint_dir / "high_noise_model"
if high_noise_path.exists():
print("Converting high-noise transformer...")
weights = load_safetensors_weights(str(high_noise_path))
weights = sanitize_wan_transformer_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "high_noise_model.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
else:
# Wan2.1: Convert single transformer model
# Try safetensors in the checkpoint dir itself
print("Converting transformer (single model)...")
weights = load_safetensors_weights(str(checkpoint_dir))
if not weights:
# Fallback: look for .pth files
for pth in sorted(checkpoint_dir.glob("*.pth")):
if "t5" not in pth.name.lower() and "vae" not in pth.name.lower():
print(f" Loading from {pth.name}...")
weights = load_torch_weights(str(pth))
break
if weights:
weights = sanitize_wan_transformer_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "model.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
else:
print(" Warning: No transformer weights found!")
# Save config — detect model size from source config.json or transformer weights
from mlx_video.models.wan.config import WanModelConfig
def _detect_config():
"""Detect config from source config.json or transformer weight shapes."""
if is_dual:
# Check source config.json for model_type (I2V vs T2V)
src_cfg_path = checkpoint_dir / "high_noise_model" / "config.json"
if src_cfg_path.exists():
with open(src_cfg_path) as f:
src_config = json.load(f)
src_model_type = src_config.get("model_type", "t2v")
if src_model_type == "i2v" or src_config.get("in_dim") == 36:
return WanModelConfig.wan22_i2v_14b()
return WanModelConfig.wan22_t2v_14b()
# Try reading source config.json first (most reliable)
src_cfg_path = checkpoint_dir / "config.json"
src_config = None
if src_cfg_path.exists():
with open(src_cfg_path) as f:
src_config = json.load(f)
if src_config and "dim" in src_config:
src_dim = src_config.get("dim", 5120)
src_in_dim = src_config.get("in_dim", 16)
src_out_dim = src_config.get("out_dim", 16)
src_ffn_dim = src_config.get("ffn_dim", 13824)
src_num_heads = src_config.get("num_heads", 40)
src_num_layers = src_config.get("num_layers", 40)
src_model_type = src_config.get("model_type", "t2v")
src_text_len = src_config.get("text_len", 512)
print(f" Source config: dim={src_dim}, layers={src_num_layers}, "
f"heads={src_num_heads}, type={src_model_type}")
# Use preset for known TI2V 5B configuration
if src_model_type == "ti2v" and src_dim == 3072:
return WanModelConfig.wan22_ti2v_5b()
is_22 = model_version == "2.2"
# Wan2.2 uses different VAE with z_dim=48 and stride (4,16,16)
vae_z = 48 if is_22 else 16
vae_s = (4, 16, 16) if is_22 else (4, 8, 8)
fps = 24 if is_22 else 16
return WanModelConfig(
model_type=src_model_type,
model_version=model_version,
dim=src_dim,
ffn_dim=src_ffn_dim,
in_dim=src_in_dim,
out_dim=src_out_dim,
num_heads=src_num_heads,
num_layers=src_num_layers,
text_len=src_text_len,
vae_z_dim=vae_z,
vae_stride=vae_s,
dual_model=False,
boundary=0.0,
sample_shift=5.0,
sample_steps=50,
sample_guide_scale=5.0,
sample_fps=fps,
)
# Fallback: detect from saved transformer weight shapes
saved_model = output_dir / "model.safetensors"
if saved_model.exists():
det_weights = mx.load(str(saved_model))
dim = None
for k, v in det_weights.items():
if "patch_embedding_proj.weight" in k:
dim = v.shape[0]
break
del det_weights
if dim is not None and dim <= 2048:
print(f" Auto-detected 1.3B model (dim={dim})")
return WanModelConfig.wan21_t2v_1_3b()
return WanModelConfig.wan21_t2v_14b()
config = _detect_config()
config_path = output_dir / "config.json"
with open(config_path, "w") as f:
json.dump(config.to_dict(), f, indent=2)
print(f" Saved config to {config_path}")
# Convert T5 encoder
t5_path = checkpoint_dir / "models_t5_umt5-xxl-enc-bf16.pth"
if t5_path.exists():
print("Converting T5 encoder...")
weights = load_torch_weights(str(t5_path))
weights = sanitize_wan_t5_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "t5_encoder.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
# Convert VAE (check both naming conventions)
vae_path = checkpoint_dir / "Wan2.1_VAE.pth"
is_wan22_vae = False
if not vae_path.exists():
vae_path = checkpoint_dir / "Wan2.2_VAE.pth"
is_wan22_vae = True
if vae_path.exists():
print(f"Converting VAE ({'Wan2.2' if is_wan22_vae else 'Wan2.1'})...")
weights = load_torch_weights(str(vae_path))
if is_wan22_vae:
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
include_encoder = config.model_type in ("ti2v", "i2v")
weights = sanitize_wan22_vae_weights(weights, include_encoder=include_encoder)
else:
weights = sanitize_wan_vae_weights(weights)
# Always save VAE in float32 — official Wan2.2 runs VAE decode in
# float32 (dtype=torch.float). Saving in bfloat16 loses precision
# that cannot be recovered by upcasting at load time.
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
out_path = output_dir / "vae.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path} (float32)")
# Quantize transformer weights if requested
if quantize:
print(f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})...")
_quantize_saved_model(output_dir, config, is_dual, bits, group_size)
print(f"\nConversion complete! Output: {output_dir}")
def _quantize_predicate(path: str, module) -> bool:
"""Return True for layers that should be quantized.
Targets heavyweight Linear layers in attention and FFN blocks.
Skips embeddings, norms, head, and modulation (small, precision-sensitive).
"""
if not hasattr(module, "to_quantized"):
return False
# Quantize attention Q/K/V/O and FFN fc1/fc2
quantize_patterns = (
".self_attn.q", ".self_attn.k", ".self_attn.v", ".self_attn.o",
".cross_attn.q", ".cross_attn.k", ".cross_attn.v", ".cross_attn.o",
".ffn.fc1", ".ffn.fc2",
)
return any(path.endswith(p) for p in quantize_patterns)
def _quantize_saved_model(
output_dir: Path,
config,
is_dual: bool,
bits: int,
group_size: int,
source_dir: Path = None,
):
"""Load saved bf16 model, quantize, and re-save.
Args:
output_dir: Directory to write quantized weights to.
config: WanModelConfig for creating the model.
is_dual: Whether this is a dual-expert model.
bits: Quantization bits.
group_size: Quantization group size.
source_dir: Directory to read bf16 weights from. Defaults to output_dir.
"""
import json
import mlx.nn as nn
from mlx_video.models.wan.model import WanModel
if source_dir is None:
source_dir = output_dir
model_names = []
if is_dual:
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
if (source_dir / name).exists():
model_names.append(name)
else:
if (source_dir / "model.safetensors").exists():
model_names.append("model.safetensors")
for name in model_names:
print(f" Quantizing {name}...")
model = WanModel(config)
weights = mx.load(str(source_dir / name))
model.load_weights(list(weights.items()), strict=False)
mx.eval(model.parameters())
del weights
gc.collect()
mx.clear_cache()
# Apply quantization to targeted layers
nn.quantize(
model,
group_size=group_size,
bits=bits,
class_predicate=lambda path, m: _quantize_predicate(path, m),
)
# Save quantized weights
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
# Validate: check for NaN/Inf in bias tensors (corruption canary)
bad_keys = []
for k, v in weights_dict.items():
if k.endswith(".bias") and not k.endswith(".biases"):
mx.eval(v)
if mx.any(mx.isnan(v)).item() or mx.any(mx.isinf(v)).item():
bad_keys.append(k)
if bad_keys:
raise RuntimeError(
f"Quantization produced corrupted weights in {model_path.name}: "
f"{len(bad_keys)} bias tensors contain NaN/Inf "
f"(e.g. {bad_keys[0]}). Try re-running with more available memory."
)
mx.save_safetensors(str(output_dir / name), weights_dict)
n_quantized = sum(1 for k in weights_dict if ".scales" in k)
print(f" {n_quantized} layers quantized, {len(weights_dict)} tensors saved")
# Free model before processing next file
del model, weights_dict
gc.collect()
mx.clear_cache()
# Update config.json with quantization metadata
config_path = output_dir / "config.json"
with open(config_path) as f:
cfg = json.load(f)
cfg["quantization"] = {
"group_size": group_size,
"bits": bits,
}
with open(config_path, "w") as f:
json.dump(cfg, f, indent=2)
print(f" Updated config.json with quantization metadata")
def quantize_mlx_model(
mlx_model_dir: str,
output_dir: str,
bits: int = 4,
group_size: int = 64,
):
"""Quantize an already-converted MLX model (skips PyTorch conversion).
Args:
mlx_model_dir: Path to existing MLX model directory (bf16/fp16).
output_dir: Path to output quantized model directory.
bits: Quantization bits (4 or 8).
group_size: Quantization group size (32, 64, or 128).
"""
import json
import shutil
src = Path(mlx_model_dir)
dst = Path(output_dir)
config_path = src / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"No config.json found in {src}")
with open(config_path) as f:
cfg = json.load(f)
if cfg.get("quantization"):
raise ValueError(
f"Model at {src} is already quantized "
f"({cfg['quantization']['bits']}-bit). Use a bf16/fp16 source."
)
# Detect dual vs single expert
is_dual = (src / "low_noise_model.safetensors").exists() and (
src / "high_noise_model.safetensors"
).exists()
# Build model config
from mlx_video.models.wan.config import WanModelConfig
config_dict = {k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__}
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
if key in config_dict and isinstance(config_dict[key], list):
config_dict[key] = tuple(config_dict[key])
config = WanModelConfig(**config_dict)
# Copy non-transformer files to output dir (skip large model weights)
transformer_files = {"low_noise_model.safetensors", "high_noise_model.safetensors", "model.safetensors"}
if dst.resolve() != src.resolve():
dst.mkdir(parents=True, exist_ok=True)
for f in src.iterdir():
if f.is_file() and f.name not in transformer_files:
shutil.copy2(f, dst / f.name)
print(f"Copied non-transformer files from {src} to {dst}")
print(f"Quantizing transformer weights ({bits}-bit, group_size={group_size})...")
_quantize_saved_model(dst, config, is_dual, bits, group_size, source_dir=src)
print(f"\nQuantization complete! Output: {dst}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Convert Wan model to MLX format")
parser.add_argument(
"--checkpoint-dir",
type=str,
required=True,
help="Path to Wan checkpoint directory",
)
parser.add_argument(
"--output-dir",
type=str,
default="wan_mlx_model",
help="Output path for MLX model",
)
parser.add_argument(
"--dtype",
type=str,
choices=["float16", "float32", "bfloat16"],
default="bfloat16",
help="Target dtype",
)
parser.add_argument(
"--model-version",
type=str,
choices=["2.1", "2.2", "auto"],
default="auto",
help="Wan model version (auto-detect by default)",
)
parser.add_argument(
"--quantize",
action="store_true",
help="Quantize transformer weights for faster inference",
)
parser.add_argument(
"--quantize-only",
action="store_true",
help="Quantize an already-converted MLX model (skips PyTorch conversion)",
)
parser.add_argument(
"--bits",
type=int,
choices=[4, 8],
default=4,
help="Quantization bits (default: 4)",
)
parser.add_argument(
"--group-size",
type=int,
choices=[32, 64, 128],
default=64,
help="Quantization group size (default: 64)",
)
args = parser.parse_args()
if args.quantize_only:
quantize_mlx_model(
args.checkpoint_dir, args.output_dir,
bits=args.bits, group_size=args.group_size,
)
else:
convert_wan_checkpoint(
args.checkpoint_dir, args.output_dir, args.dtype, args.model_version,
quantize=args.quantize, bits=args.bits, group_size=args.group_size,
)

View File

@@ -0,0 +1,828 @@
"""Wan2.2 Text-to-Video generation pipeline for MLX."""
import argparse
import gc
import math
import random
import sys
import time
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from tqdm import tqdm
from mlx_video.models.wan.i2v_utils import build_i2v_mask, preprocess_image
from mlx_video.models.wan.loading import (
_clean_text,
encode_text,
load_t5_encoder,
load_vae_decoder,
load_vae_encoder,
load_wan_model,
)
from mlx_video.models.wan.postprocess import save_video
class Colors:
"""ANSI color codes for terminal output."""
CYAN = "\033[96m"
BLUE = "\033[94m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
RED = "\033[91m"
MAGENTA = "\033[95m"
BOLD = "\033[1m"
DIM = "\033[2m"
RESET = "\033[0m"
# Backward-compat alias (tests and external code may use the old name)
_build_i2v_mask = build_i2v_mask
def _best_output_size(w, h, dw, dh, max_area):
"""Compute the best output resolution that fits within max_area while
preserving the input aspect ratio and satisfying alignment constraints.
Matches the reference implementation's best_output_size().
"""
ratio = w / h
ow = (max_area * ratio) ** 0.5
oh = max_area / ow
# Option 1: process width first
ow1 = int(ow // dw * dw)
oh1 = int(max_area / ow1 // dh * dh)
ratio1 = ow1 / oh1
# Option 2: process height first
oh2 = int(oh // dh * dh)
ow2 = int(max_area / oh2 // dw * dw)
ratio2 = ow2 / oh2
if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2, ratio2 / ratio):
return ow1, oh1
return ow2, oh2
def generate_video(
model_dir: str,
prompt: str,
negative_prompt: str | None = None,
image: str | None = None,
width: int = 1280,
height: int = 704,
num_frames: int = 81,
steps: int = None,
guide_scale: str | float | tuple = None,
shift: float = None,
seed: int = -1,
output_path: str = "output.mp4",
scheduler: str = "unipc",
loras: list | None = None,
loras_high: list | None = None,
loras_low: list | None = None,
tiling: str = "auto",
no_compile: bool = False,
trim_first_frames: int = 0,
debug_latents: bool = False,
):
"""Generate video using Wan pipeline (supports T2V and I2V).
Args:
model_dir: Path to converted MLX model directory
prompt: Text prompt
negative_prompt: Negative prompt (None = use config default, "" = no negative prompt)
image: Path to input image for I2V (None = T2V mode)
width: Video width
height: Video height
num_frames: Number of frames (must be 4n+1)
steps: Number of diffusion steps (None = use config default)
guide_scale: Guidance scale: float for single, (low,high) for dual (None = config default)
shift: Noise schedule shift (None = use config default)
seed: Random seed (-1 for random)
output_path: Output video path
scheduler: Solver type: 'euler', 'dpm++', or 'unipc' (default)
loras: Optional list of (path, strength) tuples applied to all models
loras_high: Optional list of (path, strength) tuples for high-noise model only
loras_low: Optional list of (path, strength) tuples for low-noise model only
tiling: Tiling mode for VAE decoding. Options:
- "auto": Automatically determine tiling based on video size (default)
- "none": Disable tiling
- "default", "aggressive", "conservative": Preset tiling configs
- "spatial": Spatial tiling only
- "temporal": Temporal tiling only
no_compile: If True, skip mx.compile on models (useful for debugging)
trim_first_frames: Number of temporal latent positions to generate extra
and discard from the start. Each position = 4 pixel frames. Use 1
to fix first-frame artifacts on 14B models (generates 4 extra frames,
discards first 4). Use 2 for more aggressive trimming. Default: 0.
debug_latents: If True, print per-temporal-position latent statistics
after denoising for diagnosing first-frame artifacts.
"""
import json
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
)
model_dir = Path(model_dir)
# Load config from model dir if available, otherwise auto-detect
config_path = model_dir / "config.json"
quantization = None
if config_path.exists():
with open(config_path) as f:
config_dict = json.load(f)
# Extract quantization config (not a model config field)
quantization = config_dict.pop("quantization", None)
# Handle tuple fields stored as lists in JSON
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
if key in config_dict and isinstance(config_dict[key], list):
config_dict[key] = tuple(config_dict[key])
config = WanModelConfig(**{
k: v for k, v in config_dict.items()
if k in WanModelConfig.__dataclass_fields__
})
else:
# Auto-detect: dual model files → 2.2, single model → 2.1
if (model_dir / "low_noise_model.safetensors").exists():
config = WanModelConfig.wan22_t2v_14b()
else:
# Detect 1.3B vs 14B from weight shapes
model_path = model_dir / "model.safetensors"
if model_path.exists():
probe = mx.load(str(model_path), return_metadata=False)
for k, v in probe.items():
if "patch_embedding_proj.weight" in k:
dim = v.shape[0]
if dim <= 2048:
config = WanModelConfig.wan21_t2v_1_3b()
else:
config = WanModelConfig.wan21_t2v_14b()
break
else:
config = WanModelConfig.wan21_t2v_14b()
del probe
else:
config = WanModelConfig.wan21_t2v_14b()
is_dual = config.dual_model
is_i2v = image is not None
# Validate config against actual weights (handles mismatched config.json)
if not is_dual:
model_path = model_dir / "model.safetensors"
if model_path.exists():
probe = mx.load(str(model_path), return_metadata=False)
for k, v in probe.items():
if "patch_embedding_proj.weight" in k:
actual_dim = v.shape[0]
if actual_dim != config.dim:
print(f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}")
if actual_dim <= 2048:
config = WanModelConfig.wan21_t2v_1_3b()
else:
config = WanModelConfig.wan21_t2v_14b()
break
del probe
# Auto-correct Wan2.2 VAE params from stale configs
if config.in_dim == 48 and config.vae_z_dim != 48:
print(f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}")
config = WanModelConfig(**{
**{f.name: getattr(config, f.name) for f in config.__dataclass_fields__.values()},
"vae_z_dim": 48,
"vae_stride": (4, 16, 16),
"sample_fps": 24,
})
# Apply defaults from config if not overridden
if steps is None:
steps = config.sample_steps
if shift is None:
shift = config.sample_shift
if guide_scale is None:
guide_scale = config.sample_guide_scale
# Normalize guide_scale
if isinstance(guide_scale, (int, float)):
guide_scale = float(guide_scale)
elif isinstance(guide_scale, str):
parts = [float(x) for x in guide_scale.split(",")]
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
# Detect CFG-disabled mode (guide_scale=1.0 for all models → skip uncond pass for 2x speedup)
if isinstance(guide_scale, tuple):
cfg_disabled = all(gs <= 1.0 for gs in guide_scale)
else:
cfg_disabled = guide_scale <= 1.0
# Validate frame count
assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}"
gen_frames = num_frames
if trim_first_frames > 0:
gen_frames = num_frames + trim_first_frames * 4
print(f"{Colors.DIM} Trim: generating {gen_frames} frames, will discard first {trim_first_frames * 4}{Colors.RESET}")
version_str = f"Wan{config.model_version}"
mode_str = "dual-model" if is_dual else "single-model"
pipeline_str = "Image-to-Video" if is_i2v else "Text-to-Video"
# Resolve negative prompt: explicit user value > config default
# The official Wan2.2 uses a Chinese negative prompt (config.sample_neg_prompt)
# that prevents oversaturation, artifacts, and comic look. We use it by default.
# Text cleaning (_clean_text) normalizes fullwidth chars to match official tokenization.
if negative_prompt is None:
neg_prompt_resolved = config.sample_neg_prompt
else:
neg_prompt_resolved = negative_prompt
print(f"{Colors.CYAN}{'='*60}")
print(f" {version_str} {pipeline_str} Generation (MLX, {mode_str})")
print(f"{'='*60}{Colors.RESET}")
print(f"{Colors.DIM} Prompt: {prompt}")
if is_i2v:
print(f" Image: {image}")
if neg_prompt_resolved and neg_prompt_resolved.strip():
neg_display = neg_prompt_resolved[:60] + "..." if len(neg_prompt_resolved) > 60 else neg_prompt_resolved
print(f" Neg prompt: {neg_display}")
print(f" Size: {width}x{height}, Frames: {num_frames}")
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}")
if cfg_disabled:
print(f" CFG: disabled (guide_scale≤1 → B=1 fast path, 2x denoising speedup)")
print(f"{Colors.RESET}")
# Seed
if seed < 0:
seed = random.randint(0, 2**32 - 1)
mx.random.seed(seed)
np.random.seed(seed)
print(f"{Colors.DIM} Seed: {seed}{Colors.RESET}")
# Align dimensions to patch_size * vae_stride (required for patchify)
vae_stride = config.vae_stride
patch_size = config.patch_size
align_h = patch_size[1] * vae_stride[1] # e.g. 2*16=32
align_w = patch_size[2] * vae_stride[2]
if height % align_h != 0 or width % align_w != 0:
old_h, old_w = height, width
height = (height // align_h) * align_h
width = (width // align_w) * align_w
if height == 0:
height = align_h
if width == 0:
width = align_w
print(f"{Colors.DIM} Aligned {old_w}x{old_h}{width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}")
# Enforce max_area constraint (model-specific resolution limit)
if config.max_area > 0 and height * width > config.max_area:
old_h, old_w = height, width
width, height = _best_output_size(width, height, align_w, align_h, config.max_area)
print(
f"{Colors.YELLOW} ⚠ Resolution {old_w}x{old_h} exceeds model's max area "
f"({config.max_area:,}px). Adjusted → {width}x{height}{Colors.RESET}"
)
# Compute target latent shape
z_dim = config.vae_z_dim
t_latent = (gen_frames - 1) // vae_stride[0] + 1
h_latent = height // vae_stride[1]
w_latent = width // vae_stride[2]
target_shape = (z_dim, t_latent, h_latent, w_latent)
# Sequence length for transformer
seq_len = math.ceil(
(h_latent * w_latent) / (patch_size[1] * patch_size[2]) * t_latent
)
print(f"{Colors.DIM} Latent shape: {target_shape}")
print(f" Sequence length: {seq_len}{Colors.RESET}")
# Load T5 encoder
t1 = time.time()
print(f"\n{Colors.BLUE}Loading T5 encoder...{Colors.RESET}")
t5_path = model_dir / "t5_encoder.safetensors"
t5_encoder = load_t5_encoder(t5_path, config)
# Load tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
# Encode prompts
print(f"{Colors.BLUE}Encoding text...{Colors.RESET}")
context = encode_text(t5_encoder, tokenizer, prompt, config.text_len)
if cfg_disabled:
context_null = None
mx.eval(context)
else:
context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len)
mx.eval(context, context_null)
# Free T5 from memory
del t5_encoder
gc.collect(); mx.clear_cache()
print(f"{Colors.DIM} T5 encoding: {time.time() - t1:.1f}s{Colors.RESET}")
# I2V: encode image to latent space
z_img = None
i2v_mask = None
i2v_mask_tokens = None
y_i2v = None
is_i2v_channel_concat = is_i2v and config.model_type == "i2v"
is_i2v_mask_blend = is_i2v and config.model_type != "i2v"
if is_i2v:
print(f"\n{Colors.BLUE}Encoding input image...{Colors.RESET}")
t_img = time.time()
vae_path = model_dir / "vae.safetensors"
if is_i2v_channel_concat:
# I2V-14B: encode full video (first frame = image, rest = zeros)
# and construct y tensor with mask + encoded latents
from PIL import Image
img = Image.open(image).convert("RGB")
scale = max(width / img.width, height / img.height)
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
x1, y1 = (img.width - width) // 2, (img.height - height) // 2
img = img.crop((x1, y1, x1 + width, y1 + height))
img_arr = mx.array(np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0) # [H, W, 3]
img_chw = img_arr.transpose(2, 0, 1) # [3, H, W]
# Build video: first frame = image, rest = zeros -> [3, F, H, W]
# Chunked encoding processes 1-frame + 4-frame chunks with temporal caching
video = mx.concatenate([
img_chw[:, None, :, :],
mx.zeros((3, num_frames - 1, height, width)),
], axis=1)
# Encode through Wan2.1 VAE -> [1, z_dim, T_lat, H_lat, W_lat]
vae_enc = load_vae_encoder(vae_path, config)
z_video = vae_enc.encode(video[None]) # [1, 16, T_lat, H_lat, W_lat]
mx.eval(z_video)
z_video = z_video[0] # [16, T_lat, H_lat, W_lat]
# Build mask: 1 for first frame, 0 for rest -> rearrange to [4, T_lat, H, W]
msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
# Repeat first frame 4x, concat rest: [1, 4 + (F-1), H_lat, W_lat]
msk = mx.concatenate([
mx.repeat(msk[:, :1], 4, axis=1),
msk[:, 1:],
], axis=1)
# Reshape to [1, T_lat, 4, H_lat, W_lat] then transpose -> [4, T_lat, H_lat, W_lat]
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
# y = concat([mask, encoded_video]) -> [20, T_lat, H_lat, W_lat]
y_i2v = mx.concatenate([msk, z_video], axis=0)
mx.eval(y_i2v)
del vae_enc, img_arr, img_chw, video, z_video, msk
else:
# TI2V-5B: encode single image, blend with noise via mask
img_tensor = preprocess_image(image, width, height)
mx.eval(img_tensor)
vae_enc = load_vae_encoder(vae_path, config)
z_img = vae_enc.encode(img_tensor) # [1, 1, H_lat, W_lat, z_dim]
mx.eval(z_img)
z_img = z_img[0].transpose(3, 0, 1, 2) # [z_dim, 1, H_lat, W_lat]
i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size)
del vae_enc, img_tensor
gc.collect(); mx.clear_cache()
print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}")
# Load transformer models
print(f"\n{Colors.BLUE}Loading transformer model(s)...{Colors.RESET}")
if quantization:
print(f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}")
t2 = time.time()
# Merge per-model LoRAs with shared LoRAs
_loras_low = (loras or []) + (loras_low or []) or None
_loras_high = (loras or []) + (loras_high or []) or None
_loras_single = loras
if is_dual:
low_noise_path = model_dir / "low_noise_model.safetensors"
high_noise_path = model_dir / "high_noise_model.safetensors"
low_noise_model = load_wan_model(low_noise_path, config, quantization, loras=_loras_low)
high_noise_model = load_wan_model(high_noise_path, config, quantization, loras=_loras_high)
else:
single_model = load_wan_model(model_dir / "model.safetensors", config, quantization, loras=_loras_single)
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
# Precompute text embeddings once (avoids redundant MLP in every step)
# Each model has its own text_embedding weights, so dual models need separate embeddings
if cfg_disabled:
# No CFG: only compute cond embeddings (B=1 forward pass, 2x faster)
if is_dual:
context_emb_low = low_noise_model.embed_text([context])
context_emb_high = high_noise_model.embed_text([context])
mx.eval(context_emb_low, context_emb_high)
context_cond_low = context_emb_low[0:1]
context_cond_high = context_emb_high[0:1]
else:
context_emb = single_model.embed_text([context])
mx.eval(context_emb)
context_cond = context_emb[0:1]
else:
if is_dual:
context_emb_low = low_noise_model.embed_text([context, context_null])
context_emb_high = high_noise_model.embed_text([context, context_null])
mx.eval(context_emb_low, context_emb_high)
context_cfg_low = mx.concatenate([context_emb_low[0:1], context_emb_low[1:2]], axis=0)
context_cfg_high = mx.concatenate([context_emb_high[0:1], context_emb_high[1:2]], axis=0)
else:
context_emb = single_model.embed_text([context, context_null])
mx.eval(context_emb)
context_cfg = mx.concatenate([context_emb[0:1], context_emb[1:2]], axis=0)
# Precompute cross-attention K/V caches (constant across all steps)
if cfg_disabled:
if is_dual:
cross_kv_low = low_noise_model.prepare_cross_kv(context_cond_low)
cross_kv_high = high_noise_model.prepare_cross_kv(context_cond_high)
mx.eval(cross_kv_low, cross_kv_high)
else:
cross_kv = single_model.prepare_cross_kv(context_cond)
mx.eval(cross_kv)
else:
if is_dual:
cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg_low)
cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg_high)
mx.eval(cross_kv_low, cross_kv_high)
else:
cross_kv = single_model.prepare_cross_kv(context_cfg)
mx.eval(cross_kv)
# Precompute RoPE frequencies (grid sizes are constant across all steps)
f_grid = t_latent // patch_size[0]
h_grid = h_latent // patch_size[1]
w_grid = w_latent // patch_size[2]
if cfg_disabled:
rope_grid_sizes = [(f_grid, h_grid, w_grid)]
else:
rope_grid_sizes = [(f_grid, h_grid, w_grid), (f_grid, h_grid, w_grid)]
if is_dual:
rope_cos_sin_low = low_noise_model.prepare_rope(rope_grid_sizes)
rope_cos_sin_high = high_noise_model.prepare_rope(rope_grid_sizes)
mx.eval(rope_cos_sin_low, rope_cos_sin_high)
else:
rope_cos_sin = single_model.prepare_rope(rope_grid_sizes)
mx.eval(rope_cos_sin)
# Setup scheduler
_schedulers = {
"euler": FlowMatchEulerScheduler,
"dpm++": FlowDPMPP2MScheduler,
"unipc": FlowUniPCScheduler,
}
sched_cls = _schedulers.get(scheduler, FlowUniPCScheduler)
sched = sched_cls(num_train_timesteps=config.num_train_timesteps)
sched.set_timesteps(steps, shift=shift)
# Generate initial noise
noise = mx.random.normal(target_shape)
# I2V initialization: TI2V-5B blends image with noise, I2V-14B uses pure noise
if is_i2v_mask_blend:
latents = (1.0 - i2v_mask) * z_img + i2v_mask * noise
else:
latents = noise
# Boundary for model switching (dual model only)
boundary = (config.boundary * config.num_train_timesteps) if is_dual else None
# Diffusion loop
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
t3 = time.time()
# Compile model forward for faster denoising
if not no_compile:
models_to_compile = (
[high_noise_model, low_noise_model] if is_dual else [single_model]
)
for m in models_to_compile:
m._compiled = mx.compile(m)
# Pre-convert timesteps to Python list to avoid .item() sync each step
timestep_list = sched.timesteps.tolist()
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
timestep_val = timestep_list[i]
# Select model, cached K/V, and precomputed RoPE
if is_dual:
if timestep_val >= boundary:
model = high_noise_model
kv = cross_kv_high
rcs = rope_cos_sin_high
else:
model = low_noise_model
kv = cross_kv_low
rcs = rope_cos_sin_low
else:
model = single_model
kv = cross_kv
rcs = rope_cos_sin
# Use compiled forward when available (faster after first trace)
_call = getattr(model, '_compiled', model)
if cfg_disabled:
# No CFG: B=1 forward pass (2x faster than B=2 CFG batch)
if is_i2v_mask_blend:
t_tokens = i2v_mask_tokens * timestep_val
pad_len = seq_len - t_tokens.shape[1]
if pad_len > 0:
t_tokens = mx.concatenate(
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
)
t_batch = t_tokens # [1, L]
else:
t_batch = mx.array([timestep_val])
y_arg = [y_i2v] if is_i2v_channel_concat else None
if is_dual:
ctx = context_cond_high if timestep_val >= boundary else context_cond_low
else:
ctx = context_cond
preds = _call(
[latents],
t=t_batch,
context=ctx,
seq_len=seq_len,
cross_kv_caches=kv,
y=y_arg,
rope_cos_sin=rcs,
)
noise_pred = preds[0]
del preds
else:
# CFG: batch cond + uncond into single B=2 forward pass
if is_dual:
gs = guide_scale[1] if timestep_val >= boundary else guide_scale[0]
else:
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
if is_i2v_mask_blend:
t_tokens = i2v_mask_tokens * timestep_val
pad_len = seq_len - t_tokens.shape[1]
if pad_len > 0:
t_tokens = mx.concatenate(
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
)
t_batch = mx.concatenate([t_tokens, t_tokens], axis=0)
else:
t_batch = mx.array([timestep_val, timestep_val])
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
ctx = context_cfg if not is_dual else (
context_cfg_high if timestep_val >= boundary else context_cfg_low
)
preds = _call(
[latents, latents],
t=t_batch,
context=ctx,
seq_len=seq_len,
cross_kv_caches=kv,
y=y_arg,
rope_cos_sin=rcs,
)
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
del noise_pred_cond, noise_pred_uncond, preds
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
# TI2V-5B: re-apply mask to keep first frame frozen
if is_i2v_mask_blend:
latents = (1.0 - i2v_mask) * z_img + i2v_mask * latents
# Release temporaries before eval to free memory for graph execution
del noise_pred
mx.eval(latents)
print(f"{Colors.DIM} Denoising: {time.time() - t3:.1f}s{Colors.RESET}")
# Diagnostic: per-temporal-position latent statistics
if debug_latents:
lat_np = np.array(latents) # [C, T, H, W]
n_t = lat_np.shape[1]
print(f"\n{Colors.CYAN} Latent diagnostics (shape {lat_np.shape}):{Colors.RESET}")
print(f" {'Pos':>4s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s} {'AbsMean':>8s}")
for t_pos in range(min(n_t, 8)):
frame = lat_np[:, t_pos, :, :]
print(f" {t_pos:4d} {frame.mean():8.4f} {frame.std():8.4f} "
f"{frame.min():8.4f} {frame.max():8.4f} {np.abs(frame).mean():8.4f}")
if n_t > 8:
interior = lat_np[:, 4:, :, :]
print(f" {'4+':>4s} {interior.mean():8.4f} {interior.std():8.4f} "
f"{interior.min():8.4f} {interior.max():8.4f} {np.abs(interior).mean():8.4f}")
print()
# Free transformer models and text embeddings
if is_dual:
del low_noise_model, high_noise_model, cross_kv_low, cross_kv_high
if cfg_disabled:
del context_cond_low, context_cond_high
else:
del context_cfg_low, context_cfg_high
else:
del single_model, cross_kv
if cfg_disabled:
del context_cond
else:
del context_cfg
del model, kv, context
if context_null is not None:
del context_null
gc.collect(); mx.clear_cache()
# Load VAE and decode
print(f"\n{Colors.BLUE}Decoding with VAE...{Colors.RESET}")
t4 = time.time()
vae_path = model_dir / "vae.safetensors"
vae = load_vae_decoder(vae_path, config)
is_wan22_vae = config.vae_z_dim == 48
# Temporal extend: prepend reflected latent frames to the VAE input so that
# the CausalConv3d zero-padding artifacts fall on the prefix (which we crop).
# This gives the first real frame a full temporal receptive field of real data.
# Select tiling configuration
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
if tiling == "none":
tiling_config = None
elif tiling == "auto":
tiling_config = TilingConfig.auto(height, width, num_frames)
elif tiling == "default":
tiling_config = TilingConfig.default()
elif tiling == "aggressive":
tiling_config = TilingConfig.aggressive()
elif tiling == "conservative":
tiling_config = TilingConfig.conservative()
elif tiling == "spatial":
tiling_config = TilingConfig.spatial_only()
elif tiling == "temporal":
tiling_config = TilingConfig.temporal_only()
else:
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
tiling_config = TilingConfig.auto(height, width, num_frames)
if tiling_config is not None:
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
if is_wan22_vae:
from mlx_video.models.wan.vae22 import denormalize_latents
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
z = latents.transpose(1, 2, 3, 0)[None]
z = denormalize_latents(z)
if tiling_config is not None:
video = vae.decode_tiled(z, tiling_config)
else:
video = vae(z)
mx.eval(video)
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
video = np.array(video[0]) # [T', H', W', 3]
video = (video + 1.0) / 2.0
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
else:
if tiling_config is not None:
video = vae.decode_tiled(latents[None], tiling_config)
else:
video = vae.decode(latents[None])
mx.eval(video)
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
video = np.array(video[0]) # [3, T', H, W]
video = (video + 1.0) / 2.0
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
video = video.transpose(1, 2, 3, 0) # [T, H, W, 3]
# Trim first N temporal chunks if requested (avoids first-frame artifacts)
if trim_first_frames > 0:
trim_pixels = trim_first_frames * 4
video = video[trim_pixels:]
print(f"{Colors.DIM} Trimmed first {trim_pixels} frames ({video.shape[0]} remaining){Colors.RESET}")
save_video(video, output_path, fps=config.sample_fps)
print(f"\n{Colors.GREEN}✓ Video saved to {output_path}{Colors.RESET}")
print(f"{Colors.DIM} Total time: {time.time() - t1:.1f}s{Colors.RESET}")
def main():
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory")
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
parser.add_argument("--image", type=str, default=None,
help="Path to input image for I2V (omit for T2V mode)")
parser.add_argument("--negative-prompt", type=str, default=None,
help="Negative prompt for CFG (default: official Chinese prompt from config)")
parser.add_argument("--no-negative-prompt", action="store_true",
help="Disable negative prompt (use empty string instead of config default)")
parser.add_argument("--width", type=int, default=1280, help="Video width (default: 1280)")
parser.add_argument("--height", type=int, default=704, help="Video height (default: 704; 720p models use 704)")
parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)")
parser.add_argument("--steps", type=int, default=None, help="Number of diffusion steps (default: from config)")
parser.add_argument("--guide-scale", type=str, default=None, help="Guidance scale: single float or low,high pair")
parser.add_argument("--shift", type=float, default=None, help="Noise schedule shift (default: from config)")
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
parser.add_argument("--output-path", type=str, default="output.mp4", help="Output video path")
parser.add_argument(
"--scheduler", type=str, default="unipc",
choices=["euler", "dpm++", "unipc"],
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
)
parser.add_argument(
"--lora", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to all models (repeatable). Format: --lora path.safetensors 0.8",
)
parser.add_argument(
"--lora-high", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to high-noise model only (dual-model, repeatable)",
)
parser.add_argument(
"--lora-low", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to low-noise model only (dual-model, repeatable)",
)
parser.add_argument(
"--tiling",
type=str,
default="auto",
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
help="VAE tiling mode to reduce memory during decoding (default: auto)",
)
parser.add_argument(
"--no-compile", action="store_true",
help="Disable mx.compile on models (for debugging)",
)
parser.add_argument(
"--trim-first-frames", type=int, default=0, metavar="N",
help="Generate N extra temporal chunks (N×4 frames) and discard them from the start. "
"Fixes first-frame color/lighting artifacts on 14B models. Try 1 first (4 frames). "
"Default: 0 (disabled)",
)
parser.add_argument(
"--debug-latents", action="store_true",
help="Print per-temporal-position latent statistics after denoising (diagnostic)",
)
args = parser.parse_args()
# Parse guide scale
guide_scale = None
if args.guide_scale is not None:
parts = [float(x) for x in args.guide_scale.split(",")]
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
# Handle negative prompt: --no-negative-prompt forces empty, otherwise pass through
neg_prompt = args.negative_prompt
if args.no_negative_prompt:
neg_prompt = ""
# Parse LoRA configs: convert [path, strength_str] → (path, float)
def _parse_lora_args(lora_list):
if not lora_list:
return None
return [(path, float(strength)) for path, strength in lora_list]
generate_video(
model_dir=args.model_dir,
prompt=args.prompt,
negative_prompt=neg_prompt,
image=args.image,
width=args.width,
height=args.height,
num_frames=args.num_frames,
steps=args.steps,
guide_scale=guide_scale,
shift=args.shift,
seed=args.seed,
output_path=args.output_path,
scheduler=args.scheduler,
loras=_parse_lora_args(args.lora),
loras_high=_parse_lora_args(args.lora_high),
loras_low=_parse_lora_args(args.lora_low),
tiling=args.tiling,
no_compile=args.no_compile,
trim_first_frames=args.trim_first_frames,
debug_latents=args.debug_latents,
)
if __name__ == "__main__":
main()