Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.

This commit is contained in:
Prince Canuma
2026-03-18 17:52:30 +01:00
parent 17397da70c
commit 6c63163671
28 changed files with 354 additions and 1033 deletions

View File

@@ -12,14 +12,14 @@ class TestRoPE:
"""Tests for 3-way factorized RoPE."""
def test_rope_params_shape(self):
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
freqs = rope_params(1024, 64)
mx.eval(freqs)
assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2]
def test_rope_params_different_dims(self):
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
for dim in [32, 64, 128]:
freqs = rope_params(512, dim)
@@ -27,7 +27,7 @@ class TestRoPE:
assert freqs.shape == (512, dim // 2, 2)
def test_rope_params_cos_sin_range(self):
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
freqs = rope_params(256, 64)
mx.eval(freqs)
@@ -38,7 +38,7 @@ class TestRoPE:
def test_rope_params_position_zero(self):
"""At position 0, cos should be 1 and sin should be 0."""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
freqs = rope_params(10, 64)
mx.eval(freqs)
@@ -46,7 +46,7 @@ class TestRoPE:
np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6)
def test_rope_apply_output_shape(self):
from mlx_video.models.wan.rope import rope_apply, rope_params
from mlx_video.models.wan2.rope import rope_apply, rope_params
B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim
x = mx.random.normal((B, L, N, D))
@@ -58,7 +58,7 @@ class TestRoPE:
def test_rope_apply_preserves_norm(self):
"""RoPE rotation should preserve vector norms."""
from mlx_video.models.wan.rope import rope_apply, rope_params
from mlx_video.models.wan2.rope import rope_apply, rope_params
B, N, D = 1, 2, 16
F, H, W = 2, 3, 4
@@ -79,7 +79,7 @@ class TestRoPE:
def test_rope_apply_with_padding(self):
"""When seq_len < L, extra tokens should be preserved unchanged."""
from mlx_video.models.wan.rope import rope_apply, rope_params
from mlx_video.models.wan2.rope import rope_apply, rope_params
B, N, D = 1, 2, 16
F, H, W = 2, 2, 2
@@ -100,7 +100,7 @@ class TestRoPE:
def test_rope_apply_batch(self):
"""Test with batch_size > 1 and different grid sizes."""
from mlx_video.models.wan.rope import rope_apply, rope_params
from mlx_video.models.wan2.rope import rope_apply, rope_params
B, N, D = 2, 2, 16
grids = [(2, 3, 4), (2, 3, 4)]
@@ -132,7 +132,7 @@ class TestRoPE:
class TestWanRMSNorm:
def test_output_shape(self):
from mlx_video.models.wan.attention import WanRMSNorm
from mlx_video.models.wan2.attention import WanRMSNorm
norm = WanRMSNorm(64)
x = mx.random.normal((2, 10, 64))
@@ -142,7 +142,7 @@ class TestWanRMSNorm:
def test_zero_mean_variance(self):
"""RMS norm should make RMS ≈ 1 before scaling."""
from mlx_video.models.wan.attention import WanRMSNorm
from mlx_video.models.wan2.attention import WanRMSNorm
norm = WanRMSNorm(64)
x = mx.random.normal((1, 5, 64)) * 10.0
@@ -156,7 +156,7 @@ class TestWanRMSNorm:
def test_dtype_preservation(self):
"""RMSNorm weight is float32, so output is promoted to float32."""
from mlx_video.models.wan.attention import WanRMSNorm
from mlx_video.models.wan2.attention import WanRMSNorm
norm = WanRMSNorm(32)
x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16)
@@ -168,7 +168,7 @@ class TestWanRMSNorm:
class TestWanLayerNorm:
def test_output_shape(self):
from mlx_video.models.wan.attention import WanLayerNorm
from mlx_video.models.wan2.attention import WanLayerNorm
norm = WanLayerNorm(64)
x = mx.random.normal((2, 10, 64))
@@ -177,7 +177,7 @@ class TestWanLayerNorm:
assert out.shape == (2, 10, 64)
def test_without_affine(self):
from mlx_video.models.wan.attention import WanLayerNorm
from mlx_video.models.wan2.attention import WanLayerNorm
norm = WanLayerNorm(64, elementwise_affine=False)
x = mx.random.normal((1, 4, 64))
@@ -190,7 +190,7 @@ class TestWanLayerNorm:
np.testing.assert_allclose(np.std(out_np[i]), 1.0, rtol=0.1)
def test_with_affine(self):
from mlx_video.models.wan.attention import WanLayerNorm
from mlx_video.models.wan2.attention import WanLayerNorm
norm = WanLayerNorm(32, elementwise_affine=True)
assert hasattr(norm, "weight")
@@ -208,8 +208,8 @@ class TestWanSelfAttention:
self.num_heads = 4
def test_output_shape(self):
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.attention import WanSelfAttention
from mlx_video.models.wan2.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads)
B, L = 1, 24
@@ -221,14 +221,14 @@ class TestWanSelfAttention:
assert out.shape == (B, L, self.dim)
def test_with_qk_norm(self):
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan2.attention import WanSelfAttention
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True)
assert attn.norm_q is not None
assert attn.norm_k is not None
def test_without_qk_norm(self):
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan2.attention import WanSelfAttention
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
assert attn.norm_q is None
@@ -236,8 +236,8 @@ class TestWanSelfAttention:
def test_masking(self):
"""Test that masking works: shorter seq_lens should mask later tokens."""
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.attention import WanSelfAttention
from mlx_video.models.wan2.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
B, L = 1, 24
@@ -262,7 +262,7 @@ class TestWanCrossAttention:
self.num_heads = 4
def test_output_shape(self):
from mlx_video.models.wan.attention import WanCrossAttention
from mlx_video.models.wan2.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
B, L_q, L_kv = 1, 24, 16
@@ -273,7 +273,7 @@ class TestWanCrossAttention:
assert out.shape == (B, L_q, self.dim)
def test_with_context_mask(self):
from mlx_video.models.wan.attention import WanCrossAttention
from mlx_video.models.wan2.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
B, L_q, L_kv = 1, 12, 16
@@ -311,8 +311,8 @@ class TestBFloat16Autocast:
def test_self_attn_casts_to_weight_dtype(self):
"""Self-attention should cast input to weight dtype for QKV projections."""
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.attention import WanSelfAttention
from mlx_video.models.wan2.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
@@ -326,7 +326,7 @@ class TestBFloat16Autocast:
def test_cross_attn_casts_to_weight_dtype(self):
"""Cross-attention should cast input to weight dtype."""
from mlx_video.models.wan.attention import WanCrossAttention
from mlx_video.models.wan2.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
@@ -340,7 +340,7 @@ class TestBFloat16Autocast:
def test_cross_attn_kv_cache_uses_weight_dtype(self):
"""prepare_kv should cast context to weight dtype."""
from mlx_video.models.wan.attention import WanCrossAttention
from mlx_video.models.wan2.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
@@ -353,7 +353,7 @@ class TestBFloat16Autocast:
def test_ffn_casts_to_weight_dtype(self):
"""FFN should cast input to weight dtype for linear layers."""
from mlx_video.models.wan.transformer import WanFFN
from mlx_video.models.wan2.transformer import WanFFN
ffn = WanFFN(self.dim, 128)
ffn.update(self._to_bf16(ffn.parameters()))
@@ -366,8 +366,8 @@ class TestBFloat16Autocast:
def test_self_attn_rope_in_float32(self):
"""RoPE should be applied in float32 for precision, even with bf16 weights."""
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.attention import WanSelfAttention
from mlx_video.models.wan2.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
@@ -381,8 +381,8 @@ class TestBFloat16Autocast:
def test_block_float32_residual_with_bf16_weights(self):
"""Full block: residual stream stays float32, matmuls use bf16 weights."""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan2.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, 128, self.num_heads, cross_attn_norm=True)
block.update(self._to_bf16(block.parameters()))