Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.

This commit is contained in:
Prince Canuma
2026-03-18 17:52:30 +01:00
parent 17397da70c
commit 6c63163671
28 changed files with 354 additions and 1033 deletions

View File

@@ -12,14 +12,14 @@ class TestRoPE:
"""Tests for 3-way factorized RoPE."""
def test_rope_params_shape(self):
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
freqs = rope_params(1024, 64)
mx.eval(freqs)
assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2]
def test_rope_params_different_dims(self):
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
for dim in [32, 64, 128]:
freqs = rope_params(512, dim)
@@ -27,7 +27,7 @@ class TestRoPE:
assert freqs.shape == (512, dim // 2, 2)
def test_rope_params_cos_sin_range(self):
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
freqs = rope_params(256, 64)
mx.eval(freqs)
@@ -38,7 +38,7 @@ class TestRoPE:
def test_rope_params_position_zero(self):
"""At position 0, cos should be 1 and sin should be 0."""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
freqs = rope_params(10, 64)
mx.eval(freqs)
@@ -46,7 +46,7 @@ class TestRoPE:
np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6)
def test_rope_apply_output_shape(self):
from mlx_video.models.wan.rope import rope_apply, rope_params
from mlx_video.models.wan2.rope import rope_apply, rope_params
B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim
x = mx.random.normal((B, L, N, D))
@@ -58,7 +58,7 @@ class TestRoPE:
def test_rope_apply_preserves_norm(self):
"""RoPE rotation should preserve vector norms."""
from mlx_video.models.wan.rope import rope_apply, rope_params
from mlx_video.models.wan2.rope import rope_apply, rope_params
B, N, D = 1, 2, 16
F, H, W = 2, 3, 4
@@ -79,7 +79,7 @@ class TestRoPE:
def test_rope_apply_with_padding(self):
"""When seq_len < L, extra tokens should be preserved unchanged."""
from mlx_video.models.wan.rope import rope_apply, rope_params
from mlx_video.models.wan2.rope import rope_apply, rope_params
B, N, D = 1, 2, 16
F, H, W = 2, 2, 2
@@ -100,7 +100,7 @@ class TestRoPE:
def test_rope_apply_batch(self):
"""Test with batch_size > 1 and different grid sizes."""
from mlx_video.models.wan.rope import rope_apply, rope_params
from mlx_video.models.wan2.rope import rope_apply, rope_params
B, N, D = 2, 2, 16
grids = [(2, 3, 4), (2, 3, 4)]
@@ -132,7 +132,7 @@ class TestRoPE:
class TestWanRMSNorm:
def test_output_shape(self):
from mlx_video.models.wan.attention import WanRMSNorm
from mlx_video.models.wan2.attention import WanRMSNorm
norm = WanRMSNorm(64)
x = mx.random.normal((2, 10, 64))
@@ -142,7 +142,7 @@ class TestWanRMSNorm:
def test_zero_mean_variance(self):
"""RMS norm should make RMS ≈ 1 before scaling."""
from mlx_video.models.wan.attention import WanRMSNorm
from mlx_video.models.wan2.attention import WanRMSNorm
norm = WanRMSNorm(64)
x = mx.random.normal((1, 5, 64)) * 10.0
@@ -156,7 +156,7 @@ class TestWanRMSNorm:
def test_dtype_preservation(self):
"""RMSNorm weight is float32, so output is promoted to float32."""
from mlx_video.models.wan.attention import WanRMSNorm
from mlx_video.models.wan2.attention import WanRMSNorm
norm = WanRMSNorm(32)
x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16)
@@ -168,7 +168,7 @@ class TestWanRMSNorm:
class TestWanLayerNorm:
def test_output_shape(self):
from mlx_video.models.wan.attention import WanLayerNorm
from mlx_video.models.wan2.attention import WanLayerNorm
norm = WanLayerNorm(64)
x = mx.random.normal((2, 10, 64))
@@ -177,7 +177,7 @@ class TestWanLayerNorm:
assert out.shape == (2, 10, 64)
def test_without_affine(self):
from mlx_video.models.wan.attention import WanLayerNorm
from mlx_video.models.wan2.attention import WanLayerNorm
norm = WanLayerNorm(64, elementwise_affine=False)
x = mx.random.normal((1, 4, 64))
@@ -190,7 +190,7 @@ class TestWanLayerNorm:
np.testing.assert_allclose(np.std(out_np[i]), 1.0, rtol=0.1)
def test_with_affine(self):
from mlx_video.models.wan.attention import WanLayerNorm
from mlx_video.models.wan2.attention import WanLayerNorm
norm = WanLayerNorm(32, elementwise_affine=True)
assert hasattr(norm, "weight")
@@ -208,8 +208,8 @@ class TestWanSelfAttention:
self.num_heads = 4
def test_output_shape(self):
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.attention import WanSelfAttention
from mlx_video.models.wan2.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads)
B, L = 1, 24
@@ -221,14 +221,14 @@ class TestWanSelfAttention:
assert out.shape == (B, L, self.dim)
def test_with_qk_norm(self):
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan2.attention import WanSelfAttention
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True)
assert attn.norm_q is not None
assert attn.norm_k is not None
def test_without_qk_norm(self):
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan2.attention import WanSelfAttention
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
assert attn.norm_q is None
@@ -236,8 +236,8 @@ class TestWanSelfAttention:
def test_masking(self):
"""Test that masking works: shorter seq_lens should mask later tokens."""
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.attention import WanSelfAttention
from mlx_video.models.wan2.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
B, L = 1, 24
@@ -262,7 +262,7 @@ class TestWanCrossAttention:
self.num_heads = 4
def test_output_shape(self):
from mlx_video.models.wan.attention import WanCrossAttention
from mlx_video.models.wan2.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
B, L_q, L_kv = 1, 24, 16
@@ -273,7 +273,7 @@ class TestWanCrossAttention:
assert out.shape == (B, L_q, self.dim)
def test_with_context_mask(self):
from mlx_video.models.wan.attention import WanCrossAttention
from mlx_video.models.wan2.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
B, L_q, L_kv = 1, 12, 16
@@ -311,8 +311,8 @@ class TestBFloat16Autocast:
def test_self_attn_casts_to_weight_dtype(self):
"""Self-attention should cast input to weight dtype for QKV projections."""
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.attention import WanSelfAttention
from mlx_video.models.wan2.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
@@ -326,7 +326,7 @@ class TestBFloat16Autocast:
def test_cross_attn_casts_to_weight_dtype(self):
"""Cross-attention should cast input to weight dtype."""
from mlx_video.models.wan.attention import WanCrossAttention
from mlx_video.models.wan2.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
@@ -340,7 +340,7 @@ class TestBFloat16Autocast:
def test_cross_attn_kv_cache_uses_weight_dtype(self):
"""prepare_kv should cast context to weight dtype."""
from mlx_video.models.wan.attention import WanCrossAttention
from mlx_video.models.wan2.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
@@ -353,7 +353,7 @@ class TestBFloat16Autocast:
def test_ffn_casts_to_weight_dtype(self):
"""FFN should cast input to weight dtype for linear layers."""
from mlx_video.models.wan.transformer import WanFFN
from mlx_video.models.wan2.transformer import WanFFN
ffn = WanFFN(self.dim, 128)
ffn.update(self._to_bf16(ffn.parameters()))
@@ -366,8 +366,8 @@ class TestBFloat16Autocast:
def test_self_attn_rope_in_float32(self):
"""RoPE should be applied in float32 for precision, even with bf16 weights."""
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.attention import WanSelfAttention
from mlx_video.models.wan2.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
@@ -381,8 +381,8 @@ class TestBFloat16Autocast:
def test_block_float32_residual_with_bf16_weights(self):
"""Full block: residual stream stays float32, matmuls use bf16 weights."""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan2.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, 128, self.num_heads, cross_attn_norm=True)
block.update(self._to_bf16(block.parameters()))

View File

@@ -10,7 +10,7 @@ class TestWanModelConfig:
"""Tests for WanModelConfig dataclass."""
def test_default_values(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig()
assert config.dim == 5120
@@ -32,13 +32,13 @@ class TestWanModelConfig:
assert config.text_len == 512
def test_head_dim_property(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig()
assert config.head_dim == 128 # 5120 // 40
def test_to_dict_roundtrip(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig()
d = config.to_dict()
@@ -48,7 +48,7 @@ class TestWanModelConfig:
assert d["boundary"] == 0.875
def test_t5_config_values(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig()
assert config.t5_vocab_size == 256384
@@ -69,7 +69,7 @@ class TestWan21Config:
"""Tests for Wan2.1 config presets."""
def test_wan21_14b_factory(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
assert config.model_version == "2.1"
@@ -85,7 +85,7 @@ class TestWan21Config:
assert config.boundary == 0.0
def test_wan21_1_3b_factory(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b()
assert config.model_version == "2.1"
@@ -98,7 +98,7 @@ class TestWan21Config:
assert config.sample_guide_scale == 5.0
def test_wan22_14b_factory(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig.wan22_t2v_14b()
assert config.model_version == "2.2"
@@ -110,7 +110,7 @@ class TestWan21Config:
assert config.boundary == 0.875
def test_wan21_config_to_dict(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
d = config.to_dict()
@@ -119,7 +119,7 @@ class TestWan21Config:
assert d["sample_guide_scale"] == 5.0
def test_wan21_1_3b_config_to_dict(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b()
d = config.to_dict()
@@ -128,7 +128,7 @@ class TestWan21Config:
def test_default_config_is_wan22(self):
"""Default WanModelConfig() should be Wan2.2 14B."""
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig()
assert config.model_version == "2.2"

View File

@@ -11,7 +11,7 @@ import mlx.core as mx
class TestSanitizeTransformerWeights:
def test_patch_embedding_reshape(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
weights = {
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
@@ -23,7 +23,7 @@ class TestSanitizeTransformerWeights:
assert out["patch_embedding_proj.weight"].shape == (5120, 16 * 1 * 2 * 2)
def test_text_embedding_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
weights = {
"text_embedding.0.weight": mx.zeros((64, 32)),
@@ -38,7 +38,7 @@ class TestSanitizeTransformerWeights:
assert "text_embedding_1.bias" in out
def test_time_embedding_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
weights = {
"time_embedding.0.weight": mx.zeros((64, 32)),
@@ -49,7 +49,7 @@ class TestSanitizeTransformerWeights:
assert "time_embedding_1.weight" in out
def test_time_projection_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
weights = {
"time_projection.1.weight": mx.zeros((384, 64)),
@@ -60,7 +60,7 @@ class TestSanitizeTransformerWeights:
assert "time_projection.bias" in out
def test_ffn_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
weights = {
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
@@ -75,7 +75,7 @@ class TestSanitizeTransformerWeights:
assert "blocks.0.ffn.fc2.bias" in out
def test_freqs_skipped(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
weights = {
"freqs": mx.zeros((1024, 64, 2)),
@@ -86,7 +86,7 @@ class TestSanitizeTransformerWeights:
assert "blocks.0.norm1.weight" in out
def test_passthrough_keys(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
weights = {
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
@@ -102,7 +102,7 @@ class TestSanitizeTransformerWeights:
assert key in out
def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
weights = {
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
@@ -119,14 +119,14 @@ class TestSanitizeTransformerWeights:
"head.head.weight": mx.zeros((64, 64)),
"freqs": mx.zeros((1024, 64, 2)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"):
sanitize_wan_transformer_weights(weights)
assert "Unconsumed" not in caplog.text
class TestSanitizeT5Weights:
def test_gate_rename(self):
from mlx_video.convert_wan import sanitize_wan_t5_weights
from mlx_video.models.wan2.convert import sanitize_wan_t5_weights
weights = {
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
@@ -139,7 +139,7 @@ class TestSanitizeT5Weights:
assert "blocks.0.ffn.fc2.weight" in out
def test_passthrough(self):
from mlx_video.convert_wan import sanitize_wan_t5_weights
from mlx_video.models.wan2.convert import sanitize_wan_t5_weights
weights = {
"token_embedding.weight": mx.zeros((100, 64)),
@@ -151,7 +151,7 @@ class TestSanitizeT5Weights:
assert key in out
def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_t5_weights
from mlx_video.models.wan2.convert import sanitize_wan_t5_weights
weights = {
"token_embedding.weight": mx.zeros((100, 64)),
@@ -160,14 +160,14 @@ class TestSanitizeT5Weights:
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
"norm.weight": mx.zeros((64,)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"):
sanitize_wan_t5_weights(weights)
assert "Unconsumed" not in caplog.text
class TestSanitizeVAEWeights:
def test_conv3d_transpose(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
weights = {
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W]
@@ -176,7 +176,7 @@ class TestSanitizeVAEWeights:
assert out["decoder.conv1.weight"].shape == (8, 3, 3, 3, 4) # [O, D, H, W, I]
def test_conv2d_transpose(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
weights = {
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W]
@@ -185,7 +185,7 @@ class TestSanitizeVAEWeights:
assert out["decoder.proj.weight"].shape == (16, 3, 3, 8) # [O, H, W, I]
def test_non_conv_passthrough(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
weights = {
"decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose
@@ -196,7 +196,7 @@ class TestSanitizeVAEWeights:
assert out["decoder.bias"].shape == (16,)
def test_mixed_weights(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
weights = {
"conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D
@@ -211,7 +211,7 @@ class TestSanitizeVAEWeights:
assert out["norm.weight"].shape == (8,)
def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_vae_weights
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
weights = {
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)),
@@ -219,7 +219,7 @@ class TestSanitizeVAEWeights:
"decoder.norm.weight": mx.zeros((64,)),
"decoder.bias": mx.zeros((16,)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"):
sanitize_wan_vae_weights(weights)
assert "Unconsumed" not in caplog.text
@@ -256,7 +256,7 @@ class TestWan21Convert:
def test_wan21_config_saved_correctly(self):
"""Verify config dict has correct fields for Wan2.1."""
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
d = config.to_dict()
@@ -275,7 +275,7 @@ class TestSanitizeEncoderWeights:
"""Tests for sanitize_wan22_vae_weights with include_encoder."""
def test_exclude_encoder_by_default(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
@@ -287,7 +287,7 @@ class TestSanitizeEncoderWeights:
assert not any("encoder" in k or k.startswith("conv1") for k in out)
def test_include_encoder(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
@@ -300,25 +300,25 @@ class TestSanitizeEncoderWeights:
assert "conv2.weight" in out
def test_no_unconsumed_keys(self, caplog):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.vae22"):
sanitize_wan22_vae_weights(weights, include_encoder=True)
assert "Unconsumed" not in caplog.text
def test_no_unconsumed_keys_exclude_encoder(self, caplog):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.vae22"):
sanitize_wan22_vae_weights(weights, include_encoder=False)
assert "Unconsumed" not in caplog.text

View File

@@ -14,8 +14,8 @@ class TestEndToEnd:
def test_tiny_model_denoise_step(self):
"""Simulate one denoising step with tiny model."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.model import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
mx.random.seed(42)
config = _make_tiny_config()
@@ -43,8 +43,8 @@ class TestEndToEnd:
def test_tiny_model_full_loop(self):
"""Run a complete (tiny) diffusion loop."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.model import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
mx.random.seed(123)
config = _make_tiny_config()
@@ -81,7 +81,7 @@ class TestI2VMask:
"""Tests for _build_i2v_mask."""
def test_mask_shapes(self):
from mlx_video.generate_wan import _build_i2v_mask
from mlx_video.models.wan2.generate import _build_i2v_mask
z_shape = (48, 5, 4, 4) # C, T, H, W
patch_size = (1, 2, 2)
@@ -91,7 +91,7 @@ class TestI2VMask:
assert mask_tokens.shape == (1, 20)
def test_first_frame_zero(self):
from mlx_video.generate_wan import _build_i2v_mask
from mlx_video.models.wan2.generate import _build_i2v_mask
z_shape = (48, 5, 4, 4)
mask, mask_tokens = _build_i2v_mask(z_shape, (1, 2, 2))
@@ -111,7 +111,7 @@ class TestI2VMaskAlignment:
def test_mask_with_ti2v_dimensions(self):
"""Mask should work with TI2V-5B typical dimensions."""
from mlx_video.generate_wan import _build_i2v_mask
from mlx_video.models.wan2.generate import _build_i2v_mask
# TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2)
# 704x1280 → latent 44x80, t_latent=21 for 81 frames
@@ -132,7 +132,7 @@ class TestI2VMaskAlignment:
def test_mask_per_token_timestep(self):
"""Per-token timesteps: first-frame tokens get t=0, rest get t=sigma."""
from mlx_video.generate_wan import _build_i2v_mask
from mlx_video.models.wan2.generate import _build_i2v_mask
z_shape = (4, 3, 4, 4)
patch_size = (1, 2, 2)
@@ -201,7 +201,7 @@ class TestDimensionAlignment:
def test_patchify_valid_after_alignment(self):
"""After alignment, patchify should succeed without reshape errors."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -235,7 +235,7 @@ class TestDimensionAlignment:
def test_alignment_with_ti2v_config(self):
"""TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32."""
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig.wan22_ti2v_5b()
align_h = config.patch_size[1] * config.vae_stride[1]

View File

@@ -23,7 +23,7 @@ class TestI2VConfig:
"""Test I2V-14B config preset."""
def test_wan22_i2v_14b_preset(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig.wan22_i2v_14b()
assert config.model_type == "i2v"
@@ -39,7 +39,7 @@ class TestI2VConfig:
assert config.vae_z_dim == 16
def test_i2v_vs_t2v_differences(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
i2v = WanModelConfig.wan22_i2v_14b()
t2v = WanModelConfig.wan22_t2v_14b()
@@ -51,7 +51,7 @@ class TestI2VConfig:
assert i2v.sample_shift == 5.0 and t2v.sample_shift == 12.0
def test_i2v_serialization_roundtrip(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig.wan22_i2v_14b()
d = config.to_dict()
@@ -66,7 +66,7 @@ class TestModelYParameter:
def test_forward_without_y(self):
"""Standard T2V forward pass (no y) still works."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -85,7 +85,7 @@ class TestModelYParameter:
def test_forward_with_y(self):
"""I2V forward pass with y channel concatenation."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_i2v_config()
model = WanModel(config)
@@ -108,7 +108,7 @@ class TestModelYParameter:
def test_y_none_is_noop(self):
"""Passing y=None should be identical to not passing y."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -129,7 +129,7 @@ class TestModelYParameter:
def test_batched_cfg_with_y(self):
"""Batched CFG (B=2) with y should work."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_i2v_config()
model = WanModel(config)
@@ -158,7 +158,7 @@ class TestVAEEncoder:
"""Test Wan2.1 VAE encoder."""
def test_encoder3d_instantiation(self):
from mlx_video.models.wan.vae import Encoder3d
from mlx_video.models.wan2.vae import Encoder3d
enc = Encoder3d(
dim=32, z_dim=8
@@ -169,7 +169,7 @@ class TestVAEEncoder:
def test_encoder3d_output_shape(self):
"""Encoder should downsample spatially by 8x and temporally by 4x."""
from mlx_video.models.wan.vae import Encoder3d
from mlx_video.models.wan2.vae import Encoder3d
enc = Encoder3d(dim=32, z_dim=8)
# Random input: [B=1, 3, T=5, H=32, W=32]
@@ -186,7 +186,7 @@ class TestVAEEncoder:
def test_wan_vae_encode(self):
"""WanVAE with encoder=True should produce normalized latents."""
from mlx_video.models.wan.vae import WanVAE
from mlx_video.models.wan2.vae import WanVAE
vae = WanVAE(z_dim=16, encoder=True)
# Input: [B=1, 3, T=5, H=32, W=32]
@@ -198,7 +198,7 @@ class TestVAEEncoder:
def test_wan_vae_encoder_flag(self):
"""WanVAE without encoder flag should not have encoder attribute."""
from mlx_video.models.wan.vae import WanVAE
from mlx_video.models.wan2.vae import WanVAE
vae_no_enc = WanVAE(z_dim=4, encoder=False)
assert not hasattr(vae_no_enc, "encoder")
@@ -211,7 +211,7 @@ class TestResampleDownsample:
"""Test downsample modes in Resample."""
def test_downsample2d(self):
from mlx_video.models.wan.vae import Resample
from mlx_video.models.wan2.vae import Resample
r = Resample(dim=16, mode="downsample2d")
x = mx.random.normal((1, 16, 2, 8, 8))
@@ -221,7 +221,7 @@ class TestResampleDownsample:
assert out.shape == (1, 16, 2, 4, 4)
def test_downsample3d(self):
from mlx_video.models.wan.vae import Resample
from mlx_video.models.wan2.vae import Resample
r = Resample(dim=16, mode="downsample3d")
x = mx.random.normal((1, 16, 4, 8, 8))
@@ -231,7 +231,7 @@ class TestResampleDownsample:
assert out.shape == (1, 16, 2, 4, 4)
def test_upsample2d_still_works(self):
from mlx_video.models.wan.vae import Resample
from mlx_video.models.wan2.vae import Resample
r = Resample(dim=16, mode="upsample2d")
x = mx.random.normal((1, 16, 2, 4, 4))
@@ -240,7 +240,7 @@ class TestResampleDownsample:
assert out.shape == (1, 8, 2, 8, 8)
def test_upsample3d_still_works(self):
from mlx_video.models.wan.vae import Resample
from mlx_video.models.wan2.vae import Resample
r = Resample(dim=16, mode="upsample3d")
x = mx.random.normal((1, 16, 2, 4, 4))
@@ -307,9 +307,9 @@ class TestI2VEndToEndPipeline:
def test_full_i2v_pipeline(self):
"""End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan.vae import WanVAE
from mlx_video.models.wan2.model import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.vae import WanVAE
mx.random.seed(0)
@@ -410,8 +410,8 @@ class TestDualModelSwitching:
def test_model_selection_by_timestep(self):
"""Verify high_noise model used for timesteps >= boundary, low_noise otherwise."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.model import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
mx.random.seed(1)
config = _make_tiny_i2v_config()
@@ -485,8 +485,8 @@ class TestDualModelSwitching:
def test_guide_scale_tuple_applied_per_model(self):
"""Verify (low_gs, high_gs) tuple applies different scales per model."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.model import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
mx.random.seed(2)
config = _make_tiny_i2v_config()
@@ -545,8 +545,8 @@ class TestDualModelSwitching:
def test_single_model_fallback_with_tuple_guide_scale(self):
"""When dual_model=False, guide_scale tuple should use first element."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.model import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
mx.random.seed(3)
config = _make_tiny_config()

View File

@@ -331,7 +331,7 @@ class TestEndToEnd:
"""End-to-end LoRA loading and application."""
def test_load_and_apply_loras(self):
from mlx_video.convert_wan import load_and_apply_loras
from mlx_video.models.wan2.convert import load_and_apply_loras
with tempfile.TemporaryDirectory() as tmp:
# Create mock LoRA safetensors

View File

@@ -12,7 +12,7 @@ from wan_test_helpers import _make_tiny_config
class TestSinusoidalEmbedding:
def test_output_shape(self):
from mlx_video.models.wan.model import sinusoidal_embedding_1d
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
pos = mx.arange(10).astype(mx.float32)
emb = sinusoidal_embedding_1d(256, pos)
@@ -21,7 +21,7 @@ class TestSinusoidalEmbedding:
def test_position_zero(self):
"""Position 0 should have cos=1 for all dims and sin=0."""
from mlx_video.models.wan.model import sinusoidal_embedding_1d
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
pos = mx.array([0.0])
emb = sinusoidal_embedding_1d(64, pos)
@@ -33,7 +33,7 @@ class TestSinusoidalEmbedding:
np.testing.assert_allclose(emb_np[32:], 0.0, atol=1e-5)
def test_different_positions_differ(self):
from mlx_video.models.wan.model import sinusoidal_embedding_1d
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
pos = mx.array([0.0, 100.0, 999.0])
emb = sinusoidal_embedding_1d(128, pos)
@@ -50,7 +50,7 @@ class TestSinusoidalEmbedding:
class TestHead:
def test_output_shape(self):
from mlx_video.models.wan.model import Head
from mlx_video.models.wan2.model import Head
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
B, L = 1, 24
@@ -62,7 +62,7 @@ class TestHead:
assert out.shape == (B, L, expected_proj_dim)
def test_modulation_shape(self):
from mlx_video.models.wan.model import Head
from mlx_video.models.wan2.model import Head
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
assert head.modulation.shape == (1, 2, 64)
@@ -78,7 +78,7 @@ class TestWanModel:
mx.random.seed(42)
def test_instantiation(self):
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -86,7 +86,7 @@ class TestWanModel:
assert num_params > 0
def test_patchify_shape(self):
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -99,7 +99,7 @@ class TestWanModel:
assert patches.shape == (1, 1 * 2 * 2, config.dim)
def test_patchify_various_sizes(self):
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -115,7 +115,7 @@ class TestWanModel:
def test_unpatchify_inverse(self):
"""Patchify then unpatchify should reconstruct original spatial dims."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -131,7 +131,7 @@ class TestWanModel:
assert out[0].shape == (config.out_dim, F, H, W)
def test_forward_pass(self):
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -149,7 +149,7 @@ class TestWanModel:
assert out[0].shape == (C, F, H, W)
def test_forward_batch(self):
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -171,7 +171,7 @@ class TestWanModel:
assert o.shape == (C, F, H, W)
def test_output_is_float32(self):
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -200,7 +200,7 @@ class TestWan21Model:
def _make_tiny_wan21_config(self):
"""Create a tiny config mimicking Wan2.1 (single model)."""
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
# Override to tiny values
@@ -217,7 +217,7 @@ class TestWan21Model:
def _make_tiny_wan21_1_3b_config(self):
"""Create a tiny config mimicking Wan2.1 1.3B."""
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b()
# Override to tiny values (preserve 1.3B head structure: 12 heads)
@@ -234,7 +234,7 @@ class TestWan21Model:
def test_wan21_tiny_model_forward(self):
"""Forward pass with Wan2.1 tiny config."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.model import WanModel
config = self._make_tiny_wan21_config()
model = WanModel(config)
@@ -252,7 +252,7 @@ class TestWan21Model:
def test_wan21_1_3b_tiny_model_forward(self):
"""Forward pass with Wan2.1 1.3B tiny config."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.model import WanModel
config = self._make_tiny_wan21_1_3b_config()
model = WanModel(config)
@@ -270,8 +270,8 @@ class TestWan21Model:
def test_wan21_single_model_loop(self):
"""Full diffusion loop with single model (Wan2.1 style)."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.model import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
config = self._make_tiny_wan21_config()
model = WanModel(config)
@@ -305,7 +305,7 @@ class TestWan21Model:
def test_wan21_vs_wan22_config_differences(self):
"""Verify key differences between Wan2.1 and Wan2.2 configs."""
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
c21 = WanModelConfig.wan21_t2v_14b()
c22 = WanModelConfig.wan22_t2v_14b()
@@ -333,21 +333,21 @@ class TestPerTokenTimestep:
"""Tests for per-token sinusoidal embedding."""
def test_1d_unchanged(self):
from mlx_video.models.wan.model import sinusoidal_embedding_1d
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
pos = mx.array([0.0, 100.0, 500.0])
emb = sinusoidal_embedding_1d(256, pos)
assert emb.shape == (3, 256)
def test_2d_per_token(self):
from mlx_video.models.wan.model import sinusoidal_embedding_1d
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
pos = mx.array([[0.0, 100.0, 100.0], [50.0, 50.0, 50.0]])
emb = sinusoidal_embedding_1d(256, pos)
assert emb.shape == (2, 3, 256)
def test_consistency(self):
from mlx_video.models.wan.model import sinusoidal_embedding_1d
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
pos_1d = mx.array([0.0, 100.0])
emb_1d = sinusoidal_embedding_1d(256, pos_1d)

View File

@@ -15,7 +15,7 @@ from wan_test_helpers import _make_tiny_config
class TestQuantizePredicate:
def test_matches_self_attention_layers(self):
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan2.convert import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for suffix in ["q", "k", "v", "o"]:
@@ -23,7 +23,7 @@ class TestQuantizePredicate:
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
def test_matches_cross_attention_layers(self):
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan2.convert import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for suffix in ["q", "k", "v", "o"]:
@@ -31,14 +31,14 @@ class TestQuantizePredicate:
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
def test_matches_ffn_layers(self):
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan2.convert import _quantize_predicate
mock_linear = nn.Linear(64, 64)
assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear)
assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear)
def test_rejects_embeddings(self):
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan2.convert import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for path in [
@@ -49,13 +49,13 @@ class TestQuantizePredicate:
assert not _quantize_predicate(path, mock_linear), f"Should reject {path}"
def test_rejects_norms(self):
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan2.convert import _quantize_predicate
mock_norm = nn.RMSNorm(64)
assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm)
def test_rejects_non_quantizable_modules(self):
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan2.convert import _quantize_predicate
mock_norm = nn.RMSNorm(64)
# Even if path matches, module must have to_quantized
@@ -63,7 +63,7 @@ class TestQuantizePredicate:
def test_all_10_patterns_covered(self):
"""Verify exactly 10 layer patterns are targeted."""
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan2.convert import _quantize_predicate
mock_linear = nn.Linear(64, 64)
patterns = [
@@ -90,8 +90,8 @@ class TestQuantizePredicate:
class TestQuantizeRoundTrip:
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
"""Helper: create model, quantize, save to tmp_path."""
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.convert import _quantize_predicate
from mlx_video.models.wan2.model import WanModel
model = WanModel(config)
nn.quantize(
@@ -116,7 +116,7 @@ class TestQuantizeRoundTrip:
config = _make_tiny_config()
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4)
from mlx_video.models.wan.loading import load_wan_model
from mlx_video.models.wan2.utils import load_wan_model
loaded = load_wan_model(
model_path,
@@ -136,7 +136,7 @@ class TestQuantizeRoundTrip:
config = _make_tiny_config()
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8)
from mlx_video.models.wan.loading import load_wan_model
from mlx_video.models.wan2.utils import load_wan_model
loaded = load_wan_model(
model_path,
@@ -151,7 +151,7 @@ class TestQuantizeRoundTrip:
config = _make_tiny_config()
model_path, _ = self._quantize_and_save(config, tmp_path, bits=4)
from mlx_video.models.wan.loading import load_wan_model
from mlx_video.models.wan2.utils import load_wan_model
loaded = load_wan_model(
model_path,
@@ -164,7 +164,7 @@ class TestQuantizeRoundTrip:
def test_loading_without_quantization_flag(self, tmp_path):
"""Loading a non-quantized model should have standard Linear layers."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -172,7 +172,7 @@ class TestQuantizeRoundTrip:
model_path = tmp_path / "model.safetensors"
mx.save_safetensors(str(model_path), weights_dict)
from mlx_video.models.wan.loading import load_wan_model
from mlx_video.models.wan2.utils import load_wan_model
loaded = load_wan_model(model_path, config, quantization=None)
@@ -187,8 +187,8 @@ class TestQuantizeRoundTrip:
class TestQuantizedInference:
def _make_quantized_model(self, config, bits=4):
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.convert import _quantize_predicate
from mlx_video.models.wan2.model import WanModel
model = WanModel(config)
nn.quantize(
@@ -238,8 +238,8 @@ class TestQuantizedInference:
def test_quantized_output_differs_from_unquantized(self):
"""Sanity check: quantization should change the weights."""
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.convert import _quantize_predicate
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_config()
mx.random.seed(42)
@@ -271,8 +271,8 @@ class TestQuantizedInference:
class TestQuantizationConfig:
def test_config_metadata_written(self, tmp_path):
"""Verify _quantize_saved_model writes quantization metadata to config.json."""
from mlx_video.convert_wan import _quantize_saved_model
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.convert import _quantize_saved_model
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -295,8 +295,8 @@ class TestQuantizationConfig:
assert cfg["quantization"]["group_size"] == 64
def test_config_metadata_8bit(self, tmp_path):
from mlx_video.convert_wan import _quantize_saved_model
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.convert import _quantize_saved_model
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -316,8 +316,8 @@ class TestQuantizationConfig:
def test_dual_model_quantization(self, tmp_path):
"""Verify dual-model quantization writes both model files."""
from mlx_video.convert_wan import _quantize_saved_model
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.convert import _quantize_saved_model
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_config()

View File

@@ -27,8 +27,8 @@ class TestRoPEFrequencyConstruction:
def _get_model_freqs(self, dim=64, num_heads=4):
"""Instantiate a tiny WanModel and return its .freqs tensor."""
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan2.model import WanModel
config = WanModelConfig()
config.dim = dim
@@ -51,7 +51,7 @@ class TestRoPEFrequencyConstruction:
def test_three_call_vs_single_call_differ(self):
"""Three separate rope_params calls must differ from single call."""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
d = 128 # head_dim for all Wan models
# Reference: three separate calls
@@ -79,7 +79,7 @@ class TestRoPEFrequencyConstruction:
This verifies each axis gets its own independent frequency range
starting from theta^0 = 1.0 (i.e., exponent 0/dim).
"""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
d = 128
freqs = mx.concatenate(
@@ -120,7 +120,7 @@ class TestRoPEFrequencyConstruction:
Both use rope_params(1024, 2*(d//6)) = rope_params(1024, 42).
"""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
d = 128
d_h_dim = 2 * (d // 6) # 42
@@ -150,7 +150,7 @@ class TestRoPEFrequencyConstruction:
axis should be 1.0 (theta^0). A single-call approach would give height
starting at ~0.04 and width at ~0.002 instead of 1.0.
"""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
d = 128
freqs = mx.concatenate(
@@ -182,7 +182,7 @@ class TestRoPEFrequencyConstruction:
def test_model_freqs_match_manual_construction(self):
"""WanModel.freqs should match manually constructed three-call freqs."""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4)
d = head_dim # 16
@@ -203,7 +203,7 @@ class TestRoPEFrequencyConstruction:
def test_model_freqs_14b_dimensions(self):
"""Verify freq dimensions for 14B-scale head_dim=128."""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
d = 128
freqs = mx.concatenate(
@@ -242,7 +242,7 @@ class TestRoPEFrequencyMatchesReference:
"""Numerically compare MLX and PyTorch frequency tables."""
import torch
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
d = 128
@@ -298,7 +298,7 @@ class TestRoPEApplyWithCorrectFreqs:
This is the key property that was broken by the single-call bug:
height/width frequencies were too low to distinguish nearby positions.
"""
from mlx_video.models.wan.rope import rope_apply, rope_params
from mlx_video.models.wan2.rope import rope_apply, rope_params
d = 128
freqs = mx.concatenate(
@@ -346,7 +346,7 @@ class TestRoPEApplyWithCorrectFreqs:
def test_precomputed_matches_online(self):
"""rope_precompute_cos_sin + rope_apply should match non-precomputed path."""
from mlx_video.models.wan.rope import (
from mlx_video.models.wan2.rope import (
rope_apply,
rope_params,
rope_precompute_cos_sin,

View File

@@ -13,7 +13,7 @@ import pytest
class TestFlowMatchEulerScheduler:
def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
assert sched.num_train_timesteps == 1000
@@ -21,7 +21,7 @@ class TestFlowMatchEulerScheduler:
assert sched.sigmas is None
def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(40, shift=12.0)
@@ -30,7 +30,7 @@ class TestFlowMatchEulerScheduler:
assert sched.sigmas.shape == (41,) # 40 steps + terminal
def test_timesteps_decreasing(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(40, shift=12.0)
@@ -40,7 +40,7 @@ class TestFlowMatchEulerScheduler:
assert np.all(np.diff(ts) < 0), f"Timesteps not decreasing: {ts[:5]}..."
def test_sigmas_decreasing(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(20, shift=1.0)
@@ -49,7 +49,7 @@ class TestFlowMatchEulerScheduler:
assert np.all(np.diff(sigmas) <= 0), "Sigmas not decreasing"
def test_terminal_sigma_is_zero(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(20, shift=5.0)
@@ -58,7 +58,7 @@ class TestFlowMatchEulerScheduler:
def test_shift_effect(self):
"""Larger shift should push sigmas toward higher values."""
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched1 = FlowMatchEulerScheduler()
sched2 = FlowMatchEulerScheduler()
@@ -70,7 +70,7 @@ class TestFlowMatchEulerScheduler:
assert mean2 > mean1, "Higher shift should push sigmas higher"
def test_step_euler(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(10, shift=1.0)
@@ -95,7 +95,7 @@ class TestFlowMatchEulerScheduler:
)
def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -108,7 +108,7 @@ class TestFlowMatchEulerScheduler:
assert sched._step_index == 2
def test_reset(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -121,7 +121,7 @@ class TestFlowMatchEulerScheduler:
@pytest.mark.parametrize("steps", [10, 20, 40, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(steps, shift=12.0)
@@ -131,7 +131,7 @@ class TestFlowMatchEulerScheduler:
def test_full_denoise_loop(self):
"""Run a complete denoise loop with zero velocity -> sample unchanged."""
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -153,26 +153,26 @@ class TestComputeSigmas:
"""Tests for the shared _compute_sigmas helper."""
def test_length(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
from mlx_video.models.wan2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
assert len(sigmas) == 21 # num_steps + terminal
def test_terminal_zero(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
from mlx_video.models.wan2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(10, shift=1.0)
assert sigmas[-1] == 0.0
def test_starts_near_one(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
from mlx_video.models.wan2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
# Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0)
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3)
def test_decreasing(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
from mlx_video.models.wan2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
assert np.all(np.diff(sigmas) <= 0)
@@ -185,7 +185,7 @@ class TestComputeSigmas:
sigma_max/sigma_min come from the *unshifted* training schedule, and the
shift is applied only once (single-shift).
"""
from mlx_video.models.wan.scheduler import _compute_sigmas
from mlx_video.models.wan2.scheduler import _compute_sigmas
steps, shift, N = 50, 5.0, 1000
sigmas = _compute_sigmas(steps, shift, N)
@@ -200,7 +200,7 @@ class TestComputeSigmas:
np.testing.assert_allclose(sigmas, official, atol=1e-6)
def test_shift_one_is_near_linear(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
from mlx_video.models.wan2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(10, shift=1.0)
# With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule)
@@ -210,7 +210,7 @@ class TestComputeSigmas:
def test_all_schedulers_same_sigmas(self):
"""All three schedulers should produce identical sigma schedules."""
from mlx_video.models.wan.scheduler import (
from mlx_video.models.wan2.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
@@ -229,7 +229,7 @@ class TestComputeSigmas:
np.testing.assert_allclose(np.array(s.sigmas), ref, atol=1e-6)
def test_all_schedulers_same_timesteps(self):
from mlx_video.models.wan.scheduler import (
from mlx_video.models.wan2.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
@@ -255,14 +255,14 @@ class TestComputeSigmas:
class TestFlowDPMPP2MScheduler:
def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
assert sched.num_train_timesteps == 1000
assert sched.lower_order_final is True
def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(20, shift=5.0)
@@ -271,7 +271,7 @@ class TestFlowDPMPP2MScheduler:
assert sched.sigmas.shape == (21,)
def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -284,7 +284,7 @@ class TestFlowDPMPP2MScheduler:
assert sched._step_index == 2
def test_reset(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -296,7 +296,7 @@ class TestFlowDPMPP2MScheduler:
def test_full_loop_finite(self):
"""Full loop with constant velocity should produce finite output."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=1.0)
@@ -309,7 +309,7 @@ class TestFlowDPMPP2MScheduler:
def test_first_step_is_first_order(self):
"""First step should use 1st-order (no prev_x0 available)."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=5.0)
@@ -324,7 +324,7 @@ class TestFlowDPMPP2MScheduler:
def test_second_step_uses_correction(self):
"""After first step, DPM++ should have stored prev_x0 for correction."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=5.0)
@@ -348,7 +348,7 @@ class TestFlowDPMPP2MScheduler:
def test_denoise_to_target(self):
"""Perfect oracle should denoise to target with any solver."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(20, shift=5.0)
@@ -363,7 +363,7 @@ class TestFlowDPMPP2MScheduler:
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(steps, shift=5.0)
@@ -373,7 +373,7 @@ class TestFlowDPMPP2MScheduler:
def test_terminal_sigma_produces_x0(self):
"""When sigma_next=0 the scheduler should return x0 directly."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -394,7 +394,7 @@ class TestFlowDPMPP2MScheduler:
class TestFlowUniPCScheduler:
def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
assert sched.num_train_timesteps == 1000
@@ -402,7 +402,7 @@ class TestFlowUniPCScheduler:
assert sched.lower_order_final is True
def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(30, shift=12.0)
@@ -411,7 +411,7 @@ class TestFlowUniPCScheduler:
assert sched.sigmas.shape == (31,)
def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -422,7 +422,7 @@ class TestFlowUniPCScheduler:
assert sched._step_index == 1
def test_reset(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(5, shift=1.0)
@@ -435,7 +435,7 @@ class TestFlowUniPCScheduler:
assert all(m is None for m in sched._model_outputs)
def test_full_loop_finite(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(10, shift=1.0)
@@ -448,7 +448,7 @@ class TestFlowUniPCScheduler:
def test_corrector_not_applied_first_step(self):
"""First step should skip the corrector (no history)."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True)
sched.set_timesteps(10, shift=5.0)
@@ -462,7 +462,7 @@ class TestFlowUniPCScheduler:
def test_corrector_applied_after_first_step(self):
"""Steps after the first should use the corrector when enabled."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True)
sched.set_timesteps(10, shift=5.0)
@@ -475,7 +475,7 @@ class TestFlowUniPCScheduler:
assert sched._lower_order_nums >= 2
def test_denoise_to_target(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(20, shift=5.0)
@@ -490,7 +490,7 @@ class TestFlowUniPCScheduler:
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(steps, shift=5.0)
@@ -500,7 +500,7 @@ class TestFlowUniPCScheduler:
def test_disable_corrector(self):
"""Disabling corrector on step 0 should still work without error."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0])
sched.set_timesteps(5, shift=1.0)
@@ -513,7 +513,7 @@ class TestFlowUniPCScheduler:
def test_solver_order_3(self):
"""Order 3 should work without error."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(solver_order=3, use_corrector=True)
sched.set_timesteps(10, shift=5.0)
@@ -531,7 +531,7 @@ class TestFlowUniPCScheduler:
# For 50-step schedule with shift=5.0, order 2 corrector at step 5:
# rhos_c[0] (history) should be ~0.07, NOT 0.5
# rhos_c[1] (D1_t) should be ~0.45, NOT 0.5
from mlx_video.models.wan.scheduler import _compute_sigmas
from mlx_video.models.wan2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(50, shift=5.0)
@@ -597,7 +597,7 @@ class TestSchedulerCoherence:
@staticmethod
def _make_schedulers(steps=10, shift=5.0):
from mlx_video.models.wan.scheduler import (
from mlx_video.models.wan2.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
@@ -780,7 +780,7 @@ class TestSchedulerCoherence:
def test_lambda_boundary_values(self):
"""_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0."""
from mlx_video.models.wan.scheduler import (
from mlx_video.models.wan2.scheduler import (
FlowDPMPP2MScheduler,
FlowUniPCScheduler,
)
@@ -800,7 +800,7 @@ class TestSchedulerCoherence:
def test_lambda_monotonically_decreasing(self):
"""_lambda(sigma) should decrease as sigma increases (more noise → lower SNR)."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan2.scheduler import FlowDPMPP2MScheduler
sigmas = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99]
lambdas = [FlowDPMPP2MScheduler._lambda(s) for s in sigmas]
@@ -902,7 +902,7 @@ class TestSchedulerCoherence:
shape = (1, 2, 1, 2, 2)
noise = mx.random.normal(shape)
from mlx_video.models.wan.scheduler import (
from mlx_video.models.wan2.scheduler import (
FlowDPMPP2MScheduler,
FlowUniPCScheduler,
)
@@ -947,14 +947,14 @@ class TestUniPCCorrectorDefault:
def test_corrector_enabled_by_default(self):
"""Default construction should have corrector enabled."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
assert sched._use_corrector is True
def test_corrector_affects_output(self):
"""Corrector should produce different results than no corrector after step 1."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
mx.random.seed(42)
shape = (1, 4, 1, 4, 4)
@@ -978,7 +978,7 @@ class TestUniPCCorrectorDefault:
def test_corrector_does_not_affect_first_step(self):
"""Step 0 should be identical regardless of corrector setting."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan2.scheduler import FlowUniPCScheduler
mx.random.seed(42)
shape = (1, 4, 1, 4, 4)

View File

@@ -11,7 +11,7 @@ import numpy as np
class TestT5LayerNorm:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5LayerNorm
from mlx_video.models.wan2.text_encoder import T5LayerNorm
norm = T5LayerNorm(64)
x = mx.random.normal((2, 10, 64))
@@ -21,7 +21,7 @@ class TestT5LayerNorm:
def test_rms_normalization(self):
"""After T5LayerNorm with weight=1, RMS should be ~1."""
from mlx_video.models.wan.text_encoder import T5LayerNorm
from mlx_video.models.wan2.text_encoder import T5LayerNorm
norm = T5LayerNorm(128)
x = mx.random.normal((1, 5, 128)) * 5.0
@@ -35,7 +35,7 @@ class TestT5LayerNorm:
class TestT5RelativeEmbedding:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
from mlx_video.models.wan2.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
out = rel_emb(10, 10)
@@ -43,7 +43,7 @@ class TestT5RelativeEmbedding:
assert out.shape == (1, 4, 10, 10) # [1, N, lq, lk]
def test_asymmetric_lengths(self):
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
from mlx_video.models.wan2.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
out = rel_emb(8, 12)
@@ -52,7 +52,7 @@ class TestT5RelativeEmbedding:
def test_symmetry(self):
"""Position bias should have structure (not all zeros/random)."""
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
from mlx_video.models.wan2.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2)
out = rel_emb(6, 6)
@@ -67,7 +67,7 @@ class TestT5RelativeEmbedding:
class TestT5Attention:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5Attention
from mlx_video.models.wan2.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
x = mx.random.normal((1, 10, 64))
@@ -77,14 +77,14 @@ class TestT5Attention:
def test_no_scaling(self):
"""T5 attention famously has no sqrt(d) scaling. Verify structure."""
from mlx_video.models.wan.text_encoder import T5Attention
from mlx_video.models.wan2.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
# No scale attribute (unlike standard attention)
assert not hasattr(attn, "scale")
def test_with_position_bias(self):
from mlx_video.models.wan.text_encoder import T5Attention, T5RelativeEmbedding
from mlx_video.models.wan2.text_encoder import T5Attention, T5RelativeEmbedding
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
rel_emb = T5RelativeEmbedding(32, 4)
@@ -95,7 +95,7 @@ class TestT5Attention:
assert out.shape == (1, 10, 64)
def test_with_mask(self):
from mlx_video.models.wan.text_encoder import T5Attention
from mlx_video.models.wan2.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
x = mx.random.normal((1, 10, 64))
@@ -108,7 +108,7 @@ class TestT5Attention:
class TestT5FeedForward:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5FeedForward
from mlx_video.models.wan2.text_encoder import T5FeedForward
ffn = T5FeedForward(64, 256)
x = mx.random.normal((1, 10, 64))
@@ -118,7 +118,7 @@ class TestT5FeedForward:
def test_gated_structure(self):
"""T5 FFN is gated: gate(x) * fc1(x)."""
from mlx_video.models.wan.text_encoder import T5FeedForward
from mlx_video.models.wan2.text_encoder import T5FeedForward
ffn = T5FeedForward(32, 64)
assert hasattr(ffn, "gate_proj")
@@ -131,7 +131,7 @@ class TestT5Encoder:
mx.random.seed(42)
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5Encoder
from mlx_video.models.wan2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100,
@@ -150,7 +150,7 @@ class TestT5Encoder:
assert out.shape == (1, 5, 64)
def test_shared_pos(self):
from mlx_video.models.wan.text_encoder import T5Encoder
from mlx_video.models.wan2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100,
@@ -167,7 +167,7 @@ class TestT5Encoder:
assert block.pos_embedding is None
def test_per_layer_pos(self):
from mlx_video.models.wan.text_encoder import T5Encoder
from mlx_video.models.wan2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100,
@@ -184,7 +184,7 @@ class TestT5Encoder:
assert block.pos_embedding is not None
def test_param_count(self):
from mlx_video.models.wan.text_encoder import T5Encoder
from mlx_video.models.wan2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100,
@@ -200,7 +200,7 @@ class TestT5Encoder:
assert num_params > 0
def test_without_mask(self):
from mlx_video.models.wan.text_encoder import T5Encoder
from mlx_video.models.wan2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100,

View File

@@ -3,7 +3,7 @@
import mlx.core as mx
import numpy as np
from mlx_video.models.ltx.video_vae.tiling import (
from mlx_video.models.ltx_2.video_vae.tiling import (
TilingConfig,
decode_with_tiling,
split_in_spatial,
@@ -75,7 +75,7 @@ class TestWan22TiledDecoding:
def _make_small_wan22_decoder(self):
"""Create a small Wan2.2 decoder for testing."""
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
from mlx_video.models.wan2.vae22 import Wan22VAEDecoder
# Use very small dimensions for fast testing
vae = Wan22VAEDecoder(z_dim=48, dim=16, dec_dim=16)
@@ -139,7 +139,7 @@ class TestWan21TiledDecoding:
def _make_small_wan21_vae(self):
"""Create a small Wan2.1 VAE for testing."""
from mlx_video.models.wan.vae import WanVAE
from mlx_video.models.wan2.vae import WanVAE
vae = WanVAE(z_dim=16)
mx.eval(vae.parameters())
@@ -192,7 +192,7 @@ class TestWan21TemporalScale:
def test_wan21_decoder_temporal_output(self):
"""Wan2.1 Decoder3d should produce T*4 temporal output (non-causal doubling)."""
from mlx_video.models.wan.vae import Decoder3d
from mlx_video.models.wan2.vae import Decoder3d
# Small decoder for fast test
dec = Decoder3d(

View File

@@ -10,7 +10,7 @@ import numpy as np
class TestWanFFN:
def test_output_shape(self):
from mlx_video.models.wan.transformer import WanFFN
from mlx_video.models.wan2.transformer import WanFFN
ffn = WanFFN(64, 256)
x = mx.random.normal((2, 10, 64))
@@ -20,7 +20,7 @@ class TestWanFFN:
def test_gelu_activation(self):
"""FFN should use GELU activation (non-linearity)."""
from mlx_video.models.wan.transformer import WanFFN
from mlx_video.models.wan2.transformer import WanFFN
ffn = WanFFN(32, 128)
x = mx.ones((1, 1, 32)) * 2.0
@@ -40,8 +40,8 @@ class TestWanAttentionBlock:
self.num_heads = 4
def test_output_shape(self):
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan2.transformer import WanAttentionBlock
block = WanAttentionBlock(
self.dim,
@@ -68,13 +68,13 @@ class TestWanAttentionBlock:
assert out.shape == (B, L, self.dim)
def test_modulation_shape(self):
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan2.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
assert block.modulation.shape == (1, 6, self.dim)
def test_with_cross_attn_norm(self):
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan2.transformer import WanAttentionBlock
block = WanAttentionBlock(
self.dim,
@@ -85,7 +85,7 @@ class TestWanAttentionBlock:
assert block.norm3 is not None
def test_without_cross_attn_norm(self):
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan2.transformer import WanAttentionBlock
block = WanAttentionBlock(
self.dim,
@@ -97,8 +97,8 @@ class TestWanAttentionBlock:
def test_residual_connection(self):
"""Output should differ from zero even with small random init."""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan2.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
B, L = 1, 8
@@ -129,15 +129,15 @@ class TestFloat32Modulation:
def test_block_modulation_in_float32(self):
"""Modulation param starts random but should be usable as float32."""
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan2.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, 128, 4, cross_attn_norm=True)
assert block.modulation.dtype == mx.float32
def test_block_output_float32_with_bf16_modulation_input(self):
"""Even if e (time embedding) arrives as bf16, modulation should cast to f32."""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan2.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, 128, 4)
B, L = 1, 8
@@ -153,7 +153,7 @@ class TestFloat32Modulation:
def test_head_modulation_float32(self):
"""Head modulation should be float32 even with bf16 e input."""
from mlx_video.models.wan.model import Head
from mlx_video.models.wan2.model import Head
head = Head(self.dim, 4, (1, 2, 2))
x = mx.random.normal((1, 8, self.dim))
@@ -164,7 +164,7 @@ class TestFloat32Modulation:
def test_model_time_embedding_float32(self):
"""sinusoidal_embedding_1d output must be float32."""
from mlx_video.models.wan.model import sinusoidal_embedding_1d
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
t = mx.array([500.0])
emb = sinusoidal_embedding_1d(256, t)
@@ -173,7 +173,7 @@ class TestFloat32Modulation:
def test_model_per_token_time_embedding_float32(self):
"""Per-token time embeddings (I2V) should also be float32."""
from mlx_video.models.wan.model import sinusoidal_embedding_1d
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4]
emb = sinusoidal_embedding_1d(256, t)

View File

@@ -12,7 +12,7 @@ import numpy as np
class TestCausalConv3d:
def test_output_shape_stride1(self):
from mlx_video.models.wan.vae import CausalConv3d
from mlx_video.models.wan2.vae import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1)
# Initialize weights
@@ -28,7 +28,7 @@ class TestCausalConv3d:
assert out.shape[4] == 8 # W preserved
def test_output_shape_kernel1(self):
from mlx_video.models.wan.vae import CausalConv3d
from mlx_video.models.wan2.vae import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0)
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
@@ -39,7 +39,7 @@ class TestCausalConv3d:
def test_causal_padding(self):
"""Causal conv should only use past/current frames, not future."""
from mlx_video.models.wan.vae import CausalConv3d
from mlx_video.models.wan2.vae import CausalConv3d
conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1)
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
@@ -56,7 +56,7 @@ class TestCausalConv3d:
class TestResidualBlock:
def test_same_dim(self):
from mlx_video.models.wan.vae import ResidualBlock
from mlx_video.models.wan2.vae import ResidualBlock
block = ResidualBlock(8, 8)
x = mx.random.normal((1, 8, 2, 4, 4))
@@ -65,7 +65,7 @@ class TestResidualBlock:
assert out.shape == (1, 8, 2, 4, 4)
def test_different_dim(self):
from mlx_video.models.wan.vae import ResidualBlock
from mlx_video.models.wan2.vae import ResidualBlock
block = ResidualBlock(8, 16)
x = mx.random.normal((1, 8, 2, 4, 4))
@@ -74,13 +74,13 @@ class TestResidualBlock:
assert out.shape == (1, 16, 2, 4, 4)
def test_shortcut_exists_when_dims_differ(self):
from mlx_video.models.wan.vae import ResidualBlock
from mlx_video.models.wan2.vae import ResidualBlock
block = ResidualBlock(8, 16)
assert block.shortcut is not None
def test_no_shortcut_when_dims_same(self):
from mlx_video.models.wan.vae import ResidualBlock
from mlx_video.models.wan2.vae import ResidualBlock
block = ResidualBlock(8, 8)
assert block.shortcut is None
@@ -88,7 +88,7 @@ class TestResidualBlock:
class TestAttentionBlock:
def test_output_shape(self):
from mlx_video.models.wan.vae import AttentionBlock
from mlx_video.models.wan2.vae import AttentionBlock
block = AttentionBlock(8)
x = mx.random.normal((1, 8, 2, 4, 4))
@@ -97,7 +97,7 @@ class TestAttentionBlock:
assert out.shape == (1, 8, 2, 4, 4)
def test_residual_connection(self):
from mlx_video.models.wan.vae import AttentionBlock
from mlx_video.models.wan2.vae import AttentionBlock
block = AttentionBlock(8)
x = mx.random.normal((1, 8, 1, 3, 3))
@@ -109,7 +109,7 @@ class TestAttentionBlock:
class TestWanVAE:
def test_instantiation(self):
from mlx_video.models.wan.vae import WanVAE
from mlx_video.models.wan2.vae import WanVAE
vae = WanVAE(z_dim=16)
assert vae.z_dim == 16
@@ -117,7 +117,7 @@ class TestWanVAE:
assert vae.std.shape == (16,)
def test_normalization_stats(self):
from mlx_video.models.wan.vae import VAE_MEAN, VAE_STD
from mlx_video.models.wan2.vae import VAE_MEAN, VAE_STD
assert len(VAE_MEAN) == 16
assert len(VAE_STD) == 16
@@ -133,7 +133,7 @@ class TestVAE22CausalConv3d:
"""Tests for vae22.CausalConv3d (channels-last)."""
def test_output_shape_k3(self):
from mlx_video.models.wan.vae22 import CausalConv3d
from mlx_video.models.wan2.vae22 import CausalConv3d
conv = CausalConv3d(8, 16, kernel_size=3, padding=1)
x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C]
@@ -142,7 +142,7 @@ class TestVAE22CausalConv3d:
assert out.shape == (1, 4, 8, 8, 16)
def test_output_shape_k1(self):
from mlx_video.models.wan.vae22 import CausalConv3d
from mlx_video.models.wan2.vae22 import CausalConv3d
conv = CausalConv3d(8, 16, kernel_size=1)
x = mx.random.normal((1, 2, 4, 4, 8))
@@ -152,7 +152,7 @@ class TestVAE22CausalConv3d:
def test_temporal_causal(self):
"""Output at t=0 should not depend on t>0."""
from mlx_video.models.wan.vae22 import CausalConv3d
from mlx_video.models.wan2.vae22 import CausalConv3d
conv = CausalConv3d(2, 2, kernel_size=3, padding=1)
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
@@ -178,7 +178,7 @@ class TestVAE22CausalConv3d:
def test_channels_last_format(self):
"""Verify input/output are channels-last [B, T, H, W, C]."""
from mlx_video.models.wan.vae22 import CausalConv3d
from mlx_video.models.wan2.vae22 import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=3, padding=1)
x = mx.random.normal((2, 3, 6, 6, 4))
@@ -191,7 +191,7 @@ class TestRMSNorm:
"""Tests for vae22.RMS_norm (actually L2 normalization)."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import RMS_norm
from mlx_video.models.wan2.vae22 import RMS_norm
norm = RMS_norm(16)
x = mx.random.normal((2, 4, 4, 4, 16))
@@ -201,7 +201,7 @@ class TestRMSNorm:
def test_l2_normalization(self):
"""RMS_norm should normalize to unit L2 norm * sqrt(dim)."""
from mlx_video.models.wan.vae22 import RMS_norm
from mlx_video.models.wan2.vae22 import RMS_norm
dim = 32
norm = RMS_norm(dim)
@@ -215,7 +215,7 @@ class TestRMSNorm:
def test_scale_invariant(self):
"""Scaling input by constant should not change output (L2 norm property)."""
from mlx_video.models.wan.vae22 import RMS_norm
from mlx_video.models.wan2.vae22 import RMS_norm
norm = RMS_norm(8)
x = mx.random.normal((1, 1, 1, 1, 8))
@@ -226,7 +226,7 @@ class TestRMSNorm:
def test_gamma_effect(self):
"""Non-unit gamma should scale output."""
from mlx_video.models.wan.vae22 import RMS_norm
from mlx_video.models.wan2.vae22 import RMS_norm
norm = RMS_norm(4)
norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0])
@@ -241,7 +241,7 @@ class TestDupUp3D:
"""Tests for vae22.DupUp3D spatial/temporal upsampling."""
def test_spatial_only(self):
from mlx_video.models.wan.vae22 import DupUp3D
from mlx_video.models.wan2.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
@@ -250,7 +250,7 @@ class TestDupUp3D:
assert out.shape == (1, 3, 8, 8, 4)
def test_temporal_and_spatial(self):
from mlx_video.models.wan.vae22 import DupUp3D
from mlx_video.models.wan2.vae22 import DupUp3D
up = DupUp3D(16, 8, factor_t=2, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 16))
@@ -259,7 +259,7 @@ class TestDupUp3D:
assert out.shape == (1, 6, 8, 8, 8)
def test_first_chunk_trims(self):
from mlx_video.models.wan.vae22 import DupUp3D
from mlx_video.models.wan2.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=2, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
@@ -271,7 +271,7 @@ class TestDupUp3D:
assert out_trimmed.shape[1] == 5
def test_no_temporal_first_chunk_noop(self):
from mlx_video.models.wan.vae22 import DupUp3D
from mlx_video.models.wan2.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
@@ -286,7 +286,7 @@ class TestVAE22Resample:
"""Tests for vae22.Resample (spatial/temporal upsampling)."""
def test_upsample2d_shape(self):
from mlx_video.models.wan.vae22 import Resample
from mlx_video.models.wan2.vae22 import Resample
r = Resample(8, "upsample2d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
@@ -296,7 +296,7 @@ class TestVAE22Resample:
assert out.shape == (1, 2, 8, 8, 8) # 2x spatial, same temporal
def test_upsample3d_shape(self):
from mlx_video.models.wan.vae22 import Resample
from mlx_video.models.wan2.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
@@ -306,7 +306,7 @@ class TestVAE22Resample:
assert out.shape == (1, 4, 8, 8, 8) # 2x spatial + 2x temporal
def test_upsample3d_first_chunk(self):
from mlx_video.models.wan.vae22 import Resample
from mlx_video.models.wan2.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
@@ -318,7 +318,7 @@ class TestVAE22Resample:
def test_upsample3d_first_chunk_single_frame(self):
"""Single-frame input with first_chunk: no temporal upsample."""
from mlx_video.models.wan.vae22 import Resample
from mlx_video.models.wan2.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
@@ -336,7 +336,7 @@ class TestVAE22Resample:
We verify this by checking that the first output frame depends only on
the first input frame (not on time_conv parameters).
"""
from mlx_video.models.wan.vae22 import Resample
from mlx_video.models.wan2.vae22 import Resample
C = 8
r = Resample(C, "upsample3d")
@@ -373,7 +373,7 @@ class TestVAE22ResidualBlock:
"""Tests for vae22.ResidualBlock."""
def test_same_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock
from mlx_video.models.wan2.vae22 import ResidualBlock
block = ResidualBlock(8, 8)
x = mx.random.normal((1, 2, 4, 4, 8))
@@ -382,7 +382,7 @@ class TestVAE22ResidualBlock:
assert out.shape == (1, 2, 4, 4, 8)
def test_different_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock
from mlx_video.models.wan2.vae22 import ResidualBlock
block = ResidualBlock(8, 16)
x = mx.random.normal((1, 2, 4, 4, 8))
@@ -391,13 +391,13 @@ class TestVAE22ResidualBlock:
assert out.shape == (1, 2, 4, 4, 16)
def test_shortcut_when_dims_differ(self):
from mlx_video.models.wan.vae22 import ResidualBlock
from mlx_video.models.wan2.vae22 import ResidualBlock
block = ResidualBlock(8, 16)
assert block.shortcut is not None
def test_no_shortcut_same_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock
from mlx_video.models.wan2.vae22 import ResidualBlock
block = ResidualBlock(8, 8)
assert block.shortcut is None
@@ -408,7 +408,7 @@ class TestResidualBlockLayers:
def test_layer_names_no_underscore_prefix(self):
"""Layer names must NOT start with underscore (MLX ignores them)."""
from mlx_video.models.wan.vae22 import ResidualBlockLayers
from mlx_video.models.wan2.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 8)
params = dict(block.parameters())
@@ -417,7 +417,7 @@ class TestResidualBlockLayers:
assert not key.startswith("_"), f"Parameter {key} starts with underscore"
def test_has_expected_layers(self):
from mlx_video.models.wan.vae22 import ResidualBlockLayers
from mlx_video.models.wan2.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 16)
assert hasattr(block, "layer_0") # first RMS_norm
@@ -426,7 +426,7 @@ class TestResidualBlockLayers:
assert hasattr(block, "layer_6") # second CausalConv3d
def test_forward_shape(self):
from mlx_video.models.wan.vae22 import ResidualBlockLayers
from mlx_video.models.wan2.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 16)
x = mx.random.normal((1, 2, 4, 4, 8))
@@ -439,7 +439,7 @@ class TestVAE22AttentionBlock:
"""Tests for vae22.AttentionBlock (per-frame 2D self-attention)."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import AttentionBlock
from mlx_video.models.wan2.vae22 import AttentionBlock
block = AttentionBlock(16)
block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01
@@ -450,7 +450,7 @@ class TestVAE22AttentionBlock:
assert out.shape == (1, 2, 4, 4, 16)
def test_residual_connection(self):
from mlx_video.models.wan.vae22 import AttentionBlock
from mlx_video.models.wan2.vae22 import AttentionBlock
block = AttentionBlock(8)
block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape)
@@ -466,7 +466,7 @@ class TestHead22:
"""Tests for vae22.Head22 output head."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import Head22
from mlx_video.models.wan2.vae22 import Head22
head = Head22(16, out_channels=12)
x = mx.random.normal((1, 2, 4, 4, 16))
@@ -476,7 +476,7 @@ class TestHead22:
def test_layer_names_no_underscore(self):
"""Head layers must not use underscore prefix."""
from mlx_video.models.wan.vae22 import Head22
from mlx_video.models.wan2.vae22 import Head22
head = Head22(8)
assert hasattr(head, "layer_0") # RMS_norm
@@ -490,7 +490,7 @@ class TestUnpatchify:
"""Tests for vae22._unpatchify."""
def test_basic_shape(self):
from mlx_video.models.wan.vae22 import _unpatchify
from mlx_video.models.wan2.vae22 import _unpatchify
x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2
out = _unpatchify(x, patch_size=2)
@@ -498,7 +498,7 @@ class TestUnpatchify:
assert out.shape == (1, 2, 8, 8, 3)
def test_patch_size_1_noop(self):
from mlx_video.models.wan.vae22 import _unpatchify
from mlx_video.models.wan2.vae22 import _unpatchify
x = mx.random.normal((1, 2, 4, 4, 3))
out = _unpatchify(x, patch_size=1)
@@ -507,7 +507,7 @@ class TestUnpatchify:
def test_preserves_content(self):
"""Unpatchify should be a lossless rearrangement."""
from mlx_video.models.wan.vae22 import _unpatchify
from mlx_video.models.wan2.vae22 import _unpatchify
x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32)
out = _unpatchify(x, patch_size=2)
@@ -521,7 +521,7 @@ class TestDenormalizeLatents:
"""Tests for vae22.denormalize_latents."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import denormalize_latents
from mlx_video.models.wan2.vae22 import denormalize_latents
z = mx.random.normal((1, 2, 4, 4, 48))
out = denormalize_latents(z)
@@ -529,7 +529,7 @@ class TestDenormalizeLatents:
assert out.shape == (1, 2, 4, 4, 48)
def test_custom_mean_std(self):
from mlx_video.models.wan.vae22 import denormalize_latents
from mlx_video.models.wan2.vae22 import denormalize_latents
z = mx.ones((1, 1, 1, 1, 4))
mean = mx.array([1.0, 2.0, 3.0, 4.0])
@@ -542,7 +542,7 @@ class TestDenormalizeLatents:
)
def test_uses_default_constants(self):
from mlx_video.models.wan.vae22 import (
from mlx_video.models.wan2.vae22 import (
VAE22_MEAN,
denormalize_latents,
)
@@ -563,14 +563,14 @@ class TestVAE22NormConstants:
"""Tests for VAE22_MEAN and VAE22_STD constants."""
def test_dimensions(self):
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD
from mlx_video.models.wan2.vae22 import VAE22_MEAN, VAE22_STD
mx.eval(VAE22_MEAN, VAE22_STD)
assert VAE22_MEAN.shape == (48,)
assert VAE22_STD.shape == (48,)
def test_std_positive(self):
from mlx_video.models.wan.vae22 import VAE22_STD
from mlx_video.models.wan2.vae22 import VAE22_STD
mx.eval(VAE22_STD)
assert (np.array(VAE22_STD) > 0).all()
@@ -581,7 +581,7 @@ class TestWan22VAEDecoder:
def test_output_shape_small(self):
"""Tiny decoder should produce correct spatial/temporal output."""
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
from mlx_video.models.wan2.vae22 import Wan22VAEDecoder
# Use very small dims to keep test fast
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
@@ -597,7 +597,7 @@ class TestWan22VAEDecoder:
assert np.array(out).max() <= 1.0
def test_output_clipped(self):
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
from mlx_video.models.wan2.vae22 import Wan22VAEDecoder
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values
@@ -611,7 +611,7 @@ class TestSanitizeWan22VAEWeights:
"""Tests for vae22.sanitize_wan22_vae_weights."""
def test_skip_encoder(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.layer.weight": mx.zeros((4,)),
@@ -624,7 +624,7 @@ class TestSanitizeWan22VAEWeights:
assert "decoder.conv1.bias" in out
def test_sequential_index_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)),
@@ -639,7 +639,7 @@ class TestSanitizeWan22VAEWeights:
assert "decoder.head.layer_2.bias" in out
def test_resample_conv_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)),
@@ -650,7 +650,7 @@ class TestSanitizeWan22VAEWeights:
assert "decoder.upsamples.1.upsamples.3.resample_bias" in out
def test_attention_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)),
@@ -665,7 +665,7 @@ class TestSanitizeWan22VAEWeights:
assert "decoder.middle.1.proj_bias" in out
def test_conv3d_transpose(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
# Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I]
w = mx.zeros((16, 8, 3, 3, 3))
@@ -674,7 +674,7 @@ class TestSanitizeWan22VAEWeights:
assert out["decoder.conv1.weight"].shape == (16, 3, 3, 3, 8)
def test_conv2d_transpose(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
# Conv2d weight: [O, I, H, W] → [O, H, W, I]
w = mx.zeros((8, 8, 3, 3))
@@ -684,7 +684,7 @@ class TestSanitizeWan22VAEWeights:
assert out[key].shape == (8, 3, 3, 8)
def test_gamma_squeeze(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
# gamma: (dim, 1, 1, 1) → (dim,)
w = mx.ones((16, 1, 1, 1))
@@ -698,7 +698,7 @@ class TestUpResidualBlock:
"""Tests for vae22.Up_ResidualBlock."""
def test_no_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock
from mlx_video.models.wan2.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(
8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False
@@ -710,7 +710,7 @@ class TestUpResidualBlock:
assert out.shape == (1, 2, 4, 4, 8)
def test_spatial_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock
from mlx_video.models.wan2.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(
8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True
@@ -722,7 +722,7 @@ class TestUpResidualBlock:
assert out.shape == (1, 2, 8, 8, 4)
def test_spatial_temporal_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock
from mlx_video.models.wan2.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(
8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True
@@ -738,7 +738,7 @@ class TestPatchify:
"""Tests for _patchify and _unpatchify round-trip."""
def test_roundtrip(self):
from mlx_video.models.wan.vae22 import _patchify, _unpatchify
from mlx_video.models.wan2.vae22 import _patchify, _unpatchify
x = mx.random.normal((1, 1, 64, 64, 3))
p = _patchify(x, patch_size=2)
@@ -748,7 +748,7 @@ class TestPatchify:
assert float(mx.abs(x - back).max()) == 0.0
def test_identity_patch_1(self):
from mlx_video.models.wan.vae22 import _patchify, _unpatchify
from mlx_video.models.wan2.vae22 import _patchify, _unpatchify
x = mx.random.normal((1, 2, 8, 8, 3))
assert _patchify(x, patch_size=1).shape == x.shape
@@ -759,7 +759,7 @@ class TestAvgDown3D:
"""Tests for AvgDown3D downsampling."""
def test_spatial_only(self):
from mlx_video.models.wan.vae22 import AvgDown3D
from mlx_video.models.wan2.vae22 import AvgDown3D
down = AvgDown3D(8, 16, factor_t=1, factor_s=2)
x = mx.random.normal((1, 2, 8, 8, 8))
@@ -768,7 +768,7 @@ class TestAvgDown3D:
assert out.shape == (1, 2, 4, 4, 16)
def test_temporal_and_spatial(self):
from mlx_video.models.wan.vae22 import AvgDown3D
from mlx_video.models.wan2.vae22 import AvgDown3D
down = AvgDown3D(8, 16, factor_t=2, factor_s=2)
x = mx.random.normal((1, 4, 8, 8, 8))
@@ -777,7 +777,7 @@ class TestAvgDown3D:
assert out.shape == (1, 2, 4, 4, 16)
def test_single_frame(self):
from mlx_video.models.wan.vae22 import AvgDown3D
from mlx_video.models.wan2.vae22 import AvgDown3D
down = AvgDown3D(8, 8, factor_t=2, factor_s=2)
x = mx.random.normal((1, 1, 8, 8, 8))
@@ -791,7 +791,7 @@ class TestDownResidualBlock:
"""Tests for Down_ResidualBlock."""
def test_no_downsample(self):
from mlx_video.models.wan.vae22 import Down_ResidualBlock
from mlx_video.models.wan2.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(
8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False
@@ -802,7 +802,7 @@ class TestDownResidualBlock:
assert out.shape == (1, 2, 8, 8, 8)
def test_spatial_downsample(self):
from mlx_video.models.wan.vae22 import Down_ResidualBlock
from mlx_video.models.wan2.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(
8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True
@@ -813,7 +813,7 @@ class TestDownResidualBlock:
assert out.shape == (1, 2, 4, 4, 16)
def test_spatial_temporal_downsample(self):
from mlx_video.models.wan.vae22 import Down_ResidualBlock
from mlx_video.models.wan2.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(
8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True
@@ -828,7 +828,7 @@ class TestEncoder3d:
"""Tests for Encoder3d."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import Encoder3d
from mlx_video.models.wan2.vae22 import Encoder3d
enc = Encoder3d(dim=16, z_dim=8)
x = mx.random.normal((1, 1, 16, 16, 12))
@@ -839,7 +839,7 @@ class TestEncoder3d:
assert out.shape == (1, 1, 2, 2, 8)
def test_multi_frame(self):
from mlx_video.models.wan.vae22 import Encoder3d
from mlx_video.models.wan2.vae22 import Encoder3d
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
x = mx.random.normal((1, 5, 16, 16, 12))
@@ -854,7 +854,7 @@ class TestWan22VAEEncoder:
"""Tests for Wan22VAEEncoder wrapper."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16)
# Input: single image 32×32 (patchify÷2 → 16×16, then 3 spatial ÷8 → 2×2)
@@ -865,7 +865,7 @@ class TestWan22VAEEncoder:
assert z.shape == (1, 1, 2, 2, 48)
def test_full_dim(self):
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=160)
img = mx.random.normal((1, 1, 64, 64, 3))
@@ -880,7 +880,7 @@ class TestNormalizeLatents:
"""Tests for normalize/denormalize latent roundtrip."""
def test_roundtrip(self):
from mlx_video.models.wan.vae22 import denormalize_latents, normalize_latents
from mlx_video.models.wan2.vae22 import denormalize_latents, normalize_latents
z = mx.random.normal((1, 2, 4, 4, 48))
z_norm = normalize_latents(z)
@@ -895,7 +895,7 @@ class TestVAEEncoderTemporalOrder:
def test_encoder_temporal_downsample_pattern(self):
"""Encoder3d with (False, True, True): T=5→5→3→2."""
from mlx_video.models.wan.vae22 import Encoder3d
from mlx_video.models.wan2.vae22 import Encoder3d
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
x = mx.random.normal((1, 5, 16, 16, 12))
@@ -906,7 +906,7 @@ class TestVAEEncoderTemporalOrder:
def test_wrapper_uses_correct_pattern(self):
"""Wan22VAEEncoder should use (False, True, True) temporal downsample."""
from mlx_video.models.wan.vae22 import Resample, Wan22VAEEncoder
from mlx_video.models.wan2.vae22 import Resample, Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16)
down_blocks = enc.encoder.downsamples
@@ -921,7 +921,7 @@ class TestVAEEncoderTemporalOrder:
def test_single_frame_encoder(self):
"""Single frame (T=1) should work with (False, True, True) pattern."""
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16)
img = mx.random.normal((1, 1, 32, 32, 3))
@@ -933,7 +933,7 @@ class TestVAEEncoderTemporalOrder:
def test_wrong_order_gives_different_result(self):
"""(True, True, False) vs (False, True, True) produce different outputs."""
from mlx_video.models.wan.vae22 import Encoder3d
from mlx_video.models.wan2.vae22 import Encoder3d
enc_correct = Encoder3d(
dim=16, z_dim=8, temperal_downsample=(False, True, True)
@@ -963,7 +963,7 @@ class TestVAE21RoundTrip:
def test_encode_decode_shape_and_values(self):
"""Encoder3d → Decoder3d: output shape matches input, values are finite."""
from mlx_video.models.wan.vae import Decoder3d, Encoder3d
from mlx_video.models.wan2.vae import Decoder3d, Encoder3d
z_dim = 4
dim = 8
@@ -995,7 +995,7 @@ class TestVAE22RoundTrip:
def test_encode_decode_shape_and_values(self):
"""Wan22VAEEncoder → Wan22VAEDecoder: shapes consistent, values in range."""
from mlx_video.models.wan.vae22 import (
from mlx_video.models.wan2.vae22 import (
Wan22VAEDecoder,
Wan22VAEEncoder,
denormalize_latents,

View File

@@ -3,7 +3,7 @@
def _make_tiny_config():
"""Create a tiny WanModelConfig for testing."""
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig()
# Override to tiny values