feat(wan): Add LoRA with improved quantization pipeline
This commit is contained in:
@@ -1,13 +1,16 @@
|
||||
"""Weight conversion for Wan2.2 models (PyTorch -> MLX)."""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.utils
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_torch_weights(path: str) -> Dict[str, mx.array]:
|
||||
"""Load PyTorch .pth weights and convert to MLX arrays.
|
||||
@@ -88,6 +91,7 @@ def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str,
|
||||
etc.
|
||||
"""
|
||||
sanitized = {}
|
||||
consumed = set()
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
@@ -99,36 +103,43 @@ def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str,
|
||||
value = value.reshape(value.shape[0], -1)
|
||||
new_key = "patch_embedding_proj.weight"
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
if key == "patch_embedding.bias":
|
||||
new_key = "patch_embedding_proj.bias"
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
# Text embedding Sequential: 0=Linear, 1=GELU(no params), 2=Linear
|
||||
if key.startswith("text_embedding.0."):
|
||||
new_key = key.replace("text_embedding.0.", "text_embedding_0.")
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
if key.startswith("text_embedding.2."):
|
||||
new_key = key.replace("text_embedding.2.", "text_embedding_1.")
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
# Time embedding Sequential: 0=Linear, 1=SiLU(no params), 2=Linear
|
||||
if key.startswith("time_embedding.0."):
|
||||
new_key = key.replace("time_embedding.0.", "time_embedding_0.")
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
if key.startswith("time_embedding.2."):
|
||||
new_key = key.replace("time_embedding.2.", "time_embedding_1.")
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
# Time projection Sequential: 0=SiLU(no params), 1=Linear
|
||||
if key.startswith("time_projection.1."):
|
||||
new_key = key.replace("time_projection.1.", "time_projection.")
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
# FFN: Sequential(Linear, GELU, Linear) -> ffn.{0,2} -> ffn.fc1, ffn.fc2
|
||||
@@ -137,9 +148,15 @@ def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str,
|
||||
|
||||
# Skip the freqs buffer (we compute it in the model)
|
||||
if key == "freqs":
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
|
||||
unconsumed = set(weights.keys()) - consumed
|
||||
if unconsumed:
|
||||
logger.warning("Unconsumed transformer weight keys: %s", sorted(unconsumed))
|
||||
|
||||
return sanitized
|
||||
|
||||
@@ -171,6 +188,7 @@ def sanitize_wan_t5_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]
|
||||
norm.weight
|
||||
"""
|
||||
sanitized = {}
|
||||
consumed = set()
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
@@ -179,6 +197,11 @@ def sanitize_wan_t5_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]
|
||||
new_key = new_key.replace(".ffn.gate.0.", ".ffn.gate_proj.")
|
||||
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
|
||||
unconsumed = set(weights.keys()) - consumed
|
||||
if unconsumed:
|
||||
logger.warning("Unconsumed T5 weight keys: %s", sorted(unconsumed))
|
||||
|
||||
return sanitized
|
||||
|
||||
@@ -189,6 +212,7 @@ def sanitize_wan_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array
|
||||
Handles Conv3d and Conv2d weight transpositions for MLX format.
|
||||
"""
|
||||
sanitized = {}
|
||||
consumed = set()
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
@@ -206,10 +230,78 @@ def sanitize_wan_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array
|
||||
# Need to adapt naming for our simplified structure
|
||||
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
|
||||
unconsumed = set(weights.keys()) - consumed
|
||||
if unconsumed:
|
||||
logger.warning("Unconsumed VAE weight keys: %s", sorted(unconsumed))
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def _load_lora_configs(
|
||||
lora_configs: List[Tuple[str, float]],
|
||||
) -> Dict[str, list]:
|
||||
"""Load LoRA weights from config tuples, returning module_to_loras dict.
|
||||
|
||||
Shared between weight-merging and runtime-wrapping paths.
|
||||
"""
|
||||
from mlx_video.lora import LoRAConfig, load_multiple_loras
|
||||
from mlx_video.utils import Colors
|
||||
|
||||
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
|
||||
|
||||
configs = []
|
||||
for lora_path, strength in lora_configs:
|
||||
try:
|
||||
config = LoRAConfig(path=lora_path, strength=strength)
|
||||
configs.append(config)
|
||||
print(f" - {Path(lora_path).name} (strength: {strength})")
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}Error loading LoRA {lora_path}: {e}{Colors.RESET}")
|
||||
raise
|
||||
|
||||
module_to_loras = load_multiple_loras(configs)
|
||||
|
||||
if not module_to_loras:
|
||||
print(f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}")
|
||||
|
||||
return module_to_loras
|
||||
|
||||
|
||||
def load_and_apply_loras(
|
||||
model_weights: Dict[str, mx.array],
|
||||
lora_configs: Optional[List[Tuple[str, float]]] = None,
|
||||
verbose: bool = False,
|
||||
quantization_bits: int = 0,
|
||||
) -> Dict[str, mx.array]:
|
||||
"""Load and apply LoRA weights to model weights by merging into weight dict.
|
||||
|
||||
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
|
||||
"""
|
||||
from mlx_video.lora import apply_loras_to_weights
|
||||
from mlx_video.utils import Colors
|
||||
|
||||
if not lora_configs:
|
||||
return model_weights
|
||||
|
||||
module_to_loras = _load_lora_configs(lora_configs)
|
||||
if not module_to_loras:
|
||||
return model_weights
|
||||
|
||||
print(f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}")
|
||||
if verbose:
|
||||
print(f" Model has {len(model_weights)} weight keys")
|
||||
|
||||
modified_weights = apply_loras_to_weights(
|
||||
model_weights, module_to_loras, verbose=verbose, quantization_bits=quantization_bits
|
||||
)
|
||||
|
||||
print(f"{Colors.GREEN}✓ LoRAs applied successfully{Colors.RESET}")
|
||||
|
||||
return modified_weights
|
||||
|
||||
|
||||
def convert_wan_checkpoint(
|
||||
checkpoint_dir: str,
|
||||
output_dir: str,
|
||||
@@ -464,30 +556,45 @@ def _quantize_saved_model(
|
||||
is_dual: bool,
|
||||
bits: int,
|
||||
group_size: int,
|
||||
source_dir: Path = None,
|
||||
):
|
||||
"""Load saved bf16 model, quantize, and re-save."""
|
||||
"""Load saved bf16 model, quantize, and re-save.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to write quantized weights to.
|
||||
config: WanModelConfig for creating the model.
|
||||
is_dual: Whether this is a dual-expert model.
|
||||
bits: Quantization bits.
|
||||
group_size: Quantization group size.
|
||||
source_dir: Directory to read bf16 weights from. Defaults to output_dir.
|
||||
"""
|
||||
import json
|
||||
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
model_files = []
|
||||
if source_dir is None:
|
||||
source_dir = output_dir
|
||||
|
||||
model_names = []
|
||||
if is_dual:
|
||||
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
|
||||
p = output_dir / name
|
||||
if p.exists():
|
||||
model_files.append(p)
|
||||
if (source_dir / name).exists():
|
||||
model_names.append(name)
|
||||
else:
|
||||
p = output_dir / "model.safetensors"
|
||||
if p.exists():
|
||||
model_files.append(p)
|
||||
if (source_dir / "model.safetensors").exists():
|
||||
model_names.append("model.safetensors")
|
||||
|
||||
for model_path in model_files:
|
||||
print(f" Quantizing {model_path.name}...")
|
||||
for name in model_names:
|
||||
print(f" Quantizing {name}...")
|
||||
model = WanModel(config)
|
||||
weights = mx.load(str(model_path))
|
||||
weights = mx.load(str(source_dir / name))
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(model.parameters())
|
||||
del weights
|
||||
gc.collect()
|
||||
mx.clear_cache()
|
||||
|
||||
# Apply quantization to targeted layers
|
||||
nn.quantize(
|
||||
@@ -499,10 +606,30 @@ def _quantize_saved_model(
|
||||
|
||||
# Save quantized weights
|
||||
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
||||
mx.save_safetensors(str(model_path), weights_dict)
|
||||
|
||||
# Validate: check for NaN/Inf in bias tensors (corruption canary)
|
||||
bad_keys = []
|
||||
for k, v in weights_dict.items():
|
||||
if k.endswith(".bias") and not k.endswith(".biases"):
|
||||
mx.eval(v)
|
||||
if mx.any(mx.isnan(v)).item() or mx.any(mx.isinf(v)).item():
|
||||
bad_keys.append(k)
|
||||
if bad_keys:
|
||||
raise RuntimeError(
|
||||
f"Quantization produced corrupted weights in {model_path.name}: "
|
||||
f"{len(bad_keys)} bias tensors contain NaN/Inf "
|
||||
f"(e.g. {bad_keys[0]}). Try re-running with more available memory."
|
||||
)
|
||||
|
||||
mx.save_safetensors(str(output_dir / name), weights_dict)
|
||||
n_quantized = sum(1 for k in weights_dict if ".scales" in k)
|
||||
print(f" {n_quantized} layers quantized, {len(weights_dict)} tensors saved")
|
||||
|
||||
# Free model before processing next file
|
||||
del model, weights_dict
|
||||
gc.collect()
|
||||
mx.clear_cache()
|
||||
|
||||
# Update config.json with quantization metadata
|
||||
config_path = output_dir / "config.json"
|
||||
with open(config_path) as f:
|
||||
@@ -516,6 +643,68 @@ def _quantize_saved_model(
|
||||
print(f" Updated config.json with quantization metadata")
|
||||
|
||||
|
||||
def quantize_mlx_model(
|
||||
mlx_model_dir: str,
|
||||
output_dir: str,
|
||||
bits: int = 4,
|
||||
group_size: int = 64,
|
||||
):
|
||||
"""Quantize an already-converted MLX model (skips PyTorch conversion).
|
||||
|
||||
Args:
|
||||
mlx_model_dir: Path to existing MLX model directory (bf16/fp16).
|
||||
output_dir: Path to output quantized model directory.
|
||||
bits: Quantization bits (4 or 8).
|
||||
group_size: Quantization group size (32, 64, or 128).
|
||||
"""
|
||||
import json
|
||||
import shutil
|
||||
|
||||
src = Path(mlx_model_dir)
|
||||
dst = Path(output_dir)
|
||||
|
||||
config_path = src / "config.json"
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"No config.json found in {src}")
|
||||
|
||||
with open(config_path) as f:
|
||||
cfg = json.load(f)
|
||||
|
||||
if cfg.get("quantization"):
|
||||
raise ValueError(
|
||||
f"Model at {src} is already quantized "
|
||||
f"({cfg['quantization']['bits']}-bit). Use a bf16/fp16 source."
|
||||
)
|
||||
|
||||
# Detect dual vs single expert
|
||||
is_dual = (src / "low_noise_model.safetensors").exists() and (
|
||||
src / "high_noise_model.safetensors"
|
||||
).exists()
|
||||
|
||||
# Build model config
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config_dict = {k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__}
|
||||
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
|
||||
if key in config_dict and isinstance(config_dict[key], list):
|
||||
config_dict[key] = tuple(config_dict[key])
|
||||
config = WanModelConfig(**config_dict)
|
||||
|
||||
# Copy non-transformer files to output dir (skip large model weights)
|
||||
transformer_files = {"low_noise_model.safetensors", "high_noise_model.safetensors", "model.safetensors"}
|
||||
if dst.resolve() != src.resolve():
|
||||
dst.mkdir(parents=True, exist_ok=True)
|
||||
for f in src.iterdir():
|
||||
if f.is_file() and f.name not in transformer_files:
|
||||
shutil.copy2(f, dst / f.name)
|
||||
print(f"Copied non-transformer files from {src} to {dst}")
|
||||
|
||||
print(f"Quantizing transformer weights ({bits}-bit, group_size={group_size})...")
|
||||
_quantize_saved_model(dst, config, is_dual, bits, group_size, source_dir=src)
|
||||
|
||||
print(f"\nQuantization complete! Output: {dst}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
@@ -551,6 +740,11 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Quantize transformer weights for faster inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize-only",
|
||||
action="store_true",
|
||||
help="Quantize an already-converted MLX model (skips PyTorch conversion)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bits",
|
||||
type=int,
|
||||
@@ -566,7 +760,14 @@ if __name__ == "__main__":
|
||||
help="Quantization group size (default: 64)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_wan_checkpoint(
|
||||
args.checkpoint_dir, args.output_dir, args.dtype, args.model_version,
|
||||
quantize=args.quantize, bits=args.bits, group_size=args.group_size,
|
||||
)
|
||||
|
||||
if args.quantize_only:
|
||||
quantize_mlx_model(
|
||||
args.checkpoint_dir, args.output_dir,
|
||||
bits=args.bits, group_size=args.group_size,
|
||||
)
|
||||
else:
|
||||
convert_wan_checkpoint(
|
||||
args.checkpoint_dir, args.output_dir, args.dtype, args.model_version,
|
||||
quantize=args.quantize, bits=args.bits, group_size=args.group_size,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user