This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -6,14 +6,15 @@ import mlx.core as mx
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# Euler Scheduler Tests
# ---------------------------------------------------------------------------
class TestFlowMatchEulerScheduler:
def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
assert sched.num_train_timesteps == 1000
assert sched.timesteps is None
@@ -21,6 +22,7 @@ class TestFlowMatchEulerScheduler:
def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(40, shift=12.0)
mx.eval(sched.timesteps, sched.sigmas)
@@ -29,6 +31,7 @@ class TestFlowMatchEulerScheduler:
def test_timesteps_decreasing(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(40, shift=12.0)
mx.eval(sched.timesteps)
@@ -38,6 +41,7 @@ class TestFlowMatchEulerScheduler:
def test_sigmas_decreasing(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(20, shift=1.0)
mx.eval(sched.sigmas)
@@ -46,6 +50,7 @@ class TestFlowMatchEulerScheduler:
def test_terminal_sigma_is_zero(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(20, shift=5.0)
mx.eval(sched.sigmas)
@@ -54,6 +59,7 @@ class TestFlowMatchEulerScheduler:
def test_shift_effect(self):
"""Larger shift should push sigmas toward higher values."""
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched1 = FlowMatchEulerScheduler()
sched2 = FlowMatchEulerScheduler()
sched1.set_timesteps(20, shift=1.0)
@@ -65,6 +71,7 @@ class TestFlowMatchEulerScheduler:
def test_step_euler(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(10, shift=1.0)
mx.eval(sched.sigmas)
@@ -82,11 +89,14 @@ class TestFlowMatchEulerScheduler:
# Euler: x_next = x + (sigma_next - sigma) * v
expected = 1.0 + (sigma_next - sigma) * 0.5
np.testing.assert_allclose(
np.array(result).flatten()[0], expected, rtol=1e-4,
np.array(result).flatten()[0],
expected,
rtol=1e-4,
)
def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
assert sched._step_index == 0
@@ -99,6 +109,7 @@ class TestFlowMatchEulerScheduler:
def test_reset(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1))
@@ -111,6 +122,7 @@ class TestFlowMatchEulerScheduler:
@pytest.mark.parametrize("steps", [10, 20, 40, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(steps, shift=12.0)
mx.eval(sched.timesteps, sched.sigmas)
@@ -120,6 +132,7 @@ class TestFlowMatchEulerScheduler:
def test_full_denoise_loop(self):
"""Run a complete denoise loop with zero velocity -> sample unchanged."""
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 2, 1, 2, 2))
@@ -141,22 +154,26 @@ class TestComputeSigmas:
def test_length(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
assert len(sigmas) == 21 # num_steps + terminal
def test_terminal_zero(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(10, shift=1.0)
assert sigmas[-1] == 0.0
def test_starts_near_one(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
# Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0)
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3)
def test_decreasing(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
assert np.all(np.diff(sigmas) <= 0)
@@ -169,6 +186,7 @@ class TestComputeSigmas:
shift is applied only once (single-shift).
"""
from mlx_video.models.wan.scheduler import _compute_sigmas
steps, shift, N = 50, 5.0, 1000
sigmas = _compute_sigmas(steps, shift, N)
# Official single-shift: unshifted bounds, then shift once
@@ -183,6 +201,7 @@ class TestComputeSigmas:
def test_shift_one_is_near_linear(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(10, shift=1.0)
# With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule)
# so schedule is nearly linear from ~0.999 to 0
@@ -196,6 +215,7 @@ class TestComputeSigmas:
FlowMatchEulerScheduler,
FlowUniPCScheduler,
)
scheds = [
FlowMatchEulerScheduler(1000),
FlowDPMPP2MScheduler(1000),
@@ -214,6 +234,7 @@ class TestComputeSigmas:
FlowMatchEulerScheduler,
FlowUniPCScheduler,
)
scheds = [
FlowMatchEulerScheduler(1000),
FlowDPMPP2MScheduler(1000),
@@ -235,12 +256,14 @@ class TestComputeSigmas:
class TestFlowDPMPP2MScheduler:
def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
assert sched.num_train_timesteps == 1000
assert sched.lower_order_final is True
def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(20, shift=5.0)
mx.eval(sched.timesteps, sched.sigmas)
@@ -249,6 +272,7 @@ class TestFlowDPMPP2MScheduler:
def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 4, 1, 2, 2))
@@ -261,6 +285,7 @@ class TestFlowDPMPP2MScheduler:
def test_reset(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1))
@@ -272,6 +297,7 @@ class TestFlowDPMPP2MScheduler:
def test_full_loop_finite(self):
"""Full loop with constant velocity should produce finite output."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=1.0)
sample = mx.ones((1, 2, 1, 2, 2))
@@ -284,6 +310,7 @@ class TestFlowDPMPP2MScheduler:
def test_first_step_is_first_order(self):
"""First step should use 1st-order (no prev_x0 available)."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 4, 2, 4, 4))
@@ -298,6 +325,7 @@ class TestFlowDPMPP2MScheduler:
def test_second_step_uses_correction(self):
"""After first step, DPM++ should have stored prev_x0 for correction."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 4, 1, 2, 2))
@@ -314,11 +342,14 @@ class TestFlowDPMPP2MScheduler:
x0_after_second = sched._prev_x0
assert x0_after_second is not None
# The stored x0 should differ from the first step's
assert not np.allclose(np.array(x0_after_first), np.array(x0_after_second), atol=1e-6)
assert not np.allclose(
np.array(x0_after_first), np.array(x0_after_second), atol=1e-6
)
def test_denoise_to_target(self):
"""Perfect oracle should denoise to target with any solver."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(20, shift=5.0)
target = mx.zeros((1, 2, 1, 4, 4))
@@ -333,6 +364,7 @@ class TestFlowDPMPP2MScheduler:
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(steps, shift=5.0)
mx.eval(sched.timesteps, sched.sigmas)
@@ -342,6 +374,7 @@ class TestFlowDPMPP2MScheduler:
def test_terminal_sigma_produces_x0(self):
"""When sigma_next=0 the scheduler should return x0 directly."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1)) * 3.0
@@ -362,6 +395,7 @@ class TestFlowDPMPP2MScheduler:
class TestFlowUniPCScheduler:
def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
assert sched.num_train_timesteps == 1000
assert sched.solver_order == 2
@@ -369,6 +403,7 @@ class TestFlowUniPCScheduler:
def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(30, shift=12.0)
mx.eval(sched.timesteps, sched.sigmas)
@@ -377,6 +412,7 @@ class TestFlowUniPCScheduler:
def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1))
@@ -387,6 +423,7 @@ class TestFlowUniPCScheduler:
def test_reset(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1))
@@ -399,6 +436,7 @@ class TestFlowUniPCScheduler:
def test_full_loop_finite(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(10, shift=1.0)
sample = mx.ones((1, 2, 1, 2, 2))
@@ -411,6 +449,7 @@ class TestFlowUniPCScheduler:
def test_corrector_not_applied_first_step(self):
"""First step should skip the corrector (no history)."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True)
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 4, 1, 2, 2))
@@ -424,6 +463,7 @@ class TestFlowUniPCScheduler:
def test_corrector_applied_after_first_step(self):
"""Steps after the first should use the corrector when enabled."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True)
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 2, 1, 4, 4))
@@ -436,6 +476,7 @@ class TestFlowUniPCScheduler:
def test_denoise_to_target(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(20, shift=5.0)
target = mx.zeros((1, 2, 1, 4, 4))
@@ -450,6 +491,7 @@ class TestFlowUniPCScheduler:
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(steps, shift=5.0)
mx.eval(sched.timesteps, sched.sigmas)
@@ -459,6 +501,7 @@ class TestFlowUniPCScheduler:
def test_disable_corrector(self):
"""Disabling corrector on step 0 should still work without error."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0])
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 2, 2))
@@ -471,6 +514,7 @@ class TestFlowUniPCScheduler:
def test_solver_order_3(self):
"""Order 3 should work without error."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(solver_order=3, use_corrector=True)
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 2, 1, 2, 2))
@@ -483,6 +527,7 @@ class TestFlowUniPCScheduler:
def test_corrector_rhos_c_not_hardcoded(self):
"""Corrector rhos_c should be computed via linalg.solve, not hardcoded 0.5."""
import math
# For 50-step schedule with shift=5.0, order 2 corrector at step 5:
# rhos_c[0] (history) should be ~0.07, NOT 0.5
# rhos_c[1] (D1_t) should be ~0.45, NOT 0.5
@@ -525,16 +570,23 @@ class TestFlowUniPCScheduler:
rhos_c = np.linalg.solve(R, b)
# History weight should be small (~0.07-0.09), not 0.5
assert rhos_c[0] < 0.15, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large"
assert rhos_c[0] > 0.0, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive"
assert (
rhos_c[0] < 0.15
), f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large"
assert (
rhos_c[0] > 0.0
), f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive"
# D1_t weight should be ~0.42-0.45, not 0.5
assert 0.3 < rhos_c[1] < 0.5, f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range"
assert (
0.3 < rhos_c[1] < 0.5
), f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range"
# ---------------------------------------------------------------------------
# Scheduler Coherence Tests
# ---------------------------------------------------------------------------
class TestSchedulerCoherence:
"""Tests that Euler, DPM++, and UniPC schedulers produce coherent results.
@@ -599,11 +651,15 @@ class TestSchedulerCoherence:
results[name] = np.array(r)
np.testing.assert_allclose(
results["dpm++"], results["euler"], atol=1e-5,
results["dpm++"],
results["euler"],
atol=1e-5,
err_msg="DPM++ step 0 should match Euler",
)
np.testing.assert_allclose(
results["unipc"], results["euler"], atol=1e-5,
results["unipc"],
results["euler"],
atol=1e-5,
err_msg="UniPC step 0 should match Euler",
)
@@ -621,11 +677,15 @@ class TestSchedulerCoherence:
unipc_r = scheds["unipc"].step(vel, scheds["unipc"].timesteps[0], noise)
mx.eval(euler_r, dpm_r, unipc_r)
np.testing.assert_allclose(
np.array(dpm_r), np.array(euler_r), atol=1e-5,
np.array(dpm_r),
np.array(euler_r),
atol=1e-5,
err_msg=f"DPM++ step 0 differs from Euler at shift={shift}",
)
np.testing.assert_allclose(
np.array(unipc_r), np.array(euler_r), atol=1e-5,
np.array(unipc_r),
np.array(euler_r),
atol=1e-5,
err_msg=f"UniPC step 0 differs from Euler at shift={shift}",
)
@@ -644,7 +704,9 @@ class TestSchedulerCoherence:
latents = sched.step(v, sched.timesteps[i], latents)
mx.eval(latents)
np.testing.assert_allclose(
np.array(latents), 0.0, atol=1e-3,
np.array(latents),
0.0,
atol=1e-3,
err_msg=f"{name} did not converge to target with oracle",
)
@@ -669,12 +731,12 @@ class TestSchedulerCoherence:
# Higher-order solvers should not be significantly worse than Euler
# (add small epsilon to handle near-zero errors from floating point noise)
eps = 1e-6
assert errors["dpm++"] <= errors["euler"] * 1.5 + eps, (
f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}"
)
assert errors["unipc"] <= errors["euler"] * 1.5 + eps, (
f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}"
)
assert (
errors["dpm++"] <= errors["euler"] * 1.5 + eps
), f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}"
assert (
errors["unipc"] <= errors["euler"] * 1.5 + eps
), f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}"
def test_multistep_trajectory_similar_magnitude(self):
"""Over a full denoising loop with constant velocity, all solvers
@@ -696,9 +758,9 @@ class TestSchedulerCoherence:
# All solvers should produce results within the same order of magnitude
vals = list(final_means.values())
ratio = max(vals) / max(min(vals), 1e-10)
assert ratio < 10.0, (
f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}"
)
assert (
ratio < 10.0
), f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}"
def test_intermediate_values_finite(self):
"""Every intermediate latent value must be finite for all solvers."""
@@ -712,9 +774,9 @@ class TestSchedulerCoherence:
vel = mx.random.normal(shape)
latents = sched.step(vel, sched.timesteps[i], latents)
mx.eval(latents)
assert np.isfinite(np.array(latents)).all(), (
f"{name} produced non-finite values at step {i}"
)
assert np.isfinite(
np.array(latents)
).all(), f"{name} produced non-finite values at step {i}"
def test_lambda_boundary_values(self):
"""_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0."""
@@ -724,17 +786,17 @@ class TestSchedulerCoherence:
)
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
assert cls._lambda(1.0) == -math.inf, (
f"{cls.__name__}._lambda(1.0) should be -inf"
)
assert cls._lambda(0.0) == math.inf, (
f"{cls.__name__}._lambda(0.0) should be +inf"
)
assert (
cls._lambda(1.0) == -math.inf
), f"{cls.__name__}._lambda(1.0) should be -inf"
assert (
cls._lambda(0.0) == math.inf
), f"{cls.__name__}._lambda(0.0) should be +inf"
# Interior values should be finite
lam = cls._lambda(0.5)
assert math.isfinite(lam) and lam == 0.0, (
f"{cls.__name__}._lambda(0.5) should be 0.0"
)
assert (
math.isfinite(lam) and lam == 0.0
), f"{cls.__name__}._lambda(0.5) should be 0.0"
def test_lambda_monotonically_decreasing(self):
"""_lambda(sigma) should decrease as sigma increases (more noise → lower SNR)."""
@@ -770,7 +832,9 @@ class TestSchedulerCoherence:
result = scheds[name].step(vel, scheds[name].timesteps[0], sample)
mx.eval(result)
np.testing.assert_allclose(
np.array(result), np.array(expected), atol=5e-4,
np.array(result),
np.array(expected),
atol=5e-4,
err_msg=f"{name} step 0 doesn't match DDIM formula (shift={shift})",
)
@@ -790,10 +854,14 @@ class TestSchedulerCoherence:
results[name] = np.array(r)
np.testing.assert_allclose(
results["dpm++"], results["euler"], atol=1e-5,
results["dpm++"],
results["euler"],
atol=1e-5,
)
np.testing.assert_allclose(
results["unipc"], results["euler"], atol=1e-5,
results["unipc"],
results["euler"],
atol=1e-5,
)
def test_dpmpp_unipc_agree_on_step1(self):
@@ -834,7 +902,10 @@ class TestSchedulerCoherence:
shape = (1, 2, 1, 2, 2)
noise = mx.random.normal(shape)
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler, FlowUniPCScheduler
from mlx_video.models.wan.scheduler import (
FlowDPMPP2MScheduler,
FlowUniPCScheduler,
)
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
sched = cls()
@@ -857,14 +928,19 @@ class TestSchedulerCoherence:
mx.eval(latents)
result2 = np.array(latents)
np.testing.assert_allclose(result1, result2, atol=1e-5,
err_msg=f"{cls.__name__} not reproducible after reset()")
np.testing.assert_allclose(
result1,
result2,
atol=1e-5,
err_msg=f"{cls.__name__} not reproducible after reset()",
)
# ---------------------------------------------------------------------------
# UniPC Corrector Default Tests
# ---------------------------------------------------------------------------
class TestUniPCCorrectorDefault:
"""Tests that the UniPC corrector is enabled by default,
matching official FlowUniPCMultistepScheduler behavior."""
@@ -872,12 +948,14 @@ class TestUniPCCorrectorDefault:
def test_corrector_enabled_by_default(self):
"""Default construction should have corrector enabled."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
assert sched._use_corrector is True
def test_corrector_affects_output(self):
"""Corrector should produce different results than no corrector after step 1."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
mx.random.seed(42)
shape = (1, 4, 1, 4, 4)
noise = mx.random.normal(shape)
@@ -901,6 +979,7 @@ class TestUniPCCorrectorDefault:
def test_corrector_does_not_affect_first_step(self):
"""Step 0 should be identical regardless of corrector setting."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
mx.random.seed(42)
shape = (1, 4, 1, 4, 4)
noise = mx.random.normal(shape)