fix(wan): Fix scheduler sigma schedule and add debug flags

This commit is contained in:
Daniel
2026-03-11 07:52:07 +01:00
parent afd15018b7
commit d207275fea
4 changed files with 121 additions and 34 deletions

View File

@@ -149,10 +149,11 @@ class TestComputeSigmas:
sigmas = _compute_sigmas(10, shift=1.0)
assert sigmas[-1] == 0.0
def test_starts_at_one(self):
def test_starts_near_one(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-6)
# Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0)
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3)
def test_decreasing(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
@@ -160,22 +161,33 @@ class TestComputeSigmas:
assert np.all(np.diff(sigmas) <= 0)
def test_matches_official_wan22(self):
"""Sigma schedule should match the official Wan2.2 get_sampling_sigmas."""
"""Sigma schedule should match the official Wan2.2 FlowUniPCMultistepScheduler.
The reference creates the scheduler with shift=1 (identity) in the
constructor, then passes the actual shift to set_timesteps. This means
sigma_max/sigma_min come from the *unshifted* training schedule, and the
shift is applied only once (single-shift).
"""
from mlx_video.models.wan.scheduler import _compute_sigmas
steps, shift = 50, 5.0
sigmas = _compute_sigmas(steps, shift)
# Official: sigma = linspace(1, 0, steps+1)[:steps]; sigma = shift*sigma/(1+(shift-1)*sigma)
official = np.linspace(1, 0, steps + 1)[:steps]
official = shift * official / (1 + (shift - 1) * official)
steps, shift, N = 50, 5.0, 1000
sigmas = _compute_sigmas(steps, shift, N)
# Official single-shift: unshifted bounds, then shift once
alphas = np.linspace(1.0, 1.0 / N, N)[::-1]
sigmas_unshifted = 1.0 - alphas
sigma_max = float(sigmas_unshifted[0]) # 0.999
sigma_min = float(sigmas_unshifted[-1]) # 0.0
official = np.linspace(sigma_max, sigma_min, steps + 1)[:-1]
official = shift * official / (1.0 + (shift - 1.0) * official)
official = np.append(official, 0.0).astype(np.float32)
np.testing.assert_allclose(sigmas, official, atol=1e-6)
def test_shift_one_is_linear(self):
def test_shift_one_is_near_linear(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(10, shift=1.0)
# With shift=1, f(sigma)=sigma, so schedule is linear from 1 to 0
# With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule)
# so schedule is nearly linear from ~0.999 to 0
expected = np.linspace(1, 0, 11).astype(np.float32)
np.testing.assert_allclose(sigmas, expected, atol=1e-6)
np.testing.assert_allclose(sigmas, expected, atol=2e-3)
def test_all_schedulers_same_sigmas(self):
"""All three schedulers should produce identical sigma schedules."""
@@ -655,10 +667,12 @@ class TestSchedulerCoherence:
errors[name] = float(mx.mean(mx.abs(latents)).item())
# Higher-order solvers should not be significantly worse than Euler
assert errors["dpm++"] <= errors["euler"] * 1.5, (
# (add small epsilon to handle near-zero errors from floating point noise)
eps = 1e-6
assert errors["dpm++"] <= errors["euler"] * 1.5 + eps, (
f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}"
)
assert errors["unipc"] <= errors["euler"] * 1.5, (
assert errors["unipc"] <= errors["euler"] * 1.5 + eps, (
f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}"
)
@@ -746,7 +760,7 @@ class TestSchedulerCoherence:
scheds = self._make_schedulers(steps, shift=shift)
sigma_next = float(scheds["euler"].sigmas[1].item())
sigma_cur = float(scheds["euler"].sigmas[0].item())
assert abs(sigma_cur - 1.0) < 1e-6, "First sigma should be ~1.0"
assert abs(sigma_cur - 1.0) < 1e-3, "First sigma should be ~1.0"
x0 = sample - sigma_cur * vel
expected = sigma_next * sample + (1.0 - sigma_next) * x0
@@ -756,7 +770,7 @@ class TestSchedulerCoherence:
result = scheds[name].step(vel, scheds[name].timesteps[0], sample)
mx.eval(result)
np.testing.assert_allclose(
np.array(result), np.array(expected), atol=1e-5,
np.array(result), np.array(expected), atol=5e-4,
err_msg=f"{name} step 0 doesn't match DDIM formula (shift={shift})",
)