Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.

This commit is contained in:
Prince Canuma
2026-03-18 17:59:43 +01:00
parent b029668cd2
commit 996a542011
37 changed files with 354 additions and 354 deletions

View File

@@ -10,7 +10,7 @@ import numpy as np
class TestWanFFN:
def test_output_shape(self):
from mlx_video.models.wan2.transformer import WanFFN
from mlx_video.models.wan_2.transformer import WanFFN
ffn = WanFFN(64, 256)
x = mx.random.normal((2, 10, 64))
@@ -20,7 +20,7 @@ class TestWanFFN:
def test_gelu_activation(self):
"""FFN should use GELU activation (non-linearity)."""
from mlx_video.models.wan2.transformer import WanFFN
from mlx_video.models.wan_2.transformer import WanFFN
ffn = WanFFN(32, 128)
x = mx.ones((1, 1, 32)) * 2.0
@@ -40,8 +40,8 @@ class TestWanAttentionBlock:
self.num_heads = 4
def test_output_shape(self):
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan2.transformer import WanAttentionBlock
from mlx_video.models.wan_2.rope import rope_params
from mlx_video.models.wan_2.transformer import WanAttentionBlock
block = WanAttentionBlock(
self.dim,
@@ -68,13 +68,13 @@ class TestWanAttentionBlock:
assert out.shape == (B, L, self.dim)
def test_modulation_shape(self):
from mlx_video.models.wan2.transformer import WanAttentionBlock
from mlx_video.models.wan_2.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
assert block.modulation.shape == (1, 6, self.dim)
def test_with_cross_attn_norm(self):
from mlx_video.models.wan2.transformer import WanAttentionBlock
from mlx_video.models.wan_2.transformer import WanAttentionBlock
block = WanAttentionBlock(
self.dim,
@@ -85,7 +85,7 @@ class TestWanAttentionBlock:
assert block.norm3 is not None
def test_without_cross_attn_norm(self):
from mlx_video.models.wan2.transformer import WanAttentionBlock
from mlx_video.models.wan_2.transformer import WanAttentionBlock
block = WanAttentionBlock(
self.dim,
@@ -97,8 +97,8 @@ class TestWanAttentionBlock:
def test_residual_connection(self):
"""Output should differ from zero even with small random init."""
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan2.transformer import WanAttentionBlock
from mlx_video.models.wan_2.rope import rope_params
from mlx_video.models.wan_2.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
B, L = 1, 8
@@ -129,15 +129,15 @@ class TestFloat32Modulation:
def test_block_modulation_in_float32(self):
"""Modulation param starts random but should be usable as float32."""
from mlx_video.models.wan2.transformer import WanAttentionBlock
from mlx_video.models.wan_2.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, 128, 4, cross_attn_norm=True)
assert block.modulation.dtype == mx.float32
def test_block_output_float32_with_bf16_modulation_input(self):
"""Even if e (time embedding) arrives as bf16, modulation should cast to f32."""
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan2.transformer import WanAttentionBlock
from mlx_video.models.wan_2.rope import rope_params
from mlx_video.models.wan_2.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, 128, 4)
B, L = 1, 8
@@ -153,7 +153,7 @@ class TestFloat32Modulation:
def test_head_modulation_float32(self):
"""Head modulation should be float32 even with bf16 e input."""
from mlx_video.models.wan2.wan2 import Head
from mlx_video.models.wan_2.wan_2 import Head
head = Head(self.dim, 4, (1, 2, 2))
x = mx.random.normal((1, 8, self.dim))
@@ -164,7 +164,7 @@ class TestFloat32Modulation:
def test_model_time_embedding_float32(self):
"""sinusoidal_embedding_1d output must be float32."""
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
t = mx.array([500.0])
emb = sinusoidal_embedding_1d(256, t)
@@ -173,7 +173,7 @@ class TestFloat32Modulation:
def test_model_per_token_time_embedding_float32(self):
"""Per-token time embeddings (I2V) should also be float32."""
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4]
emb = sinusoidal_embedding_1d(256, t)