format
This commit is contained in:
@@ -25,9 +25,7 @@ class WanAttentionBlock(nn.Module):
|
||||
|
||||
# Cross-attention (with optional norm on context)
|
||||
self.norm3 = (
|
||||
WanLayerNorm(dim, eps, elementwise_affine=True)
|
||||
if cross_attn_norm
|
||||
else None
|
||||
WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else None
|
||||
)
|
||||
self.cross_attn = WanCrossAttention(dim, num_heads, qk_norm, eps)
|
||||
|
||||
@@ -36,7 +34,9 @@ class WanAttentionBlock(nn.Module):
|
||||
self.ffn = WanFFN(dim, ffn_dim)
|
||||
|
||||
# Learned modulation: 6 vectors for scale/shift/gate (kept in float32 for precision)
|
||||
self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(mx.float32)
|
||||
self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(
|
||||
mx.float32
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -67,7 +67,14 @@ class WanAttentionBlock(nn.Module):
|
||||
|
||||
# Self-attention with modulation (hidden state stays in w_dtype)
|
||||
x_mod = self.norm1(x) * (1 + e1) + e0
|
||||
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs, rope_cos_sin=rope_cos_sin, attn_mask=attn_mask)
|
||||
y = self.self_attn(
|
||||
x_mod,
|
||||
seq_lens,
|
||||
grid_sizes,
|
||||
freqs,
|
||||
rope_cos_sin=rope_cos_sin,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
x = x + y * e2
|
||||
|
||||
# Cross-attention (no modulation, just norm)
|
||||
|
||||
Reference in New Issue
Block a user