fix(wan): Fix scheduler sigma schedule and add debug flags
This commit is contained in:
@@ -118,14 +118,12 @@ class WanModel(nn.Module):
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
|
||||
# Precompute sinusoidal inv_freq for time embedding
|
||||
# Use numpy float64 for precision (matches reference torch.float64),
|
||||
# then store as float32 since MLX GPU doesn't support float64.
|
||||
# Precompute sinusoidal inv_freq for time embedding.
|
||||
half = config.freq_dim // 2
|
||||
inv_freq_np = np.power(
|
||||
10000.0, -np.arange(half, dtype=np.float64) / half
|
||||
self._inv_freq = mx.array(
|
||||
np.power(10000.0, -np.arange(half, dtype=np.float64) / half
|
||||
).astype(np.float32)
|
||||
)
|
||||
self._inv_freq = mx.array(inv_freq_np.astype(np.float32))
|
||||
|
||||
|
||||
def _patchify(self, x: mx.array) -> tuple:
|
||||
@@ -311,13 +309,16 @@ class WanModel(nn.Module):
|
||||
axis=0,
|
||||
) # [B, seq_len, dim]
|
||||
|
||||
# Time embedding (use cached inv_freq to avoid recomputing each step)
|
||||
# Time embedding: sinusoidal from precomputed inv_freq.
|
||||
# inv_freq was computed in float64 for precision, stored as float32.
|
||||
# With integer timesteps (matching reference), float32 sin/cos is fine.
|
||||
if t.ndim == 0:
|
||||
t = t[None]
|
||||
|
||||
pos = t.astype(mx.float32)
|
||||
sinusoid = pos[..., None] * self._inv_freq
|
||||
sin_emb = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
|
||||
sinusoid = t[..., None].astype(mx.float32) * self._inv_freq
|
||||
sin_emb = mx.concatenate(
|
||||
[mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1
|
||||
)
|
||||
|
||||
if t.ndim == 1:
|
||||
# Standard T2V: scalar timestep per batch element [B]
|
||||
|
||||
@@ -12,13 +12,30 @@ import numpy as np
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def _compute_sigmas(num_steps: int, shift: float = 1.0) -> np.ndarray:
|
||||
"""Compute shifted sigma schedule matching official Wan2.2 code.
|
||||
def _compute_sigmas(
|
||||
num_steps: int, shift: float = 1.0, num_train_timesteps: int = 1000
|
||||
) -> np.ndarray:
|
||||
"""Compute shifted sigma schedule matching official Wan2.2 scheduler.
|
||||
|
||||
The reference creates FlowUniPCMultistepScheduler with shift=1 (identity)
|
||||
in the constructor, deriving sigma_max/sigma_min from the unshifted
|
||||
training schedule. Then set_timesteps() builds a linspace between those
|
||||
unshifted bounds and applies the actual shift once.
|
||||
|
||||
Returns num_steps+1 values (the last being 0.0 for the terminal state).
|
||||
"""
|
||||
sigmas = np.linspace(1.0, 0.0, num_steps + 1)[:num_steps]
|
||||
# sigma bounds from unshifted training schedule (constructor uses shift=1)
|
||||
alphas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps)[
|
||||
::-1
|
||||
]
|
||||
sigmas_unshifted = 1.0 - alphas
|
||||
sigma_max = float(sigmas_unshifted[0]) # (N-1)/N
|
||||
sigma_min = float(sigmas_unshifted[-1]) # 0.0
|
||||
|
||||
# Interpolate, then apply shift once (matching set_timesteps)
|
||||
sigmas = np.linspace(sigma_max, sigma_min, num_steps + 1)[:-1]
|
||||
sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas)
|
||||
|
||||
return np.append(sigmas, 0.0).astype(np.float32)
|
||||
|
||||
|
||||
@@ -31,9 +48,12 @@ class FlowMatchEulerScheduler:
|
||||
self.sigmas = None
|
||||
|
||||
def set_timesteps(self, num_steps: int, shift: float = 1.0):
|
||||
sigmas = _compute_sigmas(num_steps, shift)
|
||||
sigmas = _compute_sigmas(num_steps, shift, self.num_train_timesteps)
|
||||
self.sigmas = mx.array(sigmas)
|
||||
self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps)
|
||||
# Integer timesteps to match reference (model trained with int timesteps)
|
||||
self.timesteps = mx.array(
|
||||
(sigmas[:-1] * self.num_train_timesteps).astype(np.int64).astype(np.float32)
|
||||
)
|
||||
# Store as Python floats to avoid .item() sync in step()
|
||||
self._sigmas_float = sigmas.tolist()
|
||||
self._step_index = 0
|
||||
@@ -73,9 +93,11 @@ class FlowDPMPP2MScheduler:
|
||||
self.sigmas = None
|
||||
|
||||
def set_timesteps(self, num_steps: int, shift: float = 1.0):
|
||||
sigmas = _compute_sigmas(num_steps, shift)
|
||||
sigmas = _compute_sigmas(num_steps, shift, self.num_train_timesteps)
|
||||
self.sigmas = mx.array(sigmas)
|
||||
self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps)
|
||||
self.timesteps = mx.array(
|
||||
(sigmas[:-1] * self.num_train_timesteps).astype(np.int64).astype(np.float32)
|
||||
)
|
||||
# Store sigmas as Python floats for scalar math
|
||||
self._sigmas_float = sigmas.tolist()
|
||||
self._step_index = 0
|
||||
@@ -198,9 +220,11 @@ class FlowUniPCScheduler:
|
||||
self.sigmas = None
|
||||
|
||||
def set_timesteps(self, num_steps: int, shift: float = 1.0):
|
||||
sigmas = _compute_sigmas(num_steps, shift)
|
||||
sigmas = _compute_sigmas(num_steps, shift, self.num_train_timesteps)
|
||||
self.sigmas = mx.array(sigmas)
|
||||
self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps)
|
||||
self.timesteps = mx.array(
|
||||
(sigmas[:-1] * self.num_train_timesteps).astype(np.int64).astype(np.float32)
|
||||
)
|
||||
self._sigmas_float = sigmas.tolist()
|
||||
self._step_index = 0
|
||||
self._num_steps = num_steps
|
||||
|
||||
Reference in New Issue
Block a user