perf(wan): Add mx.compile and fix first-frame artifacts

This commit is contained in:
Daniel
2026-03-01 18:15:25 +01:00
parent 849cc45d84
commit 9597b7c9c5
4 changed files with 52 additions and 38 deletions

View File

@@ -48,13 +48,14 @@ class Head(nn.Module):
"""
if e.ndim == 2:
e = e[:, None, :] # [B, 1, dim]
# modulation already float32; e already float32 from model forward
mod = self.modulation[:, None, :, :] + e[:, :, None, :] # [B, L_e, 2, dim]
# Compute modulation in float32 for precision, cast to working dtype
w_dtype = _linear_dtype(self.head)
mod = (self.modulation[:, None, :, :] + e[:, :, None, :]).astype(w_dtype)
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
x_norm = self.norm(x)
x_mod = x_norm * (1 + e1) + e0 # type promotion handles bf16→f32
return self.head(x_mod.astype(_linear_dtype(self.head)))
x_mod = x_norm * (1 + e1) + e0
return self.head(x_mod)
class WanModel(nn.Module):
@@ -322,18 +323,14 @@ class WanModel(nn.Module):
self.time_embedding_act(self.time_embedding_0(sin_emb))
) # [B, dim]
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
# Keep e and e0 in float32 — official asserts float32 for modulation
e0 = e0.reshape(batch_size, 1, 6, self.dim).astype(mx.float32)
e = e.astype(mx.float32)
e0 = e0.reshape(batch_size, 1, 6, self.dim)
else:
# I2V: per-token timesteps [B, L]
e = self.time_embedding_1(
self.time_embedding_act(self.time_embedding_0(sin_emb))
) # [B, L, dim]
e0 = self.time_projection(self.time_projection_act(e)) # [B, L, dim*6]
# Keep e and e0 in float32 — official asserts float32 for modulation
e0 = e0.reshape(batch_size, -1, 6, self.dim).astype(mx.float32)
e = e.astype(mx.float32)
e0 = e0.reshape(batch_size, -1, 6, self.dim)
# Text embedding: skip MLP if context is already embedded (mx.array)
if isinstance(context, mx.array):