This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -25,9 +25,7 @@ class WanAttentionBlock(nn.Module):
# Cross-attention (with optional norm on context)
self.norm3 = (
WanLayerNorm(dim, eps, elementwise_affine=True)
if cross_attn_norm
else None
WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else None
)
self.cross_attn = WanCrossAttention(dim, num_heads, qk_norm, eps)
@@ -36,7 +34,9 @@ class WanAttentionBlock(nn.Module):
self.ffn = WanFFN(dim, ffn_dim)
# Learned modulation: 6 vectors for scale/shift/gate (kept in float32 for precision)
self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(mx.float32)
self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(
mx.float32
)
def __call__(
self,
@@ -67,7 +67,14 @@ class WanAttentionBlock(nn.Module):
# Self-attention with modulation (hidden state stays in w_dtype)
x_mod = self.norm1(x) * (1 + e1) + e0
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs, rope_cos_sin=rope_cos_sin, attn_mask=attn_mask)
y = self.self_attn(
x_mod,
seq_lens,
grid_sizes,
freqs,
rope_cos_sin=rope_cos_sin,
attn_mask=attn_mask,
)
x = x + y * e2
# Cross-attention (no modulation, just norm)