"""Tests for Wan model components.""" import mlx.core as mx import mlx.nn as nn import numpy as np from wan_test_helpers import _make_tiny_config # --------------------------------------------------------------------------- # Sinusoidal Embedding Tests # --------------------------------------------------------------------------- class TestSinusoidalEmbedding: def test_output_shape(self): from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d pos = mx.arange(10).astype(mx.float32) emb = sinusoidal_embedding_1d(256, pos) mx.eval(emb) assert emb.shape == (10, 256) def test_position_zero(self): """Position 0 should have cos=1 for all dims and sin=0.""" from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d pos = mx.array([0.0]) emb = sinusoidal_embedding_1d(64, pos) mx.eval(emb) emb_np = np.array(emb[0]) # First half is cos, should be 1 at position 0 np.testing.assert_allclose(emb_np[:32], 1.0, atol=1e-5) # Second half is sin, should be 0 at position 0 np.testing.assert_allclose(emb_np[32:], 0.0, atol=1e-5) def test_different_positions_differ(self): from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d pos = mx.array([0.0, 100.0, 999.0]) emb = sinusoidal_embedding_1d(128, pos) mx.eval(emb) emb_np = np.array(emb) assert not np.allclose(emb_np[0], emb_np[1]) assert not np.allclose(emb_np[1], emb_np[2]) # --------------------------------------------------------------------------- # Head Tests # --------------------------------------------------------------------------- class TestHead: def test_output_shape(self): from mlx_video.models.wan_2.wan_2 import Head head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2)) B, L = 1, 24 x = mx.random.normal((B, L, 64)) e = mx.random.normal((B, 64)) # time embedding: [B, dim] out = head(x, e) mx.eval(out) expected_proj_dim = 16 * 1 * 2 * 2 # 64 assert out.shape == (B, L, expected_proj_dim) def test_modulation_shape(self): from mlx_video.models.wan_2.wan_2 import Head head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2)) assert head.modulation.shape == (1, 2, 64) # --------------------------------------------------------------------------- # WanModel (Tiny) Tests # --------------------------------------------------------------------------- class TestWanModel: def setup_method(self): mx.random.seed(42) def test_instantiation(self): from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() model = WanModel(config) num_params = sum(p.size for _, p in nn.utils.tree_flatten(model.parameters())) assert num_params > 0 def test_patchify_shape(self): from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() model = WanModel(config) # Input: [C=4, F=1, H=4, W=4] x = mx.random.normal((4, 1, 4, 4)) patches, grid_size = model._patchify(x) mx.eval(patches) # Patch size (1,2,2): F'=1, H'=2, W'=2 assert grid_size == (1, 2, 2) assert patches.shape == (1, 1 * 2 * 2, config.dim) def test_patchify_various_sizes(self): from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() model = WanModel(config) for f, h, w in [(1, 4, 4), (2, 6, 8), (3, 4, 6)]: x = mx.random.normal((config.in_dim, f, h, w)) patches, (gf, gh, gw) = model._patchify(x) mx.eval(patches) pt, ph, pw = config.patch_size assert gf == f // pt assert gh == h // ph assert gw == w // pw assert patches.shape[1] == gf * gh * gw def test_unpatchify_inverse(self): """Patchify then unpatchify should reconstruct original spatial dims.""" from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() model = WanModel(config) C, F, H, W = config.in_dim, 2, 4, 6 pt, ph, pw = config.patch_size F_out, H_out, W_out = F // pt, H // ph, W // pw L = F_out * H_out * W_out proj_dim = config.out_dim * pt * ph * pw # Simulated head output x = mx.random.normal((1, L, proj_dim)) out = model.unpatchify(x, [(F_out, H_out, W_out)]) mx.eval(out[0]) assert out[0].shape == (config.out_dim, F, H, W) def test_forward_pass(self): from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() model = WanModel(config) C, F, H, W = config.in_dim, 1, 4, 4 pt, ph, pw = config.patch_size seq_len = (F // pt) * (H // ph) * (W // pw) x_list = [mx.random.normal((C, F, H, W))] t = mx.array([500.0]) context = [mx.random.normal((6, config.text_dim))] out = model(x_list, t, context, seq_len) mx.eval(out[0]) assert len(out) == 1 assert out[0].shape == (C, F, H, W) def test_forward_batch(self): from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() model = WanModel(config) C, F, H, W = config.in_dim, 1, 4, 4 pt, ph, pw = config.patch_size seq_len = (F // pt) * (H // ph) * (W // pw) x_list = [mx.random.normal((C, F, H, W)), mx.random.normal((C, F, H, W))] t = mx.array([500.0, 200.0]) context = [ mx.random.normal((6, config.text_dim)), mx.random.normal((4, config.text_dim)), ] out = model(x_list, t, context, seq_len) mx.eval(out[0], out[1]) assert len(out) == 2 for o in out: assert o.shape == (C, F, H, W) def test_output_is_float32(self): from mlx_video.models.wan_2.wan_2 import WanModel config = _make_tiny_config() model = WanModel(config) C, F, H, W = config.in_dim, 1, 4, 4 seq_len = (F // 1) * (H // 2) * (W // 2) out = model( [mx.random.normal((C, F, H, W))], mx.array([100.0]), [mx.random.normal((4, config.text_dim))], seq_len, ) mx.eval(out[0]) assert out[0].dtype == mx.float32 # --------------------------------------------------------------------------- # Wan2.1 Model Tests # --------------------------------------------------------------------------- class TestWan21Model: """Test tiny Wan2.1-style model (single model mode).""" def setup_method(self): mx.random.seed(42) def _make_tiny_wan21_config(self): """Create a tiny config mimicking Wan2.1 (single model).""" from mlx_video.models.wan_2.config import WanModelConfig config = WanModelConfig.wan21_t2v_14b() # Override to tiny values config.dim = 64 config.ffn_dim = 128 config.num_heads = 4 config.num_layers = 2 config.in_dim = 4 config.out_dim = 4 config.freq_dim = 32 config.text_dim = 32 config.text_len = 8 return config def _make_tiny_wan21_1_3b_config(self): """Create a tiny config mimicking Wan2.1 1.3B.""" from mlx_video.models.wan_2.config import WanModelConfig config = WanModelConfig.wan21_t2v_1_3b() # Override to tiny values (preserve 1.3B head structure: 12 heads) config.dim = 48 config.ffn_dim = 96 config.num_heads = 4 config.num_layers = 2 config.in_dim = 4 config.out_dim = 4 config.freq_dim = 24 config.text_dim = 24 config.text_len = 8 return config def test_wan21_tiny_model_forward(self): """Forward pass with Wan2.1 tiny config.""" from mlx_video.models.wan_2.wan_2 import WanModel config = self._make_tiny_wan21_config() model = WanModel(config) C, F, H, W = config.in_dim, 1, 4, 4 seq_len = (F // 1) * (H // 2) * (W // 2) latents = mx.random.normal((C, F, H, W)) context = mx.random.normal((4, config.text_dim)) t = mx.array([500.0]) out = model([latents], t, [context], seq_len) mx.eval(out) assert out[0].shape == (C, F, H, W) def test_wan21_1_3b_tiny_model_forward(self): """Forward pass with Wan2.1 1.3B tiny config.""" from mlx_video.models.wan_2.wan_2 import WanModel config = self._make_tiny_wan21_1_3b_config() model = WanModel(config) C, F, H, W = config.in_dim, 1, 4, 4 seq_len = (F // 1) * (H // 2) * (W // 2) latents = mx.random.normal((C, F, H, W)) context = mx.random.normal((4, config.text_dim)) t = mx.array([500.0]) out = model([latents], t, [context], seq_len) mx.eval(out) assert out[0].shape == (C, F, H, W) def test_wan21_single_model_loop(self): """Full diffusion loop with single model (Wan2.1 style).""" from mlx_video.models.wan_2.wan_2 import WanModel from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler config = self._make_tiny_wan21_config() model = WanModel(config) C, F, H, W = config.in_dim, 1, 4, 4 seq_len = (F // 1) * (H // 2) * (W // 2) sched = FlowMatchEulerScheduler() sched.set_timesteps(config.sample_steps, shift=config.sample_shift) # Use only 3 steps for speed latents = mx.random.normal((C, F, H, W)) context = mx.random.normal((4, config.text_dim)) context_null = mx.zeros((4, config.text_dim)) gs = config.sample_guide_scale # Should be float for Wan2.1 assert isinstance(gs, float), "Wan2.1 guide_scale should be float" for i in range(3): t = sched.timesteps[i] pred_cond = model([latents], mx.array([t.item()]), [context], seq_len)[0] pred_uncond = model( [latents], mx.array([t.item()]), [context_null], seq_len )[0] pred = pred_uncond + gs * (pred_cond - pred_uncond) latents = sched.step(pred[None], t, latents[None]).squeeze(0) mx.eval(latents) assert latents.shape == (C, F, H, W) assert not mx.any(mx.isnan(latents)).item() def test_wan21_vs_wan22_config_differences(self): """Verify key differences between Wan2.1 and Wan2.2 configs.""" from mlx_video.models.wan_2.config import WanModelConfig c21 = WanModelConfig.wan21_t2v_14b() c22 = WanModelConfig.wan22_t2v_14b() # Same architecture assert c21.dim == c22.dim assert c21.num_heads == c22.num_heads assert c21.num_layers == c22.num_layers # Different pipeline settings assert c21.dual_model is False assert c22.dual_model is True assert isinstance(c21.sample_guide_scale, float) assert isinstance(c22.sample_guide_scale, tuple) assert c21.sample_shift != c22.sample_shift assert c21.sample_steps != c22.sample_steps # --------------------------------------------------------------------------- # Per-Token Timestep Tests # --------------------------------------------------------------------------- class TestPerTokenTimestep: """Tests for per-token sinusoidal embedding.""" def test_1d_unchanged(self): from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d pos = mx.array([0.0, 100.0, 500.0]) emb = sinusoidal_embedding_1d(256, pos) assert emb.shape == (3, 256) def test_2d_per_token(self): from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d pos = mx.array([[0.0, 100.0, 100.0], [50.0, 50.0, 50.0]]) emb = sinusoidal_embedding_1d(256, pos) assert emb.shape == (2, 3, 256) def test_consistency(self): from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d pos_1d = mx.array([0.0, 100.0]) emb_1d = sinusoidal_embedding_1d(256, pos_1d) pos_2d = mx.array([[0.0, 100.0]]) emb_2d = sinusoidal_embedding_1d(256, pos_2d) assert mx.array_equal(emb_1d[0], emb_2d[0, 0]) assert mx.array_equal(emb_1d[1], emb_2d[0, 1])