feat(wan): Add LoRA with improved quantization pipeline

This commit is contained in:
Daniel
2026-02-28 14:11:13 +01:00
parent dbab95ec45
commit 849cc45d84
17 changed files with 1852 additions and 111 deletions

View File

@@ -0,0 +1,25 @@
"""LoRA support for mlx-video."""
from mlx_video.lora.apply import (
LoRALinear,
apply_lora_to_linear,
apply_loras_to_model,
apply_loras_to_weights,
)
from mlx_video.lora.loader import (
load_lora_weights,
load_multiple_loras,
)
from mlx_video.lora.types import AppliedLoRA, LoRAConfig, LoRAWeights
__all__ = [
"LoRAConfig",
"LoRAWeights",
"AppliedLoRA",
"load_lora_weights",
"load_multiple_loras",
"apply_lora_to_linear",
"apply_loras_to_weights",
"apply_loras_to_model",
"LoRALinear",
]

393
mlx_video/lora/apply.py Normal file
View File

@@ -0,0 +1,393 @@
"""Apply LoRA weights to model layers."""
from typing import Dict, List, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.lora.types import LoRAWeights
def apply_lora_to_linear(
linear_weight: mx.array,
lora_weights_and_strengths: List[Tuple[LoRAWeights, float]],
) -> mx.array:
"""Apply one or more LoRAs to a linear layer weight.
Args:
linear_weight: Original weight matrix [out_features, in_features]
lora_weights_and_strengths: List of (LoRAWeights, strength) tuples
Returns:
Modified weight with LoRA deltas applied (preserves original dtype)
"""
orig_dtype = linear_weight.dtype
modified_weight = linear_weight
for weights, strength in lora_weights_and_strengths:
scale = weights.scale
# Compute delta in float32 for precision, then cast back to avoid
# promoting model weights (e.g. bfloat16 → float32 causes ~1.5x slowdown)
delta = (weights.lora_B @ weights.lora_A) * (scale * strength)
modified_weight = modified_weight + delta.astype(orig_dtype)
return modified_weight
def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str:
"""Normalize LoRA module name to match Wan2.2 MLX model weight keys.
Handles:
- Stripping common prefixes (diffusion_model., model., etc.)
- FFN key mapping: ffn.0 → ffn.fc1, ffn.2 → ffn.fc2
- Embedding key mapping: text_embedding.0 → text_embedding_0, etc.
- Time projection: time_projection.1 → time_projection
- Patch embedding: patch_embedding → patch_embedding_proj
Args:
lora_key: Original LoRA module name
model_keys: Set of all model weight keys
Returns:
Normalized key that matches model weights
"""
# Try the key as-is first
if f"{lora_key}.weight" in model_keys or lora_key in model_keys:
return lora_key
# Common prefixes to strip
prefixes_to_strip = [
"model.diffusion_model.",
"diffusion_model.",
"base_model.model.",
"model.",
]
candidates = [lora_key]
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
candidates.append(lora_key[len(prefix):])
for candidate in candidates:
# Try as-is
if f"{candidate}.weight" in model_keys or candidate in model_keys:
return candidate
# Apply Wan2.2 key transformations
transformed = candidate
# FFN: ffn.0 → ffn.fc1, ffn.2 → ffn.fc2
transformed = transformed.replace(".ffn.0.", ".ffn.fc1.")
transformed = transformed.replace(".ffn.2.", ".ffn.fc2.")
if transformed.endswith(".ffn.0"):
transformed = transformed[:-len(".ffn.0")] + ".ffn.fc1"
if transformed.endswith(".ffn.2"):
transformed = transformed[:-len(".ffn.2")] + ".ffn.fc2"
# Text embedding: text_embedding.0 → text_embedding_0
transformed = transformed.replace("text_embedding.0.", "text_embedding_0.")
transformed = transformed.replace("text_embedding.2.", "text_embedding_1.")
if transformed.endswith("text_embedding.0"):
transformed = transformed[:-len("text_embedding.0")] + "text_embedding_0"
if transformed.endswith("text_embedding.2"):
transformed = transformed[:-len("text_embedding.2")] + "text_embedding_1"
# Time embedding: time_embedding.0 → time_embedding_0
transformed = transformed.replace("time_embedding.0.", "time_embedding_0.")
transformed = transformed.replace("time_embedding.2.", "time_embedding_1.")
if transformed.endswith("time_embedding.0"):
transformed = transformed[:-len("time_embedding.0")] + "time_embedding_0"
if transformed.endswith("time_embedding.2"):
transformed = transformed[:-len("time_embedding.2")] + "time_embedding_1"
# Time projection: time_projection.1 → time_projection
transformed = transformed.replace("time_projection.1.", "time_projection.")
if transformed.endswith("time_projection.1"):
transformed = transformed[:-len("time_projection.1")] + "time_projection"
# Patch embedding: patch_embedding → patch_embedding_proj
if "patch_embedding" in transformed and "patch_embedding_proj" not in transformed:
transformed = transformed.replace("patch_embedding", "patch_embedding_proj")
if f"{transformed}.weight" in model_keys or transformed in model_keys:
return transformed
# Return best attempt with prefix stripped
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
return lora_key[len(prefix):]
return lora_key
# Also support LTX-style key normalization
def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str:
"""Normalize LoRA module name to match LTX MLX model weight keys."""
if f"{lora_key}.weight" in model_keys or lora_key in model_keys:
return lora_key
prefixes_to_strip = [
"model.diffusion_model.",
"diffusion_model.",
"model.",
]
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
normalized = lora_key[len(prefix):]
if f"{normalized}.weight" in model_keys or normalized in model_keys:
return normalized
transformed = normalized
if transformed.endswith(".to_out.0"):
transformed = transformed[:-len(".to_out.0")] + ".to_out"
transformed = transformed.replace(".to_out.0.", ".to_out.")
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.")
transformed = transformed.replace(".ff.net.2", ".ff.proj_out")
transformed = transformed.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
transformed = transformed.replace(".audio_ff.net.0.proj", ".audio_ff.proj_in")
transformed = transformed.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
transformed = transformed.replace(".audio_ff.net.2", ".audio_ff.proj_out")
if f"{transformed}.weight" in model_keys or transformed in model_keys:
return transformed
# Try transformations on the original key
transformed = lora_key
if transformed.endswith(".to_out.0"):
transformed = transformed[:-len(".to_out.0")] + ".to_out"
transformed = transformed.replace(".to_out.0.", ".to_out.")
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.")
transformed = transformed.replace(".ff.net.2", ".ff.proj_out")
if f"{transformed}.weight" in model_keys or transformed in model_keys:
return transformed
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
return lora_key[len(prefix):]
return lora_key
def _normalize_lora_key(lora_key: str, model_keys: set) -> str:
"""Normalize LoRA module name to match model weight keys.
Auto-detects whether to use Wan2.2 or LTX key normalization based
on the presence of architecture-specific keys in the model.
"""
# Detect model architecture from keys
is_wan = any("self_attn.q.weight" in k for k in model_keys)
if is_wan:
return _normalize_wan_lora_key(lora_key, model_keys)
else:
return _normalize_ltx_lora_key(lora_key, model_keys)
def apply_loras_to_weights(
model_weights: Dict[str, mx.array],
module_to_loras: Dict[str, List[Tuple[LoRAWeights, float]]],
verbose: bool = False,
quantization_bits: int = 0,
) -> Dict[str, mx.array]:
"""Apply LoRAs to model weights.
Args:
model_weights: Original model state dictionary
module_to_loras: Dictionary mapping module names to lists of
(LoRAWeights, strength) tuples
verbose: If True, print detailed debug information
quantization_bits: If >0, weights are quantized at this bit width.
Quantized layers are dequantized before LoRA application
and re-quantized after.
Returns:
New state dictionary with LoRA-modified weights
"""
modified_weights = dict(model_weights)
model_keys = set(model_weights.keys())
applied_count = 0
skipped_count = 0
skipped_modules = []
for module_name, loras in module_to_loras.items():
normalized_name = _normalize_lora_key(module_name, model_keys)
weight_key = f"{normalized_name}.weight"
if weight_key not in modified_weights:
if normalized_name not in modified_weights:
skipped_count += 1
skipped_modules.append(module_name)
if verbose and skipped_count <= 5:
print(f" DEBUG: '{module_name}' -> '{normalized_name}' -> NOT FOUND")
similar = [
k
for k in list(model_keys)[:1000]
if normalized_name.split(".")[-1] in k
][:3]
if similar:
print(f" Similar keys: {similar}")
continue
weight_key = normalized_name
original_weight = modified_weights[weight_key]
# Handle quantized weights: dequantize → apply delta → re-quantize
scales_key = f"{normalized_name}.scales"
biases_key = f"{normalized_name}.biases"
is_quantized = (
original_weight.dtype == mx.uint32
and scales_key in modified_weights
and biases_key in modified_weights
)
if is_quantized:
scales = modified_weights[scales_key]
biases = modified_weights[biases_key]
group_size = (original_weight.shape[-1] * 32) // (scales.shape[-1] * quantization_bits)
dequantized = mx.dequantize(
original_weight, scales, biases, group_size=group_size, bits=quantization_bits
)
modified = apply_lora_to_linear(dequantized, loras)
# Re-quantize with same parameters
new_w, new_scales, new_biases = mx.quantize(modified, group_size=group_size, bits=quantization_bits)
modified_weights[weight_key] = new_w
modified_weights[scales_key] = new_scales
modified_weights[biases_key] = new_biases
else:
modified_weights[weight_key] = apply_lora_to_linear(original_weight, loras)
applied_count += 1
if applied_count > 0:
print(f" ✓ Applied to {applied_count} modules")
if skipped_count > 0:
print(f" ⚠ Skipped {skipped_count} incompatible modules")
return modified_weights
class LoRALinear(nn.Module):
"""Linear layer with on-the-fly LoRA application.
Wraps nn.Linear or nn.QuantizedLinear, computing LoRA delta at runtime:
output = base_linear(x) + (x @ lora_A.T @ lora_B.T) * scale * strength
"""
def __init__(
self,
linear: nn.Module,
lora_weights_and_strengths: List[Tuple[LoRAWeights, float]],
):
super().__init__()
self.linear = linear
self.lora_weights_and_strengths = lora_weights_and_strengths
def __call__(self, x: mx.array) -> mx.array:
output = self.linear(x)
for weights, strength in self.lora_weights_and_strengths:
scale = weights.scale
lora_out = x @ weights.lora_A.T @ weights.lora_B.T
output = output + (scale * strength * lora_out)
return output
def apply_loras_to_model(
model: nn.Module,
module_to_loras: Dict[str, List[Tuple[LoRAWeights, float]]],
verbose: bool = False,
) -> int:
"""Apply LoRAs to a model by merging into weights.
For QuantizedLinear layers: dequantizes to bf16, merges LoRA delta, and
replaces with a regular nn.Linear (no per-step overhead, no re-quantization
precision loss). Non-LoRA layers stay quantized.
For nn.Linear layers: merges LoRA delta directly into the weight.
Args:
model: The model to apply LoRAs to
module_to_loras: Dictionary mapping module names to (LoRAWeights, strength) lists
verbose: Print debug info
Returns:
Number of modules modified
"""
# Build a set of model module paths for key normalization
module_paths = set()
for name, _ in model.named_modules():
module_paths.add(name)
module_paths.add(f"{name}.weight")
# Map LoRA keys → model module paths
lora_to_module = {}
for lora_key in module_to_loras:
normalized = _normalize_lora_key(lora_key, module_paths)
if normalized.endswith(".weight"):
normalized = normalized[: -len(".weight")]
lora_to_module[lora_key] = normalized
applied_count = 0
dequant_count = 0
skipped = []
for lora_key, loras in module_to_loras.items():
module_path = lora_to_module[lora_key]
parts = module_path.split(".")
# Traverse to the parent module
parent = model
try:
for part in parts[:-1]:
parent = getattr(parent, part) if not part.isdigit() else parent[int(part)]
leaf_name = parts[-1]
target = getattr(parent, leaf_name) if not leaf_name.isdigit() else parent[int(leaf_name)]
except (AttributeError, IndexError, TypeError):
skipped.append(lora_key)
if verbose:
print(f" DEBUG: '{lora_key}' -> '{module_path}' -> module not found")
continue
if isinstance(target, nn.QuantizedLinear):
# Dequantize → merge LoRA → replace with bf16 Linear
weight = mx.dequantize(
target.weight, target.scales, target.biases,
group_size=target.group_size, bits=target.bits,
)
merged = apply_lora_to_linear(weight, loras)
new_linear = nn.Linear(merged.shape[1], merged.shape[0])
new_linear.weight = merged
if "bias" in target:
new_linear.bias = target.bias
if leaf_name.isdigit():
parent[int(leaf_name)] = new_linear
else:
setattr(parent, leaf_name, new_linear)
dequant_count += 1
applied_count += 1
elif isinstance(target, nn.Linear):
# Merge directly into weight
target.weight = apply_lora_to_linear(target.weight, loras)
applied_count += 1
else:
skipped.append(lora_key)
if verbose:
print(f" DEBUG: '{module_path}' is {type(target).__name__}, not Linear")
continue
if applied_count > 0:
msg = f" ✓ Applied to {applied_count} modules"
if dequant_count > 0:
msg += f" ({dequant_count} dequantized to bf16)"
print(msg)
if skipped:
print(f" ⚠ Skipped {len(skipped)} incompatible modules")
return applied_count

122
mlx_video/lora/loader.py Normal file
View File

@@ -0,0 +1,122 @@
"""LoRA weight loading utilities."""
import re
from pathlib import Path
from typing import Dict, List, Optional
import mlx.core as mx
from mlx_video.lora.types import LoRAConfig, LoRAWeights
def load_lora_weights(lora_path: Path) -> Dict[str, LoRAWeights]:
"""Load LoRA weights from a safetensors file.
Supports both key conventions:
- {module_name}.lora_A.weight / {module_name}.lora_B.weight
- {module_name}.lora_down.weight / {module_name}.lora_up.weight
Args:
lora_path: Path to the LoRA safetensors file
Returns:
Dictionary mapping module names to LoRAWeights objects
Raises:
FileNotFoundError: If the LoRA file doesn't exist
ValueError: If the LoRA file format is invalid
"""
if not lora_path.exists():
raise FileNotFoundError(f"LoRA file not found: {lora_path}")
all_weights = mx.load(str(lora_path))
# Group weights by module name, handling both naming conventions
lora_weights = {}
module_names = set()
for key in all_weights.keys():
# Format 1: {module}.lora_A.weight / {module}.lora_B.weight
match = re.match(r"(.+)\.lora_([AB])\.weight$", key)
if match:
module_names.add(match.group(1))
continue
# Format 2: {module}.lora_down.weight / {module}.lora_up.weight
match = re.match(r"(.+)\.lora_(down|up)\.weight$", key)
if match:
module_names.add(match.group(1))
for module_name in module_names:
# Try both key conventions
key_a = f"{module_name}.lora_A.weight"
key_b = f"{module_name}.lora_B.weight"
if key_a not in all_weights or key_b not in all_weights:
key_a = f"{module_name}.lora_down.weight"
key_b = f"{module_name}.lora_up.weight"
if key_a not in all_weights or key_b not in all_weights:
continue
lora_a = all_weights[key_a]
lora_b = all_weights[key_b]
if lora_a.ndim != 2 or lora_b.ndim != 2:
raise ValueError(
f"Invalid LoRA shape for {module_name}: "
f"lora_A={lora_a.shape}, lora_B={lora_b.shape}"
)
rank = lora_a.shape[0]
if lora_b.shape[1] != rank:
raise ValueError(
f"LoRA rank mismatch for {module_name}: "
f"lora_A rank={rank}, lora_B rank={lora_b.shape[1]}"
)
# Check for per-module alpha stored as a scalar tensor
alpha_key = f"{module_name}.alpha"
if alpha_key in all_weights:
alpha = float(all_weights[alpha_key].item())
else:
alpha = float(rank)
lora_weights[module_name] = LoRAWeights(
lora_A=lora_a,
lora_B=lora_b,
rank=rank,
alpha=alpha,
module_name=module_name,
)
if not lora_weights:
raise ValueError(f"No valid LoRA weights found in {lora_path}")
return lora_weights
def load_multiple_loras(
configs: List[LoRAConfig],
) -> Dict[str, List[tuple]]:
"""Load multiple LoRA configurations.
Args:
configs: List of LoRAConfig objects
Returns:
Dictionary mapping module names to lists of (LoRAWeights, strength) tuples.
"""
module_to_loras: Dict[str, list] = {}
for config in configs:
lora_weights = load_lora_weights(config.path)
for module_name, weights in lora_weights.items():
if config.target_modules is not None:
if module_name not in config.target_modules:
continue
if module_name not in module_to_loras:
module_to_loras[module_name] = []
module_to_loras[module_name].append((weights, config.strength))
return module_to_loras

74
mlx_video/lora/types.py Normal file
View File

@@ -0,0 +1,74 @@
"""Data structures for LoRA support."""
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import mlx.core as mx
@dataclass
class LoRAWeights:
"""Container for LoRA weight matrices.
Attributes:
lora_A: Low-rank matrix A of shape [rank, in_features]
lora_B: Low-rank matrix B of shape [out_features, rank]
rank: Rank of the LoRA decomposition
alpha: LoRA scaling parameter (default: rank)
module_name: Target module name in the model
"""
lora_A: mx.array
lora_B: mx.array
rank: int
alpha: float
module_name: str
@property
def scale(self) -> float:
"""Compute the scale factor: alpha / rank."""
return self.alpha / self.rank
@dataclass
class LoRAConfig:
"""Configuration for a single LoRA.
Attributes:
path: Path to the LoRA safetensors file
strength: Strength/weight to apply this LoRA (typically 0.0-2.0)
target_modules: Optional list of module names to apply LoRA to.
If None, applies to all available modules in the LoRA.
"""
path: Path
strength: float = 1.0
target_modules: Optional[list[str]] = None
def __post_init__(self):
"""Validate and normalize the configuration."""
self.path = Path(self.path)
if not self.path.exists():
raise FileNotFoundError(f"LoRA file not found: {self.path}")
if self.strength < 0:
raise ValueError(f"LoRA strength must be non-negative, got {self.strength}")
@dataclass
class AppliedLoRA:
"""Represents a LoRA applied to a specific module.
Attributes:
weights: The LoRA weight matrices
strength: Application strength for this LoRA
"""
weights: LoRAWeights
strength: float
def compute_delta(self) -> mx.array:
"""Compute the weight delta: strength * scale * (lora_B @ lora_A)."""
scale = self.weights.scale
delta = self.weights.lora_B @ self.weights.lora_A
return scale * self.strength * delta