Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.
This commit is contained in:
@@ -12,14 +12,14 @@ class TestRoPE:
|
||||
"""Tests for 3-way factorized RoPE."""
|
||||
|
||||
def test_rope_params_shape(self):
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
freqs = rope_params(1024, 64)
|
||||
mx.eval(freqs)
|
||||
assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2]
|
||||
|
||||
def test_rope_params_different_dims(self):
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
for dim in [32, 64, 128]:
|
||||
freqs = rope_params(512, dim)
|
||||
@@ -27,7 +27,7 @@ class TestRoPE:
|
||||
assert freqs.shape == (512, dim // 2, 2)
|
||||
|
||||
def test_rope_params_cos_sin_range(self):
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
freqs = rope_params(256, 64)
|
||||
mx.eval(freqs)
|
||||
@@ -38,7 +38,7 @@ class TestRoPE:
|
||||
|
||||
def test_rope_params_position_zero(self):
|
||||
"""At position 0, cos should be 1 and sin should be 0."""
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
freqs = rope_params(10, 64)
|
||||
mx.eval(freqs)
|
||||
@@ -46,7 +46,7 @@ class TestRoPE:
|
||||
np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6)
|
||||
|
||||
def test_rope_apply_output_shape(self):
|
||||
from mlx_video.models.wan2.rope import rope_apply, rope_params
|
||||
from mlx_video.models.wan_2.rope import rope_apply, rope_params
|
||||
|
||||
B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim
|
||||
x = mx.random.normal((B, L, N, D))
|
||||
@@ -58,7 +58,7 @@ class TestRoPE:
|
||||
|
||||
def test_rope_apply_preserves_norm(self):
|
||||
"""RoPE rotation should preserve vector norms."""
|
||||
from mlx_video.models.wan2.rope import rope_apply, rope_params
|
||||
from mlx_video.models.wan_2.rope import rope_apply, rope_params
|
||||
|
||||
B, N, D = 1, 2, 16
|
||||
F, H, W = 2, 3, 4
|
||||
@@ -79,7 +79,7 @@ class TestRoPE:
|
||||
|
||||
def test_rope_apply_with_padding(self):
|
||||
"""When seq_len < L, extra tokens should be preserved unchanged."""
|
||||
from mlx_video.models.wan2.rope import rope_apply, rope_params
|
||||
from mlx_video.models.wan_2.rope import rope_apply, rope_params
|
||||
|
||||
B, N, D = 1, 2, 16
|
||||
F, H, W = 2, 2, 2
|
||||
@@ -100,7 +100,7 @@ class TestRoPE:
|
||||
|
||||
def test_rope_apply_batch(self):
|
||||
"""Test with batch_size > 1 and different grid sizes."""
|
||||
from mlx_video.models.wan2.rope import rope_apply, rope_params
|
||||
from mlx_video.models.wan_2.rope import rope_apply, rope_params
|
||||
|
||||
B, N, D = 2, 2, 16
|
||||
grids = [(2, 3, 4), (2, 3, 4)]
|
||||
@@ -132,7 +132,7 @@ class TestRoPE:
|
||||
|
||||
class TestWanRMSNorm:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan2.attention import WanRMSNorm
|
||||
from mlx_video.models.wan_2.attention import WanRMSNorm
|
||||
|
||||
norm = WanRMSNorm(64)
|
||||
x = mx.random.normal((2, 10, 64))
|
||||
@@ -142,7 +142,7 @@ class TestWanRMSNorm:
|
||||
|
||||
def test_zero_mean_variance(self):
|
||||
"""RMS norm should make RMS ≈ 1 before scaling."""
|
||||
from mlx_video.models.wan2.attention import WanRMSNorm
|
||||
from mlx_video.models.wan_2.attention import WanRMSNorm
|
||||
|
||||
norm = WanRMSNorm(64)
|
||||
x = mx.random.normal((1, 5, 64)) * 10.0
|
||||
@@ -156,7 +156,7 @@ class TestWanRMSNorm:
|
||||
|
||||
def test_dtype_preservation(self):
|
||||
"""RMSNorm weight is float32, so output is promoted to float32."""
|
||||
from mlx_video.models.wan2.attention import WanRMSNorm
|
||||
from mlx_video.models.wan_2.attention import WanRMSNorm
|
||||
|
||||
norm = WanRMSNorm(32)
|
||||
x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16)
|
||||
@@ -168,7 +168,7 @@ class TestWanRMSNorm:
|
||||
|
||||
class TestWanLayerNorm:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan2.attention import WanLayerNorm
|
||||
from mlx_video.models.wan_2.attention import WanLayerNorm
|
||||
|
||||
norm = WanLayerNorm(64)
|
||||
x = mx.random.normal((2, 10, 64))
|
||||
@@ -177,7 +177,7 @@ class TestWanLayerNorm:
|
||||
assert out.shape == (2, 10, 64)
|
||||
|
||||
def test_without_affine(self):
|
||||
from mlx_video.models.wan2.attention import WanLayerNorm
|
||||
from mlx_video.models.wan_2.attention import WanLayerNorm
|
||||
|
||||
norm = WanLayerNorm(64, elementwise_affine=False)
|
||||
x = mx.random.normal((1, 4, 64))
|
||||
@@ -190,7 +190,7 @@ class TestWanLayerNorm:
|
||||
np.testing.assert_allclose(np.std(out_np[i]), 1.0, rtol=0.1)
|
||||
|
||||
def test_with_affine(self):
|
||||
from mlx_video.models.wan2.attention import WanLayerNorm
|
||||
from mlx_video.models.wan_2.attention import WanLayerNorm
|
||||
|
||||
norm = WanLayerNorm(32, elementwise_affine=True)
|
||||
assert hasattr(norm, "weight")
|
||||
@@ -208,8 +208,8 @@ class TestWanSelfAttention:
|
||||
self.num_heads = 4
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan2.attention import WanSelfAttention
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
from mlx_video.models.wan_2.attention import WanSelfAttention
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads)
|
||||
B, L = 1, 24
|
||||
@@ -221,14 +221,14 @@ class TestWanSelfAttention:
|
||||
assert out.shape == (B, L, self.dim)
|
||||
|
||||
def test_with_qk_norm(self):
|
||||
from mlx_video.models.wan2.attention import WanSelfAttention
|
||||
from mlx_video.models.wan_2.attention import WanSelfAttention
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True)
|
||||
assert attn.norm_q is not None
|
||||
assert attn.norm_k is not None
|
||||
|
||||
def test_without_qk_norm(self):
|
||||
from mlx_video.models.wan2.attention import WanSelfAttention
|
||||
from mlx_video.models.wan_2.attention import WanSelfAttention
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
|
||||
assert attn.norm_q is None
|
||||
@@ -236,8 +236,8 @@ class TestWanSelfAttention:
|
||||
|
||||
def test_masking(self):
|
||||
"""Test that masking works: shorter seq_lens should mask later tokens."""
|
||||
from mlx_video.models.wan2.attention import WanSelfAttention
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
from mlx_video.models.wan_2.attention import WanSelfAttention
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
|
||||
B, L = 1, 24
|
||||
@@ -262,7 +262,7 @@ class TestWanCrossAttention:
|
||||
self.num_heads = 4
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan2.attention import WanCrossAttention
|
||||
from mlx_video.models.wan_2.attention import WanCrossAttention
|
||||
|
||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||
B, L_q, L_kv = 1, 24, 16
|
||||
@@ -273,7 +273,7 @@ class TestWanCrossAttention:
|
||||
assert out.shape == (B, L_q, self.dim)
|
||||
|
||||
def test_with_context_mask(self):
|
||||
from mlx_video.models.wan2.attention import WanCrossAttention
|
||||
from mlx_video.models.wan_2.attention import WanCrossAttention
|
||||
|
||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||
B, L_q, L_kv = 1, 12, 16
|
||||
@@ -311,8 +311,8 @@ class TestBFloat16Autocast:
|
||||
|
||||
def test_self_attn_casts_to_weight_dtype(self):
|
||||
"""Self-attention should cast input to weight dtype for QKV projections."""
|
||||
from mlx_video.models.wan2.attention import WanSelfAttention
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
from mlx_video.models.wan_2.attention import WanSelfAttention
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads)
|
||||
attn.update(self._to_bf16(attn.parameters()))
|
||||
@@ -326,7 +326,7 @@ class TestBFloat16Autocast:
|
||||
|
||||
def test_cross_attn_casts_to_weight_dtype(self):
|
||||
"""Cross-attention should cast input to weight dtype."""
|
||||
from mlx_video.models.wan2.attention import WanCrossAttention
|
||||
from mlx_video.models.wan_2.attention import WanCrossAttention
|
||||
|
||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||
attn.update(self._to_bf16(attn.parameters()))
|
||||
@@ -340,7 +340,7 @@ class TestBFloat16Autocast:
|
||||
|
||||
def test_cross_attn_kv_cache_uses_weight_dtype(self):
|
||||
"""prepare_kv should cast context to weight dtype."""
|
||||
from mlx_video.models.wan2.attention import WanCrossAttention
|
||||
from mlx_video.models.wan_2.attention import WanCrossAttention
|
||||
|
||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||
attn.update(self._to_bf16(attn.parameters()))
|
||||
@@ -353,7 +353,7 @@ class TestBFloat16Autocast:
|
||||
|
||||
def test_ffn_casts_to_weight_dtype(self):
|
||||
"""FFN should cast input to weight dtype for linear layers."""
|
||||
from mlx_video.models.wan2.transformer import WanFFN
|
||||
from mlx_video.models.wan_2.transformer import WanFFN
|
||||
|
||||
ffn = WanFFN(self.dim, 128)
|
||||
ffn.update(self._to_bf16(ffn.parameters()))
|
||||
@@ -366,8 +366,8 @@ class TestBFloat16Autocast:
|
||||
|
||||
def test_self_attn_rope_in_float32(self):
|
||||
"""RoPE should be applied in float32 for precision, even with bf16 weights."""
|
||||
from mlx_video.models.wan2.attention import WanSelfAttention
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
from mlx_video.models.wan_2.attention import WanSelfAttention
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads)
|
||||
attn.update(self._to_bf16(attn.parameters()))
|
||||
@@ -381,8 +381,8 @@ class TestBFloat16Autocast:
|
||||
|
||||
def test_block_float32_residual_with_bf16_weights(self):
|
||||
"""Full block: residual stream stays float32, matmuls use bf16 weights."""
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
from mlx_video.models.wan2.transformer import WanAttentionBlock
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
from mlx_video.models.wan_2.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(self.dim, 128, self.num_heads, cross_attn_norm=True)
|
||||
block.update(self._to_bf16(block.parameters()))
|
||||
|
||||
Reference in New Issue
Block a user