This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -3,16 +3,16 @@
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# T5 Encoder Tests
# ---------------------------------------------------------------------------
class TestT5LayerNorm:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5LayerNorm
norm = T5LayerNorm(64)
x = mx.random.normal((2, 10, 64))
out = norm(x)
@@ -22,6 +22,7 @@ class TestT5LayerNorm:
def test_rms_normalization(self):
"""After T5LayerNorm with weight=1, RMS should be ~1."""
from mlx_video.models.wan.text_encoder import T5LayerNorm
norm = T5LayerNorm(128)
x = mx.random.normal((1, 5, 128)) * 5.0
out = norm(x)
@@ -35,6 +36,7 @@ class TestT5LayerNorm:
class TestT5RelativeEmbedding:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
out = rel_emb(10, 10)
mx.eval(out)
@@ -42,6 +44,7 @@ class TestT5RelativeEmbedding:
def test_asymmetric_lengths(self):
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
out = rel_emb(8, 12)
mx.eval(out)
@@ -50,6 +53,7 @@ class TestT5RelativeEmbedding:
def test_symmetry(self):
"""Position bias should have structure (not all zeros/random)."""
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2)
out = rel_emb(6, 6)
mx.eval(out)
@@ -64,6 +68,7 @@ class TestT5RelativeEmbedding:
class TestT5Attention:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
x = mx.random.normal((1, 10, 64))
out = attn(x)
@@ -73,12 +78,14 @@ class TestT5Attention:
def test_no_scaling(self):
"""T5 attention famously has no sqrt(d) scaling. Verify structure."""
from mlx_video.models.wan.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
# No scale attribute (unlike standard attention)
assert not hasattr(attn, "scale")
def test_with_position_bias(self):
from mlx_video.models.wan.text_encoder import T5Attention, T5RelativeEmbedding
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
rel_emb = T5RelativeEmbedding(32, 4)
x = mx.random.normal((1, 10, 64))
@@ -89,6 +96,7 @@ class TestT5Attention:
def test_with_mask(self):
from mlx_video.models.wan.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
x = mx.random.normal((1, 10, 64))
mask = mx.ones((1, 10))
@@ -101,6 +109,7 @@ class TestT5Attention:
class TestT5FeedForward:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5FeedForward
ffn = T5FeedForward(64, 256)
x = mx.random.normal((1, 10, 64))
out = ffn(x)
@@ -110,6 +119,7 @@ class TestT5FeedForward:
def test_gated_structure(self):
"""T5 FFN is gated: gate(x) * fc1(x)."""
from mlx_video.models.wan.text_encoder import T5FeedForward
ffn = T5FeedForward(32, 64)
assert hasattr(ffn, "gate_proj")
assert hasattr(ffn, "fc1")
@@ -122,9 +132,16 @@ class TestT5Encoder:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
vocab_size=100,
dim=64,
dim_attn=64,
dim_ffn=128,
num_heads=4,
num_layers=2,
num_buckets=32,
shared_pos=False,
)
ids = mx.array([[1, 5, 10, 0, 0]])
mask = mx.array([[1, 1, 1, 0, 0]])
@@ -134,9 +151,16 @@ class TestT5Encoder:
def test_shared_pos(self):
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=True,
vocab_size=100,
dim=64,
dim_attn=64,
dim_ffn=128,
num_heads=4,
num_layers=2,
num_buckets=32,
shared_pos=True,
)
assert encoder.pos_embedding is not None
for block in encoder.blocks:
@@ -144,9 +168,16 @@ class TestT5Encoder:
def test_per_layer_pos(self):
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
vocab_size=100,
dim=64,
dim_attn=64,
dim_ffn=128,
num_heads=4,
num_layers=2,
num_buckets=32,
shared_pos=False,
)
assert encoder.pos_embedding is None
for block in encoder.blocks:
@@ -154,18 +185,32 @@ class TestT5Encoder:
def test_param_count(self):
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
vocab_size=100,
dim=64,
dim_attn=64,
dim_ffn=128,
num_heads=4,
num_layers=2,
num_buckets=32,
shared_pos=False,
)
num_params = sum(p.size for _, p in nn.utils.tree_flatten(encoder.parameters()))
assert num_params > 0
def test_without_mask(self):
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
vocab_size=100,
dim=64,
dim_attn=64,
dim_ffn=128,
num_heads=4,
num_layers=2,
num_buckets=32,
shared_pos=False,
)
ids = mx.array([[1, 5, 10]])
out = encoder(ids)