perf(wan): Add mx.compile and fix first-frame artifacts
This commit is contained in:
@@ -51,17 +51,20 @@ class WanAttentionBlock(nn.Module):
|
||||
rope_cos_sin: tuple | None = None,
|
||||
attn_mask: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
# Modulation in float32 (e is already float32 from model forward)
|
||||
mod = self.modulation + e
|
||||
e0 = mod[:, :, 0, :] # shift for self-attn
|
||||
e1 = mod[:, :, 1, :] # scale for self-attn
|
||||
e2 = mod[:, :, 2, :] # gate for self-attn
|
||||
e3 = mod[:, :, 3, :] # shift for ffn
|
||||
e4 = mod[:, :, 4, :] # scale for ffn
|
||||
e5 = mod[:, :, 5, :] # gate for ffn
|
||||
# Modulation: compute in float32 for precision, cast to working dtype
|
||||
# to avoid promoting the full hidden state (seq_len × dim) to float32
|
||||
w_dtype = _linear_dtype(self.self_attn.q)
|
||||
mod = (self.modulation + e).astype(w_dtype)
|
||||
e0, e1, e2, e3, e4, e5 = (
|
||||
mod[:, :, 0, :], # shift for self-attn
|
||||
mod[:, :, 1, :], # scale for self-attn
|
||||
mod[:, :, 2, :], # gate for self-attn
|
||||
mod[:, :, 3, :], # shift for ffn
|
||||
mod[:, :, 4, :], # scale for ffn
|
||||
mod[:, :, 5, :], # gate for ffn
|
||||
)
|
||||
|
||||
# Self-attention with modulation
|
||||
# Type promotion handles bf16→f32 automatically when multiplied with f32 modulation
|
||||
# Self-attention with modulation (hidden state stays in w_dtype)
|
||||
x_mod = self.norm1(x) * (1 + e1) + e0
|
||||
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs, rope_cos_sin=rope_cos_sin, attn_mask=attn_mask)
|
||||
x = x + y * e2
|
||||
|
||||
Reference in New Issue
Block a user