format
This commit is contained in:
@@ -7,7 +7,6 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.utils
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -57,7 +56,9 @@ def load_safetensors_weights(path: str) -> Dict[str, mx.array]:
|
||||
return weights
|
||||
|
||||
|
||||
def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
def sanitize_wan_transformer_weights(
|
||||
weights: Dict[str, mx.array]
|
||||
) -> Dict[str, mx.array]:
|
||||
"""Convert Wan2.2 transformer weight keys to MLX model structure.
|
||||
|
||||
Wan2.2 keys follow the pattern:
|
||||
@@ -246,8 +247,8 @@ def _load_lora_configs(
|
||||
|
||||
Shared between weight-merging and runtime-wrapping paths.
|
||||
"""
|
||||
from mlx_video.lora import LoRAConfig, load_multiple_loras
|
||||
from mlx_video.generate_wan import Colors
|
||||
from mlx_video.lora import LoRAConfig, load_multiple_loras
|
||||
|
||||
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
|
||||
|
||||
@@ -264,7 +265,9 @@ def _load_lora_configs(
|
||||
module_to_loras = load_multiple_loras(configs)
|
||||
|
||||
if not module_to_loras:
|
||||
print(f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}"
|
||||
)
|
||||
|
||||
return module_to_loras
|
||||
|
||||
@@ -279,8 +282,8 @@ def load_and_apply_loras(
|
||||
|
||||
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
|
||||
"""
|
||||
from mlx_video.lora import apply_loras_to_weights
|
||||
from mlx_video.generate_wan import Colors
|
||||
from mlx_video.lora import apply_loras_to_weights
|
||||
|
||||
if not lora_configs:
|
||||
return model_weights
|
||||
@@ -289,12 +292,17 @@ def load_and_apply_loras(
|
||||
if not module_to_loras:
|
||||
return model_weights
|
||||
|
||||
print(f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}"
|
||||
)
|
||||
if verbose:
|
||||
print(f" Model has {len(model_weights)} weight keys")
|
||||
|
||||
modified_weights = apply_loras_to_weights(
|
||||
model_weights, module_to_loras, verbose=verbose, quantization_bits=quantization_bits
|
||||
model_weights,
|
||||
module_to_loras,
|
||||
verbose=verbose,
|
||||
quantization_bits=quantization_bits,
|
||||
)
|
||||
|
||||
print(f"{Colors.GREEN}✓ LoRAs applied successfully{Colors.RESET}")
|
||||
@@ -435,8 +443,10 @@ def convert_wan_checkpoint(
|
||||
src_model_type = src_config.get("model_type", "t2v")
|
||||
src_text_len = src_config.get("text_len", 512)
|
||||
|
||||
print(f" Source config: dim={src_dim}, layers={src_num_layers}, "
|
||||
f"heads={src_num_heads}, type={src_model_type}")
|
||||
print(
|
||||
f" Source config: dim={src_dim}, layers={src_num_layers}, "
|
||||
f"heads={src_num_heads}, type={src_model_type}"
|
||||
)
|
||||
|
||||
# Use preset for known TI2V 5B configuration
|
||||
if src_model_type == "ti2v" and src_dim == 3072:
|
||||
@@ -513,8 +523,11 @@ def convert_wan_checkpoint(
|
||||
weights = load_torch_weights(str(vae_path))
|
||||
if is_wan22_vae:
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
include_encoder = config.model_type in ("ti2v", "i2v")
|
||||
weights = sanitize_wan22_vae_weights(weights, include_encoder=include_encoder)
|
||||
weights = sanitize_wan22_vae_weights(
|
||||
weights, include_encoder=include_encoder
|
||||
)
|
||||
else:
|
||||
weights = sanitize_wan_vae_weights(weights)
|
||||
# Always save VAE in float32 — official Wan2.2 runs VAE decode in
|
||||
@@ -527,7 +540,9 @@ def convert_wan_checkpoint(
|
||||
|
||||
# Quantize transformer weights if requested
|
||||
if quantize:
|
||||
print(f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})...")
|
||||
print(
|
||||
f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})..."
|
||||
)
|
||||
_quantize_saved_model(output_dir, config, is_dual, bits, group_size)
|
||||
|
||||
print(f"\nConversion complete! Output: {output_dir}")
|
||||
@@ -543,9 +558,16 @@ def _quantize_predicate(path: str, module) -> bool:
|
||||
return False
|
||||
# Quantize attention Q/K/V/O and FFN fc1/fc2
|
||||
quantize_patterns = (
|
||||
".self_attn.q", ".self_attn.k", ".self_attn.v", ".self_attn.o",
|
||||
".cross_attn.q", ".cross_attn.k", ".cross_attn.v", ".cross_attn.o",
|
||||
".ffn.fc1", ".ffn.fc2",
|
||||
".self_attn.q",
|
||||
".self_attn.k",
|
||||
".self_attn.v",
|
||||
".self_attn.o",
|
||||
".cross_attn.q",
|
||||
".cross_attn.k",
|
||||
".cross_attn.v",
|
||||
".cross_attn.o",
|
||||
".ffn.fc1",
|
||||
".ffn.fc2",
|
||||
)
|
||||
return any(path.endswith(p) for p in quantize_patterns)
|
||||
|
||||
@@ -684,14 +706,20 @@ def quantize_mlx_model(
|
||||
# Build model config
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config_dict = {k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__}
|
||||
config_dict = {
|
||||
k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__
|
||||
}
|
||||
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
|
||||
if key in config_dict and isinstance(config_dict[key], list):
|
||||
config_dict[key] = tuple(config_dict[key])
|
||||
config = WanModelConfig(**config_dict)
|
||||
|
||||
# Copy non-transformer files to output dir (skip large model weights)
|
||||
transformer_files = {"low_noise_model.safetensors", "high_noise_model.safetensors", "model.safetensors"}
|
||||
transformer_files = {
|
||||
"low_noise_model.safetensors",
|
||||
"high_noise_model.safetensors",
|
||||
"model.safetensors",
|
||||
}
|
||||
if dst.resolve() != src.resolve():
|
||||
dst.mkdir(parents=True, exist_ok=True)
|
||||
for f in src.iterdir():
|
||||
@@ -763,11 +791,18 @@ if __name__ == "__main__":
|
||||
|
||||
if args.quantize_only:
|
||||
quantize_mlx_model(
|
||||
args.checkpoint_dir, args.output_dir,
|
||||
bits=args.bits, group_size=args.group_size,
|
||||
args.checkpoint_dir,
|
||||
args.output_dir,
|
||||
bits=args.bits,
|
||||
group_size=args.group_size,
|
||||
)
|
||||
else:
|
||||
convert_wan_checkpoint(
|
||||
args.checkpoint_dir, args.output_dir, args.dtype, args.model_version,
|
||||
quantize=args.quantize, bits=args.bits, group_size=args.group_size,
|
||||
args.checkpoint_dir,
|
||||
args.output_dir,
|
||||
args.dtype,
|
||||
args.model_version,
|
||||
quantize=args.quantize,
|
||||
bits=args.bits,
|
||||
group_size=args.group_size,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user