Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.

This commit is contained in:
Prince Canuma
2026-03-18 17:52:30 +01:00
parent 17397da70c
commit 6c63163671
28 changed files with 354 additions and 1033 deletions

View File

@@ -11,7 +11,7 @@ import numpy as np
class TestT5LayerNorm:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5LayerNorm
from mlx_video.models.wan2.text_encoder import T5LayerNorm
norm = T5LayerNorm(64)
x = mx.random.normal((2, 10, 64))
@@ -21,7 +21,7 @@ class TestT5LayerNorm:
def test_rms_normalization(self):
"""After T5LayerNorm with weight=1, RMS should be ~1."""
from mlx_video.models.wan.text_encoder import T5LayerNorm
from mlx_video.models.wan2.text_encoder import T5LayerNorm
norm = T5LayerNorm(128)
x = mx.random.normal((1, 5, 128)) * 5.0
@@ -35,7 +35,7 @@ class TestT5LayerNorm:
class TestT5RelativeEmbedding:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
from mlx_video.models.wan2.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
out = rel_emb(10, 10)
@@ -43,7 +43,7 @@ class TestT5RelativeEmbedding:
assert out.shape == (1, 4, 10, 10) # [1, N, lq, lk]
def test_asymmetric_lengths(self):
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
from mlx_video.models.wan2.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
out = rel_emb(8, 12)
@@ -52,7 +52,7 @@ class TestT5RelativeEmbedding:
def test_symmetry(self):
"""Position bias should have structure (not all zeros/random)."""
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
from mlx_video.models.wan2.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2)
out = rel_emb(6, 6)
@@ -67,7 +67,7 @@ class TestT5RelativeEmbedding:
class TestT5Attention:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5Attention
from mlx_video.models.wan2.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
x = mx.random.normal((1, 10, 64))
@@ -77,14 +77,14 @@ class TestT5Attention:
def test_no_scaling(self):
"""T5 attention famously has no sqrt(d) scaling. Verify structure."""
from mlx_video.models.wan.text_encoder import T5Attention
from mlx_video.models.wan2.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
# No scale attribute (unlike standard attention)
assert not hasattr(attn, "scale")
def test_with_position_bias(self):
from mlx_video.models.wan.text_encoder import T5Attention, T5RelativeEmbedding
from mlx_video.models.wan2.text_encoder import T5Attention, T5RelativeEmbedding
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
rel_emb = T5RelativeEmbedding(32, 4)
@@ -95,7 +95,7 @@ class TestT5Attention:
assert out.shape == (1, 10, 64)
def test_with_mask(self):
from mlx_video.models.wan.text_encoder import T5Attention
from mlx_video.models.wan2.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
x = mx.random.normal((1, 10, 64))
@@ -108,7 +108,7 @@ class TestT5Attention:
class TestT5FeedForward:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5FeedForward
from mlx_video.models.wan2.text_encoder import T5FeedForward
ffn = T5FeedForward(64, 256)
x = mx.random.normal((1, 10, 64))
@@ -118,7 +118,7 @@ class TestT5FeedForward:
def test_gated_structure(self):
"""T5 FFN is gated: gate(x) * fc1(x)."""
from mlx_video.models.wan.text_encoder import T5FeedForward
from mlx_video.models.wan2.text_encoder import T5FeedForward
ffn = T5FeedForward(32, 64)
assert hasattr(ffn, "gate_proj")
@@ -131,7 +131,7 @@ class TestT5Encoder:
mx.random.seed(42)
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5Encoder
from mlx_video.models.wan2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100,
@@ -150,7 +150,7 @@ class TestT5Encoder:
assert out.shape == (1, 5, 64)
def test_shared_pos(self):
from mlx_video.models.wan.text_encoder import T5Encoder
from mlx_video.models.wan2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100,
@@ -167,7 +167,7 @@ class TestT5Encoder:
assert block.pos_embedding is None
def test_per_layer_pos(self):
from mlx_video.models.wan.text_encoder import T5Encoder
from mlx_video.models.wan2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100,
@@ -184,7 +184,7 @@ class TestT5Encoder:
assert block.pos_embedding is not None
def test_param_count(self):
from mlx_video.models.wan.text_encoder import T5Encoder
from mlx_video.models.wan2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100,
@@ -200,7 +200,7 @@ class TestT5Encoder:
assert num_params > 0
def test_without_mask(self):
from mlx_video.models.wan.text_encoder import T5Encoder
from mlx_video.models.wan2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100,