Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.

This commit is contained in:
Prince Canuma
2026-03-18 17:59:43 +01:00
parent b029668cd2
commit 996a542011
37 changed files with 354 additions and 354 deletions

View File

@@ -13,7 +13,7 @@ import pytest
class TestFlowMatchEulerScheduler:
def test_initialization(self):
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
assert sched.num_train_timesteps == 1000
@@ -21,7 +21,7 @@ class TestFlowMatchEulerScheduler:
assert sched.sigmas is None
def test_set_timesteps(self):
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(40, shift=12.0)
@@ -30,7 +30,7 @@ class TestFlowMatchEulerScheduler:
assert sched.sigmas.shape == (41,) # 40 steps + terminal
def test_timesteps_decreasing(self):
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(40, shift=12.0)
@@ -40,7 +40,7 @@ class TestFlowMatchEulerScheduler:
assert np.all(np.diff(ts) < 0), f"Timesteps not decreasing: {ts[:5]}..."
def test_sigmas_decreasing(self):
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(20, shift=1.0)
@@ -49,7 +49,7 @@ class TestFlowMatchEulerScheduler:
assert np.all(np.diff(sigmas) <= 0), "Sigmas not decreasing"
def test_terminal_sigma_is_zero(self):
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(20, shift=5.0)
@@ -58,7 +58,7 @@ class TestFlowMatchEulerScheduler:
def test_shift_effect(self):
"""Larger shift should push sigmas toward higher values."""
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched1 = FlowMatchEulerScheduler()
sched2 = FlowMatchEulerScheduler()
@@ -70,7 +70,7 @@ class TestFlowMatchEulerScheduler:
assert mean2 > mean1, "Higher shift should push sigmas higher"
def test_step_euler(self):
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(10, shift=1.0)
@@ -95,7 +95,7 @@ class TestFlowMatchEulerScheduler:
)
def test_step_index_increments(self):
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -108,7 +108,7 @@ class TestFlowMatchEulerScheduler:
assert sched._step_index == 2
def test_reset(self):
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -121,7 +121,7 @@ class TestFlowMatchEulerScheduler:
@pytest.mark.parametrize("steps", [10, 20, 40, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(steps, shift=12.0)
@@ -131,7 +131,7 @@ class TestFlowMatchEulerScheduler:
def test_full_denoise_loop(self):
"""Run a complete denoise loop with zero velocity -> sample unchanged."""
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -153,26 +153,26 @@ class TestComputeSigmas:
"""Tests for the shared _compute_sigmas helper."""
def test_length(self):
from mlx_video.models.wan2.scheduler import _compute_sigmas
from mlx_video.models.wan_2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
assert len(sigmas) == 21 # num_steps + terminal
def test_terminal_zero(self):
from mlx_video.models.wan2.scheduler import _compute_sigmas
from mlx_video.models.wan_2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(10, shift=1.0)
assert sigmas[-1] == 0.0
def test_starts_near_one(self):
from mlx_video.models.wan2.scheduler import _compute_sigmas
from mlx_video.models.wan_2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
# Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0)
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3)
def test_decreasing(self):
from mlx_video.models.wan2.scheduler import _compute_sigmas
from mlx_video.models.wan_2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
assert np.all(np.diff(sigmas) <= 0)
@@ -185,7 +185,7 @@ class TestComputeSigmas:
sigma_max/sigma_min come from the *unshifted* training schedule, and the
shift is applied only once (single-shift).
"""
from mlx_video.models.wan2.scheduler import _compute_sigmas
from mlx_video.models.wan_2.scheduler import _compute_sigmas
steps, shift, N = 50, 5.0, 1000
sigmas = _compute_sigmas(steps, shift, N)
@@ -200,7 +200,7 @@ class TestComputeSigmas:
np.testing.assert_allclose(sigmas, official, atol=1e-6)
def test_shift_one_is_near_linear(self):
from mlx_video.models.wan2.scheduler import _compute_sigmas
from mlx_video.models.wan_2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(10, shift=1.0)
# With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule)
@@ -210,7 +210,7 @@ class TestComputeSigmas:
def test_all_schedulers_same_sigmas(self):
"""All three schedulers should produce identical sigma schedules."""
from mlx_video.models.wan2.scheduler import (
from mlx_video.models.wan_2.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
@@ -229,7 +229,7 @@ class TestComputeSigmas:
np.testing.assert_allclose(np.array(s.sigmas), ref, atol=1e-6)
def test_all_schedulers_same_timesteps(self):
from mlx_video.models.wan2.scheduler import (
from mlx_video.models.wan_2.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
@@ -255,14 +255,14 @@ class TestComputeSigmas:
class TestFlowDPMPP2MScheduler:
def test_initialization(self):
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
assert sched.num_train_timesteps == 1000
assert sched.lower_order_final is True
def test_set_timesteps(self):
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(20, shift=5.0)
@@ -271,7 +271,7 @@ class TestFlowDPMPP2MScheduler:
assert sched.sigmas.shape == (21,)
def test_step_index_increments(self):
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -284,7 +284,7 @@ class TestFlowDPMPP2MScheduler:
assert sched._step_index == 2
def test_reset(self):
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -296,7 +296,7 @@ class TestFlowDPMPP2MScheduler:
def test_full_loop_finite(self):
"""Full loop with constant velocity should produce finite output."""
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=1.0)
@@ -309,7 +309,7 @@ class TestFlowDPMPP2MScheduler:
def test_first_step_is_first_order(self):
"""First step should use 1st-order (no prev_x0 available)."""
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=5.0)
@@ -324,7 +324,7 @@ class TestFlowDPMPP2MScheduler:
def test_second_step_uses_correction(self):
"""After first step, DPM++ should have stored prev_x0 for correction."""
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=5.0)
@@ -348,7 +348,7 @@ class TestFlowDPMPP2MScheduler:
def test_denoise_to_target(self):
"""Perfect oracle should denoise to target with any solver."""
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(20, shift=5.0)
@@ -363,7 +363,7 @@ class TestFlowDPMPP2MScheduler:
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(steps, shift=5.0)
@@ -373,7 +373,7 @@ class TestFlowDPMPP2MScheduler:
def test_terminal_sigma_produces_x0(self):
"""When sigma_next=0 the scheduler should return x0 directly."""
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -394,7 +394,7 @@ class TestFlowDPMPP2MScheduler:
class TestFlowUniPCScheduler:
def test_initialization(self):
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
assert sched.num_train_timesteps == 1000
@@ -402,7 +402,7 @@ class TestFlowUniPCScheduler:
assert sched.lower_order_final is True
def test_set_timesteps(self):
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(30, shift=12.0)
@@ -411,7 +411,7 @@ class TestFlowUniPCScheduler:
assert sched.sigmas.shape == (31,)
def test_step_index_increments(self):
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -422,7 +422,7 @@ class TestFlowUniPCScheduler:
assert sched._step_index == 1
def test_reset(self):
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -435,7 +435,7 @@ class TestFlowUniPCScheduler:
assert all(m is None for m in sched._model_outputs)
def test_full_loop_finite(self):
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(10, shift=1.0)
@@ -448,7 +448,7 @@ class TestFlowUniPCScheduler:
def test_corrector_not_applied_first_step(self):
"""First step should skip the corrector (no history)."""
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True)
sched.set_timesteps(10, shift=5.0)
@@ -462,7 +462,7 @@ class TestFlowUniPCScheduler:
def test_corrector_applied_after_first_step(self):
"""Steps after the first should use the corrector when enabled."""
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True)
sched.set_timesteps(10, shift=5.0)
@@ -475,7 +475,7 @@ class TestFlowUniPCScheduler:
assert sched._lower_order_nums >= 2
def test_denoise_to_target(self):
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(20, shift=5.0)
@@ -490,7 +490,7 @@ class TestFlowUniPCScheduler:
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(steps, shift=5.0)
@@ -500,7 +500,7 @@ class TestFlowUniPCScheduler:
def test_disable_corrector(self):
"""Disabling corrector on step 0 should still work without error."""
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0])
sched.set_timesteps(5, shift=1.0)
@@ -513,7 +513,7 @@ class TestFlowUniPCScheduler:
def test_solver_order_3(self):
"""Order 3 should work without error."""
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(solver_order=3, use_corrector=True)
sched.set_timesteps(10, shift=5.0)
@@ -531,7 +531,7 @@ class TestFlowUniPCScheduler:
# For 50-step schedule with shift=5.0, order 2 corrector at step 5:
# rhos_c[0] (history) should be ~0.07, NOT 0.5
# rhos_c[1] (D1_t) should be ~0.45, NOT 0.5
from mlx_video.models.wan2.scheduler import _compute_sigmas
from mlx_video.models.wan_2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(50, shift=5.0)
@@ -597,7 +597,7 @@ class TestSchedulerCoherence:
@staticmethod
def _make_schedulers(steps=10, shift=5.0):
from mlx_video.models.wan2.scheduler import (
from mlx_video.models.wan_2.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
@@ -780,7 +780,7 @@ class TestSchedulerCoherence:
def test_lambda_boundary_values(self):
"""_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0."""
from mlx_video.models.wan2.scheduler import (
from mlx_video.models.wan_2.scheduler import (
FlowDPMPP2MScheduler,
FlowUniPCScheduler,
)
@@ -800,7 +800,7 @@ class TestSchedulerCoherence:
def test_lambda_monotonically_decreasing(self):
"""_lambda(sigma) should decrease as sigma increases (more noise → lower SNR)."""
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sigmas = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99]
lambdas = [FlowDPMPP2MScheduler._lambda(s) for s in sigmas]
@@ -902,7 +902,7 @@ class TestSchedulerCoherence:
shape = (1, 2, 1, 2, 2)
noise = mx.random.normal(shape)
from mlx_video.models.wan2.scheduler import (
from mlx_video.models.wan_2.scheduler import (
FlowDPMPP2MScheduler,
FlowUniPCScheduler,
)
@@ -947,14 +947,14 @@ class TestUniPCCorrectorDefault:
def test_corrector_enabled_by_default(self):
"""Default construction should have corrector enabled."""
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
assert sched._use_corrector is True
def test_corrector_affects_output(self):
"""Corrector should produce different results than no corrector after step 1."""
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
mx.random.seed(42)
shape = (1, 4, 1, 4, 4)
@@ -978,7 +978,7 @@ class TestUniPCCorrectorDefault:
def test_corrector_does_not_affect_first_step(self):
"""Step 0 should be identical regardless of corrector setting."""
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
mx.random.seed(42)
shape = (1, 4, 1, 4, 4)