fix(wan): Fix scheduler sigma schedule and add debug flags

This commit is contained in:
Daniel
2026-03-11 07:52:07 +01:00
parent afd15018b7
commit d207275fea
4 changed files with 121 additions and 34 deletions

View File

@@ -118,14 +118,12 @@ class WanModel(nn.Module):
rope_params(1024, 2 * (d // 6)),
], axis=1)
# Precompute sinusoidal inv_freq for time embedding
# Use numpy float64 for precision (matches reference torch.float64),
# then store as float32 since MLX GPU doesn't support float64.
# Precompute sinusoidal inv_freq for time embedding.
half = config.freq_dim // 2
inv_freq_np = np.power(
10000.0, -np.arange(half, dtype=np.float64) / half
self._inv_freq = mx.array(
np.power(10000.0, -np.arange(half, dtype=np.float64) / half
).astype(np.float32)
)
self._inv_freq = mx.array(inv_freq_np.astype(np.float32))
def _patchify(self, x: mx.array) -> tuple:
@@ -311,13 +309,16 @@ class WanModel(nn.Module):
axis=0,
) # [B, seq_len, dim]
# Time embedding (use cached inv_freq to avoid recomputing each step)
# Time embedding: sinusoidal from precomputed inv_freq.
# inv_freq was computed in float64 for precision, stored as float32.
# With integer timesteps (matching reference), float32 sin/cos is fine.
if t.ndim == 0:
t = t[None]
pos = t.astype(mx.float32)
sinusoid = pos[..., None] * self._inv_freq
sin_emb = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
sinusoid = t[..., None].astype(mx.float32) * self._inv_freq
sin_emb = mx.concatenate(
[mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1
)
if t.ndim == 1:
# Standard T2V: scalar timestep per batch element [B]