Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.

This commit is contained in:
Prince Canuma
2026-03-18 17:59:43 +01:00
parent b029668cd2
commit 996a542011
37 changed files with 354 additions and 354 deletions

View File

@@ -12,7 +12,7 @@ from wan_test_helpers import _make_tiny_config
class TestSinusoidalEmbedding:
def test_output_shape(self):
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
pos = mx.arange(10).astype(mx.float32)
emb = sinusoidal_embedding_1d(256, pos)
@@ -21,7 +21,7 @@ class TestSinusoidalEmbedding:
def test_position_zero(self):
"""Position 0 should have cos=1 for all dims and sin=0."""
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
pos = mx.array([0.0])
emb = sinusoidal_embedding_1d(64, pos)
@@ -33,7 +33,7 @@ class TestSinusoidalEmbedding:
np.testing.assert_allclose(emb_np[32:], 0.0, atol=1e-5)
def test_different_positions_differ(self):
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
pos = mx.array([0.0, 100.0, 999.0])
emb = sinusoidal_embedding_1d(128, pos)
@@ -50,7 +50,7 @@ class TestSinusoidalEmbedding:
class TestHead:
def test_output_shape(self):
from mlx_video.models.wan2.wan2 import Head
from mlx_video.models.wan_2.wan_2 import Head
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
B, L = 1, 24
@@ -62,7 +62,7 @@ class TestHead:
assert out.shape == (B, L, expected_proj_dim)
def test_modulation_shape(self):
from mlx_video.models.wan2.wan2 import Head
from mlx_video.models.wan_2.wan_2 import Head
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
assert head.modulation.shape == (1, 2, 64)
@@ -78,7 +78,7 @@ class TestWanModel:
mx.random.seed(42)
def test_instantiation(self):
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -86,7 +86,7 @@ class TestWanModel:
assert num_params > 0
def test_patchify_shape(self):
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -99,7 +99,7 @@ class TestWanModel:
assert patches.shape == (1, 1 * 2 * 2, config.dim)
def test_patchify_various_sizes(self):
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -115,7 +115,7 @@ class TestWanModel:
def test_unpatchify_inverse(self):
"""Patchify then unpatchify should reconstruct original spatial dims."""
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -131,7 +131,7 @@ class TestWanModel:
assert out[0].shape == (config.out_dim, F, H, W)
def test_forward_pass(self):
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -149,7 +149,7 @@ class TestWanModel:
assert out[0].shape == (C, F, H, W)
def test_forward_batch(self):
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -171,7 +171,7 @@ class TestWanModel:
assert o.shape == (C, F, H, W)
def test_output_is_float32(self):
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -200,7 +200,7 @@ class TestWan21Model:
def _make_tiny_wan21_config(self):
"""Create a tiny config mimicking Wan2.1 (single model)."""
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
# Override to tiny values
@@ -217,7 +217,7 @@ class TestWan21Model:
def _make_tiny_wan21_1_3b_config(self):
"""Create a tiny config mimicking Wan2.1 1.3B."""
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b()
# Override to tiny values (preserve 1.3B head structure: 12 heads)
@@ -234,7 +234,7 @@ class TestWan21Model:
def test_wan21_tiny_model_forward(self):
"""Forward pass with Wan2.1 tiny config."""
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = self._make_tiny_wan21_config()
model = WanModel(config)
@@ -252,7 +252,7 @@ class TestWan21Model:
def test_wan21_1_3b_tiny_model_forward(self):
"""Forward pass with Wan2.1 1.3B tiny config."""
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = self._make_tiny_wan21_1_3b_config()
model = WanModel(config)
@@ -270,8 +270,8 @@ class TestWan21Model:
def test_wan21_single_model_loop(self):
"""Full diffusion loop with single model (Wan2.1 style)."""
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.wan_2 import WanModel
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
config = self._make_tiny_wan21_config()
model = WanModel(config)
@@ -305,7 +305,7 @@ class TestWan21Model:
def test_wan21_vs_wan22_config_differences(self):
"""Verify key differences between Wan2.1 and Wan2.2 configs."""
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
c21 = WanModelConfig.wan21_t2v_14b()
c22 = WanModelConfig.wan22_t2v_14b()
@@ -333,21 +333,21 @@ class TestPerTokenTimestep:
"""Tests for per-token sinusoidal embedding."""
def test_1d_unchanged(self):
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
pos = mx.array([0.0, 100.0, 500.0])
emb = sinusoidal_embedding_1d(256, pos)
assert emb.shape == (3, 256)
def test_2d_per_token(self):
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
pos = mx.array([[0.0, 100.0, 100.0], [50.0, 50.0, 50.0]])
emb = sinusoidal_embedding_1d(256, pos)
assert emb.shape == (2, 3, 256)
def test_consistency(self):
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
pos_1d = mx.array([0.0, 100.0])
emb_1d = sinusoidal_embedding_1d(256, pos_1d)