Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.
This commit is contained in:
@@ -13,7 +13,7 @@ import pytest
|
||||
|
||||
class TestFlowMatchEulerScheduler:
|
||||
def test_initialization(self):
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
assert sched.num_train_timesteps == 1000
|
||||
@@ -21,7 +21,7 @@ class TestFlowMatchEulerScheduler:
|
||||
assert sched.sigmas is None
|
||||
|
||||
def test_set_timesteps(self):
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(40, shift=12.0)
|
||||
@@ -30,7 +30,7 @@ class TestFlowMatchEulerScheduler:
|
||||
assert sched.sigmas.shape == (41,) # 40 steps + terminal
|
||||
|
||||
def test_timesteps_decreasing(self):
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(40, shift=12.0)
|
||||
@@ -40,7 +40,7 @@ class TestFlowMatchEulerScheduler:
|
||||
assert np.all(np.diff(ts) < 0), f"Timesteps not decreasing: {ts[:5]}..."
|
||||
|
||||
def test_sigmas_decreasing(self):
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(20, shift=1.0)
|
||||
@@ -49,7 +49,7 @@ class TestFlowMatchEulerScheduler:
|
||||
assert np.all(np.diff(sigmas) <= 0), "Sigmas not decreasing"
|
||||
|
||||
def test_terminal_sigma_is_zero(self):
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(20, shift=5.0)
|
||||
@@ -58,7 +58,7 @@ class TestFlowMatchEulerScheduler:
|
||||
|
||||
def test_shift_effect(self):
|
||||
"""Larger shift should push sigmas toward higher values."""
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched1 = FlowMatchEulerScheduler()
|
||||
sched2 = FlowMatchEulerScheduler()
|
||||
@@ -70,7 +70,7 @@ class TestFlowMatchEulerScheduler:
|
||||
assert mean2 > mean1, "Higher shift should push sigmas higher"
|
||||
|
||||
def test_step_euler(self):
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(10, shift=1.0)
|
||||
@@ -95,7 +95,7 @@ class TestFlowMatchEulerScheduler:
|
||||
)
|
||||
|
||||
def test_step_index_increments(self):
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
@@ -108,7 +108,7 @@ class TestFlowMatchEulerScheduler:
|
||||
assert sched._step_index == 2
|
||||
|
||||
def test_reset(self):
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
@@ -121,7 +121,7 @@ class TestFlowMatchEulerScheduler:
|
||||
|
||||
@pytest.mark.parametrize("steps", [10, 20, 40, 50])
|
||||
def test_various_step_counts(self, steps):
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(steps, shift=12.0)
|
||||
@@ -131,7 +131,7 @@ class TestFlowMatchEulerScheduler:
|
||||
|
||||
def test_full_denoise_loop(self):
|
||||
"""Run a complete denoise loop with zero velocity -> sample unchanged."""
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
@@ -153,26 +153,26 @@ class TestComputeSigmas:
|
||||
"""Tests for the shared _compute_sigmas helper."""
|
||||
|
||||
def test_length(self):
|
||||
from mlx_video.models.wan2.scheduler import _compute_sigmas
|
||||
from mlx_video.models.wan_2.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(20, shift=5.0)
|
||||
assert len(sigmas) == 21 # num_steps + terminal
|
||||
|
||||
def test_terminal_zero(self):
|
||||
from mlx_video.models.wan2.scheduler import _compute_sigmas
|
||||
from mlx_video.models.wan_2.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(10, shift=1.0)
|
||||
assert sigmas[-1] == 0.0
|
||||
|
||||
def test_starts_near_one(self):
|
||||
from mlx_video.models.wan2.scheduler import _compute_sigmas
|
||||
from mlx_video.models.wan_2.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(20, shift=5.0)
|
||||
# Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0)
|
||||
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3)
|
||||
|
||||
def test_decreasing(self):
|
||||
from mlx_video.models.wan2.scheduler import _compute_sigmas
|
||||
from mlx_video.models.wan_2.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(20, shift=5.0)
|
||||
assert np.all(np.diff(sigmas) <= 0)
|
||||
@@ -185,7 +185,7 @@ class TestComputeSigmas:
|
||||
sigma_max/sigma_min come from the *unshifted* training schedule, and the
|
||||
shift is applied only once (single-shift).
|
||||
"""
|
||||
from mlx_video.models.wan2.scheduler import _compute_sigmas
|
||||
from mlx_video.models.wan_2.scheduler import _compute_sigmas
|
||||
|
||||
steps, shift, N = 50, 5.0, 1000
|
||||
sigmas = _compute_sigmas(steps, shift, N)
|
||||
@@ -200,7 +200,7 @@ class TestComputeSigmas:
|
||||
np.testing.assert_allclose(sigmas, official, atol=1e-6)
|
||||
|
||||
def test_shift_one_is_near_linear(self):
|
||||
from mlx_video.models.wan2.scheduler import _compute_sigmas
|
||||
from mlx_video.models.wan_2.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(10, shift=1.0)
|
||||
# With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule)
|
||||
@@ -210,7 +210,7 @@ class TestComputeSigmas:
|
||||
|
||||
def test_all_schedulers_same_sigmas(self):
|
||||
"""All three schedulers should produce identical sigma schedules."""
|
||||
from mlx_video.models.wan2.scheduler import (
|
||||
from mlx_video.models.wan_2.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowMatchEulerScheduler,
|
||||
FlowUniPCScheduler,
|
||||
@@ -229,7 +229,7 @@ class TestComputeSigmas:
|
||||
np.testing.assert_allclose(np.array(s.sigmas), ref, atol=1e-6)
|
||||
|
||||
def test_all_schedulers_same_timesteps(self):
|
||||
from mlx_video.models.wan2.scheduler import (
|
||||
from mlx_video.models.wan_2.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowMatchEulerScheduler,
|
||||
FlowUniPCScheduler,
|
||||
@@ -255,14 +255,14 @@ class TestComputeSigmas:
|
||||
|
||||
class TestFlowDPMPP2MScheduler:
|
||||
def test_initialization(self):
|
||||
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
assert sched.num_train_timesteps == 1000
|
||||
assert sched.lower_order_final is True
|
||||
|
||||
def test_set_timesteps(self):
|
||||
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(20, shift=5.0)
|
||||
@@ -271,7 +271,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
assert sched.sigmas.shape == (21,)
|
||||
|
||||
def test_step_index_increments(self):
|
||||
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
@@ -284,7 +284,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
assert sched._step_index == 2
|
||||
|
||||
def test_reset(self):
|
||||
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
@@ -296,7 +296,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
|
||||
def test_full_loop_finite(self):
|
||||
"""Full loop with constant velocity should produce finite output."""
|
||||
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(10, shift=1.0)
|
||||
@@ -309,7 +309,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
|
||||
def test_first_step_is_first_order(self):
|
||||
"""First step should use 1st-order (no prev_x0 available)."""
|
||||
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
@@ -324,7 +324,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
|
||||
def test_second_step_uses_correction(self):
|
||||
"""After first step, DPM++ should have stored prev_x0 for correction."""
|
||||
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
@@ -348,7 +348,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
|
||||
def test_denoise_to_target(self):
|
||||
"""Perfect oracle should denoise to target with any solver."""
|
||||
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(20, shift=5.0)
|
||||
@@ -363,7 +363,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
|
||||
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
|
||||
def test_various_step_counts(self, steps):
|
||||
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(steps, shift=5.0)
|
||||
@@ -373,7 +373,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
|
||||
def test_terminal_sigma_produces_x0(self):
|
||||
"""When sigma_next=0 the scheduler should return x0 directly."""
|
||||
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
@@ -394,7 +394,7 @@ class TestFlowDPMPP2MScheduler:
|
||||
|
||||
class TestFlowUniPCScheduler:
|
||||
def test_initialization(self):
|
||||
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
assert sched.num_train_timesteps == 1000
|
||||
@@ -402,7 +402,7 @@ class TestFlowUniPCScheduler:
|
||||
assert sched.lower_order_final is True
|
||||
|
||||
def test_set_timesteps(self):
|
||||
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(30, shift=12.0)
|
||||
@@ -411,7 +411,7 @@ class TestFlowUniPCScheduler:
|
||||
assert sched.sigmas.shape == (31,)
|
||||
|
||||
def test_step_index_increments(self):
|
||||
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
@@ -422,7 +422,7 @@ class TestFlowUniPCScheduler:
|
||||
assert sched._step_index == 1
|
||||
|
||||
def test_reset(self):
|
||||
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
@@ -435,7 +435,7 @@ class TestFlowUniPCScheduler:
|
||||
assert all(m is None for m in sched._model_outputs)
|
||||
|
||||
def test_full_loop_finite(self):
|
||||
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(10, shift=1.0)
|
||||
@@ -448,7 +448,7 @@ class TestFlowUniPCScheduler:
|
||||
|
||||
def test_corrector_not_applied_first_step(self):
|
||||
"""First step should skip the corrector (no history)."""
|
||||
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler(use_corrector=True)
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
@@ -462,7 +462,7 @@ class TestFlowUniPCScheduler:
|
||||
|
||||
def test_corrector_applied_after_first_step(self):
|
||||
"""Steps after the first should use the corrector when enabled."""
|
||||
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler(use_corrector=True)
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
@@ -475,7 +475,7 @@ class TestFlowUniPCScheduler:
|
||||
assert sched._lower_order_nums >= 2
|
||||
|
||||
def test_denoise_to_target(self):
|
||||
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(20, shift=5.0)
|
||||
@@ -490,7 +490,7 @@ class TestFlowUniPCScheduler:
|
||||
|
||||
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
|
||||
def test_various_step_counts(self, steps):
|
||||
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(steps, shift=5.0)
|
||||
@@ -500,7 +500,7 @@ class TestFlowUniPCScheduler:
|
||||
|
||||
def test_disable_corrector(self):
|
||||
"""Disabling corrector on step 0 should still work without error."""
|
||||
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0])
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
@@ -513,7 +513,7 @@ class TestFlowUniPCScheduler:
|
||||
|
||||
def test_solver_order_3(self):
|
||||
"""Order 3 should work without error."""
|
||||
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler(solver_order=3, use_corrector=True)
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
@@ -531,7 +531,7 @@ class TestFlowUniPCScheduler:
|
||||
# For 50-step schedule with shift=5.0, order 2 corrector at step 5:
|
||||
# rhos_c[0] (history) should be ~0.07, NOT 0.5
|
||||
# rhos_c[1] (D1_t) should be ~0.45, NOT 0.5
|
||||
from mlx_video.models.wan2.scheduler import _compute_sigmas
|
||||
from mlx_video.models.wan_2.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(50, shift=5.0)
|
||||
|
||||
@@ -597,7 +597,7 @@ class TestSchedulerCoherence:
|
||||
|
||||
@staticmethod
|
||||
def _make_schedulers(steps=10, shift=5.0):
|
||||
from mlx_video.models.wan2.scheduler import (
|
||||
from mlx_video.models.wan_2.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowMatchEulerScheduler,
|
||||
FlowUniPCScheduler,
|
||||
@@ -780,7 +780,7 @@ class TestSchedulerCoherence:
|
||||
|
||||
def test_lambda_boundary_values(self):
|
||||
"""_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0."""
|
||||
from mlx_video.models.wan2.scheduler import (
|
||||
from mlx_video.models.wan_2.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
@@ -800,7 +800,7 @@ class TestSchedulerCoherence:
|
||||
|
||||
def test_lambda_monotonically_decreasing(self):
|
||||
"""_lambda(sigma) should decrease as sigma increases (more noise → lower SNR)."""
|
||||
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sigmas = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99]
|
||||
lambdas = [FlowDPMPP2MScheduler._lambda(s) for s in sigmas]
|
||||
@@ -902,7 +902,7 @@ class TestSchedulerCoherence:
|
||||
shape = (1, 2, 1, 2, 2)
|
||||
noise = mx.random.normal(shape)
|
||||
|
||||
from mlx_video.models.wan2.scheduler import (
|
||||
from mlx_video.models.wan_2.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
@@ -947,14 +947,14 @@ class TestUniPCCorrectorDefault:
|
||||
|
||||
def test_corrector_enabled_by_default(self):
|
||||
"""Default construction should have corrector enabled."""
|
||||
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
assert sched._use_corrector is True
|
||||
|
||||
def test_corrector_affects_output(self):
|
||||
"""Corrector should produce different results than no corrector after step 1."""
|
||||
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
mx.random.seed(42)
|
||||
shape = (1, 4, 1, 4, 4)
|
||||
@@ -978,7 +978,7 @@ class TestUniPCCorrectorDefault:
|
||||
|
||||
def test_corrector_does_not_affect_first_step(self):
|
||||
"""Step 0 should be identical regardless of corrector setting."""
|
||||
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
mx.random.seed(42)
|
||||
shape = (1, 4, 1, 4, 4)
|
||||
|
||||
Reference in New Issue
Block a user