Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.

This commit is contained in:
Prince Canuma
2026-03-18 17:59:43 +01:00
parent b029668cd2
commit 996a542011
37 changed files with 354 additions and 354 deletions

View File

@@ -12,14 +12,14 @@ class TestRoPE:
"""Tests for 3-way factorized RoPE."""
def test_rope_params_shape(self):
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
freqs = rope_params(1024, 64)
mx.eval(freqs)
assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2]
def test_rope_params_different_dims(self):
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
for dim in [32, 64, 128]:
freqs = rope_params(512, dim)
@@ -27,7 +27,7 @@ class TestRoPE:
assert freqs.shape == (512, dim // 2, 2)
def test_rope_params_cos_sin_range(self):
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
freqs = rope_params(256, 64)
mx.eval(freqs)
@@ -38,7 +38,7 @@ class TestRoPE:
def test_rope_params_position_zero(self):
"""At position 0, cos should be 1 and sin should be 0."""
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
freqs = rope_params(10, 64)
mx.eval(freqs)
@@ -46,7 +46,7 @@ class TestRoPE:
np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6)
def test_rope_apply_output_shape(self):
from mlx_video.models.wan2.rope import rope_apply, rope_params
from mlx_video.models.wan_2.rope import rope_apply, rope_params
B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim
x = mx.random.normal((B, L, N, D))
@@ -58,7 +58,7 @@ class TestRoPE:
def test_rope_apply_preserves_norm(self):
"""RoPE rotation should preserve vector norms."""
from mlx_video.models.wan2.rope import rope_apply, rope_params
from mlx_video.models.wan_2.rope import rope_apply, rope_params
B, N, D = 1, 2, 16
F, H, W = 2, 3, 4
@@ -79,7 +79,7 @@ class TestRoPE:
def test_rope_apply_with_padding(self):
"""When seq_len < L, extra tokens should be preserved unchanged."""
from mlx_video.models.wan2.rope import rope_apply, rope_params
from mlx_video.models.wan_2.rope import rope_apply, rope_params
B, N, D = 1, 2, 16
F, H, W = 2, 2, 2
@@ -100,7 +100,7 @@ class TestRoPE:
def test_rope_apply_batch(self):
"""Test with batch_size > 1 and different grid sizes."""
from mlx_video.models.wan2.rope import rope_apply, rope_params
from mlx_video.models.wan_2.rope import rope_apply, rope_params
B, N, D = 2, 2, 16
grids = [(2, 3, 4), (2, 3, 4)]
@@ -132,7 +132,7 @@ class TestRoPE:
class TestWanRMSNorm:
def test_output_shape(self):
from mlx_video.models.wan2.attention import WanRMSNorm
from mlx_video.models.wan_2.attention import WanRMSNorm
norm = WanRMSNorm(64)
x = mx.random.normal((2, 10, 64))
@@ -142,7 +142,7 @@ class TestWanRMSNorm:
def test_zero_mean_variance(self):
"""RMS norm should make RMS ≈ 1 before scaling."""
from mlx_video.models.wan2.attention import WanRMSNorm
from mlx_video.models.wan_2.attention import WanRMSNorm
norm = WanRMSNorm(64)
x = mx.random.normal((1, 5, 64)) * 10.0
@@ -156,7 +156,7 @@ class TestWanRMSNorm:
def test_dtype_preservation(self):
"""RMSNorm weight is float32, so output is promoted to float32."""
from mlx_video.models.wan2.attention import WanRMSNorm
from mlx_video.models.wan_2.attention import WanRMSNorm
norm = WanRMSNorm(32)
x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16)
@@ -168,7 +168,7 @@ class TestWanRMSNorm:
class TestWanLayerNorm:
def test_output_shape(self):
from mlx_video.models.wan2.attention import WanLayerNorm
from mlx_video.models.wan_2.attention import WanLayerNorm
norm = WanLayerNorm(64)
x = mx.random.normal((2, 10, 64))
@@ -177,7 +177,7 @@ class TestWanLayerNorm:
assert out.shape == (2, 10, 64)
def test_without_affine(self):
from mlx_video.models.wan2.attention import WanLayerNorm
from mlx_video.models.wan_2.attention import WanLayerNorm
norm = WanLayerNorm(64, elementwise_affine=False)
x = mx.random.normal((1, 4, 64))
@@ -190,7 +190,7 @@ class TestWanLayerNorm:
np.testing.assert_allclose(np.std(out_np[i]), 1.0, rtol=0.1)
def test_with_affine(self):
from mlx_video.models.wan2.attention import WanLayerNorm
from mlx_video.models.wan_2.attention import WanLayerNorm
norm = WanLayerNorm(32, elementwise_affine=True)
assert hasattr(norm, "weight")
@@ -208,8 +208,8 @@ class TestWanSelfAttention:
self.num_heads = 4
def test_output_shape(self):
from mlx_video.models.wan2.attention import WanSelfAttention
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan_2.attention import WanSelfAttention
from mlx_video.models.wan_2.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads)
B, L = 1, 24
@@ -221,14 +221,14 @@ class TestWanSelfAttention:
assert out.shape == (B, L, self.dim)
def test_with_qk_norm(self):
from mlx_video.models.wan2.attention import WanSelfAttention
from mlx_video.models.wan_2.attention import WanSelfAttention
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True)
assert attn.norm_q is not None
assert attn.norm_k is not None
def test_without_qk_norm(self):
from mlx_video.models.wan2.attention import WanSelfAttention
from mlx_video.models.wan_2.attention import WanSelfAttention
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
assert attn.norm_q is None
@@ -236,8 +236,8 @@ class TestWanSelfAttention:
def test_masking(self):
"""Test that masking works: shorter seq_lens should mask later tokens."""
from mlx_video.models.wan2.attention import WanSelfAttention
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan_2.attention import WanSelfAttention
from mlx_video.models.wan_2.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
B, L = 1, 24
@@ -262,7 +262,7 @@ class TestWanCrossAttention:
self.num_heads = 4
def test_output_shape(self):
from mlx_video.models.wan2.attention import WanCrossAttention
from mlx_video.models.wan_2.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
B, L_q, L_kv = 1, 24, 16
@@ -273,7 +273,7 @@ class TestWanCrossAttention:
assert out.shape == (B, L_q, self.dim)
def test_with_context_mask(self):
from mlx_video.models.wan2.attention import WanCrossAttention
from mlx_video.models.wan_2.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
B, L_q, L_kv = 1, 12, 16
@@ -311,8 +311,8 @@ class TestBFloat16Autocast:
def test_self_attn_casts_to_weight_dtype(self):
"""Self-attention should cast input to weight dtype for QKV projections."""
from mlx_video.models.wan2.attention import WanSelfAttention
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan_2.attention import WanSelfAttention
from mlx_video.models.wan_2.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
@@ -326,7 +326,7 @@ class TestBFloat16Autocast:
def test_cross_attn_casts_to_weight_dtype(self):
"""Cross-attention should cast input to weight dtype."""
from mlx_video.models.wan2.attention import WanCrossAttention
from mlx_video.models.wan_2.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
@@ -340,7 +340,7 @@ class TestBFloat16Autocast:
def test_cross_attn_kv_cache_uses_weight_dtype(self):
"""prepare_kv should cast context to weight dtype."""
from mlx_video.models.wan2.attention import WanCrossAttention
from mlx_video.models.wan_2.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
@@ -353,7 +353,7 @@ class TestBFloat16Autocast:
def test_ffn_casts_to_weight_dtype(self):
"""FFN should cast input to weight dtype for linear layers."""
from mlx_video.models.wan2.transformer import WanFFN
from mlx_video.models.wan_2.transformer import WanFFN
ffn = WanFFN(self.dim, 128)
ffn.update(self._to_bf16(ffn.parameters()))
@@ -366,8 +366,8 @@ class TestBFloat16Autocast:
def test_self_attn_rope_in_float32(self):
"""RoPE should be applied in float32 for precision, even with bf16 weights."""
from mlx_video.models.wan2.attention import WanSelfAttention
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan_2.attention import WanSelfAttention
from mlx_video.models.wan_2.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
@@ -381,8 +381,8 @@ class TestBFloat16Autocast:
def test_block_float32_residual_with_bf16_weights(self):
"""Full block: residual stream stays float32, matmuls use bf16 weights."""
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan2.transformer import WanAttentionBlock
from mlx_video.models.wan_2.rope import rope_params
from mlx_video.models.wan_2.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, 128, self.num_heads, cross_attn_norm=True)
block.update(self._to_bf16(block.parameters()))