Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.

This commit is contained in:
Prince Canuma
2026-03-18 17:59:43 +01:00
parent b029668cd2
commit 996a542011
37 changed files with 354 additions and 354 deletions

View File

@@ -0,0 +1,808 @@
"""Weight conversion for Wan2.2 models (PyTorch -> MLX)."""
import gc
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.utils
logger = logging.getLogger(__name__)
def load_torch_weights(path: str) -> Dict[str, mx.array]:
"""Load PyTorch .pth weights and convert to MLX arrays.
Args:
path: Path to .pth file
Returns:
Dictionary of MLX arrays
"""
try:
import torch
except ImportError:
raise ImportError("PyTorch is required to load .pth weights: pip install torch")
logging.info(f"Loading weights from {path}")
state_dict = torch.load(path, map_location="cpu", weights_only=True)
weights = {}
for key, value in state_dict.items():
if isinstance(value, torch.Tensor):
np_val = value.detach().float().numpy()
weights[key] = mx.array(np_val)
return weights
def load_safetensors_weights(path: str) -> Dict[str, mx.array]:
"""Load safetensors weights as MLX arrays.
Args:
path: Path to directory with safetensors files or single file
Returns:
Dictionary of MLX arrays
"""
path = Path(path)
weights = {}
if path.is_file():
weights = mx.load(str(path))
elif path.is_dir():
for sf in sorted(path.glob("*.safetensors")):
weights.update(mx.load(str(sf)))
return weights
def sanitize_wan_transformer_weights(
weights: Dict[str, mx.array]
) -> Dict[str, mx.array]:
"""Convert Wan2.2 transformer weight keys to MLX model structure.
Wan2.2 keys follow the pattern:
patch_embedding.weight/bias
text_embedding.{0,2}.weight/bias
time_embedding.{0,2}.weight/bias
time_projection.1.weight/bias
blocks.{i}.norm1.weight
blocks.{i}.self_attn.{q,k,v,o}.weight/bias
blocks.{i}.self_attn.norm_q.weight
blocks.{i}.self_attn.norm_k.weight
blocks.{i}.norm3.weight/bias (if cross_attn_norm)
blocks.{i}.cross_attn.{q,k,v,o}.weight/bias
blocks.{i}.cross_attn.norm_q.weight
blocks.{i}.cross_attn.norm_k.weight
blocks.{i}.norm2.weight
blocks.{i}.ffn.{0,2}.weight/bias
blocks.{i}.modulation
head.norm.weight
head.head.weight/bias
head.modulation
freqs (buffer)
MLX model uses:
patch_embedding_proj.weight/bias (after patchify reshape)
text_embedding_0.weight/bias, text_embedding_1.weight/bias
time_embedding_0.weight/bias, time_embedding_1.weight/bias
time_projection.weight/bias
blocks.{i}.norm1.weight
blocks.{i}.self_attn.{q,k,v,o}.weight/bias
etc.
"""
sanitized = {}
consumed = set()
for key, value in weights.items():
new_key = key
# Patch embedding: Conv3d(16, 5120, (1,2,2)) weight is [O, I, D, H, W]
# MLX Linear expects [O, I*D*H*W] after we flatten in patchify
if key == "patch_embedding.weight":
# Original: [dim, in_dim, 1, 2, 2] -> reshape to [dim, in_dim*1*2*2]
value = value.reshape(value.shape[0], -1)
new_key = "patch_embedding_proj.weight"
sanitized[new_key] = value
consumed.add(key)
continue
if key == "patch_embedding.bias":
new_key = "patch_embedding_proj.bias"
sanitized[new_key] = value
consumed.add(key)
continue
# Text embedding Sequential: 0=Linear, 1=GELU(no params), 2=Linear
if key.startswith("text_embedding.0."):
new_key = key.replace("text_embedding.0.", "text_embedding_0.")
sanitized[new_key] = value
consumed.add(key)
continue
if key.startswith("text_embedding.2."):
new_key = key.replace("text_embedding.2.", "text_embedding_1.")
sanitized[new_key] = value
consumed.add(key)
continue
# Time embedding Sequential: 0=Linear, 1=SiLU(no params), 2=Linear
if key.startswith("time_embedding.0."):
new_key = key.replace("time_embedding.0.", "time_embedding_0.")
sanitized[new_key] = value
consumed.add(key)
continue
if key.startswith("time_embedding.2."):
new_key = key.replace("time_embedding.2.", "time_embedding_1.")
sanitized[new_key] = value
consumed.add(key)
continue
# Time projection Sequential: 0=SiLU(no params), 1=Linear
if key.startswith("time_projection.1."):
new_key = key.replace("time_projection.1.", "time_projection.")
sanitized[new_key] = value
consumed.add(key)
continue
# FFN: Sequential(Linear, GELU, Linear) -> ffn.{0,2} -> ffn.fc1, ffn.fc2
new_key = new_key.replace(".ffn.0.", ".ffn.fc1.")
new_key = new_key.replace(".ffn.2.", ".ffn.fc2.")
# Skip the freqs buffer (we compute it in the model)
if key == "freqs":
consumed.add(key)
continue
sanitized[new_key] = value
consumed.add(key)
unconsumed = set(weights.keys()) - consumed
if unconsumed:
logger.warning("Unconsumed transformer weight keys: %s", sorted(unconsumed))
return sanitized
def sanitize_wan_t5_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Convert Wan2.2 T5 encoder weight keys to MLX T5Encoder structure.
Wan2.2 T5 keys:
token_embedding.weight
pos_embedding.embedding.weight (if shared_pos)
blocks.{i}.norm1.weight
blocks.{i}.attn.{q,k,v,o}.weight
blocks.{i}.norm2.weight
blocks.{i}.ffn.gate.0.weight (gate linear)
blocks.{i}.ffn.fc1.weight
blocks.{i}.ffn.fc2.weight
blocks.{i}.pos_embedding.embedding.weight (if not shared_pos)
norm.weight
MLX T5Encoder structure:
token_embedding.weight
blocks.{i}.norm1.weight
blocks.{i}.attn.{q,k,v,o}.weight
blocks.{i}.norm2.weight
blocks.{i}.ffn.gate_proj.weight (mapped from gate.0)
blocks.{i}.ffn.fc1.weight
blocks.{i}.ffn.fc2.weight
blocks.{i}.pos_embedding.embedding.weight
norm.weight
"""
sanitized = {}
consumed = set()
for key, value in weights.items():
new_key = key
# Map gate.0 -> gate_proj (the GELU is a separate module, not a parameter)
new_key = new_key.replace(".ffn.gate.0.", ".ffn.gate_proj.")
sanitized[new_key] = value
consumed.add(key)
unconsumed = set(weights.keys()) - consumed
if unconsumed:
logger.warning("Unconsumed T5 weight keys: %s", sorted(unconsumed))
return sanitized
def sanitize_wan_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Convert Wan2.2 VAE weight keys to MLX WanVAE structure.
Handles Conv3d and Conv2d weight transpositions for MLX format.
"""
sanitized = {}
consumed = set()
for key, value in weights.items():
new_key = key
# Handle Conv3d: PyTorch [O, I, D, H, W] -> MLX CausalConv3d weight [O, D, H, W, I]
if "weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Handle Conv2d: PyTorch [O, I, H, W] -> MLX [O, H, W, I]
if "weight" in key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
# Map decoder keys to MLX decoder structure
# Wan2.2 uses encoder/decoder with downsamples/upsamples
# Need to adapt naming for our simplified structure
sanitized[new_key] = value
consumed.add(key)
unconsumed = set(weights.keys()) - consumed
if unconsumed:
logger.warning("Unconsumed VAE weight keys: %s", sorted(unconsumed))
return sanitized
def _load_lora_configs(
lora_configs: List[Tuple[str, float]],
) -> Dict[str, list]:
"""Load LoRA weights from config tuples, returning module_to_loras dict.
Shared between weight-merging and runtime-wrapping paths.
"""
from mlx_video.models.wan_2.generate import Colors
from mlx_video.lora import LoRAConfig, load_multiple_loras
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
configs = []
for lora_path, strength in lora_configs:
try:
config = LoRAConfig(path=lora_path, strength=strength)
configs.append(config)
print(f" - {Path(lora_path).name} (strength: {strength})")
except Exception as e:
print(f"{Colors.RED}Error loading LoRA {lora_path}: {e}{Colors.RESET}")
raise
module_to_loras = load_multiple_loras(configs)
if not module_to_loras:
print(
f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}"
)
return module_to_loras
def load_and_apply_loras(
model_weights: Dict[str, mx.array],
lora_configs: Optional[List[Tuple[str, float]]] = None,
verbose: bool = False,
quantization_bits: int = 0,
) -> Dict[str, mx.array]:
"""Load and apply LoRA weights to model weights by merging into weight dict.
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
"""
from mlx_video.models.wan_2.generate import Colors
from mlx_video.lora import apply_loras_to_weights
if not lora_configs:
return model_weights
module_to_loras = _load_lora_configs(lora_configs)
if not module_to_loras:
return model_weights
print(
f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}"
)
if verbose:
print(f" Model has {len(model_weights)} weight keys")
modified_weights = apply_loras_to_weights(
model_weights,
module_to_loras,
verbose=verbose,
quantization_bits=quantization_bits,
)
print(f"{Colors.GREEN}✓ LoRAs applied successfully{Colors.RESET}")
return modified_weights
def convert_wan_checkpoint(
checkpoint_dir: str,
output_dir: str,
dtype: str = "bfloat16",
model_version: str = "auto",
quantize: bool = False,
bits: int = 4,
group_size: int = 64,
):
"""Convert a Wan2.1 or Wan2.2 checkpoint directory to MLX format.
Wan2.2 expected structure:
checkpoint_dir/
models_t5_umt5-xxl-enc-bf16.pth
Wan2.1_VAE.pth
low_noise_model/ (safetensors)
high_noise_model/ (safetensors)
Wan2.1 expected structure:
checkpoint_dir/
models_t5_umt5-xxl-enc-bf16.pth
Wan2.1_VAE.pth
diffusion_pytorch_model*.safetensors (single model)
Args:
checkpoint_dir: Path to Wan checkpoint directory
output_dir: Path to output MLX model directory
dtype: Target dtype
model_version: "2.1", "2.2", or "auto" (detect from directory)
quantize: Whether to quantize the transformer weights
bits: Quantization bits (4 or 8)
group_size: Quantization group size (32, 64, or 128)
"""
import json
checkpoint_dir = Path(checkpoint_dir)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
dtype_map = {
"float16": mx.float16,
"float32": mx.float32,
"bfloat16": mx.bfloat16,
}
target_dtype = dtype_map.get(dtype, mx.bfloat16)
# Auto-detect version
if model_version == "auto":
if (checkpoint_dir / "low_noise_model").exists():
model_version = "2.2"
elif (checkpoint_dir / "Wan2.2_VAE.pth").exists():
model_version = "2.2"
else:
model_version = "2.1"
print(f"Auto-detected Wan{model_version} checkpoint")
is_dual = (checkpoint_dir / "low_noise_model").exists()
if is_dual:
# Wan2.2: Convert dual transformer models
low_noise_path = checkpoint_dir / "low_noise_model"
if low_noise_path.exists():
print("Converting low-noise transformer...")
weights = load_safetensors_weights(str(low_noise_path))
weights = sanitize_wan_transformer_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "low_noise_model.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
high_noise_path = checkpoint_dir / "high_noise_model"
if high_noise_path.exists():
print("Converting high-noise transformer...")
weights = load_safetensors_weights(str(high_noise_path))
weights = sanitize_wan_transformer_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "high_noise_model.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
else:
# Wan2.1: Convert single transformer model
# Try safetensors in the checkpoint dir itself
print("Converting transformer (single model)...")
weights = load_safetensors_weights(str(checkpoint_dir))
if not weights:
# Fallback: look for .pth files
for pth in sorted(checkpoint_dir.glob("*.pth")):
if "t5" not in pth.name.lower() and "vae" not in pth.name.lower():
print(f" Loading from {pth.name}...")
weights = load_torch_weights(str(pth))
break
if weights:
weights = sanitize_wan_transformer_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "model.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
else:
print(" Warning: No transformer weights found!")
# Save config — detect model size from source config.json or transformer weights
from mlx_video.models.wan_2.config import WanModelConfig
def _detect_config():
"""Detect config from source config.json or transformer weight shapes."""
if is_dual:
# Check source config.json for model_type (I2V vs T2V)
src_cfg_path = checkpoint_dir / "high_noise_model" / "config.json"
if src_cfg_path.exists():
with open(src_cfg_path) as f:
src_config = json.load(f)
src_model_type = src_config.get("model_type", "t2v")
if src_model_type == "i2v" or src_config.get("in_dim") == 36:
return WanModelConfig.wan22_i2v_14b()
return WanModelConfig.wan22_t2v_14b()
# Try reading source config.json first (most reliable)
src_cfg_path = checkpoint_dir / "config.json"
src_config = None
if src_cfg_path.exists():
with open(src_cfg_path) as f:
src_config = json.load(f)
if src_config and "dim" in src_config:
src_dim = src_config.get("dim", 5120)
src_in_dim = src_config.get("in_dim", 16)
src_out_dim = src_config.get("out_dim", 16)
src_ffn_dim = src_config.get("ffn_dim", 13824)
src_num_heads = src_config.get("num_heads", 40)
src_num_layers = src_config.get("num_layers", 40)
src_model_type = src_config.get("model_type", "t2v")
src_text_len = src_config.get("text_len", 512)
print(
f" Source config: dim={src_dim}, layers={src_num_layers}, "
f"heads={src_num_heads}, type={src_model_type}"
)
# Use preset for known TI2V 5B configuration
if src_model_type == "ti2v" and src_dim == 3072:
return WanModelConfig.wan22_ti2v_5b()
is_22 = model_version == "2.2"
# Wan2.2 uses different VAE with z_dim=48 and stride (4,16,16)
vae_z = 48 if is_22 else 16
vae_s = (4, 16, 16) if is_22 else (4, 8, 8)
fps = 24 if is_22 else 16
return WanModelConfig(
model_type=src_model_type,
model_version=model_version,
dim=src_dim,
ffn_dim=src_ffn_dim,
in_dim=src_in_dim,
out_dim=src_out_dim,
num_heads=src_num_heads,
num_layers=src_num_layers,
text_len=src_text_len,
vae_z_dim=vae_z,
vae_stride=vae_s,
dual_model=False,
boundary=0.0,
sample_shift=5.0,
sample_steps=50,
sample_guide_scale=5.0,
sample_fps=fps,
)
# Fallback: detect from saved transformer weight shapes
saved_model = output_dir / "model.safetensors"
if saved_model.exists():
det_weights = mx.load(str(saved_model))
dim = None
for k, v in det_weights.items():
if "patch_embedding_proj.weight" in k:
dim = v.shape[0]
break
del det_weights
if dim is not None and dim <= 2048:
print(f" Auto-detected 1.3B model (dim={dim})")
return WanModelConfig.wan21_t2v_1_3b()
return WanModelConfig.wan21_t2v_14b()
config = _detect_config()
config_path = output_dir / "config.json"
with open(config_path, "w") as f:
json.dump(config.to_dict(), f, indent=2)
print(f" Saved config to {config_path}")
# Convert T5 encoder
t5_path = checkpoint_dir / "models_t5_umt5-xxl-enc-bf16.pth"
if t5_path.exists():
print("Converting T5 encoder...")
weights = load_torch_weights(str(t5_path))
weights = sanitize_wan_t5_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "t5_encoder.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
# Convert VAE (check both naming conventions)
vae_path = checkpoint_dir / "Wan2.1_VAE.pth"
is_wan22_vae = False
if not vae_path.exists():
vae_path = checkpoint_dir / "Wan2.2_VAE.pth"
is_wan22_vae = True
if vae_path.exists():
print(f"Converting VAE ({'Wan2.2' if is_wan22_vae else 'Wan2.1'})...")
weights = load_torch_weights(str(vae_path))
if is_wan22_vae:
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
include_encoder = config.model_type in ("ti2v", "i2v")
weights = sanitize_wan22_vae_weights(
weights, include_encoder=include_encoder
)
else:
weights = sanitize_wan_vae_weights(weights)
# Always save VAE in float32 — official Wan2.2 runs VAE decode in
# float32 (dtype=torch.float). Saving in bfloat16 loses precision
# that cannot be recovered by upcasting at load time.
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
out_path = output_dir / "vae.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path} (float32)")
# Quantize transformer weights if requested
if quantize:
print(
f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})..."
)
_quantize_saved_model(output_dir, config, is_dual, bits, group_size)
print(f"\nConversion complete! Output: {output_dir}")
def _quantize_predicate(path: str, module) -> bool:
"""Return True for layers that should be quantized.
Targets heavyweight Linear layers in attention and FFN blocks.
Skips embeddings, norms, head, and modulation (small, precision-sensitive).
"""
if not hasattr(module, "to_quantized"):
return False
# Quantize attention Q/K/V/O and FFN fc1/fc2
quantize_patterns = (
".self_attn.q",
".self_attn.k",
".self_attn.v",
".self_attn.o",
".cross_attn.q",
".cross_attn.k",
".cross_attn.v",
".cross_attn.o",
".ffn.fc1",
".ffn.fc2",
)
return any(path.endswith(p) for p in quantize_patterns)
def _quantize_saved_model(
output_dir: Path,
config,
is_dual: bool,
bits: int,
group_size: int,
source_dir: Path = None,
):
"""Load saved bf16 model, quantize, and re-save.
Args:
output_dir: Directory to write quantized weights to.
config: WanModelConfig for creating the model.
is_dual: Whether this is a dual-expert model.
bits: Quantization bits.
group_size: Quantization group size.
source_dir: Directory to read bf16 weights from. Defaults to output_dir.
"""
import json
import mlx.nn as nn
from mlx_video.models.wan_2.wan_2 import WanModel
if source_dir is None:
source_dir = output_dir
model_names = []
if is_dual:
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
if (source_dir / name).exists():
model_names.append(name)
else:
if (source_dir / "model.safetensors").exists():
model_names.append("model.safetensors")
for name in model_names:
print(f" Quantizing {name}...")
model = WanModel(config)
weights = mx.load(str(source_dir / name))
model.load_weights(list(weights.items()), strict=False)
mx.eval(model.parameters())
del weights
gc.collect()
mx.clear_cache()
# Apply quantization to targeted layers
nn.quantize(
model,
group_size=group_size,
bits=bits,
class_predicate=lambda path, m: _quantize_predicate(path, m),
)
# Save quantized weights
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
# Validate: check for NaN/Inf in bias tensors (corruption canary)
bad_keys = []
for k, v in weights_dict.items():
if k.endswith(".bias") and not k.endswith(".biases"):
mx.eval(v)
if mx.any(mx.isnan(v)).item() or mx.any(mx.isinf(v)).item():
bad_keys.append(k)
if bad_keys:
raise RuntimeError(
f"Quantization produced corrupted weights in {model_path.name}: "
f"{len(bad_keys)} bias tensors contain NaN/Inf "
f"(e.g. {bad_keys[0]}). Try re-running with more available memory."
)
mx.save_safetensors(str(output_dir / name), weights_dict)
n_quantized = sum(1 for k in weights_dict if ".scales" in k)
print(f" {n_quantized} layers quantized, {len(weights_dict)} tensors saved")
# Free model before processing next file
del model, weights_dict
gc.collect()
mx.clear_cache()
# Update config.json with quantization metadata
config_path = output_dir / "config.json"
with open(config_path) as f:
cfg = json.load(f)
cfg["quantization"] = {
"group_size": group_size,
"bits": bits,
}
with open(config_path, "w") as f:
json.dump(cfg, f, indent=2)
print(f" Updated config.json with quantization metadata")
def quantize_mlx_model(
mlx_model_dir: str,
output_dir: str,
bits: int = 4,
group_size: int = 64,
):
"""Quantize an already-converted MLX model (skips PyTorch conversion).
Args:
mlx_model_dir: Path to existing MLX model directory (bf16/fp16).
output_dir: Path to output quantized model directory.
bits: Quantization bits (4 or 8).
group_size: Quantization group size (32, 64, or 128).
"""
import json
import shutil
src = Path(mlx_model_dir)
dst = Path(output_dir)
config_path = src / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"No config.json found in {src}")
with open(config_path) as f:
cfg = json.load(f)
if cfg.get("quantization"):
raise ValueError(
f"Model at {src} is already quantized "
f"({cfg['quantization']['bits']}-bit). Use a bf16/fp16 source."
)
# Detect dual vs single expert
is_dual = (src / "low_noise_model.safetensors").exists() and (
src / "high_noise_model.safetensors"
).exists()
# Build model config
from mlx_video.models.wan_2.config import WanModelConfig
config_dict = {
k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__
}
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
if key in config_dict and isinstance(config_dict[key], list):
config_dict[key] = tuple(config_dict[key])
config = WanModelConfig(**config_dict)
# Copy non-transformer files to output dir (skip large model weights)
transformer_files = {
"low_noise_model.safetensors",
"high_noise_model.safetensors",
"model.safetensors",
}
if dst.resolve() != src.resolve():
dst.mkdir(parents=True, exist_ok=True)
for f in src.iterdir():
if f.is_file() and f.name not in transformer_files:
shutil.copy2(f, dst / f.name)
print(f"Copied non-transformer files from {src} to {dst}")
print(f"Quantizing transformer weights ({bits}-bit, group_size={group_size})...")
_quantize_saved_model(dst, config, is_dual, bits, group_size, source_dir=src)
print(f"\nQuantization complete! Output: {dst}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Convert Wan model to MLX format")
parser.add_argument(
"--checkpoint-dir",
type=str,
required=True,
help="Path to Wan checkpoint directory",
)
parser.add_argument(
"--output-dir",
type=str,
default="wan_mlx_model",
help="Output path for MLX model",
)
parser.add_argument(
"--dtype",
type=str,
choices=["float16", "float32", "bfloat16"],
default="bfloat16",
help="Target dtype",
)
parser.add_argument(
"--model-version",
type=str,
choices=["2.1", "2.2", "auto"],
default="auto",
help="Wan model version (auto-detect by default)",
)
parser.add_argument(
"--quantize",
action="store_true",
help="Quantize transformer weights for faster inference",
)
parser.add_argument(
"--quantize-only",
action="store_true",
help="Quantize an already-converted MLX model (skips PyTorch conversion)",
)
parser.add_argument(
"--bits",
type=int,
choices=[4, 8],
default=4,
help="Quantization bits (default: 4)",
)
parser.add_argument(
"--group-size",
type=int,
choices=[32, 64, 128],
default=64,
help="Quantization group size (default: 64)",
)
args = parser.parse_args()
if args.quantize_only:
quantize_mlx_model(
args.checkpoint_dir,
args.output_dir,
bits=args.bits,
group_size=args.group_size,
)
else:
convert_wan_checkpoint(
args.checkpoint_dir,
args.output_dir,
args.dtype,
args.model_version,
quantize=args.quantize,
bits=args.bits,
group_size=args.group_size,
)