fix(wan): Fix scheduler sigma schedule and add debug flags
This commit is contained in:
@@ -118,14 +118,12 @@ class WanModel(nn.Module):
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
|
||||
# Precompute sinusoidal inv_freq for time embedding
|
||||
# Use numpy float64 for precision (matches reference torch.float64),
|
||||
# then store as float32 since MLX GPU doesn't support float64.
|
||||
# Precompute sinusoidal inv_freq for time embedding.
|
||||
half = config.freq_dim // 2
|
||||
inv_freq_np = np.power(
|
||||
10000.0, -np.arange(half, dtype=np.float64) / half
|
||||
self._inv_freq = mx.array(
|
||||
np.power(10000.0, -np.arange(half, dtype=np.float64) / half
|
||||
).astype(np.float32)
|
||||
)
|
||||
self._inv_freq = mx.array(inv_freq_np.astype(np.float32))
|
||||
|
||||
|
||||
def _patchify(self, x: mx.array) -> tuple:
|
||||
@@ -311,13 +309,16 @@ class WanModel(nn.Module):
|
||||
axis=0,
|
||||
) # [B, seq_len, dim]
|
||||
|
||||
# Time embedding (use cached inv_freq to avoid recomputing each step)
|
||||
# Time embedding: sinusoidal from precomputed inv_freq.
|
||||
# inv_freq was computed in float64 for precision, stored as float32.
|
||||
# With integer timesteps (matching reference), float32 sin/cos is fine.
|
||||
if t.ndim == 0:
|
||||
t = t[None]
|
||||
|
||||
pos = t.astype(mx.float32)
|
||||
sinusoid = pos[..., None] * self._inv_freq
|
||||
sin_emb = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
|
||||
sinusoid = t[..., None].astype(mx.float32) * self._inv_freq
|
||||
sin_emb = mx.concatenate(
|
||||
[mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1
|
||||
)
|
||||
|
||||
if t.ndim == 1:
|
||||
# Standard T2V: scalar timestep per batch element [B]
|
||||
|
||||
Reference in New Issue
Block a user