Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.

This commit is contained in:
Prince Canuma
2026-03-18 17:52:30 +01:00
parent 17397da70c
commit 6c63163671
28 changed files with 354 additions and 1033 deletions

View File

@@ -12,7 +12,7 @@ import numpy as np
class TestCausalConv3d:
def test_output_shape_stride1(self):
from mlx_video.models.wan.vae import CausalConv3d
from mlx_video.models.wan2.vae import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1)
# Initialize weights
@@ -28,7 +28,7 @@ class TestCausalConv3d:
assert out.shape[4] == 8 # W preserved
def test_output_shape_kernel1(self):
from mlx_video.models.wan.vae import CausalConv3d
from mlx_video.models.wan2.vae import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0)
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
@@ -39,7 +39,7 @@ class TestCausalConv3d:
def test_causal_padding(self):
"""Causal conv should only use past/current frames, not future."""
from mlx_video.models.wan.vae import CausalConv3d
from mlx_video.models.wan2.vae import CausalConv3d
conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1)
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
@@ -56,7 +56,7 @@ class TestCausalConv3d:
class TestResidualBlock:
def test_same_dim(self):
from mlx_video.models.wan.vae import ResidualBlock
from mlx_video.models.wan2.vae import ResidualBlock
block = ResidualBlock(8, 8)
x = mx.random.normal((1, 8, 2, 4, 4))
@@ -65,7 +65,7 @@ class TestResidualBlock:
assert out.shape == (1, 8, 2, 4, 4)
def test_different_dim(self):
from mlx_video.models.wan.vae import ResidualBlock
from mlx_video.models.wan2.vae import ResidualBlock
block = ResidualBlock(8, 16)
x = mx.random.normal((1, 8, 2, 4, 4))
@@ -74,13 +74,13 @@ class TestResidualBlock:
assert out.shape == (1, 16, 2, 4, 4)
def test_shortcut_exists_when_dims_differ(self):
from mlx_video.models.wan.vae import ResidualBlock
from mlx_video.models.wan2.vae import ResidualBlock
block = ResidualBlock(8, 16)
assert block.shortcut is not None
def test_no_shortcut_when_dims_same(self):
from mlx_video.models.wan.vae import ResidualBlock
from mlx_video.models.wan2.vae import ResidualBlock
block = ResidualBlock(8, 8)
assert block.shortcut is None
@@ -88,7 +88,7 @@ class TestResidualBlock:
class TestAttentionBlock:
def test_output_shape(self):
from mlx_video.models.wan.vae import AttentionBlock
from mlx_video.models.wan2.vae import AttentionBlock
block = AttentionBlock(8)
x = mx.random.normal((1, 8, 2, 4, 4))
@@ -97,7 +97,7 @@ class TestAttentionBlock:
assert out.shape == (1, 8, 2, 4, 4)
def test_residual_connection(self):
from mlx_video.models.wan.vae import AttentionBlock
from mlx_video.models.wan2.vae import AttentionBlock
block = AttentionBlock(8)
x = mx.random.normal((1, 8, 1, 3, 3))
@@ -109,7 +109,7 @@ class TestAttentionBlock:
class TestWanVAE:
def test_instantiation(self):
from mlx_video.models.wan.vae import WanVAE
from mlx_video.models.wan2.vae import WanVAE
vae = WanVAE(z_dim=16)
assert vae.z_dim == 16
@@ -117,7 +117,7 @@ class TestWanVAE:
assert vae.std.shape == (16,)
def test_normalization_stats(self):
from mlx_video.models.wan.vae import VAE_MEAN, VAE_STD
from mlx_video.models.wan2.vae import VAE_MEAN, VAE_STD
assert len(VAE_MEAN) == 16
assert len(VAE_STD) == 16
@@ -133,7 +133,7 @@ class TestVAE22CausalConv3d:
"""Tests for vae22.CausalConv3d (channels-last)."""
def test_output_shape_k3(self):
from mlx_video.models.wan.vae22 import CausalConv3d
from mlx_video.models.wan2.vae22 import CausalConv3d
conv = CausalConv3d(8, 16, kernel_size=3, padding=1)
x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C]
@@ -142,7 +142,7 @@ class TestVAE22CausalConv3d:
assert out.shape == (1, 4, 8, 8, 16)
def test_output_shape_k1(self):
from mlx_video.models.wan.vae22 import CausalConv3d
from mlx_video.models.wan2.vae22 import CausalConv3d
conv = CausalConv3d(8, 16, kernel_size=1)
x = mx.random.normal((1, 2, 4, 4, 8))
@@ -152,7 +152,7 @@ class TestVAE22CausalConv3d:
def test_temporal_causal(self):
"""Output at t=0 should not depend on t>0."""
from mlx_video.models.wan.vae22 import CausalConv3d
from mlx_video.models.wan2.vae22 import CausalConv3d
conv = CausalConv3d(2, 2, kernel_size=3, padding=1)
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
@@ -178,7 +178,7 @@ class TestVAE22CausalConv3d:
def test_channels_last_format(self):
"""Verify input/output are channels-last [B, T, H, W, C]."""
from mlx_video.models.wan.vae22 import CausalConv3d
from mlx_video.models.wan2.vae22 import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=3, padding=1)
x = mx.random.normal((2, 3, 6, 6, 4))
@@ -191,7 +191,7 @@ class TestRMSNorm:
"""Tests for vae22.RMS_norm (actually L2 normalization)."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import RMS_norm
from mlx_video.models.wan2.vae22 import RMS_norm
norm = RMS_norm(16)
x = mx.random.normal((2, 4, 4, 4, 16))
@@ -201,7 +201,7 @@ class TestRMSNorm:
def test_l2_normalization(self):
"""RMS_norm should normalize to unit L2 norm * sqrt(dim)."""
from mlx_video.models.wan.vae22 import RMS_norm
from mlx_video.models.wan2.vae22 import RMS_norm
dim = 32
norm = RMS_norm(dim)
@@ -215,7 +215,7 @@ class TestRMSNorm:
def test_scale_invariant(self):
"""Scaling input by constant should not change output (L2 norm property)."""
from mlx_video.models.wan.vae22 import RMS_norm
from mlx_video.models.wan2.vae22 import RMS_norm
norm = RMS_norm(8)
x = mx.random.normal((1, 1, 1, 1, 8))
@@ -226,7 +226,7 @@ class TestRMSNorm:
def test_gamma_effect(self):
"""Non-unit gamma should scale output."""
from mlx_video.models.wan.vae22 import RMS_norm
from mlx_video.models.wan2.vae22 import RMS_norm
norm = RMS_norm(4)
norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0])
@@ -241,7 +241,7 @@ class TestDupUp3D:
"""Tests for vae22.DupUp3D spatial/temporal upsampling."""
def test_spatial_only(self):
from mlx_video.models.wan.vae22 import DupUp3D
from mlx_video.models.wan2.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
@@ -250,7 +250,7 @@ class TestDupUp3D:
assert out.shape == (1, 3, 8, 8, 4)
def test_temporal_and_spatial(self):
from mlx_video.models.wan.vae22 import DupUp3D
from mlx_video.models.wan2.vae22 import DupUp3D
up = DupUp3D(16, 8, factor_t=2, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 16))
@@ -259,7 +259,7 @@ class TestDupUp3D:
assert out.shape == (1, 6, 8, 8, 8)
def test_first_chunk_trims(self):
from mlx_video.models.wan.vae22 import DupUp3D
from mlx_video.models.wan2.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=2, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
@@ -271,7 +271,7 @@ class TestDupUp3D:
assert out_trimmed.shape[1] == 5
def test_no_temporal_first_chunk_noop(self):
from mlx_video.models.wan.vae22 import DupUp3D
from mlx_video.models.wan2.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
@@ -286,7 +286,7 @@ class TestVAE22Resample:
"""Tests for vae22.Resample (spatial/temporal upsampling)."""
def test_upsample2d_shape(self):
from mlx_video.models.wan.vae22 import Resample
from mlx_video.models.wan2.vae22 import Resample
r = Resample(8, "upsample2d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
@@ -296,7 +296,7 @@ class TestVAE22Resample:
assert out.shape == (1, 2, 8, 8, 8) # 2x spatial, same temporal
def test_upsample3d_shape(self):
from mlx_video.models.wan.vae22 import Resample
from mlx_video.models.wan2.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
@@ -306,7 +306,7 @@ class TestVAE22Resample:
assert out.shape == (1, 4, 8, 8, 8) # 2x spatial + 2x temporal
def test_upsample3d_first_chunk(self):
from mlx_video.models.wan.vae22 import Resample
from mlx_video.models.wan2.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
@@ -318,7 +318,7 @@ class TestVAE22Resample:
def test_upsample3d_first_chunk_single_frame(self):
"""Single-frame input with first_chunk: no temporal upsample."""
from mlx_video.models.wan.vae22 import Resample
from mlx_video.models.wan2.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
@@ -336,7 +336,7 @@ class TestVAE22Resample:
We verify this by checking that the first output frame depends only on
the first input frame (not on time_conv parameters).
"""
from mlx_video.models.wan.vae22 import Resample
from mlx_video.models.wan2.vae22 import Resample
C = 8
r = Resample(C, "upsample3d")
@@ -373,7 +373,7 @@ class TestVAE22ResidualBlock:
"""Tests for vae22.ResidualBlock."""
def test_same_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock
from mlx_video.models.wan2.vae22 import ResidualBlock
block = ResidualBlock(8, 8)
x = mx.random.normal((1, 2, 4, 4, 8))
@@ -382,7 +382,7 @@ class TestVAE22ResidualBlock:
assert out.shape == (1, 2, 4, 4, 8)
def test_different_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock
from mlx_video.models.wan2.vae22 import ResidualBlock
block = ResidualBlock(8, 16)
x = mx.random.normal((1, 2, 4, 4, 8))
@@ -391,13 +391,13 @@ class TestVAE22ResidualBlock:
assert out.shape == (1, 2, 4, 4, 16)
def test_shortcut_when_dims_differ(self):
from mlx_video.models.wan.vae22 import ResidualBlock
from mlx_video.models.wan2.vae22 import ResidualBlock
block = ResidualBlock(8, 16)
assert block.shortcut is not None
def test_no_shortcut_same_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock
from mlx_video.models.wan2.vae22 import ResidualBlock
block = ResidualBlock(8, 8)
assert block.shortcut is None
@@ -408,7 +408,7 @@ class TestResidualBlockLayers:
def test_layer_names_no_underscore_prefix(self):
"""Layer names must NOT start with underscore (MLX ignores them)."""
from mlx_video.models.wan.vae22 import ResidualBlockLayers
from mlx_video.models.wan2.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 8)
params = dict(block.parameters())
@@ -417,7 +417,7 @@ class TestResidualBlockLayers:
assert not key.startswith("_"), f"Parameter {key} starts with underscore"
def test_has_expected_layers(self):
from mlx_video.models.wan.vae22 import ResidualBlockLayers
from mlx_video.models.wan2.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 16)
assert hasattr(block, "layer_0") # first RMS_norm
@@ -426,7 +426,7 @@ class TestResidualBlockLayers:
assert hasattr(block, "layer_6") # second CausalConv3d
def test_forward_shape(self):
from mlx_video.models.wan.vae22 import ResidualBlockLayers
from mlx_video.models.wan2.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 16)
x = mx.random.normal((1, 2, 4, 4, 8))
@@ -439,7 +439,7 @@ class TestVAE22AttentionBlock:
"""Tests for vae22.AttentionBlock (per-frame 2D self-attention)."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import AttentionBlock
from mlx_video.models.wan2.vae22 import AttentionBlock
block = AttentionBlock(16)
block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01
@@ -450,7 +450,7 @@ class TestVAE22AttentionBlock:
assert out.shape == (1, 2, 4, 4, 16)
def test_residual_connection(self):
from mlx_video.models.wan.vae22 import AttentionBlock
from mlx_video.models.wan2.vae22 import AttentionBlock
block = AttentionBlock(8)
block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape)
@@ -466,7 +466,7 @@ class TestHead22:
"""Tests for vae22.Head22 output head."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import Head22
from mlx_video.models.wan2.vae22 import Head22
head = Head22(16, out_channels=12)
x = mx.random.normal((1, 2, 4, 4, 16))
@@ -476,7 +476,7 @@ class TestHead22:
def test_layer_names_no_underscore(self):
"""Head layers must not use underscore prefix."""
from mlx_video.models.wan.vae22 import Head22
from mlx_video.models.wan2.vae22 import Head22
head = Head22(8)
assert hasattr(head, "layer_0") # RMS_norm
@@ -490,7 +490,7 @@ class TestUnpatchify:
"""Tests for vae22._unpatchify."""
def test_basic_shape(self):
from mlx_video.models.wan.vae22 import _unpatchify
from mlx_video.models.wan2.vae22 import _unpatchify
x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2
out = _unpatchify(x, patch_size=2)
@@ -498,7 +498,7 @@ class TestUnpatchify:
assert out.shape == (1, 2, 8, 8, 3)
def test_patch_size_1_noop(self):
from mlx_video.models.wan.vae22 import _unpatchify
from mlx_video.models.wan2.vae22 import _unpatchify
x = mx.random.normal((1, 2, 4, 4, 3))
out = _unpatchify(x, patch_size=1)
@@ -507,7 +507,7 @@ class TestUnpatchify:
def test_preserves_content(self):
"""Unpatchify should be a lossless rearrangement."""
from mlx_video.models.wan.vae22 import _unpatchify
from mlx_video.models.wan2.vae22 import _unpatchify
x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32)
out = _unpatchify(x, patch_size=2)
@@ -521,7 +521,7 @@ class TestDenormalizeLatents:
"""Tests for vae22.denormalize_latents."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import denormalize_latents
from mlx_video.models.wan2.vae22 import denormalize_latents
z = mx.random.normal((1, 2, 4, 4, 48))
out = denormalize_latents(z)
@@ -529,7 +529,7 @@ class TestDenormalizeLatents:
assert out.shape == (1, 2, 4, 4, 48)
def test_custom_mean_std(self):
from mlx_video.models.wan.vae22 import denormalize_latents
from mlx_video.models.wan2.vae22 import denormalize_latents
z = mx.ones((1, 1, 1, 1, 4))
mean = mx.array([1.0, 2.0, 3.0, 4.0])
@@ -542,7 +542,7 @@ class TestDenormalizeLatents:
)
def test_uses_default_constants(self):
from mlx_video.models.wan.vae22 import (
from mlx_video.models.wan2.vae22 import (
VAE22_MEAN,
denormalize_latents,
)
@@ -563,14 +563,14 @@ class TestVAE22NormConstants:
"""Tests for VAE22_MEAN and VAE22_STD constants."""
def test_dimensions(self):
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD
from mlx_video.models.wan2.vae22 import VAE22_MEAN, VAE22_STD
mx.eval(VAE22_MEAN, VAE22_STD)
assert VAE22_MEAN.shape == (48,)
assert VAE22_STD.shape == (48,)
def test_std_positive(self):
from mlx_video.models.wan.vae22 import VAE22_STD
from mlx_video.models.wan2.vae22 import VAE22_STD
mx.eval(VAE22_STD)
assert (np.array(VAE22_STD) > 0).all()
@@ -581,7 +581,7 @@ class TestWan22VAEDecoder:
def test_output_shape_small(self):
"""Tiny decoder should produce correct spatial/temporal output."""
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
from mlx_video.models.wan2.vae22 import Wan22VAEDecoder
# Use very small dims to keep test fast
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
@@ -597,7 +597,7 @@ class TestWan22VAEDecoder:
assert np.array(out).max() <= 1.0
def test_output_clipped(self):
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
from mlx_video.models.wan2.vae22 import Wan22VAEDecoder
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values
@@ -611,7 +611,7 @@ class TestSanitizeWan22VAEWeights:
"""Tests for vae22.sanitize_wan22_vae_weights."""
def test_skip_encoder(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.layer.weight": mx.zeros((4,)),
@@ -624,7 +624,7 @@ class TestSanitizeWan22VAEWeights:
assert "decoder.conv1.bias" in out
def test_sequential_index_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)),
@@ -639,7 +639,7 @@ class TestSanitizeWan22VAEWeights:
assert "decoder.head.layer_2.bias" in out
def test_resample_conv_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)),
@@ -650,7 +650,7 @@ class TestSanitizeWan22VAEWeights:
assert "decoder.upsamples.1.upsamples.3.resample_bias" in out
def test_attention_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)),
@@ -665,7 +665,7 @@ class TestSanitizeWan22VAEWeights:
assert "decoder.middle.1.proj_bias" in out
def test_conv3d_transpose(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
# Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I]
w = mx.zeros((16, 8, 3, 3, 3))
@@ -674,7 +674,7 @@ class TestSanitizeWan22VAEWeights:
assert out["decoder.conv1.weight"].shape == (16, 3, 3, 3, 8)
def test_conv2d_transpose(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
# Conv2d weight: [O, I, H, W] → [O, H, W, I]
w = mx.zeros((8, 8, 3, 3))
@@ -684,7 +684,7 @@ class TestSanitizeWan22VAEWeights:
assert out[key].shape == (8, 3, 3, 8)
def test_gamma_squeeze(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
# gamma: (dim, 1, 1, 1) → (dim,)
w = mx.ones((16, 1, 1, 1))
@@ -698,7 +698,7 @@ class TestUpResidualBlock:
"""Tests for vae22.Up_ResidualBlock."""
def test_no_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock
from mlx_video.models.wan2.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(
8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False
@@ -710,7 +710,7 @@ class TestUpResidualBlock:
assert out.shape == (1, 2, 4, 4, 8)
def test_spatial_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock
from mlx_video.models.wan2.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(
8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True
@@ -722,7 +722,7 @@ class TestUpResidualBlock:
assert out.shape == (1, 2, 8, 8, 4)
def test_spatial_temporal_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock
from mlx_video.models.wan2.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(
8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True
@@ -738,7 +738,7 @@ class TestPatchify:
"""Tests for _patchify and _unpatchify round-trip."""
def test_roundtrip(self):
from mlx_video.models.wan.vae22 import _patchify, _unpatchify
from mlx_video.models.wan2.vae22 import _patchify, _unpatchify
x = mx.random.normal((1, 1, 64, 64, 3))
p = _patchify(x, patch_size=2)
@@ -748,7 +748,7 @@ class TestPatchify:
assert float(mx.abs(x - back).max()) == 0.0
def test_identity_patch_1(self):
from mlx_video.models.wan.vae22 import _patchify, _unpatchify
from mlx_video.models.wan2.vae22 import _patchify, _unpatchify
x = mx.random.normal((1, 2, 8, 8, 3))
assert _patchify(x, patch_size=1).shape == x.shape
@@ -759,7 +759,7 @@ class TestAvgDown3D:
"""Tests for AvgDown3D downsampling."""
def test_spatial_only(self):
from mlx_video.models.wan.vae22 import AvgDown3D
from mlx_video.models.wan2.vae22 import AvgDown3D
down = AvgDown3D(8, 16, factor_t=1, factor_s=2)
x = mx.random.normal((1, 2, 8, 8, 8))
@@ -768,7 +768,7 @@ class TestAvgDown3D:
assert out.shape == (1, 2, 4, 4, 16)
def test_temporal_and_spatial(self):
from mlx_video.models.wan.vae22 import AvgDown3D
from mlx_video.models.wan2.vae22 import AvgDown3D
down = AvgDown3D(8, 16, factor_t=2, factor_s=2)
x = mx.random.normal((1, 4, 8, 8, 8))
@@ -777,7 +777,7 @@ class TestAvgDown3D:
assert out.shape == (1, 2, 4, 4, 16)
def test_single_frame(self):
from mlx_video.models.wan.vae22 import AvgDown3D
from mlx_video.models.wan2.vae22 import AvgDown3D
down = AvgDown3D(8, 8, factor_t=2, factor_s=2)
x = mx.random.normal((1, 1, 8, 8, 8))
@@ -791,7 +791,7 @@ class TestDownResidualBlock:
"""Tests for Down_ResidualBlock."""
def test_no_downsample(self):
from mlx_video.models.wan.vae22 import Down_ResidualBlock
from mlx_video.models.wan2.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(
8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False
@@ -802,7 +802,7 @@ class TestDownResidualBlock:
assert out.shape == (1, 2, 8, 8, 8)
def test_spatial_downsample(self):
from mlx_video.models.wan.vae22 import Down_ResidualBlock
from mlx_video.models.wan2.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(
8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True
@@ -813,7 +813,7 @@ class TestDownResidualBlock:
assert out.shape == (1, 2, 4, 4, 16)
def test_spatial_temporal_downsample(self):
from mlx_video.models.wan.vae22 import Down_ResidualBlock
from mlx_video.models.wan2.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(
8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True
@@ -828,7 +828,7 @@ class TestEncoder3d:
"""Tests for Encoder3d."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import Encoder3d
from mlx_video.models.wan2.vae22 import Encoder3d
enc = Encoder3d(dim=16, z_dim=8)
x = mx.random.normal((1, 1, 16, 16, 12))
@@ -839,7 +839,7 @@ class TestEncoder3d:
assert out.shape == (1, 1, 2, 2, 8)
def test_multi_frame(self):
from mlx_video.models.wan.vae22 import Encoder3d
from mlx_video.models.wan2.vae22 import Encoder3d
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
x = mx.random.normal((1, 5, 16, 16, 12))
@@ -854,7 +854,7 @@ class TestWan22VAEEncoder:
"""Tests for Wan22VAEEncoder wrapper."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16)
# Input: single image 32×32 (patchify÷2 → 16×16, then 3 spatial ÷8 → 2×2)
@@ -865,7 +865,7 @@ class TestWan22VAEEncoder:
assert z.shape == (1, 1, 2, 2, 48)
def test_full_dim(self):
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=160)
img = mx.random.normal((1, 1, 64, 64, 3))
@@ -880,7 +880,7 @@ class TestNormalizeLatents:
"""Tests for normalize/denormalize latent roundtrip."""
def test_roundtrip(self):
from mlx_video.models.wan.vae22 import denormalize_latents, normalize_latents
from mlx_video.models.wan2.vae22 import denormalize_latents, normalize_latents
z = mx.random.normal((1, 2, 4, 4, 48))
z_norm = normalize_latents(z)
@@ -895,7 +895,7 @@ class TestVAEEncoderTemporalOrder:
def test_encoder_temporal_downsample_pattern(self):
"""Encoder3d with (False, True, True): T=5→5→3→2."""
from mlx_video.models.wan.vae22 import Encoder3d
from mlx_video.models.wan2.vae22 import Encoder3d
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
x = mx.random.normal((1, 5, 16, 16, 12))
@@ -906,7 +906,7 @@ class TestVAEEncoderTemporalOrder:
def test_wrapper_uses_correct_pattern(self):
"""Wan22VAEEncoder should use (False, True, True) temporal downsample."""
from mlx_video.models.wan.vae22 import Resample, Wan22VAEEncoder
from mlx_video.models.wan2.vae22 import Resample, Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16)
down_blocks = enc.encoder.downsamples
@@ -921,7 +921,7 @@ class TestVAEEncoderTemporalOrder:
def test_single_frame_encoder(self):
"""Single frame (T=1) should work with (False, True, True) pattern."""
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16)
img = mx.random.normal((1, 1, 32, 32, 3))
@@ -933,7 +933,7 @@ class TestVAEEncoderTemporalOrder:
def test_wrong_order_gives_different_result(self):
"""(True, True, False) vs (False, True, True) produce different outputs."""
from mlx_video.models.wan.vae22 import Encoder3d
from mlx_video.models.wan2.vae22 import Encoder3d
enc_correct = Encoder3d(
dim=16, z_dim=8, temperal_downsample=(False, True, True)
@@ -963,7 +963,7 @@ class TestVAE21RoundTrip:
def test_encode_decode_shape_and_values(self):
"""Encoder3d → Decoder3d: output shape matches input, values are finite."""
from mlx_video.models.wan.vae import Decoder3d, Encoder3d
from mlx_video.models.wan2.vae import Decoder3d, Encoder3d
z_dim = 4
dim = 8
@@ -995,7 +995,7 @@ class TestVAE22RoundTrip:
def test_encode_decode_shape_and_values(self):
"""Wan22VAEEncoder → Wan22VAEDecoder: shapes consistent, values in range."""
from mlx_video.models.wan.vae22 import (
from mlx_video.models.wan2.vae22 import (
Wan22VAEDecoder,
Wan22VAEEncoder,
denormalize_latents,