Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.

This commit is contained in:
Prince Canuma
2026-03-18 17:52:30 +01:00
parent 17397da70c
commit 6c63163671
28 changed files with 354 additions and 1033 deletions

View File

@@ -10,7 +10,7 @@ import numpy as np
class TestWanFFN:
def test_output_shape(self):
from mlx_video.models.wan.transformer import WanFFN
from mlx_video.models.wan2.transformer import WanFFN
ffn = WanFFN(64, 256)
x = mx.random.normal((2, 10, 64))
@@ -20,7 +20,7 @@ class TestWanFFN:
def test_gelu_activation(self):
"""FFN should use GELU activation (non-linearity)."""
from mlx_video.models.wan.transformer import WanFFN
from mlx_video.models.wan2.transformer import WanFFN
ffn = WanFFN(32, 128)
x = mx.ones((1, 1, 32)) * 2.0
@@ -40,8 +40,8 @@ class TestWanAttentionBlock:
self.num_heads = 4
def test_output_shape(self):
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan2.transformer import WanAttentionBlock
block = WanAttentionBlock(
self.dim,
@@ -68,13 +68,13 @@ class TestWanAttentionBlock:
assert out.shape == (B, L, self.dim)
def test_modulation_shape(self):
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan2.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
assert block.modulation.shape == (1, 6, self.dim)
def test_with_cross_attn_norm(self):
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan2.transformer import WanAttentionBlock
block = WanAttentionBlock(
self.dim,
@@ -85,7 +85,7 @@ class TestWanAttentionBlock:
assert block.norm3 is not None
def test_without_cross_attn_norm(self):
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan2.transformer import WanAttentionBlock
block = WanAttentionBlock(
self.dim,
@@ -97,8 +97,8 @@ class TestWanAttentionBlock:
def test_residual_connection(self):
"""Output should differ from zero even with small random init."""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan2.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
B, L = 1, 8
@@ -129,15 +129,15 @@ class TestFloat32Modulation:
def test_block_modulation_in_float32(self):
"""Modulation param starts random but should be usable as float32."""
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan2.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, 128, 4, cross_attn_norm=True)
assert block.modulation.dtype == mx.float32
def test_block_output_float32_with_bf16_modulation_input(self):
"""Even if e (time embedding) arrives as bf16, modulation should cast to f32."""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan2.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, 128, 4)
B, L = 1, 8
@@ -153,7 +153,7 @@ class TestFloat32Modulation:
def test_head_modulation_float32(self):
"""Head modulation should be float32 even with bf16 e input."""
from mlx_video.models.wan.model import Head
from mlx_video.models.wan2.model import Head
head = Head(self.dim, 4, (1, 2, 2))
x = mx.random.normal((1, 8, self.dim))
@@ -164,7 +164,7 @@ class TestFloat32Modulation:
def test_model_time_embedding_float32(self):
"""sinusoidal_embedding_1d output must be float32."""
from mlx_video.models.wan.model import sinusoidal_embedding_1d
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
t = mx.array([500.0])
emb = sinusoidal_embedding_1d(256, t)
@@ -173,7 +173,7 @@ class TestFloat32Modulation:
def test_model_per_token_time_embedding_float32(self):
"""Per-token time embeddings (I2V) should also be float32."""
from mlx_video.models.wan.model import sinusoidal_embedding_1d
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4]
emb = sinusoidal_embedding_1d(256, t)