Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.

This commit is contained in:
Prince Canuma
2026-03-18 17:52:30 +01:00
parent 17397da70c
commit 6c63163671
28 changed files with 354 additions and 1033 deletions

View File

@@ -13,7 +13,7 @@ import pytest
class TestFlowMatchEulerScheduler:
def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
assert sched.num_train_timesteps == 1000
@@ -21,7 +21,7 @@ class TestFlowMatchEulerScheduler:
assert sched.sigmas is None
def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(40, shift=12.0)
@@ -30,7 +30,7 @@ class TestFlowMatchEulerScheduler:
assert sched.sigmas.shape == (41,) # 40 steps + terminal
def test_timesteps_decreasing(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(40, shift=12.0)
@@ -40,7 +40,7 @@ class TestFlowMatchEulerScheduler:
assert np.all(np.diff(ts) < 0), f"Timesteps not decreasing: {ts[:5]}..."
def test_sigmas_decreasing(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(20, shift=1.0)
@@ -49,7 +49,7 @@ class TestFlowMatchEulerScheduler:
assert np.all(np.diff(sigmas) <= 0), "Sigmas not decreasing"
def test_terminal_sigma_is_zero(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(20, shift=5.0)
@@ -58,7 +58,7 @@ class TestFlowMatchEulerScheduler:
def test_shift_effect(self):
"""Larger shift should push sigmas toward higher values."""
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched1 = FlowMatchEulerScheduler()
sched2 = FlowMatchEulerScheduler()
@@ -70,7 +70,7 @@ class TestFlowMatchEulerScheduler:
assert mean2 > mean1, "Higher shift should push sigmas higher"
def test_step_euler(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(10, shift=1.0)
@@ -95,7 +95,7 @@ class TestFlowMatchEulerScheduler:
)
def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -108,7 +108,7 @@ class TestFlowMatchEulerScheduler:
assert sched._step_index == 2
def test_reset(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -121,7 +121,7 @@ class TestFlowMatchEulerScheduler:
@pytest.mark.parametrize("steps", [10, 20, 40, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(steps, shift=12.0)
@@ -131,7 +131,7 @@ class TestFlowMatchEulerScheduler:
def test_full_denoise_loop(self):
"""Run a complete denoise loop with zero velocity -> sample unchanged."""
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -153,26 +153,26 @@ class TestComputeSigmas:
"""Tests for the shared _compute_sigmas helper."""
def test_length(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
from mlx_video.models.wan2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
assert len(sigmas) == 21 # num_steps + terminal
def test_terminal_zero(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
from mlx_video.models.wan2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(10, shift=1.0)
assert sigmas[-1] == 0.0
def test_starts_near_one(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
from mlx_video.models.wan2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
# Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0)
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3)
def test_decreasing(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
from mlx_video.models.wan2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
assert np.all(np.diff(sigmas) <= 0)
@@ -185,7 +185,7 @@ class TestComputeSigmas:
sigma_max/sigma_min come from the *unshifted* training schedule, and the
shift is applied only once (single-shift).
"""
from mlx_video.models.wan.scheduler import _compute_sigmas
from mlx_video.models.wan2.scheduler import _compute_sigmas
steps, shift, N = 50, 5.0, 1000
sigmas = _compute_sigmas(steps, shift, N)
@@ -200,7 +200,7 @@ class TestComputeSigmas:
np.testing.assert_allclose(sigmas, official, atol=1e-6)
def test_shift_one_is_near_linear(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
from mlx_video.models.wan2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(10, shift=1.0)
# With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule)
@@ -210,7 +210,7 @@ class TestComputeSigmas:
def test_all_schedulers_same_sigmas(self):
"""All three schedulers should produce identical sigma schedules."""
from mlx_video.models.wan.scheduler import (
from mlx_video.models.wan2.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
@@ -229,7 +229,7 @@ class TestComputeSigmas:
np.testing.assert_allclose(np.array(s.sigmas), ref, atol=1e-6)
def test_all_schedulers_same_timesteps(self):
from mlx_video.models.wan.scheduler import (
from mlx_video.models.wan2.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
@@ -255,14 +255,14 @@ class TestComputeSigmas:
class TestFlowDPMPP2MScheduler:
def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
assert sched.num_train_timesteps == 1000
assert sched.lower_order_final is True
def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(20, shift=5.0)
@@ -271,7 +271,7 @@ class TestFlowDPMPP2MScheduler:
assert sched.sigmas.shape == (21,)
def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -284,7 +284,7 @@ class TestFlowDPMPP2MScheduler:
assert sched._step_index == 2
def test_reset(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -296,7 +296,7 @@ class TestFlowDPMPP2MScheduler:
def test_full_loop_finite(self):
"""Full loop with constant velocity should produce finite output."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=1.0)
@@ -309,7 +309,7 @@ class TestFlowDPMPP2MScheduler:
def test_first_step_is_first_order(self):
"""First step should use 1st-order (no prev_x0 available)."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=5.0)
@@ -324,7 +324,7 @@ class TestFlowDPMPP2MScheduler:
def test_second_step_uses_correction(self):
"""After first step, DPM++ should have stored prev_x0 for correction."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=5.0)
@@ -348,7 +348,7 @@ class TestFlowDPMPP2MScheduler:
def test_denoise_to_target(self):
"""Perfect oracle should denoise to target with any solver."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(20, shift=5.0)
@@ -363,7 +363,7 @@ class TestFlowDPMPP2MScheduler:
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(steps, shift=5.0)
@@ -373,7 +373,7 @@ class TestFlowDPMPP2MScheduler:
def test_terminal_sigma_produces_x0(self):
"""When sigma_next=0 the scheduler should return x0 directly."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -394,7 +394,7 @@ class TestFlowDPMPP2MScheduler:
class TestFlowUniPCScheduler:
def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
assert sched.num_train_timesteps == 1000
@@ -402,7 +402,7 @@ class TestFlowUniPCScheduler:
assert sched.lower_order_final is True
def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(30, shift=12.0)
@@ -411,7 +411,7 @@ class TestFlowUniPCScheduler:
assert sched.sigmas.shape == (31,)
def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -422,7 +422,7 @@ class TestFlowUniPCScheduler:
assert sched._step_index == 1
def test_reset(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -435,7 +435,7 @@ class TestFlowUniPCScheduler:
assert all(m is None for m in sched._model_outputs)
def test_full_loop_finite(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(10, shift=1.0)
@@ -448,7 +448,7 @@ class TestFlowUniPCScheduler:
def test_corrector_not_applied_first_step(self):
"""First step should skip the corrector (no history)."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True)
sched.set_timesteps(10, shift=5.0)
@@ -462,7 +462,7 @@ class TestFlowUniPCScheduler:
def test_corrector_applied_after_first_step(self):
"""Steps after the first should use the corrector when enabled."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True)
sched.set_timesteps(10, shift=5.0)
@@ -475,7 +475,7 @@ class TestFlowUniPCScheduler:
assert sched._lower_order_nums >= 2
def test_denoise_to_target(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(20, shift=5.0)
@@ -490,7 +490,7 @@ class TestFlowUniPCScheduler:
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(steps, shift=5.0)
@@ -500,7 +500,7 @@ class TestFlowUniPCScheduler:
def test_disable_corrector(self):
"""Disabling corrector on step 0 should still work without error."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0])
sched.set_timesteps(5, shift=1.0)
@@ -513,7 +513,7 @@ class TestFlowUniPCScheduler:
def test_solver_order_3(self):
"""Order 3 should work without error."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(solver_order=3, use_corrector=True)
sched.set_timesteps(10, shift=5.0)
@@ -531,7 +531,7 @@ class TestFlowUniPCScheduler:
# For 50-step schedule with shift=5.0, order 2 corrector at step 5:
# rhos_c[0] (history) should be ~0.07, NOT 0.5
# rhos_c[1] (D1_t) should be ~0.45, NOT 0.5
from mlx_video.models.wan.scheduler import _compute_sigmas
from mlx_video.models.wan2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(50, shift=5.0)
@@ -597,7 +597,7 @@ class TestSchedulerCoherence:
@staticmethod
def _make_schedulers(steps=10, shift=5.0):
from mlx_video.models.wan.scheduler import (
from mlx_video.models.wan2.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
@@ -780,7 +780,7 @@ class TestSchedulerCoherence:
def test_lambda_boundary_values(self):
"""_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0."""
from mlx_video.models.wan.scheduler import (
from mlx_video.models.wan2.scheduler import (
FlowDPMPP2MScheduler,
FlowUniPCScheduler,
)
@@ -800,7 +800,7 @@ class TestSchedulerCoherence:
def test_lambda_monotonically_decreasing(self):
"""_lambda(sigma) should decrease as sigma increases (more noise → lower SNR)."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sigmas = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99]
lambdas = [FlowDPMPP2MScheduler._lambda(s) for s in sigmas]
@@ -902,7 +902,7 @@ class TestSchedulerCoherence:
shape = (1, 2, 1, 2, 2)
noise = mx.random.normal(shape)
from mlx_video.models.wan.scheduler import (
from mlx_video.models.wan2.scheduler import (
FlowDPMPP2MScheduler,
FlowUniPCScheduler,
)
@@ -947,14 +947,14 @@ class TestUniPCCorrectorDefault:
def test_corrector_enabled_by_default(self):
"""Default construction should have corrector enabled."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
assert sched._use_corrector is True
def test_corrector_affects_output(self):
"""Corrector should produce different results than no corrector after step 1."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
mx.random.seed(42)
shape = (1, 4, 1, 4, 4)
@@ -978,7 +978,7 @@ class TestUniPCCorrectorDefault:
def test_corrector_does_not_affect_first_step(self):
"""Step 0 should be identical regardless of corrector setting."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
mx.random.seed(42)
shape = (1, 4, 1, 4, 4)