feat(wan): Add LoRA with improved quantization pipeline

This commit is contained in:
Daniel
2026-02-28 14:11:13 +01:00
parent dbab95ec45
commit 849cc45d84
17 changed files with 1852 additions and 111 deletions

View File

@@ -4,7 +4,7 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .attention import WanLayerNorm
from .attention import WanLayerNorm, _linear_dtype
from .config import WanModelConfig
from .rope import rope_params, rope_precompute_cos_sin
from .transformer import WanAttentionBlock
@@ -54,7 +54,7 @@ class Head(nn.Module):
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
x_norm = self.norm(x)
x_mod = x_norm * (1 + e1) + e0 # type promotion handles bf16→f32
return self.head(x_mod.astype(self.head.weight.dtype))
return self.head(x_mod.astype(_linear_dtype(self.head)))
class WanModel(nn.Module):
@@ -79,7 +79,7 @@ class WanModel(nn.Module):
# Text embedding MLP
self.text_embedding_0 = nn.Linear(config.text_dim, dim)
self.text_embedding_act = nn.GELU(approx="precise")
self.text_embedding_act = nn.GELU(approx="tanh")
self.text_embedding_1 = nn.Linear(dim, dim)
# Time embedding MLP
@@ -149,7 +149,7 @@ class WanModel(nn.Module):
# Project and cast to model dtype to prevent float32 cascade from input latents
patches = self.patch_embedding_proj(x) # [L, dim]
patches = patches.astype(self.patch_embedding_proj.weight.dtype)
patches = patches.astype(_linear_dtype(self.patch_embedding_proj))
patches = patches[None, :, :] # [1, L, dim]
return patches, (f_out, h_out, w_out)
@@ -186,7 +186,7 @@ class WanModel(nn.Module):
Returns:
Embedded context [B, text_len, dim] in model dtype
"""
model_dtype = self.patch_embedding_proj.weight.dtype
model_dtype = _linear_dtype(self.patch_embedding_proj)
context_padded = []
for ctx in context:
pad_len = self.text_len - ctx.shape[0]
@@ -231,7 +231,7 @@ class WanModel(nn.Module):
Returns:
(cos_f, sin_f) precomputed frequency tensors
"""
w_dtype = self.patch_embedding_proj.weight.dtype
w_dtype = _linear_dtype(self.patch_embedding_proj)
return rope_precompute_cos_sin(grid_sizes, self.freqs, dtype=w_dtype)
def __call__(
@@ -348,7 +348,7 @@ class WanModel(nn.Module):
# Pre-compute attention mask from seq_lens (constant across all blocks)
attn_mask = None
w_dtype = self.patch_embedding_proj.weight.dtype
w_dtype = _linear_dtype(self.patch_embedding_proj)
if any(sl < seq_len for sl in seq_lens_list):
attn_mask = mx.zeros((batch_size, 1, 1, seq_len), dtype=w_dtype)
for i, sl in enumerate(seq_lens_list):