Merge branch 'main' into pc/unify-apis
This commit is contained in:
4
tests/conftest.py
Normal file
4
tests/conftest.py
Normal file
@@ -0,0 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
372
tests/test_wan_attention.py
Normal file
372
tests/test_wan_attention.py
Normal file
@@ -0,0 +1,372 @@
|
||||
"""Tests for Wan attention components and RoPE."""
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RoPE Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRoPE:
|
||||
"""Tests for 3-way factorized RoPE."""
|
||||
|
||||
def test_rope_params_shape(self):
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
freqs = rope_params(1024, 64)
|
||||
mx.eval(freqs)
|
||||
assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2]
|
||||
|
||||
def test_rope_params_different_dims(self):
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
for dim in [32, 64, 128]:
|
||||
freqs = rope_params(512, dim)
|
||||
mx.eval(freqs)
|
||||
assert freqs.shape == (512, dim // 2, 2)
|
||||
|
||||
def test_rope_params_cos_sin_range(self):
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
freqs = rope_params(256, 64)
|
||||
mx.eval(freqs)
|
||||
cos_vals = np.array(freqs[:, :, 0])
|
||||
sin_vals = np.array(freqs[:, :, 1])
|
||||
assert np.all(cos_vals >= -1.0) and np.all(cos_vals <= 1.0)
|
||||
assert np.all(sin_vals >= -1.0) and np.all(sin_vals <= 1.0)
|
||||
|
||||
def test_rope_params_position_zero(self):
|
||||
"""At position 0, cos should be 1 and sin should be 0."""
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
freqs = rope_params(10, 64)
|
||||
mx.eval(freqs)
|
||||
np.testing.assert_allclose(np.array(freqs[0, :, 0]), 1.0, atol=1e-6)
|
||||
np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6)
|
||||
|
||||
def test_rope_apply_output_shape(self):
|
||||
from mlx_video.models.wan.rope import rope_params, rope_apply
|
||||
B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim
|
||||
x = mx.random.normal((B, L, N, D))
|
||||
freqs = rope_params(1024, D)
|
||||
grid_sizes = [(2, 3, 4)] # F*H*W = 24 = L
|
||||
out = rope_apply(x, grid_sizes, freqs)
|
||||
mx.eval(out)
|
||||
assert out.shape == (B, L, N, D)
|
||||
|
||||
def test_rope_apply_preserves_norm(self):
|
||||
"""RoPE rotation should preserve vector norms."""
|
||||
from mlx_video.models.wan.rope import rope_params, rope_apply
|
||||
B, N, D = 1, 2, 16
|
||||
F, H, W = 2, 3, 4
|
||||
L = F * H * W
|
||||
x = mx.random.normal((B, L, N, D))
|
||||
freqs = rope_params(1024, D)
|
||||
|
||||
out = rope_apply(x, [(F, H, W)], freqs)
|
||||
mx.eval(x, out)
|
||||
|
||||
x_np = np.array(x[0])
|
||||
out_np = np.array(out[0])
|
||||
for i in range(L):
|
||||
for h in range(N):
|
||||
norm_in = np.linalg.norm(x_np[i, h])
|
||||
norm_out = np.linalg.norm(out_np[i, h])
|
||||
np.testing.assert_allclose(norm_in, norm_out, rtol=1e-4)
|
||||
|
||||
def test_rope_apply_with_padding(self):
|
||||
"""When seq_len < L, extra tokens should be preserved unchanged."""
|
||||
from mlx_video.models.wan.rope import rope_params, rope_apply
|
||||
B, N, D = 1, 2, 16
|
||||
F, H, W = 2, 2, 2
|
||||
seq_len = F * H * W # 8
|
||||
pad = 4
|
||||
L = seq_len + pad
|
||||
x = mx.random.normal((B, L, N, D))
|
||||
freqs = rope_params(1024, D)
|
||||
|
||||
out = rope_apply(x, [(F, H, W)], freqs)
|
||||
mx.eval(x, out)
|
||||
# Padded tokens should be unchanged
|
||||
np.testing.assert_allclose(
|
||||
np.array(out[0, seq_len:]),
|
||||
np.array(x[0, seq_len:]),
|
||||
atol=1e-6,
|
||||
)
|
||||
|
||||
def test_rope_apply_batch(self):
|
||||
"""Test with batch_size > 1 and different grid sizes."""
|
||||
from mlx_video.models.wan.rope import rope_params, rope_apply
|
||||
B, N, D = 2, 2, 16
|
||||
grids = [(2, 3, 4), (2, 3, 4)]
|
||||
L = 2 * 3 * 4
|
||||
x = mx.random.normal((B, L, N, D))
|
||||
freqs = rope_params(1024, D)
|
||||
|
||||
out = rope_apply(x, grids, freqs)
|
||||
mx.eval(out)
|
||||
assert out.shape == (B, L, N, D)
|
||||
|
||||
def test_rope_frequency_split(self):
|
||||
"""Verify the 3-way frequency dimension split matches Wan2.2 convention."""
|
||||
D = 128 # head_dim for 14B model
|
||||
half_d = D // 2
|
||||
d_t = half_d - 2 * (half_d // 3)
|
||||
d_h = half_d // 3
|
||||
d_w = half_d // 3
|
||||
assert d_t + d_h + d_w == half_d
|
||||
# Temporal gets more capacity
|
||||
assert d_t >= d_h
|
||||
assert d_t >= d_w
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attention Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestWanRMSNorm:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.attention import WanRMSNorm
|
||||
norm = WanRMSNorm(64)
|
||||
x = mx.random.normal((2, 10, 64))
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (2, 10, 64)
|
||||
|
||||
def test_zero_mean_variance(self):
|
||||
"""RMS norm should make RMS ≈ 1 before scaling."""
|
||||
from mlx_video.models.wan.attention import WanRMSNorm
|
||||
norm = WanRMSNorm(64)
|
||||
x = mx.random.normal((1, 5, 64)) * 10.0
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
out_np = np.array(out[0])
|
||||
for i in range(5):
|
||||
rms = np.sqrt(np.mean(out_np[i] ** 2))
|
||||
# After RMS norm with weight=1, RMS should be ~1
|
||||
np.testing.assert_allclose(rms, 1.0, rtol=0.1)
|
||||
|
||||
def test_dtype_preservation(self):
|
||||
"""RMSNorm weight is float32, so output is promoted to float32."""
|
||||
from mlx_video.models.wan.attention import WanRMSNorm
|
||||
norm = WanRMSNorm(32)
|
||||
x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16)
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
# Weight is float32, so multiplication promotes result to float32
|
||||
assert out.dtype == mx.float32
|
||||
|
||||
|
||||
class TestWanLayerNorm:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.attention import WanLayerNorm
|
||||
norm = WanLayerNorm(64)
|
||||
x = mx.random.normal((2, 10, 64))
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (2, 10, 64)
|
||||
|
||||
def test_without_affine(self):
|
||||
from mlx_video.models.wan.attention import WanLayerNorm
|
||||
norm = WanLayerNorm(64, elementwise_affine=False)
|
||||
x = mx.random.normal((1, 4, 64))
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
# Mean should be ~0, variance should be ~1
|
||||
out_np = np.array(out[0])
|
||||
for i in range(4):
|
||||
np.testing.assert_allclose(np.mean(out_np[i]), 0.0, atol=0.05)
|
||||
np.testing.assert_allclose(np.std(out_np[i]), 1.0, rtol=0.1)
|
||||
|
||||
def test_with_affine(self):
|
||||
from mlx_video.models.wan.attention import WanLayerNorm
|
||||
norm = WanLayerNorm(32, elementwise_affine=True)
|
||||
assert hasattr(norm, "weight")
|
||||
assert hasattr(norm, "bias")
|
||||
x = mx.random.normal((1, 4, 32))
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 4, 32)
|
||||
|
||||
|
||||
class TestWanSelfAttention:
|
||||
def setup_method(self):
|
||||
mx.random.seed(42)
|
||||
self.dim = 64
|
||||
self.num_heads = 4
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.attention import WanSelfAttention
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
attn = WanSelfAttention(self.dim, self.num_heads)
|
||||
B, L = 1, 24
|
||||
F, H, W = 2, 3, 4
|
||||
x = mx.random.normal((B, L, self.dim))
|
||||
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||
out = attn(x, seq_lens=[L], grid_sizes=[(F, H, W)], freqs=freqs)
|
||||
mx.eval(out)
|
||||
assert out.shape == (B, L, self.dim)
|
||||
|
||||
def test_with_qk_norm(self):
|
||||
from mlx_video.models.wan.attention import WanSelfAttention
|
||||
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True)
|
||||
assert attn.norm_q is not None
|
||||
assert attn.norm_k is not None
|
||||
|
||||
def test_without_qk_norm(self):
|
||||
from mlx_video.models.wan.attention import WanSelfAttention
|
||||
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
|
||||
assert attn.norm_q is None
|
||||
assert attn.norm_k is None
|
||||
|
||||
def test_masking(self):
|
||||
"""Test that masking works: shorter seq_lens should mask later tokens."""
|
||||
from mlx_video.models.wan.attention import WanSelfAttention
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
|
||||
B, L = 1, 24
|
||||
F, H, W = 2, 3, 4
|
||||
x = mx.random.normal((B, L, self.dim))
|
||||
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||
|
||||
# Full sequence
|
||||
out_full = attn(x, seq_lens=[L], grid_sizes=[(F, H, W)], freqs=freqs)
|
||||
# Shorter sequence (mask last 4 tokens)
|
||||
out_masked = attn(x, seq_lens=[L - 4], grid_sizes=[(F, H, W)], freqs=freqs)
|
||||
mx.eval(out_full, out_masked)
|
||||
|
||||
# Outputs should differ when masking is applied
|
||||
assert not np.allclose(np.array(out_full), np.array(out_masked), atol=1e-5)
|
||||
|
||||
|
||||
class TestWanCrossAttention:
|
||||
def setup_method(self):
|
||||
mx.random.seed(42)
|
||||
self.dim = 64
|
||||
self.num_heads = 4
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.attention import WanCrossAttention
|
||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||
B, L_q, L_kv = 1, 24, 16
|
||||
x = mx.random.normal((B, L_q, self.dim))
|
||||
context = mx.random.normal((B, L_kv, self.dim))
|
||||
out = attn(x, context)
|
||||
mx.eval(out)
|
||||
assert out.shape == (B, L_q, self.dim)
|
||||
|
||||
def test_with_context_mask(self):
|
||||
from mlx_video.models.wan.attention import WanCrossAttention
|
||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||
B, L_q, L_kv = 1, 12, 16
|
||||
x = mx.random.normal((B, L_q, self.dim))
|
||||
context = mx.random.normal((B, L_kv, self.dim))
|
||||
out = attn(x, context, context_lens=[10])
|
||||
mx.eval(out)
|
||||
assert out.shape == (B, L_q, self.dim)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# bfloat16 Autocast Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBFloat16Autocast:
|
||||
"""Tests that attention and FFN cast inputs to weight dtype (bfloat16)
|
||||
for efficient matmul, matching official PyTorch autocast behavior."""
|
||||
|
||||
def setup_method(self):
|
||||
mx.random.seed(42)
|
||||
self.dim = 64
|
||||
self.num_heads = 4
|
||||
|
||||
@staticmethod
|
||||
def _to_bf16(params):
|
||||
"""Recursively cast all arrays in params to bfloat16."""
|
||||
if isinstance(params, dict):
|
||||
return {k: TestBFloat16Autocast._to_bf16(v) for k, v in params.items()}
|
||||
elif isinstance(params, list):
|
||||
return [TestBFloat16Autocast._to_bf16(v) for v in params]
|
||||
elif isinstance(params, mx.array):
|
||||
return params.astype(mx.bfloat16)
|
||||
return params
|
||||
|
||||
def test_self_attn_casts_to_weight_dtype(self):
|
||||
"""Self-attention should cast input to weight dtype for QKV projections."""
|
||||
from mlx_video.models.wan.attention import WanSelfAttention
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
attn = WanSelfAttention(self.dim, self.num_heads)
|
||||
attn.update(self._to_bf16(attn.parameters()))
|
||||
|
||||
x = mx.random.normal((1, 8, self.dim))
|
||||
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||
out = attn(x, seq_lens=[8], grid_sizes=[(2, 2, 2)], freqs=freqs)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 8, self.dim)
|
||||
assert np.isfinite(np.array(out.astype(mx.float32))).all()
|
||||
|
||||
def test_cross_attn_casts_to_weight_dtype(self):
|
||||
"""Cross-attention should cast input to weight dtype."""
|
||||
from mlx_video.models.wan.attention import WanCrossAttention
|
||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||
attn.update(self._to_bf16(attn.parameters()))
|
||||
|
||||
x = mx.random.normal((1, 8, self.dim))
|
||||
ctx = mx.random.normal((1, 4, self.dim))
|
||||
out = attn(x, ctx)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 8, self.dim)
|
||||
assert np.isfinite(np.array(out.astype(mx.float32))).all()
|
||||
|
||||
def test_cross_attn_kv_cache_uses_weight_dtype(self):
|
||||
"""prepare_kv should cast context to weight dtype."""
|
||||
from mlx_video.models.wan.attention import WanCrossAttention
|
||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||
attn.update(self._to_bf16(attn.parameters()))
|
||||
|
||||
ctx = mx.random.normal((1, 4, self.dim))
|
||||
k, v = attn.prepare_kv(ctx)
|
||||
mx.eval(k, v)
|
||||
assert k.dtype == mx.bfloat16
|
||||
assert v.dtype == mx.bfloat16
|
||||
|
||||
def test_ffn_casts_to_weight_dtype(self):
|
||||
"""FFN should cast input to weight dtype for linear layers."""
|
||||
from mlx_video.models.wan.transformer import WanFFN
|
||||
ffn = WanFFN(self.dim, 128)
|
||||
ffn.update(self._to_bf16(ffn.parameters()))
|
||||
|
||||
x = mx.random.normal((1, 8, self.dim))
|
||||
out = ffn(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 8, self.dim)
|
||||
assert np.isfinite(np.array(out.astype(mx.float32))).all()
|
||||
|
||||
def test_self_attn_rope_in_float32(self):
|
||||
"""RoPE should be applied in float32 for precision, even with bf16 weights."""
|
||||
from mlx_video.models.wan.attention import WanSelfAttention
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
attn = WanSelfAttention(self.dim, self.num_heads)
|
||||
attn.update(self._to_bf16(attn.parameters()))
|
||||
|
||||
x = mx.random.normal((1, 8, self.dim))
|
||||
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||
assert freqs.dtype == mx.float32
|
||||
out = attn(x, seq_lens=[8], grid_sizes=[(2, 2, 2)], freqs=freqs)
|
||||
mx.eval(out)
|
||||
assert np.isfinite(np.array(out.astype(mx.float32))).all()
|
||||
|
||||
def test_block_float32_residual_with_bf16_weights(self):
|
||||
"""Full block: residual stream stays float32, matmuls use bf16 weights."""
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
block = WanAttentionBlock(self.dim, 128, self.num_heads, cross_attn_norm=True)
|
||||
block.update(self._to_bf16(block.parameters()))
|
||||
|
||||
B, L = 1, 8
|
||||
x = mx.random.normal((B, L, self.dim))
|
||||
e = mx.random.normal((B, L, 6, self.dim))
|
||||
ctx = mx.random.normal((B, 4, self.dim))
|
||||
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||
|
||||
out = block(x, e, [L], [(2, 2, 2)], freqs, ctx)
|
||||
mx.eval(out)
|
||||
assert out.dtype == mx.float32
|
||||
assert np.isfinite(np.array(out)).all()
|
||||
125
tests/test_wan_config.py
Normal file
125
tests/test_wan_config.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Tests for Wan model configuration."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestWanModelConfig:
|
||||
"""Tests for WanModelConfig dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
config = WanModelConfig()
|
||||
assert config.dim == 5120
|
||||
assert config.ffn_dim == 13824
|
||||
assert config.num_heads == 40
|
||||
assert config.num_layers == 40
|
||||
assert config.in_dim == 16
|
||||
assert config.out_dim == 16
|
||||
assert config.patch_size == (1, 2, 2)
|
||||
assert config.vae_stride == (4, 8, 8)
|
||||
assert config.vae_z_dim == 16
|
||||
assert config.boundary == 0.875
|
||||
assert config.sample_shift == 12.0
|
||||
assert config.sample_steps == 40
|
||||
assert config.sample_guide_scale == (3.0, 4.0)
|
||||
assert config.num_train_timesteps == 1000
|
||||
assert config.qk_norm is True
|
||||
assert config.cross_attn_norm is True
|
||||
assert config.text_len == 512
|
||||
|
||||
def test_head_dim_property(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
config = WanModelConfig()
|
||||
assert config.head_dim == 128 # 5120 // 40
|
||||
|
||||
def test_to_dict_roundtrip(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
config = WanModelConfig()
|
||||
d = config.to_dict()
|
||||
assert isinstance(d, dict)
|
||||
assert d["dim"] == 5120
|
||||
assert d["patch_size"] == (1, 2, 2)
|
||||
assert d["boundary"] == 0.875
|
||||
|
||||
def test_t5_config_values(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
config = WanModelConfig()
|
||||
assert config.t5_vocab_size == 256384
|
||||
assert config.t5_dim == 4096
|
||||
assert config.t5_dim_attn == 4096
|
||||
assert config.t5_dim_ffn == 10240
|
||||
assert config.t5_num_heads == 64
|
||||
assert config.t5_num_layers == 24
|
||||
assert config.t5_num_buckets == 32
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wan2.1 Config Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestWan21Config:
|
||||
"""Tests for Wan2.1 config presets."""
|
||||
|
||||
def test_wan21_14b_factory(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
assert config.model_version == "2.1"
|
||||
assert config.dual_model is False
|
||||
assert config.dim == 5120
|
||||
assert config.ffn_dim == 13824
|
||||
assert config.num_heads == 40
|
||||
assert config.num_layers == 40
|
||||
assert config.head_dim == 128
|
||||
assert config.sample_guide_scale == 5.0
|
||||
assert config.sample_shift == 5.0
|
||||
assert config.sample_steps == 50
|
||||
assert config.boundary == 0.0
|
||||
|
||||
def test_wan21_1_3b_factory(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
assert config.model_version == "2.1"
|
||||
assert config.dual_model is False
|
||||
assert config.dim == 1536
|
||||
assert config.ffn_dim == 8960
|
||||
assert config.num_heads == 12
|
||||
assert config.num_layers == 30
|
||||
assert config.head_dim == 128 # 1536 // 12
|
||||
assert config.sample_guide_scale == 5.0
|
||||
|
||||
def test_wan22_14b_factory(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
config = WanModelConfig.wan22_t2v_14b()
|
||||
assert config.model_version == "2.2"
|
||||
assert config.dual_model is True
|
||||
assert config.dim == 5120
|
||||
assert config.sample_guide_scale == (3.0, 4.0)
|
||||
assert config.sample_shift == 12.0
|
||||
assert config.sample_steps == 40
|
||||
assert config.boundary == 0.875
|
||||
|
||||
def test_wan21_config_to_dict(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
d = config.to_dict()
|
||||
assert d["model_version"] == "2.1"
|
||||
assert d["dual_model"] is False
|
||||
assert d["sample_guide_scale"] == 5.0
|
||||
|
||||
def test_wan21_1_3b_config_to_dict(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
d = config.to_dict()
|
||||
assert d["dim"] == 1536
|
||||
assert d["num_layers"] == 30
|
||||
|
||||
def test_default_config_is_wan22(self):
|
||||
"""Default WanModelConfig() should be Wan2.2 14B."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
config = WanModelConfig()
|
||||
assert config.model_version == "2.2"
|
||||
assert config.dual_model is True
|
||||
307
tests/test_wan_convert.py
Normal file
307
tests/test_wan_convert.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""Tests for Wan weight conversion utilities."""
|
||||
|
||||
import logging
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transformer Weight Conversion Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSanitizeTransformerWeights:
|
||||
def test_patch_embedding_reshape(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
weights = {
|
||||
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
|
||||
"patch_embedding.bias": mx.random.normal((5120,)),
|
||||
}
|
||||
out = sanitize_wan_transformer_weights(weights)
|
||||
assert "patch_embedding_proj.weight" in out
|
||||
assert "patch_embedding_proj.bias" in out
|
||||
assert out["patch_embedding_proj.weight"].shape == (5120, 16 * 1 * 2 * 2)
|
||||
|
||||
def test_text_embedding_rename(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
weights = {
|
||||
"text_embedding.0.weight": mx.zeros((64, 32)),
|
||||
"text_embedding.0.bias": mx.zeros((64,)),
|
||||
"text_embedding.2.weight": mx.zeros((64, 64)),
|
||||
"text_embedding.2.bias": mx.zeros((64,)),
|
||||
}
|
||||
out = sanitize_wan_transformer_weights(weights)
|
||||
assert "text_embedding_0.weight" in out
|
||||
assert "text_embedding_0.bias" in out
|
||||
assert "text_embedding_1.weight" in out
|
||||
assert "text_embedding_1.bias" in out
|
||||
|
||||
def test_time_embedding_rename(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
weights = {
|
||||
"time_embedding.0.weight": mx.zeros((64, 32)),
|
||||
"time_embedding.2.weight": mx.zeros((64, 64)),
|
||||
}
|
||||
out = sanitize_wan_transformer_weights(weights)
|
||||
assert "time_embedding_0.weight" in out
|
||||
assert "time_embedding_1.weight" in out
|
||||
|
||||
def test_time_projection_rename(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
weights = {
|
||||
"time_projection.1.weight": mx.zeros((384, 64)),
|
||||
"time_projection.1.bias": mx.zeros((384,)),
|
||||
}
|
||||
out = sanitize_wan_transformer_weights(weights)
|
||||
assert "time_projection.weight" in out
|
||||
assert "time_projection.bias" in out
|
||||
|
||||
def test_ffn_rename(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
weights = {
|
||||
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
|
||||
"blocks.0.ffn.0.bias": mx.zeros((128,)),
|
||||
"blocks.0.ffn.2.weight": mx.zeros((64, 128)),
|
||||
"blocks.0.ffn.2.bias": mx.zeros((64,)),
|
||||
}
|
||||
out = sanitize_wan_transformer_weights(weights)
|
||||
assert "blocks.0.ffn.fc1.weight" in out
|
||||
assert "blocks.0.ffn.fc1.bias" in out
|
||||
assert "blocks.0.ffn.fc2.weight" in out
|
||||
assert "blocks.0.ffn.fc2.bias" in out
|
||||
|
||||
def test_freqs_skipped(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
weights = {
|
||||
"freqs": mx.zeros((1024, 64, 2)),
|
||||
"blocks.0.norm1.weight": mx.zeros((64,)),
|
||||
}
|
||||
out = sanitize_wan_transformer_weights(weights)
|
||||
assert "freqs" not in out
|
||||
assert "blocks.0.norm1.weight" in out
|
||||
|
||||
def test_passthrough_keys(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
weights = {
|
||||
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
|
||||
"blocks.0.self_attn.k.weight": mx.zeros((64, 64)),
|
||||
"blocks.0.self_attn.v.weight": mx.zeros((64, 64)),
|
||||
"blocks.0.self_attn.o.weight": mx.zeros((64, 64)),
|
||||
"blocks.0.modulation": mx.zeros((1, 6, 64)),
|
||||
"head.head.weight": mx.zeros((64, 64)),
|
||||
"head.modulation": mx.zeros((1, 2, 64)),
|
||||
}
|
||||
out = sanitize_wan_transformer_weights(weights)
|
||||
for key in weights:
|
||||
assert key in out
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
weights = {
|
||||
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
|
||||
"patch_embedding.bias": mx.random.normal((5120,)),
|
||||
"text_embedding.0.weight": mx.zeros((64, 32)),
|
||||
"text_embedding.2.weight": mx.zeros((64, 64)),
|
||||
"time_embedding.0.weight": mx.zeros((64, 32)),
|
||||
"time_embedding.2.weight": mx.zeros((64, 64)),
|
||||
"time_projection.1.weight": mx.zeros((384, 64)),
|
||||
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
|
||||
"blocks.0.ffn.2.weight": mx.zeros((64, 128)),
|
||||
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
|
||||
"blocks.0.modulation": mx.zeros((1, 6, 64)),
|
||||
"head.head.weight": mx.zeros((64, 64)),
|
||||
"freqs": mx.zeros((1024, 64, 2)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
|
||||
sanitize_wan_transformer_weights(weights)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
|
||||
class TestSanitizeT5Weights:
|
||||
def test_gate_rename(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
||||
weights = {
|
||||
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
|
||||
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
|
||||
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
|
||||
}
|
||||
out = sanitize_wan_t5_weights(weights)
|
||||
assert "blocks.0.ffn.gate_proj.weight" in out
|
||||
assert "blocks.0.ffn.fc1.weight" in out
|
||||
assert "blocks.0.ffn.fc2.weight" in out
|
||||
|
||||
def test_passthrough(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
||||
weights = {
|
||||
"token_embedding.weight": mx.zeros((100, 64)),
|
||||
"blocks.0.attn.q.weight": mx.zeros((64, 64)),
|
||||
"norm.weight": mx.zeros((64,)),
|
||||
}
|
||||
out = sanitize_wan_t5_weights(weights)
|
||||
for key in weights:
|
||||
assert key in out
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
||||
weights = {
|
||||
"token_embedding.weight": mx.zeros((100, 64)),
|
||||
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
|
||||
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
|
||||
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
|
||||
"norm.weight": mx.zeros((64,)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
|
||||
sanitize_wan_t5_weights(weights)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
|
||||
class TestSanitizeVAEWeights:
|
||||
def test_conv3d_transpose(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||
weights = {
|
||||
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W]
|
||||
}
|
||||
out = sanitize_wan_vae_weights(weights)
|
||||
assert out["decoder.conv1.weight"].shape == (8, 3, 3, 3, 4) # [O, D, H, W, I]
|
||||
|
||||
def test_conv2d_transpose(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||
weights = {
|
||||
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W]
|
||||
}
|
||||
out = sanitize_wan_vae_weights(weights)
|
||||
assert out["decoder.proj.weight"].shape == (16, 3, 3, 8) # [O, H, W, I]
|
||||
|
||||
def test_non_conv_passthrough(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||
weights = {
|
||||
"decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose
|
||||
"decoder.bias": mx.zeros((16,)),
|
||||
}
|
||||
out = sanitize_wan_vae_weights(weights)
|
||||
assert out["decoder.norm.weight"].shape == (64,)
|
||||
assert out["decoder.bias"].shape == (16,)
|
||||
|
||||
def test_mixed_weights(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||
weights = {
|
||||
"conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D
|
||||
"conv2d.weight": mx.zeros((8, 4, 3, 3)), # 4D
|
||||
"linear.weight": mx.zeros((8, 4)), # 2D
|
||||
"norm.weight": mx.zeros((8,)), # 1D
|
||||
}
|
||||
out = sanitize_wan_vae_weights(weights)
|
||||
assert out["conv3d.weight"].shape == (8, 3, 3, 3, 4)
|
||||
assert out["conv2d.weight"].shape == (8, 3, 3, 4)
|
||||
assert out["linear.weight"].shape == (8, 4)
|
||||
assert out["norm.weight"].shape == (8,)
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||
weights = {
|
||||
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)),
|
||||
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)),
|
||||
"decoder.norm.weight": mx.zeros((64,)),
|
||||
"decoder.bias": mx.zeros((16,)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
|
||||
sanitize_wan_vae_weights(weights)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wan2.1 Conversion Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestWan21Convert:
|
||||
"""Tests for Wan2.1 conversion support."""
|
||||
|
||||
def test_auto_detect_wan21(self, tmp_path):
|
||||
"""Auto-detect single-model directory as Wan2.1."""
|
||||
# Create a Wan2.1-style directory (no low_noise_model subdir)
|
||||
(tmp_path / "dummy.safetensors").touch()
|
||||
# The auto-detect logic: no low_noise_model dir → 2.1
|
||||
from pathlib import Path
|
||||
low = tmp_path / "low_noise_model"
|
||||
assert not low.exists()
|
||||
# Simulates auto detection
|
||||
version = "2.2" if low.exists() else "2.1"
|
||||
assert version == "2.1"
|
||||
|
||||
def test_auto_detect_wan22(self, tmp_path):
|
||||
"""Auto-detect dual-model directory as Wan2.2."""
|
||||
(tmp_path / "low_noise_model").mkdir()
|
||||
(tmp_path / "high_noise_model").mkdir()
|
||||
from pathlib import Path
|
||||
low = tmp_path / "low_noise_model"
|
||||
assert low.exists()
|
||||
version = "2.2" if low.exists() else "2.1"
|
||||
assert version == "2.2"
|
||||
|
||||
def test_wan21_config_saved_correctly(self):
|
||||
"""Verify config dict has correct fields for Wan2.1."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
d = config.to_dict()
|
||||
assert d["model_version"] == "2.1"
|
||||
assert d["dual_model"] is False
|
||||
assert d["sample_steps"] == 50
|
||||
assert d["sample_shift"] == 5.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Encoder Weight Sanitization Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSanitizeEncoderWeights:
|
||||
"""Tests for sanitize_wan22_vae_weights with include_encoder."""
|
||||
|
||||
def test_exclude_encoder_by_default(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
}
|
||||
out = sanitize_wan22_vae_weights(weights, include_encoder=False)
|
||||
assert "conv2.weight" in out
|
||||
assert not any("encoder" in k or k.startswith("conv1") for k in out)
|
||||
|
||||
def test_include_encoder(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
}
|
||||
out = sanitize_wan22_vae_weights(weights, include_encoder=True)
|
||||
assert "encoder.conv1.weight" in out
|
||||
assert "conv1.weight" in out
|
||||
assert "conv2.weight" in out
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
|
||||
sanitize_wan22_vae_weights(weights, include_encoder=True)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
def test_no_unconsumed_keys_exclude_encoder(self, caplog):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
|
||||
sanitize_wan22_vae_weights(weights, include_encoder=False)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
238
tests/test_wan_generate.py
Normal file
238
tests/test_wan_generate.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Tests for end-to-end generation and I2V mask construction."""
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from wan_test_helpers import _make_tiny_config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: end-to-end tiny model forward pass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEndToEnd:
|
||||
"""End-to-end test with tiny model (no real weights needed)."""
|
||||
|
||||
def test_tiny_model_denoise_step(self):
|
||||
"""Simulate one denoising step with tiny model."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(42)
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(5, shift=3.0)
|
||||
|
||||
latents = mx.random.normal((C, F, H, W))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
|
||||
# One step
|
||||
t = sched.timesteps[0]
|
||||
pred = model([latents], mx.array([t.item()]), [context], seq_len)[0]
|
||||
latents_next = sched.step(pred[None], t, latents[None]).squeeze(0)
|
||||
mx.eval(latents_next)
|
||||
|
||||
assert latents_next.shape == (C, F, H, W)
|
||||
# Should differ from original noise
|
||||
assert not np.allclose(np.array(latents_next), np.array(latents), atol=1e-5)
|
||||
|
||||
def test_tiny_model_full_loop(self):
|
||||
"""Run a complete (tiny) diffusion loop."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(123)
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
num_steps = 3
|
||||
sched.set_timesteps(num_steps, shift=3.0)
|
||||
|
||||
latents = mx.random.normal((C, F, H, W))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
|
||||
for i in range(num_steps):
|
||||
t = sched.timesteps[i]
|
||||
pred = model([latents], mx.array([t.item()]), [context], seq_len)[0]
|
||||
latents = sched.step(pred[None], t, latents[None]).squeeze(0)
|
||||
mx.eval(latents)
|
||||
|
||||
assert latents.shape == (C, F, H, W)
|
||||
assert not mx.any(mx.isnan(latents)).item(), "NaN in output"
|
||||
assert not mx.any(mx.isinf(latents)).item(), "Inf in output"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# I2V Mask Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestI2VMask:
|
||||
"""Tests for _build_i2v_mask."""
|
||||
|
||||
def test_mask_shapes(self):
|
||||
from mlx_video.generate_wan import _build_i2v_mask
|
||||
|
||||
z_shape = (48, 5, 4, 4) # C, T, H, W
|
||||
patch_size = (1, 2, 2)
|
||||
mask, mask_tokens = _build_i2v_mask(z_shape, patch_size)
|
||||
assert mask.shape == z_shape
|
||||
# Tokens: T=5, H/2=2, W/2=2 → 5*2*2 = 20
|
||||
assert mask_tokens.shape == (1, 20)
|
||||
|
||||
def test_first_frame_zero(self):
|
||||
from mlx_video.generate_wan import _build_i2v_mask
|
||||
|
||||
z_shape = (48, 5, 4, 4)
|
||||
mask, mask_tokens = _build_i2v_mask(z_shape, (1, 2, 2))
|
||||
mx.eval(mask, mask_tokens)
|
||||
# First temporal position should be 0
|
||||
assert float(mask[:, 0, :, :].max()) == 0.0
|
||||
# Rest should be 1
|
||||
assert float(mask[:, 1:, :, :].min()) == 1.0
|
||||
# First-frame tokens (T=0) should be 0 in mask_tokens
|
||||
# With T=5, H'=2, W'=2: first 4 tokens are frame 0
|
||||
assert float(mask_tokens[0, :4].max()) == 0.0
|
||||
assert float(mask_tokens[0, 4:].min()) == 1.0
|
||||
|
||||
|
||||
class TestI2VMaskAlignment:
|
||||
"""Tests that I2V mask works correctly with various aligned dimensions."""
|
||||
|
||||
def test_mask_with_ti2v_dimensions(self):
|
||||
"""Mask should work with TI2V-5B typical dimensions."""
|
||||
from mlx_video.generate_wan import _build_i2v_mask
|
||||
# TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2)
|
||||
# 704x1280 → latent 44x80, t_latent=21 for 81 frames
|
||||
z_shape = (48, 21, 44, 80)
|
||||
patch_size = (1, 2, 2)
|
||||
mask, mask_tokens = _build_i2v_mask(z_shape, patch_size)
|
||||
mx.eval(mask, mask_tokens)
|
||||
|
||||
assert mask.shape == z_shape
|
||||
assert float(mask[:, 0].max()) == 0.0
|
||||
assert float(mask[:, 1:].min()) == 1.0
|
||||
|
||||
expected_tokens = 21 * 22 * 40 # T * (H/ph) * (W/pw)
|
||||
assert mask_tokens.shape == (1, expected_tokens)
|
||||
first_frame_tokens = 1 * 22 * 40 # pt=1
|
||||
assert float(mask_tokens[0, :first_frame_tokens].max()) == 0.0
|
||||
assert float(mask_tokens[0, first_frame_tokens:].min()) == 1.0
|
||||
|
||||
def test_mask_per_token_timestep(self):
|
||||
"""Per-token timesteps: first-frame tokens get t=0, rest get t=sigma."""
|
||||
from mlx_video.generate_wan import _build_i2v_mask
|
||||
z_shape = (4, 3, 4, 4)
|
||||
patch_size = (1, 2, 2)
|
||||
_, mask_tokens = _build_i2v_mask(z_shape, patch_size)
|
||||
mx.eval(mask_tokens)
|
||||
|
||||
timestep_val = 0.8
|
||||
t_tokens = mask_tokens * timestep_val
|
||||
mx.eval(t_tokens)
|
||||
|
||||
first_tokens = 1 * 2 * 2 # pt * (H/ph) * (W/pw)
|
||||
np.testing.assert_allclose(np.array(t_tokens[0, :first_tokens]), 0.0, atol=1e-7)
|
||||
np.testing.assert_allclose(np.array(t_tokens[0, first_tokens:]), timestep_val, atol=1e-7)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dimension Alignment Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDimensionAlignment:
|
||||
"""Tests for automatic dimension alignment in generate_wan."""
|
||||
|
||||
def test_already_aligned(self):
|
||||
"""Dimensions already divisible by alignment factor should be unchanged."""
|
||||
# patch_size=(1,2,2), vae_stride=(4,16,16) → align = 32
|
||||
align_h = 2 * 16 # 32
|
||||
align_w = 2 * 16 # 32
|
||||
h, w = 704, 1280
|
||||
assert h % align_h == 0
|
||||
assert w % align_w == 0
|
||||
h_aligned = (h // align_h) * align_h
|
||||
w_aligned = (w // align_w) * align_w
|
||||
assert h_aligned == h
|
||||
assert w_aligned == w
|
||||
|
||||
def test_720p_rounds_down(self):
|
||||
"""720p (1280x720) should round height to 704."""
|
||||
align_h = 32
|
||||
align_w = 32
|
||||
h, w = 720, 1280
|
||||
assert h % align_h != 0 # 720 not divisible by 32
|
||||
h_aligned = (h // align_h) * align_h
|
||||
w_aligned = (w // align_w) * align_w
|
||||
assert h_aligned == 704
|
||||
assert w_aligned == 1280
|
||||
|
||||
def test_1080p_rounds_down(self):
|
||||
"""1080p (1920x1080) should round height to 1056."""
|
||||
align = 32
|
||||
h, w = 1080, 1920
|
||||
assert h % align != 0
|
||||
assert (h // align) * align == 1056
|
||||
assert (w // align) * align == 1920
|
||||
|
||||
def test_odd_sizes(self):
|
||||
"""Odd sizes should be safely rounded down."""
|
||||
align = 32
|
||||
for size in [100, 255, 513, 1023]:
|
||||
aligned = (size // align) * align
|
||||
assert aligned % align == 0
|
||||
assert aligned <= size
|
||||
assert aligned + align > size # closest lower multiple
|
||||
|
||||
def test_patchify_valid_after_alignment(self):
|
||||
"""After alignment, patchify should succeed without reshape errors."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
|
||||
# Simulate 720p-like scenario with tiny config
|
||||
vae_stride = config.vae_stride # (4, 8, 8)
|
||||
patch_size = config.patch_size # (1, 2, 2)
|
||||
align_h = patch_size[1] * vae_stride[1]
|
||||
align_w = patch_size[2] * vae_stride[2]
|
||||
|
||||
# Pick a height not divisible by alignment
|
||||
raw_h = align_h * 3 + 5 # e.g. 53 for align=16
|
||||
raw_w = align_w * 4
|
||||
h = (raw_h // align_h) * align_h # rounds down
|
||||
w = (raw_w // align_w) * align_w
|
||||
|
||||
C = config.in_dim
|
||||
t_latent = 1
|
||||
h_latent = h // vae_stride[1]
|
||||
w_latent = w // vae_stride[2]
|
||||
|
||||
vid = mx.random.normal((C, t_latent, h_latent, w_latent))
|
||||
patches, grid_size = model._patchify(vid)
|
||||
mx.eval(patches)
|
||||
assert patches.ndim == 3 # [1, L, dim]
|
||||
assert grid_size == (t_latent, h_latent // patch_size[1], w_latent // patch_size[2])
|
||||
|
||||
def test_alignment_with_ti2v_config(self):
|
||||
"""TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
config = WanModelConfig.wan22_ti2v_5b()
|
||||
align_h = config.patch_size[1] * config.vae_stride[1]
|
||||
align_w = config.patch_size[2] * config.vae_stride[2]
|
||||
assert align_h == 32
|
||||
assert align_w == 32
|
||||
# 720 not divisible
|
||||
assert 720 % align_h != 0
|
||||
# 704 is
|
||||
assert 704 % align_h == 0
|
||||
570
tests/test_wan_i2v.py
Normal file
570
tests/test_wan_i2v.py
Normal file
@@ -0,0 +1,570 @@
|
||||
"""Tests for Wan2.2 I2V-14B support."""
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from wan_test_helpers import _make_tiny_config
|
||||
|
||||
|
||||
def _make_tiny_i2v_config():
|
||||
"""Create a tiny I2V-14B config for testing."""
|
||||
config = _make_tiny_config()
|
||||
config.model_type = "i2v"
|
||||
config.in_dim = 9 # 4 noise + 4 image + 1 mask (scaled down from 16+16+4=36)
|
||||
config.out_dim = 4
|
||||
config.vae_z_dim = 4
|
||||
config.vae_stride = (4, 8, 8)
|
||||
config.dual_model = True
|
||||
config.boundary = 0.900
|
||||
config.sample_shift = 5.0
|
||||
config.sample_guide_scale = (3.5, 3.5)
|
||||
return config
|
||||
|
||||
|
||||
class TestI2VConfig:
|
||||
"""Test I2V-14B config preset."""
|
||||
|
||||
def test_wan22_i2v_14b_preset(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan22_i2v_14b()
|
||||
assert config.model_type == "i2v"
|
||||
assert config.in_dim == 36
|
||||
assert config.out_dim == 16
|
||||
assert config.dim == 5120
|
||||
assert config.num_layers == 40
|
||||
assert config.dual_model is True
|
||||
assert config.boundary == 0.900
|
||||
assert config.sample_shift == 5.0
|
||||
assert config.sample_guide_scale == (3.5, 3.5)
|
||||
assert config.vae_stride == (4, 8, 8)
|
||||
assert config.vae_z_dim == 16
|
||||
|
||||
def test_i2v_vs_t2v_differences(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
i2v = WanModelConfig.wan22_i2v_14b()
|
||||
t2v = WanModelConfig.wan22_t2v_14b()
|
||||
|
||||
assert i2v.model_type == "i2v"
|
||||
assert t2v.model_type == "t2v"
|
||||
assert i2v.in_dim == 36 and t2v.in_dim == 16
|
||||
assert i2v.boundary == 0.900 and t2v.boundary == 0.875
|
||||
assert i2v.sample_shift == 5.0 and t2v.sample_shift == 12.0
|
||||
|
||||
def test_i2v_serialization_roundtrip(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan22_i2v_14b()
|
||||
d = config.to_dict()
|
||||
restored = WanModelConfig.from_dict(d)
|
||||
assert restored.model_type == "i2v"
|
||||
assert restored.in_dim == 36
|
||||
assert restored.boundary == 0.900
|
||||
|
||||
|
||||
class TestModelYParameter:
|
||||
"""Test y parameter channel concatenation in WanModel."""
|
||||
|
||||
def test_forward_without_y(self):
|
||||
"""Standard T2V forward pass (no y) still works."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
x_list = [mx.random.normal((C, F, H, W))]
|
||||
t = mx.array([500.0])
|
||||
context = [mx.random.normal((6, config.text_dim))]
|
||||
|
||||
out = model(x_list, t, context, seq_len)
|
||||
mx.eval(out[0])
|
||||
assert out[0].shape == (C, F, H, W)
|
||||
|
||||
def test_forward_with_y(self):
|
||||
"""I2V forward pass with y channel concatenation."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_i2v_config()
|
||||
model = WanModel(config)
|
||||
|
||||
C_noise = 4 # noise channels
|
||||
C_y = 5 # mask (1) + image (4)
|
||||
F, H, W = 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
x_list = [mx.random.normal((C_noise, F, H, W))]
|
||||
y_list = [mx.random.normal((C_y, F, H, W))]
|
||||
t = mx.array([500.0])
|
||||
context = [mx.random.normal((6, config.text_dim))]
|
||||
|
||||
out = model(x_list, t, context, seq_len, y=y_list)
|
||||
mx.eval(out[0])
|
||||
# Output should match noise channels (out_dim), not concatenated in_dim
|
||||
assert out[0].shape == (config.out_dim, F, H, W)
|
||||
|
||||
def test_y_none_is_noop(self):
|
||||
"""Passing y=None should be identical to not passing y."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
mx.random.seed(42)
|
||||
x = mx.random.normal((C, F, H, W))
|
||||
t = mx.array([500.0])
|
||||
ctx = [mx.random.normal((6, config.text_dim))]
|
||||
|
||||
out1 = model([x], t, ctx, seq_len)[0]
|
||||
out2 = model([x], t, ctx, seq_len, y=None)[0]
|
||||
mx.eval(out1, out2)
|
||||
assert mx.allclose(out1, out2, atol=1e-5).item()
|
||||
|
||||
def test_batched_cfg_with_y(self):
|
||||
"""Batched CFG (B=2) with y should work."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_i2v_config()
|
||||
model = WanModel(config)
|
||||
|
||||
C_noise, C_y = 4, 5
|
||||
F, H, W = 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
latents = mx.random.normal((C_noise, F, H, W))
|
||||
y = mx.random.normal((C_y, F, H, W))
|
||||
t = mx.array([500.0, 500.0])
|
||||
ctx = [mx.random.normal((6, config.text_dim)), mx.random.normal((6, config.text_dim))]
|
||||
|
||||
out = model([latents, latents], t, ctx, seq_len, y=[y, y])
|
||||
mx.eval(out[0], out[1])
|
||||
assert len(out) == 2
|
||||
assert out[0].shape == (config.out_dim, F, H, W)
|
||||
assert out[1].shape == (config.out_dim, F, H, W)
|
||||
|
||||
|
||||
class TestVAEEncoder:
|
||||
"""Test Wan2.1 VAE encoder."""
|
||||
|
||||
def test_encoder3d_instantiation(self):
|
||||
from mlx_video.models.wan.vae import Encoder3d
|
||||
|
||||
enc = Encoder3d(dim=32, z_dim=8) # z_dim=8 (will output 8ch, but WanVAE wraps with z*2)
|
||||
assert enc.conv1 is not None
|
||||
assert len(enc.downsamples) > 0
|
||||
assert len(enc.middle) == 3
|
||||
|
||||
def test_encoder3d_output_shape(self):
|
||||
"""Encoder should downsample spatially by 8x and temporally by 4x."""
|
||||
from mlx_video.models.wan.vae import Encoder3d
|
||||
|
||||
enc = Encoder3d(dim=32, z_dim=8)
|
||||
# Random input: [B=1, 3, T=5, H=32, W=32]
|
||||
x = mx.random.normal((1, 3, 5, 32, 32))
|
||||
out = enc(x)
|
||||
mx.eval(out)
|
||||
# With default dim_mult=[1,2,4,4] and temporal_downsample=[True,True,False]:
|
||||
# Spatial: 32 -> 16 -> 8 -> 4 (3 spatial downsamples)
|
||||
# Temporal: 5 -> 3 -> 2 (2 temporal downsamples: downsample3d stride 2)
|
||||
assert out.shape[0] == 1
|
||||
assert out.shape[1] == 8 # z_dim
|
||||
assert out.shape[3] == 32 // 8 # spatial /8
|
||||
assert out.shape[4] == 32 // 8
|
||||
|
||||
def test_wan_vae_encode(self):
|
||||
"""WanVAE with encoder=True should produce normalized latents."""
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16, encoder=True)
|
||||
# Input: [B=1, 3, T=5, H=32, W=32]
|
||||
x = mx.random.normal((1, 3, 5, 32, 32))
|
||||
z = vae.encode(x)
|
||||
mx.eval(z)
|
||||
assert z.shape[0] == 1
|
||||
assert z.shape[1] == 16 # z_dim
|
||||
|
||||
def test_wan_vae_encoder_flag(self):
|
||||
"""WanVAE without encoder flag should not have encoder attribute."""
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
|
||||
vae_no_enc = WanVAE(z_dim=4, encoder=False)
|
||||
assert not hasattr(vae_no_enc, 'encoder')
|
||||
|
||||
vae_enc = WanVAE(z_dim=4, encoder=True)
|
||||
assert hasattr(vae_enc, 'encoder')
|
||||
|
||||
|
||||
class TestResampleDownsample:
|
||||
"""Test downsample modes in Resample."""
|
||||
|
||||
def test_downsample2d(self):
|
||||
from mlx_video.models.wan.vae import Resample
|
||||
|
||||
r = Resample(dim=16, mode="downsample2d")
|
||||
x = mx.random.normal((1, 16, 2, 8, 8))
|
||||
out = r(x)
|
||||
mx.eval(out)
|
||||
# Spatial /2, temporal unchanged, channels same
|
||||
assert out.shape == (1, 16, 2, 4, 4)
|
||||
|
||||
def test_downsample3d(self):
|
||||
from mlx_video.models.wan.vae import Resample
|
||||
|
||||
r = Resample(dim=16, mode="downsample3d")
|
||||
x = mx.random.normal((1, 16, 4, 8, 8))
|
||||
out = r(x)
|
||||
mx.eval(out)
|
||||
# Spatial /2, temporal /2, channels same
|
||||
assert out.shape == (1, 16, 2, 4, 4)
|
||||
|
||||
def test_upsample2d_still_works(self):
|
||||
from mlx_video.models.wan.vae import Resample
|
||||
|
||||
r = Resample(dim=16, mode="upsample2d")
|
||||
x = mx.random.normal((1, 16, 2, 4, 4))
|
||||
out = r(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 8, 2, 8, 8)
|
||||
|
||||
def test_upsample3d_still_works(self):
|
||||
from mlx_video.models.wan.vae import Resample
|
||||
|
||||
r = Resample(dim=16, mode="upsample3d")
|
||||
x = mx.random.normal((1, 16, 2, 4, 4))
|
||||
out = r(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 8, 4, 8, 8)
|
||||
|
||||
|
||||
class TestI2VMaskConstruction:
|
||||
"""Test mask construction for I2V-14B."""
|
||||
|
||||
def test_mask_shape(self):
|
||||
"""I2V-14B mask should have 4 channels with correct temporal structure."""
|
||||
num_frames = 81
|
||||
h_latent, w_latent = 10, 18 # example latent dims
|
||||
t_latent = (num_frames - 1) // 4 + 1 # = 21
|
||||
|
||||
# Build mask following reference logic
|
||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
|
||||
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
||||
|
||||
assert msk.shape == (4, t_latent, h_latent, w_latent)
|
||||
|
||||
def test_mask_values(self):
|
||||
"""First temporal position should be 1, rest 0."""
|
||||
num_frames = 9
|
||||
h_latent, w_latent = 4, 4
|
||||
t_latent = (num_frames - 1) // 4 + 1 # = 3
|
||||
|
||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
|
||||
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||
msk = msk.transpose(0, 2, 1, 3, 4)[0]
|
||||
|
||||
mx.eval(msk)
|
||||
# First temporal position: all 4 channels should be 1
|
||||
assert mx.all(msk[:, 0] == 1.0).item()
|
||||
# Rest: all should be 0
|
||||
assert mx.all(msk[:, 1:] == 0.0).item()
|
||||
|
||||
def test_y_tensor_shape(self):
|
||||
"""y = concat([mask_4ch, encoded_video_16ch]) should be 20 channels."""
|
||||
mask = mx.zeros((4, 5, 10, 18))
|
||||
encoded = mx.zeros((16, 5, 10, 18))
|
||||
y = mx.concatenate([mask, encoded], axis=0)
|
||||
assert y.shape == (20, 5, 10, 18)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: I2V end-to-end pipeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestI2VEndToEndPipeline:
|
||||
"""Full I2V pipeline: image → preprocess → VAE encode → y tensor → denoise → VAE decode."""
|
||||
|
||||
def test_full_i2v_pipeline(self):
|
||||
"""End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
|
||||
mx.random.seed(0)
|
||||
|
||||
# --- Tiny I2V model config (z_dim=16 to match VAE normalization stats) ---
|
||||
config = _make_tiny_i2v_config()
|
||||
config.vae_z_dim = 16
|
||||
config.out_dim = 16 # must match VAE z_dim for decode
|
||||
config.in_dim = 16 + 4 + 16 # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36
|
||||
model = WanModel(config)
|
||||
|
||||
# --- Tiny VAE (with encoder) ---
|
||||
vae = WanVAE(z_dim=config.vae_z_dim, encoder=True)
|
||||
|
||||
# --- Synthetic image: [B=1, 3, T=1, H=32, W=32] in [-1, 1] ---
|
||||
height, width = 32, 32
|
||||
num_frames = 5 # small temporal extent
|
||||
img = mx.random.uniform(-1, 1, (1, 3, 1, height, width))
|
||||
|
||||
# Build video: first frame = image, rest = zeros -> [1, 3, F, H, W]
|
||||
video = mx.concatenate([
|
||||
img,
|
||||
mx.zeros((1, 3, num_frames - 1, height, width)),
|
||||
], axis=2)
|
||||
|
||||
# --- VAE encode ---
|
||||
z_video = vae.encode(video) # [1, z_dim, T_lat, H_lat, W_lat]
|
||||
mx.eval(z_video)
|
||||
assert z_video.ndim == 5
|
||||
assert z_video.shape[1] == config.vae_z_dim
|
||||
|
||||
z_video = z_video[0] # [z_dim, T_lat, H_lat, W_lat]
|
||||
t_latent = z_video.shape[1]
|
||||
h_latent = z_video.shape[2]
|
||||
w_latent = z_video.shape[3]
|
||||
|
||||
# --- Build I2V mask (4 channels) ---
|
||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
|
||||
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
||||
|
||||
# --- Build y tensor: [mask(4ch) + encoded(z_dim ch)] ---
|
||||
y_i2v = mx.concatenate([msk, z_video], axis=0)
|
||||
mx.eval(y_i2v)
|
||||
assert y_i2v.shape[0] == 4 + config.vae_z_dim
|
||||
|
||||
# --- Denoising loop (2 steps) ---
|
||||
C_noise = config.out_dim # noise channels
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (t_latent // pt) * (h_latent // ph) * (w_latent // pw)
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
num_steps = 2
|
||||
sched.set_timesteps(num_steps, shift=config.sample_shift)
|
||||
|
||||
latents = mx.random.normal((C_noise, t_latent, h_latent, w_latent))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
|
||||
for i in range(num_steps):
|
||||
t_val = sched.timesteps[i].item()
|
||||
pred = model(
|
||||
[latents],
|
||||
mx.array([t_val]),
|
||||
[context],
|
||||
seq_len,
|
||||
y=[y_i2v],
|
||||
)[0]
|
||||
latents = sched.step(pred[None], t_val, latents[None]).squeeze(0)
|
||||
mx.eval(latents)
|
||||
|
||||
assert latents.shape == (C_noise, t_latent, h_latent, w_latent)
|
||||
assert not mx.any(mx.isnan(latents)).item(), "NaN in denoised latents"
|
||||
assert not mx.any(mx.isinf(latents)).item(), "Inf in denoised latents"
|
||||
|
||||
# --- VAE decode ---
|
||||
decoded = vae.decode(latents[None]) # [1, 3, T_out, H_out, W_out]
|
||||
mx.eval(decoded)
|
||||
assert decoded.ndim == 5
|
||||
assert decoded.shape[0] == 1
|
||||
assert decoded.shape[1] == 3 # RGB output
|
||||
assert not mx.any(mx.isnan(decoded)).item(), "NaN in decoded video"
|
||||
assert not mx.any(mx.isinf(decoded)).item(), "Inf in decoded video"
|
||||
# VAE decode clips to [-1, 1]
|
||||
assert float(decoded.max()) <= 1.0
|
||||
assert float(decoded.min()) >= -1.0
|
||||
|
||||
|
||||
class TestDualModelSwitching:
|
||||
"""Test dual-model selection logic: high_noise vs low_noise based on boundary."""
|
||||
|
||||
def test_model_selection_by_timestep(self):
|
||||
"""Verify high_noise model used for timesteps >= boundary, low_noise otherwise."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(1)
|
||||
config = _make_tiny_i2v_config()
|
||||
assert config.dual_model is True
|
||||
|
||||
high_noise_model = WanModel(config)
|
||||
low_noise_model = WanModel(config)
|
||||
|
||||
boundary = config.boundary * config.num_train_timesteps # 0.9 * 1000 = 900
|
||||
|
||||
C_noise = config.out_dim # 4
|
||||
C_y = config.in_dim - config.out_dim # 9 - 4 = 5
|
||||
F, H, W = 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
num_steps = 5
|
||||
sched.set_timesteps(num_steps, shift=config.sample_shift)
|
||||
|
||||
guide_scale = config.sample_guide_scale # (3.5, 3.5)
|
||||
assert isinstance(guide_scale, tuple) and len(guide_scale) == 2
|
||||
|
||||
latents = mx.random.normal((C_noise, F, H, W))
|
||||
y_i2v = mx.random.normal((C_y, F, H, W))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
|
||||
high_used_steps = []
|
||||
low_used_steps = []
|
||||
|
||||
timestep_list = sched.timesteps.tolist()
|
||||
for i in range(num_steps):
|
||||
timestep_val = timestep_list[i]
|
||||
|
||||
if timestep_val >= boundary:
|
||||
model = high_noise_model
|
||||
gs = guide_scale[1]
|
||||
high_used_steps.append(i)
|
||||
else:
|
||||
model = low_noise_model
|
||||
gs = guide_scale[0]
|
||||
low_used_steps.append(i)
|
||||
|
||||
# CFG pass: cond + uncond
|
||||
preds = model(
|
||||
[latents, latents],
|
||||
mx.array([timestep_val, timestep_val]),
|
||||
[context, context],
|
||||
seq_len,
|
||||
y=[y_i2v, y_i2v],
|
||||
)
|
||||
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
|
||||
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||
mx.eval(latents)
|
||||
|
||||
# With shift=5.0, early timesteps should be high (>=900), later ones low
|
||||
assert len(high_used_steps) > 0, "High-noise model was never selected"
|
||||
assert len(low_used_steps) > 0, "Low-noise model was never selected"
|
||||
# High-noise steps should come before low-noise steps (timesteps decrease)
|
||||
if high_used_steps and low_used_steps:
|
||||
assert max(high_used_steps) < min(low_used_steps) or \
|
||||
min(high_used_steps) < max(low_used_steps), \
|
||||
"Model switching should happen during the loop"
|
||||
|
||||
assert latents.shape == (C_noise, F, H, W)
|
||||
assert not mx.any(mx.isnan(latents)).item()
|
||||
|
||||
def test_guide_scale_tuple_applied_per_model(self):
|
||||
"""Verify (low_gs, high_gs) tuple applies different scales per model."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(2)
|
||||
config = _make_tiny_i2v_config()
|
||||
config.sample_guide_scale = (2.0, 5.0) # distinct values
|
||||
|
||||
model = WanModel(config)
|
||||
boundary = config.boundary * config.num_train_timesteps
|
||||
|
||||
C_noise = config.out_dim
|
||||
F, H, W = 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(5, shift=config.sample_shift)
|
||||
|
||||
latents = mx.random.normal((C_noise, F, H, W))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
guide_scale = config.sample_guide_scale
|
||||
C_y = config.in_dim - config.out_dim # y channels
|
||||
y_i2v = mx.random.normal((C_y, F, H, W))
|
||||
|
||||
# Track which guide scale was used at each step
|
||||
gs_per_step = []
|
||||
|
||||
timestep_list = sched.timesteps.tolist()
|
||||
for i in range(5):
|
||||
timestep_val = timestep_list[i]
|
||||
|
||||
if timestep_val >= boundary:
|
||||
gs = guide_scale[1] # high_gs = 5.0
|
||||
else:
|
||||
gs = guide_scale[0] # low_gs = 2.0
|
||||
gs_per_step.append(gs)
|
||||
|
||||
pred = model(
|
||||
[latents, latents],
|
||||
mx.array([timestep_val, timestep_val]),
|
||||
[context, context],
|
||||
seq_len,
|
||||
y=[y_i2v, y_i2v],
|
||||
)
|
||||
noise_pred = pred[1] + gs * (pred[0] - pred[1])
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||
mx.eval(latents)
|
||||
|
||||
# Verify both guide scales were used
|
||||
assert 5.0 in gs_per_step, "High guide scale (5.0) was never used"
|
||||
assert 2.0 in gs_per_step, "Low guide scale (2.0) was never used"
|
||||
# High gs should appear first (high timesteps come first)
|
||||
first_high = gs_per_step.index(5.0)
|
||||
last_low = len(gs_per_step) - 1 - gs_per_step[::-1].index(2.0)
|
||||
assert first_high < last_low, "High gs steps should precede low gs steps"
|
||||
|
||||
def test_single_model_fallback_with_tuple_guide_scale(self):
|
||||
"""When dual_model=False, guide_scale tuple should use first element."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(3)
|
||||
config = _make_tiny_config()
|
||||
config.dual_model = False
|
||||
config.sample_guide_scale = (3.0, 5.0)
|
||||
|
||||
model = WanModel(config)
|
||||
guide_scale = config.sample_guide_scale
|
||||
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(3, shift=3.0)
|
||||
|
||||
latents = mx.random.normal((C, F, H, W))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
|
||||
# Mimic generate_wan.py single-model logic:
|
||||
# gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
||||
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
||||
assert gs == 3.0, "Single model should use first element of guide_scale tuple"
|
||||
|
||||
for i in range(3):
|
||||
t_val = sched.timesteps[i].item()
|
||||
pred = model(
|
||||
[latents, latents],
|
||||
mx.array([t_val, t_val]),
|
||||
[context, context],
|
||||
seq_len,
|
||||
)
|
||||
noise_pred = pred[1] + gs * (pred[0] - pred[1])
|
||||
latents = sched.step(noise_pred[None], t_val, latents[None]).squeeze(0)
|
||||
mx.eval(latents)
|
||||
|
||||
assert latents.shape == (C, F, H, W)
|
||||
assert not mx.any(mx.isnan(latents)).item()
|
||||
334
tests/test_wan_lora.py
Normal file
334
tests/test_wan_lora.py
Normal file
@@ -0,0 +1,334 @@
|
||||
"""Tests for LoRA loading and application."""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
class TestLoRATypes:
|
||||
"""Test LoRA data structures."""
|
||||
|
||||
def test_lora_weights_scale(self):
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
w = LoRAWeights(
|
||||
lora_A=mx.zeros((16, 64)),
|
||||
lora_B=mx.zeros((128, 16)),
|
||||
rank=16,
|
||||
alpha=32.0,
|
||||
module_name="test",
|
||||
)
|
||||
assert w.scale == 2.0
|
||||
|
||||
def test_lora_weights_scale_default(self):
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
w = LoRAWeights(
|
||||
lora_A=mx.zeros((16, 64)),
|
||||
lora_B=mx.zeros((128, 16)),
|
||||
rank=16,
|
||||
alpha=16.0,
|
||||
module_name="test",
|
||||
)
|
||||
assert w.scale == 1.0
|
||||
|
||||
def test_applied_lora_delta(self):
|
||||
from mlx_video.lora.types import AppliedLoRA, LoRAWeights
|
||||
|
||||
lora_a = mx.ones((2, 4))
|
||||
lora_b = mx.ones((8, 2))
|
||||
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
||||
applied = AppliedLoRA(weights=w, strength=0.5)
|
||||
delta = applied.compute_delta()
|
||||
# scale=1.0, strength=0.5, B@A = [[2,2,2,2]]*8 (each row sum of 2 ones)
|
||||
expected = 0.5 * mx.ones((8, 4)) * 2.0
|
||||
assert mx.allclose(delta, expected).item()
|
||||
|
||||
|
||||
class TestLoRALoader:
|
||||
"""Test LoRA weight loading from safetensors."""
|
||||
|
||||
def _make_lora_file(self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"):
|
||||
"""Helper to create a mock LoRA safetensors file."""
|
||||
weights = {}
|
||||
for name in module_names:
|
||||
if key_format == "AB":
|
||||
weights[f"{name}.lora_A.weight"] = mx.random.normal((rank, in_dim))
|
||||
weights[f"{name}.lora_B.weight"] = mx.random.normal((out_dim, rank))
|
||||
else:
|
||||
weights[f"{name}.lora_down.weight"] = mx.random.normal((rank, in_dim))
|
||||
weights[f"{name}.lora_up.weight"] = mx.random.normal((out_dim, rank))
|
||||
path = Path(tmp_dir) / "test_lora.safetensors"
|
||||
mx.save_safetensors(str(path), weights)
|
||||
return path
|
||||
|
||||
def test_load_lora_a_b_format(self):
|
||||
from mlx_video.lora.loader import load_lora_weights
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = self._make_lora_file(tmp, ["blocks.0.self_attn.q"], key_format="AB")
|
||||
lora_weights = load_lora_weights(path)
|
||||
assert "blocks.0.self_attn.q" in lora_weights
|
||||
w = lora_weights["blocks.0.self_attn.q"]
|
||||
assert w.rank == 4
|
||||
assert w.alpha == 4.0 # default: alpha == rank
|
||||
assert w.lora_A.shape == (4, 64)
|
||||
assert w.lora_B.shape == (128, 4)
|
||||
|
||||
def test_load_lora_down_up_format(self):
|
||||
from mlx_video.lora.loader import load_lora_weights
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = self._make_lora_file(
|
||||
tmp, ["blocks.0.self_attn.q"], key_format="down_up"
|
||||
)
|
||||
lora_weights = load_lora_weights(path)
|
||||
assert "blocks.0.self_attn.q" in lora_weights
|
||||
|
||||
def test_load_multiple_modules(self):
|
||||
from mlx_video.lora.loader import load_lora_weights
|
||||
|
||||
modules = [
|
||||
"blocks.0.self_attn.q",
|
||||
"blocks.0.self_attn.k",
|
||||
"blocks.0.ffn.fc1",
|
||||
]
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = self._make_lora_file(tmp, modules)
|
||||
lora_weights = load_lora_weights(path)
|
||||
assert len(lora_weights) == 3
|
||||
for name in modules:
|
||||
assert name in lora_weights
|
||||
|
||||
def test_load_with_alpha(self):
|
||||
from mlx_video.lora.loader import load_lora_weights
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
weights = {
|
||||
"test.lora_A.weight": mx.random.normal((8, 64)),
|
||||
"test.lora_B.weight": mx.random.normal((128, 8)),
|
||||
"test.alpha": mx.array(16.0),
|
||||
}
|
||||
path = Path(tmp) / "lora.safetensors"
|
||||
mx.save_safetensors(str(path), weights)
|
||||
lora_weights = load_lora_weights(path)
|
||||
assert lora_weights["test"].alpha == 16.0
|
||||
assert lora_weights["test"].rank == 8
|
||||
assert lora_weights["test"].scale == 2.0
|
||||
|
||||
def test_file_not_found(self):
|
||||
from mlx_video.lora.loader import load_lora_weights
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_lora_weights(Path("/nonexistent/lora.safetensors"))
|
||||
|
||||
|
||||
class TestWanKeyNormalization:
|
||||
"""Test Wan2.2 LoRA key normalization."""
|
||||
|
||||
def _wan_model_keys(self):
|
||||
"""Simulate typical Wan2.2 MLX model weight keys."""
|
||||
keys = set()
|
||||
for i in range(2):
|
||||
for layer in ["self_attn.q", "self_attn.k", "self_attn.v", "self_attn.o",
|
||||
"cross_attn.q", "cross_attn.k", "cross_attn.v", "cross_attn.o"]:
|
||||
keys.add(f"blocks.{i}.{layer}.weight")
|
||||
keys.add(f"blocks.{i}.ffn.fc1.weight")
|
||||
keys.add(f"blocks.{i}.ffn.fc2.weight")
|
||||
keys.add("text_embedding_0.weight")
|
||||
keys.add("text_embedding_1.weight")
|
||||
keys.add("time_embedding_0.weight")
|
||||
keys.add("time_embedding_1.weight")
|
||||
keys.add("time_projection.weight")
|
||||
keys.add("patch_embedding_proj.weight")
|
||||
return keys
|
||||
|
||||
def test_direct_match(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert _normalize_wan_lora_key("blocks.0.self_attn.q", keys) == "blocks.0.self_attn.q"
|
||||
|
||||
def test_strip_diffusion_model_prefix(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
result = _normalize_wan_lora_key("diffusion_model.blocks.0.self_attn.q", keys)
|
||||
assert result == "blocks.0.self_attn.q"
|
||||
|
||||
def test_strip_model_prefix(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
result = _normalize_wan_lora_key("model.diffusion_model.blocks.0.self_attn.k", keys)
|
||||
assert result == "blocks.0.self_attn.k"
|
||||
|
||||
def test_ffn_key_mapping(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert _normalize_wan_lora_key("blocks.0.ffn.0", keys) == "blocks.0.ffn.fc1"
|
||||
assert _normalize_wan_lora_key("blocks.0.ffn.2", keys) == "blocks.0.ffn.fc2"
|
||||
|
||||
def test_text_embedding_mapping(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert _normalize_wan_lora_key("text_embedding.0", keys) == "text_embedding_0"
|
||||
assert _normalize_wan_lora_key("text_embedding.2", keys) == "text_embedding_1"
|
||||
|
||||
def test_time_embedding_mapping(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert _normalize_wan_lora_key("time_embedding.0", keys) == "time_embedding_0"
|
||||
assert _normalize_wan_lora_key("time_embedding.2", keys) == "time_embedding_1"
|
||||
|
||||
def test_time_projection_mapping(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert _normalize_wan_lora_key("time_projection.1", keys) == "time_projection"
|
||||
|
||||
def test_patch_embedding_mapping(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert _normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj"
|
||||
|
||||
def test_combined_prefix_and_ffn(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
result = _normalize_wan_lora_key("diffusion_model.blocks.1.ffn.0", keys)
|
||||
assert result == "blocks.1.ffn.fc1"
|
||||
|
||||
|
||||
class TestApplyLoRA:
|
||||
"""Test LoRA delta application to weights."""
|
||||
|
||||
def test_preserves_bfloat16_dtype(self):
|
||||
"""LoRA delta must not promote bfloat16 weights to float32."""
|
||||
from mlx_video.lora.apply import apply_lora_to_linear
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
original = mx.ones((8, 4), dtype=mx.bfloat16)
|
||||
# LoRA weights in float32 (typical when loaded from safetensors)
|
||||
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
|
||||
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
|
||||
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
||||
result = apply_lora_to_linear(original, [(w, 1.0)])
|
||||
assert result.dtype == mx.bfloat16, f"Expected bfloat16, got {result.dtype}"
|
||||
|
||||
def test_preserves_float16_dtype(self):
|
||||
from mlx_video.lora.apply import apply_lora_to_linear
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
original = mx.ones((8, 4), dtype=mx.float16)
|
||||
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
|
||||
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
|
||||
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
||||
result = apply_lora_to_linear(original, [(w, 1.0)])
|
||||
assert result.dtype == mx.float16, f"Expected float16, got {result.dtype}"
|
||||
|
||||
def test_apply_single_lora(self):
|
||||
from mlx_video.lora.apply import apply_lora_to_linear
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
original = mx.ones((8, 4))
|
||||
lora_a = mx.ones((2, 4)) * 0.1
|
||||
lora_b = mx.ones((8, 2)) * 0.1
|
||||
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
||||
result = apply_lora_to_linear(original, [(w, 1.0)])
|
||||
# delta = 1.0 * (B @ A) = ones(8,2)*0.1 @ ones(2,4)*0.1 = 0.02 * ones(8,4)
|
||||
expected = original + 0.02 * mx.ones((8, 4))
|
||||
assert mx.allclose(result, expected, atol=1e-6).item()
|
||||
|
||||
def test_apply_multiple_loras(self):
|
||||
from mlx_video.lora.apply import apply_lora_to_linear
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
original = mx.zeros((8, 4))
|
||||
w1 = LoRAWeights(
|
||||
lora_A=mx.ones((2, 4)),
|
||||
lora_B=mx.ones((8, 2)),
|
||||
rank=2, alpha=2.0, module_name="a",
|
||||
)
|
||||
w2 = LoRAWeights(
|
||||
lora_A=mx.ones((2, 4)) * 2,
|
||||
lora_B=mx.ones((8, 2)) * 2,
|
||||
rank=2, alpha=4.0, module_name="b",
|
||||
)
|
||||
result = apply_lora_to_linear(original, [(w1, 1.0), (w2, 0.5)])
|
||||
# w1 delta: 1.0 * 1.0 * (ones(8,2) @ ones(2,4)) = 2 * ones(8,4)
|
||||
# w2 delta: 2.0 * 0.5 * (2*ones(8,2) @ 2*ones(2,4)) = 1.0 * 8*ones(8,4) = 8
|
||||
delta1 = mx.ones((8, 4)) * 2.0
|
||||
delta2 = mx.ones((8, 4)) * 8.0
|
||||
expected = delta1 + delta2
|
||||
assert mx.allclose(result, expected, atol=1e-5).item()
|
||||
|
||||
def test_apply_loras_to_weights_dict(self):
|
||||
from mlx_video.lora.apply import apply_loras_to_weights
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
model_weights = {
|
||||
"blocks.0.self_attn.q.weight": mx.ones((128, 64)),
|
||||
"blocks.0.self_attn.k.weight": mx.ones((128, 64)),
|
||||
"blocks.0.ffn.fc1.weight": mx.ones((256, 64)),
|
||||
}
|
||||
w = LoRAWeights(
|
||||
lora_A=mx.ones((4, 64)) * 0.01,
|
||||
lora_B=mx.ones((128, 4)) * 0.01,
|
||||
rank=4, alpha=4.0, module_name="blocks.0.self_attn.q",
|
||||
)
|
||||
module_to_loras = {"blocks.0.self_attn.q": [(w, 1.0)]}
|
||||
result = apply_loras_to_weights(model_weights, module_to_loras)
|
||||
# Only q should be modified
|
||||
assert not mx.array_equal(
|
||||
result["blocks.0.self_attn.q.weight"],
|
||||
model_weights["blocks.0.self_attn.q.weight"],
|
||||
).item()
|
||||
assert mx.array_equal(
|
||||
result["blocks.0.self_attn.k.weight"],
|
||||
model_weights["blocks.0.self_attn.k.weight"],
|
||||
).item()
|
||||
|
||||
|
||||
class TestEndToEnd:
|
||||
"""End-to-end LoRA loading and application."""
|
||||
|
||||
def test_load_and_apply_loras(self):
|
||||
from mlx_video.convert_wan import load_and_apply_loras
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
# Create mock LoRA safetensors
|
||||
rank = 4
|
||||
weights = {
|
||||
"blocks.0.self_attn.q.lora_A.weight": mx.random.normal((rank, 64)),
|
||||
"blocks.0.self_attn.q.lora_B.weight": mx.random.normal((128, rank)),
|
||||
}
|
||||
lora_path = Path(tmp) / "test.safetensors"
|
||||
mx.save_safetensors(str(lora_path), weights)
|
||||
|
||||
# Create mock model weights
|
||||
model_weights = {
|
||||
"blocks.0.self_attn.q.weight": mx.ones((128, 64)),
|
||||
"blocks.0.self_attn.k.weight": mx.ones((128, 64)),
|
||||
}
|
||||
|
||||
result = load_and_apply_loras(
|
||||
model_weights, [(str(lora_path), 1.0)]
|
||||
)
|
||||
|
||||
# q weight should be modified, k unchanged
|
||||
assert not mx.array_equal(
|
||||
result["blocks.0.self_attn.q.weight"],
|
||||
model_weights["blocks.0.self_attn.q.weight"],
|
||||
).item()
|
||||
assert mx.array_equal(
|
||||
result["blocks.0.self_attn.k.weight"],
|
||||
model_weights["blocks.0.self_attn.k.weight"],
|
||||
).item()
|
||||
332
tests/test_wan_model.py
Normal file
332
tests/test_wan_model.py
Normal file
@@ -0,0 +1,332 @@
|
||||
"""Tests for Wan model components."""
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from wan_test_helpers import _make_tiny_config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sinusoidal Embedding Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSinusoidalEmbedding:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
pos = mx.arange(10).astype(mx.float32)
|
||||
emb = sinusoidal_embedding_1d(256, pos)
|
||||
mx.eval(emb)
|
||||
assert emb.shape == (10, 256)
|
||||
|
||||
def test_position_zero(self):
|
||||
"""Position 0 should have cos=1 for all dims and sin=0."""
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
pos = mx.array([0.0])
|
||||
emb = sinusoidal_embedding_1d(64, pos)
|
||||
mx.eval(emb)
|
||||
emb_np = np.array(emb[0])
|
||||
# First half is cos, should be 1 at position 0
|
||||
np.testing.assert_allclose(emb_np[:32], 1.0, atol=1e-5)
|
||||
# Second half is sin, should be 0 at position 0
|
||||
np.testing.assert_allclose(emb_np[32:], 0.0, atol=1e-5)
|
||||
|
||||
def test_different_positions_differ(self):
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
pos = mx.array([0.0, 100.0, 999.0])
|
||||
emb = sinusoidal_embedding_1d(128, pos)
|
||||
mx.eval(emb)
|
||||
emb_np = np.array(emb)
|
||||
assert not np.allclose(emb_np[0], emb_np[1])
|
||||
assert not np.allclose(emb_np[1], emb_np[2])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Head Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHead:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.model import Head
|
||||
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
||||
B, L = 1, 24
|
||||
x = mx.random.normal((B, L, 64))
|
||||
e = mx.random.normal((B, 64)) # time embedding: [B, dim]
|
||||
out = head(x, e)
|
||||
mx.eval(out)
|
||||
expected_proj_dim = 16 * 1 * 2 * 2 # 64
|
||||
assert out.shape == (B, L, expected_proj_dim)
|
||||
|
||||
def test_modulation_shape(self):
|
||||
from mlx_video.models.wan.model import Head
|
||||
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
||||
assert head.modulation.shape == (1, 2, 64)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WanModel (Tiny) Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestWanModel:
|
||||
def setup_method(self):
|
||||
mx.random.seed(42)
|
||||
|
||||
def test_instantiation(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
num_params = sum(p.size for _, p in nn.utils.tree_flatten(model.parameters()))
|
||||
assert num_params > 0
|
||||
|
||||
def test_patchify_shape(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
# Input: [C=4, F=1, H=4, W=4]
|
||||
x = mx.random.normal((4, 1, 4, 4))
|
||||
patches, grid_size = model._patchify(x)
|
||||
mx.eval(patches)
|
||||
# Patch size (1,2,2): F'=1, H'=2, W'=2
|
||||
assert grid_size == (1, 2, 2)
|
||||
assert patches.shape == (1, 1 * 2 * 2, config.dim)
|
||||
|
||||
def test_patchify_various_sizes(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
for f, h, w in [(1, 4, 4), (2, 6, 8), (3, 4, 6)]:
|
||||
x = mx.random.normal((config.in_dim, f, h, w))
|
||||
patches, (gf, gh, gw) = model._patchify(x)
|
||||
mx.eval(patches)
|
||||
pt, ph, pw = config.patch_size
|
||||
assert gf == f // pt
|
||||
assert gh == h // ph
|
||||
assert gw == w // pw
|
||||
assert patches.shape[1] == gf * gh * gw
|
||||
|
||||
def test_unpatchify_inverse(self):
|
||||
"""Patchify then unpatchify should reconstruct original spatial dims."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
C, F, H, W = config.in_dim, 2, 4, 6
|
||||
pt, ph, pw = config.patch_size
|
||||
F_out, H_out, W_out = F // pt, H // ph, W // pw
|
||||
L = F_out * H_out * W_out
|
||||
proj_dim = config.out_dim * pt * ph * pw
|
||||
# Simulated head output
|
||||
x = mx.random.normal((1, L, proj_dim))
|
||||
out = model.unpatchify(x, [(F_out, H_out, W_out)])
|
||||
mx.eval(out[0])
|
||||
assert out[0].shape == (config.out_dim, F, H, W)
|
||||
|
||||
def test_forward_pass(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
x_list = [mx.random.normal((C, F, H, W))]
|
||||
t = mx.array([500.0])
|
||||
context = [mx.random.normal((6, config.text_dim))]
|
||||
|
||||
out = model(x_list, t, context, seq_len)
|
||||
mx.eval(out[0])
|
||||
assert len(out) == 1
|
||||
assert out[0].shape == (C, F, H, W)
|
||||
|
||||
def test_forward_batch(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
x_list = [mx.random.normal((C, F, H, W)), mx.random.normal((C, F, H, W))]
|
||||
t = mx.array([500.0, 200.0])
|
||||
context = [mx.random.normal((6, config.text_dim)), mx.random.normal((4, config.text_dim))]
|
||||
|
||||
out = model(x_list, t, context, seq_len)
|
||||
mx.eval(out[0], out[1])
|
||||
assert len(out) == 2
|
||||
for o in out:
|
||||
assert o.shape == (C, F, H, W)
|
||||
|
||||
def test_output_is_float32(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
seq_len = (F // 1) * (H // 2) * (W // 2)
|
||||
out = model([mx.random.normal((C, F, H, W))], mx.array([100.0]),
|
||||
[mx.random.normal((4, config.text_dim))], seq_len)
|
||||
mx.eval(out[0])
|
||||
assert out[0].dtype == mx.float32
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wan2.1 Model Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestWan21Model:
|
||||
"""Test tiny Wan2.1-style model (single model mode)."""
|
||||
|
||||
def setup_method(self):
|
||||
mx.random.seed(42)
|
||||
|
||||
def _make_tiny_wan21_config(self):
|
||||
"""Create a tiny config mimicking Wan2.1 (single model)."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
# Override to tiny values
|
||||
config.dim = 64
|
||||
config.ffn_dim = 128
|
||||
config.num_heads = 4
|
||||
config.num_layers = 2
|
||||
config.in_dim = 4
|
||||
config.out_dim = 4
|
||||
config.freq_dim = 32
|
||||
config.text_dim = 32
|
||||
config.text_len = 8
|
||||
return config
|
||||
|
||||
def _make_tiny_wan21_1_3b_config(self):
|
||||
"""Create a tiny config mimicking Wan2.1 1.3B."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
# Override to tiny values (preserve 1.3B head structure: 12 heads)
|
||||
config.dim = 48
|
||||
config.ffn_dim = 96
|
||||
config.num_heads = 4
|
||||
config.num_layers = 2
|
||||
config.in_dim = 4
|
||||
config.out_dim = 4
|
||||
config.freq_dim = 24
|
||||
config.text_dim = 24
|
||||
config.text_len = 8
|
||||
return config
|
||||
|
||||
def test_wan21_tiny_model_forward(self):
|
||||
"""Forward pass with Wan2.1 tiny config."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = self._make_tiny_wan21_config()
|
||||
model = WanModel(config)
|
||||
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
seq_len = (F // 1) * (H // 2) * (W // 2)
|
||||
|
||||
latents = mx.random.normal((C, F, H, W))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
t = mx.array([500.0])
|
||||
|
||||
out = model([latents], t, [context], seq_len)
|
||||
mx.eval(out)
|
||||
assert out[0].shape == (C, F, H, W)
|
||||
|
||||
def test_wan21_1_3b_tiny_model_forward(self):
|
||||
"""Forward pass with Wan2.1 1.3B tiny config."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = self._make_tiny_wan21_1_3b_config()
|
||||
model = WanModel(config)
|
||||
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
seq_len = (F // 1) * (H // 2) * (W // 2)
|
||||
|
||||
latents = mx.random.normal((C, F, H, W))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
t = mx.array([500.0])
|
||||
|
||||
out = model([latents], t, [context], seq_len)
|
||||
mx.eval(out)
|
||||
assert out[0].shape == (C, F, H, W)
|
||||
|
||||
def test_wan21_single_model_loop(self):
|
||||
"""Full diffusion loop with single model (Wan2.1 style)."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
config = self._make_tiny_wan21_config()
|
||||
model = WanModel(config)
|
||||
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
seq_len = (F // 1) * (H // 2) * (W // 2)
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(config.sample_steps, shift=config.sample_shift)
|
||||
|
||||
# Use only 3 steps for speed
|
||||
latents = mx.random.normal((C, F, H, W))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
context_null = mx.zeros((4, config.text_dim))
|
||||
gs = config.sample_guide_scale # Should be float for Wan2.1
|
||||
|
||||
assert isinstance(gs, float), "Wan2.1 guide_scale should be float"
|
||||
|
||||
for i in range(3):
|
||||
t = sched.timesteps[i]
|
||||
pred_cond = model([latents], mx.array([t.item()]), [context], seq_len)[0]
|
||||
pred_uncond = model([latents], mx.array([t.item()]), [context_null], seq_len)[0]
|
||||
pred = pred_uncond + gs * (pred_cond - pred_uncond)
|
||||
latents = sched.step(pred[None], t, latents[None]).squeeze(0)
|
||||
mx.eval(latents)
|
||||
|
||||
assert latents.shape == (C, F, H, W)
|
||||
assert not mx.any(mx.isnan(latents)).item()
|
||||
|
||||
def test_wan21_vs_wan22_config_differences(self):
|
||||
"""Verify key differences between Wan2.1 and Wan2.2 configs."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
c21 = WanModelConfig.wan21_t2v_14b()
|
||||
c22 = WanModelConfig.wan22_t2v_14b()
|
||||
|
||||
# Same architecture
|
||||
assert c21.dim == c22.dim
|
||||
assert c21.num_heads == c22.num_heads
|
||||
assert c21.num_layers == c22.num_layers
|
||||
|
||||
# Different pipeline settings
|
||||
assert c21.dual_model is False
|
||||
assert c22.dual_model is True
|
||||
assert isinstance(c21.sample_guide_scale, float)
|
||||
assert isinstance(c22.sample_guide_scale, tuple)
|
||||
assert c21.sample_shift != c22.sample_shift
|
||||
assert c21.sample_steps != c22.sample_steps
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-Token Timestep Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPerTokenTimestep:
|
||||
"""Tests for per-token sinusoidal embedding."""
|
||||
|
||||
def test_1d_unchanged(self):
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
|
||||
pos = mx.array([0.0, 100.0, 500.0])
|
||||
emb = sinusoidal_embedding_1d(256, pos)
|
||||
assert emb.shape == (3, 256)
|
||||
|
||||
def test_2d_per_token(self):
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
|
||||
pos = mx.array([[0.0, 100.0, 100.0], [50.0, 50.0, 50.0]])
|
||||
emb = sinusoidal_embedding_1d(256, pos)
|
||||
assert emb.shape == (2, 3, 256)
|
||||
|
||||
def test_consistency(self):
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
|
||||
pos_1d = mx.array([0.0, 100.0])
|
||||
emb_1d = sinusoidal_embedding_1d(256, pos_1d)
|
||||
pos_2d = mx.array([[0.0, 100.0]])
|
||||
emb_2d = sinusoidal_embedding_1d(256, pos_2d)
|
||||
assert mx.array_equal(emb_1d[0], emb_2d[0, 0])
|
||||
assert mx.array_equal(emb_1d[1], emb_2d[0, 1])
|
||||
313
tests/test_wan_quantization.py
Normal file
313
tests/test_wan_quantization.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""Tests for Wan model quantization pipeline."""
|
||||
|
||||
import json
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.utils
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from wan_test_helpers import _make_tiny_config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quantize Predicate Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestQuantizePredicate:
|
||||
def test_matches_self_attention_layers(self):
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
for suffix in ["q", "k", "v", "o"]:
|
||||
path = f"blocks.0.self_attn.{suffix}"
|
||||
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
|
||||
|
||||
def test_matches_cross_attention_layers(self):
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
for suffix in ["q", "k", "v", "o"]:
|
||||
path = f"blocks.0.cross_attn.{suffix}"
|
||||
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
|
||||
|
||||
def test_matches_ffn_layers(self):
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear)
|
||||
assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear)
|
||||
|
||||
def test_rejects_embeddings(self):
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
for path in ["patch_embedding_proj", "text_embedding_fc1", "time_embedding.fc1"]:
|
||||
assert not _quantize_predicate(path, mock_linear), f"Should reject {path}"
|
||||
|
||||
def test_rejects_norms(self):
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
mock_norm = nn.RMSNorm(64)
|
||||
assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm)
|
||||
|
||||
def test_rejects_non_quantizable_modules(self):
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
mock_norm = nn.RMSNorm(64)
|
||||
# Even if path matches, module must have to_quantized
|
||||
assert not _quantize_predicate("blocks.0.self_attn.q", mock_norm)
|
||||
|
||||
def test_all_10_patterns_covered(self):
|
||||
"""Verify exactly 10 layer patterns are targeted."""
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
patterns = [
|
||||
"blocks.0.self_attn.q", "blocks.0.self_attn.k",
|
||||
"blocks.0.self_attn.v", "blocks.0.self_attn.o",
|
||||
"blocks.0.cross_attn.q", "blocks.0.cross_attn.k",
|
||||
"blocks.0.cross_attn.v", "blocks.0.cross_attn.o",
|
||||
"blocks.0.ffn.fc1", "blocks.0.ffn.fc2",
|
||||
]
|
||||
matched = [p for p in patterns if _quantize_predicate(p, mock_linear)]
|
||||
assert len(matched) == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quantize Round-Trip Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestQuantizeRoundTrip:
|
||||
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
|
||||
"""Helper: create model, quantize, save to tmp_path."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
|
||||
model = WanModel(config)
|
||||
nn.quantize(
|
||||
model,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
||||
)
|
||||
|
||||
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
||||
model_path = tmp_path / "model.safetensors"
|
||||
mx.save_safetensors(str(model_path), weights_dict)
|
||||
|
||||
# Write config.json
|
||||
cfg = {"quantization": {"bits": bits, "group_size": group_size}}
|
||||
with open(tmp_path / "config.json", "w") as f:
|
||||
json.dump(cfg, f)
|
||||
|
||||
return model_path, weights_dict
|
||||
|
||||
def test_4bit_roundtrip(self, tmp_path):
|
||||
config = _make_tiny_config()
|
||||
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4)
|
||||
|
||||
from mlx_video.models.wan.loading import load_wan_model
|
||||
loaded = load_wan_model(
|
||||
model_path, config,
|
||||
quantization={"bits": 4, "group_size": 64},
|
||||
)
|
||||
|
||||
# Verify quantized layers have scales
|
||||
has_scales = any("scales" in k for k in saved_weights)
|
||||
assert has_scales, "Quantized model should have .scales tensors"
|
||||
|
||||
# Verify a self-attention layer is QuantizedLinear
|
||||
assert isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear)
|
||||
assert isinstance(loaded.blocks[0].ffn.fc1, nn.QuantizedLinear)
|
||||
|
||||
def test_8bit_roundtrip(self, tmp_path):
|
||||
config = _make_tiny_config()
|
||||
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8)
|
||||
|
||||
from mlx_video.models.wan.loading import load_wan_model
|
||||
loaded = load_wan_model(
|
||||
model_path, config,
|
||||
quantization={"bits": 8, "group_size": 64},
|
||||
)
|
||||
|
||||
assert isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear)
|
||||
assert isinstance(loaded.blocks[0].cross_attn.k, nn.QuantizedLinear)
|
||||
|
||||
def test_non_quantized_layers_remain_linear(self, tmp_path):
|
||||
config = _make_tiny_config()
|
||||
model_path, _ = self._quantize_and_save(config, tmp_path, bits=4)
|
||||
|
||||
from mlx_video.models.wan.loading import load_wan_model
|
||||
loaded = load_wan_model(
|
||||
model_path, config,
|
||||
quantization={"bits": 4, "group_size": 64},
|
||||
)
|
||||
|
||||
# Head should NOT be quantized (it's not in the predicate patterns)
|
||||
assert not isinstance(loaded.head, nn.QuantizedLinear)
|
||||
|
||||
def test_loading_without_quantization_flag(self, tmp_path):
|
||||
"""Loading a non-quantized model should have standard Linear layers."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
||||
model_path = tmp_path / "model.safetensors"
|
||||
mx.save_safetensors(str(model_path), weights_dict)
|
||||
|
||||
from mlx_video.models.wan.loading import load_wan_model
|
||||
loaded = load_wan_model(model_path, config, quantization=None)
|
||||
|
||||
assert isinstance(loaded.blocks[0].self_attn.q, nn.Linear)
|
||||
assert not isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quantized Inference Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestQuantizedInference:
|
||||
def _make_quantized_model(self, config, bits=4):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
|
||||
model = WanModel(config)
|
||||
nn.quantize(
|
||||
model,
|
||||
group_size=64,
|
||||
bits=bits,
|
||||
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
||||
)
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
|
||||
def test_forward_pass_4bit(self):
|
||||
config = _make_tiny_config()
|
||||
model = self._make_quantized_model(config, bits=4)
|
||||
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
x = [mx.random.normal((C, F, H, W))]
|
||||
t = mx.array([500.0])
|
||||
context = [mx.random.normal((4, config.text_dim))]
|
||||
|
||||
out = model(x, t, context, seq_len)
|
||||
mx.eval(out[0])
|
||||
|
||||
assert len(out) == 1
|
||||
assert out[0].shape == (C, F, H, W)
|
||||
|
||||
def test_forward_pass_8bit(self):
|
||||
config = _make_tiny_config()
|
||||
model = self._make_quantized_model(config, bits=8)
|
||||
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
x = [mx.random.normal((C, F, H, W))]
|
||||
t = mx.array([500.0])
|
||||
context = [mx.random.normal((4, config.text_dim))]
|
||||
|
||||
out = model(x, t, context, seq_len)
|
||||
mx.eval(out[0])
|
||||
|
||||
assert len(out) == 1
|
||||
assert out[0].shape == (C, F, H, W)
|
||||
|
||||
def test_quantized_output_differs_from_unquantized(self):
|
||||
"""Sanity check: quantization should change the weights."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
|
||||
config = _make_tiny_config()
|
||||
mx.random.seed(42)
|
||||
|
||||
# Get unquantized weights
|
||||
model = WanModel(config)
|
||||
mx.eval(model.parameters())
|
||||
orig_weight = np.array(model.blocks[0].self_attn.q.weight)
|
||||
|
||||
# Quantize
|
||||
nn.quantize(
|
||||
model,
|
||||
group_size=64,
|
||||
bits=4,
|
||||
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
||||
)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
# QuantizedLinear stores weight differently (uint32 packed)
|
||||
assert isinstance(model.blocks[0].self_attn.q, nn.QuantizedLinear)
|
||||
assert hasattr(model.blocks[0].self_attn.q, "scales")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config Metadata Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestQuantizationConfig:
|
||||
def test_config_metadata_written(self, tmp_path):
|
||||
"""Verify _quantize_saved_model writes quantization metadata to config.json."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.convert_wan import _quantize_saved_model
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
||||
|
||||
# Save unquantized model + config
|
||||
model_path = tmp_path / "model.safetensors"
|
||||
mx.save_safetensors(str(model_path), weights_dict)
|
||||
with open(tmp_path / "config.json", "w") as f:
|
||||
json.dump({"dim": config.dim}, f)
|
||||
|
||||
# Run quantization
|
||||
_quantize_saved_model(tmp_path, config, is_dual=False, bits=4, group_size=64)
|
||||
|
||||
# Verify metadata
|
||||
with open(tmp_path / "config.json") as f:
|
||||
cfg = json.load(f)
|
||||
assert "quantization" in cfg
|
||||
assert cfg["quantization"]["bits"] == 4
|
||||
assert cfg["quantization"]["group_size"] == 64
|
||||
|
||||
def test_config_metadata_8bit(self, tmp_path):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.convert_wan import _quantize_saved_model
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
||||
|
||||
model_path = tmp_path / "model.safetensors"
|
||||
mx.save_safetensors(str(model_path), weights_dict)
|
||||
with open(tmp_path / "config.json", "w") as f:
|
||||
json.dump({}, f)
|
||||
|
||||
_quantize_saved_model(tmp_path, config, is_dual=False, bits=8, group_size=32)
|
||||
|
||||
with open(tmp_path / "config.json") as f:
|
||||
cfg = json.load(f)
|
||||
assert cfg["quantization"]["bits"] == 8
|
||||
assert cfg["quantization"]["group_size"] == 32
|
||||
|
||||
def test_dual_model_quantization(self, tmp_path):
|
||||
"""Verify dual-model quantization writes both model files."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.convert_wan import _quantize_saved_model
|
||||
|
||||
config = _make_tiny_config()
|
||||
|
||||
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
|
||||
model = WanModel(config)
|
||||
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
||||
mx.save_safetensors(str(tmp_path / name), weights_dict)
|
||||
|
||||
with open(tmp_path / "config.json", "w") as f:
|
||||
json.dump({}, f)
|
||||
|
||||
_quantize_saved_model(tmp_path, config, is_dual=True, bits=4, group_size=64)
|
||||
|
||||
# Both files should now contain quantized weights (have .scales keys)
|
||||
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
|
||||
weights = mx.load(str(tmp_path / name))
|
||||
has_scales = any("scales" in k for k in weights)
|
||||
assert has_scales, f"{name} should have quantized layers"
|
||||
334
tests/test_wan_rope_freqs.py
Normal file
334
tests/test_wan_rope_freqs.py
Normal file
@@ -0,0 +1,334 @@
|
||||
"""Tests for Wan RoPE frequency construction (Bug 6 regression tests).
|
||||
|
||||
These tests verify that the RoPE frequency table is built correctly by
|
||||
concatenating three separate rope_params calls with different dimension
|
||||
normalizations, matching the reference implementation.
|
||||
|
||||
Background: The reference Wan model constructs RoPE frequencies as:
|
||||
d = dim // num_heads (128 for all Wan models)
|
||||
freqs = cat([
|
||||
rope_params(1024, d - 4*(d//6)), # temporal (dim=44, 22 freqs)
|
||||
rope_params(1024, 2*(d//6)), # height (dim=42, 21 freqs)
|
||||
rope_params(1024, 2*(d//6)), # width (dim=42, 21 freqs)
|
||||
])
|
||||
|
||||
A previous incorrect fix used a single rope_params(1024, 128) call, which
|
||||
gave height/width axes only medium/high frequencies instead of full-range.
|
||||
This destroyed spatial position encoding and caused grey/artifact output.
|
||||
"""
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
class TestRoPEFrequencyConstruction:
|
||||
"""Verify WanModel builds RoPE frequencies matching the reference."""
|
||||
|
||||
def _get_model_freqs(self, dim=64, num_heads=4):
|
||||
"""Instantiate a tiny WanModel and return its .freqs tensor."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = WanModelConfig()
|
||||
config.dim = dim
|
||||
config.ffn_dim = dim * 2
|
||||
config.num_heads = num_heads
|
||||
config.num_layers = 1
|
||||
config.in_dim = 4
|
||||
config.out_dim = 4
|
||||
config.freq_dim = 32
|
||||
config.text_dim = 32
|
||||
config.text_len = 8
|
||||
model = WanModel(config)
|
||||
mx.eval(model.freqs)
|
||||
return model.freqs, dim // num_heads
|
||||
|
||||
def test_freqs_shape(self):
|
||||
"""Freqs should be [1024, head_dim//2, 2] regardless of construction."""
|
||||
freqs, head_dim = self._get_model_freqs(dim=64, num_heads=4)
|
||||
assert freqs.shape == (1024, head_dim // 2, 2)
|
||||
|
||||
def test_three_call_vs_single_call_differ(self):
|
||||
"""Three separate rope_params calls must differ from single call."""
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
d = 128 # head_dim for all Wan models
|
||||
# Reference: three separate calls
|
||||
correct = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
# Wrong: single call
|
||||
wrong = rope_params(1024, d)
|
||||
mx.eval(correct, wrong)
|
||||
|
||||
assert correct.shape == wrong.shape
|
||||
diff = np.abs(np.array(correct) - np.array(wrong)).max()
|
||||
assert diff > 0.1, f"Three-call and single-call should differ significantly, got max diff {diff}"
|
||||
|
||||
def test_each_axis_starts_at_frequency_one(self):
|
||||
"""Each axis (temporal/height/width) should have cos=1, sin=0 at position 0.
|
||||
|
||||
This verifies each axis gets its own independent frequency range
|
||||
starting from theta^0 = 1.0 (i.e., exponent 0/dim).
|
||||
"""
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
mx.eval(freqs)
|
||||
f = np.array(freqs)
|
||||
|
||||
half_d = d // 2 # 64
|
||||
d_t = half_d - 2 * (half_d // 3) # 22
|
||||
d_h = half_d // 3 # 21
|
||||
|
||||
# At position 0, cos=1 and sin=0 for ALL frequency components
|
||||
np.testing.assert_allclose(f[0, :, 0], 1.0, atol=1e-6, err_msg="cos at pos 0")
|
||||
np.testing.assert_allclose(f[0, :, 1], 0.0, atol=1e-6, err_msg="sin at pos 0")
|
||||
|
||||
# At position 1, each axis should have its FIRST frequency near cos(1/theta^0)=cos(1)
|
||||
# Temporal axis first freq
|
||||
np.testing.assert_allclose(f[1, 0, 0], np.cos(1.0), atol=1e-5,
|
||||
err_msg="temporal[0] cos at pos 1")
|
||||
# Height axis first freq (starts at index d_t)
|
||||
np.testing.assert_allclose(f[1, d_t, 0], np.cos(1.0), atol=1e-5,
|
||||
err_msg="height[0] cos at pos 1")
|
||||
# Width axis first freq (starts at index d_t + d_h)
|
||||
np.testing.assert_allclose(f[1, d_t + d_h, 0], np.cos(1.0), atol=1e-5,
|
||||
err_msg="width[0] cos at pos 1")
|
||||
|
||||
def test_height_width_frequencies_identical(self):
|
||||
"""Height and width axes should have identical frequency tables.
|
||||
|
||||
Both use rope_params(1024, 2*(d//6)) = rope_params(1024, 42).
|
||||
"""
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
d = 128
|
||||
d_h_dim = 2 * (d // 6) # 42
|
||||
freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, d_h_dim),
|
||||
rope_params(1024, d_h_dim),
|
||||
], axis=1)
|
||||
mx.eval(freqs)
|
||||
f = np.array(freqs)
|
||||
|
||||
half_d = d // 2
|
||||
d_t = half_d - 2 * (half_d // 3)
|
||||
d_h = half_d // 3
|
||||
|
||||
height_freqs = f[:, d_t:d_t + d_h]
|
||||
width_freqs = f[:, d_t + d_h:]
|
||||
np.testing.assert_array_equal(height_freqs, width_freqs)
|
||||
|
||||
def test_frequency_range_per_axis(self):
|
||||
"""Each axis should span a full frequency range, not a slice of one range.
|
||||
|
||||
With three-call construction, the inverse frequency at index 0 of each
|
||||
axis should be 1.0 (theta^0). A single-call approach would give height
|
||||
starting at ~0.04 and width at ~0.002 instead of 1.0.
|
||||
"""
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
mx.eval(freqs)
|
||||
f = np.array(freqs)
|
||||
|
||||
half_d = d // 2
|
||||
d_t = half_d - 2 * (half_d // 3)
|
||||
d_h = half_d // 3
|
||||
|
||||
# At position 1, the first frequency component of each axis should
|
||||
# have significant magnitude (cos ≈ 0.54), not near-zero
|
||||
pos1_t = f[1, 0, 0] # temporal first freq
|
||||
pos1_h = f[1, d_t, 0] # height first freq
|
||||
pos1_w = f[1, d_t + d_h, 0] # width first freq
|
||||
|
||||
assert pos1_t > 0.5, f"Temporal first freq at pos 1 should be >0.5, got {pos1_t}"
|
||||
assert pos1_h > 0.5, f"Height first freq at pos 1 should be >0.5, got {pos1_h}"
|
||||
assert pos1_w > 0.5, f"Width first freq at pos 1 should be >0.5, got {pos1_w}"
|
||||
|
||||
def test_model_freqs_match_manual_construction(self):
|
||||
"""WanModel.freqs should match manually constructed three-call freqs."""
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4)
|
||||
d = head_dim # 16
|
||||
freqs_manual = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
mx.eval(freqs_model, freqs_manual)
|
||||
np.testing.assert_array_equal(
|
||||
np.array(freqs_model), np.array(freqs_manual),
|
||||
err_msg="WanModel.freqs should use three-call construction"
|
||||
)
|
||||
|
||||
def test_model_freqs_14b_dimensions(self):
|
||||
"""Verify freq dimensions for 14B-scale head_dim=128."""
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)), # dim=44 → 22 freq pairs
|
||||
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
|
||||
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
|
||||
], axis=1)
|
||||
mx.eval(freqs)
|
||||
|
||||
assert freqs.shape == (1024, 64, 2)
|
||||
# Verify the split dimensions used by rope_apply
|
||||
half_d = 64
|
||||
d_t = half_d - 2 * (half_d // 3)
|
||||
d_h = half_d // 3
|
||||
d_w = half_d // 3
|
||||
assert (d_t, d_h, d_w) == (22, 21, 21)
|
||||
assert d_t + d_h + d_w == half_d
|
||||
|
||||
|
||||
class TestRoPEFrequencyMatchesReference:
|
||||
"""Cross-validate MLX RoPE against PyTorch reference implementation."""
|
||||
|
||||
@pytest.fixture
|
||||
def has_torch(self):
|
||||
try:
|
||||
import torch
|
||||
return True
|
||||
except ImportError:
|
||||
pytest.skip("PyTorch not installed")
|
||||
|
||||
def test_freqs_match_pytorch_reference(self, has_torch):
|
||||
"""Numerically compare MLX and PyTorch frequency tables."""
|
||||
import torch
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
d = 128
|
||||
|
||||
# PyTorch reference (from wan/modules/model.py)
|
||||
def pt_rope_params(max_seq_len, dim, theta=10000):
|
||||
freqs = torch.outer(
|
||||
torch.arange(max_seq_len),
|
||||
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)))
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
ref = torch.cat([
|
||||
pt_rope_params(1024, d - 4 * (d // 6)),
|
||||
pt_rope_params(1024, 2 * (d // 6)),
|
||||
pt_rope_params(1024, 2 * (d // 6)),
|
||||
], dim=1)
|
||||
|
||||
# MLX
|
||||
ours = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
mx.eval(ours)
|
||||
|
||||
our_cos = np.array(ours[:, :, 0])
|
||||
our_sin = np.array(ours[:, :, 1])
|
||||
ref_cos = ref.real.float().numpy()
|
||||
ref_sin = ref.imag.float().numpy()
|
||||
|
||||
np.testing.assert_allclose(our_cos, ref_cos, atol=1e-6,
|
||||
err_msg="cos mismatch vs PyTorch reference")
|
||||
np.testing.assert_allclose(our_sin, ref_sin, atol=1e-6,
|
||||
err_msg="sin mismatch vs PyTorch reference")
|
||||
|
||||
|
||||
class TestRoPEApplyWithCorrectFreqs:
|
||||
"""Test that rope_apply produces correct rotations with three-call freqs."""
|
||||
|
||||
def test_different_spatial_positions_get_different_rotations(self):
|
||||
"""Adjacent height/width positions must produce different RoPE rotations.
|
||||
|
||||
This is the key property that was broken by the single-call bug:
|
||||
height/width frequencies were too low to distinguish nearby positions.
|
||||
"""
|
||||
from mlx_video.models.wan.rope import rope_params, rope_apply
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
|
||||
B, N = 1, 4
|
||||
F, H, W = 1, 4, 4
|
||||
L = F * H * W
|
||||
# Use a constant input so differences come purely from RoPE
|
||||
x = mx.ones((B, L, N, d))
|
||||
out = rope_apply(x, [(F, H, W)], freqs)
|
||||
mx.eval(out)
|
||||
out_np = np.array(out[0])
|
||||
|
||||
# Position (0,0,0) vs (0,1,0) — different height
|
||||
pos_00 = out_np[0 * H * W + 0 * W + 0] # (f=0, h=0, w=0)
|
||||
pos_10 = out_np[0 * H * W + 1 * W + 0] # (f=0, h=1, w=0)
|
||||
height_diff = np.abs(pos_00 - pos_10).max()
|
||||
|
||||
# Position (0,0,0) vs (0,0,1) — different width
|
||||
pos_01 = out_np[0 * H * W + 0 * W + 1] # (f=0, h=0, w=1)
|
||||
width_diff = np.abs(pos_00 - pos_01).max()
|
||||
|
||||
# Max diff should be >0.5 for both axes. With the bug, height was ~0.04
|
||||
# and width was ~0.002. With correct freqs, both are ~1.3.
|
||||
assert height_diff > 0.5, (
|
||||
f"Adjacent height positions should differ significantly, got {height_diff:.4f}"
|
||||
)
|
||||
assert width_diff > 0.5, (
|
||||
f"Adjacent width positions should differ significantly, got {width_diff:.4f}"
|
||||
)
|
||||
# Height and width should have identical frequency tables → same diffs
|
||||
np.testing.assert_allclose(height_diff, width_diff, rtol=1e-5,
|
||||
err_msg="Height and width should use identical frequency tables")
|
||||
|
||||
def test_precomputed_matches_online(self):
|
||||
"""rope_precompute_cos_sin + rope_apply should match non-precomputed path."""
|
||||
from mlx_video.models.wan.rope import (
|
||||
rope_apply,
|
||||
rope_params,
|
||||
rope_precompute_cos_sin,
|
||||
)
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
|
||||
B, N = 2, 4
|
||||
F, H, W = 2, 3, 4
|
||||
L = F * H * W
|
||||
grids = [(F, H, W), (F, H, W)]
|
||||
|
||||
x = mx.random.normal((B, L, N, d))
|
||||
|
||||
# Online (no precomputed)
|
||||
out_online = rope_apply(x, grids, freqs)
|
||||
# Precomputed
|
||||
cos_sin = rope_precompute_cos_sin(grids, freqs)
|
||||
out_precomp = rope_apply(x, grids, freqs, precomputed_cos_sin=cos_sin)
|
||||
mx.eval(out_online, out_precomp)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
np.array(out_online), np.array(out_precomp), atol=1e-5,
|
||||
err_msg="Precomputed and online RoPE should match"
|
||||
)
|
||||
917
tests/test_wan_scheduler.py
Normal file
917
tests/test_wan_scheduler.py
Normal file
@@ -0,0 +1,917 @@
|
||||
"""Tests for Wan scheduler components."""
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Euler Scheduler Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFlowMatchEulerScheduler:
|
||||
def test_initialization(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
sched = FlowMatchEulerScheduler()
|
||||
assert sched.num_train_timesteps == 1000
|
||||
assert sched.timesteps is None
|
||||
assert sched.sigmas is None
|
||||
|
||||
def test_set_timesteps(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(40, shift=12.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
assert sched.timesteps.shape == (40,)
|
||||
assert sched.sigmas.shape == (41,) # 40 steps + terminal
|
||||
|
||||
def test_timesteps_decreasing(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(40, shift=12.0)
|
||||
mx.eval(sched.timesteps)
|
||||
ts = np.array(sched.timesteps)
|
||||
# Timesteps should be monotonically decreasing
|
||||
assert np.all(np.diff(ts) < 0), f"Timesteps not decreasing: {ts[:5]}..."
|
||||
|
||||
def test_sigmas_decreasing(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(20, shift=1.0)
|
||||
mx.eval(sched.sigmas)
|
||||
sigmas = np.array(sched.sigmas)
|
||||
assert np.all(np.diff(sigmas) <= 0), "Sigmas not decreasing"
|
||||
|
||||
def test_terminal_sigma_is_zero(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(20, shift=5.0)
|
||||
mx.eval(sched.sigmas)
|
||||
np.testing.assert_allclose(np.array(sched.sigmas[-1]), 0.0, atol=1e-6)
|
||||
|
||||
def test_shift_effect(self):
|
||||
"""Larger shift should push sigmas toward higher values."""
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
sched1 = FlowMatchEulerScheduler()
|
||||
sched2 = FlowMatchEulerScheduler()
|
||||
sched1.set_timesteps(20, shift=1.0)
|
||||
sched2.set_timesteps(20, shift=12.0)
|
||||
mx.eval(sched1.sigmas, sched2.sigmas)
|
||||
mean1 = np.mean(np.array(sched1.sigmas[:-1]))
|
||||
mean2 = np.mean(np.array(sched2.sigmas[:-1]))
|
||||
assert mean2 > mean1, "Higher shift should push sigmas higher"
|
||||
|
||||
def test_step_euler(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(10, shift=1.0)
|
||||
mx.eval(sched.sigmas)
|
||||
|
||||
sample = mx.ones((1, 4, 2, 2, 2))
|
||||
velocity = mx.ones((1, 4, 2, 2, 2)) * 0.5
|
||||
timestep = sched.timesteps[0]
|
||||
|
||||
sigma = float(np.array(sched.sigmas[0]))
|
||||
sigma_next = float(np.array(sched.sigmas[1]))
|
||||
|
||||
result = sched.step(velocity, timestep, sample)
|
||||
mx.eval(result)
|
||||
|
||||
# Euler: x_next = x + (sigma_next - sigma) * v
|
||||
expected = 1.0 + (sigma_next - sigma) * 0.5
|
||||
np.testing.assert_allclose(
|
||||
np.array(result).flatten()[0], expected, rtol=1e-4,
|
||||
)
|
||||
|
||||
def test_step_index_increments(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
assert sched._step_index == 0
|
||||
sample = mx.ones((1, 1, 1, 1, 1))
|
||||
vel = mx.zeros((1, 1, 1, 1, 1))
|
||||
sched.step(vel, sched.timesteps[0], sample)
|
||||
assert sched._step_index == 1
|
||||
sched.step(vel, sched.timesteps[1], sample)
|
||||
assert sched._step_index == 2
|
||||
|
||||
def test_reset(self):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 1, 1))
|
||||
vel = mx.zeros((1, 1, 1, 1, 1))
|
||||
sched.step(vel, sched.timesteps[0], sample)
|
||||
assert sched._step_index == 1
|
||||
sched.reset()
|
||||
assert sched._step_index == 0
|
||||
|
||||
@pytest.mark.parametrize("steps", [10, 20, 40, 50])
|
||||
def test_various_step_counts(self, steps):
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(steps, shift=12.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
assert sched.timesteps.shape == (steps,)
|
||||
assert sched.sigmas.shape == (steps + 1,)
|
||||
|
||||
def test_full_denoise_loop(self):
|
||||
"""Run a complete denoise loop with zero velocity -> sample unchanged."""
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 2, 1, 2, 2))
|
||||
for i in range(5):
|
||||
vel = mx.zeros_like(sample)
|
||||
sample = sched.step(vel, sched.timesteps[i], sample)
|
||||
mx.eval(sample)
|
||||
# With zero velocity, sample should remain unchanged
|
||||
np.testing.assert_allclose(np.array(sample), 1.0, atol=1e-5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared Sigma Schedule Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComputeSigmas:
|
||||
"""Tests for the shared _compute_sigmas helper."""
|
||||
|
||||
def test_length(self):
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
sigmas = _compute_sigmas(20, shift=5.0)
|
||||
assert len(sigmas) == 21 # num_steps + terminal
|
||||
|
||||
def test_terminal_zero(self):
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
sigmas = _compute_sigmas(10, shift=1.0)
|
||||
assert sigmas[-1] == 0.0
|
||||
|
||||
def test_starts_near_one(self):
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
sigmas = _compute_sigmas(20, shift=5.0)
|
||||
# Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0)
|
||||
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3)
|
||||
|
||||
def test_decreasing(self):
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
sigmas = _compute_sigmas(20, shift=5.0)
|
||||
assert np.all(np.diff(sigmas) <= 0)
|
||||
|
||||
def test_matches_official_wan22(self):
|
||||
"""Sigma schedule should match the official Wan2.2 FlowUniPCMultistepScheduler.
|
||||
|
||||
The reference creates the scheduler with shift=1 (identity) in the
|
||||
constructor, then passes the actual shift to set_timesteps. This means
|
||||
sigma_max/sigma_min come from the *unshifted* training schedule, and the
|
||||
shift is applied only once (single-shift).
|
||||
"""
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
steps, shift, N = 50, 5.0, 1000
|
||||
sigmas = _compute_sigmas(steps, shift, N)
|
||||
# Official single-shift: unshifted bounds, then shift once
|
||||
alphas = np.linspace(1.0, 1.0 / N, N)[::-1]
|
||||
sigmas_unshifted = 1.0 - alphas
|
||||
sigma_max = float(sigmas_unshifted[0]) # 0.999
|
||||
sigma_min = float(sigmas_unshifted[-1]) # 0.0
|
||||
official = np.linspace(sigma_max, sigma_min, steps + 1)[:-1]
|
||||
official = shift * official / (1.0 + (shift - 1.0) * official)
|
||||
official = np.append(official, 0.0).astype(np.float32)
|
||||
np.testing.assert_allclose(sigmas, official, atol=1e-6)
|
||||
|
||||
def test_shift_one_is_near_linear(self):
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
sigmas = _compute_sigmas(10, shift=1.0)
|
||||
# With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule)
|
||||
# so schedule is nearly linear from ~0.999 to 0
|
||||
expected = np.linspace(1, 0, 11).astype(np.float32)
|
||||
np.testing.assert_allclose(sigmas, expected, atol=2e-3)
|
||||
|
||||
def test_all_schedulers_same_sigmas(self):
|
||||
"""All three schedulers should produce identical sigma schedules."""
|
||||
from mlx_video.models.wan.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowMatchEulerScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
scheds = [
|
||||
FlowMatchEulerScheduler(1000),
|
||||
FlowDPMPP2MScheduler(1000),
|
||||
FlowUniPCScheduler(1000),
|
||||
]
|
||||
for s in scheds:
|
||||
s.set_timesteps(20, shift=5.0)
|
||||
mx.eval(*[s.sigmas for s in scheds])
|
||||
ref = np.array(scheds[0].sigmas)
|
||||
for s in scheds[1:]:
|
||||
np.testing.assert_allclose(np.array(s.sigmas), ref, atol=1e-6)
|
||||
|
||||
def test_all_schedulers_same_timesteps(self):
|
||||
from mlx_video.models.wan.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowMatchEulerScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
scheds = [
|
||||
FlowMatchEulerScheduler(1000),
|
||||
FlowDPMPP2MScheduler(1000),
|
||||
FlowUniPCScheduler(1000),
|
||||
]
|
||||
for s in scheds:
|
||||
s.set_timesteps(30, shift=12.0)
|
||||
mx.eval(*[s.timesteps for s in scheds])
|
||||
ref = np.array(scheds[0].timesteps)
|
||||
for s in scheds[1:]:
|
||||
np.testing.assert_allclose(np.array(s.timesteps), ref, atol=1e-3)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DPM++ 2M Scheduler Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlowDPMPP2MScheduler:
|
||||
def test_initialization(self):
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
assert sched.num_train_timesteps == 1000
|
||||
assert sched.lower_order_final is True
|
||||
|
||||
def test_set_timesteps(self):
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(20, shift=5.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
assert sched.timesteps.shape == (20,)
|
||||
assert sched.sigmas.shape == (21,)
|
||||
|
||||
def test_step_index_increments(self):
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 4, 1, 2, 2))
|
||||
vel = mx.zeros_like(sample)
|
||||
assert sched._step_index == 0
|
||||
sched.step(vel, sched.timesteps[0], sample)
|
||||
assert sched._step_index == 1
|
||||
sched.step(vel, sched.timesteps[1], sample)
|
||||
assert sched._step_index == 2
|
||||
|
||||
def test_reset(self):
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 1, 1))
|
||||
sched.step(mx.zeros_like(sample), 0, sample)
|
||||
sched.reset()
|
||||
assert sched._step_index == 0
|
||||
assert sched._prev_x0 is None
|
||||
|
||||
def test_full_loop_finite(self):
|
||||
"""Full loop with constant velocity should produce finite output."""
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(10, shift=1.0)
|
||||
sample = mx.ones((1, 2, 1, 2, 2))
|
||||
for i in range(10):
|
||||
vel = mx.ones_like(sample) * 0.1
|
||||
sample = sched.step(vel, sched.timesteps[i], sample)
|
||||
mx.eval(sample)
|
||||
assert np.isfinite(np.array(sample)).all()
|
||||
|
||||
def test_first_step_is_first_order(self):
|
||||
"""First step should use 1st-order (no prev_x0 available)."""
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
sample = mx.random.normal((1, 4, 2, 4, 4))
|
||||
vel = mx.random.normal(sample.shape)
|
||||
# Before first step, no prev_x0
|
||||
assert sched._prev_x0 is None
|
||||
result = sched.step(vel, sched.timesteps[0], sample)
|
||||
mx.eval(result)
|
||||
# After first step, prev_x0 should be set
|
||||
assert sched._prev_x0 is not None
|
||||
|
||||
def test_second_step_uses_correction(self):
|
||||
"""After first step, DPM++ should have stored prev_x0 for correction."""
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
sample = mx.random.normal((1, 4, 1, 2, 2))
|
||||
vel = mx.random.normal(sample.shape)
|
||||
# Step 1
|
||||
sample = sched.step(vel, sched.timesteps[0], sample)
|
||||
mx.eval(sample)
|
||||
x0_after_first = sched._prev_x0
|
||||
# Step 2
|
||||
vel = mx.random.normal(sample.shape)
|
||||
sample = sched.step(vel, sched.timesteps[1], sample)
|
||||
mx.eval(sample)
|
||||
# prev_x0 should have been updated
|
||||
x0_after_second = sched._prev_x0
|
||||
assert x0_after_second is not None
|
||||
# The stored x0 should differ from the first step's
|
||||
assert not np.allclose(np.array(x0_after_first), np.array(x0_after_second), atol=1e-6)
|
||||
|
||||
def test_denoise_to_target(self):
|
||||
"""Perfect oracle should denoise to target with any solver."""
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(20, shift=5.0)
|
||||
target = mx.zeros((1, 2, 1, 4, 4))
|
||||
latents = mx.random.normal(target.shape)
|
||||
for i in range(20):
|
||||
sigma = float(sched.sigmas[i].item())
|
||||
v = latents / max(sigma, 1e-6) # perfect velocity for target=0
|
||||
latents = sched.step(v, sched.timesteps[i], latents)
|
||||
mx.eval(latents)
|
||||
np.testing.assert_allclose(np.array(latents), 0.0, atol=1e-3)
|
||||
|
||||
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
|
||||
def test_various_step_counts(self, steps):
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(steps, shift=5.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
assert sched.timesteps.shape == (steps,)
|
||||
assert sched.sigmas.shape == (steps + 1,)
|
||||
|
||||
def test_terminal_sigma_produces_x0(self):
|
||||
"""When sigma_next=0 the scheduler should return x0 directly."""
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 1, 1)) * 3.0
|
||||
vel = mx.ones_like(sample) * 2.0
|
||||
# Run through all steps; the last step has sigma_next=0
|
||||
for i in range(5):
|
||||
sample = sched.step(vel, sched.timesteps[i], sample)
|
||||
mx.eval(sample)
|
||||
# Final value should be finite
|
||||
assert np.isfinite(np.array(sample)).all()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# UniPC Scheduler Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlowUniPCScheduler:
|
||||
def test_initialization(self):
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
sched = FlowUniPCScheduler()
|
||||
assert sched.num_train_timesteps == 1000
|
||||
assert sched.solver_order == 2
|
||||
assert sched.lower_order_final is True
|
||||
|
||||
def test_set_timesteps(self):
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(30, shift=12.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
assert sched.timesteps.shape == (30,)
|
||||
assert sched.sigmas.shape == (31,)
|
||||
|
||||
def test_step_index_increments(self):
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 1, 1))
|
||||
vel = mx.zeros_like(sample)
|
||||
assert sched._step_index == 0
|
||||
sched.step(vel, 0, sample)
|
||||
assert sched._step_index == 1
|
||||
|
||||
def test_reset(self):
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 1, 1))
|
||||
sched.step(mx.zeros_like(sample), 0, sample)
|
||||
sched.reset()
|
||||
assert sched._step_index == 0
|
||||
assert sched._lower_order_nums == 0
|
||||
assert sched._last_sample is None
|
||||
assert all(m is None for m in sched._model_outputs)
|
||||
|
||||
def test_full_loop_finite(self):
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(10, shift=1.0)
|
||||
sample = mx.ones((1, 2, 1, 2, 2))
|
||||
for i in range(10):
|
||||
vel = mx.ones_like(sample) * 0.1
|
||||
sample = sched.step(vel, sched.timesteps[i], sample)
|
||||
mx.eval(sample)
|
||||
assert np.isfinite(np.array(sample)).all()
|
||||
|
||||
def test_corrector_not_applied_first_step(self):
|
||||
"""First step should skip the corrector (no history)."""
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
sched = FlowUniPCScheduler(use_corrector=True)
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
sample = mx.random.normal((1, 4, 1, 2, 2))
|
||||
vel = mx.random.normal(sample.shape)
|
||||
# Before step 0: no last_sample
|
||||
assert sched._last_sample is None
|
||||
sched.step(vel, sched.timesteps[0], sample)
|
||||
# After step 0: last_sample should be set for corrector on step 1
|
||||
assert sched._last_sample is not None
|
||||
|
||||
def test_corrector_applied_after_first_step(self):
|
||||
"""Steps after the first should use the corrector when enabled."""
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
sched = FlowUniPCScheduler(use_corrector=True)
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
sample = mx.random.normal((1, 2, 1, 4, 4))
|
||||
for i in range(3):
|
||||
vel = mx.random.normal(sample.shape)
|
||||
sample = sched.step(vel, sched.timesteps[i], sample)
|
||||
mx.eval(sample)
|
||||
# lower_order_nums should have increased
|
||||
assert sched._lower_order_nums >= 2
|
||||
|
||||
def test_denoise_to_target(self):
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(20, shift=5.0)
|
||||
target = mx.zeros((1, 2, 1, 4, 4))
|
||||
latents = mx.random.normal(target.shape)
|
||||
for i in range(20):
|
||||
sigma = float(sched.sigmas[i].item())
|
||||
v = latents / max(sigma, 1e-6)
|
||||
latents = sched.step(v, sched.timesteps[i], latents)
|
||||
mx.eval(latents)
|
||||
np.testing.assert_allclose(np.array(latents), 0.0, atol=1e-3)
|
||||
|
||||
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
|
||||
def test_various_step_counts(self, steps):
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(steps, shift=5.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
assert sched.timesteps.shape == (steps,)
|
||||
assert sched.sigmas.shape == (steps + 1,)
|
||||
|
||||
def test_disable_corrector(self):
|
||||
"""Disabling corrector on step 0 should still work without error."""
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0])
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 2, 2))
|
||||
for i in range(5):
|
||||
vel = mx.ones_like(sample) * 0.1
|
||||
sample = sched.step(vel, sched.timesteps[i], sample)
|
||||
mx.eval(sample)
|
||||
assert np.isfinite(np.array(sample)).all()
|
||||
|
||||
def test_solver_order_3(self):
|
||||
"""Order 3 should work without error."""
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
sched = FlowUniPCScheduler(solver_order=3, use_corrector=True)
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
sample = mx.random.normal((1, 2, 1, 2, 2))
|
||||
for i in range(10):
|
||||
vel = mx.random.normal(sample.shape)
|
||||
sample = sched.step(vel, sched.timesteps[i], sample)
|
||||
mx.eval(sample)
|
||||
assert np.isfinite(np.array(sample)).all()
|
||||
|
||||
def test_corrector_rhos_c_not_hardcoded(self):
|
||||
"""Corrector rhos_c should be computed via linalg.solve, not hardcoded 0.5."""
|
||||
import math
|
||||
# For 50-step schedule with shift=5.0, order 2 corrector at step 5:
|
||||
# rhos_c[0] (history) should be ~0.07, NOT 0.5
|
||||
# rhos_c[1] (D1_t) should be ~0.45, NOT 0.5
|
||||
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(50, shift=5.0)
|
||||
|
||||
def _lambda(sigma):
|
||||
if sigma >= 1.0:
|
||||
return -math.inf
|
||||
if sigma <= 0.0:
|
||||
return math.inf
|
||||
return math.log(1 - sigma) - math.log(sigma)
|
||||
|
||||
for step_idx in [5, 10, 25, 45]:
|
||||
sigma_s0 = sigmas[step_idx - 1]
|
||||
sigma_t = sigmas[step_idx]
|
||||
lambda_s0 = _lambda(sigma_s0)
|
||||
lambda_t = _lambda(sigma_t)
|
||||
h = lambda_t - lambda_s0
|
||||
hh = -h
|
||||
|
||||
sigma_sk = sigmas[step_idx - 2]
|
||||
lambda_sk = _lambda(sigma_sk)
|
||||
rk = (lambda_sk - lambda_s0) / h
|
||||
rks = np.array([rk, 1.0])
|
||||
|
||||
h_phi_1 = math.expm1(hh)
|
||||
B_h = h_phi_1
|
||||
h_phi_k = h_phi_1 / hh - 1.0
|
||||
factorial_i = 1
|
||||
R_rows, b_vals = [], []
|
||||
for j in range(1, 3):
|
||||
R_rows.append(rks ** (j - 1))
|
||||
b_vals.append(h_phi_k * factorial_i / B_h)
|
||||
factorial_i *= j + 1
|
||||
h_phi_k = h_phi_k / hh - 1.0 / factorial_i
|
||||
R = np.stack(R_rows)
|
||||
b = np.array(b_vals)
|
||||
rhos_c = np.linalg.solve(R, b)
|
||||
|
||||
# History weight should be small (~0.07-0.09), not 0.5
|
||||
assert rhos_c[0] < 0.15, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large"
|
||||
assert rhos_c[0] > 0.0, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive"
|
||||
# D1_t weight should be ~0.42-0.45, not 0.5
|
||||
assert 0.3 < rhos_c[1] < 0.5, f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scheduler Coherence Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSchedulerCoherence:
|
||||
"""Tests that Euler, DPM++, and UniPC schedulers produce coherent results.
|
||||
|
||||
All three schedulers should agree on shared structure (sigma schedules,
|
||||
first-step behavior) and converge to the same result given perfect
|
||||
velocity oracles, even though they use different update rules.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_schedulers(steps=10, shift=5.0):
|
||||
from mlx_video.models.wan.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowMatchEulerScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
|
||||
scheds = {
|
||||
"euler": FlowMatchEulerScheduler(),
|
||||
"dpm++": FlowDPMPP2MScheduler(),
|
||||
"unipc": FlowUniPCScheduler(),
|
||||
}
|
||||
for s in scheds.values():
|
||||
s.set_timesteps(steps, shift=shift)
|
||||
return scheds
|
||||
|
||||
def test_identical_sigma_schedules(self):
|
||||
"""All schedulers must use the same sigma schedule."""
|
||||
scheds = self._make_schedulers(20, shift=5.0)
|
||||
ref = np.array(scheds["euler"].sigmas)
|
||||
for name in ("dpm++", "unipc"):
|
||||
np.testing.assert_allclose(
|
||||
np.array(scheds[name].sigmas),
|
||||
ref,
|
||||
atol=1e-6,
|
||||
err_msg=f"{name} sigma schedule differs from Euler",
|
||||
)
|
||||
|
||||
def test_identical_timesteps(self):
|
||||
"""All schedulers must produce the same timestep sequence."""
|
||||
scheds = self._make_schedulers(20, shift=5.0)
|
||||
ref = np.array(scheds["euler"].timesteps)
|
||||
for name in ("dpm++", "unipc"):
|
||||
np.testing.assert_allclose(
|
||||
np.array(scheds[name].timesteps),
|
||||
ref,
|
||||
atol=1e-6,
|
||||
err_msg=f"{name} timesteps differ from Euler",
|
||||
)
|
||||
|
||||
def test_first_step_matches_euler(self):
|
||||
"""Step 0 (1st-order for all solvers) should match Euler exactly."""
|
||||
mx.random.seed(42)
|
||||
shape = (1, 4, 1, 4, 4)
|
||||
noise = mx.random.normal(shape)
|
||||
vel = mx.random.normal(shape)
|
||||
|
||||
scheds = self._make_schedulers(10, shift=5.0)
|
||||
results = {}
|
||||
for name, sched in scheds.items():
|
||||
r = sched.step(vel, sched.timesteps[0], noise)
|
||||
mx.eval(r)
|
||||
results[name] = np.array(r)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
results["dpm++"], results["euler"], atol=1e-5,
|
||||
err_msg="DPM++ step 0 should match Euler",
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
results["unipc"], results["euler"], atol=1e-5,
|
||||
err_msg="UniPC step 0 should match Euler",
|
||||
)
|
||||
|
||||
def test_first_step_matches_across_shifts(self):
|
||||
"""Step 0 should match Euler for different shift values."""
|
||||
mx.random.seed(99)
|
||||
shape = (1, 2, 1, 2, 2)
|
||||
noise = mx.random.normal(shape)
|
||||
vel = mx.random.normal(shape)
|
||||
|
||||
for shift in (1.0, 5.0, 12.0):
|
||||
scheds = self._make_schedulers(10, shift=shift)
|
||||
euler_r = scheds["euler"].step(vel, scheds["euler"].timesteps[0], noise)
|
||||
dpm_r = scheds["dpm++"].step(vel, scheds["dpm++"].timesteps[0], noise)
|
||||
unipc_r = scheds["unipc"].step(vel, scheds["unipc"].timesteps[0], noise)
|
||||
mx.eval(euler_r, dpm_r, unipc_r)
|
||||
np.testing.assert_allclose(
|
||||
np.array(dpm_r), np.array(euler_r), atol=1e-5,
|
||||
err_msg=f"DPM++ step 0 differs from Euler at shift={shift}",
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
np.array(unipc_r), np.array(euler_r), atol=1e-5,
|
||||
err_msg=f"UniPC step 0 differs from Euler at shift={shift}",
|
||||
)
|
||||
|
||||
def test_oracle_all_converge_to_target(self):
|
||||
"""Given a perfect velocity oracle v=x/sigma, all solvers should
|
||||
denoise to approximately zero (the target)."""
|
||||
mx.random.seed(7)
|
||||
shape = (1, 2, 1, 4, 4)
|
||||
noise = mx.random.normal(shape)
|
||||
|
||||
for name, sched in self._make_schedulers(20, shift=5.0).items():
|
||||
latents = noise
|
||||
for i in range(20):
|
||||
sigma = float(sched.sigmas[i].item())
|
||||
v = latents / max(sigma, 1e-8)
|
||||
latents = sched.step(v, sched.timesteps[i], latents)
|
||||
mx.eval(latents)
|
||||
np.testing.assert_allclose(
|
||||
np.array(latents), 0.0, atol=1e-3,
|
||||
err_msg=f"{name} did not converge to target with oracle",
|
||||
)
|
||||
|
||||
def test_oracle_higher_order_closer_to_target(self):
|
||||
"""With few steps and a perfect oracle, higher-order solvers should
|
||||
be at least as accurate as Euler."""
|
||||
mx.random.seed(12)
|
||||
shape = (1, 2, 1, 4, 4)
|
||||
noise = mx.random.normal(shape)
|
||||
steps = 5
|
||||
|
||||
errors = {}
|
||||
for name, sched in self._make_schedulers(steps, shift=5.0).items():
|
||||
latents = noise
|
||||
for i in range(steps):
|
||||
sigma = float(sched.sigmas[i].item())
|
||||
v = latents / max(sigma, 1e-8)
|
||||
latents = sched.step(v, sched.timesteps[i], latents)
|
||||
mx.eval(latents)
|
||||
errors[name] = float(mx.mean(mx.abs(latents)).item())
|
||||
|
||||
# Higher-order solvers should not be significantly worse than Euler
|
||||
# (add small epsilon to handle near-zero errors from floating point noise)
|
||||
eps = 1e-6
|
||||
assert errors["dpm++"] <= errors["euler"] * 1.5 + eps, (
|
||||
f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}"
|
||||
)
|
||||
assert errors["unipc"] <= errors["euler"] * 1.5 + eps, (
|
||||
f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}"
|
||||
)
|
||||
|
||||
def test_multistep_trajectory_similar_magnitude(self):
|
||||
"""Over a full denoising loop with constant velocity, all solvers
|
||||
should produce outputs of similar magnitude (not diverging)."""
|
||||
mx.random.seed(42)
|
||||
shape = (1, 4, 1, 4, 4)
|
||||
noise = mx.random.normal(shape)
|
||||
steps = 20
|
||||
|
||||
final_means = {}
|
||||
for name, sched in self._make_schedulers(steps, shift=5.0).items():
|
||||
latents = noise
|
||||
for i in range(steps):
|
||||
vel = latents * 0.1
|
||||
latents = sched.step(vel, sched.timesteps[i], latents)
|
||||
mx.eval(latents)
|
||||
final_means[name] = float(mx.mean(mx.abs(latents)).item())
|
||||
|
||||
# All solvers should produce results within the same order of magnitude
|
||||
vals = list(final_means.values())
|
||||
ratio = max(vals) / max(min(vals), 1e-10)
|
||||
assert ratio < 10.0, (
|
||||
f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}"
|
||||
)
|
||||
|
||||
def test_intermediate_values_finite(self):
|
||||
"""Every intermediate latent value must be finite for all solvers."""
|
||||
mx.random.seed(0)
|
||||
shape = (1, 2, 1, 2, 2)
|
||||
noise = mx.random.normal(shape)
|
||||
|
||||
for name, sched in self._make_schedulers(15, shift=5.0).items():
|
||||
latents = noise
|
||||
for i in range(15):
|
||||
vel = mx.random.normal(shape)
|
||||
latents = sched.step(vel, sched.timesteps[i], latents)
|
||||
mx.eval(latents)
|
||||
assert np.isfinite(np.array(latents)).all(), (
|
||||
f"{name} produced non-finite values at step {i}"
|
||||
)
|
||||
|
||||
def test_lambda_boundary_values(self):
|
||||
"""_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0."""
|
||||
from mlx_video.models.wan.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
|
||||
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
|
||||
assert cls._lambda(1.0) == -math.inf, (
|
||||
f"{cls.__name__}._lambda(1.0) should be -inf"
|
||||
)
|
||||
assert cls._lambda(0.0) == math.inf, (
|
||||
f"{cls.__name__}._lambda(0.0) should be +inf"
|
||||
)
|
||||
# Interior values should be finite
|
||||
lam = cls._lambda(0.5)
|
||||
assert math.isfinite(lam) and lam == 0.0, (
|
||||
f"{cls.__name__}._lambda(0.5) should be 0.0"
|
||||
)
|
||||
|
||||
def test_lambda_monotonically_decreasing(self):
|
||||
"""_lambda(sigma) should decrease as sigma increases (more noise → lower SNR)."""
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sigmas = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99]
|
||||
lambdas = [FlowDPMPP2MScheduler._lambda(s) for s in sigmas]
|
||||
for i in range(len(lambdas) - 1):
|
||||
assert lambdas[i] > lambdas[i + 1], (
|
||||
f"_lambda not decreasing: _lambda({sigmas[i]})={lambdas[i]} "
|
||||
f"vs _lambda({sigmas[i+1]})={lambdas[i+1]}"
|
||||
)
|
||||
|
||||
def test_step0_is_ddim_formula(self):
|
||||
"""At sigma=1.0, the DPM++/UniPC first step should reduce to the
|
||||
DDIM formula: x_next = sigma_next * x + (1 - sigma_next) * x0."""
|
||||
mx.random.seed(55)
|
||||
shape = (1, 2, 1, 2, 2)
|
||||
sample = mx.random.normal(shape)
|
||||
vel = mx.random.normal(shape)
|
||||
|
||||
for steps, shift in [(10, 5.0), (20, 12.0)]:
|
||||
scheds = self._make_schedulers(steps, shift=shift)
|
||||
sigma_next = float(scheds["euler"].sigmas[1].item())
|
||||
sigma_cur = float(scheds["euler"].sigmas[0].item())
|
||||
assert abs(sigma_cur - 1.0) < 1e-3, "First sigma should be ~1.0"
|
||||
|
||||
x0 = sample - sigma_cur * vel
|
||||
expected = sigma_next * sample + (1.0 - sigma_next) * x0
|
||||
mx.eval(expected)
|
||||
|
||||
for name in ("dpm++", "unipc"):
|
||||
result = scheds[name].step(vel, scheds[name].timesteps[0], sample)
|
||||
mx.eval(result)
|
||||
np.testing.assert_allclose(
|
||||
np.array(result), np.array(expected), atol=5e-4,
|
||||
err_msg=f"{name} step 0 doesn't match DDIM formula (shift={shift})",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
|
||||
def test_coherent_across_step_counts(self, steps):
|
||||
"""All solvers should agree on step 0 regardless of total step count."""
|
||||
mx.random.seed(77)
|
||||
shape = (1, 2, 1, 2, 2)
|
||||
noise = mx.random.normal(shape)
|
||||
vel = mx.random.normal(shape)
|
||||
|
||||
scheds = self._make_schedulers(steps, shift=5.0)
|
||||
results = {}
|
||||
for name, sched in scheds.items():
|
||||
r = sched.step(vel, sched.timesteps[0], noise)
|
||||
mx.eval(r)
|
||||
results[name] = np.array(r)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
results["dpm++"], results["euler"], atol=1e-5,
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
results["unipc"], results["euler"], atol=1e-5,
|
||||
)
|
||||
|
||||
def test_dpmpp_unipc_agree_on_step1(self):
|
||||
"""After warmup, DPM++ and UniPC step 1 should be similar
|
||||
(both use 2nd-order corrections based on the same model outputs)."""
|
||||
mx.random.seed(42)
|
||||
shape = (1, 4, 1, 4, 4)
|
||||
noise = mx.random.normal(shape)
|
||||
|
||||
scheds = self._make_schedulers(10, shift=5.0)
|
||||
# Run step 0 with same velocity
|
||||
vel0 = mx.random.normal(shape)
|
||||
for sched in scheds.values():
|
||||
sched.step(vel0, sched.timesteps[0], noise)
|
||||
|
||||
# Run step 1 from same sample with same velocity
|
||||
sample1 = scheds["euler"].step(vel0, scheds["euler"].timesteps[0], noise)
|
||||
mx.eval(sample1)
|
||||
vel1 = mx.random.normal(shape)
|
||||
|
||||
r_dpm = scheds["dpm++"].step(vel1, scheds["dpm++"].timesteps[1], sample1)
|
||||
r_unipc = scheds["unipc"].step(vel1, scheds["unipc"].timesteps[1], sample1)
|
||||
mx.eval(r_dpm, r_unipc)
|
||||
|
||||
# They won't be identical (different correction formulas) but should
|
||||
# be in the same ballpark (within 50% of each other's magnitude)
|
||||
mean_dpm = float(mx.mean(mx.abs(r_dpm)).item())
|
||||
mean_unipc = float(mx.mean(mx.abs(r_unipc)).item())
|
||||
ratio = max(mean_dpm, mean_unipc) / max(min(mean_dpm, mean_unipc), 1e-10)
|
||||
assert ratio < 2.0, (
|
||||
f"DPM++ and UniPC step 1 differ too much: "
|
||||
f"DPM++={mean_dpm:.4f}, UniPC={mean_unipc:.4f}"
|
||||
)
|
||||
|
||||
def test_reset_makes_solvers_reproducible(self):
|
||||
"""After reset(), running the same loop should produce identical output."""
|
||||
mx.random.seed(42)
|
||||
shape = (1, 2, 1, 2, 2)
|
||||
noise = mx.random.normal(shape)
|
||||
|
||||
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler, FlowUniPCScheduler
|
||||
|
||||
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
|
||||
sched = cls()
|
||||
sched.set_timesteps(5, shift=5.0)
|
||||
|
||||
# First run
|
||||
latents = noise
|
||||
for i in range(5):
|
||||
vel = latents * 0.1
|
||||
latents = sched.step(vel, sched.timesteps[i], latents)
|
||||
mx.eval(latents)
|
||||
result1 = np.array(latents)
|
||||
|
||||
# Reset and run again
|
||||
sched.reset()
|
||||
latents = noise
|
||||
for i in range(5):
|
||||
vel = latents * 0.1
|
||||
latents = sched.step(vel, sched.timesteps[i], latents)
|
||||
mx.eval(latents)
|
||||
result2 = np.array(latents)
|
||||
|
||||
np.testing.assert_allclose(result1, result2, atol=1e-5,
|
||||
err_msg=f"{cls.__name__} not reproducible after reset()")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# UniPC Corrector Default Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUniPCCorrectorDefault:
|
||||
"""Tests that the UniPC corrector is enabled by default,
|
||||
matching official FlowUniPCMultistepScheduler behavior."""
|
||||
|
||||
def test_corrector_enabled_by_default(self):
|
||||
"""Default construction should have corrector enabled."""
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
sched = FlowUniPCScheduler()
|
||||
assert sched._use_corrector is True
|
||||
|
||||
def test_corrector_affects_output(self):
|
||||
"""Corrector should produce different results than no corrector after step 1."""
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
mx.random.seed(42)
|
||||
shape = (1, 4, 1, 4, 4)
|
||||
noise = mx.random.normal(shape)
|
||||
|
||||
sched_corr = FlowUniPCScheduler(use_corrector=True)
|
||||
sched_corr.set_timesteps(10, shift=5.0)
|
||||
sched_no = FlowUniPCScheduler(use_corrector=False)
|
||||
sched_no.set_timesteps(10, shift=5.0)
|
||||
|
||||
latent_corr = noise
|
||||
latent_no = noise
|
||||
for i in range(3):
|
||||
vel = mx.random.normal(shape) * 0.1
|
||||
latent_corr = sched_corr.step(vel, sched_corr.timesteps[i], latent_corr)
|
||||
latent_no = sched_no.step(vel, sched_no.timesteps[i], latent_no)
|
||||
mx.eval(latent_corr, latent_no)
|
||||
|
||||
diff = float(mx.abs(latent_corr - latent_no).max())
|
||||
assert diff > 1e-6, f"Corrector had no effect (max diff={diff})"
|
||||
|
||||
def test_corrector_does_not_affect_first_step(self):
|
||||
"""Step 0 should be identical regardless of corrector setting."""
|
||||
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||
mx.random.seed(42)
|
||||
shape = (1, 4, 1, 4, 4)
|
||||
noise = mx.random.normal(shape)
|
||||
vel = mx.random.normal(shape)
|
||||
|
||||
sched_corr = FlowUniPCScheduler(use_corrector=True)
|
||||
sched_corr.set_timesteps(10, shift=5.0)
|
||||
sched_no = FlowUniPCScheduler(use_corrector=False)
|
||||
sched_no.set_timesteps(10, shift=5.0)
|
||||
|
||||
r1 = sched_corr.step(vel, sched_corr.timesteps[0], noise)
|
||||
r2 = sched_no.step(vel, sched_no.timesteps[0], noise)
|
||||
mx.eval(r1, r2)
|
||||
np.testing.assert_allclose(np.array(r1), np.array(r2), atol=1e-6)
|
||||
173
tests/test_wan_t5.py
Normal file
173
tests/test_wan_t5.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Tests for T5 encoder components."""
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# T5 Encoder Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestT5LayerNorm:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.text_encoder import T5LayerNorm
|
||||
norm = T5LayerNorm(64)
|
||||
x = mx.random.normal((2, 10, 64))
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (2, 10, 64)
|
||||
|
||||
def test_rms_normalization(self):
|
||||
"""After T5LayerNorm with weight=1, RMS should be ~1."""
|
||||
from mlx_video.models.wan.text_encoder import T5LayerNorm
|
||||
norm = T5LayerNorm(128)
|
||||
x = mx.random.normal((1, 5, 128)) * 5.0
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
out_np = np.array(out[0])
|
||||
for i in range(5):
|
||||
rms = np.sqrt(np.mean(out_np[i] ** 2))
|
||||
np.testing.assert_allclose(rms, 1.0, rtol=0.1)
|
||||
|
||||
|
||||
class TestT5RelativeEmbedding:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
|
||||
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
|
||||
out = rel_emb(10, 10)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 4, 10, 10) # [1, N, lq, lk]
|
||||
|
||||
def test_asymmetric_lengths(self):
|
||||
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
|
||||
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
|
||||
out = rel_emb(8, 12)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 4, 8, 12)
|
||||
|
||||
def test_symmetry(self):
|
||||
"""Position bias should have structure (not all zeros/random)."""
|
||||
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
|
||||
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2)
|
||||
out = rel_emb(6, 6)
|
||||
mx.eval(out)
|
||||
out_np = np.array(out[0]) # [N, lq, lk]
|
||||
# Diagonal elements (position i attending to position i) should be consistent
|
||||
# (same relative distance = 0 for all diagonal elements)
|
||||
for h in range(2):
|
||||
diag = np.diag(out_np[h])
|
||||
np.testing.assert_allclose(diag, diag[0], atol=1e-5)
|
||||
|
||||
|
||||
class TestT5Attention:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.text_encoder import T5Attention
|
||||
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||
x = mx.random.normal((1, 10, 64))
|
||||
out = attn(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 10, 64)
|
||||
|
||||
def test_no_scaling(self):
|
||||
"""T5 attention famously has no sqrt(d) scaling. Verify structure."""
|
||||
from mlx_video.models.wan.text_encoder import T5Attention
|
||||
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||
# No scale attribute (unlike standard attention)
|
||||
assert not hasattr(attn, "scale")
|
||||
|
||||
def test_with_position_bias(self):
|
||||
from mlx_video.models.wan.text_encoder import T5Attention, T5RelativeEmbedding
|
||||
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||
rel_emb = T5RelativeEmbedding(32, 4)
|
||||
x = mx.random.normal((1, 10, 64))
|
||||
pos_bias = rel_emb(10, 10)
|
||||
out = attn(x, pos_bias=pos_bias)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 10, 64)
|
||||
|
||||
def test_with_mask(self):
|
||||
from mlx_video.models.wan.text_encoder import T5Attention
|
||||
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||
x = mx.random.normal((1, 10, 64))
|
||||
mask = mx.ones((1, 10))
|
||||
mask = mx.concatenate([mask[:, :7], mx.zeros((1, 3))], axis=1)
|
||||
out = attn(x, mask=mask)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 10, 64)
|
||||
|
||||
|
||||
class TestT5FeedForward:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.text_encoder import T5FeedForward
|
||||
ffn = T5FeedForward(64, 256)
|
||||
x = mx.random.normal((1, 10, 64))
|
||||
out = ffn(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 10, 64)
|
||||
|
||||
def test_gated_structure(self):
|
||||
"""T5 FFN is gated: gate(x) * fc1(x)."""
|
||||
from mlx_video.models.wan.text_encoder import T5FeedForward
|
||||
ffn = T5FeedForward(32, 64)
|
||||
assert hasattr(ffn, "gate_proj")
|
||||
assert hasattr(ffn, "fc1")
|
||||
assert hasattr(ffn, "fc2")
|
||||
|
||||
|
||||
class TestT5Encoder:
|
||||
def setup_method(self):
|
||||
mx.random.seed(42)
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||
encoder = T5Encoder(
|
||||
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
||||
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
|
||||
)
|
||||
ids = mx.array([[1, 5, 10, 0, 0]])
|
||||
mask = mx.array([[1, 1, 1, 0, 0]])
|
||||
out = encoder(ids, mask=mask)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 5, 64)
|
||||
|
||||
def test_shared_pos(self):
|
||||
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||
encoder = T5Encoder(
|
||||
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
||||
num_heads=4, num_layers=2, num_buckets=32, shared_pos=True,
|
||||
)
|
||||
assert encoder.pos_embedding is not None
|
||||
for block in encoder.blocks:
|
||||
assert block.pos_embedding is None
|
||||
|
||||
def test_per_layer_pos(self):
|
||||
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||
encoder = T5Encoder(
|
||||
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
||||
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
|
||||
)
|
||||
assert encoder.pos_embedding is None
|
||||
for block in encoder.blocks:
|
||||
assert block.pos_embedding is not None
|
||||
|
||||
def test_param_count(self):
|
||||
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||
encoder = T5Encoder(
|
||||
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
||||
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
|
||||
)
|
||||
num_params = sum(p.size for _, p in nn.utils.tree_flatten(encoder.parameters()))
|
||||
assert num_params > 0
|
||||
|
||||
def test_without_mask(self):
|
||||
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||
encoder = T5Encoder(
|
||||
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
||||
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
|
||||
)
|
||||
ids = mx.array([[1, 5, 10]])
|
||||
out = encoder(ids)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 3, 64)
|
||||
198
tests/test_wan_tiling.py
Normal file
198
tests/test_wan_tiling.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Tests for Wan VAE tiled decoding."""
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mlx_video.models.ltx.video_vae.tiling import (
|
||||
TilingConfig,
|
||||
decode_with_tiling,
|
||||
split_in_spatial,
|
||||
split_in_temporal,
|
||||
)
|
||||
|
||||
|
||||
class TestNonCausalTemporal:
|
||||
"""Tests for the causal_temporal=False path in decode_with_tiling."""
|
||||
|
||||
def test_split_spatial_for_temporal(self):
|
||||
"""Non-causal temporal should use split_in_spatial (no causal shift)."""
|
||||
intervals = split_in_spatial(8, 2, 20)
|
||||
# No causal adjustment: starts should be evenly spaced
|
||||
assert intervals.starts[0] == 0
|
||||
for i in range(1, len(intervals.starts)):
|
||||
assert intervals.starts[i] == intervals.starts[i - 1] + (8 - 2)
|
||||
|
||||
def test_causal_vs_noncausal_output_size(self):
|
||||
"""Causal temporal gives 1+(T-1)*S frames, non-causal gives T*S."""
|
||||
mx.random.seed(42)
|
||||
b, c, t, h, w = 1, 4, 4, 4, 4
|
||||
latents = mx.random.normal((b, c, t, h, w))
|
||||
scale = 4
|
||||
|
||||
# Simple passthrough decoder: just repeat along dimensions
|
||||
def dummy_decoder_causal(x, **kwargs):
|
||||
b, c, t, h, w = x.shape
|
||||
out_t = 1 + (t - 1) * scale
|
||||
out_h = h * scale
|
||||
out_w = w * scale
|
||||
return mx.ones((b, 3, out_t, out_h, out_w))
|
||||
|
||||
def dummy_decoder_noncausal(x, **kwargs):
|
||||
b, c, t, h, w = x.shape
|
||||
out_t = t * scale
|
||||
out_h = h * scale
|
||||
out_w = w * scale
|
||||
return mx.ones((b, 3, out_t, out_h, out_w))
|
||||
|
||||
config = TilingConfig.spatial_only(tile_size=128, overlap=64)
|
||||
|
||||
# Causal: 1 + (4-1)*4 = 13
|
||||
out_causal = decode_with_tiling(
|
||||
dummy_decoder_causal, latents, config,
|
||||
spatial_scale=scale, temporal_scale=scale, causal_temporal=True,
|
||||
)
|
||||
mx.eval(out_causal)
|
||||
assert out_causal.shape[2] == 1 + (t - 1) * scale # 13
|
||||
|
||||
# Non-causal: 4*4 = 16
|
||||
out_noncausal = decode_with_tiling(
|
||||
dummy_decoder_noncausal, latents, config,
|
||||
spatial_scale=scale, temporal_scale=scale, causal_temporal=False,
|
||||
)
|
||||
mx.eval(out_noncausal)
|
||||
assert out_noncausal.shape[2] == t * scale # 16
|
||||
|
||||
|
||||
class TestWan22TiledDecoding:
|
||||
"""Tests for Wan2.2 VAE tiled decoding."""
|
||||
|
||||
def _make_small_wan22_decoder(self):
|
||||
"""Create a small Wan2.2 decoder for testing."""
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||
|
||||
# Use very small dimensions for fast testing
|
||||
vae = Wan22VAEDecoder(z_dim=48, dim=16, dec_dim=16)
|
||||
mx.eval(vae.parameters())
|
||||
return vae
|
||||
|
||||
def test_decode_tiled_output_shape(self):
|
||||
"""Tiled decode should produce same shape as non-tiled."""
|
||||
mx.random.seed(42)
|
||||
vae = self._make_small_wan22_decoder()
|
||||
|
||||
# Small input: [B=1, T=3, H=2, W=2, C=48]
|
||||
z = mx.random.normal((1, 3, 2, 2, 48))
|
||||
mx.eval(z)
|
||||
|
||||
# Non-tiled
|
||||
out_regular = vae(z)
|
||||
mx.eval(out_regular)
|
||||
|
||||
# Tiled (force tiling with very small tile sizes)
|
||||
# Use spatial tile=32px (2 latent @ scale 16) and temporal=8 frames (2 latent @ scale 4)
|
||||
config = TilingConfig(
|
||||
spatial_config=None, # Don't tile spatially (input is tiny)
|
||||
temporal_config=None, # Don't tile temporally (input is tiny)
|
||||
)
|
||||
# With no tiling config, decode_tiled should fall through to regular decode
|
||||
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
|
||||
mx.eval(out_tiled)
|
||||
|
||||
# Both should produce the same shape
|
||||
assert out_regular.shape == out_tiled.shape, (
|
||||
f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
|
||||
)
|
||||
|
||||
def test_decode_tiled_falls_through_when_small(self):
|
||||
"""When input is smaller than tile size, decode_tiled should produce same output as __call__."""
|
||||
mx.random.seed(42)
|
||||
vae = self._make_small_wan22_decoder()
|
||||
|
||||
# Input smaller than any tile size
|
||||
z = mx.random.normal((1, 2, 2, 2, 48))
|
||||
mx.eval(z)
|
||||
|
||||
out_regular = vae(z)
|
||||
mx.eval(out_regular)
|
||||
|
||||
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
|
||||
mx.eval(out_tiled)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
np.array(out_regular), np.array(out_tiled),
|
||||
rtol=1e-4, atol=1e-4,
|
||||
err_msg="Tiled decode should match regular decode for small inputs",
|
||||
)
|
||||
|
||||
|
||||
class TestWan21TiledDecoding:
|
||||
"""Tests for Wan2.1 VAE tiled decoding."""
|
||||
|
||||
def _make_small_wan21_vae(self):
|
||||
"""Create a small Wan2.1 VAE for testing."""
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16)
|
||||
mx.eval(vae.parameters())
|
||||
return vae
|
||||
|
||||
def test_decode_tiled_output_shape(self):
|
||||
"""Tiled decode should produce correct output shape."""
|
||||
mx.random.seed(42)
|
||||
vae = self._make_small_wan21_vae()
|
||||
|
||||
# [B=1, C=16, T=3, H=4, W=4]
|
||||
z = mx.random.normal((1, 16, 3, 4, 4))
|
||||
mx.eval(z)
|
||||
|
||||
out_regular = vae.decode(z)
|
||||
mx.eval(out_regular)
|
||||
|
||||
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
|
||||
mx.eval(out_tiled)
|
||||
|
||||
assert out_regular.shape == out_tiled.shape, (
|
||||
f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
|
||||
)
|
||||
|
||||
def test_decode_tiled_falls_through_when_small(self):
|
||||
"""When input is smaller than tile size, decode_tiled should produce same output as decode."""
|
||||
mx.random.seed(42)
|
||||
vae = self._make_small_wan21_vae()
|
||||
|
||||
z = mx.random.normal((1, 16, 2, 4, 4))
|
||||
mx.eval(z)
|
||||
|
||||
out_regular = vae.decode(z)
|
||||
mx.eval(out_regular)
|
||||
|
||||
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
|
||||
mx.eval(out_tiled)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
np.array(out_regular), np.array(out_tiled),
|
||||
rtol=1e-4, atol=1e-4,
|
||||
err_msg="Tiled decode should match regular decode for small inputs",
|
||||
)
|
||||
|
||||
|
||||
class TestWan21TemporalScale:
|
||||
"""Verify Wan2.1 decoder temporal output is T*4 (non-causal)."""
|
||||
|
||||
def test_wan21_decoder_temporal_output(self):
|
||||
"""Wan2.1 Decoder3d should produce T*4 temporal output (non-causal doubling)."""
|
||||
from mlx_video.models.wan.vae import Decoder3d
|
||||
|
||||
# Small decoder for fast test
|
||||
dec = Decoder3d(dim=16, z_dim=4, dim_mult=[1, 1, 1, 1], num_res_blocks=1,
|
||||
temporal_upsample=[True, True, False])
|
||||
mx.eval(dec.parameters())
|
||||
|
||||
x = mx.random.normal((1, 4, 3, 4, 4)) # T=3
|
||||
mx.eval(x)
|
||||
out = dec(x)
|
||||
mx.eval(out)
|
||||
|
||||
# With two temporal 2× upsamples: T=3 → 6 → 12
|
||||
assert out.shape[2] == 3 * 4, f"Expected T=12, got T={out.shape[2]}"
|
||||
160
tests/test_wan_transformer.py
Normal file
160
tests/test_wan_transformer.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""Tests for Wan transformer block components."""
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transformer Block Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestWanFFN:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.transformer import WanFFN
|
||||
ffn = WanFFN(64, 256)
|
||||
x = mx.random.normal((2, 10, 64))
|
||||
out = ffn(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (2, 10, 64)
|
||||
|
||||
def test_gelu_activation(self):
|
||||
"""FFN should use GELU activation (non-linearity)."""
|
||||
from mlx_video.models.wan.transformer import WanFFN
|
||||
ffn = WanFFN(32, 128)
|
||||
x = mx.ones((1, 1, 32)) * 2.0
|
||||
out1 = ffn(x)
|
||||
x2 = mx.ones((1, 1, 32)) * 4.0
|
||||
out2 = ffn(x2)
|
||||
mx.eval(out1, out2)
|
||||
# Non-linear: 2x input should not give 2x output
|
||||
assert not np.allclose(np.array(out2), np.array(out1) * 2.0, rtol=0.1)
|
||||
|
||||
|
||||
class TestWanAttentionBlock:
|
||||
def setup_method(self):
|
||||
mx.random.seed(42)
|
||||
self.dim = 64
|
||||
self.ffn_dim = 128
|
||||
self.num_heads = 4
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
block = WanAttentionBlock(
|
||||
self.dim, self.ffn_dim, self.num_heads,
|
||||
cross_attn_norm=True,
|
||||
)
|
||||
B, L = 1, 24
|
||||
F, H, W = 2, 3, 4
|
||||
x = mx.random.normal((B, L, self.dim))
|
||||
e = mx.random.normal((B, L, 6, self.dim))
|
||||
context = mx.random.normal((B, 16, self.dim))
|
||||
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||
|
||||
out = block(
|
||||
x, e, seq_lens=[L], grid_sizes=[(F, H, W)],
|
||||
freqs=freqs, context=context,
|
||||
)
|
||||
mx.eval(out)
|
||||
assert out.shape == (B, L, self.dim)
|
||||
|
||||
def test_modulation_shape(self):
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
|
||||
assert block.modulation.shape == (1, 6, self.dim)
|
||||
|
||||
def test_with_cross_attn_norm(self):
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
block = WanAttentionBlock(
|
||||
self.dim, self.ffn_dim, self.num_heads,
|
||||
cross_attn_norm=True,
|
||||
)
|
||||
assert block.norm3 is not None
|
||||
|
||||
def test_without_cross_attn_norm(self):
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
block = WanAttentionBlock(
|
||||
self.dim, self.ffn_dim, self.num_heads,
|
||||
cross_attn_norm=False,
|
||||
)
|
||||
assert block.norm3 is None
|
||||
|
||||
def test_residual_connection(self):
|
||||
"""Output should differ from zero even with small random init."""
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
|
||||
B, L = 1, 8
|
||||
F, H, W = 2, 2, 2
|
||||
x = mx.ones((B, L, self.dim))
|
||||
e = mx.zeros((B, L, 6, self.dim))
|
||||
context = mx.random.normal((B, 4, self.dim))
|
||||
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||
|
||||
out = block(x, e, [L], [(F, H, W)], freqs, context)
|
||||
mx.eval(out)
|
||||
# With residual connections, output should be close to input + corrections
|
||||
assert not np.allclose(np.array(out), 0.0, atol=1e-3)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Float32 Modulation Precision Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFloat32Modulation:
|
||||
"""Tests that modulation/gate operations are computed in float32,
|
||||
matching official torch.amp.autocast('cuda', dtype=torch.float32)."""
|
||||
|
||||
def setup_method(self):
|
||||
mx.random.seed(42)
|
||||
self.dim = 64
|
||||
|
||||
def test_block_modulation_in_float32(self):
|
||||
"""Modulation param starts random but should be usable as float32."""
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
block = WanAttentionBlock(self.dim, 128, 4, cross_attn_norm=True)
|
||||
assert block.modulation.dtype == mx.float32
|
||||
|
||||
def test_block_output_float32_with_bf16_modulation_input(self):
|
||||
"""Even if e (time embedding) arrives as bf16, modulation should cast to f32."""
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
block = WanAttentionBlock(self.dim, 128, 4)
|
||||
B, L = 1, 8
|
||||
x = mx.random.normal((B, L, self.dim))
|
||||
e = mx.random.normal((B, L, 6, self.dim)).astype(mx.bfloat16)
|
||||
ctx = mx.random.normal((B, 4, self.dim))
|
||||
freqs = rope_params(1024, self.dim // 4)
|
||||
|
||||
out = block(x, e, [L], [(2, 2, 2)], freqs, ctx)
|
||||
mx.eval(out)
|
||||
assert out.dtype == mx.float32
|
||||
assert np.isfinite(np.array(out)).all()
|
||||
|
||||
def test_head_modulation_float32(self):
|
||||
"""Head modulation should be float32 even with bf16 e input."""
|
||||
from mlx_video.models.wan.model import Head
|
||||
head = Head(self.dim, 4, (1, 2, 2))
|
||||
x = mx.random.normal((1, 8, self.dim))
|
||||
e = mx.random.normal((1, 8, self.dim)).astype(mx.bfloat16)
|
||||
out = head(x, e)
|
||||
mx.eval(out)
|
||||
assert np.isfinite(np.array(out.astype(mx.float32))).all()
|
||||
|
||||
def test_model_time_embedding_float32(self):
|
||||
"""sinusoidal_embedding_1d output must be float32."""
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
t = mx.array([500.0])
|
||||
emb = sinusoidal_embedding_1d(256, t)
|
||||
mx.eval(emb)
|
||||
assert emb.dtype == mx.float32
|
||||
|
||||
def test_model_per_token_time_embedding_float32(self):
|
||||
"""Per-token time embeddings (I2V) should also be float32."""
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4]
|
||||
emb = sinusoidal_embedding_1d(256, t)
|
||||
mx.eval(emb)
|
||||
assert emb.dtype == mx.float32
|
||||
assert emb.shape == (1, 4, 256)
|
||||
951
tests/test_wan_vae.py
Normal file
951
tests/test_wan_vae.py
Normal file
@@ -0,0 +1,951 @@
|
||||
"""Tests for Wan VAE 2.1 and 2.2 components."""
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VAE 2.1 Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCausalConv3d:
|
||||
def test_output_shape_stride1(self):
|
||||
from mlx_video.models.wan.vae import CausalConv3d
|
||||
conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1)
|
||||
# Initialize weights
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
|
||||
x = mx.random.normal((1, 4, 3, 8, 8)) # [B, C, T, H, W]
|
||||
out = conv(x)
|
||||
mx.eval(out)
|
||||
# With causal padding and padding=1 on spatial, dims should be preserved
|
||||
assert out.shape[0] == 1
|
||||
assert out.shape[1] == 8 # out_channels
|
||||
assert out.shape[2] == 3 # T preserved
|
||||
assert out.shape[3] == 8 # H preserved
|
||||
assert out.shape[4] == 8 # W preserved
|
||||
|
||||
def test_output_shape_kernel1(self):
|
||||
from mlx_video.models.wan.vae import CausalConv3d
|
||||
conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0)
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
|
||||
x = mx.random.normal((1, 4, 2, 4, 4))
|
||||
out = conv(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 8, 2, 4, 4)
|
||||
|
||||
def test_causal_padding(self):
|
||||
"""Causal conv should only use past/current frames, not future."""
|
||||
from mlx_video.models.wan.vae import CausalConv3d
|
||||
conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1)
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
|
||||
conv.bias = mx.zeros((2,))
|
||||
# Create input where only the first frame has signal
|
||||
x = mx.zeros((1, 2, 4, 4, 4))
|
||||
x_np = np.zeros((1, 2, 4, 4, 4), dtype=np.float32)
|
||||
x_np[:, :, 0, :, :] = 1.0
|
||||
x = mx.array(x_np)
|
||||
out = conv(x)
|
||||
mx.eval(out)
|
||||
# Due to causal padding, the output at t=0 should only depend on t=0
|
||||
|
||||
|
||||
class TestResidualBlock:
|
||||
def test_same_dim(self):
|
||||
from mlx_video.models.wan.vae import ResidualBlock
|
||||
block = ResidualBlock(8, 8)
|
||||
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 8, 2, 4, 4)
|
||||
|
||||
def test_different_dim(self):
|
||||
from mlx_video.models.wan.vae import ResidualBlock
|
||||
block = ResidualBlock(8, 16)
|
||||
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 16, 2, 4, 4)
|
||||
|
||||
def test_shortcut_exists_when_dims_differ(self):
|
||||
from mlx_video.models.wan.vae import ResidualBlock
|
||||
block = ResidualBlock(8, 16)
|
||||
assert block.shortcut is not None
|
||||
|
||||
def test_no_shortcut_when_dims_same(self):
|
||||
from mlx_video.models.wan.vae import ResidualBlock
|
||||
block = ResidualBlock(8, 8)
|
||||
assert block.shortcut is None
|
||||
|
||||
|
||||
class TestAttentionBlock:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae import AttentionBlock
|
||||
block = AttentionBlock(8)
|
||||
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 8, 2, 4, 4)
|
||||
|
||||
def test_residual_connection(self):
|
||||
from mlx_video.models.wan.vae import AttentionBlock
|
||||
block = AttentionBlock(8)
|
||||
x = mx.random.normal((1, 8, 1, 3, 3))
|
||||
out = block(x)
|
||||
mx.eval(x, out)
|
||||
# Residual: output should not be zero even with random init
|
||||
assert np.abs(np.array(out)).max() > 0
|
||||
|
||||
|
||||
class TestWanVAE:
|
||||
def test_instantiation(self):
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
vae = WanVAE(z_dim=16)
|
||||
assert vae.z_dim == 16
|
||||
assert vae.mean.shape == (16,)
|
||||
assert vae.std.shape == (16,)
|
||||
|
||||
def test_normalization_stats(self):
|
||||
from mlx_video.models.wan.vae import WanVAE, VAE_MEAN, VAE_STD
|
||||
assert len(VAE_MEAN) == 16
|
||||
assert len(VAE_STD) == 16
|
||||
assert all(s > 0 for s in VAE_STD)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wan2.2 VAE Component Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVAE22CausalConv3d:
|
||||
"""Tests for vae22.CausalConv3d (channels-last)."""
|
||||
|
||||
def test_output_shape_k3(self):
|
||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||
conv = CausalConv3d(8, 16, kernel_size=3, padding=1)
|
||||
x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C]
|
||||
out = conv(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 4, 8, 8, 16)
|
||||
|
||||
def test_output_shape_k1(self):
|
||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||
conv = CausalConv3d(8, 16, kernel_size=1)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = conv(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_temporal_causal(self):
|
||||
"""Output at t=0 should not depend on t>0."""
|
||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||
conv = CausalConv3d(2, 2, kernel_size=3, padding=1)
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
|
||||
conv.bias = mx.zeros(conv.bias.shape)
|
||||
|
||||
x = mx.zeros((1, 4, 4, 4, 2))
|
||||
out_zero = conv(x)
|
||||
mx.eval(out_zero)
|
||||
t0_ref = np.array(out_zero[0, 0])
|
||||
|
||||
# Modify t=2..3; output at t=0 should be unchanged
|
||||
x_mod = mx.concatenate([
|
||||
x[:, :2],
|
||||
mx.ones((1, 2, 4, 4, 2)),
|
||||
], axis=1)
|
||||
out_mod = conv(x_mod)
|
||||
mx.eval(out_mod)
|
||||
t0_mod = np.array(out_mod[0, 0])
|
||||
np.testing.assert_allclose(t0_ref, t0_mod, atol=1e-5)
|
||||
|
||||
def test_channels_last_format(self):
|
||||
"""Verify input/output are channels-last [B, T, H, W, C]."""
|
||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||
conv = CausalConv3d(4, 8, kernel_size=3, padding=1)
|
||||
x = mx.random.normal((2, 3, 6, 6, 4))
|
||||
out = conv(x)
|
||||
mx.eval(out)
|
||||
assert out.shape[-1] == 8 # last dim = out_channels
|
||||
|
||||
|
||||
class TestRMSNorm:
|
||||
"""Tests for vae22.RMS_norm (actually L2 normalization)."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import RMS_norm
|
||||
norm = RMS_norm(16)
|
||||
x = mx.random.normal((2, 4, 4, 4, 16))
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == x.shape
|
||||
|
||||
def test_l2_normalization(self):
|
||||
"""RMS_norm should normalize to unit L2 norm * sqrt(dim)."""
|
||||
from mlx_video.models.wan.vae22 import RMS_norm
|
||||
dim = 32
|
||||
norm = RMS_norm(dim)
|
||||
x = mx.random.normal((1, 1, 1, 1, dim)) * 5.0 # large values
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
# After L2 norm * scale(=sqrt(dim)) * gamma(=1): ||out|| = sqrt(dim)
|
||||
out_np = np.array(out).flatten()
|
||||
l2 = np.linalg.norm(out_np)
|
||||
np.testing.assert_allclose(l2, math.sqrt(dim), rtol=1e-3)
|
||||
|
||||
def test_scale_invariant(self):
|
||||
"""Scaling input by constant should not change output (L2 norm property)."""
|
||||
from mlx_video.models.wan.vae22 import RMS_norm
|
||||
norm = RMS_norm(8)
|
||||
x = mx.random.normal((1, 1, 1, 1, 8))
|
||||
out1 = norm(x)
|
||||
out2 = norm(x * 10.0)
|
||||
mx.eval(out1, out2)
|
||||
np.testing.assert_allclose(np.array(out1), np.array(out2), atol=1e-4)
|
||||
|
||||
def test_gamma_effect(self):
|
||||
"""Non-unit gamma should scale output."""
|
||||
from mlx_video.models.wan.vae22 import RMS_norm
|
||||
norm = RMS_norm(4)
|
||||
norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0])
|
||||
x = mx.ones((1, 1, 1, 1, 4))
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
# With gamma=2, each component is 2 * sqrt(4) * x/||x|| = 2 * 2 * 1/2 = 2
|
||||
np.testing.assert_allclose(np.array(out).flatten(), 2.0, atol=1e-4)
|
||||
|
||||
|
||||
class TestDupUp3D:
|
||||
"""Tests for vae22.DupUp3D spatial/temporal upsampling."""
|
||||
|
||||
def test_spatial_only(self):
|
||||
from mlx_video.models.wan.vae22 import DupUp3D
|
||||
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||
out = up(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 3, 8, 8, 4)
|
||||
|
||||
def test_temporal_and_spatial(self):
|
||||
from mlx_video.models.wan.vae22 import DupUp3D
|
||||
up = DupUp3D(16, 8, factor_t=2, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 16))
|
||||
out = up(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 6, 8, 8, 8)
|
||||
|
||||
def test_first_chunk_trims(self):
|
||||
from mlx_video.models.wan.vae22 import DupUp3D
|
||||
up = DupUp3D(8, 4, factor_t=2, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||
out_normal = up(x, first_chunk=False)
|
||||
out_trimmed = up(x, first_chunk=True)
|
||||
mx.eval(out_normal, out_trimmed)
|
||||
# first_chunk removes factor_t-1=1 temporal frame
|
||||
assert out_normal.shape[1] == 6
|
||||
assert out_trimmed.shape[1] == 5
|
||||
|
||||
def test_no_temporal_first_chunk_noop(self):
|
||||
from mlx_video.models.wan.vae22 import DupUp3D
|
||||
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||
out_normal = up(x, first_chunk=False)
|
||||
out_trimmed = up(x, first_chunk=True)
|
||||
mx.eval(out_normal, out_trimmed)
|
||||
# factor_t=1, so first_chunk removes 0 frames
|
||||
assert out_normal.shape == out_trimmed.shape
|
||||
|
||||
|
||||
class TestVAE22Resample:
|
||||
"""Tests for vae22.Resample (spatial/temporal upsampling)."""
|
||||
|
||||
def test_upsample2d_shape(self):
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
r = Resample(8, "upsample2d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = r(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 8, 8, 8) # 2x spatial, same temporal
|
||||
|
||||
def test_upsample3d_shape(self):
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
r = Resample(8, "upsample3d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = r(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 4, 8, 8, 8) # 2x spatial + 2x temporal
|
||||
|
||||
def test_upsample3d_first_chunk(self):
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
r = Resample(8, "upsample3d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = r(x, first_chunk=True)
|
||||
mx.eval(out)
|
||||
# first_chunk: 1 (bypass) + 2*(T-1) (interleaved) = 2T-1 = 3
|
||||
assert out.shape == (1, 3, 8, 8, 8)
|
||||
|
||||
def test_upsample3d_first_chunk_single_frame(self):
|
||||
"""Single-frame input with first_chunk: no temporal upsample."""
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
r = Resample(8, "upsample3d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
x = mx.random.normal((1, 1, 4, 4, 8))
|
||||
out = r(x, first_chunk=True)
|
||||
mx.eval(out)
|
||||
# Single frame with first_chunk: falls through to non-first path
|
||||
# time_conv on 1 frame → 2 interleaved
|
||||
assert out.shape == (1, 2, 8, 8, 8)
|
||||
|
||||
def test_upsample3d_first_frame_bypasses_time_conv(self):
|
||||
"""First frame of first_chunk should NOT go through time_conv.
|
||||
|
||||
Official Wan2.2 skips time_conv for the very first frame entirely.
|
||||
We verify this by checking that the first output frame depends only on
|
||||
the first input frame (not on time_conv parameters).
|
||||
"""
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
C = 8
|
||||
r = Resample(C, "upsample3d")
|
||||
# Set time_conv weights to large values so its effect is detectable
|
||||
r.time_conv.weight = mx.ones(r.time_conv.weight.shape) * 10.0
|
||||
r.time_conv.bias = mx.zeros(r.time_conv.bias.shape)
|
||||
# Set spatial conv to identity-like
|
||||
r.resample_weight = mx.zeros(r.resample_weight.shape)
|
||||
r.resample_bias = mx.zeros(r.resample_bias.shape)
|
||||
|
||||
x = mx.random.normal((1, 3, 2, 2, C))
|
||||
out = r(x, first_chunk=True)
|
||||
mx.eval(out)
|
||||
# Output: 5 frames (1 bypass + 4 interleaved from 2 remaining)
|
||||
assert out.shape[1] == 5
|
||||
|
||||
# First frame should be spatial upsample of x[:, 0:1] only.
|
||||
# Run just the first frame through spatial upsample for reference
|
||||
first_only = x[:, 0:1]
|
||||
ref = r._upsample2x(first_only.reshape(1, 2, 2, C))
|
||||
ref = mx.pad(ref, [(0, 0), (1, 1), (1, 1), (0, 0)])
|
||||
ref = mx.conv_general(ref, r.resample_weight) + r.resample_bias
|
||||
mx.eval(ref)
|
||||
|
||||
# Compare first output frame to reference
|
||||
first_out = out[:, 0:1].reshape(1, out.shape[2], out.shape[3], C)
|
||||
mx.eval(first_out)
|
||||
assert mx.allclose(first_out, ref, atol=1e-5).item(), \
|
||||
"First frame should bypass time_conv and match spatial-only upsample"
|
||||
|
||||
|
||||
class TestVAE22ResidualBlock:
|
||||
"""Tests for vae22.ResidualBlock."""
|
||||
|
||||
def test_same_dim(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||
block = ResidualBlock(8, 8)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 8)
|
||||
|
||||
def test_different_dim(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||
block = ResidualBlock(8, 16)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_shortcut_when_dims_differ(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||
block = ResidualBlock(8, 16)
|
||||
assert block.shortcut is not None
|
||||
|
||||
def test_no_shortcut_same_dim(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||
block = ResidualBlock(8, 8)
|
||||
assert block.shortcut is None
|
||||
|
||||
|
||||
class TestResidualBlockLayers:
|
||||
"""Tests for vae22.ResidualBlockLayers naming convention."""
|
||||
|
||||
def test_layer_names_no_underscore_prefix(self):
|
||||
"""Layer names must NOT start with underscore (MLX ignores them)."""
|
||||
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
||||
block = ResidualBlockLayers(8, 8)
|
||||
params = dict(block.parameters())
|
||||
# All param keys should use layer_N, not _layer_N
|
||||
for key in params:
|
||||
assert not key.startswith("_"), f"Parameter {key} starts with underscore"
|
||||
|
||||
def test_has_expected_layers(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
||||
block = ResidualBlockLayers(8, 16)
|
||||
assert hasattr(block, "layer_0") # first RMS_norm
|
||||
assert hasattr(block, "layer_2") # first CausalConv3d
|
||||
assert hasattr(block, "layer_3") # second RMS_norm
|
||||
assert hasattr(block, "layer_6") # second CausalConv3d
|
||||
|
||||
def test_forward_shape(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
||||
block = ResidualBlockLayers(8, 16)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
|
||||
class TestVAE22AttentionBlock:
|
||||
"""Tests for vae22.AttentionBlock (per-frame 2D self-attention)."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import AttentionBlock
|
||||
block = AttentionBlock(16)
|
||||
block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01
|
||||
block.proj_weight = mx.random.normal(block.proj_weight.shape) * 0.01
|
||||
x = mx.random.normal((1, 2, 4, 4, 16))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_residual_connection(self):
|
||||
from mlx_video.models.wan.vae22 import AttentionBlock
|
||||
block = AttentionBlock(8)
|
||||
block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape)
|
||||
block.proj_weight = mx.zeros(block.proj_weight.shape)
|
||||
x = mx.ones((1, 1, 2, 2, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
# With zero weights, attention output is 0 → residual is identity
|
||||
np.testing.assert_allclose(np.array(out), np.array(x), atol=1e-5)
|
||||
|
||||
|
||||
class TestHead22:
|
||||
"""Tests for vae22.Head22 output head."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import Head22
|
||||
head = Head22(16, out_channels=12)
|
||||
x = mx.random.normal((1, 2, 4, 4, 16))
|
||||
out = head(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 12)
|
||||
|
||||
def test_layer_names_no_underscore(self):
|
||||
"""Head layers must not use underscore prefix."""
|
||||
from mlx_video.models.wan.vae22 import Head22
|
||||
head = Head22(8)
|
||||
assert hasattr(head, "layer_0") # RMS_norm
|
||||
assert hasattr(head, "layer_2") # CausalConv3d
|
||||
params = dict(head.parameters())
|
||||
for key in params:
|
||||
assert not key.startswith("_"), f"Head param {key} starts with underscore"
|
||||
|
||||
|
||||
class TestUnpatchify:
|
||||
"""Tests for vae22._unpatchify."""
|
||||
|
||||
def test_basic_shape(self):
|
||||
from mlx_video.models.wan.vae22 import _unpatchify
|
||||
x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2
|
||||
out = _unpatchify(x, patch_size=2)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 8, 8, 3)
|
||||
|
||||
def test_patch_size_1_noop(self):
|
||||
from mlx_video.models.wan.vae22 import _unpatchify
|
||||
x = mx.random.normal((1, 2, 4, 4, 3))
|
||||
out = _unpatchify(x, patch_size=1)
|
||||
mx.eval(out)
|
||||
np.testing.assert_array_equal(np.array(out), np.array(x))
|
||||
|
||||
def test_preserves_content(self):
|
||||
"""Unpatchify should be a lossless rearrangement."""
|
||||
from mlx_video.models.wan.vae22 import _unpatchify
|
||||
x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32)
|
||||
out = _unpatchify(x, patch_size=2)
|
||||
mx.eval(out)
|
||||
# All elements should be preserved
|
||||
assert np.array(out).size == 48
|
||||
assert set(np.array(out).flatten().tolist()) == set(range(48))
|
||||
|
||||
|
||||
class TestDenormalizeLatents:
|
||||
"""Tests for vae22.denormalize_latents."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||
z = mx.random.normal((1, 2, 4, 4, 48))
|
||||
out = denormalize_latents(z)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 48)
|
||||
|
||||
def test_custom_mean_std(self):
|
||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||
z = mx.ones((1, 1, 1, 1, 4))
|
||||
mean = mx.array([1.0, 2.0, 3.0, 4.0])
|
||||
std = mx.array([0.5, 0.5, 0.5, 0.5])
|
||||
out = denormalize_latents(z, mean=mean, std=std)
|
||||
mx.eval(out)
|
||||
# z * std + mean = 1*0.5 + [1,2,3,4] = [1.5, 2.5, 3.5, 4.5]
|
||||
np.testing.assert_allclose(np.array(out).flatten(), [1.5, 2.5, 3.5, 4.5], atol=1e-5)
|
||||
|
||||
def test_uses_default_constants(self):
|
||||
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD, denormalize_latents
|
||||
# Should not raise with default constants
|
||||
z = mx.zeros((1, 1, 1, 1, 48))
|
||||
out = denormalize_latents(z)
|
||||
mx.eval(out)
|
||||
# z=0 → result = 0 * std + mean = mean
|
||||
np.testing.assert_allclose(
|
||||
np.array(out).flatten(),
|
||||
np.array(VAE22_MEAN).flatten(),
|
||||
atol=1e-5,
|
||||
)
|
||||
|
||||
|
||||
class TestVAE22NormConstants:
|
||||
"""Tests for VAE22_MEAN and VAE22_STD constants."""
|
||||
|
||||
def test_dimensions(self):
|
||||
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD
|
||||
mx.eval(VAE22_MEAN, VAE22_STD)
|
||||
assert VAE22_MEAN.shape == (48,)
|
||||
assert VAE22_STD.shape == (48,)
|
||||
|
||||
def test_std_positive(self):
|
||||
from mlx_video.models.wan.vae22 import VAE22_STD
|
||||
mx.eval(VAE22_STD)
|
||||
assert (np.array(VAE22_STD) > 0).all()
|
||||
|
||||
|
||||
class TestWan22VAEDecoder:
|
||||
"""Tests for the full Wan22VAEDecoder (tiny configuration)."""
|
||||
|
||||
def test_output_shape_small(self):
|
||||
"""Tiny decoder should produce correct spatial/temporal output."""
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||
# Use very small dims to keep test fast
|
||||
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
|
||||
# Latent: [B=1, T=3, H=2, W=2, C=4]
|
||||
# Expected: temporal 3→5→9→9→9 (two temporal upsamples), spatial 2→4→8→16
|
||||
z = mx.random.normal((1, 3, 2, 2, 4)) * 0.1
|
||||
out = dec(z)
|
||||
mx.eval(out)
|
||||
# Output should have 3 RGB channels and be clipped to [-1, 1]
|
||||
assert out.shape[-1] == 3
|
||||
assert out.ndim == 5
|
||||
assert np.array(out).min() >= -1.0
|
||||
assert np.array(out).max() <= 1.0
|
||||
|
||||
def test_output_clipped(self):
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
|
||||
z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values
|
||||
out = dec(z)
|
||||
mx.eval(out)
|
||||
assert np.array(out).min() >= -1.0 - 1e-6
|
||||
assert np.array(out).max() <= 1.0 + 1e-6
|
||||
|
||||
|
||||
class TestSanitizeWan22VAEWeights:
|
||||
"""Tests for vae22.sanitize_wan22_vae_weights."""
|
||||
|
||||
def test_skip_encoder(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
weights = {
|
||||
"encoder.layer.weight": mx.zeros((4,)),
|
||||
"conv1.weight": mx.zeros((4,)),
|
||||
"decoder.conv1.bias": mx.zeros((4,)),
|
||||
}
|
||||
out = sanitize_wan22_vae_weights(weights)
|
||||
assert "encoder.layer.weight" not in out
|
||||
assert "conv1.weight" not in out
|
||||
assert "decoder.conv1.bias" in out
|
||||
|
||||
def test_sequential_index_remapping(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
weights = {
|
||||
"decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)),
|
||||
"decoder.upsamples.0.upsamples.0.residual.6.bias": mx.zeros((8,)),
|
||||
"decoder.head.0.gamma": mx.ones((4,)),
|
||||
"decoder.head.2.bias": mx.zeros((12,)),
|
||||
}
|
||||
out = sanitize_wan22_vae_weights(weights)
|
||||
assert "decoder.upsamples.0.upsamples.0.residual.layer_0.gamma" in out
|
||||
assert "decoder.upsamples.0.upsamples.0.residual.layer_6.bias" in out
|
||||
assert "decoder.head.layer_0.gamma" in out
|
||||
assert "decoder.head.layer_2.bias" in out
|
||||
|
||||
def test_resample_conv_remapping(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
weights = {
|
||||
"decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)),
|
||||
"decoder.upsamples.1.upsamples.3.resample.1.bias": mx.zeros((8,)),
|
||||
}
|
||||
out = sanitize_wan22_vae_weights(weights)
|
||||
assert "decoder.upsamples.1.upsamples.3.resample_weight" in out
|
||||
assert "decoder.upsamples.1.upsamples.3.resample_bias" in out
|
||||
|
||||
def test_attention_remapping(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
weights = {
|
||||
"decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)),
|
||||
"decoder.middle.1.to_qkv.bias": mx.zeros((24,)),
|
||||
"decoder.middle.1.proj.weight": mx.zeros((8, 8, 1, 1)),
|
||||
"decoder.middle.1.proj.bias": mx.zeros((8,)),
|
||||
}
|
||||
out = sanitize_wan22_vae_weights(weights)
|
||||
assert "decoder.middle.1.to_qkv_weight" in out
|
||||
assert "decoder.middle.1.to_qkv_bias" in out
|
||||
assert "decoder.middle.1.proj_weight" in out
|
||||
assert "decoder.middle.1.proj_bias" in out
|
||||
|
||||
def test_conv3d_transpose(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
# Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I]
|
||||
w = mx.zeros((16, 8, 3, 3, 3))
|
||||
weights = {"decoder.conv1.weight": w}
|
||||
out = sanitize_wan22_vae_weights(weights)
|
||||
assert out["decoder.conv1.weight"].shape == (16, 3, 3, 3, 8)
|
||||
|
||||
def test_conv2d_transpose(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
# Conv2d weight: [O, I, H, W] → [O, H, W, I]
|
||||
w = mx.zeros((8, 8, 3, 3))
|
||||
weights = {"decoder.upsamples.0.upsamples.2.resample.1.weight": w}
|
||||
out = sanitize_wan22_vae_weights(weights)
|
||||
key = "decoder.upsamples.0.upsamples.2.resample_weight"
|
||||
assert out[key].shape == (8, 3, 3, 8)
|
||||
|
||||
def test_gamma_squeeze(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
# gamma: (dim, 1, 1, 1) → (dim,)
|
||||
w = mx.ones((16, 1, 1, 1))
|
||||
weights = {"decoder.upsamples.0.upsamples.0.residual.0.gamma": w}
|
||||
out = sanitize_wan22_vae_weights(weights)
|
||||
key = "decoder.upsamples.0.upsamples.0.residual.layer_0.gamma"
|
||||
assert out[key].shape == (16,)
|
||||
|
||||
|
||||
class TestUpResidualBlock:
|
||||
"""Tests for vae22.Up_ResidualBlock."""
|
||||
|
||||
def test_no_upsample(self):
|
||||
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
||||
block = Up_ResidualBlock(8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
# No upsample: same shape
|
||||
assert out.shape == (1, 2, 4, 4, 8)
|
||||
|
||||
def test_spatial_upsample(self):
|
||||
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
||||
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
# 2x spatial upsample, no temporal
|
||||
assert out.shape == (1, 2, 8, 8, 4)
|
||||
|
||||
def test_spatial_temporal_upsample(self):
|
||||
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
||||
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
# 2x spatial + 2x temporal
|
||||
assert out.shape == (1, 4, 8, 8, 4)
|
||||
|
||||
|
||||
class TestPatchify:
|
||||
"""Tests for _patchify and _unpatchify round-trip."""
|
||||
|
||||
def test_roundtrip(self):
|
||||
from mlx_video.models.wan.vae22 import _patchify, _unpatchify
|
||||
|
||||
x = mx.random.normal((1, 1, 64, 64, 3))
|
||||
p = _patchify(x, patch_size=2)
|
||||
assert p.shape == (1, 1, 32, 32, 12)
|
||||
back = _unpatchify(p, patch_size=2)
|
||||
assert back.shape == x.shape
|
||||
assert float(mx.abs(x - back).max()) == 0.0
|
||||
|
||||
def test_identity_patch_1(self):
|
||||
from mlx_video.models.wan.vae22 import _patchify, _unpatchify
|
||||
|
||||
x = mx.random.normal((1, 2, 8, 8, 3))
|
||||
assert _patchify(x, patch_size=1).shape == x.shape
|
||||
assert _unpatchify(x, patch_size=1).shape == x.shape
|
||||
|
||||
|
||||
class TestAvgDown3D:
|
||||
"""Tests for AvgDown3D downsampling."""
|
||||
|
||||
def test_spatial_only(self):
|
||||
from mlx_video.models.wan.vae22 import AvgDown3D
|
||||
|
||||
down = AvgDown3D(8, 16, factor_t=1, factor_s=2)
|
||||
x = mx.random.normal((1, 2, 8, 8, 8))
|
||||
out = down(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_temporal_and_spatial(self):
|
||||
from mlx_video.models.wan.vae22 import AvgDown3D
|
||||
|
||||
down = AvgDown3D(8, 16, factor_t=2, factor_s=2)
|
||||
x = mx.random.normal((1, 4, 8, 8, 8))
|
||||
out = down(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_single_frame(self):
|
||||
from mlx_video.models.wan.vae22 import AvgDown3D
|
||||
|
||||
down = AvgDown3D(8, 8, factor_t=2, factor_s=2)
|
||||
x = mx.random.normal((1, 1, 8, 8, 8))
|
||||
out = down(x)
|
||||
mx.eval(out)
|
||||
# T=1 with factor_t=2: pads to T=2 then averages → T=1
|
||||
assert out.shape == (1, 1, 4, 4, 8)
|
||||
|
||||
|
||||
class TestDownResidualBlock:
|
||||
"""Tests for Down_ResidualBlock."""
|
||||
|
||||
def test_no_downsample(self):
|
||||
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
||||
|
||||
block = Down_ResidualBlock(8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False)
|
||||
x = mx.random.normal((1, 2, 8, 8, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 8, 8, 8)
|
||||
|
||||
def test_spatial_downsample(self):
|
||||
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
||||
|
||||
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True)
|
||||
x = mx.random.normal((1, 2, 8, 8, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_spatial_temporal_downsample(self):
|
||||
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
||||
|
||||
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True)
|
||||
x = mx.random.normal((1, 4, 8, 8, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
|
||||
class TestEncoder3d:
|
||||
"""Tests for Encoder3d."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import Encoder3d
|
||||
|
||||
enc = Encoder3d(dim=16, z_dim=8)
|
||||
x = mx.random.normal((1, 1, 16, 16, 12))
|
||||
mx.eval(enc.parameters())
|
||||
out = enc(x)
|
||||
mx.eval(out)
|
||||
# 3 spatial downsamples ÷8: 16→2
|
||||
assert out.shape == (1, 1, 2, 2, 8)
|
||||
|
||||
def test_multi_frame(self):
|
||||
from mlx_video.models.wan.vae22 import Encoder3d
|
||||
|
||||
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
|
||||
x = mx.random.normal((1, 5, 16, 16, 12))
|
||||
mx.eval(enc.parameters())
|
||||
out = enc(x)
|
||||
mx.eval(out)
|
||||
# T: 5→3 (1st t_down) →2 (2nd t_down), spatial ÷8
|
||||
assert out.shape[2:] == (2, 2, 8)
|
||||
|
||||
|
||||
class TestWan22VAEEncoder:
|
||||
"""Tests for Wan22VAEEncoder wrapper."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||
# Input: single image 32×32 (patchify÷2 → 16×16, then 3 spatial ÷8 → 2×2)
|
||||
img = mx.random.normal((1, 1, 32, 32, 3))
|
||||
mx.eval(enc.parameters())
|
||||
z = enc(img)
|
||||
mx.eval(z)
|
||||
assert z.shape == (1, 1, 2, 2, 48)
|
||||
|
||||
def test_full_dim(self):
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=160)
|
||||
img = mx.random.normal((1, 1, 64, 64, 3))
|
||||
mx.eval(enc.parameters())
|
||||
z = enc(img)
|
||||
mx.eval(z)
|
||||
# 64 / 16 = 4 (vae stride 16×)
|
||||
assert z.shape == (1, 1, 4, 4, 48)
|
||||
|
||||
|
||||
class TestNormalizeLatents:
|
||||
"""Tests for normalize/denormalize latent roundtrip."""
|
||||
|
||||
def test_roundtrip(self):
|
||||
from mlx_video.models.wan.vae22 import denormalize_latents, normalize_latents
|
||||
|
||||
z = mx.random.normal((1, 2, 4, 4, 48))
|
||||
z_norm = normalize_latents(z)
|
||||
z_back = denormalize_latents(z_norm)
|
||||
mx.eval(z_back)
|
||||
assert float(mx.abs(z - z_back).max()) < 1e-4
|
||||
|
||||
|
||||
class TestVAEEncoderTemporalOrder:
|
||||
"""Tests that VAE encoder uses (False, True, True) temporal downsample order,
|
||||
matching official Wan2.2 vae2_2.py."""
|
||||
|
||||
def test_encoder_temporal_downsample_pattern(self):
|
||||
"""Encoder3d with (False, True, True): T=5→5→3→2."""
|
||||
from mlx_video.models.wan.vae22 import Encoder3d
|
||||
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
|
||||
x = mx.random.normal((1, 5, 16, 16, 12))
|
||||
mx.eval(enc.parameters())
|
||||
out = enc(x)
|
||||
mx.eval(out)
|
||||
assert out.shape[1] == 2
|
||||
|
||||
def test_wrapper_uses_correct_pattern(self):
|
||||
"""Wan22VAEEncoder should use (False, True, True) temporal downsample."""
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder, Resample
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||
down_blocks = enc.encoder.downsamples
|
||||
found_modes = []
|
||||
for block in down_blocks:
|
||||
for layer in block.downsamples:
|
||||
if isinstance(layer, Resample):
|
||||
found_modes.append(layer.mode)
|
||||
# First spatial-only, then two with temporal
|
||||
assert found_modes[0] == "downsample2d"
|
||||
assert any("3d" in m for m in found_modes)
|
||||
|
||||
def test_single_frame_encoder(self):
|
||||
"""Single frame (T=1) should work with (False, True, True) pattern."""
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||
img = mx.random.normal((1, 1, 32, 32, 3))
|
||||
mx.eval(enc.parameters())
|
||||
z = enc(img)
|
||||
mx.eval(z)
|
||||
assert z.shape[1] == 1
|
||||
assert z.shape[-1] == 48
|
||||
|
||||
def test_wrong_order_gives_different_result(self):
|
||||
"""(True, True, False) vs (False, True, True) produce different outputs."""
|
||||
from mlx_video.models.wan.vae22 import Encoder3d
|
||||
enc_correct = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
|
||||
enc_wrong = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
|
||||
|
||||
x = mx.random.normal((1, 5, 16, 16, 12))
|
||||
mx.eval(enc_correct.parameters())
|
||||
mx.eval(enc_wrong.parameters())
|
||||
|
||||
out_correct = enc_correct(x)
|
||||
out_wrong = enc_wrong(x)
|
||||
mx.eval(out_correct, out_wrong)
|
||||
|
||||
# Both give T=2 but spatial processing path differs
|
||||
assert out_correct.shape[1] == 2
|
||||
assert out_wrong.shape[1] == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VAE Encode → Decode Round-Trip Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVAE21RoundTrip:
|
||||
"""Encode→decode round-trip for Wan 2.1 VAE (channels-first)."""
|
||||
|
||||
def test_encode_decode_shape_and_values(self):
|
||||
"""Encoder3d → Decoder3d: output shape matches input, values are finite."""
|
||||
from mlx_video.models.wan.vae import Decoder3d, Encoder3d
|
||||
|
||||
z_dim = 4
|
||||
dim = 8
|
||||
# No temporal up/downsampling to keep the test simple
|
||||
enc = Encoder3d(
|
||||
dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False]
|
||||
)
|
||||
dec = Decoder3d(
|
||||
dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False]
|
||||
)
|
||||
mx.eval(enc.parameters(), dec.parameters())
|
||||
|
||||
# [B=1, C=3, T=1, H=8, W=8]
|
||||
x = mx.random.normal((1, 3, 1, 8, 8)) * 0.5
|
||||
|
||||
z = enc(x)
|
||||
mx.eval(z)
|
||||
# 3 spatial downsamples (÷8): H=1, W=1
|
||||
assert z.shape == (1, z_dim, 1, 1, 1)
|
||||
|
||||
x_hat = dec(z)
|
||||
mx.eval(x_hat)
|
||||
# 3 spatial upsamples (×8): should recover original shape
|
||||
assert x_hat.shape == x.shape
|
||||
|
||||
out_np = np.array(x_hat)
|
||||
assert np.all(np.isfinite(out_np))
|
||||
assert np.abs(out_np).max() < 1000
|
||||
|
||||
|
||||
class TestVAE22RoundTrip:
|
||||
"""Encode→decode round-trip for Wan 2.2 VAE (channels-last)."""
|
||||
|
||||
def test_encode_decode_shape_and_values(self):
|
||||
"""Wan22VAEEncoder → Wan22VAEDecoder: shapes consistent, values in range."""
|
||||
from mlx_video.models.wan.vae22 import (
|
||||
Wan22VAEDecoder,
|
||||
Wan22VAEEncoder,
|
||||
denormalize_latents,
|
||||
)
|
||||
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||
dec = Wan22VAEDecoder(z_dim=48, dec_dim=8)
|
||||
mx.eval(enc.parameters(), dec.parameters())
|
||||
|
||||
# [B=1, T=1, H=32, W=32, C=3]
|
||||
img = mx.random.normal((1, 1, 32, 32, 3)) * 0.5
|
||||
|
||||
z_norm = enc(img)
|
||||
mx.eval(z_norm)
|
||||
# patchify(÷2) + 3 spatial downsamples(÷8) = ÷16
|
||||
assert z_norm.shape == (1, 1, 2, 2, 48)
|
||||
|
||||
z = denormalize_latents(z_norm)
|
||||
out = dec(z)
|
||||
mx.eval(out)
|
||||
|
||||
# 3 spatial upsamples(×8) + unpatchify(×2) = ×16
|
||||
assert out.shape[0] == 1 # batch
|
||||
assert out.shape[2] == 32 # H recovered
|
||||
assert out.shape[3] == 32 # W recovered
|
||||
assert out.shape[-1] == 3 # RGB
|
||||
|
||||
out_np = np.array(out)
|
||||
assert np.all(np.isfinite(out_np))
|
||||
assert out_np.min() >= -1.0 - 1e-6
|
||||
assert out_np.max() <= 1.0 + 1e-6
|
||||
|
||||
|
||||
|
||||
19
tests/wan_test_helpers.py
Normal file
19
tests/wan_test_helpers.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Shared test helpers for Wan test modules."""
|
||||
|
||||
|
||||
def _make_tiny_config():
|
||||
"""Create a tiny WanModelConfig for testing."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
config = WanModelConfig()
|
||||
# Override to tiny values
|
||||
config.dim = 64
|
||||
config.ffn_dim = 128
|
||||
config.num_heads = 4
|
||||
config.num_layers = 2
|
||||
config.in_dim = 4
|
||||
config.out_dim = 4
|
||||
config.patch_size = (1, 2, 2)
|
||||
config.freq_dim = 32
|
||||
config.text_dim = 32
|
||||
config.text_len = 8
|
||||
return config
|
||||
Reference in New Issue
Block a user