feat(wan): Add LoRA with improved quantization pipeline

This commit is contained in:
Daniel
2026-02-28 14:11:13 +01:00
parent dbab95ec45
commit 849cc45d84
17 changed files with 1852 additions and 111 deletions

View File

@@ -4,6 +4,15 @@ import mlx.nn as nn
from .rope import rope_apply
def _linear_dtype(layer) -> mx.Dtype:
"""Get the compute dtype of a linear layer, handling QuantizedLinear and LoRA wrappers."""
# Unwrap LoRA wrapper to get the underlying linear layer
inner = getattr(layer, "linear", layer)
if isinstance(inner, nn.QuantizedLinear):
return inner.scales.dtype
return inner.weight.dtype
class WanRMSNorm(nn.Module):
"""RMS normalization with learnable scale."""
@@ -73,8 +82,8 @@ class WanSelfAttention(nn.Module):
b, s, _ = x.shape
n, d = self.num_heads, self.head_dim
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = self.q.weight.dtype
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = _linear_dtype(self.q)
x_w = x.astype(w_dtype)
q = self.q(x_w)
@@ -154,8 +163,8 @@ class WanCrossAttention(nn.Module):
"""
b = context.shape[0]
n, d = self.num_heads, self.head_dim
# Cast to weight dtype for efficient matmul
w_dtype = self.k.weight.dtype
# Cast to compute dtype for efficient matmul
w_dtype = _linear_dtype(self.k)
ctx = context.astype(w_dtype)
k = self.k(ctx)
if self.norm_k is not None:
@@ -174,8 +183,8 @@ class WanCrossAttention(nn.Module):
b = x.shape[0]
n, d = self.num_heads, self.head_dim
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = self.q.weight.dtype
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = _linear_dtype(self.q)
q = self.q(x.astype(w_dtype))
if self.norm_q is not None:
q = self.norm_q(q)

View File

@@ -6,14 +6,15 @@ import mlx.core as mx
import mlx.nn as nn
def load_wan_model(model_path: Path, config, quantization: dict | None = None):
"""Load and initialize WanModel, with optional quantization support.
def load_wan_model(model_path: Path, config, quantization: dict | None = None, loras: list | None = None):
"""Load and initialize WanModel, with optional quantization and LoRA support.
Args:
model_path: Path to model safetensors file
config: WanModelConfig
quantization: Optional dict with 'bits' and 'group_size' keys.
If provided, creates QuantizedLinear stubs before loading.
loras: Optional list of (lora_path, strength) tuples to apply.
"""
from mlx_video.models.wan.model import WanModel
@@ -30,6 +31,27 @@ def load_wan_model(model_path: Path, config, quantization: dict | None = None):
)
weights = mx.load(str(model_path))
# Apply LoRAs: dequantize+merge for quantized models, weight merge for bf16
if loras:
if quantization:
# Dequantize LoRA-targeted layers, merge delta, replace with bf16 Linear.
# Non-LoRA layers stay 4-bit. Zero per-step overhead.
from mlx_video.convert_wan import _load_lora_configs
from mlx_video.lora import apply_loras_to_model
model.load_weights(list(weights.items()), strict=False)
mx.eval(model.parameters())
module_to_loras = _load_lora_configs(loras)
apply_loras_to_model(model, module_to_loras)
mx.eval(model.parameters())
return model
else:
# Weight merging: fold LoRA into bf16 weights before loading
from mlx_video.convert_wan import load_and_apply_loras
weights = load_and_apply_loras(dict(weights), loras)
model.load_weights(list(weights.items()), strict=False)
mx.eval(model.parameters())
return model

View File

@@ -4,7 +4,7 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .attention import WanLayerNorm
from .attention import WanLayerNorm, _linear_dtype
from .config import WanModelConfig
from .rope import rope_params, rope_precompute_cos_sin
from .transformer import WanAttentionBlock
@@ -54,7 +54,7 @@ class Head(nn.Module):
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
x_norm = self.norm(x)
x_mod = x_norm * (1 + e1) + e0 # type promotion handles bf16→f32
return self.head(x_mod.astype(self.head.weight.dtype))
return self.head(x_mod.astype(_linear_dtype(self.head)))
class WanModel(nn.Module):
@@ -79,7 +79,7 @@ class WanModel(nn.Module):
# Text embedding MLP
self.text_embedding_0 = nn.Linear(config.text_dim, dim)
self.text_embedding_act = nn.GELU(approx="precise")
self.text_embedding_act = nn.GELU(approx="tanh")
self.text_embedding_1 = nn.Linear(dim, dim)
# Time embedding MLP
@@ -149,7 +149,7 @@ class WanModel(nn.Module):
# Project and cast to model dtype to prevent float32 cascade from input latents
patches = self.patch_embedding_proj(x) # [L, dim]
patches = patches.astype(self.patch_embedding_proj.weight.dtype)
patches = patches.astype(_linear_dtype(self.patch_embedding_proj))
patches = patches[None, :, :] # [1, L, dim]
return patches, (f_out, h_out, w_out)
@@ -186,7 +186,7 @@ class WanModel(nn.Module):
Returns:
Embedded context [B, text_len, dim] in model dtype
"""
model_dtype = self.patch_embedding_proj.weight.dtype
model_dtype = _linear_dtype(self.patch_embedding_proj)
context_padded = []
for ctx in context:
pad_len = self.text_len - ctx.shape[0]
@@ -231,7 +231,7 @@ class WanModel(nn.Module):
Returns:
(cos_f, sin_f) precomputed frequency tensors
"""
w_dtype = self.patch_embedding_proj.weight.dtype
w_dtype = _linear_dtype(self.patch_embedding_proj)
return rope_precompute_cos_sin(grid_sizes, self.freqs, dtype=w_dtype)
def __call__(
@@ -348,7 +348,7 @@ class WanModel(nn.Module):
# Pre-compute attention mask from seq_lens (constant across all blocks)
attn_mask = None
w_dtype = self.patch_embedding_proj.weight.dtype
w_dtype = _linear_dtype(self.patch_embedding_proj)
if any(sl < seq_len for sl in seq_lens_list):
attn_mask = mx.zeros((batch_size, 1, 1, seq_len), dtype=w_dtype)
for i, sl in enumerate(seq_lens_list):

View File

@@ -146,7 +146,7 @@ class T5FeedForward(nn.Module):
self.dim = dim
self.dim_ffn = dim_ffn
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
self.gate_act = nn.GELU(approx="precise")
self.gate_act = nn.GELU(approx="tanh")
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)

View File

@@ -1,7 +1,7 @@
import mlx.core as mx
import mlx.nn as nn
from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention
from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention, _linear_dtype
class WanAttentionBlock(nn.Module):
@@ -84,10 +84,10 @@ class WanFFN(nn.Module):
def __init__(self, dim: int, ffn_dim: int):
super().__init__()
self.fc1 = nn.Linear(dim, ffn_dim)
self.act = nn.GELU(approx="precise")
self.act = nn.GELU(approx="tanh")
self.fc2 = nn.Linear(ffn_dim, dim)
def __call__(self, x: mx.array) -> mx.array:
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
x_w = x.astype(self.fc1.weight.dtype)
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
x_w = x.astype(_linear_dtype(self.fc1))
return self.fc2(self.act(self.fc1(x_w)))

View File

@@ -7,12 +7,15 @@ Weight keys mirror the PyTorch checkpoint hierarchy so only tensor format
conversion (channels-first → channels-last) is needed.
"""
import logging
import math
import mlx.core as mx
import mlx.nn as nn
import numpy as np
logger = logging.getLogger(__name__)
CACHE_T = 2
# Per-channel normalization for z_dim=48 latent space
@@ -774,11 +777,13 @@ def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) ->
Maps PyTorch nn.Sequential indices to our named layers.
"""
sanitized = {}
consumed = set()
for key, value in weights.items():
# Skip encoder and conv1 unless requested
if not include_encoder:
if key.startswith("encoder.") or key.startswith("conv1."):
consumed.add(key)
continue
new_key = key
@@ -832,5 +837,10 @@ def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) ->
value = mx.array(np.array(value).squeeze())
sanitized[new_key] = value
consumed.add(key)
unconsumed = set(weights.keys()) - consumed
if unconsumed:
logger.warning("Unconsumed Wan2.2 VAE weight keys: %s", sorted(unconsumed))
return sanitized