perf(wan): Add mx.compile and fix first-frame artifacts
This commit is contained in:
@@ -81,7 +81,8 @@ class CausalConv3d(nn.Module):
|
||||
y = mx.conv_general(x_flat, w2d) + self.bias
|
||||
return y.reshape(B, T, y.shape[1], y.shape[2], -1)
|
||||
|
||||
# Causal temporal padding (left only)
|
||||
# Causal temporal padding (left only) — zeros match the reference
|
||||
# implementation and what the model was trained with.
|
||||
if self._causal_pad_t > 0:
|
||||
pad_t = mx.zeros((B, self._causal_pad_t, H, W, C))
|
||||
x = mx.concatenate([pad_t, x], axis=1)
|
||||
|
||||
Reference in New Issue
Block a user