format
This commit is contained in:
@@ -6,14 +6,15 @@ import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Euler Scheduler Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlowMatchEulerScheduler:
|
||||
def test_initialization(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
assert sched.num_train_timesteps == 1000
|
||||
assert sched.timesteps is None
|
||||
@@ -21,6 +22,7 @@ class TestFlowMatchEulerScheduler:
|
||||
|
||||
def test_set_timesteps(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(40, shift=12.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
@@ -29,6 +31,7 @@ class TestFlowMatchEulerScheduler:
|
||||
|
||||
def test_timesteps_decreasing(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(40, shift=12.0)
|
||||
mx.eval(sched.timesteps)
|
||||
@@ -38,6 +41,7 @@ class TestFlowMatchEulerScheduler:
|
||||
|
||||
def test_sigmas_decreasing(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(20, shift=1.0)
|
||||
mx.eval(sched.sigmas)
|
||||
@@ -46,6 +50,7 @@ class TestFlowMatchEulerScheduler:
|
||||
|
||||
def test_terminal_sigma_is_zero(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(20, shift=5.0)
|
||||
mx.eval(sched.sigmas)
|
||||
@@ -54,6 +59,7 @@ class TestFlowMatchEulerScheduler:
|
||||
def test_shift_effect(self):
|
||||
"""Larger shift should push sigmas toward higher values."""
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched1 = FlowMatchEulerScheduler()
|
||||
sched2 = FlowMatchEulerScheduler()
|
||||
sched1.set_timesteps(20, shift=1.0)
|
||||
@@ -65,6 +71,7 @@ class TestFlowMatchEulerScheduler:
|
||||
|
||||
def test_step_euler(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(10, shift=1.0)
|
||||
mx.eval(sched.sigmas)
|
||||
@@ -82,11 +89,14 @@ class TestFlowMatchEulerScheduler:
|
||||
# Euler: x_next = x + (sigma_next - sigma) * v
|
||||
expected = 1.0 + (sigma_next - sigma) * 0.5
|
||||
np.testing.assert_allclose(
|
||||
np.array(result).flatten()[0], expected, rtol=1e-4,
|
||||
np.array(result).flatten()[0],
|
||||
expected,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
def test_step_index_increments(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
assert sched._step_index == 0
|
||||
@@ -99,6 +109,7 @@ class TestFlowMatchEulerScheduler:
|
||||
|
||||
def test_reset(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 1, 1))
|
||||
@@ -111,6 +122,7 @@ class TestFlowMatchEulerScheduler:
|
||||
@pytest.mark.parametrize("steps", [10, 20, 40, 50])
|
||||
def test_various_step_counts(self, steps):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(steps, shift=12.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
@@ -120,6 +132,7 @@ class TestFlowMatchEulerScheduler:
|
||||
def test_full_denoise_loop(self):
|
||||
"""Run a complete denoise loop with zero velocity -> sample unchanged."""
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 2, 1, 2, 2))
|
||||
@@ -141,22 +154,26 @@ class TestComputeSigmas:
|
||||
|
||||
def test_length(self):
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(20, shift=5.0)
|
||||
assert len(sigmas) == 21 # num_steps + terminal
|
||||
|
||||
def test_terminal_zero(self):
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(10, shift=1.0)
|
||||
assert sigmas[-1] == 0.0
|
||||
|
||||
def test_starts_near_one(self):
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(20, shift=5.0)
|
||||
# Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0)
|
||||
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3)
|
||||
|
||||
def test_decreasing(self):
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(20, shift=5.0)
|
||||
assert np.all(np.diff(sigmas) <= 0)
|
||||
|
||||
@@ -169,6 +186,7 @@ class TestComputeSigmas:
|
||||
shift is applied only once (single-shift).
|
||||
"""
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
|
||||
steps, shift, N = 50, 5.0, 1000
|
||||
sigmas = _compute_sigmas(steps, shift, N)
|
||||
# Official single-shift: unshifted bounds, then shift once
|
||||
@@ -183,6 +201,7 @@ class TestComputeSigmas:
|
||||
|
||||
def test_shift_one_is_near_linear(self):
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(10, shift=1.0)
|
||||
# With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule)
|
||||
# so schedule is nearly linear from ~0.999 to 0
|
||||
@@ -196,6 +215,7 @@ class TestComputeSigmas:
|
||||
FlowMatchEulerScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
|
||||
scheds = [
|
||||
FlowMatchEulerScheduler(1000),
|
||||
FlowDPMPP2MScheduler(1000),
|
||||
@@ -214,6 +234,7 @@ class TestComputeSigmas:
|
||||
FlowMatchEulerScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
|
||||
scheds = [
|
||||
FlowMatchEulerScheduler(1000),
|
||||
FlowDPMPP2MScheduler(1000),
|
||||
@@ -235,12 +256,14 @@ class TestComputeSigmas:
|
||||
class TestFlowDPMPP2MScheduler:
|
||||
def test_initialization(self):
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
assert sched.num_train_timesteps == 1000
|
||||
assert sched.lower_order_final is True
|
||||
|
||||
def test_set_timesteps(self):
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(20, shift=5.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
@@ -249,6 +272,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
|
||||
def test_step_index_increments(self):
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 4, 1, 2, 2))
|
||||
@@ -261,6 +285,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
|
||||
def test_reset(self):
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 1, 1))
|
||||
@@ -272,6 +297,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
def test_full_loop_finite(self):
|
||||
"""Full loop with constant velocity should produce finite output."""
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(10, shift=1.0)
|
||||
sample = mx.ones((1, 2, 1, 2, 2))
|
||||
@@ -284,6 +310,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
def test_first_step_is_first_order(self):
|
||||
"""First step should use 1st-order (no prev_x0 available)."""
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
sample = mx.random.normal((1, 4, 2, 4, 4))
|
||||
@@ -298,6 +325,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
def test_second_step_uses_correction(self):
|
||||
"""After first step, DPM++ should have stored prev_x0 for correction."""
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
sample = mx.random.normal((1, 4, 1, 2, 2))
|
||||
@@ -314,11 +342,14 @@ class TestFlowDPMPP2MScheduler:
|
||||
x0_after_second = sched._prev_x0
|
||||
assert x0_after_second is not None
|
||||
# The stored x0 should differ from the first step's
|
||||
assert not np.allclose(np.array(x0_after_first), np.array(x0_after_second), atol=1e-6)
|
||||
assert not np.allclose(
|
||||
np.array(x0_after_first), np.array(x0_after_second), atol=1e-6
|
||||
)
|
||||
|
||||
def test_denoise_to_target(self):
|
||||
"""Perfect oracle should denoise to target with any solver."""
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(20, shift=5.0)
|
||||
target = mx.zeros((1, 2, 1, 4, 4))
|
||||
@@ -333,6 +364,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
|
||||
def test_various_step_counts(self, steps):
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(steps, shift=5.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
@@ -342,6 +374,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
def test_terminal_sigma_produces_x0(self):
|
||||
"""When sigma_next=0 the scheduler should return x0 directly."""
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 1, 1)) * 3.0
|
||||
@@ -362,6 +395,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
class TestFlowUniPCScheduler:
|
||||
def test_initialization(self):
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
assert sched.num_train_timesteps == 1000
|
||||
assert sched.solver_order == 2
|
||||
@@ -369,6 +403,7 @@ class TestFlowUniPCScheduler:
|
||||
|
||||
def test_set_timesteps(self):
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(30, shift=12.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
@@ -377,6 +412,7 @@ class TestFlowUniPCScheduler:
|
||||
|
||||
def test_step_index_increments(self):
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 1, 1))
|
||||
@@ -387,6 +423,7 @@ class TestFlowUniPCScheduler:
|
||||
|
||||
def test_reset(self):
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 1, 1))
|
||||
@@ -399,6 +436,7 @@ class TestFlowUniPCScheduler:
|
||||
|
||||
def test_full_loop_finite(self):
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(10, shift=1.0)
|
||||
sample = mx.ones((1, 2, 1, 2, 2))
|
||||
@@ -411,6 +449,7 @@ class TestFlowUniPCScheduler:
|
||||
def test_corrector_not_applied_first_step(self):
|
||||
"""First step should skip the corrector (no history)."""
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler(use_corrector=True)
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
sample = mx.random.normal((1, 4, 1, 2, 2))
|
||||
@@ -424,6 +463,7 @@ class TestFlowUniPCScheduler:
|
||||
def test_corrector_applied_after_first_step(self):
|
||||
"""Steps after the first should use the corrector when enabled."""
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler(use_corrector=True)
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
sample = mx.random.normal((1, 2, 1, 4, 4))
|
||||
@@ -436,6 +476,7 @@ class TestFlowUniPCScheduler:
|
||||
|
||||
def test_denoise_to_target(self):
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(20, shift=5.0)
|
||||
target = mx.zeros((1, 2, 1, 4, 4))
|
||||
@@ -450,6 +491,7 @@ class TestFlowUniPCScheduler:
|
||||
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
|
||||
def test_various_step_counts(self, steps):
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(steps, shift=5.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
@@ -459,6 +501,7 @@ class TestFlowUniPCScheduler:
|
||||
def test_disable_corrector(self):
|
||||
"""Disabling corrector on step 0 should still work without error."""
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0])
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 2, 2))
|
||||
@@ -471,6 +514,7 @@ class TestFlowUniPCScheduler:
|
||||
def test_solver_order_3(self):
|
||||
"""Order 3 should work without error."""
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler(solver_order=3, use_corrector=True)
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
sample = mx.random.normal((1, 2, 1, 2, 2))
|
||||
@@ -483,6 +527,7 @@ class TestFlowUniPCScheduler:
|
||||
def test_corrector_rhos_c_not_hardcoded(self):
|
||||
"""Corrector rhos_c should be computed via linalg.solve, not hardcoded 0.5."""
|
||||
import math
|
||||
|
||||
# For 50-step schedule with shift=5.0, order 2 corrector at step 5:
|
||||
# rhos_c[0] (history) should be ~0.07, NOT 0.5
|
||||
# rhos_c[1] (D1_t) should be ~0.45, NOT 0.5
|
||||
@@ -525,16 +570,23 @@ class TestFlowUniPCScheduler:
|
||||
rhos_c = np.linalg.solve(R, b)
|
||||
|
||||
# History weight should be small (~0.07-0.09), not 0.5
|
||||
assert rhos_c[0] < 0.15, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large"
|
||||
assert rhos_c[0] > 0.0, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive"
|
||||
assert (
|
||||
rhos_c[0] < 0.15
|
||||
), f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large"
|
||||
assert (
|
||||
rhos_c[0] > 0.0
|
||||
), f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive"
|
||||
# D1_t weight should be ~0.42-0.45, not 0.5
|
||||
assert 0.3 < rhos_c[1] < 0.5, f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range"
|
||||
assert (
|
||||
0.3 < rhos_c[1] < 0.5
|
||||
), f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scheduler Coherence Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSchedulerCoherence:
|
||||
"""Tests that Euler, DPM++, and UniPC schedulers produce coherent results.
|
||||
|
||||
@@ -599,11 +651,15 @@ class TestSchedulerCoherence:
|
||||
results[name] = np.array(r)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
results["dpm++"], results["euler"], atol=1e-5,
|
||||
results["dpm++"],
|
||||
results["euler"],
|
||||
atol=1e-5,
|
||||
err_msg="DPM++ step 0 should match Euler",
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
results["unipc"], results["euler"], atol=1e-5,
|
||||
results["unipc"],
|
||||
results["euler"],
|
||||
atol=1e-5,
|
||||
err_msg="UniPC step 0 should match Euler",
|
||||
)
|
||||
|
||||
@@ -621,11 +677,15 @@ class TestSchedulerCoherence:
|
||||
unipc_r = scheds["unipc"].step(vel, scheds["unipc"].timesteps[0], noise)
|
||||
mx.eval(euler_r, dpm_r, unipc_r)
|
||||
np.testing.assert_allclose(
|
||||
np.array(dpm_r), np.array(euler_r), atol=1e-5,
|
||||
np.array(dpm_r),
|
||||
np.array(euler_r),
|
||||
atol=1e-5,
|
||||
err_msg=f"DPM++ step 0 differs from Euler at shift={shift}",
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
np.array(unipc_r), np.array(euler_r), atol=1e-5,
|
||||
np.array(unipc_r),
|
||||
np.array(euler_r),
|
||||
atol=1e-5,
|
||||
err_msg=f"UniPC step 0 differs from Euler at shift={shift}",
|
||||
)
|
||||
|
||||
@@ -644,7 +704,9 @@ class TestSchedulerCoherence:
|
||||
latents = sched.step(v, sched.timesteps[i], latents)
|
||||
mx.eval(latents)
|
||||
np.testing.assert_allclose(
|
||||
np.array(latents), 0.0, atol=1e-3,
|
||||
np.array(latents),
|
||||
0.0,
|
||||
atol=1e-3,
|
||||
err_msg=f"{name} did not converge to target with oracle",
|
||||
)
|
||||
|
||||
@@ -669,12 +731,12 @@ class TestSchedulerCoherence:
|
||||
# Higher-order solvers should not be significantly worse than Euler
|
||||
# (add small epsilon to handle near-zero errors from floating point noise)
|
||||
eps = 1e-6
|
||||
assert errors["dpm++"] <= errors["euler"] * 1.5 + eps, (
|
||||
f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}"
|
||||
)
|
||||
assert errors["unipc"] <= errors["euler"] * 1.5 + eps, (
|
||||
f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}"
|
||||
)
|
||||
assert (
|
||||
errors["dpm++"] <= errors["euler"] * 1.5 + eps
|
||||
), f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}"
|
||||
assert (
|
||||
errors["unipc"] <= errors["euler"] * 1.5 + eps
|
||||
), f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}"
|
||||
|
||||
def test_multistep_trajectory_similar_magnitude(self):
|
||||
"""Over a full denoising loop with constant velocity, all solvers
|
||||
@@ -696,9 +758,9 @@ class TestSchedulerCoherence:
|
||||
# All solvers should produce results within the same order of magnitude
|
||||
vals = list(final_means.values())
|
||||
ratio = max(vals) / max(min(vals), 1e-10)
|
||||
assert ratio < 10.0, (
|
||||
f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}"
|
||||
)
|
||||
assert (
|
||||
ratio < 10.0
|
||||
), f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}"
|
||||
|
||||
def test_intermediate_values_finite(self):
|
||||
"""Every intermediate latent value must be finite for all solvers."""
|
||||
@@ -712,9 +774,9 @@ class TestSchedulerCoherence:
|
||||
vel = mx.random.normal(shape)
|
||||
latents = sched.step(vel, sched.timesteps[i], latents)
|
||||
mx.eval(latents)
|
||||
assert np.isfinite(np.array(latents)).all(), (
|
||||
f"{name} produced non-finite values at step {i}"
|
||||
)
|
||||
assert np.isfinite(
|
||||
np.array(latents)
|
||||
).all(), f"{name} produced non-finite values at step {i}"
|
||||
|
||||
def test_lambda_boundary_values(self):
|
||||
"""_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0."""
|
||||
@@ -724,17 +786,17 @@ class TestSchedulerCoherence:
|
||||
)
|
||||
|
||||
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
|
||||
assert cls._lambda(1.0) == -math.inf, (
|
||||
f"{cls.__name__}._lambda(1.0) should be -inf"
|
||||
)
|
||||
assert cls._lambda(0.0) == math.inf, (
|
||||
f"{cls.__name__}._lambda(0.0) should be +inf"
|
||||
)
|
||||
assert (
|
||||
cls._lambda(1.0) == -math.inf
|
||||
), f"{cls.__name__}._lambda(1.0) should be -inf"
|
||||
assert (
|
||||
cls._lambda(0.0) == math.inf
|
||||
), f"{cls.__name__}._lambda(0.0) should be +inf"
|
||||
# Interior values should be finite
|
||||
lam = cls._lambda(0.5)
|
||||
assert math.isfinite(lam) and lam == 0.0, (
|
||||
f"{cls.__name__}._lambda(0.5) should be 0.0"
|
||||
)
|
||||
assert (
|
||||
math.isfinite(lam) and lam == 0.0
|
||||
), f"{cls.__name__}._lambda(0.5) should be 0.0"
|
||||
|
||||
def test_lambda_monotonically_decreasing(self):
|
||||
"""_lambda(sigma) should decrease as sigma increases (more noise → lower SNR)."""
|
||||
@@ -770,7 +832,9 @@ class TestSchedulerCoherence:
|
||||
result = scheds[name].step(vel, scheds[name].timesteps[0], sample)
|
||||
mx.eval(result)
|
||||
np.testing.assert_allclose(
|
||||
np.array(result), np.array(expected), atol=5e-4,
|
||||
np.array(result),
|
||||
np.array(expected),
|
||||
atol=5e-4,
|
||||
err_msg=f"{name} step 0 doesn't match DDIM formula (shift={shift})",
|
||||
)
|
||||
|
||||
@@ -790,10 +854,14 @@ class TestSchedulerCoherence:
|
||||
results[name] = np.array(r)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
results["dpm++"], results["euler"], atol=1e-5,
|
||||
results["dpm++"],
|
||||
results["euler"],
|
||||
atol=1e-5,
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
results["unipc"], results["euler"], atol=1e-5,
|
||||
results["unipc"],
|
||||
results["euler"],
|
||||
atol=1e-5,
|
||||
)
|
||||
|
||||
def test_dpmpp_unipc_agree_on_step1(self):
|
||||
@@ -834,7 +902,10 @@ class TestSchedulerCoherence:
|
||||
shape = (1, 2, 1, 2, 2)
|
||||
noise = mx.random.normal(shape)
|
||||
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler, FlowUniPCScheduler
|
||||
from mlx_video.models.wan.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
|
||||
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
|
||||
sched = cls()
|
||||
@@ -857,14 +928,19 @@ class TestSchedulerCoherence:
|
||||
mx.eval(latents)
|
||||
result2 = np.array(latents)
|
||||
|
||||
np.testing.assert_allclose(result1, result2, atol=1e-5,
|
||||
err_msg=f"{cls.__name__} not reproducible after reset()")
|
||||
np.testing.assert_allclose(
|
||||
result1,
|
||||
result2,
|
||||
atol=1e-5,
|
||||
err_msg=f"{cls.__name__} not reproducible after reset()",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# UniPC Corrector Default Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUniPCCorrectorDefault:
|
||||
"""Tests that the UniPC corrector is enabled by default,
|
||||
matching official FlowUniPCMultistepScheduler behavior."""
|
||||
@@ -872,12 +948,14 @@ class TestUniPCCorrectorDefault:
|
||||
def test_corrector_enabled_by_default(self):
|
||||
"""Default construction should have corrector enabled."""
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
assert sched._use_corrector is True
|
||||
|
||||
def test_corrector_affects_output(self):
|
||||
"""Corrector should produce different results than no corrector after step 1."""
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
mx.random.seed(42)
|
||||
shape = (1, 4, 1, 4, 4)
|
||||
noise = mx.random.normal(shape)
|
||||
@@ -901,6 +979,7 @@ class TestUniPCCorrectorDefault:
|
||||
def test_corrector_does_not_affect_first_step(self):
|
||||
"""Step 0 should be identical regardless of corrector setting."""
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
|
||||
mx.random.seed(42)
|
||||
shape = (1, 4, 1, 4, 4)
|
||||
noise = mx.random.normal(shape)
|
||||
|
||||
Reference in New Issue
Block a user