perf(wan): Add mx.compile and fix first-frame artifacts

This commit is contained in:
Daniel
2026-03-01 18:15:25 +01:00
parent 849cc45d84
commit 9597b7c9c5
4 changed files with 52 additions and 38 deletions

View File

@@ -81,7 +81,8 @@ class CausalConv3d(nn.Module):
y = mx.conv_general(x_flat, w2d) + self.bias
return y.reshape(B, T, y.shape[1], y.shape[2], -1)
# Causal temporal padding (left only)
# Causal temporal padding (left only) — zeros match the reference
# implementation and what the model was trained with.
if self._causal_pad_t > 0:
pad_t = mx.zeros((B, self._causal_pad_t, H, W, C))
x = mx.concatenate([pad_t, x], axis=1)