This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -98,8 +98,12 @@ class WanSelfAttention(nn.Module):
v = self.v(x_w).reshape(b, s, n, d)
# RoPE in float32 for precision (official uses float64)
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
q = rope_apply(
q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
)
k = rope_apply(
k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
)
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
@@ -120,9 +124,7 @@ class WanSelfAttention(nn.Module):
q, k, v, scale=self.scale, mask=mask
)
else:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale
)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
out = out.transpose(0, 2, 1, 3).reshape(b, s, -1)
return self.o(out)
@@ -213,9 +215,7 @@ class WanCrossAttention(nn.Module):
q, k, v, scale=self.scale, mask=mask
)
else:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale
)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d)
return self.o(out)