feat(wan): Add LoRA with improved quantization pipeline
This commit is contained in:
18
README.md
18
README.md
@@ -82,15 +82,15 @@ python -m mlx_video.generate \
|
|||||||
|
|
||||||
Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.com/Wan-Video/Wan2.2) are text-to-video diffusion models built on a DiT (Diffusion Transformer) backbone with a T5 text encoder and 3D VAE. They share the same model architecture — the difference is in the inference pipeline:
|
Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.com/Wan-Video/Wan2.2) are text-to-video diffusion models built on a DiT (Diffusion Transformer) backbone with a T5 text encoder and 3D VAE. They share the same model architecture — the difference is in the inference pipeline:
|
||||||
|
|
||||||
| | Wan2.1 | Wan2.2 T2V-14B | Wan2.2 I2V-14B |
|
| | Wan2.1 | Wan2.2 T2V-14B | Wan2.2 I2V-14B | Wan2.2 TI2V-5B |
|
||||||
|---|--------|--------|--------|
|
|---|--------|--------|--------|--------|
|
||||||
| **Task** | Text-to-Video | Text-to-Video | Image-to-Video |
|
| **Task** | Text-to-Video | Text-to-Video | Image-to-Video | Text+Image-to-Video |
|
||||||
| **Pipeline** | Single model | Dual model | Dual model |
|
| **Pipeline** | Single model | Dual model | Dual model | Single model |
|
||||||
| **Sizes** | 1.3B, 14B | 14B | 14B |
|
| **Sizes** | 1.3B, 14B | 14B | 14B | 5B |
|
||||||
| **Steps** | 50 | 40 | 40 |
|
| **Steps** | 50 | 40 | 40 | 40 |
|
||||||
| **Guidance** | 5.0 (fixed) | 3.0 / 4.0 | 3.5 / 3.5 |
|
| **Guidance** | 5.0 (fixed) | 3.0 / 4.0 | 3.5 / 3.5 | 5.0 (fixed) |
|
||||||
| **Shift** | 5.0 | 12.0 | 5.0 |
|
| **Shift** | 5.0 | 12.0 | 5.0 | 5.0 |
|
||||||
| **VAE** | Wan2.1 (z=16) | Wan2.1 (z=16) | Wan2.1 (z=16) + encoder |
|
| **VAE** | Wan2.1 (z=16) | Wan2.1 (z=16) | Wan2.1 (z=16) + encoder | Wan2.2 (z=48) |
|
||||||
|
|
||||||
### Step 1: Download Weights
|
### Step 1: Download Weights
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
"""Weight conversion for Wan2.2 models (PyTorch -> MLX)."""
|
"""Weight conversion for Wan2.2 models (PyTorch -> MLX)."""
|
||||||
|
|
||||||
|
import gc
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.utils
|
import mlx.utils
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_torch_weights(path: str) -> Dict[str, mx.array]:
|
def load_torch_weights(path: str) -> Dict[str, mx.array]:
|
||||||
"""Load PyTorch .pth weights and convert to MLX arrays.
|
"""Load PyTorch .pth weights and convert to MLX arrays.
|
||||||
@@ -88,6 +91,7 @@ def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str,
|
|||||||
etc.
|
etc.
|
||||||
"""
|
"""
|
||||||
sanitized = {}
|
sanitized = {}
|
||||||
|
consumed = set()
|
||||||
|
|
||||||
for key, value in weights.items():
|
for key, value in weights.items():
|
||||||
new_key = key
|
new_key = key
|
||||||
@@ -99,36 +103,43 @@ def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str,
|
|||||||
value = value.reshape(value.shape[0], -1)
|
value = value.reshape(value.shape[0], -1)
|
||||||
new_key = "patch_embedding_proj.weight"
|
new_key = "patch_embedding_proj.weight"
|
||||||
sanitized[new_key] = value
|
sanitized[new_key] = value
|
||||||
|
consumed.add(key)
|
||||||
continue
|
continue
|
||||||
if key == "patch_embedding.bias":
|
if key == "patch_embedding.bias":
|
||||||
new_key = "patch_embedding_proj.bias"
|
new_key = "patch_embedding_proj.bias"
|
||||||
sanitized[new_key] = value
|
sanitized[new_key] = value
|
||||||
|
consumed.add(key)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Text embedding Sequential: 0=Linear, 1=GELU(no params), 2=Linear
|
# Text embedding Sequential: 0=Linear, 1=GELU(no params), 2=Linear
|
||||||
if key.startswith("text_embedding.0."):
|
if key.startswith("text_embedding.0."):
|
||||||
new_key = key.replace("text_embedding.0.", "text_embedding_0.")
|
new_key = key.replace("text_embedding.0.", "text_embedding_0.")
|
||||||
sanitized[new_key] = value
|
sanitized[new_key] = value
|
||||||
|
consumed.add(key)
|
||||||
continue
|
continue
|
||||||
if key.startswith("text_embedding.2."):
|
if key.startswith("text_embedding.2."):
|
||||||
new_key = key.replace("text_embedding.2.", "text_embedding_1.")
|
new_key = key.replace("text_embedding.2.", "text_embedding_1.")
|
||||||
sanitized[new_key] = value
|
sanitized[new_key] = value
|
||||||
|
consumed.add(key)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Time embedding Sequential: 0=Linear, 1=SiLU(no params), 2=Linear
|
# Time embedding Sequential: 0=Linear, 1=SiLU(no params), 2=Linear
|
||||||
if key.startswith("time_embedding.0."):
|
if key.startswith("time_embedding.0."):
|
||||||
new_key = key.replace("time_embedding.0.", "time_embedding_0.")
|
new_key = key.replace("time_embedding.0.", "time_embedding_0.")
|
||||||
sanitized[new_key] = value
|
sanitized[new_key] = value
|
||||||
|
consumed.add(key)
|
||||||
continue
|
continue
|
||||||
if key.startswith("time_embedding.2."):
|
if key.startswith("time_embedding.2."):
|
||||||
new_key = key.replace("time_embedding.2.", "time_embedding_1.")
|
new_key = key.replace("time_embedding.2.", "time_embedding_1.")
|
||||||
sanitized[new_key] = value
|
sanitized[new_key] = value
|
||||||
|
consumed.add(key)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Time projection Sequential: 0=SiLU(no params), 1=Linear
|
# Time projection Sequential: 0=SiLU(no params), 1=Linear
|
||||||
if key.startswith("time_projection.1."):
|
if key.startswith("time_projection.1."):
|
||||||
new_key = key.replace("time_projection.1.", "time_projection.")
|
new_key = key.replace("time_projection.1.", "time_projection.")
|
||||||
sanitized[new_key] = value
|
sanitized[new_key] = value
|
||||||
|
consumed.add(key)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# FFN: Sequential(Linear, GELU, Linear) -> ffn.{0,2} -> ffn.fc1, ffn.fc2
|
# FFN: Sequential(Linear, GELU, Linear) -> ffn.{0,2} -> ffn.fc1, ffn.fc2
|
||||||
@@ -137,9 +148,15 @@ def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str,
|
|||||||
|
|
||||||
# Skip the freqs buffer (we compute it in the model)
|
# Skip the freqs buffer (we compute it in the model)
|
||||||
if key == "freqs":
|
if key == "freqs":
|
||||||
|
consumed.add(key)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
sanitized[new_key] = value
|
sanitized[new_key] = value
|
||||||
|
consumed.add(key)
|
||||||
|
|
||||||
|
unconsumed = set(weights.keys()) - consumed
|
||||||
|
if unconsumed:
|
||||||
|
logger.warning("Unconsumed transformer weight keys: %s", sorted(unconsumed))
|
||||||
|
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
@@ -171,6 +188,7 @@ def sanitize_wan_t5_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]
|
|||||||
norm.weight
|
norm.weight
|
||||||
"""
|
"""
|
||||||
sanitized = {}
|
sanitized = {}
|
||||||
|
consumed = set()
|
||||||
|
|
||||||
for key, value in weights.items():
|
for key, value in weights.items():
|
||||||
new_key = key
|
new_key = key
|
||||||
@@ -179,6 +197,11 @@ def sanitize_wan_t5_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]
|
|||||||
new_key = new_key.replace(".ffn.gate.0.", ".ffn.gate_proj.")
|
new_key = new_key.replace(".ffn.gate.0.", ".ffn.gate_proj.")
|
||||||
|
|
||||||
sanitized[new_key] = value
|
sanitized[new_key] = value
|
||||||
|
consumed.add(key)
|
||||||
|
|
||||||
|
unconsumed = set(weights.keys()) - consumed
|
||||||
|
if unconsumed:
|
||||||
|
logger.warning("Unconsumed T5 weight keys: %s", sorted(unconsumed))
|
||||||
|
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
@@ -189,6 +212,7 @@ def sanitize_wan_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array
|
|||||||
Handles Conv3d and Conv2d weight transpositions for MLX format.
|
Handles Conv3d and Conv2d weight transpositions for MLX format.
|
||||||
"""
|
"""
|
||||||
sanitized = {}
|
sanitized = {}
|
||||||
|
consumed = set()
|
||||||
|
|
||||||
for key, value in weights.items():
|
for key, value in weights.items():
|
||||||
new_key = key
|
new_key = key
|
||||||
@@ -206,10 +230,78 @@ def sanitize_wan_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array
|
|||||||
# Need to adapt naming for our simplified structure
|
# Need to adapt naming for our simplified structure
|
||||||
|
|
||||||
sanitized[new_key] = value
|
sanitized[new_key] = value
|
||||||
|
consumed.add(key)
|
||||||
|
|
||||||
|
unconsumed = set(weights.keys()) - consumed
|
||||||
|
if unconsumed:
|
||||||
|
logger.warning("Unconsumed VAE weight keys: %s", sorted(unconsumed))
|
||||||
|
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
def _load_lora_configs(
|
||||||
|
lora_configs: List[Tuple[str, float]],
|
||||||
|
) -> Dict[str, list]:
|
||||||
|
"""Load LoRA weights from config tuples, returning module_to_loras dict.
|
||||||
|
|
||||||
|
Shared between weight-merging and runtime-wrapping paths.
|
||||||
|
"""
|
||||||
|
from mlx_video.lora import LoRAConfig, load_multiple_loras
|
||||||
|
from mlx_video.utils import Colors
|
||||||
|
|
||||||
|
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
|
||||||
|
|
||||||
|
configs = []
|
||||||
|
for lora_path, strength in lora_configs:
|
||||||
|
try:
|
||||||
|
config = LoRAConfig(path=lora_path, strength=strength)
|
||||||
|
configs.append(config)
|
||||||
|
print(f" - {Path(lora_path).name} (strength: {strength})")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"{Colors.RED}Error loading LoRA {lora_path}: {e}{Colors.RESET}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
module_to_loras = load_multiple_loras(configs)
|
||||||
|
|
||||||
|
if not module_to_loras:
|
||||||
|
print(f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}")
|
||||||
|
|
||||||
|
return module_to_loras
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_apply_loras(
|
||||||
|
model_weights: Dict[str, mx.array],
|
||||||
|
lora_configs: Optional[List[Tuple[str, float]]] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
quantization_bits: int = 0,
|
||||||
|
) -> Dict[str, mx.array]:
|
||||||
|
"""Load and apply LoRA weights to model weights by merging into weight dict.
|
||||||
|
|
||||||
|
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
|
||||||
|
"""
|
||||||
|
from mlx_video.lora import apply_loras_to_weights
|
||||||
|
from mlx_video.utils import Colors
|
||||||
|
|
||||||
|
if not lora_configs:
|
||||||
|
return model_weights
|
||||||
|
|
||||||
|
module_to_loras = _load_lora_configs(lora_configs)
|
||||||
|
if not module_to_loras:
|
||||||
|
return model_weights
|
||||||
|
|
||||||
|
print(f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}")
|
||||||
|
if verbose:
|
||||||
|
print(f" Model has {len(model_weights)} weight keys")
|
||||||
|
|
||||||
|
modified_weights = apply_loras_to_weights(
|
||||||
|
model_weights, module_to_loras, verbose=verbose, quantization_bits=quantization_bits
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"{Colors.GREEN}✓ LoRAs applied successfully{Colors.RESET}")
|
||||||
|
|
||||||
|
return modified_weights
|
||||||
|
|
||||||
|
|
||||||
def convert_wan_checkpoint(
|
def convert_wan_checkpoint(
|
||||||
checkpoint_dir: str,
|
checkpoint_dir: str,
|
||||||
output_dir: str,
|
output_dir: str,
|
||||||
@@ -464,30 +556,45 @@ def _quantize_saved_model(
|
|||||||
is_dual: bool,
|
is_dual: bool,
|
||||||
bits: int,
|
bits: int,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
|
source_dir: Path = None,
|
||||||
):
|
):
|
||||||
"""Load saved bf16 model, quantize, and re-save."""
|
"""Load saved bf16 model, quantize, and re-save.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: Directory to write quantized weights to.
|
||||||
|
config: WanModelConfig for creating the model.
|
||||||
|
is_dual: Whether this is a dual-expert model.
|
||||||
|
bits: Quantization bits.
|
||||||
|
group_size: Quantization group size.
|
||||||
|
source_dir: Directory to read bf16 weights from. Defaults to output_dir.
|
||||||
|
"""
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from mlx_video.models.wan.model import WanModel
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
model_files = []
|
if source_dir is None:
|
||||||
|
source_dir = output_dir
|
||||||
|
|
||||||
|
model_names = []
|
||||||
if is_dual:
|
if is_dual:
|
||||||
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
|
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
|
||||||
p = output_dir / name
|
if (source_dir / name).exists():
|
||||||
if p.exists():
|
model_names.append(name)
|
||||||
model_files.append(p)
|
|
||||||
else:
|
else:
|
||||||
p = output_dir / "model.safetensors"
|
if (source_dir / "model.safetensors").exists():
|
||||||
if p.exists():
|
model_names.append("model.safetensors")
|
||||||
model_files.append(p)
|
|
||||||
|
|
||||||
for model_path in model_files:
|
for name in model_names:
|
||||||
print(f" Quantizing {model_path.name}...")
|
print(f" Quantizing {name}...")
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
weights = mx.load(str(model_path))
|
weights = mx.load(str(source_dir / name))
|
||||||
model.load_weights(list(weights.items()), strict=False)
|
model.load_weights(list(weights.items()), strict=False)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
del weights
|
||||||
|
gc.collect()
|
||||||
|
mx.clear_cache()
|
||||||
|
|
||||||
# Apply quantization to targeted layers
|
# Apply quantization to targeted layers
|
||||||
nn.quantize(
|
nn.quantize(
|
||||||
@@ -499,10 +606,30 @@ def _quantize_saved_model(
|
|||||||
|
|
||||||
# Save quantized weights
|
# Save quantized weights
|
||||||
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
||||||
mx.save_safetensors(str(model_path), weights_dict)
|
|
||||||
|
# Validate: check for NaN/Inf in bias tensors (corruption canary)
|
||||||
|
bad_keys = []
|
||||||
|
for k, v in weights_dict.items():
|
||||||
|
if k.endswith(".bias") and not k.endswith(".biases"):
|
||||||
|
mx.eval(v)
|
||||||
|
if mx.any(mx.isnan(v)).item() or mx.any(mx.isinf(v)).item():
|
||||||
|
bad_keys.append(k)
|
||||||
|
if bad_keys:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Quantization produced corrupted weights in {model_path.name}: "
|
||||||
|
f"{len(bad_keys)} bias tensors contain NaN/Inf "
|
||||||
|
f"(e.g. {bad_keys[0]}). Try re-running with more available memory."
|
||||||
|
)
|
||||||
|
|
||||||
|
mx.save_safetensors(str(output_dir / name), weights_dict)
|
||||||
n_quantized = sum(1 for k in weights_dict if ".scales" in k)
|
n_quantized = sum(1 for k in weights_dict if ".scales" in k)
|
||||||
print(f" {n_quantized} layers quantized, {len(weights_dict)} tensors saved")
|
print(f" {n_quantized} layers quantized, {len(weights_dict)} tensors saved")
|
||||||
|
|
||||||
|
# Free model before processing next file
|
||||||
|
del model, weights_dict
|
||||||
|
gc.collect()
|
||||||
|
mx.clear_cache()
|
||||||
|
|
||||||
# Update config.json with quantization metadata
|
# Update config.json with quantization metadata
|
||||||
config_path = output_dir / "config.json"
|
config_path = output_dir / "config.json"
|
||||||
with open(config_path) as f:
|
with open(config_path) as f:
|
||||||
@@ -516,6 +643,68 @@ def _quantize_saved_model(
|
|||||||
print(f" Updated config.json with quantization metadata")
|
print(f" Updated config.json with quantization metadata")
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_mlx_model(
|
||||||
|
mlx_model_dir: str,
|
||||||
|
output_dir: str,
|
||||||
|
bits: int = 4,
|
||||||
|
group_size: int = 64,
|
||||||
|
):
|
||||||
|
"""Quantize an already-converted MLX model (skips PyTorch conversion).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mlx_model_dir: Path to existing MLX model directory (bf16/fp16).
|
||||||
|
output_dir: Path to output quantized model directory.
|
||||||
|
bits: Quantization bits (4 or 8).
|
||||||
|
group_size: Quantization group size (32, 64, or 128).
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
src = Path(mlx_model_dir)
|
||||||
|
dst = Path(output_dir)
|
||||||
|
|
||||||
|
config_path = src / "config.json"
|
||||||
|
if not config_path.exists():
|
||||||
|
raise FileNotFoundError(f"No config.json found in {src}")
|
||||||
|
|
||||||
|
with open(config_path) as f:
|
||||||
|
cfg = json.load(f)
|
||||||
|
|
||||||
|
if cfg.get("quantization"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Model at {src} is already quantized "
|
||||||
|
f"({cfg['quantization']['bits']}-bit). Use a bf16/fp16 source."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Detect dual vs single expert
|
||||||
|
is_dual = (src / "low_noise_model.safetensors").exists() and (
|
||||||
|
src / "high_noise_model.safetensors"
|
||||||
|
).exists()
|
||||||
|
|
||||||
|
# Build model config
|
||||||
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
|
||||||
|
config_dict = {k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__}
|
||||||
|
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
|
||||||
|
if key in config_dict and isinstance(config_dict[key], list):
|
||||||
|
config_dict[key] = tuple(config_dict[key])
|
||||||
|
config = WanModelConfig(**config_dict)
|
||||||
|
|
||||||
|
# Copy non-transformer files to output dir (skip large model weights)
|
||||||
|
transformer_files = {"low_noise_model.safetensors", "high_noise_model.safetensors", "model.safetensors"}
|
||||||
|
if dst.resolve() != src.resolve():
|
||||||
|
dst.mkdir(parents=True, exist_ok=True)
|
||||||
|
for f in src.iterdir():
|
||||||
|
if f.is_file() and f.name not in transformer_files:
|
||||||
|
shutil.copy2(f, dst / f.name)
|
||||||
|
print(f"Copied non-transformer files from {src} to {dst}")
|
||||||
|
|
||||||
|
print(f"Quantizing transformer weights ({bits}-bit, group_size={group_size})...")
|
||||||
|
_quantize_saved_model(dst, config, is_dual, bits, group_size, source_dir=src)
|
||||||
|
|
||||||
|
print(f"\nQuantization complete! Output: {dst}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
@@ -551,6 +740,11 @@ if __name__ == "__main__":
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Quantize transformer weights for faster inference",
|
help="Quantize transformer weights for faster inference",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantize-only",
|
||||||
|
action="store_true",
|
||||||
|
help="Quantize an already-converted MLX model (skips PyTorch conversion)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bits",
|
"--bits",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -566,6 +760,13 @@ if __name__ == "__main__":
|
|||||||
help="Quantization group size (default: 64)",
|
help="Quantization group size (default: 64)",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.quantize_only:
|
||||||
|
quantize_mlx_model(
|
||||||
|
args.checkpoint_dir, args.output_dir,
|
||||||
|
bits=args.bits, group_size=args.group_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
convert_wan_checkpoint(
|
convert_wan_checkpoint(
|
||||||
args.checkpoint_dir, args.output_dir, args.dtype, args.model_version,
|
args.checkpoint_dir, args.output_dir, args.dtype, args.model_version,
|
||||||
quantize=args.quantize, bits=args.bits, group_size=args.group_size,
|
quantize=args.quantize, bits=args.bits, group_size=args.group_size,
|
||||||
|
|||||||
@@ -43,6 +43,10 @@ def generate_video(
|
|||||||
seed: int = -1,
|
seed: int = -1,
|
||||||
output_path: str = "output.mp4",
|
output_path: str = "output.mp4",
|
||||||
scheduler: str = "unipc",
|
scheduler: str = "unipc",
|
||||||
|
loras: list | None = None,
|
||||||
|
loras_high: list | None = None,
|
||||||
|
loras_low: list | None = None,
|
||||||
|
|
||||||
):
|
):
|
||||||
"""Generate video using Wan pipeline (supports T2V and I2V).
|
"""Generate video using Wan pipeline (supports T2V and I2V).
|
||||||
|
|
||||||
@@ -60,6 +64,10 @@ def generate_video(
|
|||||||
seed: Random seed (-1 for random)
|
seed: Random seed (-1 for random)
|
||||||
output_path: Output video path
|
output_path: Output video path
|
||||||
scheduler: Solver type: 'euler', 'dpm++', or 'unipc' (default)
|
scheduler: Solver type: 'euler', 'dpm++', or 'unipc' (default)
|
||||||
|
loras: Optional list of (path, strength) tuples applied to all models
|
||||||
|
loras_high: Optional list of (path, strength) tuples for high-noise model only
|
||||||
|
loras_low: Optional list of (path, strength) tuples for low-noise model only
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@@ -156,6 +164,12 @@ def generate_video(
|
|||||||
parts = [float(x) for x in guide_scale.split(",")]
|
parts = [float(x) for x in guide_scale.split(",")]
|
||||||
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
|
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
|
||||||
|
|
||||||
|
# Detect CFG-disabled mode (guide_scale=1.0 for all models → skip uncond pass for 2x speedup)
|
||||||
|
if isinstance(guide_scale, tuple):
|
||||||
|
cfg_disabled = all(gs <= 1.0 for gs in guide_scale)
|
||||||
|
else:
|
||||||
|
cfg_disabled = guide_scale <= 1.0
|
||||||
|
|
||||||
# Validate frame count
|
# Validate frame count
|
||||||
assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}"
|
assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}"
|
||||||
|
|
||||||
@@ -181,6 +195,8 @@ def generate_video(
|
|||||||
print(f" Neg prompt: {neg_display}")
|
print(f" Neg prompt: {neg_display}")
|
||||||
print(f" Size: {width}x{height}, Frames: {num_frames}")
|
print(f" Size: {width}x{height}, Frames: {num_frames}")
|
||||||
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}")
|
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}")
|
||||||
|
if cfg_disabled:
|
||||||
|
print(f" CFG: disabled (guide_scale≤1 → B=1 fast path, 2x denoising speedup)")
|
||||||
print(f"{Colors.RESET}")
|
print(f"{Colors.RESET}")
|
||||||
|
|
||||||
# Seed
|
# Seed
|
||||||
@@ -233,6 +249,10 @@ def generate_video(
|
|||||||
# Encode prompts
|
# Encode prompts
|
||||||
print(f"{Colors.BLUE}Encoding text...{Colors.RESET}")
|
print(f"{Colors.BLUE}Encoding text...{Colors.RESET}")
|
||||||
context = encode_text(t5_encoder, tokenizer, prompt, config.text_len)
|
context = encode_text(t5_encoder, tokenizer, prompt, config.text_len)
|
||||||
|
if cfg_disabled:
|
||||||
|
context_null = None
|
||||||
|
mx.eval(context)
|
||||||
|
else:
|
||||||
context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len)
|
context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len)
|
||||||
mx.eval(context, context_null)
|
mx.eval(context, context_null)
|
||||||
|
|
||||||
@@ -319,17 +339,35 @@ def generate_video(
|
|||||||
print(f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}")
|
print(f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}")
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
|
|
||||||
|
# Merge per-model LoRAs with shared LoRAs
|
||||||
|
_loras_low = (loras or []) + (loras_low or []) or None
|
||||||
|
_loras_high = (loras or []) + (loras_high or []) or None
|
||||||
|
_loras_single = loras
|
||||||
|
|
||||||
if is_dual:
|
if is_dual:
|
||||||
low_noise_path = model_dir / "low_noise_model.safetensors"
|
low_noise_path = model_dir / "low_noise_model.safetensors"
|
||||||
high_noise_path = model_dir / "high_noise_model.safetensors"
|
high_noise_path = model_dir / "high_noise_model.safetensors"
|
||||||
low_noise_model = load_wan_model(low_noise_path, config, quantization)
|
low_noise_model = load_wan_model(low_noise_path, config, quantization, loras=_loras_low)
|
||||||
high_noise_model = load_wan_model(high_noise_path, config, quantization)
|
high_noise_model = load_wan_model(high_noise_path, config, quantization, loras=_loras_high)
|
||||||
else:
|
else:
|
||||||
single_model = load_wan_model(model_dir / "model.safetensors", config, quantization)
|
single_model = load_wan_model(model_dir / "model.safetensors", config, quantization, loras=_loras_single)
|
||||||
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
|
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
|
||||||
|
|
||||||
# Precompute text embeddings once (avoids redundant MLP in every step)
|
# Precompute text embeddings once (avoids redundant MLP in every step)
|
||||||
# Each model has its own text_embedding weights, so dual models need separate embeddings
|
# Each model has its own text_embedding weights, so dual models need separate embeddings
|
||||||
|
if cfg_disabled:
|
||||||
|
# No CFG: only compute cond embeddings (B=1 forward pass, 2x faster)
|
||||||
|
if is_dual:
|
||||||
|
context_emb_low = low_noise_model.embed_text([context])
|
||||||
|
context_emb_high = high_noise_model.embed_text([context])
|
||||||
|
mx.eval(context_emb_low, context_emb_high)
|
||||||
|
context_cond_low = context_emb_low[0:1]
|
||||||
|
context_cond_high = context_emb_high[0:1]
|
||||||
|
else:
|
||||||
|
context_emb = single_model.embed_text([context])
|
||||||
|
mx.eval(context_emb)
|
||||||
|
context_cond = context_emb[0:1]
|
||||||
|
else:
|
||||||
if is_dual:
|
if is_dual:
|
||||||
context_emb_low = low_noise_model.embed_text([context, context_null])
|
context_emb_low = low_noise_model.embed_text([context, context_null])
|
||||||
context_emb_high = high_noise_model.embed_text([context, context_null])
|
context_emb_high = high_noise_model.embed_text([context, context_null])
|
||||||
@@ -342,6 +380,15 @@ def generate_video(
|
|||||||
context_cfg = mx.concatenate([context_emb[0:1], context_emb[1:2]], axis=0)
|
context_cfg = mx.concatenate([context_emb[0:1], context_emb[1:2]], axis=0)
|
||||||
|
|
||||||
# Precompute cross-attention K/V caches (constant across all steps)
|
# Precompute cross-attention K/V caches (constant across all steps)
|
||||||
|
if cfg_disabled:
|
||||||
|
if is_dual:
|
||||||
|
cross_kv_low = low_noise_model.prepare_cross_kv(context_cond_low)
|
||||||
|
cross_kv_high = high_noise_model.prepare_cross_kv(context_cond_high)
|
||||||
|
mx.eval(cross_kv_low, cross_kv_high)
|
||||||
|
else:
|
||||||
|
cross_kv = single_model.prepare_cross_kv(context_cond)
|
||||||
|
mx.eval(cross_kv)
|
||||||
|
else:
|
||||||
if is_dual:
|
if is_dual:
|
||||||
cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg_low)
|
cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg_low)
|
||||||
cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg_high)
|
cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg_high)
|
||||||
@@ -354,13 +401,16 @@ def generate_video(
|
|||||||
f_grid = t_latent // patch_size[0]
|
f_grid = t_latent // patch_size[0]
|
||||||
h_grid = h_latent // patch_size[1]
|
h_grid = h_latent // patch_size[1]
|
||||||
w_grid = w_latent // patch_size[2]
|
w_grid = w_latent // patch_size[2]
|
||||||
cfg_grid_sizes = [(f_grid, h_grid, w_grid), (f_grid, h_grid, w_grid)]
|
if cfg_disabled:
|
||||||
|
rope_grid_sizes = [(f_grid, h_grid, w_grid)]
|
||||||
|
else:
|
||||||
|
rope_grid_sizes = [(f_grid, h_grid, w_grid), (f_grid, h_grid, w_grid)]
|
||||||
if is_dual:
|
if is_dual:
|
||||||
rope_cos_sin_low = low_noise_model.prepare_rope(cfg_grid_sizes)
|
rope_cos_sin_low = low_noise_model.prepare_rope(rope_grid_sizes)
|
||||||
rope_cos_sin_high = high_noise_model.prepare_rope(cfg_grid_sizes)
|
rope_cos_sin_high = high_noise_model.prepare_rope(rope_grid_sizes)
|
||||||
mx.eval(rope_cos_sin_low, rope_cos_sin_high)
|
mx.eval(rope_cos_sin_low, rope_cos_sin_high)
|
||||||
else:
|
else:
|
||||||
rope_cos_sin = ref_model.prepare_rope(cfg_grid_sizes)
|
rope_cos_sin = ref_model.prepare_rope(rope_grid_sizes)
|
||||||
mx.eval(rope_cos_sin)
|
mx.eval(rope_cos_sin)
|
||||||
|
|
||||||
# Setup scheduler
|
# Setup scheduler
|
||||||
@@ -395,42 +445,71 @@ def generate_video(
|
|||||||
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
|
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
|
||||||
timestep_val = timestep_list[i]
|
timestep_val = timestep_list[i]
|
||||||
|
|
||||||
# Select model, guide scale, cached K/V, and precomputed RoPE
|
# Select model, cached K/V, and precomputed RoPE
|
||||||
if is_dual:
|
if is_dual:
|
||||||
if timestep_val >= boundary:
|
if timestep_val >= boundary:
|
||||||
model = high_noise_model
|
model = high_noise_model
|
||||||
gs = guide_scale[1]
|
|
||||||
kv = cross_kv_high
|
kv = cross_kv_high
|
||||||
rcs = rope_cos_sin_high
|
rcs = rope_cos_sin_high
|
||||||
else:
|
else:
|
||||||
model = low_noise_model
|
model = low_noise_model
|
||||||
gs = guide_scale[0]
|
|
||||||
kv = cross_kv_low
|
kv = cross_kv_low
|
||||||
rcs = rope_cos_sin_low
|
rcs = rope_cos_sin_low
|
||||||
else:
|
else:
|
||||||
model = single_model
|
model = single_model
|
||||||
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
|
||||||
kv = cross_kv
|
kv = cross_kv
|
||||||
rcs = rope_cos_sin
|
rcs = rope_cos_sin
|
||||||
|
|
||||||
# Build per-token timesteps for TI2V-5B (first-frame patches get t=0)
|
if cfg_disabled:
|
||||||
|
# No CFG: B=1 forward pass (2x faster than B=2 CFG batch)
|
||||||
if is_i2v_mask_blend:
|
if is_i2v_mask_blend:
|
||||||
t_tokens = i2v_mask_tokens * timestep_val # [1, L]
|
t_tokens = i2v_mask_tokens * timestep_val
|
||||||
# Pad to seq_len if needed
|
|
||||||
pad_len = seq_len - t_tokens.shape[1]
|
pad_len = seq_len - t_tokens.shape[1]
|
||||||
if pad_len > 0:
|
if pad_len > 0:
|
||||||
t_tokens = mx.concatenate(
|
t_tokens = mx.concatenate(
|
||||||
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
|
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
|
||||||
)
|
)
|
||||||
# Batch for CFG: both cond and uncond get same timesteps
|
t_batch = t_tokens # [1, L]
|
||||||
t_batch = mx.concatenate([t_tokens, t_tokens], axis=0) # [2, L]
|
else:
|
||||||
|
t_batch = mx.array([timestep_val])
|
||||||
|
|
||||||
|
y_arg = [y_i2v] if is_i2v_channel_concat else None
|
||||||
|
|
||||||
|
if is_dual:
|
||||||
|
ctx = context_cond_high if timestep_val >= boundary else context_cond_low
|
||||||
|
else:
|
||||||
|
ctx = context_cond
|
||||||
|
preds = model(
|
||||||
|
[latents],
|
||||||
|
t=t_batch,
|
||||||
|
context=ctx,
|
||||||
|
seq_len=seq_len,
|
||||||
|
cross_kv_caches=kv,
|
||||||
|
y=y_arg,
|
||||||
|
rope_cos_sin=rcs,
|
||||||
|
)
|
||||||
|
noise_pred = preds[0]
|
||||||
|
del preds
|
||||||
|
else:
|
||||||
|
# CFG: batch cond + uncond into single B=2 forward pass
|
||||||
|
if is_dual:
|
||||||
|
gs = guide_scale[1] if timestep_val >= boundary else guide_scale[0]
|
||||||
|
else:
|
||||||
|
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
||||||
|
|
||||||
|
if is_i2v_mask_blend:
|
||||||
|
t_tokens = i2v_mask_tokens * timestep_val
|
||||||
|
pad_len = seq_len - t_tokens.shape[1]
|
||||||
|
if pad_len > 0:
|
||||||
|
t_tokens = mx.concatenate(
|
||||||
|
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
|
||||||
|
)
|
||||||
|
t_batch = mx.concatenate([t_tokens, t_tokens], axis=0)
|
||||||
else:
|
else:
|
||||||
t_batch = mx.array([timestep_val, timestep_val])
|
t_batch = mx.array([timestep_val, timestep_val])
|
||||||
|
|
||||||
# I2V-14B: pass y conditioning to model (same y for cond and uncond)
|
|
||||||
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
|
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
|
||||||
|
|
||||||
# CFG: batch cond + uncond into single B=2 forward pass
|
|
||||||
ctx = context_cfg if not is_dual else (
|
ctx = context_cfg if not is_dual else (
|
||||||
context_cfg_high if timestep_val >= boundary else context_cfg_low
|
context_cfg_high if timestep_val >= boundary else context_cfg_low
|
||||||
)
|
)
|
||||||
@@ -444,9 +523,8 @@ def generate_video(
|
|||||||
rope_cos_sin=rcs,
|
rope_cos_sin=rcs,
|
||||||
)
|
)
|
||||||
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
|
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
|
||||||
|
|
||||||
# Classifier-free guidance + scheduler step
|
|
||||||
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
||||||
|
del noise_pred_cond, noise_pred_uncond, preds
|
||||||
|
|
||||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||||
|
|
||||||
@@ -455,7 +533,7 @@ def generate_video(
|
|||||||
latents = (1.0 - i2v_mask) * z_img + i2v_mask * latents
|
latents = (1.0 - i2v_mask) * z_img + i2v_mask * latents
|
||||||
|
|
||||||
# Release temporaries before eval to free memory for graph execution
|
# Release temporaries before eval to free memory for graph execution
|
||||||
del noise_pred_cond, noise_pred_uncond, noise_pred, preds
|
del noise_pred
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
|
|
||||||
print(f"{Colors.DIM} Denoising: {time.time() - t3:.1f}s{Colors.RESET}")
|
print(f"{Colors.DIM} Denoising: {time.time() - t3:.1f}s{Colors.RESET}")
|
||||||
@@ -463,11 +541,19 @@ def generate_video(
|
|||||||
# Free transformer models and text embeddings
|
# Free transformer models and text embeddings
|
||||||
if is_dual:
|
if is_dual:
|
||||||
del low_noise_model, high_noise_model, cross_kv_low, cross_kv_high
|
del low_noise_model, high_noise_model, cross_kv_low, cross_kv_high
|
||||||
|
if cfg_disabled:
|
||||||
|
del context_cond_low, context_cond_high
|
||||||
|
else:
|
||||||
del context_cfg_low, context_cfg_high
|
del context_cfg_low, context_cfg_high
|
||||||
else:
|
else:
|
||||||
del single_model, cross_kv
|
del single_model, cross_kv
|
||||||
|
if cfg_disabled:
|
||||||
|
del context_cond
|
||||||
|
else:
|
||||||
del context_cfg
|
del context_cfg
|
||||||
del model, kv, context, context_null
|
del model, kv, context
|
||||||
|
if context_null is not None:
|
||||||
|
del context_null
|
||||||
gc.collect(); mx.clear_cache()
|
gc.collect(); mx.clear_cache()
|
||||||
|
|
||||||
# Load VAE and decode
|
# Load VAE and decode
|
||||||
@@ -478,25 +564,36 @@ def generate_video(
|
|||||||
|
|
||||||
is_wan22_vae = config.vae_z_dim == 48
|
is_wan22_vae = config.vae_z_dim == 48
|
||||||
|
|
||||||
|
# Warm-up: prepend a copy of the first latent frame to provide temporal
|
||||||
|
# context for the real first frame. Causal convolutions in the VAE decoder
|
||||||
|
# pad with zeros on the left, so the first few output frames have degraded
|
||||||
|
# quality (no temporal context). By duplicating the first latent, the real
|
||||||
|
# first frame sees its own features as left context instead of zeros.
|
||||||
|
# We trim the extra output frames after decoding.
|
||||||
|
warmup_trim = vae_stride[0] # 4 frames per latent temporal position
|
||||||
|
latents_for_decode = mx.concatenate([latents[:, 0:1], latents], axis=1)
|
||||||
|
|
||||||
if is_wan22_vae:
|
if is_wan22_vae:
|
||||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||||
|
|
||||||
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
|
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
|
||||||
z = latents.transpose(1, 2, 3, 0)[None] # [1, T, H, W, C]
|
z = latents_for_decode.transpose(1, 2, 3, 0)[None] # [1, T+1, H, W, C]
|
||||||
z = denormalize_latents(z)
|
z = denormalize_latents(z)
|
||||||
video = vae(z) # [1, T', H', W', 3]
|
video = vae(z) # [1, T', H', W', 3]
|
||||||
mx.eval(video)
|
mx.eval(video)
|
||||||
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
||||||
|
|
||||||
video = np.array(video[0]) # [T', H', W', 3]
|
video = np.array(video[0]) # [T', H', W', 3]
|
||||||
|
video = video[warmup_trim:] # Trim warm-up frames
|
||||||
video = (video + 1.0) / 2.0
|
video = (video + 1.0) / 2.0
|
||||||
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
||||||
else:
|
else:
|
||||||
video = vae.decode(latents[None]) # [1, 3, T, H, W]
|
video = vae.decode(latents_for_decode[None]) # [1, 3, T+1*4, H, W]
|
||||||
mx.eval(video)
|
mx.eval(video)
|
||||||
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
||||||
|
|
||||||
video = np.array(video[0]) # [3, T, H, W]
|
video = np.array(video[0]) # [3, T', H, W]
|
||||||
|
video = video[:, warmup_trim:] # Trim warm-up frames (channels-first)
|
||||||
video = (video + 1.0) / 2.0
|
video = (video + 1.0) / 2.0
|
||||||
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
||||||
video = video.transpose(1, 2, 3, 0) # [T, H, W, 3]
|
video = video.transpose(1, 2, 3, 0) # [T, H, W, 3]
|
||||||
@@ -529,6 +626,19 @@ def main():
|
|||||||
choices=["euler", "dpm++", "unipc"],
|
choices=["euler", "dpm++", "unipc"],
|
||||||
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
|
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
|
||||||
|
help="Apply a LoRA to all models (repeatable). Format: --lora path.safetensors 0.8",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora-high", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
|
||||||
|
help="Apply a LoRA to high-noise model only (dual-model, repeatable)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora-low", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
|
||||||
|
help="Apply a LoRA to low-noise model only (dual-model, repeatable)",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Parse guide scale
|
# Parse guide scale
|
||||||
@@ -542,6 +652,12 @@ def main():
|
|||||||
if args.no_negative_prompt:
|
if args.no_negative_prompt:
|
||||||
neg_prompt = ""
|
neg_prompt = ""
|
||||||
|
|
||||||
|
# Parse LoRA configs: convert [path, strength_str] → (path, float)
|
||||||
|
def _parse_lora_args(lora_list):
|
||||||
|
if not lora_list:
|
||||||
|
return None
|
||||||
|
return [(path, float(strength)) for path, strength in lora_list]
|
||||||
|
|
||||||
generate_video(
|
generate_video(
|
||||||
model_dir=args.model_dir,
|
model_dir=args.model_dir,
|
||||||
prompt=args.prompt,
|
prompt=args.prompt,
|
||||||
@@ -556,6 +672,10 @@ def main():
|
|||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
output_path=args.output_path,
|
output_path=args.output_path,
|
||||||
scheduler=args.scheduler,
|
scheduler=args.scheduler,
|
||||||
|
loras=_parse_lora_args(args.lora),
|
||||||
|
loras_high=_parse_lora_args(args.lora_high),
|
||||||
|
loras_low=_parse_lora_args(args.lora_low),
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
25
mlx_video/lora/__init__.py
Normal file
25
mlx_video/lora/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""LoRA support for mlx-video."""
|
||||||
|
|
||||||
|
from mlx_video.lora.apply import (
|
||||||
|
LoRALinear,
|
||||||
|
apply_lora_to_linear,
|
||||||
|
apply_loras_to_model,
|
||||||
|
apply_loras_to_weights,
|
||||||
|
)
|
||||||
|
from mlx_video.lora.loader import (
|
||||||
|
load_lora_weights,
|
||||||
|
load_multiple_loras,
|
||||||
|
)
|
||||||
|
from mlx_video.lora.types import AppliedLoRA, LoRAConfig, LoRAWeights
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LoRAConfig",
|
||||||
|
"LoRAWeights",
|
||||||
|
"AppliedLoRA",
|
||||||
|
"load_lora_weights",
|
||||||
|
"load_multiple_loras",
|
||||||
|
"apply_lora_to_linear",
|
||||||
|
"apply_loras_to_weights",
|
||||||
|
"apply_loras_to_model",
|
||||||
|
"LoRALinear",
|
||||||
|
]
|
||||||
393
mlx_video/lora/apply.py
Normal file
393
mlx_video/lora/apply.py
Normal file
@@ -0,0 +1,393 @@
|
|||||||
|
"""Apply LoRA weights to model layers."""
|
||||||
|
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from mlx_video.lora.types import LoRAWeights
|
||||||
|
|
||||||
|
|
||||||
|
def apply_lora_to_linear(
|
||||||
|
linear_weight: mx.array,
|
||||||
|
lora_weights_and_strengths: List[Tuple[LoRAWeights, float]],
|
||||||
|
) -> mx.array:
|
||||||
|
"""Apply one or more LoRAs to a linear layer weight.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
linear_weight: Original weight matrix [out_features, in_features]
|
||||||
|
lora_weights_and_strengths: List of (LoRAWeights, strength) tuples
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified weight with LoRA deltas applied (preserves original dtype)
|
||||||
|
"""
|
||||||
|
orig_dtype = linear_weight.dtype
|
||||||
|
modified_weight = linear_weight
|
||||||
|
|
||||||
|
for weights, strength in lora_weights_and_strengths:
|
||||||
|
scale = weights.scale
|
||||||
|
# Compute delta in float32 for precision, then cast back to avoid
|
||||||
|
# promoting model weights (e.g. bfloat16 → float32 causes ~1.5x slowdown)
|
||||||
|
delta = (weights.lora_B @ weights.lora_A) * (scale * strength)
|
||||||
|
modified_weight = modified_weight + delta.astype(orig_dtype)
|
||||||
|
|
||||||
|
return modified_weight
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str:
|
||||||
|
"""Normalize LoRA module name to match Wan2.2 MLX model weight keys.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- Stripping common prefixes (diffusion_model., model., etc.)
|
||||||
|
- FFN key mapping: ffn.0 → ffn.fc1, ffn.2 → ffn.fc2
|
||||||
|
- Embedding key mapping: text_embedding.0 → text_embedding_0, etc.
|
||||||
|
- Time projection: time_projection.1 → time_projection
|
||||||
|
- Patch embedding: patch_embedding → patch_embedding_proj
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lora_key: Original LoRA module name
|
||||||
|
model_keys: Set of all model weight keys
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized key that matches model weights
|
||||||
|
"""
|
||||||
|
# Try the key as-is first
|
||||||
|
if f"{lora_key}.weight" in model_keys or lora_key in model_keys:
|
||||||
|
return lora_key
|
||||||
|
|
||||||
|
# Common prefixes to strip
|
||||||
|
prefixes_to_strip = [
|
||||||
|
"model.diffusion_model.",
|
||||||
|
"diffusion_model.",
|
||||||
|
"base_model.model.",
|
||||||
|
"model.",
|
||||||
|
]
|
||||||
|
|
||||||
|
candidates = [lora_key]
|
||||||
|
for prefix in prefixes_to_strip:
|
||||||
|
if lora_key.startswith(prefix):
|
||||||
|
candidates.append(lora_key[len(prefix):])
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
# Try as-is
|
||||||
|
if f"{candidate}.weight" in model_keys or candidate in model_keys:
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
# Apply Wan2.2 key transformations
|
||||||
|
transformed = candidate
|
||||||
|
|
||||||
|
# FFN: ffn.0 → ffn.fc1, ffn.2 → ffn.fc2
|
||||||
|
transformed = transformed.replace(".ffn.0.", ".ffn.fc1.")
|
||||||
|
transformed = transformed.replace(".ffn.2.", ".ffn.fc2.")
|
||||||
|
if transformed.endswith(".ffn.0"):
|
||||||
|
transformed = transformed[:-len(".ffn.0")] + ".ffn.fc1"
|
||||||
|
if transformed.endswith(".ffn.2"):
|
||||||
|
transformed = transformed[:-len(".ffn.2")] + ".ffn.fc2"
|
||||||
|
|
||||||
|
# Text embedding: text_embedding.0 → text_embedding_0
|
||||||
|
transformed = transformed.replace("text_embedding.0.", "text_embedding_0.")
|
||||||
|
transformed = transformed.replace("text_embedding.2.", "text_embedding_1.")
|
||||||
|
if transformed.endswith("text_embedding.0"):
|
||||||
|
transformed = transformed[:-len("text_embedding.0")] + "text_embedding_0"
|
||||||
|
if transformed.endswith("text_embedding.2"):
|
||||||
|
transformed = transformed[:-len("text_embedding.2")] + "text_embedding_1"
|
||||||
|
|
||||||
|
# Time embedding: time_embedding.0 → time_embedding_0
|
||||||
|
transformed = transformed.replace("time_embedding.0.", "time_embedding_0.")
|
||||||
|
transformed = transformed.replace("time_embedding.2.", "time_embedding_1.")
|
||||||
|
if transformed.endswith("time_embedding.0"):
|
||||||
|
transformed = transformed[:-len("time_embedding.0")] + "time_embedding_0"
|
||||||
|
if transformed.endswith("time_embedding.2"):
|
||||||
|
transformed = transformed[:-len("time_embedding.2")] + "time_embedding_1"
|
||||||
|
|
||||||
|
# Time projection: time_projection.1 → time_projection
|
||||||
|
transformed = transformed.replace("time_projection.1.", "time_projection.")
|
||||||
|
if transformed.endswith("time_projection.1"):
|
||||||
|
transformed = transformed[:-len("time_projection.1")] + "time_projection"
|
||||||
|
|
||||||
|
# Patch embedding: patch_embedding → patch_embedding_proj
|
||||||
|
if "patch_embedding" in transformed and "patch_embedding_proj" not in transformed:
|
||||||
|
transformed = transformed.replace("patch_embedding", "patch_embedding_proj")
|
||||||
|
|
||||||
|
if f"{transformed}.weight" in model_keys or transformed in model_keys:
|
||||||
|
return transformed
|
||||||
|
|
||||||
|
# Return best attempt with prefix stripped
|
||||||
|
for prefix in prefixes_to_strip:
|
||||||
|
if lora_key.startswith(prefix):
|
||||||
|
return lora_key[len(prefix):]
|
||||||
|
|
||||||
|
return lora_key
|
||||||
|
|
||||||
|
|
||||||
|
# Also support LTX-style key normalization
|
||||||
|
def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str:
|
||||||
|
"""Normalize LoRA module name to match LTX MLX model weight keys."""
|
||||||
|
if f"{lora_key}.weight" in model_keys or lora_key in model_keys:
|
||||||
|
return lora_key
|
||||||
|
|
||||||
|
prefixes_to_strip = [
|
||||||
|
"model.diffusion_model.",
|
||||||
|
"diffusion_model.",
|
||||||
|
"model.",
|
||||||
|
]
|
||||||
|
|
||||||
|
for prefix in prefixes_to_strip:
|
||||||
|
if lora_key.startswith(prefix):
|
||||||
|
normalized = lora_key[len(prefix):]
|
||||||
|
|
||||||
|
if f"{normalized}.weight" in model_keys or normalized in model_keys:
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
transformed = normalized
|
||||||
|
if transformed.endswith(".to_out.0"):
|
||||||
|
transformed = transformed[:-len(".to_out.0")] + ".to_out"
|
||||||
|
transformed = transformed.replace(".to_out.0.", ".to_out.")
|
||||||
|
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||||
|
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
|
||||||
|
transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.")
|
||||||
|
transformed = transformed.replace(".ff.net.2", ".ff.proj_out")
|
||||||
|
transformed = transformed.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
|
||||||
|
transformed = transformed.replace(".audio_ff.net.0.proj", ".audio_ff.proj_in")
|
||||||
|
transformed = transformed.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
||||||
|
transformed = transformed.replace(".audio_ff.net.2", ".audio_ff.proj_out")
|
||||||
|
|
||||||
|
if f"{transformed}.weight" in model_keys or transformed in model_keys:
|
||||||
|
return transformed
|
||||||
|
|
||||||
|
# Try transformations on the original key
|
||||||
|
transformed = lora_key
|
||||||
|
if transformed.endswith(".to_out.0"):
|
||||||
|
transformed = transformed[:-len(".to_out.0")] + ".to_out"
|
||||||
|
transformed = transformed.replace(".to_out.0.", ".to_out.")
|
||||||
|
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||||
|
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
|
||||||
|
transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.")
|
||||||
|
transformed = transformed.replace(".ff.net.2", ".ff.proj_out")
|
||||||
|
|
||||||
|
if f"{transformed}.weight" in model_keys or transformed in model_keys:
|
||||||
|
return transformed
|
||||||
|
|
||||||
|
for prefix in prefixes_to_strip:
|
||||||
|
if lora_key.startswith(prefix):
|
||||||
|
return lora_key[len(prefix):]
|
||||||
|
|
||||||
|
return lora_key
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_lora_key(lora_key: str, model_keys: set) -> str:
|
||||||
|
"""Normalize LoRA module name to match model weight keys.
|
||||||
|
|
||||||
|
Auto-detects whether to use Wan2.2 or LTX key normalization based
|
||||||
|
on the presence of architecture-specific keys in the model.
|
||||||
|
"""
|
||||||
|
# Detect model architecture from keys
|
||||||
|
is_wan = any("self_attn.q.weight" in k for k in model_keys)
|
||||||
|
|
||||||
|
if is_wan:
|
||||||
|
return _normalize_wan_lora_key(lora_key, model_keys)
|
||||||
|
else:
|
||||||
|
return _normalize_ltx_lora_key(lora_key, model_keys)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_loras_to_weights(
|
||||||
|
model_weights: Dict[str, mx.array],
|
||||||
|
module_to_loras: Dict[str, List[Tuple[LoRAWeights, float]]],
|
||||||
|
verbose: bool = False,
|
||||||
|
quantization_bits: int = 0,
|
||||||
|
) -> Dict[str, mx.array]:
|
||||||
|
"""Apply LoRAs to model weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_weights: Original model state dictionary
|
||||||
|
module_to_loras: Dictionary mapping module names to lists of
|
||||||
|
(LoRAWeights, strength) tuples
|
||||||
|
verbose: If True, print detailed debug information
|
||||||
|
quantization_bits: If >0, weights are quantized at this bit width.
|
||||||
|
Quantized layers are dequantized before LoRA application
|
||||||
|
and re-quantized after.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New state dictionary with LoRA-modified weights
|
||||||
|
"""
|
||||||
|
modified_weights = dict(model_weights)
|
||||||
|
model_keys = set(model_weights.keys())
|
||||||
|
|
||||||
|
applied_count = 0
|
||||||
|
skipped_count = 0
|
||||||
|
skipped_modules = []
|
||||||
|
|
||||||
|
for module_name, loras in module_to_loras.items():
|
||||||
|
normalized_name = _normalize_lora_key(module_name, model_keys)
|
||||||
|
weight_key = f"{normalized_name}.weight"
|
||||||
|
|
||||||
|
if weight_key not in modified_weights:
|
||||||
|
if normalized_name not in modified_weights:
|
||||||
|
skipped_count += 1
|
||||||
|
skipped_modules.append(module_name)
|
||||||
|
if verbose and skipped_count <= 5:
|
||||||
|
print(f" DEBUG: '{module_name}' -> '{normalized_name}' -> NOT FOUND")
|
||||||
|
similar = [
|
||||||
|
k
|
||||||
|
for k in list(model_keys)[:1000]
|
||||||
|
if normalized_name.split(".")[-1] in k
|
||||||
|
][:3]
|
||||||
|
if similar:
|
||||||
|
print(f" Similar keys: {similar}")
|
||||||
|
continue
|
||||||
|
weight_key = normalized_name
|
||||||
|
|
||||||
|
original_weight = modified_weights[weight_key]
|
||||||
|
|
||||||
|
# Handle quantized weights: dequantize → apply delta → re-quantize
|
||||||
|
scales_key = f"{normalized_name}.scales"
|
||||||
|
biases_key = f"{normalized_name}.biases"
|
||||||
|
is_quantized = (
|
||||||
|
original_weight.dtype == mx.uint32
|
||||||
|
and scales_key in modified_weights
|
||||||
|
and biases_key in modified_weights
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_quantized:
|
||||||
|
scales = modified_weights[scales_key]
|
||||||
|
biases = modified_weights[biases_key]
|
||||||
|
group_size = (original_weight.shape[-1] * 32) // (scales.shape[-1] * quantization_bits)
|
||||||
|
dequantized = mx.dequantize(
|
||||||
|
original_weight, scales, biases, group_size=group_size, bits=quantization_bits
|
||||||
|
)
|
||||||
|
modified = apply_lora_to_linear(dequantized, loras)
|
||||||
|
# Re-quantize with same parameters
|
||||||
|
new_w, new_scales, new_biases = mx.quantize(modified, group_size=group_size, bits=quantization_bits)
|
||||||
|
modified_weights[weight_key] = new_w
|
||||||
|
modified_weights[scales_key] = new_scales
|
||||||
|
modified_weights[biases_key] = new_biases
|
||||||
|
else:
|
||||||
|
modified_weights[weight_key] = apply_lora_to_linear(original_weight, loras)
|
||||||
|
|
||||||
|
applied_count += 1
|
||||||
|
|
||||||
|
if applied_count > 0:
|
||||||
|
print(f" ✓ Applied to {applied_count} modules")
|
||||||
|
if skipped_count > 0:
|
||||||
|
print(f" ⚠ Skipped {skipped_count} incompatible modules")
|
||||||
|
|
||||||
|
return modified_weights
|
||||||
|
|
||||||
|
|
||||||
|
class LoRALinear(nn.Module):
|
||||||
|
"""Linear layer with on-the-fly LoRA application.
|
||||||
|
|
||||||
|
Wraps nn.Linear or nn.QuantizedLinear, computing LoRA delta at runtime:
|
||||||
|
output = base_linear(x) + (x @ lora_A.T @ lora_B.T) * scale * strength
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
linear: nn.Module,
|
||||||
|
lora_weights_and_strengths: List[Tuple[LoRAWeights, float]],
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = linear
|
||||||
|
self.lora_weights_and_strengths = lora_weights_and_strengths
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
output = self.linear(x)
|
||||||
|
for weights, strength in self.lora_weights_and_strengths:
|
||||||
|
scale = weights.scale
|
||||||
|
lora_out = x @ weights.lora_A.T @ weights.lora_B.T
|
||||||
|
output = output + (scale * strength * lora_out)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def apply_loras_to_model(
|
||||||
|
model: nn.Module,
|
||||||
|
module_to_loras: Dict[str, List[Tuple[LoRAWeights, float]]],
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> int:
|
||||||
|
"""Apply LoRAs to a model by merging into weights.
|
||||||
|
|
||||||
|
For QuantizedLinear layers: dequantizes to bf16, merges LoRA delta, and
|
||||||
|
replaces with a regular nn.Linear (no per-step overhead, no re-quantization
|
||||||
|
precision loss). Non-LoRA layers stay quantized.
|
||||||
|
|
||||||
|
For nn.Linear layers: merges LoRA delta directly into the weight.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to apply LoRAs to
|
||||||
|
module_to_loras: Dictionary mapping module names to (LoRAWeights, strength) lists
|
||||||
|
verbose: Print debug info
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of modules modified
|
||||||
|
"""
|
||||||
|
# Build a set of model module paths for key normalization
|
||||||
|
module_paths = set()
|
||||||
|
for name, _ in model.named_modules():
|
||||||
|
module_paths.add(name)
|
||||||
|
module_paths.add(f"{name}.weight")
|
||||||
|
|
||||||
|
# Map LoRA keys → model module paths
|
||||||
|
lora_to_module = {}
|
||||||
|
for lora_key in module_to_loras:
|
||||||
|
normalized = _normalize_lora_key(lora_key, module_paths)
|
||||||
|
if normalized.endswith(".weight"):
|
||||||
|
normalized = normalized[: -len(".weight")]
|
||||||
|
lora_to_module[lora_key] = normalized
|
||||||
|
|
||||||
|
applied_count = 0
|
||||||
|
dequant_count = 0
|
||||||
|
skipped = []
|
||||||
|
|
||||||
|
for lora_key, loras in module_to_loras.items():
|
||||||
|
module_path = lora_to_module[lora_key]
|
||||||
|
parts = module_path.split(".")
|
||||||
|
|
||||||
|
# Traverse to the parent module
|
||||||
|
parent = model
|
||||||
|
try:
|
||||||
|
for part in parts[:-1]:
|
||||||
|
parent = getattr(parent, part) if not part.isdigit() else parent[int(part)]
|
||||||
|
leaf_name = parts[-1]
|
||||||
|
target = getattr(parent, leaf_name) if not leaf_name.isdigit() else parent[int(leaf_name)]
|
||||||
|
except (AttributeError, IndexError, TypeError):
|
||||||
|
skipped.append(lora_key)
|
||||||
|
if verbose:
|
||||||
|
print(f" DEBUG: '{lora_key}' -> '{module_path}' -> module not found")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(target, nn.QuantizedLinear):
|
||||||
|
# Dequantize → merge LoRA → replace with bf16 Linear
|
||||||
|
weight = mx.dequantize(
|
||||||
|
target.weight, target.scales, target.biases,
|
||||||
|
group_size=target.group_size, bits=target.bits,
|
||||||
|
)
|
||||||
|
merged = apply_lora_to_linear(weight, loras)
|
||||||
|
new_linear = nn.Linear(merged.shape[1], merged.shape[0])
|
||||||
|
new_linear.weight = merged
|
||||||
|
if "bias" in target:
|
||||||
|
new_linear.bias = target.bias
|
||||||
|
if leaf_name.isdigit():
|
||||||
|
parent[int(leaf_name)] = new_linear
|
||||||
|
else:
|
||||||
|
setattr(parent, leaf_name, new_linear)
|
||||||
|
dequant_count += 1
|
||||||
|
applied_count += 1
|
||||||
|
elif isinstance(target, nn.Linear):
|
||||||
|
# Merge directly into weight
|
||||||
|
target.weight = apply_lora_to_linear(target.weight, loras)
|
||||||
|
applied_count += 1
|
||||||
|
else:
|
||||||
|
skipped.append(lora_key)
|
||||||
|
if verbose:
|
||||||
|
print(f" DEBUG: '{module_path}' is {type(target).__name__}, not Linear")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if applied_count > 0:
|
||||||
|
msg = f" ✓ Applied to {applied_count} modules"
|
||||||
|
if dequant_count > 0:
|
||||||
|
msg += f" ({dequant_count} dequantized to bf16)"
|
||||||
|
print(msg)
|
||||||
|
if skipped:
|
||||||
|
print(f" ⚠ Skipped {len(skipped)} incompatible modules")
|
||||||
|
|
||||||
|
return applied_count
|
||||||
122
mlx_video/lora/loader.py
Normal file
122
mlx_video/lora/loader.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
"""LoRA weight loading utilities."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
from mlx_video.lora.types import LoRAConfig, LoRAWeights
|
||||||
|
|
||||||
|
|
||||||
|
def load_lora_weights(lora_path: Path) -> Dict[str, LoRAWeights]:
|
||||||
|
"""Load LoRA weights from a safetensors file.
|
||||||
|
|
||||||
|
Supports both key conventions:
|
||||||
|
- {module_name}.lora_A.weight / {module_name}.lora_B.weight
|
||||||
|
- {module_name}.lora_down.weight / {module_name}.lora_up.weight
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lora_path: Path to the LoRA safetensors file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping module names to LoRAWeights objects
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the LoRA file doesn't exist
|
||||||
|
ValueError: If the LoRA file format is invalid
|
||||||
|
"""
|
||||||
|
if not lora_path.exists():
|
||||||
|
raise FileNotFoundError(f"LoRA file not found: {lora_path}")
|
||||||
|
|
||||||
|
all_weights = mx.load(str(lora_path))
|
||||||
|
|
||||||
|
# Group weights by module name, handling both naming conventions
|
||||||
|
lora_weights = {}
|
||||||
|
module_names = set()
|
||||||
|
|
||||||
|
for key in all_weights.keys():
|
||||||
|
# Format 1: {module}.lora_A.weight / {module}.lora_B.weight
|
||||||
|
match = re.match(r"(.+)\.lora_([AB])\.weight$", key)
|
||||||
|
if match:
|
||||||
|
module_names.add(match.group(1))
|
||||||
|
continue
|
||||||
|
# Format 2: {module}.lora_down.weight / {module}.lora_up.weight
|
||||||
|
match = re.match(r"(.+)\.lora_(down|up)\.weight$", key)
|
||||||
|
if match:
|
||||||
|
module_names.add(match.group(1))
|
||||||
|
|
||||||
|
for module_name in module_names:
|
||||||
|
# Try both key conventions
|
||||||
|
key_a = f"{module_name}.lora_A.weight"
|
||||||
|
key_b = f"{module_name}.lora_B.weight"
|
||||||
|
if key_a not in all_weights or key_b not in all_weights:
|
||||||
|
key_a = f"{module_name}.lora_down.weight"
|
||||||
|
key_b = f"{module_name}.lora_up.weight"
|
||||||
|
if key_a not in all_weights or key_b not in all_weights:
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora_a = all_weights[key_a]
|
||||||
|
lora_b = all_weights[key_b]
|
||||||
|
|
||||||
|
if lora_a.ndim != 2 or lora_b.ndim != 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid LoRA shape for {module_name}: "
|
||||||
|
f"lora_A={lora_a.shape}, lora_B={lora_b.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
rank = lora_a.shape[0]
|
||||||
|
if lora_b.shape[1] != rank:
|
||||||
|
raise ValueError(
|
||||||
|
f"LoRA rank mismatch for {module_name}: "
|
||||||
|
f"lora_A rank={rank}, lora_B rank={lora_b.shape[1]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for per-module alpha stored as a scalar tensor
|
||||||
|
alpha_key = f"{module_name}.alpha"
|
||||||
|
if alpha_key in all_weights:
|
||||||
|
alpha = float(all_weights[alpha_key].item())
|
||||||
|
else:
|
||||||
|
alpha = float(rank)
|
||||||
|
|
||||||
|
lora_weights[module_name] = LoRAWeights(
|
||||||
|
lora_A=lora_a,
|
||||||
|
lora_B=lora_b,
|
||||||
|
rank=rank,
|
||||||
|
alpha=alpha,
|
||||||
|
module_name=module_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not lora_weights:
|
||||||
|
raise ValueError(f"No valid LoRA weights found in {lora_path}")
|
||||||
|
|
||||||
|
return lora_weights
|
||||||
|
|
||||||
|
|
||||||
|
def load_multiple_loras(
|
||||||
|
configs: List[LoRAConfig],
|
||||||
|
) -> Dict[str, List[tuple]]:
|
||||||
|
"""Load multiple LoRA configurations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
configs: List of LoRAConfig objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping module names to lists of (LoRAWeights, strength) tuples.
|
||||||
|
"""
|
||||||
|
module_to_loras: Dict[str, list] = {}
|
||||||
|
|
||||||
|
for config in configs:
|
||||||
|
lora_weights = load_lora_weights(config.path)
|
||||||
|
|
||||||
|
for module_name, weights in lora_weights.items():
|
||||||
|
if config.target_modules is not None:
|
||||||
|
if module_name not in config.target_modules:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if module_name not in module_to_loras:
|
||||||
|
module_to_loras[module_name] = []
|
||||||
|
|
||||||
|
module_to_loras[module_name].append((weights, config.strength))
|
||||||
|
|
||||||
|
return module_to_loras
|
||||||
74
mlx_video/lora/types.py
Normal file
74
mlx_video/lora/types.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""Data structures for LoRA support."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoRAWeights:
|
||||||
|
"""Container for LoRA weight matrices.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
lora_A: Low-rank matrix A of shape [rank, in_features]
|
||||||
|
lora_B: Low-rank matrix B of shape [out_features, rank]
|
||||||
|
rank: Rank of the LoRA decomposition
|
||||||
|
alpha: LoRA scaling parameter (default: rank)
|
||||||
|
module_name: Target module name in the model
|
||||||
|
"""
|
||||||
|
|
||||||
|
lora_A: mx.array
|
||||||
|
lora_B: mx.array
|
||||||
|
rank: int
|
||||||
|
alpha: float
|
||||||
|
module_name: str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scale(self) -> float:
|
||||||
|
"""Compute the scale factor: alpha / rank."""
|
||||||
|
return self.alpha / self.rank
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoRAConfig:
|
||||||
|
"""Configuration for a single LoRA.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
path: Path to the LoRA safetensors file
|
||||||
|
strength: Strength/weight to apply this LoRA (typically 0.0-2.0)
|
||||||
|
target_modules: Optional list of module names to apply LoRA to.
|
||||||
|
If None, applies to all available modules in the LoRA.
|
||||||
|
"""
|
||||||
|
|
||||||
|
path: Path
|
||||||
|
strength: float = 1.0
|
||||||
|
target_modules: Optional[list[str]] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Validate and normalize the configuration."""
|
||||||
|
self.path = Path(self.path)
|
||||||
|
if not self.path.exists():
|
||||||
|
raise FileNotFoundError(f"LoRA file not found: {self.path}")
|
||||||
|
if self.strength < 0:
|
||||||
|
raise ValueError(f"LoRA strength must be non-negative, got {self.strength}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AppliedLoRA:
|
||||||
|
"""Represents a LoRA applied to a specific module.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
weights: The LoRA weight matrices
|
||||||
|
strength: Application strength for this LoRA
|
||||||
|
"""
|
||||||
|
|
||||||
|
weights: LoRAWeights
|
||||||
|
strength: float
|
||||||
|
|
||||||
|
def compute_delta(self) -> mx.array:
|
||||||
|
"""Compute the weight delta: strength * scale * (lora_B @ lora_A)."""
|
||||||
|
scale = self.weights.scale
|
||||||
|
delta = self.weights.lora_B @ self.weights.lora_A
|
||||||
|
return scale * self.strength * delta
|
||||||
@@ -4,6 +4,15 @@ import mlx.nn as nn
|
|||||||
from .rope import rope_apply
|
from .rope import rope_apply
|
||||||
|
|
||||||
|
|
||||||
|
def _linear_dtype(layer) -> mx.Dtype:
|
||||||
|
"""Get the compute dtype of a linear layer, handling QuantizedLinear and LoRA wrappers."""
|
||||||
|
# Unwrap LoRA wrapper to get the underlying linear layer
|
||||||
|
inner = getattr(layer, "linear", layer)
|
||||||
|
if isinstance(inner, nn.QuantizedLinear):
|
||||||
|
return inner.scales.dtype
|
||||||
|
return inner.weight.dtype
|
||||||
|
|
||||||
|
|
||||||
class WanRMSNorm(nn.Module):
|
class WanRMSNorm(nn.Module):
|
||||||
"""RMS normalization with learnable scale."""
|
"""RMS normalization with learnable scale."""
|
||||||
|
|
||||||
@@ -73,8 +82,8 @@ class WanSelfAttention(nn.Module):
|
|||||||
b, s, _ = x.shape
|
b, s, _ = x.shape
|
||||||
n, d = self.num_heads, self.head_dim
|
n, d = self.num_heads, self.head_dim
|
||||||
|
|
||||||
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
|
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
|
||||||
w_dtype = self.q.weight.dtype
|
w_dtype = _linear_dtype(self.q)
|
||||||
x_w = x.astype(w_dtype)
|
x_w = x.astype(w_dtype)
|
||||||
|
|
||||||
q = self.q(x_w)
|
q = self.q(x_w)
|
||||||
@@ -154,8 +163,8 @@ class WanCrossAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
b = context.shape[0]
|
b = context.shape[0]
|
||||||
n, d = self.num_heads, self.head_dim
|
n, d = self.num_heads, self.head_dim
|
||||||
# Cast to weight dtype for efficient matmul
|
# Cast to compute dtype for efficient matmul
|
||||||
w_dtype = self.k.weight.dtype
|
w_dtype = _linear_dtype(self.k)
|
||||||
ctx = context.astype(w_dtype)
|
ctx = context.astype(w_dtype)
|
||||||
k = self.k(ctx)
|
k = self.k(ctx)
|
||||||
if self.norm_k is not None:
|
if self.norm_k is not None:
|
||||||
@@ -174,8 +183,8 @@ class WanCrossAttention(nn.Module):
|
|||||||
b = x.shape[0]
|
b = x.shape[0]
|
||||||
n, d = self.num_heads, self.head_dim
|
n, d = self.num_heads, self.head_dim
|
||||||
|
|
||||||
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
|
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
|
||||||
w_dtype = self.q.weight.dtype
|
w_dtype = _linear_dtype(self.q)
|
||||||
q = self.q(x.astype(w_dtype))
|
q = self.q(x.astype(w_dtype))
|
||||||
if self.norm_q is not None:
|
if self.norm_q is not None:
|
||||||
q = self.norm_q(q)
|
q = self.norm_q(q)
|
||||||
|
|||||||
@@ -6,14 +6,15 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
def load_wan_model(model_path: Path, config, quantization: dict | None = None):
|
def load_wan_model(model_path: Path, config, quantization: dict | None = None, loras: list | None = None):
|
||||||
"""Load and initialize WanModel, with optional quantization support.
|
"""Load and initialize WanModel, with optional quantization and LoRA support.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path: Path to model safetensors file
|
model_path: Path to model safetensors file
|
||||||
config: WanModelConfig
|
config: WanModelConfig
|
||||||
quantization: Optional dict with 'bits' and 'group_size' keys.
|
quantization: Optional dict with 'bits' and 'group_size' keys.
|
||||||
If provided, creates QuantizedLinear stubs before loading.
|
If provided, creates QuantizedLinear stubs before loading.
|
||||||
|
loras: Optional list of (lora_path, strength) tuples to apply.
|
||||||
"""
|
"""
|
||||||
from mlx_video.models.wan.model import WanModel
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
@@ -30,6 +31,27 @@ def load_wan_model(model_path: Path, config, quantization: dict | None = None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
weights = mx.load(str(model_path))
|
weights = mx.load(str(model_path))
|
||||||
|
|
||||||
|
# Apply LoRAs: dequantize+merge for quantized models, weight merge for bf16
|
||||||
|
if loras:
|
||||||
|
if quantization:
|
||||||
|
# Dequantize LoRA-targeted layers, merge delta, replace with bf16 Linear.
|
||||||
|
# Non-LoRA layers stay 4-bit. Zero per-step overhead.
|
||||||
|
from mlx_video.convert_wan import _load_lora_configs
|
||||||
|
from mlx_video.lora import apply_loras_to_model
|
||||||
|
|
||||||
|
model.load_weights(list(weights.items()), strict=False)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
module_to_loras = _load_lora_configs(loras)
|
||||||
|
apply_loras_to_model(model, module_to_loras)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
return model
|
||||||
|
else:
|
||||||
|
# Weight merging: fold LoRA into bf16 weights before loading
|
||||||
|
from mlx_video.convert_wan import load_and_apply_loras
|
||||||
|
|
||||||
|
weights = load_and_apply_loras(dict(weights), loras)
|
||||||
|
|
||||||
model.load_weights(list(weights.items()), strict=False)
|
model.load_weights(list(weights.items()), strict=False)
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
return model
|
return model
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .attention import WanLayerNorm
|
from .attention import WanLayerNorm, _linear_dtype
|
||||||
from .config import WanModelConfig
|
from .config import WanModelConfig
|
||||||
from .rope import rope_params, rope_precompute_cos_sin
|
from .rope import rope_params, rope_precompute_cos_sin
|
||||||
from .transformer import WanAttentionBlock
|
from .transformer import WanAttentionBlock
|
||||||
@@ -54,7 +54,7 @@ class Head(nn.Module):
|
|||||||
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
|
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
|
||||||
x_norm = self.norm(x)
|
x_norm = self.norm(x)
|
||||||
x_mod = x_norm * (1 + e1) + e0 # type promotion handles bf16→f32
|
x_mod = x_norm * (1 + e1) + e0 # type promotion handles bf16→f32
|
||||||
return self.head(x_mod.astype(self.head.weight.dtype))
|
return self.head(x_mod.astype(_linear_dtype(self.head)))
|
||||||
|
|
||||||
|
|
||||||
class WanModel(nn.Module):
|
class WanModel(nn.Module):
|
||||||
@@ -79,7 +79,7 @@ class WanModel(nn.Module):
|
|||||||
|
|
||||||
# Text embedding MLP
|
# Text embedding MLP
|
||||||
self.text_embedding_0 = nn.Linear(config.text_dim, dim)
|
self.text_embedding_0 = nn.Linear(config.text_dim, dim)
|
||||||
self.text_embedding_act = nn.GELU(approx="precise")
|
self.text_embedding_act = nn.GELU(approx="tanh")
|
||||||
self.text_embedding_1 = nn.Linear(dim, dim)
|
self.text_embedding_1 = nn.Linear(dim, dim)
|
||||||
|
|
||||||
# Time embedding MLP
|
# Time embedding MLP
|
||||||
@@ -149,7 +149,7 @@ class WanModel(nn.Module):
|
|||||||
|
|
||||||
# Project and cast to model dtype to prevent float32 cascade from input latents
|
# Project and cast to model dtype to prevent float32 cascade from input latents
|
||||||
patches = self.patch_embedding_proj(x) # [L, dim]
|
patches = self.patch_embedding_proj(x) # [L, dim]
|
||||||
patches = patches.astype(self.patch_embedding_proj.weight.dtype)
|
patches = patches.astype(_linear_dtype(self.patch_embedding_proj))
|
||||||
patches = patches[None, :, :] # [1, L, dim]
|
patches = patches[None, :, :] # [1, L, dim]
|
||||||
|
|
||||||
return patches, (f_out, h_out, w_out)
|
return patches, (f_out, h_out, w_out)
|
||||||
@@ -186,7 +186,7 @@ class WanModel(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Embedded context [B, text_len, dim] in model dtype
|
Embedded context [B, text_len, dim] in model dtype
|
||||||
"""
|
"""
|
||||||
model_dtype = self.patch_embedding_proj.weight.dtype
|
model_dtype = _linear_dtype(self.patch_embedding_proj)
|
||||||
context_padded = []
|
context_padded = []
|
||||||
for ctx in context:
|
for ctx in context:
|
||||||
pad_len = self.text_len - ctx.shape[0]
|
pad_len = self.text_len - ctx.shape[0]
|
||||||
@@ -231,7 +231,7 @@ class WanModel(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
(cos_f, sin_f) precomputed frequency tensors
|
(cos_f, sin_f) precomputed frequency tensors
|
||||||
"""
|
"""
|
||||||
w_dtype = self.patch_embedding_proj.weight.dtype
|
w_dtype = _linear_dtype(self.patch_embedding_proj)
|
||||||
return rope_precompute_cos_sin(grid_sizes, self.freqs, dtype=w_dtype)
|
return rope_precompute_cos_sin(grid_sizes, self.freqs, dtype=w_dtype)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@@ -348,7 +348,7 @@ class WanModel(nn.Module):
|
|||||||
|
|
||||||
# Pre-compute attention mask from seq_lens (constant across all blocks)
|
# Pre-compute attention mask from seq_lens (constant across all blocks)
|
||||||
attn_mask = None
|
attn_mask = None
|
||||||
w_dtype = self.patch_embedding_proj.weight.dtype
|
w_dtype = _linear_dtype(self.patch_embedding_proj)
|
||||||
if any(sl < seq_len for sl in seq_lens_list):
|
if any(sl < seq_len for sl in seq_lens_list):
|
||||||
attn_mask = mx.zeros((batch_size, 1, 1, seq_len), dtype=w_dtype)
|
attn_mask = mx.zeros((batch_size, 1, 1, seq_len), dtype=w_dtype)
|
||||||
for i, sl in enumerate(seq_lens_list):
|
for i, sl in enumerate(seq_lens_list):
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ class T5FeedForward(nn.Module):
|
|||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.dim_ffn = dim_ffn
|
self.dim_ffn = dim_ffn
|
||||||
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
|
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
|
||||||
self.gate_act = nn.GELU(approx="precise")
|
self.gate_act = nn.GELU(approx="tanh")
|
||||||
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
||||||
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention
|
from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention, _linear_dtype
|
||||||
|
|
||||||
|
|
||||||
class WanAttentionBlock(nn.Module):
|
class WanAttentionBlock(nn.Module):
|
||||||
@@ -84,10 +84,10 @@ class WanFFN(nn.Module):
|
|||||||
def __init__(self, dim: int, ffn_dim: int):
|
def __init__(self, dim: int, ffn_dim: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fc1 = nn.Linear(dim, ffn_dim)
|
self.fc1 = nn.Linear(dim, ffn_dim)
|
||||||
self.act = nn.GELU(approx="precise")
|
self.act = nn.GELU(approx="tanh")
|
||||||
self.fc2 = nn.Linear(ffn_dim, dim)
|
self.fc2 = nn.Linear(ffn_dim, dim)
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
|
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
|
||||||
x_w = x.astype(self.fc1.weight.dtype)
|
x_w = x.astype(_linear_dtype(self.fc1))
|
||||||
return self.fc2(self.act(self.fc1(x_w)))
|
return self.fc2(self.act(self.fc1(x_w)))
|
||||||
|
|||||||
@@ -7,12 +7,15 @@ Weight keys mirror the PyTorch checkpoint hierarchy so only tensor format
|
|||||||
conversion (channels-first → channels-last) is needed.
|
conversion (channels-first → channels-last) is needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
CACHE_T = 2
|
CACHE_T = 2
|
||||||
|
|
||||||
# Per-channel normalization for z_dim=48 latent space
|
# Per-channel normalization for z_dim=48 latent space
|
||||||
@@ -774,11 +777,13 @@ def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) ->
|
|||||||
Maps PyTorch nn.Sequential indices to our named layers.
|
Maps PyTorch nn.Sequential indices to our named layers.
|
||||||
"""
|
"""
|
||||||
sanitized = {}
|
sanitized = {}
|
||||||
|
consumed = set()
|
||||||
|
|
||||||
for key, value in weights.items():
|
for key, value in weights.items():
|
||||||
# Skip encoder and conv1 unless requested
|
# Skip encoder and conv1 unless requested
|
||||||
if not include_encoder:
|
if not include_encoder:
|
||||||
if key.startswith("encoder.") or key.startswith("conv1."):
|
if key.startswith("encoder.") or key.startswith("conv1."):
|
||||||
|
consumed.add(key)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
new_key = key
|
new_key = key
|
||||||
@@ -832,5 +837,10 @@ def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) ->
|
|||||||
value = mx.array(np.array(value).squeeze())
|
value = mx.array(np.array(value).squeeze())
|
||||||
|
|
||||||
sanitized[new_key] = value
|
sanitized[new_key] = value
|
||||||
|
consumed.add(key)
|
||||||
|
|
||||||
|
unconsumed = set(weights.keys()) - consumed
|
||||||
|
if unconsumed:
|
||||||
|
logger.warning("Unconsumed Wan2.2 VAE weight keys: %s", sorted(unconsumed))
|
||||||
|
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Tests for Wan weight conversion utilities."""
|
"""Tests for Wan weight conversion utilities."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@@ -94,6 +96,27 @@ class TestSanitizeTransformerWeights:
|
|||||||
for key in weights:
|
for key in weights:
|
||||||
assert key in out
|
assert key in out
|
||||||
|
|
||||||
|
def test_no_unconsumed_keys(self, caplog):
|
||||||
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||||
|
weights = {
|
||||||
|
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
|
||||||
|
"patch_embedding.bias": mx.random.normal((5120,)),
|
||||||
|
"text_embedding.0.weight": mx.zeros((64, 32)),
|
||||||
|
"text_embedding.2.weight": mx.zeros((64, 64)),
|
||||||
|
"time_embedding.0.weight": mx.zeros((64, 32)),
|
||||||
|
"time_embedding.2.weight": mx.zeros((64, 64)),
|
||||||
|
"time_projection.1.weight": mx.zeros((384, 64)),
|
||||||
|
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
|
||||||
|
"blocks.0.ffn.2.weight": mx.zeros((64, 128)),
|
||||||
|
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
|
||||||
|
"blocks.0.modulation": mx.zeros((1, 6, 64)),
|
||||||
|
"head.head.weight": mx.zeros((64, 64)),
|
||||||
|
"freqs": mx.zeros((1024, 64, 2)),
|
||||||
|
}
|
||||||
|
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
|
||||||
|
sanitize_wan_transformer_weights(weights)
|
||||||
|
assert "Unconsumed" not in caplog.text
|
||||||
|
|
||||||
|
|
||||||
class TestSanitizeT5Weights:
|
class TestSanitizeT5Weights:
|
||||||
def test_gate_rename(self):
|
def test_gate_rename(self):
|
||||||
@@ -119,6 +142,19 @@ class TestSanitizeT5Weights:
|
|||||||
for key in weights:
|
for key in weights:
|
||||||
assert key in out
|
assert key in out
|
||||||
|
|
||||||
|
def test_no_unconsumed_keys(self, caplog):
|
||||||
|
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
||||||
|
weights = {
|
||||||
|
"token_embedding.weight": mx.zeros((100, 64)),
|
||||||
|
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
|
||||||
|
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
|
||||||
|
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
|
||||||
|
"norm.weight": mx.zeros((64,)),
|
||||||
|
}
|
||||||
|
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
|
||||||
|
sanitize_wan_t5_weights(weights)
|
||||||
|
assert "Unconsumed" not in caplog.text
|
||||||
|
|
||||||
|
|
||||||
class TestSanitizeVAEWeights:
|
class TestSanitizeVAEWeights:
|
||||||
def test_conv3d_transpose(self):
|
def test_conv3d_transpose(self):
|
||||||
@@ -161,6 +197,18 @@ class TestSanitizeVAEWeights:
|
|||||||
assert out["linear.weight"].shape == (8, 4)
|
assert out["linear.weight"].shape == (8, 4)
|
||||||
assert out["norm.weight"].shape == (8,)
|
assert out["norm.weight"].shape == (8,)
|
||||||
|
|
||||||
|
def test_no_unconsumed_keys(self, caplog):
|
||||||
|
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||||
|
weights = {
|
||||||
|
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)),
|
||||||
|
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)),
|
||||||
|
"decoder.norm.weight": mx.zeros((64,)),
|
||||||
|
"decoder.bias": mx.zeros((16,)),
|
||||||
|
}
|
||||||
|
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
|
||||||
|
sanitize_wan_vae_weights(weights)
|
||||||
|
assert "Unconsumed" not in caplog.text
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Wan2.1 Conversion Tests
|
# Wan2.1 Conversion Tests
|
||||||
@@ -233,3 +281,27 @@ class TestSanitizeEncoderWeights:
|
|||||||
assert "encoder.conv1.weight" in out
|
assert "encoder.conv1.weight" in out
|
||||||
assert "conv1.weight" in out
|
assert "conv1.weight" in out
|
||||||
assert "conv2.weight" in out
|
assert "conv2.weight" in out
|
||||||
|
|
||||||
|
def test_no_unconsumed_keys(self, caplog):
|
||||||
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||||
|
|
||||||
|
weights = {
|
||||||
|
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||||
|
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||||
|
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||||
|
}
|
||||||
|
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
|
||||||
|
sanitize_wan22_vae_weights(weights, include_encoder=True)
|
||||||
|
assert "Unconsumed" not in caplog.text
|
||||||
|
|
||||||
|
def test_no_unconsumed_keys_exclude_encoder(self, caplog):
|
||||||
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||||
|
|
||||||
|
weights = {
|
||||||
|
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||||
|
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||||
|
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||||
|
}
|
||||||
|
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
|
||||||
|
sanitize_wan22_vae_weights(weights, include_encoder=False)
|
||||||
|
assert "Unconsumed" not in caplog.text
|
||||||
|
|||||||
@@ -291,3 +291,282 @@ class TestI2VMaskConstruction:
|
|||||||
encoded = mx.zeros((16, 5, 10, 18))
|
encoded = mx.zeros((16, 5, 10, 18))
|
||||||
y = mx.concatenate([mask, encoded], axis=0)
|
y = mx.concatenate([mask, encoded], axis=0)
|
||||||
assert y.shape == (20, 5, 10, 18)
|
assert y.shape == (20, 5, 10, 18)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Integration: I2V end-to-end pipeline
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestI2VEndToEndPipeline:
|
||||||
|
"""Full I2V pipeline: image → preprocess → VAE encode → y tensor → denoise → VAE decode."""
|
||||||
|
|
||||||
|
def test_full_i2v_pipeline(self):
|
||||||
|
"""End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode."""
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
from mlx_video.models.wan.vae import WanVAE
|
||||||
|
|
||||||
|
mx.random.seed(0)
|
||||||
|
|
||||||
|
# --- Tiny I2V model config (z_dim=16 to match VAE normalization stats) ---
|
||||||
|
config = _make_tiny_i2v_config()
|
||||||
|
config.vae_z_dim = 16
|
||||||
|
config.out_dim = 16 # must match VAE z_dim for decode
|
||||||
|
config.in_dim = 16 + 4 + 16 # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36
|
||||||
|
model = WanModel(config)
|
||||||
|
|
||||||
|
# --- Tiny VAE (with encoder) ---
|
||||||
|
vae = WanVAE(z_dim=config.vae_z_dim, encoder=True)
|
||||||
|
|
||||||
|
# --- Synthetic image: [B=1, 3, T=1, H=32, W=32] in [-1, 1] ---
|
||||||
|
height, width = 32, 32
|
||||||
|
num_frames = 5 # small temporal extent
|
||||||
|
img = mx.random.uniform(-1, 1, (1, 3, 1, height, width))
|
||||||
|
|
||||||
|
# Build video: first frame = image, rest = zeros -> [1, 3, F, H, W]
|
||||||
|
video = mx.concatenate([
|
||||||
|
img,
|
||||||
|
mx.zeros((1, 3, num_frames - 1, height, width)),
|
||||||
|
], axis=2)
|
||||||
|
|
||||||
|
# --- VAE encode ---
|
||||||
|
z_video = vae.encode(video) # [1, z_dim, T_lat, H_lat, W_lat]
|
||||||
|
mx.eval(z_video)
|
||||||
|
assert z_video.ndim == 5
|
||||||
|
assert z_video.shape[1] == config.vae_z_dim
|
||||||
|
|
||||||
|
z_video = z_video[0] # [z_dim, T_lat, H_lat, W_lat]
|
||||||
|
t_latent = z_video.shape[1]
|
||||||
|
h_latent = z_video.shape[2]
|
||||||
|
w_latent = z_video.shape[3]
|
||||||
|
|
||||||
|
# --- Build I2V mask (4 channels) ---
|
||||||
|
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||||
|
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
|
||||||
|
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
||||||
|
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||||
|
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
||||||
|
|
||||||
|
# --- Build y tensor: [mask(4ch) + encoded(z_dim ch)] ---
|
||||||
|
y_i2v = mx.concatenate([msk, z_video], axis=0)
|
||||||
|
mx.eval(y_i2v)
|
||||||
|
assert y_i2v.shape[0] == 4 + config.vae_z_dim
|
||||||
|
|
||||||
|
# --- Denoising loop (2 steps) ---
|
||||||
|
C_noise = config.out_dim # noise channels
|
||||||
|
pt, ph, pw = config.patch_size
|
||||||
|
seq_len = (t_latent // pt) * (h_latent // ph) * (w_latent // pw)
|
||||||
|
|
||||||
|
sched = FlowMatchEulerScheduler()
|
||||||
|
num_steps = 2
|
||||||
|
sched.set_timesteps(num_steps, shift=config.sample_shift)
|
||||||
|
|
||||||
|
latents = mx.random.normal((C_noise, t_latent, h_latent, w_latent))
|
||||||
|
context = mx.random.normal((4, config.text_dim))
|
||||||
|
|
||||||
|
for i in range(num_steps):
|
||||||
|
t_val = sched.timesteps[i].item()
|
||||||
|
pred = model(
|
||||||
|
[latents],
|
||||||
|
mx.array([t_val]),
|
||||||
|
[context],
|
||||||
|
seq_len,
|
||||||
|
y=[y_i2v],
|
||||||
|
)[0]
|
||||||
|
latents = sched.step(pred[None], t_val, latents[None]).squeeze(0)
|
||||||
|
mx.eval(latents)
|
||||||
|
|
||||||
|
assert latents.shape == (C_noise, t_latent, h_latent, w_latent)
|
||||||
|
assert not mx.any(mx.isnan(latents)).item(), "NaN in denoised latents"
|
||||||
|
assert not mx.any(mx.isinf(latents)).item(), "Inf in denoised latents"
|
||||||
|
|
||||||
|
# --- VAE decode ---
|
||||||
|
decoded = vae.decode(latents[None]) # [1, 3, T_out, H_out, W_out]
|
||||||
|
mx.eval(decoded)
|
||||||
|
assert decoded.ndim == 5
|
||||||
|
assert decoded.shape[0] == 1
|
||||||
|
assert decoded.shape[1] == 3 # RGB output
|
||||||
|
assert not mx.any(mx.isnan(decoded)).item(), "NaN in decoded video"
|
||||||
|
assert not mx.any(mx.isinf(decoded)).item(), "Inf in decoded video"
|
||||||
|
# VAE decode clips to [-1, 1]
|
||||||
|
assert float(decoded.max()) <= 1.0
|
||||||
|
assert float(decoded.min()) >= -1.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestDualModelSwitching:
|
||||||
|
"""Test dual-model selection logic: high_noise vs low_noise based on boundary."""
|
||||||
|
|
||||||
|
def test_model_selection_by_timestep(self):
|
||||||
|
"""Verify high_noise model used for timesteps >= boundary, low_noise otherwise."""
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
|
mx.random.seed(1)
|
||||||
|
config = _make_tiny_i2v_config()
|
||||||
|
assert config.dual_model is True
|
||||||
|
|
||||||
|
high_noise_model = WanModel(config)
|
||||||
|
low_noise_model = WanModel(config)
|
||||||
|
|
||||||
|
boundary = config.boundary * config.num_train_timesteps # 0.9 * 1000 = 900
|
||||||
|
|
||||||
|
C_noise = config.out_dim # 4
|
||||||
|
C_y = config.in_dim - config.out_dim # 9 - 4 = 5
|
||||||
|
F, H, W = 1, 4, 4
|
||||||
|
pt, ph, pw = config.patch_size
|
||||||
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||||
|
|
||||||
|
sched = FlowMatchEulerScheduler()
|
||||||
|
num_steps = 5
|
||||||
|
sched.set_timesteps(num_steps, shift=config.sample_shift)
|
||||||
|
|
||||||
|
guide_scale = config.sample_guide_scale # (3.5, 3.5)
|
||||||
|
assert isinstance(guide_scale, tuple) and len(guide_scale) == 2
|
||||||
|
|
||||||
|
latents = mx.random.normal((C_noise, F, H, W))
|
||||||
|
y_i2v = mx.random.normal((C_y, F, H, W))
|
||||||
|
context = mx.random.normal((4, config.text_dim))
|
||||||
|
|
||||||
|
high_used_steps = []
|
||||||
|
low_used_steps = []
|
||||||
|
|
||||||
|
timestep_list = sched.timesteps.tolist()
|
||||||
|
for i in range(num_steps):
|
||||||
|
timestep_val = timestep_list[i]
|
||||||
|
|
||||||
|
if timestep_val >= boundary:
|
||||||
|
model = high_noise_model
|
||||||
|
gs = guide_scale[1]
|
||||||
|
high_used_steps.append(i)
|
||||||
|
else:
|
||||||
|
model = low_noise_model
|
||||||
|
gs = guide_scale[0]
|
||||||
|
low_used_steps.append(i)
|
||||||
|
|
||||||
|
# CFG pass: cond + uncond
|
||||||
|
preds = model(
|
||||||
|
[latents, latents],
|
||||||
|
mx.array([timestep_val, timestep_val]),
|
||||||
|
[context, context],
|
||||||
|
seq_len,
|
||||||
|
y=[y_i2v, y_i2v],
|
||||||
|
)
|
||||||
|
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
|
||||||
|
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
||||||
|
|
||||||
|
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||||
|
mx.eval(latents)
|
||||||
|
|
||||||
|
# With shift=5.0, early timesteps should be high (>=900), later ones low
|
||||||
|
assert len(high_used_steps) > 0, "High-noise model was never selected"
|
||||||
|
assert len(low_used_steps) > 0, "Low-noise model was never selected"
|
||||||
|
# High-noise steps should come before low-noise steps (timesteps decrease)
|
||||||
|
if high_used_steps and low_used_steps:
|
||||||
|
assert max(high_used_steps) < min(low_used_steps) or \
|
||||||
|
min(high_used_steps) < max(low_used_steps), \
|
||||||
|
"Model switching should happen during the loop"
|
||||||
|
|
||||||
|
assert latents.shape == (C_noise, F, H, W)
|
||||||
|
assert not mx.any(mx.isnan(latents)).item()
|
||||||
|
|
||||||
|
def test_guide_scale_tuple_applied_per_model(self):
|
||||||
|
"""Verify (low_gs, high_gs) tuple applies different scales per model."""
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
|
mx.random.seed(2)
|
||||||
|
config = _make_tiny_i2v_config()
|
||||||
|
config.sample_guide_scale = (2.0, 5.0) # distinct values
|
||||||
|
|
||||||
|
model = WanModel(config)
|
||||||
|
boundary = config.boundary * config.num_train_timesteps
|
||||||
|
|
||||||
|
C_noise = config.out_dim
|
||||||
|
F, H, W = 1, 4, 4
|
||||||
|
pt, ph, pw = config.patch_size
|
||||||
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||||
|
|
||||||
|
sched = FlowMatchEulerScheduler()
|
||||||
|
sched.set_timesteps(5, shift=config.sample_shift)
|
||||||
|
|
||||||
|
latents = mx.random.normal((C_noise, F, H, W))
|
||||||
|
context = mx.random.normal((4, config.text_dim))
|
||||||
|
guide_scale = config.sample_guide_scale
|
||||||
|
C_y = config.in_dim - config.out_dim # y channels
|
||||||
|
y_i2v = mx.random.normal((C_y, F, H, W))
|
||||||
|
|
||||||
|
# Track which guide scale was used at each step
|
||||||
|
gs_per_step = []
|
||||||
|
|
||||||
|
timestep_list = sched.timesteps.tolist()
|
||||||
|
for i in range(5):
|
||||||
|
timestep_val = timestep_list[i]
|
||||||
|
|
||||||
|
if timestep_val >= boundary:
|
||||||
|
gs = guide_scale[1] # high_gs = 5.0
|
||||||
|
else:
|
||||||
|
gs = guide_scale[0] # low_gs = 2.0
|
||||||
|
gs_per_step.append(gs)
|
||||||
|
|
||||||
|
pred = model(
|
||||||
|
[latents, latents],
|
||||||
|
mx.array([timestep_val, timestep_val]),
|
||||||
|
[context, context],
|
||||||
|
seq_len,
|
||||||
|
y=[y_i2v, y_i2v],
|
||||||
|
)
|
||||||
|
noise_pred = pred[1] + gs * (pred[0] - pred[1])
|
||||||
|
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||||
|
mx.eval(latents)
|
||||||
|
|
||||||
|
# Verify both guide scales were used
|
||||||
|
assert 5.0 in gs_per_step, "High guide scale (5.0) was never used"
|
||||||
|
assert 2.0 in gs_per_step, "Low guide scale (2.0) was never used"
|
||||||
|
# High gs should appear first (high timesteps come first)
|
||||||
|
first_high = gs_per_step.index(5.0)
|
||||||
|
last_low = len(gs_per_step) - 1 - gs_per_step[::-1].index(2.0)
|
||||||
|
assert first_high < last_low, "High gs steps should precede low gs steps"
|
||||||
|
|
||||||
|
def test_single_model_fallback_with_tuple_guide_scale(self):
|
||||||
|
"""When dual_model=False, guide_scale tuple should use first element."""
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
|
mx.random.seed(3)
|
||||||
|
config = _make_tiny_config()
|
||||||
|
config.dual_model = False
|
||||||
|
config.sample_guide_scale = (3.0, 5.0)
|
||||||
|
|
||||||
|
model = WanModel(config)
|
||||||
|
guide_scale = config.sample_guide_scale
|
||||||
|
|
||||||
|
C, F, H, W = config.in_dim, 1, 4, 4
|
||||||
|
pt, ph, pw = config.patch_size
|
||||||
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||||
|
|
||||||
|
sched = FlowMatchEulerScheduler()
|
||||||
|
sched.set_timesteps(3, shift=3.0)
|
||||||
|
|
||||||
|
latents = mx.random.normal((C, F, H, W))
|
||||||
|
context = mx.random.normal((4, config.text_dim))
|
||||||
|
|
||||||
|
# Mimic generate_wan.py single-model logic:
|
||||||
|
# gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
||||||
|
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
||||||
|
assert gs == 3.0, "Single model should use first element of guide_scale tuple"
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
t_val = sched.timesteps[i].item()
|
||||||
|
pred = model(
|
||||||
|
[latents, latents],
|
||||||
|
mx.array([t_val, t_val]),
|
||||||
|
[context, context],
|
||||||
|
seq_len,
|
||||||
|
)
|
||||||
|
noise_pred = pred[1] + gs * (pred[0] - pred[1])
|
||||||
|
latents = sched.step(noise_pred[None], t_val, latents[None]).squeeze(0)
|
||||||
|
mx.eval(latents)
|
||||||
|
|
||||||
|
assert latents.shape == (C, F, H, W)
|
||||||
|
assert not mx.any(mx.isnan(latents)).item()
|
||||||
|
|||||||
334
tests/test_wan_lora.py
Normal file
334
tests/test_wan_lora.py
Normal file
@@ -0,0 +1,334 @@
|
|||||||
|
"""Tests for LoRA loading and application."""
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoRATypes:
|
||||||
|
"""Test LoRA data structures."""
|
||||||
|
|
||||||
|
def test_lora_weights_scale(self):
|
||||||
|
from mlx_video.lora.types import LoRAWeights
|
||||||
|
|
||||||
|
w = LoRAWeights(
|
||||||
|
lora_A=mx.zeros((16, 64)),
|
||||||
|
lora_B=mx.zeros((128, 16)),
|
||||||
|
rank=16,
|
||||||
|
alpha=32.0,
|
||||||
|
module_name="test",
|
||||||
|
)
|
||||||
|
assert w.scale == 2.0
|
||||||
|
|
||||||
|
def test_lora_weights_scale_default(self):
|
||||||
|
from mlx_video.lora.types import LoRAWeights
|
||||||
|
|
||||||
|
w = LoRAWeights(
|
||||||
|
lora_A=mx.zeros((16, 64)),
|
||||||
|
lora_B=mx.zeros((128, 16)),
|
||||||
|
rank=16,
|
||||||
|
alpha=16.0,
|
||||||
|
module_name="test",
|
||||||
|
)
|
||||||
|
assert w.scale == 1.0
|
||||||
|
|
||||||
|
def test_applied_lora_delta(self):
|
||||||
|
from mlx_video.lora.types import AppliedLoRA, LoRAWeights
|
||||||
|
|
||||||
|
lora_a = mx.ones((2, 4))
|
||||||
|
lora_b = mx.ones((8, 2))
|
||||||
|
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
||||||
|
applied = AppliedLoRA(weights=w, strength=0.5)
|
||||||
|
delta = applied.compute_delta()
|
||||||
|
# scale=1.0, strength=0.5, B@A = [[2,2,2,2]]*8 (each row sum of 2 ones)
|
||||||
|
expected = 0.5 * mx.ones((8, 4)) * 2.0
|
||||||
|
assert mx.allclose(delta, expected).item()
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoRALoader:
|
||||||
|
"""Test LoRA weight loading from safetensors."""
|
||||||
|
|
||||||
|
def _make_lora_file(self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"):
|
||||||
|
"""Helper to create a mock LoRA safetensors file."""
|
||||||
|
weights = {}
|
||||||
|
for name in module_names:
|
||||||
|
if key_format == "AB":
|
||||||
|
weights[f"{name}.lora_A.weight"] = mx.random.normal((rank, in_dim))
|
||||||
|
weights[f"{name}.lora_B.weight"] = mx.random.normal((out_dim, rank))
|
||||||
|
else:
|
||||||
|
weights[f"{name}.lora_down.weight"] = mx.random.normal((rank, in_dim))
|
||||||
|
weights[f"{name}.lora_up.weight"] = mx.random.normal((out_dim, rank))
|
||||||
|
path = Path(tmp_dir) / "test_lora.safetensors"
|
||||||
|
mx.save_safetensors(str(path), weights)
|
||||||
|
return path
|
||||||
|
|
||||||
|
def test_load_lora_a_b_format(self):
|
||||||
|
from mlx_video.lora.loader import load_lora_weights
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
path = self._make_lora_file(tmp, ["blocks.0.self_attn.q"], key_format="AB")
|
||||||
|
lora_weights = load_lora_weights(path)
|
||||||
|
assert "blocks.0.self_attn.q" in lora_weights
|
||||||
|
w = lora_weights["blocks.0.self_attn.q"]
|
||||||
|
assert w.rank == 4
|
||||||
|
assert w.alpha == 4.0 # default: alpha == rank
|
||||||
|
assert w.lora_A.shape == (4, 64)
|
||||||
|
assert w.lora_B.shape == (128, 4)
|
||||||
|
|
||||||
|
def test_load_lora_down_up_format(self):
|
||||||
|
from mlx_video.lora.loader import load_lora_weights
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
path = self._make_lora_file(
|
||||||
|
tmp, ["blocks.0.self_attn.q"], key_format="down_up"
|
||||||
|
)
|
||||||
|
lora_weights = load_lora_weights(path)
|
||||||
|
assert "blocks.0.self_attn.q" in lora_weights
|
||||||
|
|
||||||
|
def test_load_multiple_modules(self):
|
||||||
|
from mlx_video.lora.loader import load_lora_weights
|
||||||
|
|
||||||
|
modules = [
|
||||||
|
"blocks.0.self_attn.q",
|
||||||
|
"blocks.0.self_attn.k",
|
||||||
|
"blocks.0.ffn.fc1",
|
||||||
|
]
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
path = self._make_lora_file(tmp, modules)
|
||||||
|
lora_weights = load_lora_weights(path)
|
||||||
|
assert len(lora_weights) == 3
|
||||||
|
for name in modules:
|
||||||
|
assert name in lora_weights
|
||||||
|
|
||||||
|
def test_load_with_alpha(self):
|
||||||
|
from mlx_video.lora.loader import load_lora_weights
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
weights = {
|
||||||
|
"test.lora_A.weight": mx.random.normal((8, 64)),
|
||||||
|
"test.lora_B.weight": mx.random.normal((128, 8)),
|
||||||
|
"test.alpha": mx.array(16.0),
|
||||||
|
}
|
||||||
|
path = Path(tmp) / "lora.safetensors"
|
||||||
|
mx.save_safetensors(str(path), weights)
|
||||||
|
lora_weights = load_lora_weights(path)
|
||||||
|
assert lora_weights["test"].alpha == 16.0
|
||||||
|
assert lora_weights["test"].rank == 8
|
||||||
|
assert lora_weights["test"].scale == 2.0
|
||||||
|
|
||||||
|
def test_file_not_found(self):
|
||||||
|
from mlx_video.lora.loader import load_lora_weights
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
load_lora_weights(Path("/nonexistent/lora.safetensors"))
|
||||||
|
|
||||||
|
|
||||||
|
class TestWanKeyNormalization:
|
||||||
|
"""Test Wan2.2 LoRA key normalization."""
|
||||||
|
|
||||||
|
def _wan_model_keys(self):
|
||||||
|
"""Simulate typical Wan2.2 MLX model weight keys."""
|
||||||
|
keys = set()
|
||||||
|
for i in range(2):
|
||||||
|
for layer in ["self_attn.q", "self_attn.k", "self_attn.v", "self_attn.o",
|
||||||
|
"cross_attn.q", "cross_attn.k", "cross_attn.v", "cross_attn.o"]:
|
||||||
|
keys.add(f"blocks.{i}.{layer}.weight")
|
||||||
|
keys.add(f"blocks.{i}.ffn.fc1.weight")
|
||||||
|
keys.add(f"blocks.{i}.ffn.fc2.weight")
|
||||||
|
keys.add("text_embedding_0.weight")
|
||||||
|
keys.add("text_embedding_1.weight")
|
||||||
|
keys.add("time_embedding_0.weight")
|
||||||
|
keys.add("time_embedding_1.weight")
|
||||||
|
keys.add("time_projection.weight")
|
||||||
|
keys.add("patch_embedding_proj.weight")
|
||||||
|
return keys
|
||||||
|
|
||||||
|
def test_direct_match(self):
|
||||||
|
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||||
|
|
||||||
|
keys = self._wan_model_keys()
|
||||||
|
assert _normalize_wan_lora_key("blocks.0.self_attn.q", keys) == "blocks.0.self_attn.q"
|
||||||
|
|
||||||
|
def test_strip_diffusion_model_prefix(self):
|
||||||
|
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||||
|
|
||||||
|
keys = self._wan_model_keys()
|
||||||
|
result = _normalize_wan_lora_key("diffusion_model.blocks.0.self_attn.q", keys)
|
||||||
|
assert result == "blocks.0.self_attn.q"
|
||||||
|
|
||||||
|
def test_strip_model_prefix(self):
|
||||||
|
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||||
|
|
||||||
|
keys = self._wan_model_keys()
|
||||||
|
result = _normalize_wan_lora_key("model.diffusion_model.blocks.0.self_attn.k", keys)
|
||||||
|
assert result == "blocks.0.self_attn.k"
|
||||||
|
|
||||||
|
def test_ffn_key_mapping(self):
|
||||||
|
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||||
|
|
||||||
|
keys = self._wan_model_keys()
|
||||||
|
assert _normalize_wan_lora_key("blocks.0.ffn.0", keys) == "blocks.0.ffn.fc1"
|
||||||
|
assert _normalize_wan_lora_key("blocks.0.ffn.2", keys) == "blocks.0.ffn.fc2"
|
||||||
|
|
||||||
|
def test_text_embedding_mapping(self):
|
||||||
|
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||||
|
|
||||||
|
keys = self._wan_model_keys()
|
||||||
|
assert _normalize_wan_lora_key("text_embedding.0", keys) == "text_embedding_0"
|
||||||
|
assert _normalize_wan_lora_key("text_embedding.2", keys) == "text_embedding_1"
|
||||||
|
|
||||||
|
def test_time_embedding_mapping(self):
|
||||||
|
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||||
|
|
||||||
|
keys = self._wan_model_keys()
|
||||||
|
assert _normalize_wan_lora_key("time_embedding.0", keys) == "time_embedding_0"
|
||||||
|
assert _normalize_wan_lora_key("time_embedding.2", keys) == "time_embedding_1"
|
||||||
|
|
||||||
|
def test_time_projection_mapping(self):
|
||||||
|
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||||
|
|
||||||
|
keys = self._wan_model_keys()
|
||||||
|
assert _normalize_wan_lora_key("time_projection.1", keys) == "time_projection"
|
||||||
|
|
||||||
|
def test_patch_embedding_mapping(self):
|
||||||
|
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||||
|
|
||||||
|
keys = self._wan_model_keys()
|
||||||
|
assert _normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj"
|
||||||
|
|
||||||
|
def test_combined_prefix_and_ffn(self):
|
||||||
|
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||||
|
|
||||||
|
keys = self._wan_model_keys()
|
||||||
|
result = _normalize_wan_lora_key("diffusion_model.blocks.1.ffn.0", keys)
|
||||||
|
assert result == "blocks.1.ffn.fc1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestApplyLoRA:
|
||||||
|
"""Test LoRA delta application to weights."""
|
||||||
|
|
||||||
|
def test_preserves_bfloat16_dtype(self):
|
||||||
|
"""LoRA delta must not promote bfloat16 weights to float32."""
|
||||||
|
from mlx_video.lora.apply import apply_lora_to_linear
|
||||||
|
from mlx_video.lora.types import LoRAWeights
|
||||||
|
|
||||||
|
original = mx.ones((8, 4), dtype=mx.bfloat16)
|
||||||
|
# LoRA weights in float32 (typical when loaded from safetensors)
|
||||||
|
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
|
||||||
|
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
|
||||||
|
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
||||||
|
result = apply_lora_to_linear(original, [(w, 1.0)])
|
||||||
|
assert result.dtype == mx.bfloat16, f"Expected bfloat16, got {result.dtype}"
|
||||||
|
|
||||||
|
def test_preserves_float16_dtype(self):
|
||||||
|
from mlx_video.lora.apply import apply_lora_to_linear
|
||||||
|
from mlx_video.lora.types import LoRAWeights
|
||||||
|
|
||||||
|
original = mx.ones((8, 4), dtype=mx.float16)
|
||||||
|
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
|
||||||
|
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
|
||||||
|
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
||||||
|
result = apply_lora_to_linear(original, [(w, 1.0)])
|
||||||
|
assert result.dtype == mx.float16, f"Expected float16, got {result.dtype}"
|
||||||
|
|
||||||
|
def test_apply_single_lora(self):
|
||||||
|
from mlx_video.lora.apply import apply_lora_to_linear
|
||||||
|
from mlx_video.lora.types import LoRAWeights
|
||||||
|
|
||||||
|
original = mx.ones((8, 4))
|
||||||
|
lora_a = mx.ones((2, 4)) * 0.1
|
||||||
|
lora_b = mx.ones((8, 2)) * 0.1
|
||||||
|
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
||||||
|
result = apply_lora_to_linear(original, [(w, 1.0)])
|
||||||
|
# delta = 1.0 * (B @ A) = ones(8,2)*0.1 @ ones(2,4)*0.1 = 0.02 * ones(8,4)
|
||||||
|
expected = original + 0.02 * mx.ones((8, 4))
|
||||||
|
assert mx.allclose(result, expected, atol=1e-6).item()
|
||||||
|
|
||||||
|
def test_apply_multiple_loras(self):
|
||||||
|
from mlx_video.lora.apply import apply_lora_to_linear
|
||||||
|
from mlx_video.lora.types import LoRAWeights
|
||||||
|
|
||||||
|
original = mx.zeros((8, 4))
|
||||||
|
w1 = LoRAWeights(
|
||||||
|
lora_A=mx.ones((2, 4)),
|
||||||
|
lora_B=mx.ones((8, 2)),
|
||||||
|
rank=2, alpha=2.0, module_name="a",
|
||||||
|
)
|
||||||
|
w2 = LoRAWeights(
|
||||||
|
lora_A=mx.ones((2, 4)) * 2,
|
||||||
|
lora_B=mx.ones((8, 2)) * 2,
|
||||||
|
rank=2, alpha=4.0, module_name="b",
|
||||||
|
)
|
||||||
|
result = apply_lora_to_linear(original, [(w1, 1.0), (w2, 0.5)])
|
||||||
|
# w1 delta: 1.0 * 1.0 * (ones(8,2) @ ones(2,4)) = 2 * ones(8,4)
|
||||||
|
# w2 delta: 2.0 * 0.5 * (2*ones(8,2) @ 2*ones(2,4)) = 1.0 * 8*ones(8,4) = 8
|
||||||
|
delta1 = mx.ones((8, 4)) * 2.0
|
||||||
|
delta2 = mx.ones((8, 4)) * 8.0
|
||||||
|
expected = delta1 + delta2
|
||||||
|
assert mx.allclose(result, expected, atol=1e-5).item()
|
||||||
|
|
||||||
|
def test_apply_loras_to_weights_dict(self):
|
||||||
|
from mlx_video.lora.apply import apply_loras_to_weights
|
||||||
|
from mlx_video.lora.types import LoRAWeights
|
||||||
|
|
||||||
|
model_weights = {
|
||||||
|
"blocks.0.self_attn.q.weight": mx.ones((128, 64)),
|
||||||
|
"blocks.0.self_attn.k.weight": mx.ones((128, 64)),
|
||||||
|
"blocks.0.ffn.fc1.weight": mx.ones((256, 64)),
|
||||||
|
}
|
||||||
|
w = LoRAWeights(
|
||||||
|
lora_A=mx.ones((4, 64)) * 0.01,
|
||||||
|
lora_B=mx.ones((128, 4)) * 0.01,
|
||||||
|
rank=4, alpha=4.0, module_name="blocks.0.self_attn.q",
|
||||||
|
)
|
||||||
|
module_to_loras = {"blocks.0.self_attn.q": [(w, 1.0)]}
|
||||||
|
result = apply_loras_to_weights(model_weights, module_to_loras)
|
||||||
|
# Only q should be modified
|
||||||
|
assert not mx.array_equal(
|
||||||
|
result["blocks.0.self_attn.q.weight"],
|
||||||
|
model_weights["blocks.0.self_attn.q.weight"],
|
||||||
|
).item()
|
||||||
|
assert mx.array_equal(
|
||||||
|
result["blocks.0.self_attn.k.weight"],
|
||||||
|
model_weights["blocks.0.self_attn.k.weight"],
|
||||||
|
).item()
|
||||||
|
|
||||||
|
|
||||||
|
class TestEndToEnd:
|
||||||
|
"""End-to-end LoRA loading and application."""
|
||||||
|
|
||||||
|
def test_load_and_apply_loras(self):
|
||||||
|
from mlx_video.convert_wan import load_and_apply_loras
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
# Create mock LoRA safetensors
|
||||||
|
rank = 4
|
||||||
|
weights = {
|
||||||
|
"blocks.0.self_attn.q.lora_A.weight": mx.random.normal((rank, 64)),
|
||||||
|
"blocks.0.self_attn.q.lora_B.weight": mx.random.normal((128, rank)),
|
||||||
|
}
|
||||||
|
lora_path = Path(tmp) / "test.safetensors"
|
||||||
|
mx.save_safetensors(str(lora_path), weights)
|
||||||
|
|
||||||
|
# Create mock model weights
|
||||||
|
model_weights = {
|
||||||
|
"blocks.0.self_attn.q.weight": mx.ones((128, 64)),
|
||||||
|
"blocks.0.self_attn.k.weight": mx.ones((128, 64)),
|
||||||
|
}
|
||||||
|
|
||||||
|
result = load_and_apply_loras(
|
||||||
|
model_weights, [(str(lora_path), 1.0)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# q weight should be modified, k unchanged
|
||||||
|
assert not mx.array_equal(
|
||||||
|
result["blocks.0.self_attn.q.weight"],
|
||||||
|
model_weights["blocks.0.self_attn.q.weight"],
|
||||||
|
).item()
|
||||||
|
assert mx.array_equal(
|
||||||
|
result["blocks.0.self_attn.k.weight"],
|
||||||
|
model_weights["blocks.0.self_attn.k.weight"],
|
||||||
|
).item()
|
||||||
@@ -868,4 +868,84 @@ class TestVAEEncoderTemporalOrder:
|
|||||||
assert out_wrong.shape[1] == 2
|
assert out_wrong.shape[1] == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# VAE Encode → Decode Round-Trip Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestVAE21RoundTrip:
|
||||||
|
"""Encode→decode round-trip for Wan 2.1 VAE (channels-first)."""
|
||||||
|
|
||||||
|
def test_encode_decode_shape_and_values(self):
|
||||||
|
"""Encoder3d → Decoder3d: output shape matches input, values are finite."""
|
||||||
|
from mlx_video.models.wan.vae import Decoder3d, Encoder3d
|
||||||
|
|
||||||
|
z_dim = 4
|
||||||
|
dim = 8
|
||||||
|
# No temporal up/downsampling to keep the test simple
|
||||||
|
enc = Encoder3d(
|
||||||
|
dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False]
|
||||||
|
)
|
||||||
|
dec = Decoder3d(
|
||||||
|
dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False]
|
||||||
|
)
|
||||||
|
mx.eval(enc.parameters(), dec.parameters())
|
||||||
|
|
||||||
|
# [B=1, C=3, T=1, H=8, W=8]
|
||||||
|
x = mx.random.normal((1, 3, 1, 8, 8)) * 0.5
|
||||||
|
|
||||||
|
z = enc(x)
|
||||||
|
mx.eval(z)
|
||||||
|
# 3 spatial downsamples (÷8): H=1, W=1
|
||||||
|
assert z.shape == (1, z_dim, 1, 1, 1)
|
||||||
|
|
||||||
|
x_hat = dec(z)
|
||||||
|
mx.eval(x_hat)
|
||||||
|
# 3 spatial upsamples (×8): should recover original shape
|
||||||
|
assert x_hat.shape == x.shape
|
||||||
|
|
||||||
|
out_np = np.array(x_hat)
|
||||||
|
assert np.all(np.isfinite(out_np))
|
||||||
|
assert np.abs(out_np).max() < 1000
|
||||||
|
|
||||||
|
|
||||||
|
class TestVAE22RoundTrip:
|
||||||
|
"""Encode→decode round-trip for Wan 2.2 VAE (channels-last)."""
|
||||||
|
|
||||||
|
def test_encode_decode_shape_and_values(self):
|
||||||
|
"""Wan22VAEEncoder → Wan22VAEDecoder: shapes consistent, values in range."""
|
||||||
|
from mlx_video.models.wan.vae22 import (
|
||||||
|
Wan22VAEDecoder,
|
||||||
|
Wan22VAEEncoder,
|
||||||
|
denormalize_latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||||
|
dec = Wan22VAEDecoder(z_dim=48, dec_dim=8)
|
||||||
|
mx.eval(enc.parameters(), dec.parameters())
|
||||||
|
|
||||||
|
# [B=1, T=1, H=32, W=32, C=3]
|
||||||
|
img = mx.random.normal((1, 1, 32, 32, 3)) * 0.5
|
||||||
|
|
||||||
|
z_norm = enc(img)
|
||||||
|
mx.eval(z_norm)
|
||||||
|
# patchify(÷2) + 3 spatial downsamples(÷8) = ÷16
|
||||||
|
assert z_norm.shape == (1, 1, 2, 2, 48)
|
||||||
|
|
||||||
|
z = denormalize_latents(z_norm)
|
||||||
|
out = dec(z)
|
||||||
|
mx.eval(out)
|
||||||
|
|
||||||
|
# 3 spatial upsamples(×8) + unpatchify(×2) = ×16
|
||||||
|
assert out.shape[0] == 1 # batch
|
||||||
|
assert out.shape[2] == 32 # H recovered
|
||||||
|
assert out.shape[3] == 32 # W recovered
|
||||||
|
assert out.shape[-1] == 3 # RGB
|
||||||
|
|
||||||
|
out_np = np.array(out)
|
||||||
|
assert np.all(np.isfinite(out_np))
|
||||||
|
assert out_np.min() >= -1.0 - 1e-6
|
||||||
|
assert out_np.max() <= 1.0 + 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user