feat(wan): Add LoRA with improved quantization pipeline
This commit is contained in:
@@ -4,6 +4,15 @@ import mlx.nn as nn
|
||||
from .rope import rope_apply
|
||||
|
||||
|
||||
def _linear_dtype(layer) -> mx.Dtype:
|
||||
"""Get the compute dtype of a linear layer, handling QuantizedLinear and LoRA wrappers."""
|
||||
# Unwrap LoRA wrapper to get the underlying linear layer
|
||||
inner = getattr(layer, "linear", layer)
|
||||
if isinstance(inner, nn.QuantizedLinear):
|
||||
return inner.scales.dtype
|
||||
return inner.weight.dtype
|
||||
|
||||
|
||||
class WanRMSNorm(nn.Module):
|
||||
"""RMS normalization with learnable scale."""
|
||||
|
||||
@@ -73,8 +82,8 @@ class WanSelfAttention(nn.Module):
|
||||
b, s, _ = x.shape
|
||||
n, d = self.num_heads, self.head_dim
|
||||
|
||||
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
w_dtype = self.q.weight.dtype
|
||||
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
w_dtype = _linear_dtype(self.q)
|
||||
x_w = x.astype(w_dtype)
|
||||
|
||||
q = self.q(x_w)
|
||||
@@ -154,8 +163,8 @@ class WanCrossAttention(nn.Module):
|
||||
"""
|
||||
b = context.shape[0]
|
||||
n, d = self.num_heads, self.head_dim
|
||||
# Cast to weight dtype for efficient matmul
|
||||
w_dtype = self.k.weight.dtype
|
||||
# Cast to compute dtype for efficient matmul
|
||||
w_dtype = _linear_dtype(self.k)
|
||||
ctx = context.astype(w_dtype)
|
||||
k = self.k(ctx)
|
||||
if self.norm_k is not None:
|
||||
@@ -174,8 +183,8 @@ class WanCrossAttention(nn.Module):
|
||||
b = x.shape[0]
|
||||
n, d = self.num_heads, self.head_dim
|
||||
|
||||
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
w_dtype = self.q.weight.dtype
|
||||
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
w_dtype = _linear_dtype(self.q)
|
||||
q = self.q(x.astype(w_dtype))
|
||||
if self.norm_q is not None:
|
||||
q = self.norm_q(q)
|
||||
|
||||
Reference in New Issue
Block a user