perf(wan): Add mx.compile and fix first-frame artifacts
This commit is contained in:
@@ -48,13 +48,14 @@ class Head(nn.Module):
|
||||
"""
|
||||
if e.ndim == 2:
|
||||
e = e[:, None, :] # [B, 1, dim]
|
||||
# modulation already float32; e already float32 from model forward
|
||||
mod = self.modulation[:, None, :, :] + e[:, :, None, :] # [B, L_e, 2, dim]
|
||||
# Compute modulation in float32 for precision, cast to working dtype
|
||||
w_dtype = _linear_dtype(self.head)
|
||||
mod = (self.modulation[:, None, :, :] + e[:, :, None, :]).astype(w_dtype)
|
||||
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
|
||||
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
|
||||
x_norm = self.norm(x)
|
||||
x_mod = x_norm * (1 + e1) + e0 # type promotion handles bf16→f32
|
||||
return self.head(x_mod.astype(_linear_dtype(self.head)))
|
||||
x_mod = x_norm * (1 + e1) + e0
|
||||
return self.head(x_mod)
|
||||
|
||||
|
||||
class WanModel(nn.Module):
|
||||
@@ -322,18 +323,14 @@ class WanModel(nn.Module):
|
||||
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||
) # [B, dim]
|
||||
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
|
||||
# Keep e and e0 in float32 — official asserts float32 for modulation
|
||||
e0 = e0.reshape(batch_size, 1, 6, self.dim).astype(mx.float32)
|
||||
e = e.astype(mx.float32)
|
||||
e0 = e0.reshape(batch_size, 1, 6, self.dim)
|
||||
else:
|
||||
# I2V: per-token timesteps [B, L]
|
||||
e = self.time_embedding_1(
|
||||
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||
) # [B, L, dim]
|
||||
e0 = self.time_projection(self.time_projection_act(e)) # [B, L, dim*6]
|
||||
# Keep e and e0 in float32 — official asserts float32 for modulation
|
||||
e0 = e0.reshape(batch_size, -1, 6, self.dim).astype(mx.float32)
|
||||
e = e.astype(mx.float32)
|
||||
e0 = e0.reshape(batch_size, -1, 6, self.dim)
|
||||
|
||||
# Text embedding: skip MLP if context is already embedded (mx.array)
|
||||
if isinstance(context, mx.array):
|
||||
|
||||
@@ -51,17 +51,20 @@ class WanAttentionBlock(nn.Module):
|
||||
rope_cos_sin: tuple | None = None,
|
||||
attn_mask: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
# Modulation in float32 (e is already float32 from model forward)
|
||||
mod = self.modulation + e
|
||||
e0 = mod[:, :, 0, :] # shift for self-attn
|
||||
e1 = mod[:, :, 1, :] # scale for self-attn
|
||||
e2 = mod[:, :, 2, :] # gate for self-attn
|
||||
e3 = mod[:, :, 3, :] # shift for ffn
|
||||
e4 = mod[:, :, 4, :] # scale for ffn
|
||||
e5 = mod[:, :, 5, :] # gate for ffn
|
||||
# Modulation: compute in float32 for precision, cast to working dtype
|
||||
# to avoid promoting the full hidden state (seq_len × dim) to float32
|
||||
w_dtype = _linear_dtype(self.self_attn.q)
|
||||
mod = (self.modulation + e).astype(w_dtype)
|
||||
e0, e1, e2, e3, e4, e5 = (
|
||||
mod[:, :, 0, :], # shift for self-attn
|
||||
mod[:, :, 1, :], # scale for self-attn
|
||||
mod[:, :, 2, :], # gate for self-attn
|
||||
mod[:, :, 3, :], # shift for ffn
|
||||
mod[:, :, 4, :], # scale for ffn
|
||||
mod[:, :, 5, :], # gate for ffn
|
||||
)
|
||||
|
||||
# Self-attention with modulation
|
||||
# Type promotion handles bf16→f32 automatically when multiplied with f32 modulation
|
||||
# Self-attention with modulation (hidden state stays in w_dtype)
|
||||
x_mod = self.norm1(x) * (1 + e1) + e0
|
||||
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs, rope_cos_sin=rope_cos_sin, attn_mask=attn_mask)
|
||||
x = x + y * e2
|
||||
|
||||
@@ -81,7 +81,8 @@ class CausalConv3d(nn.Module):
|
||||
y = mx.conv_general(x_flat, w2d) + self.bias
|
||||
return y.reshape(B, T, y.shape[1], y.shape[2], -1)
|
||||
|
||||
# Causal temporal padding (left only)
|
||||
# Causal temporal padding (left only) — zeros match the reference
|
||||
# implementation and what the model was trained with.
|
||||
if self._causal_pad_t > 0:
|
||||
pad_t = mx.zeros((B, self._causal_pad_t, H, W, C))
|
||||
x = mx.concatenate([pad_t, x], axis=1)
|
||||
|
||||
Reference in New Issue
Block a user