fix(wan): Fix scheduler sigma schedule and add debug flags
This commit is contained in:
@@ -149,10 +149,11 @@ class TestComputeSigmas:
|
||||
sigmas = _compute_sigmas(10, shift=1.0)
|
||||
assert sigmas[-1] == 0.0
|
||||
|
||||
def test_starts_at_one(self):
|
||||
def test_starts_near_one(self):
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
sigmas = _compute_sigmas(20, shift=5.0)
|
||||
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-6)
|
||||
# Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0)
|
||||
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3)
|
||||
|
||||
def test_decreasing(self):
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
@@ -160,22 +161,33 @@ class TestComputeSigmas:
|
||||
assert np.all(np.diff(sigmas) <= 0)
|
||||
|
||||
def test_matches_official_wan22(self):
|
||||
"""Sigma schedule should match the official Wan2.2 get_sampling_sigmas."""
|
||||
"""Sigma schedule should match the official Wan2.2 FlowUniPCMultistepScheduler.
|
||||
|
||||
The reference creates the scheduler with shift=1 (identity) in the
|
||||
constructor, then passes the actual shift to set_timesteps. This means
|
||||
sigma_max/sigma_min come from the *unshifted* training schedule, and the
|
||||
shift is applied only once (single-shift).
|
||||
"""
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
steps, shift = 50, 5.0
|
||||
sigmas = _compute_sigmas(steps, shift)
|
||||
# Official: sigma = linspace(1, 0, steps+1)[:steps]; sigma = shift*sigma/(1+(shift-1)*sigma)
|
||||
official = np.linspace(1, 0, steps + 1)[:steps]
|
||||
official = shift * official / (1 + (shift - 1) * official)
|
||||
steps, shift, N = 50, 5.0, 1000
|
||||
sigmas = _compute_sigmas(steps, shift, N)
|
||||
# Official single-shift: unshifted bounds, then shift once
|
||||
alphas = np.linspace(1.0, 1.0 / N, N)[::-1]
|
||||
sigmas_unshifted = 1.0 - alphas
|
||||
sigma_max = float(sigmas_unshifted[0]) # 0.999
|
||||
sigma_min = float(sigmas_unshifted[-1]) # 0.0
|
||||
official = np.linspace(sigma_max, sigma_min, steps + 1)[:-1]
|
||||
official = shift * official / (1.0 + (shift - 1.0) * official)
|
||||
official = np.append(official, 0.0).astype(np.float32)
|
||||
np.testing.assert_allclose(sigmas, official, atol=1e-6)
|
||||
|
||||
def test_shift_one_is_linear(self):
|
||||
def test_shift_one_is_near_linear(self):
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
sigmas = _compute_sigmas(10, shift=1.0)
|
||||
# With shift=1, f(sigma)=sigma, so schedule is linear from 1 to 0
|
||||
# With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule)
|
||||
# so schedule is nearly linear from ~0.999 to 0
|
||||
expected = np.linspace(1, 0, 11).astype(np.float32)
|
||||
np.testing.assert_allclose(sigmas, expected, atol=1e-6)
|
||||
np.testing.assert_allclose(sigmas, expected, atol=2e-3)
|
||||
|
||||
def test_all_schedulers_same_sigmas(self):
|
||||
"""All three schedulers should produce identical sigma schedules."""
|
||||
@@ -655,10 +667,12 @@ class TestSchedulerCoherence:
|
||||
errors[name] = float(mx.mean(mx.abs(latents)).item())
|
||||
|
||||
# Higher-order solvers should not be significantly worse than Euler
|
||||
assert errors["dpm++"] <= errors["euler"] * 1.5, (
|
||||
# (add small epsilon to handle near-zero errors from floating point noise)
|
||||
eps = 1e-6
|
||||
assert errors["dpm++"] <= errors["euler"] * 1.5 + eps, (
|
||||
f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}"
|
||||
)
|
||||
assert errors["unipc"] <= errors["euler"] * 1.5, (
|
||||
assert errors["unipc"] <= errors["euler"] * 1.5 + eps, (
|
||||
f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}"
|
||||
)
|
||||
|
||||
@@ -746,7 +760,7 @@ class TestSchedulerCoherence:
|
||||
scheds = self._make_schedulers(steps, shift=shift)
|
||||
sigma_next = float(scheds["euler"].sigmas[1].item())
|
||||
sigma_cur = float(scheds["euler"].sigmas[0].item())
|
||||
assert abs(sigma_cur - 1.0) < 1e-6, "First sigma should be ~1.0"
|
||||
assert abs(sigma_cur - 1.0) < 1e-3, "First sigma should be ~1.0"
|
||||
|
||||
x0 = sample - sigma_cur * vel
|
||||
expected = sigma_next * sample + (1.0 - sigma_next) * x0
|
||||
@@ -756,7 +770,7 @@ class TestSchedulerCoherence:
|
||||
result = scheds[name].step(vel, scheds[name].timesteps[0], sample)
|
||||
mx.eval(result)
|
||||
np.testing.assert_allclose(
|
||||
np.array(result), np.array(expected), atol=1e-5,
|
||||
np.array(result), np.array(expected), atol=5e-4,
|
||||
err_msg=f"{name} step 0 doesn't match DDIM formula (shift={shift})",
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user