Files
mlx-video/tests/test_wan.py
2026-03-11 09:08:10 +01:00

2713 lines
104 KiB
Python

"""Comprehensive tests for Wan2.2 model components.
All tests use small/tiny configurations to avoid needing actual weights.
"""
import math
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# Config Tests
# ---------------------------------------------------------------------------
class TestWanModelConfig:
"""Tests for WanModelConfig dataclass."""
def test_default_values(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig()
assert config.dim == 5120
assert config.ffn_dim == 13824
assert config.num_heads == 40
assert config.num_layers == 40
assert config.in_dim == 16
assert config.out_dim == 16
assert config.patch_size == (1, 2, 2)
assert config.vae_stride == (4, 8, 8)
assert config.vae_z_dim == 16
assert config.boundary == 0.875
assert config.sample_shift == 12.0
assert config.sample_steps == 40
assert config.sample_guide_scale == (3.0, 4.0)
assert config.num_train_timesteps == 1000
assert config.qk_norm is True
assert config.cross_attn_norm is True
assert config.text_len == 512
def test_head_dim_property(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig()
assert config.head_dim == 128 # 5120 // 40
def test_to_dict_roundtrip(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig()
d = config.to_dict()
assert isinstance(d, dict)
assert d["dim"] == 5120
assert d["patch_size"] == (1, 2, 2)
assert d["boundary"] == 0.875
def test_t5_config_values(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig()
assert config.t5_vocab_size == 256384
assert config.t5_dim == 4096
assert config.t5_dim_attn == 4096
assert config.t5_dim_ffn == 10240
assert config.t5_num_heads == 64
assert config.t5_num_layers == 24
assert config.t5_num_buckets == 32
# ---------------------------------------------------------------------------
# RoPE Tests
# ---------------------------------------------------------------------------
class TestRoPE:
"""Tests for 3-way factorized RoPE."""
def test_rope_params_shape(self):
from mlx_video.models.wan.rope import rope_params
freqs = rope_params(1024, 64)
mx.eval(freqs)
assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2]
def test_rope_params_different_dims(self):
from mlx_video.models.wan.rope import rope_params
for dim in [32, 64, 128]:
freqs = rope_params(512, dim)
mx.eval(freqs)
assert freqs.shape == (512, dim // 2, 2)
def test_rope_params_cos_sin_range(self):
from mlx_video.models.wan.rope import rope_params
freqs = rope_params(256, 64)
mx.eval(freqs)
cos_vals = np.array(freqs[:, :, 0])
sin_vals = np.array(freqs[:, :, 1])
assert np.all(cos_vals >= -1.0) and np.all(cos_vals <= 1.0)
assert np.all(sin_vals >= -1.0) and np.all(sin_vals <= 1.0)
def test_rope_params_position_zero(self):
"""At position 0, cos should be 1 and sin should be 0."""
from mlx_video.models.wan.rope import rope_params
freqs = rope_params(10, 64)
mx.eval(freqs)
np.testing.assert_allclose(np.array(freqs[0, :, 0]), 1.0, atol=1e-6)
np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6)
def test_rope_apply_output_shape(self):
from mlx_video.models.wan.rope import rope_params, rope_apply
B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim
x = mx.random.normal((B, L, N, D))
freqs = rope_params(1024, D)
grid_sizes = [(2, 3, 4)] # F*H*W = 24 = L
out = rope_apply(x, grid_sizes, freqs)
mx.eval(out)
assert out.shape == (B, L, N, D)
def test_rope_apply_preserves_norm(self):
"""RoPE rotation should preserve vector norms."""
from mlx_video.models.wan.rope import rope_params, rope_apply
B, N, D = 1, 2, 16
F, H, W = 2, 3, 4
L = F * H * W
x = mx.random.normal((B, L, N, D))
freqs = rope_params(1024, D)
out = rope_apply(x, [(F, H, W)], freqs)
mx.eval(x, out)
x_np = np.array(x[0])
out_np = np.array(out[0])
for i in range(L):
for h in range(N):
norm_in = np.linalg.norm(x_np[i, h])
norm_out = np.linalg.norm(out_np[i, h])
np.testing.assert_allclose(norm_in, norm_out, rtol=1e-4)
def test_rope_apply_with_padding(self):
"""When seq_len < L, extra tokens should be preserved unchanged."""
from mlx_video.models.wan.rope import rope_params, rope_apply
B, N, D = 1, 2, 16
F, H, W = 2, 2, 2
seq_len = F * H * W # 8
pad = 4
L = seq_len + pad
x = mx.random.normal((B, L, N, D))
freqs = rope_params(1024, D)
out = rope_apply(x, [(F, H, W)], freqs)
mx.eval(x, out)
# Padded tokens should be unchanged
np.testing.assert_allclose(
np.array(out[0, seq_len:]),
np.array(x[0, seq_len:]),
atol=1e-6,
)
def test_rope_apply_batch(self):
"""Test with batch_size > 1 and different grid sizes."""
from mlx_video.models.wan.rope import rope_params, rope_apply
B, N, D = 2, 2, 16
grids = [(2, 3, 4), (2, 3, 4)]
L = 2 * 3 * 4
x = mx.random.normal((B, L, N, D))
freqs = rope_params(1024, D)
out = rope_apply(x, grids, freqs)
mx.eval(out)
assert out.shape == (B, L, N, D)
def test_rope_frequency_split(self):
"""Verify the 3-way frequency dimension split matches Wan2.2 convention."""
D = 128 # head_dim for 14B model
half_d = D // 2
d_t = half_d - 2 * (half_d // 3)
d_h = half_d // 3
d_w = half_d // 3
assert d_t + d_h + d_w == half_d
# Temporal gets more capacity
assert d_t >= d_h
assert d_t >= d_w
# ---------------------------------------------------------------------------
# Attention Tests
# ---------------------------------------------------------------------------
class TestWanRMSNorm:
def test_output_shape(self):
from mlx_video.models.wan.attention import WanRMSNorm
norm = WanRMSNorm(64)
x = mx.random.normal((2, 10, 64))
out = norm(x)
mx.eval(out)
assert out.shape == (2, 10, 64)
def test_zero_mean_variance(self):
"""RMS norm should make RMS ≈ 1 before scaling."""
from mlx_video.models.wan.attention import WanRMSNorm
norm = WanRMSNorm(64)
x = mx.random.normal((1, 5, 64)) * 10.0
out = norm(x)
mx.eval(out)
out_np = np.array(out[0])
for i in range(5):
rms = np.sqrt(np.mean(out_np[i] ** 2))
# After RMS norm with weight=1, RMS should be ~1
np.testing.assert_allclose(rms, 1.0, rtol=0.1)
def test_dtype_preservation(self):
"""RMSNorm weight is float32, so output is promoted to float32."""
from mlx_video.models.wan.attention import WanRMSNorm
norm = WanRMSNorm(32)
x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16)
out = norm(x)
mx.eval(out)
# Weight is float32, so multiplication promotes result to float32
assert out.dtype == mx.float32
class TestWanLayerNorm:
def test_output_shape(self):
from mlx_video.models.wan.attention import WanLayerNorm
norm = WanLayerNorm(64)
x = mx.random.normal((2, 10, 64))
out = norm(x)
mx.eval(out)
assert out.shape == (2, 10, 64)
def test_without_affine(self):
from mlx_video.models.wan.attention import WanLayerNorm
norm = WanLayerNorm(64, elementwise_affine=False)
x = mx.random.normal((1, 4, 64))
out = norm(x)
mx.eval(out)
# Mean should be ~0, variance should be ~1
out_np = np.array(out[0])
for i in range(4):
np.testing.assert_allclose(np.mean(out_np[i]), 0.0, atol=0.05)
np.testing.assert_allclose(np.std(out_np[i]), 1.0, rtol=0.1)
def test_with_affine(self):
from mlx_video.models.wan.attention import WanLayerNorm
norm = WanLayerNorm(32, elementwise_affine=True)
assert hasattr(norm, "weight")
assert hasattr(norm, "bias")
x = mx.random.normal((1, 4, 32))
out = norm(x)
mx.eval(out)
assert out.shape == (1, 4, 32)
class TestWanSelfAttention:
def setup_method(self):
mx.random.seed(42)
self.dim = 64
self.num_heads = 4
def test_output_shape(self):
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads)
B, L = 1, 24
F, H, W = 2, 3, 4
x = mx.random.normal((B, L, self.dim))
freqs = rope_params(1024, self.dim // self.num_heads)
out = attn(x, seq_lens=[L], grid_sizes=[(F, H, W)], freqs=freqs)
mx.eval(out)
assert out.shape == (B, L, self.dim)
def test_with_qk_norm(self):
from mlx_video.models.wan.attention import WanSelfAttention
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True)
assert attn.norm_q is not None
assert attn.norm_k is not None
def test_without_qk_norm(self):
from mlx_video.models.wan.attention import WanSelfAttention
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
assert attn.norm_q is None
assert attn.norm_k is None
def test_masking(self):
"""Test that masking works: shorter seq_lens should mask later tokens."""
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
B, L = 1, 24
F, H, W = 2, 3, 4
x = mx.random.normal((B, L, self.dim))
freqs = rope_params(1024, self.dim // self.num_heads)
# Full sequence
out_full = attn(x, seq_lens=[L], grid_sizes=[(F, H, W)], freqs=freqs)
# Shorter sequence (mask last 4 tokens)
out_masked = attn(x, seq_lens=[L - 4], grid_sizes=[(F, H, W)], freqs=freqs)
mx.eval(out_full, out_masked)
# Outputs should differ when masking is applied
assert not np.allclose(np.array(out_full), np.array(out_masked), atol=1e-5)
class TestWanCrossAttention:
def setup_method(self):
mx.random.seed(42)
self.dim = 64
self.num_heads = 4
def test_output_shape(self):
from mlx_video.models.wan.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
B, L_q, L_kv = 1, 24, 16
x = mx.random.normal((B, L_q, self.dim))
context = mx.random.normal((B, L_kv, self.dim))
out = attn(x, context)
mx.eval(out)
assert out.shape == (B, L_q, self.dim)
def test_with_context_mask(self):
from mlx_video.models.wan.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
B, L_q, L_kv = 1, 12, 16
x = mx.random.normal((B, L_q, self.dim))
context = mx.random.normal((B, L_kv, self.dim))
out = attn(x, context, context_lens=[10])
mx.eval(out)
assert out.shape == (B, L_q, self.dim)
# ---------------------------------------------------------------------------
# Transformer Block Tests
# ---------------------------------------------------------------------------
class TestWanFFN:
def test_output_shape(self):
from mlx_video.models.wan.transformer import WanFFN
ffn = WanFFN(64, 256)
x = mx.random.normal((2, 10, 64))
out = ffn(x)
mx.eval(out)
assert out.shape == (2, 10, 64)
def test_gelu_activation(self):
"""FFN should use GELU activation (non-linearity)."""
from mlx_video.models.wan.transformer import WanFFN
ffn = WanFFN(32, 128)
x = mx.ones((1, 1, 32)) * 2.0
out1 = ffn(x)
x2 = mx.ones((1, 1, 32)) * 4.0
out2 = ffn(x2)
mx.eval(out1, out2)
# Non-linear: 2x input should not give 2x output
assert not np.allclose(np.array(out2), np.array(out1) * 2.0, rtol=0.1)
class TestWanAttentionBlock:
def setup_method(self):
mx.random.seed(42)
self.dim = 64
self.ffn_dim = 128
self.num_heads = 4
def test_output_shape(self):
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan.rope import rope_params
block = WanAttentionBlock(
self.dim, self.ffn_dim, self.num_heads,
cross_attn_norm=True,
)
B, L = 1, 24
F, H, W = 2, 3, 4
x = mx.random.normal((B, L, self.dim))
e = mx.random.normal((B, L, 6, self.dim))
context = mx.random.normal((B, 16, self.dim))
freqs = rope_params(1024, self.dim // self.num_heads)
out = block(
x, e, seq_lens=[L], grid_sizes=[(F, H, W)],
freqs=freqs, context=context,
)
mx.eval(out)
assert out.shape == (B, L, self.dim)
def test_modulation_shape(self):
from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
assert block.modulation.shape == (1, 6, self.dim)
def test_with_cross_attn_norm(self):
from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock(
self.dim, self.ffn_dim, self.num_heads,
cross_attn_norm=True,
)
assert block.norm3 is not None
def test_without_cross_attn_norm(self):
from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock(
self.dim, self.ffn_dim, self.num_heads,
cross_attn_norm=False,
)
assert block.norm3 is None
def test_residual_connection(self):
"""Output should differ from zero even with small random init."""
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan.rope import rope_params
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
B, L = 1, 8
F, H, W = 2, 2, 2
x = mx.ones((B, L, self.dim))
e = mx.zeros((B, L, 6, self.dim))
context = mx.random.normal((B, 4, self.dim))
freqs = rope_params(1024, self.dim // self.num_heads)
out = block(x, e, [L], [(F, H, W)], freqs, context)
mx.eval(out)
# With residual connections, output should be close to input + corrections
assert not np.allclose(np.array(out), 0.0, atol=1e-3)
# ---------------------------------------------------------------------------
# Sinusoidal Embedding Tests
# ---------------------------------------------------------------------------
class TestSinusoidalEmbedding:
def test_output_shape(self):
from mlx_video.models.wan.model import sinusoidal_embedding_1d
pos = mx.arange(10).astype(mx.float32)
emb = sinusoidal_embedding_1d(256, pos)
mx.eval(emb)
assert emb.shape == (10, 256)
def test_position_zero(self):
"""Position 0 should have cos=1 for all dims and sin=0."""
from mlx_video.models.wan.model import sinusoidal_embedding_1d
pos = mx.array([0.0])
emb = sinusoidal_embedding_1d(64, pos)
mx.eval(emb)
emb_np = np.array(emb[0])
# First half is cos, should be 1 at position 0
np.testing.assert_allclose(emb_np[:32], 1.0, atol=1e-5)
# Second half is sin, should be 0 at position 0
np.testing.assert_allclose(emb_np[32:], 0.0, atol=1e-5)
def test_different_positions_differ(self):
from mlx_video.models.wan.model import sinusoidal_embedding_1d
pos = mx.array([0.0, 100.0, 999.0])
emb = sinusoidal_embedding_1d(128, pos)
mx.eval(emb)
emb_np = np.array(emb)
assert not np.allclose(emb_np[0], emb_np[1])
assert not np.allclose(emb_np[1], emb_np[2])
# ---------------------------------------------------------------------------
# Head Tests
# ---------------------------------------------------------------------------
class TestHead:
def test_output_shape(self):
from mlx_video.models.wan.model import Head
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
B, L = 1, 24
x = mx.random.normal((B, L, 64))
e = mx.random.normal((B, 64)) # time embedding: [B, dim]
out = head(x, e)
mx.eval(out)
expected_proj_dim = 16 * 1 * 2 * 2 # 64
assert out.shape == (B, L, expected_proj_dim)
def test_modulation_shape(self):
from mlx_video.models.wan.model import Head
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
assert head.modulation.shape == (1, 2, 64)
# ---------------------------------------------------------------------------
# WanModel (Tiny) Tests
# ---------------------------------------------------------------------------
def _make_tiny_config():
"""Create a tiny WanModelConfig for testing."""
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig()
# Override to tiny values
config.dim = 64
config.ffn_dim = 128
config.num_heads = 4
config.num_layers = 2
config.in_dim = 4
config.out_dim = 4
config.patch_size = (1, 2, 2)
config.freq_dim = 32
config.text_dim = 32
config.text_len = 8
return config
class TestWanModel:
def setup_method(self):
mx.random.seed(42)
def test_instantiation(self):
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
num_params = sum(p.size for _, p in nn.utils.tree_flatten(model.parameters()))
assert num_params > 0
def test_patchify_shape(self):
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
# Input: [C=4, F=1, H=4, W=4]
x = mx.random.normal((4, 1, 4, 4))
patches, grid_size = model._patchify(x)
mx.eval(patches)
# Patch size (1,2,2): F'=1, H'=2, W'=2
assert grid_size == (1, 2, 2)
assert patches.shape == (1, 1 * 2 * 2, config.dim)
def test_patchify_various_sizes(self):
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
for f, h, w in [(1, 4, 4), (2, 6, 8), (3, 4, 6)]:
x = mx.random.normal((config.in_dim, f, h, w))
patches, (gf, gh, gw) = model._patchify(x)
mx.eval(patches)
pt, ph, pw = config.patch_size
assert gf == f // pt
assert gh == h // ph
assert gw == w // pw
assert patches.shape[1] == gf * gh * gw
def test_unpatchify_inverse(self):
"""Patchify then unpatchify should reconstruct original spatial dims."""
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 2, 4, 6
pt, ph, pw = config.patch_size
F_out, H_out, W_out = F // pt, H // ph, W // pw
L = F_out * H_out * W_out
proj_dim = config.out_dim * pt * ph * pw
# Simulated head output
x = mx.random.normal((1, L, proj_dim))
out = model.unpatchify(x, [(F_out, H_out, W_out)])
mx.eval(out[0])
assert out[0].shape == (config.out_dim, F, H, W)
def test_forward_pass(self):
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
x_list = [mx.random.normal((C, F, H, W))]
t = mx.array([500.0])
context = [mx.random.normal((6, config.text_dim))]
out = model(x_list, t, context, seq_len)
mx.eval(out[0])
assert len(out) == 1
assert out[0].shape == (C, F, H, W)
def test_forward_batch(self):
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
x_list = [mx.random.normal((C, F, H, W)), mx.random.normal((C, F, H, W))]
t = mx.array([500.0, 200.0])
context = [mx.random.normal((6, config.text_dim)), mx.random.normal((4, config.text_dim))]
out = model(x_list, t, context, seq_len)
mx.eval(out[0], out[1])
assert len(out) == 2
for o in out:
assert o.shape == (C, F, H, W)
def test_output_is_float32(self):
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
seq_len = (F // 1) * (H // 2) * (W // 2)
out = model([mx.random.normal((C, F, H, W))], mx.array([100.0]),
[mx.random.normal((4, config.text_dim))], seq_len)
mx.eval(out[0])
assert out[0].dtype == mx.float32
# ---------------------------------------------------------------------------
# T5 Encoder Tests
# ---------------------------------------------------------------------------
class TestT5LayerNorm:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5LayerNorm
norm = T5LayerNorm(64)
x = mx.random.normal((2, 10, 64))
out = norm(x)
mx.eval(out)
assert out.shape == (2, 10, 64)
def test_rms_normalization(self):
"""After T5LayerNorm with weight=1, RMS should be ~1."""
from mlx_video.models.wan.text_encoder import T5LayerNorm
norm = T5LayerNorm(128)
x = mx.random.normal((1, 5, 128)) * 5.0
out = norm(x)
mx.eval(out)
out_np = np.array(out[0])
for i in range(5):
rms = np.sqrt(np.mean(out_np[i] ** 2))
np.testing.assert_allclose(rms, 1.0, rtol=0.1)
class TestT5RelativeEmbedding:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
out = rel_emb(10, 10)
mx.eval(out)
assert out.shape == (1, 4, 10, 10) # [1, N, lq, lk]
def test_asymmetric_lengths(self):
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
out = rel_emb(8, 12)
mx.eval(out)
assert out.shape == (1, 4, 8, 12)
def test_symmetry(self):
"""Position bias should have structure (not all zeros/random)."""
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2)
out = rel_emb(6, 6)
mx.eval(out)
out_np = np.array(out[0]) # [N, lq, lk]
# Diagonal elements (position i attending to position i) should be consistent
# (same relative distance = 0 for all diagonal elements)
for h in range(2):
diag = np.diag(out_np[h])
np.testing.assert_allclose(diag, diag[0], atol=1e-5)
class TestT5Attention:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
x = mx.random.normal((1, 10, 64))
out = attn(x)
mx.eval(out)
assert out.shape == (1, 10, 64)
def test_no_scaling(self):
"""T5 attention famously has no sqrt(d) scaling. Verify structure."""
from mlx_video.models.wan.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
# No scale attribute (unlike standard attention)
assert not hasattr(attn, "scale")
def test_with_position_bias(self):
from mlx_video.models.wan.text_encoder import T5Attention, T5RelativeEmbedding
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
rel_emb = T5RelativeEmbedding(32, 4)
x = mx.random.normal((1, 10, 64))
pos_bias = rel_emb(10, 10)
out = attn(x, pos_bias=pos_bias)
mx.eval(out)
assert out.shape == (1, 10, 64)
def test_with_mask(self):
from mlx_video.models.wan.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
x = mx.random.normal((1, 10, 64))
mask = mx.ones((1, 10))
mask = mx.concatenate([mask[:, :7], mx.zeros((1, 3))], axis=1)
out = attn(x, mask=mask)
mx.eval(out)
assert out.shape == (1, 10, 64)
class TestT5FeedForward:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5FeedForward
ffn = T5FeedForward(64, 256)
x = mx.random.normal((1, 10, 64))
out = ffn(x)
mx.eval(out)
assert out.shape == (1, 10, 64)
def test_gated_structure(self):
"""T5 FFN is gated: gate(x) * fc1(x)."""
from mlx_video.models.wan.text_encoder import T5FeedForward
ffn = T5FeedForward(32, 64)
assert hasattr(ffn, "gate_proj")
assert hasattr(ffn, "fc1")
assert hasattr(ffn, "fc2")
class TestT5Encoder:
def setup_method(self):
mx.random.seed(42)
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
)
ids = mx.array([[1, 5, 10, 0, 0]])
mask = mx.array([[1, 1, 1, 0, 0]])
out = encoder(ids, mask=mask)
mx.eval(out)
assert out.shape == (1, 5, 64)
def test_shared_pos(self):
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=True,
)
assert encoder.pos_embedding is not None
for block in encoder.blocks:
assert block.pos_embedding is None
def test_per_layer_pos(self):
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
)
assert encoder.pos_embedding is None
for block in encoder.blocks:
assert block.pos_embedding is not None
def test_param_count(self):
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
)
num_params = sum(p.size for _, p in nn.utils.tree_flatten(encoder.parameters()))
assert num_params > 0
def test_without_mask(self):
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
)
ids = mx.array([[1, 5, 10]])
out = encoder(ids)
mx.eval(out)
assert out.shape == (1, 3, 64)
# ---------------------------------------------------------------------------
# VAE Tests
# ---------------------------------------------------------------------------
class TestCausalConv3d:
def test_output_shape_stride1(self):
from mlx_video.models.wan.vae import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1)
# Initialize weights
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
x = mx.random.normal((1, 4, 3, 8, 8)) # [B, C, T, H, W]
out = conv(x)
mx.eval(out)
# With causal padding and padding=1 on spatial, dims should be preserved
assert out.shape[0] == 1
assert out.shape[1] == 8 # out_channels
assert out.shape[2] == 3 # T preserved
assert out.shape[3] == 8 # H preserved
assert out.shape[4] == 8 # W preserved
def test_output_shape_kernel1(self):
from mlx_video.models.wan.vae import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0)
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
x = mx.random.normal((1, 4, 2, 4, 4))
out = conv(x)
mx.eval(out)
assert out.shape == (1, 8, 2, 4, 4)
def test_causal_padding(self):
"""Causal conv should only use past/current frames, not future."""
from mlx_video.models.wan.vae import CausalConv3d
conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1)
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
conv.bias = mx.zeros((2,))
# Create input where only the first frame has signal
x = mx.zeros((1, 2, 4, 4, 4))
x_np = np.zeros((1, 2, 4, 4, 4), dtype=np.float32)
x_np[:, :, 0, :, :] = 1.0
x = mx.array(x_np)
out = conv(x)
mx.eval(out)
# Due to causal padding, the output at t=0 should only depend on t=0
class TestResidualBlock:
def test_same_dim(self):
from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 8)
x = mx.random.normal((1, 8, 2, 4, 4))
out = block(x)
mx.eval(out)
assert out.shape == (1, 8, 2, 4, 4)
def test_different_dim(self):
from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 16)
x = mx.random.normal((1, 8, 2, 4, 4))
out = block(x)
mx.eval(out)
assert out.shape == (1, 16, 2, 4, 4)
def test_shortcut_exists_when_dims_differ(self):
from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 16)
assert block.shortcut is not None
def test_no_shortcut_when_dims_same(self):
from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 8)
assert block.shortcut is None
class TestAttentionBlock:
def test_output_shape(self):
from mlx_video.models.wan.vae import AttentionBlock
block = AttentionBlock(8)
x = mx.random.normal((1, 8, 2, 4, 4))
out = block(x)
mx.eval(out)
assert out.shape == (1, 8, 2, 4, 4)
def test_residual_connection(self):
from mlx_video.models.wan.vae import AttentionBlock
block = AttentionBlock(8)
x = mx.random.normal((1, 8, 1, 3, 3))
out = block(x)
mx.eval(x, out)
# Residual: output should not be zero even with random init
assert np.abs(np.array(out)).max() > 0
class TestWanVAE:
def test_instantiation(self):
from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16)
assert vae.z_dim == 16
assert vae.mean.shape == (16,)
assert vae.std.shape == (16,)
def test_normalization_stats(self):
from mlx_video.models.wan.vae import WanVAE, VAE_MEAN, VAE_STD
assert len(VAE_MEAN) == 16
assert len(VAE_STD) == 16
assert all(s > 0 for s in VAE_STD)
# ---------------------------------------------------------------------------
# Scheduler Tests
# ---------------------------------------------------------------------------
class TestFlowMatchEulerScheduler:
def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
assert sched.num_train_timesteps == 1000
assert sched.timesteps is None
assert sched.sigmas is None
def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(40, shift=12.0)
mx.eval(sched.timesteps, sched.sigmas)
assert sched.timesteps.shape == (40,)
assert sched.sigmas.shape == (41,) # 40 steps + terminal
def test_timesteps_decreasing(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(40, shift=12.0)
mx.eval(sched.timesteps)
ts = np.array(sched.timesteps)
# Timesteps should be monotonically decreasing
assert np.all(np.diff(ts) < 0), f"Timesteps not decreasing: {ts[:5]}..."
def test_sigmas_decreasing(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(20, shift=1.0)
mx.eval(sched.sigmas)
sigmas = np.array(sched.sigmas)
assert np.all(np.diff(sigmas) <= 0), "Sigmas not decreasing"
def test_terminal_sigma_is_zero(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(20, shift=5.0)
mx.eval(sched.sigmas)
np.testing.assert_allclose(np.array(sched.sigmas[-1]), 0.0, atol=1e-6)
def test_shift_effect(self):
"""Larger shift should push sigmas toward higher values."""
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched1 = FlowMatchEulerScheduler()
sched2 = FlowMatchEulerScheduler()
sched1.set_timesteps(20, shift=1.0)
sched2.set_timesteps(20, shift=12.0)
mx.eval(sched1.sigmas, sched2.sigmas)
mean1 = np.mean(np.array(sched1.sigmas[:-1]))
mean2 = np.mean(np.array(sched2.sigmas[:-1]))
assert mean2 > mean1, "Higher shift should push sigmas higher"
def test_step_euler(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(10, shift=1.0)
mx.eval(sched.sigmas)
sample = mx.ones((1, 4, 2, 2, 2))
velocity = mx.ones((1, 4, 2, 2, 2)) * 0.5
timestep = sched.timesteps[0]
sigma = float(np.array(sched.sigmas[0]))
sigma_next = float(np.array(sched.sigmas[1]))
result = sched.step(velocity, timestep, sample)
mx.eval(result)
# Euler: x_next = x + (sigma_next - sigma) * v
expected = 1.0 + (sigma_next - sigma) * 0.5
np.testing.assert_allclose(
np.array(result).flatten()[0], expected, rtol=1e-4,
)
def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
assert sched._step_index == 0
sample = mx.ones((1, 1, 1, 1, 1))
vel = mx.zeros((1, 1, 1, 1, 1))
sched.step(vel, sched.timesteps[0], sample)
assert sched._step_index == 1
sched.step(vel, sched.timesteps[1], sample)
assert sched._step_index == 2
def test_reset(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1))
vel = mx.zeros((1, 1, 1, 1, 1))
sched.step(vel, sched.timesteps[0], sample)
assert sched._step_index == 1
sched.reset()
assert sched._step_index == 0
@pytest.mark.parametrize("steps", [10, 20, 40, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(steps, shift=12.0)
mx.eval(sched.timesteps, sched.sigmas)
assert sched.timesteps.shape == (steps,)
assert sched.sigmas.shape == (steps + 1,)
def test_full_denoise_loop(self):
"""Run a complete denoise loop with zero velocity -> sample unchanged."""
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 2, 1, 2, 2))
for i in range(5):
vel = mx.zeros_like(sample)
sample = sched.step(vel, sched.timesteps[i], sample)
mx.eval(sample)
# With zero velocity, sample should remain unchanged
np.testing.assert_allclose(np.array(sample), 1.0, atol=1e-5)
# ---------------------------------------------------------------------------
# Weight Conversion Tests
# ---------------------------------------------------------------------------
class TestSanitizeTransformerWeights:
def test_patch_embedding_reshape(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
"patch_embedding.bias": mx.random.normal((5120,)),
}
out = sanitize_wan_transformer_weights(weights)
assert "patch_embedding_proj.weight" in out
assert "patch_embedding_proj.bias" in out
assert out["patch_embedding_proj.weight"].shape == (5120, 16 * 1 * 2 * 2)
def test_text_embedding_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"text_embedding.0.weight": mx.zeros((64, 32)),
"text_embedding.0.bias": mx.zeros((64,)),
"text_embedding.2.weight": mx.zeros((64, 64)),
"text_embedding.2.bias": mx.zeros((64,)),
}
out = sanitize_wan_transformer_weights(weights)
assert "text_embedding_0.weight" in out
assert "text_embedding_0.bias" in out
assert "text_embedding_1.weight" in out
assert "text_embedding_1.bias" in out
def test_time_embedding_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"time_embedding.0.weight": mx.zeros((64, 32)),
"time_embedding.2.weight": mx.zeros((64, 64)),
}
out = sanitize_wan_transformer_weights(weights)
assert "time_embedding_0.weight" in out
assert "time_embedding_1.weight" in out
def test_time_projection_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"time_projection.1.weight": mx.zeros((384, 64)),
"time_projection.1.bias": mx.zeros((384,)),
}
out = sanitize_wan_transformer_weights(weights)
assert "time_projection.weight" in out
assert "time_projection.bias" in out
def test_ffn_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
"blocks.0.ffn.0.bias": mx.zeros((128,)),
"blocks.0.ffn.2.weight": mx.zeros((64, 128)),
"blocks.0.ffn.2.bias": mx.zeros((64,)),
}
out = sanitize_wan_transformer_weights(weights)
assert "blocks.0.ffn.fc1.weight" in out
assert "blocks.0.ffn.fc1.bias" in out
assert "blocks.0.ffn.fc2.weight" in out
assert "blocks.0.ffn.fc2.bias" in out
def test_freqs_skipped(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"freqs": mx.zeros((1024, 64, 2)),
"blocks.0.norm1.weight": mx.zeros((64,)),
}
out = sanitize_wan_transformer_weights(weights)
assert "freqs" not in out
assert "blocks.0.norm1.weight" in out
def test_passthrough_keys(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
"blocks.0.self_attn.k.weight": mx.zeros((64, 64)),
"blocks.0.self_attn.v.weight": mx.zeros((64, 64)),
"blocks.0.self_attn.o.weight": mx.zeros((64, 64)),
"blocks.0.modulation": mx.zeros((1, 6, 64)),
"head.head.weight": mx.zeros((64, 64)),
"head.modulation": mx.zeros((1, 2, 64)),
}
out = sanitize_wan_transformer_weights(weights)
for key in weights:
assert key in out
class TestSanitizeT5Weights:
def test_gate_rename(self):
from mlx_video.convert_wan import sanitize_wan_t5_weights
weights = {
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
}
out = sanitize_wan_t5_weights(weights)
assert "blocks.0.ffn.gate_proj.weight" in out
assert "blocks.0.ffn.fc1.weight" in out
assert "blocks.0.ffn.fc2.weight" in out
def test_passthrough(self):
from mlx_video.convert_wan import sanitize_wan_t5_weights
weights = {
"token_embedding.weight": mx.zeros((100, 64)),
"blocks.0.attn.q.weight": mx.zeros((64, 64)),
"norm.weight": mx.zeros((64,)),
}
out = sanitize_wan_t5_weights(weights)
for key in weights:
assert key in out
class TestSanitizeVAEWeights:
def test_conv3d_transpose(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = {
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W]
}
out = sanitize_wan_vae_weights(weights)
assert out["decoder.conv1.weight"].shape == (8, 3, 3, 3, 4) # [O, D, H, W, I]
def test_conv2d_transpose(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = {
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W]
}
out = sanitize_wan_vae_weights(weights)
assert out["decoder.proj.weight"].shape == (16, 3, 3, 8) # [O, H, W, I]
def test_non_conv_passthrough(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = {
"decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose
"decoder.bias": mx.zeros((16,)),
}
out = sanitize_wan_vae_weights(weights)
assert out["decoder.norm.weight"].shape == (64,)
assert out["decoder.bias"].shape == (16,)
def test_mixed_weights(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = {
"conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D
"conv2d.weight": mx.zeros((8, 4, 3, 3)), # 4D
"linear.weight": mx.zeros((8, 4)), # 2D
"norm.weight": mx.zeros((8,)), # 1D
}
out = sanitize_wan_vae_weights(weights)
assert out["conv3d.weight"].shape == (8, 3, 3, 3, 4)
assert out["conv2d.weight"].shape == (8, 3, 3, 4)
assert out["linear.weight"].shape == (8, 4)
assert out["norm.weight"].shape == (8,)
# ---------------------------------------------------------------------------
# Integration: end-to-end tiny model forward pass
# ---------------------------------------------------------------------------
class TestEndToEnd:
"""End-to-end test with tiny model (no real weights needed)."""
def test_tiny_model_denoise_step(self):
"""Simulate one denoising step with tiny model."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
mx.random.seed(42)
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=3.0)
latents = mx.random.normal((C, F, H, W))
context = mx.random.normal((4, config.text_dim))
# One step
t = sched.timesteps[0]
pred = model([latents], mx.array([t.item()]), [context], seq_len)[0]
latents_next = sched.step(pred[None], t, latents[None]).squeeze(0)
mx.eval(latents_next)
assert latents_next.shape == (C, F, H, W)
# Should differ from original noise
assert not np.allclose(np.array(latents_next), np.array(latents), atol=1e-5)
def test_tiny_model_full_loop(self):
"""Run a complete (tiny) diffusion loop."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
mx.random.seed(123)
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
sched = FlowMatchEulerScheduler()
num_steps = 3
sched.set_timesteps(num_steps, shift=3.0)
latents = mx.random.normal((C, F, H, W))
context = mx.random.normal((4, config.text_dim))
for i in range(num_steps):
t = sched.timesteps[i]
pred = model([latents], mx.array([t.item()]), [context], seq_len)[0]
latents = sched.step(pred[None], t, latents[None]).squeeze(0)
mx.eval(latents)
assert latents.shape == (C, F, H, W)
assert not mx.any(mx.isnan(latents)).item(), "NaN in output"
assert not mx.any(mx.isinf(latents)).item(), "Inf in output"
# ---------------------------------------------------------------------------
# Wan2.1 Config & Pipeline Tests
# ---------------------------------------------------------------------------
class TestWan21Config:
"""Tests for Wan2.1 config presets."""
def test_wan21_14b_factory(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
assert config.model_version == "2.1"
assert config.dual_model is False
assert config.dim == 5120
assert config.ffn_dim == 13824
assert config.num_heads == 40
assert config.num_layers == 40
assert config.head_dim == 128
assert config.sample_guide_scale == 5.0
assert config.sample_shift == 5.0
assert config.sample_steps == 50
assert config.boundary == 0.0
def test_wan21_1_3b_factory(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b()
assert config.model_version == "2.1"
assert config.dual_model is False
assert config.dim == 1536
assert config.ffn_dim == 8960
assert config.num_heads == 12
assert config.num_layers == 30
assert config.head_dim == 128 # 1536 // 12
assert config.sample_guide_scale == 5.0
def test_wan22_14b_factory(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan22_t2v_14b()
assert config.model_version == "2.2"
assert config.dual_model is True
assert config.dim == 5120
assert config.sample_guide_scale == (3.0, 4.0)
assert config.sample_shift == 12.0
assert config.sample_steps == 40
assert config.boundary == 0.875
def test_wan21_config_to_dict(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
d = config.to_dict()
assert d["model_version"] == "2.1"
assert d["dual_model"] is False
assert d["sample_guide_scale"] == 5.0
def test_wan21_1_3b_config_to_dict(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b()
d = config.to_dict()
assert d["dim"] == 1536
assert d["num_layers"] == 30
def test_default_config_is_wan22(self):
"""Default WanModelConfig() should be Wan2.2 14B."""
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig()
assert config.model_version == "2.2"
assert config.dual_model is True
class TestWan21Model:
"""Test tiny Wan2.1-style model (single model mode)."""
def setup_method(self):
mx.random.seed(42)
def _make_tiny_wan21_config(self):
"""Create a tiny config mimicking Wan2.1 (single model)."""
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
# Override to tiny values
config.dim = 64
config.ffn_dim = 128
config.num_heads = 4
config.num_layers = 2
config.in_dim = 4
config.out_dim = 4
config.freq_dim = 32
config.text_dim = 32
config.text_len = 8
return config
def _make_tiny_wan21_1_3b_config(self):
"""Create a tiny config mimicking Wan2.1 1.3B."""
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b()
# Override to tiny values (preserve 1.3B head structure: 12 heads)
config.dim = 48
config.ffn_dim = 96
config.num_heads = 4
config.num_layers = 2
config.in_dim = 4
config.out_dim = 4
config.freq_dim = 24
config.text_dim = 24
config.text_len = 8
return config
def test_wan21_tiny_model_forward(self):
"""Forward pass with Wan2.1 tiny config."""
from mlx_video.models.wan.model import WanModel
config = self._make_tiny_wan21_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
seq_len = (F // 1) * (H // 2) * (W // 2)
latents = mx.random.normal((C, F, H, W))
context = mx.random.normal((4, config.text_dim))
t = mx.array([500.0])
out = model([latents], t, [context], seq_len)
mx.eval(out)
assert out[0].shape == (C, F, H, W)
def test_wan21_1_3b_tiny_model_forward(self):
"""Forward pass with Wan2.1 1.3B tiny config."""
from mlx_video.models.wan.model import WanModel
config = self._make_tiny_wan21_1_3b_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
seq_len = (F // 1) * (H // 2) * (W // 2)
latents = mx.random.normal((C, F, H, W))
context = mx.random.normal((4, config.text_dim))
t = mx.array([500.0])
out = model([latents], t, [context], seq_len)
mx.eval(out)
assert out[0].shape == (C, F, H, W)
def test_wan21_single_model_loop(self):
"""Full diffusion loop with single model (Wan2.1 style)."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
config = self._make_tiny_wan21_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
seq_len = (F // 1) * (H // 2) * (W // 2)
sched = FlowMatchEulerScheduler()
sched.set_timesteps(config.sample_steps, shift=config.sample_shift)
# Use only 3 steps for speed
latents = mx.random.normal((C, F, H, W))
context = mx.random.normal((4, config.text_dim))
context_null = mx.zeros((4, config.text_dim))
gs = config.sample_guide_scale # Should be float for Wan2.1
assert isinstance(gs, float), "Wan2.1 guide_scale should be float"
for i in range(3):
t = sched.timesteps[i]
pred_cond = model([latents], mx.array([t.item()]), [context], seq_len)[0]
pred_uncond = model([latents], mx.array([t.item()]), [context_null], seq_len)[0]
pred = pred_uncond + gs * (pred_cond - pred_uncond)
latents = sched.step(pred[None], t, latents[None]).squeeze(0)
mx.eval(latents)
assert latents.shape == (C, F, H, W)
assert not mx.any(mx.isnan(latents)).item()
def test_wan21_vs_wan22_config_differences(self):
"""Verify key differences between Wan2.1 and Wan2.2 configs."""
from mlx_video.models.wan.config import WanModelConfig
c21 = WanModelConfig.wan21_t2v_14b()
c22 = WanModelConfig.wan22_t2v_14b()
# Same architecture
assert c21.dim == c22.dim
assert c21.num_heads == c22.num_heads
assert c21.num_layers == c22.num_layers
# Different pipeline settings
assert c21.dual_model is False
assert c22.dual_model is True
assert isinstance(c21.sample_guide_scale, float)
assert isinstance(c22.sample_guide_scale, tuple)
assert c21.sample_shift != c22.sample_shift
assert c21.sample_steps != c22.sample_steps
class TestWan21Convert:
"""Tests for Wan2.1 conversion support."""
def test_auto_detect_wan21(self, tmp_path):
"""Auto-detect single-model directory as Wan2.1."""
# Create a Wan2.1-style directory (no low_noise_model subdir)
(tmp_path / "dummy.safetensors").touch()
# The auto-detect logic: no low_noise_model dir → 2.1
from pathlib import Path
low = tmp_path / "low_noise_model"
assert not low.exists()
# Simulates auto detection
version = "2.2" if low.exists() else "2.1"
assert version == "2.1"
def test_auto_detect_wan22(self, tmp_path):
"""Auto-detect dual-model directory as Wan2.2."""
(tmp_path / "low_noise_model").mkdir()
(tmp_path / "high_noise_model").mkdir()
from pathlib import Path
low = tmp_path / "low_noise_model"
assert low.exists()
version = "2.2" if low.exists() else "2.1"
assert version == "2.2"
def test_wan21_config_saved_correctly(self):
"""Verify config dict has correct fields for Wan2.1."""
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
d = config.to_dict()
assert d["model_version"] == "2.1"
assert d["dual_model"] is False
assert d["sample_steps"] == 50
assert d["sample_shift"] == 5.0
# ---------------------------------------------------------------------------
# Shared Sigma Schedule Tests
# ---------------------------------------------------------------------------
class TestComputeSigmas:
"""Tests for the shared _compute_sigmas helper."""
def test_length(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
assert len(sigmas) == 21 # num_steps + terminal
def test_terminal_zero(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(10, shift=1.0)
assert sigmas[-1] == 0.0
def test_starts_at_one(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-6)
def test_decreasing(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
assert np.all(np.diff(sigmas) <= 0)
def test_matches_official_wan22(self):
"""Sigma schedule should match the official Wan2.2 get_sampling_sigmas."""
from mlx_video.models.wan.scheduler import _compute_sigmas
steps, shift = 50, 5.0
sigmas = _compute_sigmas(steps, shift)
# Official: sigma = linspace(1, 0, steps+1)[:steps]; sigma = shift*sigma/(1+(shift-1)*sigma)
official = np.linspace(1, 0, steps + 1)[:steps]
official = shift * official / (1 + (shift - 1) * official)
official = np.append(official, 0.0).astype(np.float32)
np.testing.assert_allclose(sigmas, official, atol=1e-6)
def test_shift_one_is_linear(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(10, shift=1.0)
# With shift=1, f(sigma)=sigma, so schedule is linear from 1 to 0
expected = np.linspace(1, 0, 11).astype(np.float32)
np.testing.assert_allclose(sigmas, expected, atol=1e-6)
def test_all_schedulers_same_sigmas(self):
"""All three schedulers should produce identical sigma schedules."""
from mlx_video.models.wan.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
)
scheds = [
FlowMatchEulerScheduler(1000),
FlowDPMPP2MScheduler(1000),
FlowUniPCScheduler(1000),
]
for s in scheds:
s.set_timesteps(20, shift=5.0)
mx.eval(*[s.sigmas for s in scheds])
ref = np.array(scheds[0].sigmas)
for s in scheds[1:]:
np.testing.assert_allclose(np.array(s.sigmas), ref, atol=1e-6)
def test_all_schedulers_same_timesteps(self):
from mlx_video.models.wan.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
)
scheds = [
FlowMatchEulerScheduler(1000),
FlowDPMPP2MScheduler(1000),
FlowUniPCScheduler(1000),
]
for s in scheds:
s.set_timesteps(30, shift=12.0)
mx.eval(*[s.timesteps for s in scheds])
ref = np.array(scheds[0].timesteps)
for s in scheds[1:]:
np.testing.assert_allclose(np.array(s.timesteps), ref, atol=1e-3)
# ---------------------------------------------------------------------------
# DPM++ 2M Scheduler Tests
# ---------------------------------------------------------------------------
class TestFlowDPMPP2MScheduler:
def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
assert sched.num_train_timesteps == 1000
assert sched.lower_order_final is True
def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(20, shift=5.0)
mx.eval(sched.timesteps, sched.sigmas)
assert sched.timesteps.shape == (20,)
assert sched.sigmas.shape == (21,)
def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 4, 1, 2, 2))
vel = mx.zeros_like(sample)
assert sched._step_index == 0
sched.step(vel, sched.timesteps[0], sample)
assert sched._step_index == 1
sched.step(vel, sched.timesteps[1], sample)
assert sched._step_index == 2
def test_reset(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1))
sched.step(mx.zeros_like(sample), 0, sample)
sched.reset()
assert sched._step_index == 0
assert sched._prev_x0 is None
def test_full_loop_finite(self):
"""Full loop with constant velocity should produce finite output."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=1.0)
sample = mx.ones((1, 2, 1, 2, 2))
for i in range(10):
vel = mx.ones_like(sample) * 0.1
sample = sched.step(vel, sched.timesteps[i], sample)
mx.eval(sample)
assert np.isfinite(np.array(sample)).all()
def test_first_step_is_first_order(self):
"""First step should use 1st-order (no prev_x0 available)."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 4, 2, 4, 4))
vel = mx.random.normal(sample.shape)
# Before first step, no prev_x0
assert sched._prev_x0 is None
result = sched.step(vel, sched.timesteps[0], sample)
mx.eval(result)
# After first step, prev_x0 should be set
assert sched._prev_x0 is not None
def test_second_step_uses_correction(self):
"""After first step, DPM++ should have stored prev_x0 for correction."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 4, 1, 2, 2))
vel = mx.random.normal(sample.shape)
# Step 1
sample = sched.step(vel, sched.timesteps[0], sample)
mx.eval(sample)
x0_after_first = sched._prev_x0
# Step 2
vel = mx.random.normal(sample.shape)
sample = sched.step(vel, sched.timesteps[1], sample)
mx.eval(sample)
# prev_x0 should have been updated
x0_after_second = sched._prev_x0
assert x0_after_second is not None
# The stored x0 should differ from the first step's
assert not np.allclose(np.array(x0_after_first), np.array(x0_after_second), atol=1e-6)
def test_denoise_to_target(self):
"""Perfect oracle should denoise to target with any solver."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(20, shift=5.0)
target = mx.zeros((1, 2, 1, 4, 4))
latents = mx.random.normal(target.shape)
for i in range(20):
sigma = float(sched.sigmas[i].item())
v = latents / max(sigma, 1e-6) # perfect velocity for target=0
latents = sched.step(v, sched.timesteps[i], latents)
mx.eval(latents)
np.testing.assert_allclose(np.array(latents), 0.0, atol=1e-3)
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(steps, shift=5.0)
mx.eval(sched.timesteps, sched.sigmas)
assert sched.timesteps.shape == (steps,)
assert sched.sigmas.shape == (steps + 1,)
def test_terminal_sigma_produces_x0(self):
"""When sigma_next=0 the scheduler should return x0 directly."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1)) * 3.0
vel = mx.ones_like(sample) * 2.0
# Run through all steps; the last step has sigma_next=0
for i in range(5):
sample = sched.step(vel, sched.timesteps[i], sample)
mx.eval(sample)
# Final value should be finite
assert np.isfinite(np.array(sample)).all()
# ---------------------------------------------------------------------------
# UniPC Scheduler Tests
# ---------------------------------------------------------------------------
class TestFlowUniPCScheduler:
def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
assert sched.num_train_timesteps == 1000
assert sched.solver_order == 2
assert sched.lower_order_final is True
def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(30, shift=12.0)
mx.eval(sched.timesteps, sched.sigmas)
assert sched.timesteps.shape == (30,)
assert sched.sigmas.shape == (31,)
def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1))
vel = mx.zeros_like(sample)
assert sched._step_index == 0
sched.step(vel, 0, sample)
assert sched._step_index == 1
def test_reset(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1))
sched.step(mx.zeros_like(sample), 0, sample)
sched.reset()
assert sched._step_index == 0
assert sched._lower_order_nums == 0
assert sched._last_sample is None
assert all(m is None for m in sched._model_outputs)
def test_full_loop_finite(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(10, shift=1.0)
sample = mx.ones((1, 2, 1, 2, 2))
for i in range(10):
vel = mx.ones_like(sample) * 0.1
sample = sched.step(vel, sched.timesteps[i], sample)
mx.eval(sample)
assert np.isfinite(np.array(sample)).all()
def test_corrector_not_applied_first_step(self):
"""First step should skip the corrector (no history)."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True)
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 4, 1, 2, 2))
vel = mx.random.normal(sample.shape)
# Before step 0: no last_sample
assert sched._last_sample is None
sched.step(vel, sched.timesteps[0], sample)
# After step 0: last_sample should be set for corrector on step 1
assert sched._last_sample is not None
def test_corrector_applied_after_first_step(self):
"""Steps after the first should use the corrector when enabled."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True)
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 2, 1, 4, 4))
for i in range(3):
vel = mx.random.normal(sample.shape)
sample = sched.step(vel, sched.timesteps[i], sample)
mx.eval(sample)
# lower_order_nums should have increased
assert sched._lower_order_nums >= 2
def test_denoise_to_target(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(20, shift=5.0)
target = mx.zeros((1, 2, 1, 4, 4))
latents = mx.random.normal(target.shape)
for i in range(20):
sigma = float(sched.sigmas[i].item())
v = latents / max(sigma, 1e-6)
latents = sched.step(v, sched.timesteps[i], latents)
mx.eval(latents)
np.testing.assert_allclose(np.array(latents), 0.0, atol=1e-3)
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(steps, shift=5.0)
mx.eval(sched.timesteps, sched.sigmas)
assert sched.timesteps.shape == (steps,)
assert sched.sigmas.shape == (steps + 1,)
def test_disable_corrector(self):
"""Disabling corrector on step 0 should still work without error."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0])
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 2, 2))
for i in range(5):
vel = mx.ones_like(sample) * 0.1
sample = sched.step(vel, sched.timesteps[i], sample)
mx.eval(sample)
assert np.isfinite(np.array(sample)).all()
def test_solver_order_3(self):
"""Order 3 should work without error."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(solver_order=3, use_corrector=True)
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 2, 1, 2, 2))
for i in range(10):
vel = mx.random.normal(sample.shape)
sample = sched.step(vel, sched.timesteps[i], sample)
mx.eval(sample)
assert np.isfinite(np.array(sample)).all()
def test_corrector_rhos_c_not_hardcoded(self):
"""Corrector rhos_c should be computed via linalg.solve, not hardcoded 0.5."""
import math
# For 50-step schedule with shift=5.0, order 2 corrector at step 5:
# rhos_c[0] (history) should be ~0.07, NOT 0.5
# rhos_c[1] (D1_t) should be ~0.45, NOT 0.5
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(50, shift=5.0)
def _lambda(sigma):
if sigma >= 1.0:
return -math.inf
if sigma <= 0.0:
return math.inf
return math.log(1 - sigma) - math.log(sigma)
for step_idx in [5, 10, 25, 45]:
sigma_s0 = sigmas[step_idx - 1]
sigma_t = sigmas[step_idx]
lambda_s0 = _lambda(sigma_s0)
lambda_t = _lambda(sigma_t)
h = lambda_t - lambda_s0
hh = -h
sigma_sk = sigmas[step_idx - 2]
lambda_sk = _lambda(sigma_sk)
rk = (lambda_sk - lambda_s0) / h
rks = np.array([rk, 1.0])
h_phi_1 = math.expm1(hh)
B_h = h_phi_1
h_phi_k = h_phi_1 / hh - 1.0
factorial_i = 1
R_rows, b_vals = [], []
for j in range(1, 3):
R_rows.append(rks ** (j - 1))
b_vals.append(h_phi_k * factorial_i / B_h)
factorial_i *= j + 1
h_phi_k = h_phi_k / hh - 1.0 / factorial_i
R = np.stack(R_rows)
b = np.array(b_vals)
rhos_c = np.linalg.solve(R, b)
# History weight should be small (~0.07-0.09), not 0.5
assert rhos_c[0] < 0.15, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large"
assert rhos_c[0] > 0.0, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive"
# D1_t weight should be ~0.42-0.45, not 0.5
assert 0.3 < rhos_c[1] < 0.5, f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range"
class TestSchedulerCoherence:
"""Tests that Euler, DPM++, and UniPC schedulers produce coherent results.
All three schedulers should agree on shared structure (sigma schedules,
first-step behavior) and converge to the same result given perfect
velocity oracles, even though they use different update rules.
"""
@staticmethod
def _make_schedulers(steps=10, shift=5.0):
from mlx_video.models.wan.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
)
scheds = {
"euler": FlowMatchEulerScheduler(),
"dpm++": FlowDPMPP2MScheduler(),
"unipc": FlowUniPCScheduler(),
}
for s in scheds.values():
s.set_timesteps(steps, shift=shift)
return scheds
def test_identical_sigma_schedules(self):
"""All schedulers must use the same sigma schedule."""
scheds = self._make_schedulers(20, shift=5.0)
ref = np.array(scheds["euler"].sigmas)
for name in ("dpm++", "unipc"):
np.testing.assert_allclose(
np.array(scheds[name].sigmas),
ref,
atol=1e-6,
err_msg=f"{name} sigma schedule differs from Euler",
)
def test_identical_timesteps(self):
"""All schedulers must produce the same timestep sequence."""
scheds = self._make_schedulers(20, shift=5.0)
ref = np.array(scheds["euler"].timesteps)
for name in ("dpm++", "unipc"):
np.testing.assert_allclose(
np.array(scheds[name].timesteps),
ref,
atol=1e-6,
err_msg=f"{name} timesteps differ from Euler",
)
def test_first_step_matches_euler(self):
"""Step 0 (1st-order for all solvers) should match Euler exactly."""
mx.random.seed(42)
shape = (1, 4, 1, 4, 4)
noise = mx.random.normal(shape)
vel = mx.random.normal(shape)
scheds = self._make_schedulers(10, shift=5.0)
results = {}
for name, sched in scheds.items():
r = sched.step(vel, sched.timesteps[0], noise)
mx.eval(r)
results[name] = np.array(r)
np.testing.assert_allclose(
results["dpm++"], results["euler"], atol=1e-5,
err_msg="DPM++ step 0 should match Euler",
)
np.testing.assert_allclose(
results["unipc"], results["euler"], atol=1e-5,
err_msg="UniPC step 0 should match Euler",
)
def test_first_step_matches_across_shifts(self):
"""Step 0 should match Euler for different shift values."""
mx.random.seed(99)
shape = (1, 2, 1, 2, 2)
noise = mx.random.normal(shape)
vel = mx.random.normal(shape)
for shift in (1.0, 5.0, 12.0):
scheds = self._make_schedulers(10, shift=shift)
euler_r = scheds["euler"].step(vel, scheds["euler"].timesteps[0], noise)
dpm_r = scheds["dpm++"].step(vel, scheds["dpm++"].timesteps[0], noise)
unipc_r = scheds["unipc"].step(vel, scheds["unipc"].timesteps[0], noise)
mx.eval(euler_r, dpm_r, unipc_r)
np.testing.assert_allclose(
np.array(dpm_r), np.array(euler_r), atol=1e-5,
err_msg=f"DPM++ step 0 differs from Euler at shift={shift}",
)
np.testing.assert_allclose(
np.array(unipc_r), np.array(euler_r), atol=1e-5,
err_msg=f"UniPC step 0 differs from Euler at shift={shift}",
)
def test_oracle_all_converge_to_target(self):
"""Given a perfect velocity oracle v=x/sigma, all solvers should
denoise to approximately zero (the target)."""
mx.random.seed(7)
shape = (1, 2, 1, 4, 4)
noise = mx.random.normal(shape)
for name, sched in self._make_schedulers(20, shift=5.0).items():
latents = noise
for i in range(20):
sigma = float(sched.sigmas[i].item())
v = latents / max(sigma, 1e-8)
latents = sched.step(v, sched.timesteps[i], latents)
mx.eval(latents)
np.testing.assert_allclose(
np.array(latents), 0.0, atol=1e-3,
err_msg=f"{name} did not converge to target with oracle",
)
def test_oracle_higher_order_closer_to_target(self):
"""With few steps and a perfect oracle, higher-order solvers should
be at least as accurate as Euler."""
mx.random.seed(12)
shape = (1, 2, 1, 4, 4)
noise = mx.random.normal(shape)
steps = 5
errors = {}
for name, sched in self._make_schedulers(steps, shift=5.0).items():
latents = noise
for i in range(steps):
sigma = float(sched.sigmas[i].item())
v = latents / max(sigma, 1e-8)
latents = sched.step(v, sched.timesteps[i], latents)
mx.eval(latents)
errors[name] = float(mx.mean(mx.abs(latents)).item())
# Higher-order solvers should not be significantly worse than Euler
assert errors["dpm++"] <= errors["euler"] * 1.5, (
f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}"
)
assert errors["unipc"] <= errors["euler"] * 1.5, (
f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}"
)
def test_multistep_trajectory_similar_magnitude(self):
"""Over a full denoising loop with constant velocity, all solvers
should produce outputs of similar magnitude (not diverging)."""
mx.random.seed(42)
shape = (1, 4, 1, 4, 4)
noise = mx.random.normal(shape)
steps = 20
final_means = {}
for name, sched in self._make_schedulers(steps, shift=5.0).items():
latents = noise
for i in range(steps):
vel = latents * 0.1
latents = sched.step(vel, sched.timesteps[i], latents)
mx.eval(latents)
final_means[name] = float(mx.mean(mx.abs(latents)).item())
# All solvers should produce results within the same order of magnitude
vals = list(final_means.values())
ratio = max(vals) / max(min(vals), 1e-10)
assert ratio < 10.0, (
f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}"
)
def test_intermediate_values_finite(self):
"""Every intermediate latent value must be finite for all solvers."""
mx.random.seed(0)
shape = (1, 2, 1, 2, 2)
noise = mx.random.normal(shape)
for name, sched in self._make_schedulers(15, shift=5.0).items():
latents = noise
for i in range(15):
vel = mx.random.normal(shape)
latents = sched.step(vel, sched.timesteps[i], latents)
mx.eval(latents)
assert np.isfinite(np.array(latents)).all(), (
f"{name} produced non-finite values at step {i}"
)
def test_lambda_boundary_values(self):
"""_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0."""
from mlx_video.models.wan.scheduler import (
FlowDPMPP2MScheduler,
FlowUniPCScheduler,
)
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
assert cls._lambda(1.0) == -math.inf, (
f"{cls.__name__}._lambda(1.0) should be -inf"
)
assert cls._lambda(0.0) == math.inf, (
f"{cls.__name__}._lambda(0.0) should be +inf"
)
# Interior values should be finite
lam = cls._lambda(0.5)
assert math.isfinite(lam) and lam == 0.0, (
f"{cls.__name__}._lambda(0.5) should be 0.0"
)
def test_lambda_monotonically_decreasing(self):
"""_lambda(sigma) should decrease as sigma increases (more noise → lower SNR)."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sigmas = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99]
lambdas = [FlowDPMPP2MScheduler._lambda(s) for s in sigmas]
for i in range(len(lambdas) - 1):
assert lambdas[i] > lambdas[i + 1], (
f"_lambda not decreasing: _lambda({sigmas[i]})={lambdas[i]} "
f"vs _lambda({sigmas[i+1]})={lambdas[i+1]}"
)
def test_step0_is_ddim_formula(self):
"""At sigma=1.0, the DPM++/UniPC first step should reduce to the
DDIM formula: x_next = sigma_next * x + (1 - sigma_next) * x0."""
mx.random.seed(55)
shape = (1, 2, 1, 2, 2)
sample = mx.random.normal(shape)
vel = mx.random.normal(shape)
for steps, shift in [(10, 5.0), (20, 12.0)]:
scheds = self._make_schedulers(steps, shift=shift)
sigma_next = float(scheds["euler"].sigmas[1].item())
sigma_cur = float(scheds["euler"].sigmas[0].item())
assert abs(sigma_cur - 1.0) < 1e-6, "First sigma should be ~1.0"
x0 = sample - sigma_cur * vel
expected = sigma_next * sample + (1.0 - sigma_next) * x0
mx.eval(expected)
for name in ("dpm++", "unipc"):
result = scheds[name].step(vel, scheds[name].timesteps[0], sample)
mx.eval(result)
np.testing.assert_allclose(
np.array(result), np.array(expected), atol=1e-5,
err_msg=f"{name} step 0 doesn't match DDIM formula (shift={shift})",
)
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_coherent_across_step_counts(self, steps):
"""All solvers should agree on step 0 regardless of total step count."""
mx.random.seed(77)
shape = (1, 2, 1, 2, 2)
noise = mx.random.normal(shape)
vel = mx.random.normal(shape)
scheds = self._make_schedulers(steps, shift=5.0)
results = {}
for name, sched in scheds.items():
r = sched.step(vel, sched.timesteps[0], noise)
mx.eval(r)
results[name] = np.array(r)
np.testing.assert_allclose(
results["dpm++"], results["euler"], atol=1e-5,
)
np.testing.assert_allclose(
results["unipc"], results["euler"], atol=1e-5,
)
def test_dpmpp_unipc_agree_on_step1(self):
"""After warmup, DPM++ and UniPC step 1 should be similar
(both use 2nd-order corrections based on the same model outputs)."""
mx.random.seed(42)
shape = (1, 4, 1, 4, 4)
noise = mx.random.normal(shape)
scheds = self._make_schedulers(10, shift=5.0)
# Run step 0 with same velocity
vel0 = mx.random.normal(shape)
for sched in scheds.values():
sched.step(vel0, sched.timesteps[0], noise)
# Run step 1 from same sample with same velocity
sample1 = scheds["euler"].step(vel0, scheds["euler"].timesteps[0], noise)
mx.eval(sample1)
vel1 = mx.random.normal(shape)
r_dpm = scheds["dpm++"].step(vel1, scheds["dpm++"].timesteps[1], sample1)
r_unipc = scheds["unipc"].step(vel1, scheds["unipc"].timesteps[1], sample1)
mx.eval(r_dpm, r_unipc)
# They won't be identical (different correction formulas) but should
# be in the same ballpark (within 50% of each other's magnitude)
mean_dpm = float(mx.mean(mx.abs(r_dpm)).item())
mean_unipc = float(mx.mean(mx.abs(r_unipc)).item())
ratio = max(mean_dpm, mean_unipc) / max(min(mean_dpm, mean_unipc), 1e-10)
assert ratio < 2.0, (
f"DPM++ and UniPC step 1 differ too much: "
f"DPM++={mean_dpm:.4f}, UniPC={mean_unipc:.4f}"
)
def test_reset_makes_solvers_reproducible(self):
"""After reset(), running the same loop should produce identical output."""
mx.random.seed(42)
shape = (1, 2, 1, 2, 2)
noise = mx.random.normal(shape)
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler, FlowUniPCScheduler
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
sched = cls()
sched.set_timesteps(5, shift=5.0)
# First run
latents = noise
for i in range(5):
vel = latents * 0.1
latents = sched.step(vel, sched.timesteps[i], latents)
mx.eval(latents)
result1 = np.array(latents)
# Reset and run again
sched.reset()
latents = noise
for i in range(5):
vel = latents * 0.1
latents = sched.step(vel, sched.timesteps[i], latents)
mx.eval(latents)
result2 = np.array(latents)
np.testing.assert_allclose(result1, result2, atol=1e-5,
err_msg=f"{cls.__name__} not reproducible after reset()")
# ---------------------------------------------------------------------------
# Wan2.2 VAE Component Tests
# ---------------------------------------------------------------------------
class TestVAE22CausalConv3d:
"""Tests for vae22.CausalConv3d (channels-last)."""
def test_output_shape_k3(self):
from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(8, 16, kernel_size=3, padding=1)
x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C]
out = conv(x)
mx.eval(out)
assert out.shape == (1, 4, 8, 8, 16)
def test_output_shape_k1(self):
from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(8, 16, kernel_size=1)
x = mx.random.normal((1, 2, 4, 4, 8))
out = conv(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 16)
def test_temporal_causal(self):
"""Output at t=0 should not depend on t>0."""
from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(2, 2, kernel_size=3, padding=1)
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
conv.bias = mx.zeros(conv.bias.shape)
x = mx.zeros((1, 4, 4, 4, 2))
out_zero = conv(x)
mx.eval(out_zero)
t0_ref = np.array(out_zero[0, 0])
# Modify t=2..3; output at t=0 should be unchanged
x_mod = mx.concatenate([
x[:, :2],
mx.ones((1, 2, 4, 4, 2)),
], axis=1)
out_mod = conv(x_mod)
mx.eval(out_mod)
t0_mod = np.array(out_mod[0, 0])
np.testing.assert_allclose(t0_ref, t0_mod, atol=1e-5)
def test_channels_last_format(self):
"""Verify input/output are channels-last [B, T, H, W, C]."""
from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=3, padding=1)
x = mx.random.normal((2, 3, 6, 6, 4))
out = conv(x)
mx.eval(out)
assert out.shape[-1] == 8 # last dim = out_channels
class TestRMSNorm:
"""Tests for vae22.RMS_norm (actually L2 normalization)."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import RMS_norm
norm = RMS_norm(16)
x = mx.random.normal((2, 4, 4, 4, 16))
out = norm(x)
mx.eval(out)
assert out.shape == x.shape
def test_l2_normalization(self):
"""RMS_norm should normalize to unit L2 norm * sqrt(dim)."""
from mlx_video.models.wan.vae22 import RMS_norm
dim = 32
norm = RMS_norm(dim)
x = mx.random.normal((1, 1, 1, 1, dim)) * 5.0 # large values
out = norm(x)
mx.eval(out)
# After L2 norm * scale(=sqrt(dim)) * gamma(=1): ||out|| = sqrt(dim)
out_np = np.array(out).flatten()
l2 = np.linalg.norm(out_np)
np.testing.assert_allclose(l2, math.sqrt(dim), rtol=1e-3)
def test_scale_invariant(self):
"""Scaling input by constant should not change output (L2 norm property)."""
from mlx_video.models.wan.vae22 import RMS_norm
norm = RMS_norm(8)
x = mx.random.normal((1, 1, 1, 1, 8))
out1 = norm(x)
out2 = norm(x * 10.0)
mx.eval(out1, out2)
np.testing.assert_allclose(np.array(out1), np.array(out2), atol=1e-4)
def test_gamma_effect(self):
"""Non-unit gamma should scale output."""
from mlx_video.models.wan.vae22 import RMS_norm
norm = RMS_norm(4)
norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0])
x = mx.ones((1, 1, 1, 1, 4))
out = norm(x)
mx.eval(out)
# With gamma=2, each component is 2 * sqrt(4) * x/||x|| = 2 * 2 * 1/2 = 2
np.testing.assert_allclose(np.array(out).flatten(), 2.0, atol=1e-4)
class TestDupUp3D:
"""Tests for vae22.DupUp3D spatial/temporal upsampling."""
def test_spatial_only(self):
from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
out = up(x)
mx.eval(out)
assert out.shape == (1, 3, 8, 8, 4)
def test_temporal_and_spatial(self):
from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(16, 8, factor_t=2, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 16))
out = up(x)
mx.eval(out)
assert out.shape == (1, 6, 8, 8, 8)
def test_first_chunk_trims(self):
from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=2, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
out_normal = up(x, first_chunk=False)
out_trimmed = up(x, first_chunk=True)
mx.eval(out_normal, out_trimmed)
# first_chunk removes factor_t-1=1 temporal frame
assert out_normal.shape[1] == 6
assert out_trimmed.shape[1] == 5
def test_no_temporal_first_chunk_noop(self):
from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
out_normal = up(x, first_chunk=False)
out_trimmed = up(x, first_chunk=True)
mx.eval(out_normal, out_trimmed)
# factor_t=1, so first_chunk removes 0 frames
assert out_normal.shape == out_trimmed.shape
class TestVAE22Resample:
"""Tests for vae22.Resample (spatial/temporal upsampling)."""
def test_upsample2d_shape(self):
from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample2d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 2, 4, 4, 8))
out = r(x)
mx.eval(out)
assert out.shape == (1, 2, 8, 8, 8) # 2x spatial, same temporal
def test_upsample3d_shape(self):
from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 2, 4, 4, 8))
out = r(x)
mx.eval(out)
assert out.shape == (1, 4, 8, 8, 8) # 2x spatial + 2x temporal
def test_upsample3d_first_chunk(self):
from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 2, 4, 4, 8))
out = r(x, first_chunk=True)
mx.eval(out)
# first_chunk: 1 (bypass) + 2*(T-1) (interleaved) = 2T-1 = 3
assert out.shape == (1, 3, 8, 8, 8)
def test_upsample3d_first_chunk_single_frame(self):
"""Single-frame input with first_chunk: no temporal upsample."""
from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 1, 4, 4, 8))
out = r(x, first_chunk=True)
mx.eval(out)
# Single frame with first_chunk: falls through to non-first path
# time_conv on 1 frame → 2 interleaved
assert out.shape == (1, 2, 8, 8, 8)
def test_upsample3d_first_frame_bypasses_time_conv(self):
"""First frame of first_chunk should NOT go through time_conv.
Official Wan2.2 skips time_conv for the very first frame entirely.
We verify this by checking that the first output frame depends only on
the first input frame (not on time_conv parameters).
"""
from mlx_video.models.wan.vae22 import Resample
C = 8
r = Resample(C, "upsample3d")
# Set time_conv weights to large values so its effect is detectable
r.time_conv.weight = mx.ones(r.time_conv.weight.shape) * 10.0
r.time_conv.bias = mx.zeros(r.time_conv.bias.shape)
# Set spatial conv to identity-like
r.resample_weight = mx.zeros(r.resample_weight.shape)
r.resample_bias = mx.zeros(r.resample_bias.shape)
x = mx.random.normal((1, 3, 2, 2, C))
out = r(x, first_chunk=True)
mx.eval(out)
# Output: 5 frames (1 bypass + 4 interleaved from 2 remaining)
assert out.shape[1] == 5
# First frame should be spatial upsample of x[:, 0:1] only.
# Run just the first frame through spatial upsample for reference
first_only = x[:, 0:1]
ref = r._upsample2x(first_only.reshape(1, 2, 2, C))
ref = mx.pad(ref, [(0, 0), (1, 1), (1, 1), (0, 0)])
ref = mx.conv_general(ref, r.resample_weight) + r.resample_bias
mx.eval(ref)
# Compare first output frame to reference
first_out = out[:, 0:1].reshape(1, out.shape[2], out.shape[3], C)
mx.eval(first_out)
assert mx.allclose(first_out, ref, atol=1e-5).item(), \
"First frame should bypass time_conv and match spatial-only upsample"
class TestVAE22ResidualBlock:
"""Tests for vae22.ResidualBlock."""
def test_same_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 8)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 8)
def test_different_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 16)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 16)
def test_shortcut_when_dims_differ(self):
from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 16)
assert block.shortcut is not None
def test_no_shortcut_same_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 8)
assert block.shortcut is None
class TestResidualBlockLayers:
"""Tests for vae22.ResidualBlockLayers naming convention."""
def test_layer_names_no_underscore_prefix(self):
"""Layer names must NOT start with underscore (MLX ignores them)."""
from mlx_video.models.wan.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 8)
params = dict(block.parameters())
# All param keys should use layer_N, not _layer_N
for key in params:
assert not key.startswith("_"), f"Parameter {key} starts with underscore"
def test_has_expected_layers(self):
from mlx_video.models.wan.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 16)
assert hasattr(block, "layer_0") # first RMS_norm
assert hasattr(block, "layer_2") # first CausalConv3d
assert hasattr(block, "layer_3") # second RMS_norm
assert hasattr(block, "layer_6") # second CausalConv3d
def test_forward_shape(self):
from mlx_video.models.wan.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 16)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 16)
class TestVAE22AttentionBlock:
"""Tests for vae22.AttentionBlock (per-frame 2D self-attention)."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import AttentionBlock
block = AttentionBlock(16)
block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01
block.proj_weight = mx.random.normal(block.proj_weight.shape) * 0.01
x = mx.random.normal((1, 2, 4, 4, 16))
out = block(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 16)
def test_residual_connection(self):
from mlx_video.models.wan.vae22 import AttentionBlock
block = AttentionBlock(8)
block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape)
block.proj_weight = mx.zeros(block.proj_weight.shape)
x = mx.ones((1, 1, 2, 2, 8))
out = block(x)
mx.eval(out)
# With zero weights, attention output is 0 → residual is identity
np.testing.assert_allclose(np.array(out), np.array(x), atol=1e-5)
class TestHead22:
"""Tests for vae22.Head22 output head."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import Head22
head = Head22(16, out_channels=12)
x = mx.random.normal((1, 2, 4, 4, 16))
out = head(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 12)
def test_layer_names_no_underscore(self):
"""Head layers must not use underscore prefix."""
from mlx_video.models.wan.vae22 import Head22
head = Head22(8)
assert hasattr(head, "layer_0") # RMS_norm
assert hasattr(head, "layer_2") # CausalConv3d
params = dict(head.parameters())
for key in params:
assert not key.startswith("_"), f"Head param {key} starts with underscore"
class TestUnpatchify:
"""Tests for vae22._unpatchify."""
def test_basic_shape(self):
from mlx_video.models.wan.vae22 import _unpatchify
x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2
out = _unpatchify(x, patch_size=2)
mx.eval(out)
assert out.shape == (1, 2, 8, 8, 3)
def test_patch_size_1_noop(self):
from mlx_video.models.wan.vae22 import _unpatchify
x = mx.random.normal((1, 2, 4, 4, 3))
out = _unpatchify(x, patch_size=1)
mx.eval(out)
np.testing.assert_array_equal(np.array(out), np.array(x))
def test_preserves_content(self):
"""Unpatchify should be a lossless rearrangement."""
from mlx_video.models.wan.vae22 import _unpatchify
x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32)
out = _unpatchify(x, patch_size=2)
mx.eval(out)
# All elements should be preserved
assert np.array(out).size == 48
assert set(np.array(out).flatten().tolist()) == set(range(48))
class TestDenormalizeLatents:
"""Tests for vae22.denormalize_latents."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import denormalize_latents
z = mx.random.normal((1, 2, 4, 4, 48))
out = denormalize_latents(z)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 48)
def test_custom_mean_std(self):
from mlx_video.models.wan.vae22 import denormalize_latents
z = mx.ones((1, 1, 1, 1, 4))
mean = mx.array([1.0, 2.0, 3.0, 4.0])
std = mx.array([0.5, 0.5, 0.5, 0.5])
out = denormalize_latents(z, mean=mean, std=std)
mx.eval(out)
# z * std + mean = 1*0.5 + [1,2,3,4] = [1.5, 2.5, 3.5, 4.5]
np.testing.assert_allclose(np.array(out).flatten(), [1.5, 2.5, 3.5, 4.5], atol=1e-5)
def test_uses_default_constants(self):
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD, denormalize_latents
# Should not raise with default constants
z = mx.zeros((1, 1, 1, 1, 48))
out = denormalize_latents(z)
mx.eval(out)
# z=0 → result = 0 * std + mean = mean
np.testing.assert_allclose(
np.array(out).flatten(),
np.array(VAE22_MEAN).flatten(),
atol=1e-5,
)
class TestVAE22NormConstants:
"""Tests for VAE22_MEAN and VAE22_STD constants."""
def test_dimensions(self):
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD
mx.eval(VAE22_MEAN, VAE22_STD)
assert VAE22_MEAN.shape == (48,)
assert VAE22_STD.shape == (48,)
def test_std_positive(self):
from mlx_video.models.wan.vae22 import VAE22_STD
mx.eval(VAE22_STD)
assert (np.array(VAE22_STD) > 0).all()
class TestWan22VAEDecoder:
"""Tests for the full Wan22VAEDecoder (tiny configuration)."""
def test_output_shape_small(self):
"""Tiny decoder should produce correct spatial/temporal output."""
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
# Use very small dims to keep test fast
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
# Latent: [B=1, T=3, H=2, W=2, C=4]
# Expected: temporal 3→5→9→9→9 (two temporal upsamples), spatial 2→4→8→16
z = mx.random.normal((1, 3, 2, 2, 4)) * 0.1
out = dec(z)
mx.eval(out)
# Output should have 3 RGB channels and be clipped to [-1, 1]
assert out.shape[-1] == 3
assert out.ndim == 5
assert np.array(out).min() >= -1.0
assert np.array(out).max() <= 1.0
def test_output_clipped(self):
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values
out = dec(z)
mx.eval(out)
assert np.array(out).min() >= -1.0 - 1e-6
assert np.array(out).max() <= 1.0 + 1e-6
class TestSanitizeWan22VAEWeights:
"""Tests for vae22.sanitize_wan22_vae_weights."""
def test_skip_encoder(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.layer.weight": mx.zeros((4,)),
"conv1.weight": mx.zeros((4,)),
"decoder.conv1.bias": mx.zeros((4,)),
}
out = sanitize_wan22_vae_weights(weights)
assert "encoder.layer.weight" not in out
assert "conv1.weight" not in out
assert "decoder.conv1.bias" in out
def test_sequential_index_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)),
"decoder.upsamples.0.upsamples.0.residual.6.bias": mx.zeros((8,)),
"decoder.head.0.gamma": mx.ones((4,)),
"decoder.head.2.bias": mx.zeros((12,)),
}
out = sanitize_wan22_vae_weights(weights)
assert "decoder.upsamples.0.upsamples.0.residual.layer_0.gamma" in out
assert "decoder.upsamples.0.upsamples.0.residual.layer_6.bias" in out
assert "decoder.head.layer_0.gamma" in out
assert "decoder.head.layer_2.bias" in out
def test_resample_conv_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)),
"decoder.upsamples.1.upsamples.3.resample.1.bias": mx.zeros((8,)),
}
out = sanitize_wan22_vae_weights(weights)
assert "decoder.upsamples.1.upsamples.3.resample_weight" in out
assert "decoder.upsamples.1.upsamples.3.resample_bias" in out
def test_attention_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)),
"decoder.middle.1.to_qkv.bias": mx.zeros((24,)),
"decoder.middle.1.proj.weight": mx.zeros((8, 8, 1, 1)),
"decoder.middle.1.proj.bias": mx.zeros((8,)),
}
out = sanitize_wan22_vae_weights(weights)
assert "decoder.middle.1.to_qkv_weight" in out
assert "decoder.middle.1.to_qkv_bias" in out
assert "decoder.middle.1.proj_weight" in out
assert "decoder.middle.1.proj_bias" in out
def test_conv3d_transpose(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
# Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I]
w = mx.zeros((16, 8, 3, 3, 3))
weights = {"decoder.conv1.weight": w}
out = sanitize_wan22_vae_weights(weights)
assert out["decoder.conv1.weight"].shape == (16, 3, 3, 3, 8)
def test_conv2d_transpose(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
# Conv2d weight: [O, I, H, W] → [O, H, W, I]
w = mx.zeros((8, 8, 3, 3))
weights = {"decoder.upsamples.0.upsamples.2.resample.1.weight": w}
out = sanitize_wan22_vae_weights(weights)
key = "decoder.upsamples.0.upsamples.2.resample_weight"
assert out[key].shape == (8, 3, 3, 8)
def test_gamma_squeeze(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
# gamma: (dim, 1, 1, 1) → (dim,)
w = mx.ones((16, 1, 1, 1))
weights = {"decoder.upsamples.0.upsamples.0.residual.0.gamma": w}
out = sanitize_wan22_vae_weights(weights)
key = "decoder.upsamples.0.upsamples.0.residual.layer_0.gamma"
assert out[key].shape == (16,)
class TestUpResidualBlock:
"""Tests for vae22.Up_ResidualBlock."""
def test_no_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
# No upsample: same shape
assert out.shape == (1, 2, 4, 4, 8)
def test_spatial_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
# 2x spatial upsample, no temporal
assert out.shape == (1, 2, 8, 8, 4)
def test_spatial_temporal_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
# 2x spatial + 2x temporal
assert out.shape == (1, 4, 8, 8, 4)
if __name__ == "__main__":
pytest.main([__file__, "-v"])