1454 lines
54 KiB
Python
1454 lines
54 KiB
Python
"""Comprehensive tests for Wan2.2 model components.
|
|
|
|
All tests use small/tiny configurations to avoid needing actual weights.
|
|
"""
|
|
|
|
import math
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
import numpy as np
|
|
import pytest
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestWanModelConfig:
|
|
"""Tests for WanModelConfig dataclass."""
|
|
|
|
def test_default_values(self):
|
|
from mlx_video.models.wan.config import WanModelConfig
|
|
config = WanModelConfig()
|
|
assert config.dim == 5120
|
|
assert config.ffn_dim == 13824
|
|
assert config.num_heads == 40
|
|
assert config.num_layers == 40
|
|
assert config.in_dim == 16
|
|
assert config.out_dim == 16
|
|
assert config.patch_size == (1, 2, 2)
|
|
assert config.vae_stride == (4, 8, 8)
|
|
assert config.vae_z_dim == 16
|
|
assert config.boundary == 0.875
|
|
assert config.sample_shift == 12.0
|
|
assert config.sample_steps == 40
|
|
assert config.sample_guide_scale == (3.0, 4.0)
|
|
assert config.num_train_timesteps == 1000
|
|
assert config.qk_norm is True
|
|
assert config.cross_attn_norm is True
|
|
assert config.text_len == 512
|
|
|
|
def test_head_dim_property(self):
|
|
from mlx_video.models.wan.config import WanModelConfig
|
|
config = WanModelConfig()
|
|
assert config.head_dim == 128 # 5120 // 40
|
|
|
|
def test_to_dict_roundtrip(self):
|
|
from mlx_video.models.wan.config import WanModelConfig
|
|
config = WanModelConfig()
|
|
d = config.to_dict()
|
|
assert isinstance(d, dict)
|
|
assert d["dim"] == 5120
|
|
assert d["patch_size"] == (1, 2, 2)
|
|
assert d["boundary"] == 0.875
|
|
|
|
def test_t5_config_values(self):
|
|
from mlx_video.models.wan.config import WanModelConfig
|
|
config = WanModelConfig()
|
|
assert config.t5_vocab_size == 256384
|
|
assert config.t5_dim == 4096
|
|
assert config.t5_dim_attn == 4096
|
|
assert config.t5_dim_ffn == 10240
|
|
assert config.t5_num_heads == 64
|
|
assert config.t5_num_layers == 24
|
|
assert config.t5_num_buckets == 32
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RoPE Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestRoPE:
|
|
"""Tests for 3-way factorized RoPE."""
|
|
|
|
def test_rope_params_shape(self):
|
|
from mlx_video.models.wan.rope import rope_params
|
|
freqs = rope_params(1024, 64)
|
|
mx.eval(freqs)
|
|
assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2]
|
|
|
|
def test_rope_params_different_dims(self):
|
|
from mlx_video.models.wan.rope import rope_params
|
|
for dim in [32, 64, 128]:
|
|
freqs = rope_params(512, dim)
|
|
mx.eval(freqs)
|
|
assert freqs.shape == (512, dim // 2, 2)
|
|
|
|
def test_rope_params_cos_sin_range(self):
|
|
from mlx_video.models.wan.rope import rope_params
|
|
freqs = rope_params(256, 64)
|
|
mx.eval(freqs)
|
|
cos_vals = np.array(freqs[:, :, 0])
|
|
sin_vals = np.array(freqs[:, :, 1])
|
|
assert np.all(cos_vals >= -1.0) and np.all(cos_vals <= 1.0)
|
|
assert np.all(sin_vals >= -1.0) and np.all(sin_vals <= 1.0)
|
|
|
|
def test_rope_params_position_zero(self):
|
|
"""At position 0, cos should be 1 and sin should be 0."""
|
|
from mlx_video.models.wan.rope import rope_params
|
|
freqs = rope_params(10, 64)
|
|
mx.eval(freqs)
|
|
np.testing.assert_allclose(np.array(freqs[0, :, 0]), 1.0, atol=1e-6)
|
|
np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6)
|
|
|
|
def test_rope_apply_output_shape(self):
|
|
from mlx_video.models.wan.rope import rope_params, rope_apply
|
|
B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim
|
|
x = mx.random.normal((B, L, N, D))
|
|
freqs = rope_params(1024, D)
|
|
grid_sizes = [(2, 3, 4)] # F*H*W = 24 = L
|
|
out = rope_apply(x, grid_sizes, freqs)
|
|
mx.eval(out)
|
|
assert out.shape == (B, L, N, D)
|
|
|
|
def test_rope_apply_preserves_norm(self):
|
|
"""RoPE rotation should preserve vector norms."""
|
|
from mlx_video.models.wan.rope import rope_params, rope_apply
|
|
B, N, D = 1, 2, 16
|
|
F, H, W = 2, 3, 4
|
|
L = F * H * W
|
|
x = mx.random.normal((B, L, N, D))
|
|
freqs = rope_params(1024, D)
|
|
|
|
out = rope_apply(x, [(F, H, W)], freqs)
|
|
mx.eval(x, out)
|
|
|
|
x_np = np.array(x[0])
|
|
out_np = np.array(out[0])
|
|
for i in range(L):
|
|
for h in range(N):
|
|
norm_in = np.linalg.norm(x_np[i, h])
|
|
norm_out = np.linalg.norm(out_np[i, h])
|
|
np.testing.assert_allclose(norm_in, norm_out, rtol=1e-4)
|
|
|
|
def test_rope_apply_with_padding(self):
|
|
"""When seq_len < L, extra tokens should be preserved unchanged."""
|
|
from mlx_video.models.wan.rope import rope_params, rope_apply
|
|
B, N, D = 1, 2, 16
|
|
F, H, W = 2, 2, 2
|
|
seq_len = F * H * W # 8
|
|
pad = 4
|
|
L = seq_len + pad
|
|
x = mx.random.normal((B, L, N, D))
|
|
freqs = rope_params(1024, D)
|
|
|
|
out = rope_apply(x, [(F, H, W)], freqs)
|
|
mx.eval(x, out)
|
|
# Padded tokens should be unchanged
|
|
np.testing.assert_allclose(
|
|
np.array(out[0, seq_len:]),
|
|
np.array(x[0, seq_len:]),
|
|
atol=1e-6,
|
|
)
|
|
|
|
def test_rope_apply_batch(self):
|
|
"""Test with batch_size > 1 and different grid sizes."""
|
|
from mlx_video.models.wan.rope import rope_params, rope_apply
|
|
B, N, D = 2, 2, 16
|
|
grids = [(2, 3, 4), (2, 3, 4)]
|
|
L = 2 * 3 * 4
|
|
x = mx.random.normal((B, L, N, D))
|
|
freqs = rope_params(1024, D)
|
|
|
|
out = rope_apply(x, grids, freqs)
|
|
mx.eval(out)
|
|
assert out.shape == (B, L, N, D)
|
|
|
|
def test_rope_frequency_split(self):
|
|
"""Verify the 3-way frequency dimension split matches Wan2.2 convention."""
|
|
D = 128 # head_dim for 14B model
|
|
half_d = D // 2
|
|
d_t = half_d - 2 * (half_d // 3)
|
|
d_h = half_d // 3
|
|
d_w = half_d // 3
|
|
assert d_t + d_h + d_w == half_d
|
|
# Temporal gets more capacity
|
|
assert d_t >= d_h
|
|
assert d_t >= d_w
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Attention Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestWanRMSNorm:
|
|
def test_output_shape(self):
|
|
from mlx_video.models.wan.attention import WanRMSNorm
|
|
norm = WanRMSNorm(64)
|
|
x = mx.random.normal((2, 10, 64))
|
|
out = norm(x)
|
|
mx.eval(out)
|
|
assert out.shape == (2, 10, 64)
|
|
|
|
def test_zero_mean_variance(self):
|
|
"""RMS norm should make RMS ≈ 1 before scaling."""
|
|
from mlx_video.models.wan.attention import WanRMSNorm
|
|
norm = WanRMSNorm(64)
|
|
x = mx.random.normal((1, 5, 64)) * 10.0
|
|
out = norm(x)
|
|
mx.eval(out)
|
|
out_np = np.array(out[0])
|
|
for i in range(5):
|
|
rms = np.sqrt(np.mean(out_np[i] ** 2))
|
|
# After RMS norm with weight=1, RMS should be ~1
|
|
np.testing.assert_allclose(rms, 1.0, rtol=0.1)
|
|
|
|
def test_dtype_preservation(self):
|
|
"""RMSNorm weight is float32, so output is promoted to float32."""
|
|
from mlx_video.models.wan.attention import WanRMSNorm
|
|
norm = WanRMSNorm(32)
|
|
x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16)
|
|
out = norm(x)
|
|
mx.eval(out)
|
|
# Weight is float32, so multiplication promotes result to float32
|
|
assert out.dtype == mx.float32
|
|
|
|
|
|
class TestWanLayerNorm:
|
|
def test_output_shape(self):
|
|
from mlx_video.models.wan.attention import WanLayerNorm
|
|
norm = WanLayerNorm(64)
|
|
x = mx.random.normal((2, 10, 64))
|
|
out = norm(x)
|
|
mx.eval(out)
|
|
assert out.shape == (2, 10, 64)
|
|
|
|
def test_without_affine(self):
|
|
from mlx_video.models.wan.attention import WanLayerNorm
|
|
norm = WanLayerNorm(64, elementwise_affine=False)
|
|
x = mx.random.normal((1, 4, 64))
|
|
out = norm(x)
|
|
mx.eval(out)
|
|
# Mean should be ~0, variance should be ~1
|
|
out_np = np.array(out[0])
|
|
for i in range(4):
|
|
np.testing.assert_allclose(np.mean(out_np[i]), 0.0, atol=0.05)
|
|
np.testing.assert_allclose(np.std(out_np[i]), 1.0, rtol=0.1)
|
|
|
|
def test_with_affine(self):
|
|
from mlx_video.models.wan.attention import WanLayerNorm
|
|
norm = WanLayerNorm(32, elementwise_affine=True)
|
|
assert hasattr(norm, "weight")
|
|
assert hasattr(norm, "bias")
|
|
x = mx.random.normal((1, 4, 32))
|
|
out = norm(x)
|
|
mx.eval(out)
|
|
assert out.shape == (1, 4, 32)
|
|
|
|
|
|
class TestWanSelfAttention:
|
|
def setup_method(self):
|
|
mx.random.seed(42)
|
|
self.dim = 64
|
|
self.num_heads = 4
|
|
|
|
def test_output_shape(self):
|
|
from mlx_video.models.wan.attention import WanSelfAttention
|
|
from mlx_video.models.wan.rope import rope_params
|
|
attn = WanSelfAttention(self.dim, self.num_heads)
|
|
B, L = 1, 24
|
|
F, H, W = 2, 3, 4
|
|
x = mx.random.normal((B, L, self.dim))
|
|
freqs = rope_params(1024, self.dim // self.num_heads)
|
|
out = attn(x, seq_lens=[L], grid_sizes=[(F, H, W)], freqs=freqs)
|
|
mx.eval(out)
|
|
assert out.shape == (B, L, self.dim)
|
|
|
|
def test_with_qk_norm(self):
|
|
from mlx_video.models.wan.attention import WanSelfAttention
|
|
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True)
|
|
assert attn.norm_q is not None
|
|
assert attn.norm_k is not None
|
|
|
|
def test_without_qk_norm(self):
|
|
from mlx_video.models.wan.attention import WanSelfAttention
|
|
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
|
|
assert attn.norm_q is None
|
|
assert attn.norm_k is None
|
|
|
|
def test_masking(self):
|
|
"""Test that masking works: shorter seq_lens should mask later tokens."""
|
|
from mlx_video.models.wan.attention import WanSelfAttention
|
|
from mlx_video.models.wan.rope import rope_params
|
|
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
|
|
B, L = 1, 24
|
|
F, H, W = 2, 3, 4
|
|
x = mx.random.normal((B, L, self.dim))
|
|
freqs = rope_params(1024, self.dim // self.num_heads)
|
|
|
|
# Full sequence
|
|
out_full = attn(x, seq_lens=[L], grid_sizes=[(F, H, W)], freqs=freqs)
|
|
# Shorter sequence (mask last 4 tokens)
|
|
out_masked = attn(x, seq_lens=[L - 4], grid_sizes=[(F, H, W)], freqs=freqs)
|
|
mx.eval(out_full, out_masked)
|
|
|
|
# Outputs should differ when masking is applied
|
|
assert not np.allclose(np.array(out_full), np.array(out_masked), atol=1e-5)
|
|
|
|
|
|
class TestWanCrossAttention:
|
|
def setup_method(self):
|
|
mx.random.seed(42)
|
|
self.dim = 64
|
|
self.num_heads = 4
|
|
|
|
def test_output_shape(self):
|
|
from mlx_video.models.wan.attention import WanCrossAttention
|
|
attn = WanCrossAttention(self.dim, self.num_heads)
|
|
B, L_q, L_kv = 1, 24, 16
|
|
x = mx.random.normal((B, L_q, self.dim))
|
|
context = mx.random.normal((B, L_kv, self.dim))
|
|
out = attn(x, context)
|
|
mx.eval(out)
|
|
assert out.shape == (B, L_q, self.dim)
|
|
|
|
def test_with_context_mask(self):
|
|
from mlx_video.models.wan.attention import WanCrossAttention
|
|
attn = WanCrossAttention(self.dim, self.num_heads)
|
|
B, L_q, L_kv = 1, 12, 16
|
|
x = mx.random.normal((B, L_q, self.dim))
|
|
context = mx.random.normal((B, L_kv, self.dim))
|
|
out = attn(x, context, context_lens=[10])
|
|
mx.eval(out)
|
|
assert out.shape == (B, L_q, self.dim)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Transformer Block Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestWanFFN:
|
|
def test_output_shape(self):
|
|
from mlx_video.models.wan.transformer import WanFFN
|
|
ffn = WanFFN(64, 256)
|
|
x = mx.random.normal((2, 10, 64))
|
|
out = ffn(x)
|
|
mx.eval(out)
|
|
assert out.shape == (2, 10, 64)
|
|
|
|
def test_gelu_activation(self):
|
|
"""FFN should use GELU activation (non-linearity)."""
|
|
from mlx_video.models.wan.transformer import WanFFN
|
|
ffn = WanFFN(32, 128)
|
|
x = mx.ones((1, 1, 32)) * 2.0
|
|
out1 = ffn(x)
|
|
x2 = mx.ones((1, 1, 32)) * 4.0
|
|
out2 = ffn(x2)
|
|
mx.eval(out1, out2)
|
|
# Non-linear: 2x input should not give 2x output
|
|
assert not np.allclose(np.array(out2), np.array(out1) * 2.0, rtol=0.1)
|
|
|
|
|
|
class TestWanAttentionBlock:
|
|
def setup_method(self):
|
|
mx.random.seed(42)
|
|
self.dim = 64
|
|
self.ffn_dim = 128
|
|
self.num_heads = 4
|
|
|
|
def test_output_shape(self):
|
|
from mlx_video.models.wan.transformer import WanAttentionBlock
|
|
from mlx_video.models.wan.rope import rope_params
|
|
block = WanAttentionBlock(
|
|
self.dim, self.ffn_dim, self.num_heads,
|
|
cross_attn_norm=True,
|
|
)
|
|
B, L = 1, 24
|
|
F, H, W = 2, 3, 4
|
|
x = mx.random.normal((B, L, self.dim))
|
|
e = mx.random.normal((B, L, 6, self.dim))
|
|
context = mx.random.normal((B, 16, self.dim))
|
|
freqs = rope_params(1024, self.dim // self.num_heads)
|
|
|
|
out = block(
|
|
x, e, seq_lens=[L], grid_sizes=[(F, H, W)],
|
|
freqs=freqs, context=context,
|
|
)
|
|
mx.eval(out)
|
|
assert out.shape == (B, L, self.dim)
|
|
|
|
def test_modulation_shape(self):
|
|
from mlx_video.models.wan.transformer import WanAttentionBlock
|
|
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
|
|
assert block.modulation.shape == (1, 6, self.dim)
|
|
|
|
def test_with_cross_attn_norm(self):
|
|
from mlx_video.models.wan.transformer import WanAttentionBlock
|
|
block = WanAttentionBlock(
|
|
self.dim, self.ffn_dim, self.num_heads,
|
|
cross_attn_norm=True,
|
|
)
|
|
assert block.norm3 is not None
|
|
|
|
def test_without_cross_attn_norm(self):
|
|
from mlx_video.models.wan.transformer import WanAttentionBlock
|
|
block = WanAttentionBlock(
|
|
self.dim, self.ffn_dim, self.num_heads,
|
|
cross_attn_norm=False,
|
|
)
|
|
assert block.norm3 is None
|
|
|
|
def test_residual_connection(self):
|
|
"""Output should differ from zero even with small random init."""
|
|
from mlx_video.models.wan.transformer import WanAttentionBlock
|
|
from mlx_video.models.wan.rope import rope_params
|
|
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
|
|
B, L = 1, 8
|
|
F, H, W = 2, 2, 2
|
|
x = mx.ones((B, L, self.dim))
|
|
e = mx.zeros((B, L, 6, self.dim))
|
|
context = mx.random.normal((B, 4, self.dim))
|
|
freqs = rope_params(1024, self.dim // self.num_heads)
|
|
|
|
out = block(x, e, [L], [(F, H, W)], freqs, context)
|
|
mx.eval(out)
|
|
# With residual connections, output should be close to input + corrections
|
|
assert not np.allclose(np.array(out), 0.0, atol=1e-3)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Sinusoidal Embedding Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestSinusoidalEmbedding:
|
|
def test_output_shape(self):
|
|
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
|
pos = mx.arange(10).astype(mx.float32)
|
|
emb = sinusoidal_embedding_1d(256, pos)
|
|
mx.eval(emb)
|
|
assert emb.shape == (10, 256)
|
|
|
|
def test_position_zero(self):
|
|
"""Position 0 should have cos=1 for all dims and sin=0."""
|
|
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
|
pos = mx.array([0.0])
|
|
emb = sinusoidal_embedding_1d(64, pos)
|
|
mx.eval(emb)
|
|
emb_np = np.array(emb[0])
|
|
# First half is cos, should be 1 at position 0
|
|
np.testing.assert_allclose(emb_np[:32], 1.0, atol=1e-5)
|
|
# Second half is sin, should be 0 at position 0
|
|
np.testing.assert_allclose(emb_np[32:], 0.0, atol=1e-5)
|
|
|
|
def test_different_positions_differ(self):
|
|
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
|
pos = mx.array([0.0, 100.0, 999.0])
|
|
emb = sinusoidal_embedding_1d(128, pos)
|
|
mx.eval(emb)
|
|
emb_np = np.array(emb)
|
|
assert not np.allclose(emb_np[0], emb_np[1])
|
|
assert not np.allclose(emb_np[1], emb_np[2])
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Head Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestHead:
|
|
def test_output_shape(self):
|
|
from mlx_video.models.wan.model import Head
|
|
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
|
B, L = 1, 24
|
|
x = mx.random.normal((B, L, 64))
|
|
e = mx.random.normal((B, 64)) # time embedding: [B, dim]
|
|
out = head(x, e)
|
|
mx.eval(out)
|
|
expected_proj_dim = 16 * 1 * 2 * 2 # 64
|
|
assert out.shape == (B, L, expected_proj_dim)
|
|
|
|
def test_modulation_shape(self):
|
|
from mlx_video.models.wan.model import Head
|
|
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
|
assert head.modulation.shape == (1, 2, 64)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# WanModel (Tiny) Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _make_tiny_config():
|
|
"""Create a tiny WanModelConfig for testing."""
|
|
from mlx_video.models.wan.config import WanModelConfig
|
|
config = WanModelConfig()
|
|
# Override to tiny values
|
|
config.dim = 64
|
|
config.ffn_dim = 128
|
|
config.num_heads = 4
|
|
config.num_layers = 2
|
|
config.in_dim = 4
|
|
config.out_dim = 4
|
|
config.patch_size = (1, 2, 2)
|
|
config.freq_dim = 32
|
|
config.text_dim = 32
|
|
config.text_len = 8
|
|
return config
|
|
|
|
|
|
class TestWanModel:
|
|
def setup_method(self):
|
|
mx.random.seed(42)
|
|
|
|
def test_instantiation(self):
|
|
from mlx_video.models.wan.model import WanModel
|
|
config = _make_tiny_config()
|
|
model = WanModel(config)
|
|
num_params = sum(p.size for _, p in nn.utils.tree_flatten(model.parameters()))
|
|
assert num_params > 0
|
|
|
|
def test_patchify_shape(self):
|
|
from mlx_video.models.wan.model import WanModel
|
|
config = _make_tiny_config()
|
|
model = WanModel(config)
|
|
# Input: [C=4, F=1, H=4, W=4]
|
|
x = mx.random.normal((4, 1, 4, 4))
|
|
patches, grid_size = model._patchify(x)
|
|
mx.eval(patches)
|
|
# Patch size (1,2,2): F'=1, H'=2, W'=2
|
|
assert grid_size == (1, 2, 2)
|
|
assert patches.shape == (1, 1 * 2 * 2, config.dim)
|
|
|
|
def test_patchify_various_sizes(self):
|
|
from mlx_video.models.wan.model import WanModel
|
|
config = _make_tiny_config()
|
|
model = WanModel(config)
|
|
for f, h, w in [(1, 4, 4), (2, 6, 8), (3, 4, 6)]:
|
|
x = mx.random.normal((config.in_dim, f, h, w))
|
|
patches, (gf, gh, gw) = model._patchify(x)
|
|
mx.eval(patches)
|
|
pt, ph, pw = config.patch_size
|
|
assert gf == f // pt
|
|
assert gh == h // ph
|
|
assert gw == w // pw
|
|
assert patches.shape[1] == gf * gh * gw
|
|
|
|
def test_unpatchify_inverse(self):
|
|
"""Patchify then unpatchify should reconstruct original spatial dims."""
|
|
from mlx_video.models.wan.model import WanModel
|
|
config = _make_tiny_config()
|
|
model = WanModel(config)
|
|
C, F, H, W = config.in_dim, 2, 4, 6
|
|
pt, ph, pw = config.patch_size
|
|
F_out, H_out, W_out = F // pt, H // ph, W // pw
|
|
L = F_out * H_out * W_out
|
|
proj_dim = config.out_dim * pt * ph * pw
|
|
# Simulated head output
|
|
x = mx.random.normal((1, L, proj_dim))
|
|
out = model.unpatchify(x, [(F_out, H_out, W_out)])
|
|
mx.eval(out[0])
|
|
assert out[0].shape == (config.out_dim, F, H, W)
|
|
|
|
def test_forward_pass(self):
|
|
from mlx_video.models.wan.model import WanModel
|
|
config = _make_tiny_config()
|
|
model = WanModel(config)
|
|
C, F, H, W = config.in_dim, 1, 4, 4
|
|
pt, ph, pw = config.patch_size
|
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
|
|
|
x_list = [mx.random.normal((C, F, H, W))]
|
|
t = mx.array([500.0])
|
|
context = [mx.random.normal((6, config.text_dim))]
|
|
|
|
out = model(x_list, t, context, seq_len)
|
|
mx.eval(out[0])
|
|
assert len(out) == 1
|
|
assert out[0].shape == (C, F, H, W)
|
|
|
|
def test_forward_batch(self):
|
|
from mlx_video.models.wan.model import WanModel
|
|
config = _make_tiny_config()
|
|
model = WanModel(config)
|
|
C, F, H, W = config.in_dim, 1, 4, 4
|
|
pt, ph, pw = config.patch_size
|
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
|
|
|
x_list = [mx.random.normal((C, F, H, W)), mx.random.normal((C, F, H, W))]
|
|
t = mx.array([500.0, 200.0])
|
|
context = [mx.random.normal((6, config.text_dim)), mx.random.normal((4, config.text_dim))]
|
|
|
|
out = model(x_list, t, context, seq_len)
|
|
mx.eval(out[0], out[1])
|
|
assert len(out) == 2
|
|
for o in out:
|
|
assert o.shape == (C, F, H, W)
|
|
|
|
def test_output_is_float32(self):
|
|
from mlx_video.models.wan.model import WanModel
|
|
config = _make_tiny_config()
|
|
model = WanModel(config)
|
|
C, F, H, W = config.in_dim, 1, 4, 4
|
|
seq_len = (F // 1) * (H // 2) * (W // 2)
|
|
out = model([mx.random.normal((C, F, H, W))], mx.array([100.0]),
|
|
[mx.random.normal((4, config.text_dim))], seq_len)
|
|
mx.eval(out[0])
|
|
assert out[0].dtype == mx.float32
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# T5 Encoder Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestT5LayerNorm:
|
|
def test_output_shape(self):
|
|
from mlx_video.models.wan.text_encoder import T5LayerNorm
|
|
norm = T5LayerNorm(64)
|
|
x = mx.random.normal((2, 10, 64))
|
|
out = norm(x)
|
|
mx.eval(out)
|
|
assert out.shape == (2, 10, 64)
|
|
|
|
def test_rms_normalization(self):
|
|
"""After T5LayerNorm with weight=1, RMS should be ~1."""
|
|
from mlx_video.models.wan.text_encoder import T5LayerNorm
|
|
norm = T5LayerNorm(128)
|
|
x = mx.random.normal((1, 5, 128)) * 5.0
|
|
out = norm(x)
|
|
mx.eval(out)
|
|
out_np = np.array(out[0])
|
|
for i in range(5):
|
|
rms = np.sqrt(np.mean(out_np[i] ** 2))
|
|
np.testing.assert_allclose(rms, 1.0, rtol=0.1)
|
|
|
|
|
|
class TestT5RelativeEmbedding:
|
|
def test_output_shape(self):
|
|
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
|
|
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
|
|
out = rel_emb(10, 10)
|
|
mx.eval(out)
|
|
assert out.shape == (1, 4, 10, 10) # [1, N, lq, lk]
|
|
|
|
def test_asymmetric_lengths(self):
|
|
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
|
|
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
|
|
out = rel_emb(8, 12)
|
|
mx.eval(out)
|
|
assert out.shape == (1, 4, 8, 12)
|
|
|
|
def test_symmetry(self):
|
|
"""Position bias should have structure (not all zeros/random)."""
|
|
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
|
|
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2)
|
|
out = rel_emb(6, 6)
|
|
mx.eval(out)
|
|
out_np = np.array(out[0]) # [N, lq, lk]
|
|
# Diagonal elements (position i attending to position i) should be consistent
|
|
# (same relative distance = 0 for all diagonal elements)
|
|
for h in range(2):
|
|
diag = np.diag(out_np[h])
|
|
np.testing.assert_allclose(diag, diag[0], atol=1e-5)
|
|
|
|
|
|
class TestT5Attention:
|
|
def test_output_shape(self):
|
|
from mlx_video.models.wan.text_encoder import T5Attention
|
|
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
|
x = mx.random.normal((1, 10, 64))
|
|
out = attn(x)
|
|
mx.eval(out)
|
|
assert out.shape == (1, 10, 64)
|
|
|
|
def test_no_scaling(self):
|
|
"""T5 attention famously has no sqrt(d) scaling. Verify structure."""
|
|
from mlx_video.models.wan.text_encoder import T5Attention
|
|
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
|
# No scale attribute (unlike standard attention)
|
|
assert not hasattr(attn, "scale")
|
|
|
|
def test_with_position_bias(self):
|
|
from mlx_video.models.wan.text_encoder import T5Attention, T5RelativeEmbedding
|
|
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
|
rel_emb = T5RelativeEmbedding(32, 4)
|
|
x = mx.random.normal((1, 10, 64))
|
|
pos_bias = rel_emb(10, 10)
|
|
out = attn(x, pos_bias=pos_bias)
|
|
mx.eval(out)
|
|
assert out.shape == (1, 10, 64)
|
|
|
|
def test_with_mask(self):
|
|
from mlx_video.models.wan.text_encoder import T5Attention
|
|
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
|
x = mx.random.normal((1, 10, 64))
|
|
mask = mx.ones((1, 10))
|
|
mask = mx.concatenate([mask[:, :7], mx.zeros((1, 3))], axis=1)
|
|
out = attn(x, mask=mask)
|
|
mx.eval(out)
|
|
assert out.shape == (1, 10, 64)
|
|
|
|
|
|
class TestT5FeedForward:
|
|
def test_output_shape(self):
|
|
from mlx_video.models.wan.text_encoder import T5FeedForward
|
|
ffn = T5FeedForward(64, 256)
|
|
x = mx.random.normal((1, 10, 64))
|
|
out = ffn(x)
|
|
mx.eval(out)
|
|
assert out.shape == (1, 10, 64)
|
|
|
|
def test_gated_structure(self):
|
|
"""T5 FFN is gated: gate(x) * fc1(x)."""
|
|
from mlx_video.models.wan.text_encoder import T5FeedForward
|
|
ffn = T5FeedForward(32, 64)
|
|
assert hasattr(ffn, "gate_proj")
|
|
assert hasattr(ffn, "fc1")
|
|
assert hasattr(ffn, "fc2")
|
|
|
|
|
|
class TestT5Encoder:
|
|
def setup_method(self):
|
|
mx.random.seed(42)
|
|
|
|
def test_output_shape(self):
|
|
from mlx_video.models.wan.text_encoder import T5Encoder
|
|
encoder = T5Encoder(
|
|
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
|
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
|
|
)
|
|
ids = mx.array([[1, 5, 10, 0, 0]])
|
|
mask = mx.array([[1, 1, 1, 0, 0]])
|
|
out = encoder(ids, mask=mask)
|
|
mx.eval(out)
|
|
assert out.shape == (1, 5, 64)
|
|
|
|
def test_shared_pos(self):
|
|
from mlx_video.models.wan.text_encoder import T5Encoder
|
|
encoder = T5Encoder(
|
|
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
|
num_heads=4, num_layers=2, num_buckets=32, shared_pos=True,
|
|
)
|
|
assert encoder.pos_embedding is not None
|
|
for block in encoder.blocks:
|
|
assert block.pos_embedding is None
|
|
|
|
def test_per_layer_pos(self):
|
|
from mlx_video.models.wan.text_encoder import T5Encoder
|
|
encoder = T5Encoder(
|
|
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
|
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
|
|
)
|
|
assert encoder.pos_embedding is None
|
|
for block in encoder.blocks:
|
|
assert block.pos_embedding is not None
|
|
|
|
def test_param_count(self):
|
|
from mlx_video.models.wan.text_encoder import T5Encoder
|
|
encoder = T5Encoder(
|
|
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
|
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
|
|
)
|
|
num_params = sum(p.size for _, p in nn.utils.tree_flatten(encoder.parameters()))
|
|
assert num_params > 0
|
|
|
|
def test_without_mask(self):
|
|
from mlx_video.models.wan.text_encoder import T5Encoder
|
|
encoder = T5Encoder(
|
|
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
|
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
|
|
)
|
|
ids = mx.array([[1, 5, 10]])
|
|
out = encoder(ids)
|
|
mx.eval(out)
|
|
assert out.shape == (1, 3, 64)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# VAE Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestCausalConv3d:
|
|
def test_output_shape_stride1(self):
|
|
from mlx_video.models.wan.vae import CausalConv3d
|
|
conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1)
|
|
# Initialize weights
|
|
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
|
|
x = mx.random.normal((1, 4, 3, 8, 8)) # [B, C, T, H, W]
|
|
out = conv(x)
|
|
mx.eval(out)
|
|
# With causal padding and padding=1 on spatial, dims should be preserved
|
|
assert out.shape[0] == 1
|
|
assert out.shape[1] == 8 # out_channels
|
|
assert out.shape[2] == 3 # T preserved
|
|
assert out.shape[3] == 8 # H preserved
|
|
assert out.shape[4] == 8 # W preserved
|
|
|
|
def test_output_shape_kernel1(self):
|
|
from mlx_video.models.wan.vae import CausalConv3d
|
|
conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0)
|
|
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
|
|
x = mx.random.normal((1, 4, 2, 4, 4))
|
|
out = conv(x)
|
|
mx.eval(out)
|
|
assert out.shape == (1, 8, 2, 4, 4)
|
|
|
|
def test_causal_padding(self):
|
|
"""Causal conv should only use past/current frames, not future."""
|
|
from mlx_video.models.wan.vae import CausalConv3d
|
|
conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1)
|
|
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
|
|
conv.bias = mx.zeros((2,))
|
|
# Create input where only the first frame has signal
|
|
x = mx.zeros((1, 2, 4, 4, 4))
|
|
x_np = np.zeros((1, 2, 4, 4, 4), dtype=np.float32)
|
|
x_np[:, :, 0, :, :] = 1.0
|
|
x = mx.array(x_np)
|
|
out = conv(x)
|
|
mx.eval(out)
|
|
# Due to causal padding, the output at t=0 should only depend on t=0
|
|
|
|
|
|
class TestResidualBlock:
|
|
def test_same_dim(self):
|
|
from mlx_video.models.wan.vae import ResidualBlock
|
|
block = ResidualBlock(8, 8)
|
|
x = mx.random.normal((1, 8, 2, 4, 4))
|
|
out = block(x)
|
|
mx.eval(out)
|
|
assert out.shape == (1, 8, 2, 4, 4)
|
|
|
|
def test_different_dim(self):
|
|
from mlx_video.models.wan.vae import ResidualBlock
|
|
block = ResidualBlock(8, 16)
|
|
x = mx.random.normal((1, 8, 2, 4, 4))
|
|
out = block(x)
|
|
mx.eval(out)
|
|
assert out.shape == (1, 16, 2, 4, 4)
|
|
|
|
def test_shortcut_exists_when_dims_differ(self):
|
|
from mlx_video.models.wan.vae import ResidualBlock
|
|
block = ResidualBlock(8, 16)
|
|
assert block.shortcut is not None
|
|
|
|
def test_no_shortcut_when_dims_same(self):
|
|
from mlx_video.models.wan.vae import ResidualBlock
|
|
block = ResidualBlock(8, 8)
|
|
assert block.shortcut is None
|
|
|
|
|
|
class TestAttentionBlock:
|
|
def test_output_shape(self):
|
|
from mlx_video.models.wan.vae import AttentionBlock
|
|
block = AttentionBlock(8)
|
|
x = mx.random.normal((1, 8, 2, 4, 4))
|
|
out = block(x)
|
|
mx.eval(out)
|
|
assert out.shape == (1, 8, 2, 4, 4)
|
|
|
|
def test_residual_connection(self):
|
|
from mlx_video.models.wan.vae import AttentionBlock
|
|
block = AttentionBlock(8)
|
|
x = mx.random.normal((1, 8, 1, 3, 3))
|
|
out = block(x)
|
|
mx.eval(x, out)
|
|
# Residual: output should not be zero even with random init
|
|
assert np.abs(np.array(out)).max() > 0
|
|
|
|
|
|
class TestWanVAE:
|
|
def test_instantiation(self):
|
|
from mlx_video.models.wan.vae import WanVAE
|
|
vae = WanVAE(z_dim=16)
|
|
assert vae.z_dim == 16
|
|
assert vae.mean.shape == (16,)
|
|
assert vae.std.shape == (16,)
|
|
|
|
def test_normalization_stats(self):
|
|
from mlx_video.models.wan.vae import WanVAE, VAE_MEAN, VAE_STD
|
|
assert len(VAE_MEAN) == 16
|
|
assert len(VAE_STD) == 16
|
|
assert all(s > 0 for s in VAE_STD)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Scheduler Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestFlowMatchEulerScheduler:
|
|
def test_initialization(self):
|
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
|
sched = FlowMatchEulerScheduler()
|
|
assert sched.num_train_timesteps == 1000
|
|
assert sched.timesteps is None
|
|
assert sched.sigmas is None
|
|
|
|
def test_set_timesteps(self):
|
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
|
sched = FlowMatchEulerScheduler()
|
|
sched.set_timesteps(40, shift=12.0)
|
|
mx.eval(sched.timesteps, sched.sigmas)
|
|
assert sched.timesteps.shape == (40,)
|
|
assert sched.sigmas.shape == (41,) # 40 steps + terminal
|
|
|
|
def test_timesteps_decreasing(self):
|
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
|
sched = FlowMatchEulerScheduler()
|
|
sched.set_timesteps(40, shift=12.0)
|
|
mx.eval(sched.timesteps)
|
|
ts = np.array(sched.timesteps)
|
|
# Timesteps should be monotonically decreasing
|
|
assert np.all(np.diff(ts) < 0), f"Timesteps not decreasing: {ts[:5]}..."
|
|
|
|
def test_sigmas_decreasing(self):
|
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
|
sched = FlowMatchEulerScheduler()
|
|
sched.set_timesteps(20, shift=1.0)
|
|
mx.eval(sched.sigmas)
|
|
sigmas = np.array(sched.sigmas)
|
|
assert np.all(np.diff(sigmas) <= 0), "Sigmas not decreasing"
|
|
|
|
def test_terminal_sigma_is_zero(self):
|
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
|
sched = FlowMatchEulerScheduler()
|
|
sched.set_timesteps(20, shift=5.0)
|
|
mx.eval(sched.sigmas)
|
|
np.testing.assert_allclose(np.array(sched.sigmas[-1]), 0.0, atol=1e-6)
|
|
|
|
def test_shift_effect(self):
|
|
"""Larger shift should push sigmas toward higher values."""
|
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
|
sched1 = FlowMatchEulerScheduler()
|
|
sched2 = FlowMatchEulerScheduler()
|
|
sched1.set_timesteps(20, shift=1.0)
|
|
sched2.set_timesteps(20, shift=12.0)
|
|
mx.eval(sched1.sigmas, sched2.sigmas)
|
|
mean1 = np.mean(np.array(sched1.sigmas[:-1]))
|
|
mean2 = np.mean(np.array(sched2.sigmas[:-1]))
|
|
assert mean2 > mean1, "Higher shift should push sigmas higher"
|
|
|
|
def test_step_euler(self):
|
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
|
sched = FlowMatchEulerScheduler()
|
|
sched.set_timesteps(10, shift=1.0)
|
|
mx.eval(sched.sigmas)
|
|
|
|
sample = mx.ones((1, 4, 2, 2, 2))
|
|
velocity = mx.ones((1, 4, 2, 2, 2)) * 0.5
|
|
timestep = sched.timesteps[0]
|
|
|
|
sigma = float(np.array(sched.sigmas[0]))
|
|
sigma_next = float(np.array(sched.sigmas[1]))
|
|
|
|
result = sched.step(velocity, timestep, sample)
|
|
mx.eval(result)
|
|
|
|
# Euler: x_next = x + (sigma_next - sigma) * v
|
|
expected = 1.0 + (sigma_next - sigma) * 0.5
|
|
np.testing.assert_allclose(
|
|
np.array(result).flatten()[0], expected, rtol=1e-4,
|
|
)
|
|
|
|
def test_step_index_increments(self):
|
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
|
sched = FlowMatchEulerScheduler()
|
|
sched.set_timesteps(5, shift=1.0)
|
|
assert sched._step_index == 0
|
|
sample = mx.ones((1, 1, 1, 1, 1))
|
|
vel = mx.zeros((1, 1, 1, 1, 1))
|
|
sched.step(vel, sched.timesteps[0], sample)
|
|
assert sched._step_index == 1
|
|
sched.step(vel, sched.timesteps[1], sample)
|
|
assert sched._step_index == 2
|
|
|
|
def test_reset(self):
|
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
|
sched = FlowMatchEulerScheduler()
|
|
sched.set_timesteps(5, shift=1.0)
|
|
sample = mx.ones((1, 1, 1, 1, 1))
|
|
vel = mx.zeros((1, 1, 1, 1, 1))
|
|
sched.step(vel, sched.timesteps[0], sample)
|
|
assert sched._step_index == 1
|
|
sched.reset()
|
|
assert sched._step_index == 0
|
|
|
|
@pytest.mark.parametrize("steps", [10, 20, 40, 50])
|
|
def test_various_step_counts(self, steps):
|
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
|
sched = FlowMatchEulerScheduler()
|
|
sched.set_timesteps(steps, shift=12.0)
|
|
mx.eval(sched.timesteps, sched.sigmas)
|
|
assert sched.timesteps.shape == (steps,)
|
|
assert sched.sigmas.shape == (steps + 1,)
|
|
|
|
def test_full_denoise_loop(self):
|
|
"""Run a complete denoise loop with zero velocity -> sample unchanged."""
|
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
|
sched = FlowMatchEulerScheduler()
|
|
sched.set_timesteps(5, shift=1.0)
|
|
sample = mx.ones((1, 2, 1, 2, 2))
|
|
for i in range(5):
|
|
vel = mx.zeros_like(sample)
|
|
sample = sched.step(vel, sched.timesteps[i], sample)
|
|
mx.eval(sample)
|
|
# With zero velocity, sample should remain unchanged
|
|
np.testing.assert_allclose(np.array(sample), 1.0, atol=1e-5)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Weight Conversion Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestSanitizeTransformerWeights:
|
|
def test_patch_embedding_reshape(self):
|
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
|
weights = {
|
|
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
|
|
"patch_embedding.bias": mx.random.normal((5120,)),
|
|
}
|
|
out = sanitize_wan_transformer_weights(weights)
|
|
assert "patch_embedding_proj.weight" in out
|
|
assert "patch_embedding_proj.bias" in out
|
|
assert out["patch_embedding_proj.weight"].shape == (5120, 16 * 1 * 2 * 2)
|
|
|
|
def test_text_embedding_rename(self):
|
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
|
weights = {
|
|
"text_embedding.0.weight": mx.zeros((64, 32)),
|
|
"text_embedding.0.bias": mx.zeros((64,)),
|
|
"text_embedding.2.weight": mx.zeros((64, 64)),
|
|
"text_embedding.2.bias": mx.zeros((64,)),
|
|
}
|
|
out = sanitize_wan_transformer_weights(weights)
|
|
assert "text_embedding_0.weight" in out
|
|
assert "text_embedding_0.bias" in out
|
|
assert "text_embedding_1.weight" in out
|
|
assert "text_embedding_1.bias" in out
|
|
|
|
def test_time_embedding_rename(self):
|
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
|
weights = {
|
|
"time_embedding.0.weight": mx.zeros((64, 32)),
|
|
"time_embedding.2.weight": mx.zeros((64, 64)),
|
|
}
|
|
out = sanitize_wan_transformer_weights(weights)
|
|
assert "time_embedding_0.weight" in out
|
|
assert "time_embedding_1.weight" in out
|
|
|
|
def test_time_projection_rename(self):
|
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
|
weights = {
|
|
"time_projection.1.weight": mx.zeros((384, 64)),
|
|
"time_projection.1.bias": mx.zeros((384,)),
|
|
}
|
|
out = sanitize_wan_transformer_weights(weights)
|
|
assert "time_projection.weight" in out
|
|
assert "time_projection.bias" in out
|
|
|
|
def test_ffn_rename(self):
|
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
|
weights = {
|
|
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
|
|
"blocks.0.ffn.0.bias": mx.zeros((128,)),
|
|
"blocks.0.ffn.2.weight": mx.zeros((64, 128)),
|
|
"blocks.0.ffn.2.bias": mx.zeros((64,)),
|
|
}
|
|
out = sanitize_wan_transformer_weights(weights)
|
|
assert "blocks.0.ffn.fc1.weight" in out
|
|
assert "blocks.0.ffn.fc1.bias" in out
|
|
assert "blocks.0.ffn.fc2.weight" in out
|
|
assert "blocks.0.ffn.fc2.bias" in out
|
|
|
|
def test_freqs_skipped(self):
|
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
|
weights = {
|
|
"freqs": mx.zeros((1024, 64, 2)),
|
|
"blocks.0.norm1.weight": mx.zeros((64,)),
|
|
}
|
|
out = sanitize_wan_transformer_weights(weights)
|
|
assert "freqs" not in out
|
|
assert "blocks.0.norm1.weight" in out
|
|
|
|
def test_passthrough_keys(self):
|
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
|
weights = {
|
|
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
|
|
"blocks.0.self_attn.k.weight": mx.zeros((64, 64)),
|
|
"blocks.0.self_attn.v.weight": mx.zeros((64, 64)),
|
|
"blocks.0.self_attn.o.weight": mx.zeros((64, 64)),
|
|
"blocks.0.modulation": mx.zeros((1, 6, 64)),
|
|
"head.head.weight": mx.zeros((64, 64)),
|
|
"head.modulation": mx.zeros((1, 2, 64)),
|
|
}
|
|
out = sanitize_wan_transformer_weights(weights)
|
|
for key in weights:
|
|
assert key in out
|
|
|
|
|
|
class TestSanitizeT5Weights:
|
|
def test_gate_rename(self):
|
|
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
|
weights = {
|
|
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
|
|
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
|
|
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
|
|
}
|
|
out = sanitize_wan_t5_weights(weights)
|
|
assert "blocks.0.ffn.gate_proj.weight" in out
|
|
assert "blocks.0.ffn.fc1.weight" in out
|
|
assert "blocks.0.ffn.fc2.weight" in out
|
|
|
|
def test_passthrough(self):
|
|
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
|
weights = {
|
|
"token_embedding.weight": mx.zeros((100, 64)),
|
|
"blocks.0.attn.q.weight": mx.zeros((64, 64)),
|
|
"norm.weight": mx.zeros((64,)),
|
|
}
|
|
out = sanitize_wan_t5_weights(weights)
|
|
for key in weights:
|
|
assert key in out
|
|
|
|
|
|
class TestSanitizeVAEWeights:
|
|
def test_conv3d_transpose(self):
|
|
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
|
weights = {
|
|
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W]
|
|
}
|
|
out = sanitize_wan_vae_weights(weights)
|
|
assert out["decoder.conv1.weight"].shape == (8, 3, 3, 3, 4) # [O, D, H, W, I]
|
|
|
|
def test_conv2d_transpose(self):
|
|
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
|
weights = {
|
|
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W]
|
|
}
|
|
out = sanitize_wan_vae_weights(weights)
|
|
assert out["decoder.proj.weight"].shape == (16, 3, 3, 8) # [O, H, W, I]
|
|
|
|
def test_non_conv_passthrough(self):
|
|
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
|
weights = {
|
|
"decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose
|
|
"decoder.bias": mx.zeros((16,)),
|
|
}
|
|
out = sanitize_wan_vae_weights(weights)
|
|
assert out["decoder.norm.weight"].shape == (64,)
|
|
assert out["decoder.bias"].shape == (16,)
|
|
|
|
def test_mixed_weights(self):
|
|
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
|
weights = {
|
|
"conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D
|
|
"conv2d.weight": mx.zeros((8, 4, 3, 3)), # 4D
|
|
"linear.weight": mx.zeros((8, 4)), # 2D
|
|
"norm.weight": mx.zeros((8,)), # 1D
|
|
}
|
|
out = sanitize_wan_vae_weights(weights)
|
|
assert out["conv3d.weight"].shape == (8, 3, 3, 3, 4)
|
|
assert out["conv2d.weight"].shape == (8, 3, 3, 4)
|
|
assert out["linear.weight"].shape == (8, 4)
|
|
assert out["norm.weight"].shape == (8,)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Integration: end-to-end tiny model forward pass
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestEndToEnd:
|
|
"""End-to-end test with tiny model (no real weights needed)."""
|
|
|
|
def test_tiny_model_denoise_step(self):
|
|
"""Simulate one denoising step with tiny model."""
|
|
from mlx_video.models.wan.model import WanModel
|
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
|
|
|
mx.random.seed(42)
|
|
config = _make_tiny_config()
|
|
model = WanModel(config)
|
|
|
|
C, F, H, W = config.in_dim, 1, 4, 4
|
|
pt, ph, pw = config.patch_size
|
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
|
|
|
sched = FlowMatchEulerScheduler()
|
|
sched.set_timesteps(5, shift=3.0)
|
|
|
|
latents = mx.random.normal((C, F, H, W))
|
|
context = mx.random.normal((4, config.text_dim))
|
|
|
|
# One step
|
|
t = sched.timesteps[0]
|
|
pred = model([latents], mx.array([t.item()]), [context], seq_len)[0]
|
|
latents_next = sched.step(pred[None], t, latents[None]).squeeze(0)
|
|
mx.eval(latents_next)
|
|
|
|
assert latents_next.shape == (C, F, H, W)
|
|
# Should differ from original noise
|
|
assert not np.allclose(np.array(latents_next), np.array(latents), atol=1e-5)
|
|
|
|
def test_tiny_model_full_loop(self):
|
|
"""Run a complete (tiny) diffusion loop."""
|
|
from mlx_video.models.wan.model import WanModel
|
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
|
|
|
mx.random.seed(123)
|
|
config = _make_tiny_config()
|
|
model = WanModel(config)
|
|
|
|
C, F, H, W = config.in_dim, 1, 4, 4
|
|
pt, ph, pw = config.patch_size
|
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
|
|
|
sched = FlowMatchEulerScheduler()
|
|
num_steps = 3
|
|
sched.set_timesteps(num_steps, shift=3.0)
|
|
|
|
latents = mx.random.normal((C, F, H, W))
|
|
context = mx.random.normal((4, config.text_dim))
|
|
|
|
for i in range(num_steps):
|
|
t = sched.timesteps[i]
|
|
pred = model([latents], mx.array([t.item()]), [context], seq_len)[0]
|
|
latents = sched.step(pred[None], t, latents[None]).squeeze(0)
|
|
mx.eval(latents)
|
|
|
|
assert latents.shape == (C, F, H, W)
|
|
assert not mx.any(mx.isnan(latents)).item(), "NaN in output"
|
|
assert not mx.any(mx.isinf(latents)).item(), "Inf in output"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Wan2.1 Config & Pipeline Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestWan21Config:
|
|
"""Tests for Wan2.1 config presets."""
|
|
|
|
def test_wan21_14b_factory(self):
|
|
from mlx_video.models.wan.config import WanModelConfig
|
|
config = WanModelConfig.wan21_t2v_14b()
|
|
assert config.model_version == "2.1"
|
|
assert config.dual_model is False
|
|
assert config.dim == 5120
|
|
assert config.ffn_dim == 13824
|
|
assert config.num_heads == 40
|
|
assert config.num_layers == 40
|
|
assert config.head_dim == 128
|
|
assert config.sample_guide_scale == 5.0
|
|
assert config.sample_shift == 5.0
|
|
assert config.sample_steps == 50
|
|
assert config.boundary == 0.0
|
|
|
|
def test_wan21_1_3b_factory(self):
|
|
from mlx_video.models.wan.config import WanModelConfig
|
|
config = WanModelConfig.wan21_t2v_1_3b()
|
|
assert config.model_version == "2.1"
|
|
assert config.dual_model is False
|
|
assert config.dim == 1536
|
|
assert config.ffn_dim == 8960
|
|
assert config.num_heads == 12
|
|
assert config.num_layers == 30
|
|
assert config.head_dim == 128 # 1536 // 12
|
|
assert config.sample_guide_scale == 5.0
|
|
|
|
def test_wan22_14b_factory(self):
|
|
from mlx_video.models.wan.config import WanModelConfig
|
|
config = WanModelConfig.wan22_t2v_14b()
|
|
assert config.model_version == "2.2"
|
|
assert config.dual_model is True
|
|
assert config.dim == 5120
|
|
assert config.sample_guide_scale == (3.0, 4.0)
|
|
assert config.sample_shift == 12.0
|
|
assert config.sample_steps == 40
|
|
assert config.boundary == 0.875
|
|
|
|
def test_wan21_config_to_dict(self):
|
|
from mlx_video.models.wan.config import WanModelConfig
|
|
config = WanModelConfig.wan21_t2v_14b()
|
|
d = config.to_dict()
|
|
assert d["model_version"] == "2.1"
|
|
assert d["dual_model"] is False
|
|
assert d["sample_guide_scale"] == 5.0
|
|
|
|
def test_wan21_1_3b_config_to_dict(self):
|
|
from mlx_video.models.wan.config import WanModelConfig
|
|
config = WanModelConfig.wan21_t2v_1_3b()
|
|
d = config.to_dict()
|
|
assert d["dim"] == 1536
|
|
assert d["num_layers"] == 30
|
|
|
|
def test_default_config_is_wan22(self):
|
|
"""Default WanModelConfig() should be Wan2.2 14B."""
|
|
from mlx_video.models.wan.config import WanModelConfig
|
|
config = WanModelConfig()
|
|
assert config.model_version == "2.2"
|
|
assert config.dual_model is True
|
|
|
|
|
|
class TestWan21Model:
|
|
"""Test tiny Wan2.1-style model (single model mode)."""
|
|
|
|
def setup_method(self):
|
|
mx.random.seed(42)
|
|
|
|
def _make_tiny_wan21_config(self):
|
|
"""Create a tiny config mimicking Wan2.1 (single model)."""
|
|
from mlx_video.models.wan.config import WanModelConfig
|
|
config = WanModelConfig.wan21_t2v_14b()
|
|
# Override to tiny values
|
|
config.dim = 64
|
|
config.ffn_dim = 128
|
|
config.num_heads = 4
|
|
config.num_layers = 2
|
|
config.in_dim = 4
|
|
config.out_dim = 4
|
|
config.freq_dim = 32
|
|
config.text_dim = 32
|
|
config.text_len = 8
|
|
return config
|
|
|
|
def _make_tiny_wan21_1_3b_config(self):
|
|
"""Create a tiny config mimicking Wan2.1 1.3B."""
|
|
from mlx_video.models.wan.config import WanModelConfig
|
|
config = WanModelConfig.wan21_t2v_1_3b()
|
|
# Override to tiny values (preserve 1.3B head structure: 12 heads)
|
|
config.dim = 48
|
|
config.ffn_dim = 96
|
|
config.num_heads = 4
|
|
config.num_layers = 2
|
|
config.in_dim = 4
|
|
config.out_dim = 4
|
|
config.freq_dim = 24
|
|
config.text_dim = 24
|
|
config.text_len = 8
|
|
return config
|
|
|
|
def test_wan21_tiny_model_forward(self):
|
|
"""Forward pass with Wan2.1 tiny config."""
|
|
from mlx_video.models.wan.model import WanModel
|
|
|
|
config = self._make_tiny_wan21_config()
|
|
model = WanModel(config)
|
|
|
|
C, F, H, W = config.in_dim, 1, 4, 4
|
|
seq_len = (F // 1) * (H // 2) * (W // 2)
|
|
|
|
latents = mx.random.normal((C, F, H, W))
|
|
context = mx.random.normal((4, config.text_dim))
|
|
t = mx.array([500.0])
|
|
|
|
out = model([latents], t, [context], seq_len)
|
|
mx.eval(out)
|
|
assert out[0].shape == (C, F, H, W)
|
|
|
|
def test_wan21_1_3b_tiny_model_forward(self):
|
|
"""Forward pass with Wan2.1 1.3B tiny config."""
|
|
from mlx_video.models.wan.model import WanModel
|
|
|
|
config = self._make_tiny_wan21_1_3b_config()
|
|
model = WanModel(config)
|
|
|
|
C, F, H, W = config.in_dim, 1, 4, 4
|
|
seq_len = (F // 1) * (H // 2) * (W // 2)
|
|
|
|
latents = mx.random.normal((C, F, H, W))
|
|
context = mx.random.normal((4, config.text_dim))
|
|
t = mx.array([500.0])
|
|
|
|
out = model([latents], t, [context], seq_len)
|
|
mx.eval(out)
|
|
assert out[0].shape == (C, F, H, W)
|
|
|
|
def test_wan21_single_model_loop(self):
|
|
"""Full diffusion loop with single model (Wan2.1 style)."""
|
|
from mlx_video.models.wan.model import WanModel
|
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
|
|
|
config = self._make_tiny_wan21_config()
|
|
model = WanModel(config)
|
|
|
|
C, F, H, W = config.in_dim, 1, 4, 4
|
|
seq_len = (F // 1) * (H // 2) * (W // 2)
|
|
|
|
sched = FlowMatchEulerScheduler()
|
|
sched.set_timesteps(config.sample_steps, shift=config.sample_shift)
|
|
|
|
# Use only 3 steps for speed
|
|
latents = mx.random.normal((C, F, H, W))
|
|
context = mx.random.normal((4, config.text_dim))
|
|
context_null = mx.zeros((4, config.text_dim))
|
|
gs = config.sample_guide_scale # Should be float for Wan2.1
|
|
|
|
assert isinstance(gs, float), "Wan2.1 guide_scale should be float"
|
|
|
|
for i in range(3):
|
|
t = sched.timesteps[i]
|
|
pred_cond = model([latents], mx.array([t.item()]), [context], seq_len)[0]
|
|
pred_uncond = model([latents], mx.array([t.item()]), [context_null], seq_len)[0]
|
|
pred = pred_uncond + gs * (pred_cond - pred_uncond)
|
|
latents = sched.step(pred[None], t, latents[None]).squeeze(0)
|
|
mx.eval(latents)
|
|
|
|
assert latents.shape == (C, F, H, W)
|
|
assert not mx.any(mx.isnan(latents)).item()
|
|
|
|
def test_wan21_vs_wan22_config_differences(self):
|
|
"""Verify key differences between Wan2.1 and Wan2.2 configs."""
|
|
from mlx_video.models.wan.config import WanModelConfig
|
|
|
|
c21 = WanModelConfig.wan21_t2v_14b()
|
|
c22 = WanModelConfig.wan22_t2v_14b()
|
|
|
|
# Same architecture
|
|
assert c21.dim == c22.dim
|
|
assert c21.num_heads == c22.num_heads
|
|
assert c21.num_layers == c22.num_layers
|
|
|
|
# Different pipeline settings
|
|
assert c21.dual_model is False
|
|
assert c22.dual_model is True
|
|
assert isinstance(c21.sample_guide_scale, float)
|
|
assert isinstance(c22.sample_guide_scale, tuple)
|
|
assert c21.sample_shift != c22.sample_shift
|
|
assert c21.sample_steps != c22.sample_steps
|
|
|
|
|
|
class TestWan21Convert:
|
|
"""Tests for Wan2.1 conversion support."""
|
|
|
|
def test_auto_detect_wan21(self, tmp_path):
|
|
"""Auto-detect single-model directory as Wan2.1."""
|
|
# Create a Wan2.1-style directory (no low_noise_model subdir)
|
|
(tmp_path / "dummy.safetensors").touch()
|
|
# The auto-detect logic: no low_noise_model dir → 2.1
|
|
from pathlib import Path
|
|
low = tmp_path / "low_noise_model"
|
|
assert not low.exists()
|
|
# Simulates auto detection
|
|
version = "2.2" if low.exists() else "2.1"
|
|
assert version == "2.1"
|
|
|
|
def test_auto_detect_wan22(self, tmp_path):
|
|
"""Auto-detect dual-model directory as Wan2.2."""
|
|
(tmp_path / "low_noise_model").mkdir()
|
|
(tmp_path / "high_noise_model").mkdir()
|
|
from pathlib import Path
|
|
low = tmp_path / "low_noise_model"
|
|
assert low.exists()
|
|
version = "2.2" if low.exists() else "2.1"
|
|
assert version == "2.2"
|
|
|
|
def test_wan21_config_saved_correctly(self):
|
|
"""Verify config dict has correct fields for Wan2.1."""
|
|
from mlx_video.models.wan.config import WanModelConfig
|
|
config = WanModelConfig.wan21_t2v_14b()
|
|
d = config.to_dict()
|
|
assert d["model_version"] == "2.1"
|
|
assert d["dual_model"] is False
|
|
assert d["sample_steps"] == 50
|
|
assert d["sample_shift"] == 5.0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|