feat(wan): Add LoRA with improved quantization pipeline
This commit is contained in:
@@ -4,6 +4,15 @@ import mlx.nn as nn
|
||||
from .rope import rope_apply
|
||||
|
||||
|
||||
def _linear_dtype(layer) -> mx.Dtype:
|
||||
"""Get the compute dtype of a linear layer, handling QuantizedLinear and LoRA wrappers."""
|
||||
# Unwrap LoRA wrapper to get the underlying linear layer
|
||||
inner = getattr(layer, "linear", layer)
|
||||
if isinstance(inner, nn.QuantizedLinear):
|
||||
return inner.scales.dtype
|
||||
return inner.weight.dtype
|
||||
|
||||
|
||||
class WanRMSNorm(nn.Module):
|
||||
"""RMS normalization with learnable scale."""
|
||||
|
||||
@@ -73,8 +82,8 @@ class WanSelfAttention(nn.Module):
|
||||
b, s, _ = x.shape
|
||||
n, d = self.num_heads, self.head_dim
|
||||
|
||||
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
w_dtype = self.q.weight.dtype
|
||||
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
w_dtype = _linear_dtype(self.q)
|
||||
x_w = x.astype(w_dtype)
|
||||
|
||||
q = self.q(x_w)
|
||||
@@ -154,8 +163,8 @@ class WanCrossAttention(nn.Module):
|
||||
"""
|
||||
b = context.shape[0]
|
||||
n, d = self.num_heads, self.head_dim
|
||||
# Cast to weight dtype for efficient matmul
|
||||
w_dtype = self.k.weight.dtype
|
||||
# Cast to compute dtype for efficient matmul
|
||||
w_dtype = _linear_dtype(self.k)
|
||||
ctx = context.astype(w_dtype)
|
||||
k = self.k(ctx)
|
||||
if self.norm_k is not None:
|
||||
@@ -174,8 +183,8 @@ class WanCrossAttention(nn.Module):
|
||||
b = x.shape[0]
|
||||
n, d = self.num_heads, self.head_dim
|
||||
|
||||
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
w_dtype = self.q.weight.dtype
|
||||
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
w_dtype = _linear_dtype(self.q)
|
||||
q = self.q(x.astype(w_dtype))
|
||||
if self.norm_q is not None:
|
||||
q = self.norm_q(q)
|
||||
|
||||
@@ -6,14 +6,15 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
def load_wan_model(model_path: Path, config, quantization: dict | None = None):
|
||||
"""Load and initialize WanModel, with optional quantization support.
|
||||
def load_wan_model(model_path: Path, config, quantization: dict | None = None, loras: list | None = None):
|
||||
"""Load and initialize WanModel, with optional quantization and LoRA support.
|
||||
|
||||
Args:
|
||||
model_path: Path to model safetensors file
|
||||
config: WanModelConfig
|
||||
quantization: Optional dict with 'bits' and 'group_size' keys.
|
||||
If provided, creates QuantizedLinear stubs before loading.
|
||||
loras: Optional list of (lora_path, strength) tuples to apply.
|
||||
"""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
@@ -30,6 +31,27 @@ def load_wan_model(model_path: Path, config, quantization: dict | None = None):
|
||||
)
|
||||
|
||||
weights = mx.load(str(model_path))
|
||||
|
||||
# Apply LoRAs: dequantize+merge for quantized models, weight merge for bf16
|
||||
if loras:
|
||||
if quantization:
|
||||
# Dequantize LoRA-targeted layers, merge delta, replace with bf16 Linear.
|
||||
# Non-LoRA layers stay 4-bit. Zero per-step overhead.
|
||||
from mlx_video.convert_wan import _load_lora_configs
|
||||
from mlx_video.lora import apply_loras_to_model
|
||||
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(model.parameters())
|
||||
module_to_loras = _load_lora_configs(loras)
|
||||
apply_loras_to_model(model, module_to_loras)
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
else:
|
||||
# Weight merging: fold LoRA into bf16 weights before loading
|
||||
from mlx_video.convert_wan import load_and_apply_loras
|
||||
|
||||
weights = load_and_apply_loras(dict(weights), loras)
|
||||
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
|
||||
@@ -4,7 +4,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .attention import WanLayerNorm
|
||||
from .attention import WanLayerNorm, _linear_dtype
|
||||
from .config import WanModelConfig
|
||||
from .rope import rope_params, rope_precompute_cos_sin
|
||||
from .transformer import WanAttentionBlock
|
||||
@@ -54,7 +54,7 @@ class Head(nn.Module):
|
||||
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
|
||||
x_norm = self.norm(x)
|
||||
x_mod = x_norm * (1 + e1) + e0 # type promotion handles bf16→f32
|
||||
return self.head(x_mod.astype(self.head.weight.dtype))
|
||||
return self.head(x_mod.astype(_linear_dtype(self.head)))
|
||||
|
||||
|
||||
class WanModel(nn.Module):
|
||||
@@ -79,7 +79,7 @@ class WanModel(nn.Module):
|
||||
|
||||
# Text embedding MLP
|
||||
self.text_embedding_0 = nn.Linear(config.text_dim, dim)
|
||||
self.text_embedding_act = nn.GELU(approx="precise")
|
||||
self.text_embedding_act = nn.GELU(approx="tanh")
|
||||
self.text_embedding_1 = nn.Linear(dim, dim)
|
||||
|
||||
# Time embedding MLP
|
||||
@@ -149,7 +149,7 @@ class WanModel(nn.Module):
|
||||
|
||||
# Project and cast to model dtype to prevent float32 cascade from input latents
|
||||
patches = self.patch_embedding_proj(x) # [L, dim]
|
||||
patches = patches.astype(self.patch_embedding_proj.weight.dtype)
|
||||
patches = patches.astype(_linear_dtype(self.patch_embedding_proj))
|
||||
patches = patches[None, :, :] # [1, L, dim]
|
||||
|
||||
return patches, (f_out, h_out, w_out)
|
||||
@@ -186,7 +186,7 @@ class WanModel(nn.Module):
|
||||
Returns:
|
||||
Embedded context [B, text_len, dim] in model dtype
|
||||
"""
|
||||
model_dtype = self.patch_embedding_proj.weight.dtype
|
||||
model_dtype = _linear_dtype(self.patch_embedding_proj)
|
||||
context_padded = []
|
||||
for ctx in context:
|
||||
pad_len = self.text_len - ctx.shape[0]
|
||||
@@ -231,7 +231,7 @@ class WanModel(nn.Module):
|
||||
Returns:
|
||||
(cos_f, sin_f) precomputed frequency tensors
|
||||
"""
|
||||
w_dtype = self.patch_embedding_proj.weight.dtype
|
||||
w_dtype = _linear_dtype(self.patch_embedding_proj)
|
||||
return rope_precompute_cos_sin(grid_sizes, self.freqs, dtype=w_dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -348,7 +348,7 @@ class WanModel(nn.Module):
|
||||
|
||||
# Pre-compute attention mask from seq_lens (constant across all blocks)
|
||||
attn_mask = None
|
||||
w_dtype = self.patch_embedding_proj.weight.dtype
|
||||
w_dtype = _linear_dtype(self.patch_embedding_proj)
|
||||
if any(sl < seq_len for sl in seq_lens_list):
|
||||
attn_mask = mx.zeros((batch_size, 1, 1, seq_len), dtype=w_dtype)
|
||||
for i, sl in enumerate(seq_lens_list):
|
||||
|
||||
@@ -146,7 +146,7 @@ class T5FeedForward(nn.Module):
|
||||
self.dim = dim
|
||||
self.dim_ffn = dim_ffn
|
||||
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.gate_act = nn.GELU(approx="precise")
|
||||
self.gate_act = nn.GELU(approx="tanh")
|
||||
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention
|
||||
from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention, _linear_dtype
|
||||
|
||||
|
||||
class WanAttentionBlock(nn.Module):
|
||||
@@ -84,10 +84,10 @@ class WanFFN(nn.Module):
|
||||
def __init__(self, dim: int, ffn_dim: int):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(dim, ffn_dim)
|
||||
self.act = nn.GELU(approx="precise")
|
||||
self.act = nn.GELU(approx="tanh")
|
||||
self.fc2 = nn.Linear(ffn_dim, dim)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
x_w = x.astype(self.fc1.weight.dtype)
|
||||
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
x_w = x.astype(_linear_dtype(self.fc1))
|
||||
return self.fc2(self.act(self.fc1(x_w)))
|
||||
|
||||
@@ -7,12 +7,15 @@ Weight keys mirror the PyTorch checkpoint hierarchy so only tensor format
|
||||
conversion (channels-first → channels-last) is needed.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
# Per-channel normalization for z_dim=48 latent space
|
||||
@@ -774,11 +777,13 @@ def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) ->
|
||||
Maps PyTorch nn.Sequential indices to our named layers.
|
||||
"""
|
||||
sanitized = {}
|
||||
consumed = set()
|
||||
|
||||
for key, value in weights.items():
|
||||
# Skip encoder and conv1 unless requested
|
||||
if not include_encoder:
|
||||
if key.startswith("encoder.") or key.startswith("conv1."):
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
new_key = key
|
||||
@@ -832,5 +837,10 @@ def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) ->
|
||||
value = mx.array(np.array(value).squeeze())
|
||||
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
|
||||
unconsumed = set(weights.keys()) - consumed
|
||||
if unconsumed:
|
||||
logger.warning("Unconsumed Wan2.2 VAE weight keys: %s", sorted(unconsumed))
|
||||
|
||||
return sanitized
|
||||
|
||||
Reference in New Issue
Block a user