Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.
This commit is contained in:
@@ -11,7 +11,7 @@ import numpy as np
|
||||
|
||||
class TestT5LayerNorm:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan2.text_encoder import T5LayerNorm
|
||||
from mlx_video.models.wan_2.text_encoder import T5LayerNorm
|
||||
|
||||
norm = T5LayerNorm(64)
|
||||
x = mx.random.normal((2, 10, 64))
|
||||
@@ -21,7 +21,7 @@ class TestT5LayerNorm:
|
||||
|
||||
def test_rms_normalization(self):
|
||||
"""After T5LayerNorm with weight=1, RMS should be ~1."""
|
||||
from mlx_video.models.wan2.text_encoder import T5LayerNorm
|
||||
from mlx_video.models.wan_2.text_encoder import T5LayerNorm
|
||||
|
||||
norm = T5LayerNorm(128)
|
||||
x = mx.random.normal((1, 5, 128)) * 5.0
|
||||
@@ -35,7 +35,7 @@ class TestT5LayerNorm:
|
||||
|
||||
class TestT5RelativeEmbedding:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan2.text_encoder import T5RelativeEmbedding
|
||||
from mlx_video.models.wan_2.text_encoder import T5RelativeEmbedding
|
||||
|
||||
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
|
||||
out = rel_emb(10, 10)
|
||||
@@ -43,7 +43,7 @@ class TestT5RelativeEmbedding:
|
||||
assert out.shape == (1, 4, 10, 10) # [1, N, lq, lk]
|
||||
|
||||
def test_asymmetric_lengths(self):
|
||||
from mlx_video.models.wan2.text_encoder import T5RelativeEmbedding
|
||||
from mlx_video.models.wan_2.text_encoder import T5RelativeEmbedding
|
||||
|
||||
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
|
||||
out = rel_emb(8, 12)
|
||||
@@ -52,7 +52,7 @@ class TestT5RelativeEmbedding:
|
||||
|
||||
def test_symmetry(self):
|
||||
"""Position bias should have structure (not all zeros/random)."""
|
||||
from mlx_video.models.wan2.text_encoder import T5RelativeEmbedding
|
||||
from mlx_video.models.wan_2.text_encoder import T5RelativeEmbedding
|
||||
|
||||
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2)
|
||||
out = rel_emb(6, 6)
|
||||
@@ -67,7 +67,7 @@ class TestT5RelativeEmbedding:
|
||||
|
||||
class TestT5Attention:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan2.text_encoder import T5Attention
|
||||
from mlx_video.models.wan_2.text_encoder import T5Attention
|
||||
|
||||
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||
x = mx.random.normal((1, 10, 64))
|
||||
@@ -77,14 +77,14 @@ class TestT5Attention:
|
||||
|
||||
def test_no_scaling(self):
|
||||
"""T5 attention famously has no sqrt(d) scaling. Verify structure."""
|
||||
from mlx_video.models.wan2.text_encoder import T5Attention
|
||||
from mlx_video.models.wan_2.text_encoder import T5Attention
|
||||
|
||||
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||
# No scale attribute (unlike standard attention)
|
||||
assert not hasattr(attn, "scale")
|
||||
|
||||
def test_with_position_bias(self):
|
||||
from mlx_video.models.wan2.text_encoder import T5Attention, T5RelativeEmbedding
|
||||
from mlx_video.models.wan_2.text_encoder import T5Attention, T5RelativeEmbedding
|
||||
|
||||
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||
rel_emb = T5RelativeEmbedding(32, 4)
|
||||
@@ -95,7 +95,7 @@ class TestT5Attention:
|
||||
assert out.shape == (1, 10, 64)
|
||||
|
||||
def test_with_mask(self):
|
||||
from mlx_video.models.wan2.text_encoder import T5Attention
|
||||
from mlx_video.models.wan_2.text_encoder import T5Attention
|
||||
|
||||
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||
x = mx.random.normal((1, 10, 64))
|
||||
@@ -108,7 +108,7 @@ class TestT5Attention:
|
||||
|
||||
class TestT5FeedForward:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan2.text_encoder import T5FeedForward
|
||||
from mlx_video.models.wan_2.text_encoder import T5FeedForward
|
||||
|
||||
ffn = T5FeedForward(64, 256)
|
||||
x = mx.random.normal((1, 10, 64))
|
||||
@@ -118,7 +118,7 @@ class TestT5FeedForward:
|
||||
|
||||
def test_gated_structure(self):
|
||||
"""T5 FFN is gated: gate(x) * fc1(x)."""
|
||||
from mlx_video.models.wan2.text_encoder import T5FeedForward
|
||||
from mlx_video.models.wan_2.text_encoder import T5FeedForward
|
||||
|
||||
ffn = T5FeedForward(32, 64)
|
||||
assert hasattr(ffn, "gate_proj")
|
||||
@@ -131,7 +131,7 @@ class TestT5Encoder:
|
||||
mx.random.seed(42)
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan2.text_encoder import T5Encoder
|
||||
from mlx_video.models.wan_2.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
vocab_size=100,
|
||||
@@ -150,7 +150,7 @@ class TestT5Encoder:
|
||||
assert out.shape == (1, 5, 64)
|
||||
|
||||
def test_shared_pos(self):
|
||||
from mlx_video.models.wan2.text_encoder import T5Encoder
|
||||
from mlx_video.models.wan_2.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
vocab_size=100,
|
||||
@@ -167,7 +167,7 @@ class TestT5Encoder:
|
||||
assert block.pos_embedding is None
|
||||
|
||||
def test_per_layer_pos(self):
|
||||
from mlx_video.models.wan2.text_encoder import T5Encoder
|
||||
from mlx_video.models.wan_2.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
vocab_size=100,
|
||||
@@ -184,7 +184,7 @@ class TestT5Encoder:
|
||||
assert block.pos_embedding is not None
|
||||
|
||||
def test_param_count(self):
|
||||
from mlx_video.models.wan2.text_encoder import T5Encoder
|
||||
from mlx_video.models.wan_2.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
vocab_size=100,
|
||||
@@ -200,7 +200,7 @@ class TestT5Encoder:
|
||||
assert num_params > 0
|
||||
|
||||
def test_without_mask(self):
|
||||
from mlx_video.models.wan2.text_encoder import T5Encoder
|
||||
from mlx_video.models.wan_2.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
vocab_size=100,
|
||||
|
||||
Reference in New Issue
Block a user