Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.

This commit is contained in:
Prince Canuma
2026-03-18 17:59:43 +01:00
parent b029668cd2
commit 996a542011
37 changed files with 354 additions and 354 deletions

View File

@@ -11,7 +11,7 @@ import numpy as np
class TestT5LayerNorm:
def test_output_shape(self):
from mlx_video.models.wan2.text_encoder import T5LayerNorm
from mlx_video.models.wan_2.text_encoder import T5LayerNorm
norm = T5LayerNorm(64)
x = mx.random.normal((2, 10, 64))
@@ -21,7 +21,7 @@ class TestT5LayerNorm:
def test_rms_normalization(self):
"""After T5LayerNorm with weight=1, RMS should be ~1."""
from mlx_video.models.wan2.text_encoder import T5LayerNorm
from mlx_video.models.wan_2.text_encoder import T5LayerNorm
norm = T5LayerNorm(128)
x = mx.random.normal((1, 5, 128)) * 5.0
@@ -35,7 +35,7 @@ class TestT5LayerNorm:
class TestT5RelativeEmbedding:
def test_output_shape(self):
from mlx_video.models.wan2.text_encoder import T5RelativeEmbedding
from mlx_video.models.wan_2.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
out = rel_emb(10, 10)
@@ -43,7 +43,7 @@ class TestT5RelativeEmbedding:
assert out.shape == (1, 4, 10, 10) # [1, N, lq, lk]
def test_asymmetric_lengths(self):
from mlx_video.models.wan2.text_encoder import T5RelativeEmbedding
from mlx_video.models.wan_2.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
out = rel_emb(8, 12)
@@ -52,7 +52,7 @@ class TestT5RelativeEmbedding:
def test_symmetry(self):
"""Position bias should have structure (not all zeros/random)."""
from mlx_video.models.wan2.text_encoder import T5RelativeEmbedding
from mlx_video.models.wan_2.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2)
out = rel_emb(6, 6)
@@ -67,7 +67,7 @@ class TestT5RelativeEmbedding:
class TestT5Attention:
def test_output_shape(self):
from mlx_video.models.wan2.text_encoder import T5Attention
from mlx_video.models.wan_2.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
x = mx.random.normal((1, 10, 64))
@@ -77,14 +77,14 @@ class TestT5Attention:
def test_no_scaling(self):
"""T5 attention famously has no sqrt(d) scaling. Verify structure."""
from mlx_video.models.wan2.text_encoder import T5Attention
from mlx_video.models.wan_2.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
# No scale attribute (unlike standard attention)
assert not hasattr(attn, "scale")
def test_with_position_bias(self):
from mlx_video.models.wan2.text_encoder import T5Attention, T5RelativeEmbedding
from mlx_video.models.wan_2.text_encoder import T5Attention, T5RelativeEmbedding
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
rel_emb = T5RelativeEmbedding(32, 4)
@@ -95,7 +95,7 @@ class TestT5Attention:
assert out.shape == (1, 10, 64)
def test_with_mask(self):
from mlx_video.models.wan2.text_encoder import T5Attention
from mlx_video.models.wan_2.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
x = mx.random.normal((1, 10, 64))
@@ -108,7 +108,7 @@ class TestT5Attention:
class TestT5FeedForward:
def test_output_shape(self):
from mlx_video.models.wan2.text_encoder import T5FeedForward
from mlx_video.models.wan_2.text_encoder import T5FeedForward
ffn = T5FeedForward(64, 256)
x = mx.random.normal((1, 10, 64))
@@ -118,7 +118,7 @@ class TestT5FeedForward:
def test_gated_structure(self):
"""T5 FFN is gated: gate(x) * fc1(x)."""
from mlx_video.models.wan2.text_encoder import T5FeedForward
from mlx_video.models.wan_2.text_encoder import T5FeedForward
ffn = T5FeedForward(32, 64)
assert hasattr(ffn, "gate_proj")
@@ -131,7 +131,7 @@ class TestT5Encoder:
mx.random.seed(42)
def test_output_shape(self):
from mlx_video.models.wan2.text_encoder import T5Encoder
from mlx_video.models.wan_2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100,
@@ -150,7 +150,7 @@ class TestT5Encoder:
assert out.shape == (1, 5, 64)
def test_shared_pos(self):
from mlx_video.models.wan2.text_encoder import T5Encoder
from mlx_video.models.wan_2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100,
@@ -167,7 +167,7 @@ class TestT5Encoder:
assert block.pos_embedding is None
def test_per_layer_pos(self):
from mlx_video.models.wan2.text_encoder import T5Encoder
from mlx_video.models.wan_2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100,
@@ -184,7 +184,7 @@ class TestT5Encoder:
assert block.pos_embedding is not None
def test_param_count(self):
from mlx_video.models.wan2.text_encoder import T5Encoder
from mlx_video.models.wan_2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100,
@@ -200,7 +200,7 @@ class TestT5Encoder:
assert num_params > 0
def test_without_mask(self):
from mlx_video.models.wan2.text_encoder import T5Encoder
from mlx_video.models.wan_2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100,