fix(wan): Fix scheduler sigma schedule and add debug flags

This commit is contained in:
Daniel
2026-03-11 07:52:07 +01:00
parent afd15018b7
commit d207275fea
4 changed files with 121 additions and 34 deletions

View File

@@ -12,13 +12,30 @@ import numpy as np
import mlx.core as mx
def _compute_sigmas(num_steps: int, shift: float = 1.0) -> np.ndarray:
"""Compute shifted sigma schedule matching official Wan2.2 code.
def _compute_sigmas(
num_steps: int, shift: float = 1.0, num_train_timesteps: int = 1000
) -> np.ndarray:
"""Compute shifted sigma schedule matching official Wan2.2 scheduler.
The reference creates FlowUniPCMultistepScheduler with shift=1 (identity)
in the constructor, deriving sigma_max/sigma_min from the unshifted
training schedule. Then set_timesteps() builds a linspace between those
unshifted bounds and applies the actual shift once.
Returns num_steps+1 values (the last being 0.0 for the terminal state).
"""
sigmas = np.linspace(1.0, 0.0, num_steps + 1)[:num_steps]
# sigma bounds from unshifted training schedule (constructor uses shift=1)
alphas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps)[
::-1
]
sigmas_unshifted = 1.0 - alphas
sigma_max = float(sigmas_unshifted[0]) # (N-1)/N
sigma_min = float(sigmas_unshifted[-1]) # 0.0
# Interpolate, then apply shift once (matching set_timesteps)
sigmas = np.linspace(sigma_max, sigma_min, num_steps + 1)[:-1]
sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas)
return np.append(sigmas, 0.0).astype(np.float32)
@@ -31,9 +48,12 @@ class FlowMatchEulerScheduler:
self.sigmas = None
def set_timesteps(self, num_steps: int, shift: float = 1.0):
sigmas = _compute_sigmas(num_steps, shift)
sigmas = _compute_sigmas(num_steps, shift, self.num_train_timesteps)
self.sigmas = mx.array(sigmas)
self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps)
# Integer timesteps to match reference (model trained with int timesteps)
self.timesteps = mx.array(
(sigmas[:-1] * self.num_train_timesteps).astype(np.int64).astype(np.float32)
)
# Store as Python floats to avoid .item() sync in step()
self._sigmas_float = sigmas.tolist()
self._step_index = 0
@@ -73,9 +93,11 @@ class FlowDPMPP2MScheduler:
self.sigmas = None
def set_timesteps(self, num_steps: int, shift: float = 1.0):
sigmas = _compute_sigmas(num_steps, shift)
sigmas = _compute_sigmas(num_steps, shift, self.num_train_timesteps)
self.sigmas = mx.array(sigmas)
self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps)
self.timesteps = mx.array(
(sigmas[:-1] * self.num_train_timesteps).astype(np.int64).astype(np.float32)
)
# Store sigmas as Python floats for scalar math
self._sigmas_float = sigmas.tolist()
self._step_index = 0
@@ -198,9 +220,11 @@ class FlowUniPCScheduler:
self.sigmas = None
def set_timesteps(self, num_steps: int, shift: float = 1.0):
sigmas = _compute_sigmas(num_steps, shift)
sigmas = _compute_sigmas(num_steps, shift, self.num_train_timesteps)
self.sigmas = mx.array(sigmas)
self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps)
self.timesteps = mx.array(
(sigmas[:-1] * self.num_train_timesteps).astype(np.int64).astype(np.float32)
)
self._sigmas_float = sigmas.tolist()
self._step_index = 0
self._num_steps = num_steps