feat(wan): Add LoRA with improved quantization pipeline
This commit is contained in:
@@ -4,7 +4,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .attention import WanLayerNorm
|
||||
from .attention import WanLayerNorm, _linear_dtype
|
||||
from .config import WanModelConfig
|
||||
from .rope import rope_params, rope_precompute_cos_sin
|
||||
from .transformer import WanAttentionBlock
|
||||
@@ -54,7 +54,7 @@ class Head(nn.Module):
|
||||
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
|
||||
x_norm = self.norm(x)
|
||||
x_mod = x_norm * (1 + e1) + e0 # type promotion handles bf16→f32
|
||||
return self.head(x_mod.astype(self.head.weight.dtype))
|
||||
return self.head(x_mod.astype(_linear_dtype(self.head)))
|
||||
|
||||
|
||||
class WanModel(nn.Module):
|
||||
@@ -79,7 +79,7 @@ class WanModel(nn.Module):
|
||||
|
||||
# Text embedding MLP
|
||||
self.text_embedding_0 = nn.Linear(config.text_dim, dim)
|
||||
self.text_embedding_act = nn.GELU(approx="precise")
|
||||
self.text_embedding_act = nn.GELU(approx="tanh")
|
||||
self.text_embedding_1 = nn.Linear(dim, dim)
|
||||
|
||||
# Time embedding MLP
|
||||
@@ -149,7 +149,7 @@ class WanModel(nn.Module):
|
||||
|
||||
# Project and cast to model dtype to prevent float32 cascade from input latents
|
||||
patches = self.patch_embedding_proj(x) # [L, dim]
|
||||
patches = patches.astype(self.patch_embedding_proj.weight.dtype)
|
||||
patches = patches.astype(_linear_dtype(self.patch_embedding_proj))
|
||||
patches = patches[None, :, :] # [1, L, dim]
|
||||
|
||||
return patches, (f_out, h_out, w_out)
|
||||
@@ -186,7 +186,7 @@ class WanModel(nn.Module):
|
||||
Returns:
|
||||
Embedded context [B, text_len, dim] in model dtype
|
||||
"""
|
||||
model_dtype = self.patch_embedding_proj.weight.dtype
|
||||
model_dtype = _linear_dtype(self.patch_embedding_proj)
|
||||
context_padded = []
|
||||
for ctx in context:
|
||||
pad_len = self.text_len - ctx.shape[0]
|
||||
@@ -231,7 +231,7 @@ class WanModel(nn.Module):
|
||||
Returns:
|
||||
(cos_f, sin_f) precomputed frequency tensors
|
||||
"""
|
||||
w_dtype = self.patch_embedding_proj.weight.dtype
|
||||
w_dtype = _linear_dtype(self.patch_embedding_proj)
|
||||
return rope_precompute_cos_sin(grid_sizes, self.freqs, dtype=w_dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -348,7 +348,7 @@ class WanModel(nn.Module):
|
||||
|
||||
# Pre-compute attention mask from seq_lens (constant across all blocks)
|
||||
attn_mask = None
|
||||
w_dtype = self.patch_embedding_proj.weight.dtype
|
||||
w_dtype = _linear_dtype(self.patch_embedding_proj)
|
||||
if any(sl < seq_len for sl in seq_lens_list):
|
||||
attn_mask = mx.zeros((batch_size, 1, 1, seq_len), dtype=w_dtype)
|
||||
for i, sl in enumerate(seq_lens_list):
|
||||
|
||||
Reference in New Issue
Block a user