feat(wan): Add LoRA with improved quantization pipeline
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention
|
||||
from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention, _linear_dtype
|
||||
|
||||
|
||||
class WanAttentionBlock(nn.Module):
|
||||
@@ -84,10 +84,10 @@ class WanFFN(nn.Module):
|
||||
def __init__(self, dim: int, ffn_dim: int):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(dim, ffn_dim)
|
||||
self.act = nn.GELU(approx="precise")
|
||||
self.act = nn.GELU(approx="tanh")
|
||||
self.fc2 = nn.Linear(ffn_dim, dim)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
x_w = x.astype(self.fc1.weight.dtype)
|
||||
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
x_w = x.astype(_linear_dtype(self.fc1))
|
||||
return self.fc2(self.act(self.fc1(x_w)))
|
||||
|
||||
Reference in New Issue
Block a user