This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -7,7 +7,6 @@ from typing import Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.utils
import numpy as np
logger = logging.getLogger(__name__)
@@ -57,7 +56,9 @@ def load_safetensors_weights(path: str) -> Dict[str, mx.array]:
return weights
def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
def sanitize_wan_transformer_weights(
weights: Dict[str, mx.array]
) -> Dict[str, mx.array]:
"""Convert Wan2.2 transformer weight keys to MLX model structure.
Wan2.2 keys follow the pattern:
@@ -246,8 +247,8 @@ def _load_lora_configs(
Shared between weight-merging and runtime-wrapping paths.
"""
from mlx_video.lora import LoRAConfig, load_multiple_loras
from mlx_video.generate_wan import Colors
from mlx_video.lora import LoRAConfig, load_multiple_loras
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
@@ -264,7 +265,9 @@ def _load_lora_configs(
module_to_loras = load_multiple_loras(configs)
if not module_to_loras:
print(f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}")
print(
f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}"
)
return module_to_loras
@@ -279,8 +282,8 @@ def load_and_apply_loras(
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
"""
from mlx_video.lora import apply_loras_to_weights
from mlx_video.generate_wan import Colors
from mlx_video.lora import apply_loras_to_weights
if not lora_configs:
return model_weights
@@ -289,12 +292,17 @@ def load_and_apply_loras(
if not module_to_loras:
return model_weights
print(f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}")
print(
f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}"
)
if verbose:
print(f" Model has {len(model_weights)} weight keys")
modified_weights = apply_loras_to_weights(
model_weights, module_to_loras, verbose=verbose, quantization_bits=quantization_bits
model_weights,
module_to_loras,
verbose=verbose,
quantization_bits=quantization_bits,
)
print(f"{Colors.GREEN}✓ LoRAs applied successfully{Colors.RESET}")
@@ -435,8 +443,10 @@ def convert_wan_checkpoint(
src_model_type = src_config.get("model_type", "t2v")
src_text_len = src_config.get("text_len", 512)
print(f" Source config: dim={src_dim}, layers={src_num_layers}, "
f"heads={src_num_heads}, type={src_model_type}")
print(
f" Source config: dim={src_dim}, layers={src_num_layers}, "
f"heads={src_num_heads}, type={src_model_type}"
)
# Use preset for known TI2V 5B configuration
if src_model_type == "ti2v" and src_dim == 3072:
@@ -513,8 +523,11 @@ def convert_wan_checkpoint(
weights = load_torch_weights(str(vae_path))
if is_wan22_vae:
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
include_encoder = config.model_type in ("ti2v", "i2v")
weights = sanitize_wan22_vae_weights(weights, include_encoder=include_encoder)
weights = sanitize_wan22_vae_weights(
weights, include_encoder=include_encoder
)
else:
weights = sanitize_wan_vae_weights(weights)
# Always save VAE in float32 — official Wan2.2 runs VAE decode in
@@ -527,7 +540,9 @@ def convert_wan_checkpoint(
# Quantize transformer weights if requested
if quantize:
print(f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})...")
print(
f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})..."
)
_quantize_saved_model(output_dir, config, is_dual, bits, group_size)
print(f"\nConversion complete! Output: {output_dir}")
@@ -543,9 +558,16 @@ def _quantize_predicate(path: str, module) -> bool:
return False
# Quantize attention Q/K/V/O and FFN fc1/fc2
quantize_patterns = (
".self_attn.q", ".self_attn.k", ".self_attn.v", ".self_attn.o",
".cross_attn.q", ".cross_attn.k", ".cross_attn.v", ".cross_attn.o",
".ffn.fc1", ".ffn.fc2",
".self_attn.q",
".self_attn.k",
".self_attn.v",
".self_attn.o",
".cross_attn.q",
".cross_attn.k",
".cross_attn.v",
".cross_attn.o",
".ffn.fc1",
".ffn.fc2",
)
return any(path.endswith(p) for p in quantize_patterns)
@@ -684,14 +706,20 @@ def quantize_mlx_model(
# Build model config
from mlx_video.models.wan.config import WanModelConfig
config_dict = {k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__}
config_dict = {
k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__
}
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
if key in config_dict and isinstance(config_dict[key], list):
config_dict[key] = tuple(config_dict[key])
config = WanModelConfig(**config_dict)
# Copy non-transformer files to output dir (skip large model weights)
transformer_files = {"low_noise_model.safetensors", "high_noise_model.safetensors", "model.safetensors"}
transformer_files = {
"low_noise_model.safetensors",
"high_noise_model.safetensors",
"model.safetensors",
}
if dst.resolve() != src.resolve():
dst.mkdir(parents=True, exist_ok=True)
for f in src.iterdir():
@@ -763,11 +791,18 @@ if __name__ == "__main__":
if args.quantize_only:
quantize_mlx_model(
args.checkpoint_dir, args.output_dir,
bits=args.bits, group_size=args.group_size,
args.checkpoint_dir,
args.output_dir,
bits=args.bits,
group_size=args.group_size,
)
else:
convert_wan_checkpoint(
args.checkpoint_dir, args.output_dir, args.dtype, args.model_version,
quantize=args.quantize, bits=args.bits, group_size=args.group_size,
args.checkpoint_dir,
args.output_dir,
args.dtype,
args.model_version,
quantize=args.quantize,
bits=args.bits,
group_size=args.group_size,
)