Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.

This commit is contained in:
Prince Canuma
2026-03-18 17:59:43 +01:00
parent b029668cd2
commit 996a542011
37 changed files with 354 additions and 354 deletions

View File

@@ -12,7 +12,7 @@ import numpy as np
class TestCausalConv3d:
def test_output_shape_stride1(self):
from mlx_video.models.wan2.vae import CausalConv3d
from mlx_video.models.wan_2.vae import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1)
# Initialize weights
@@ -28,7 +28,7 @@ class TestCausalConv3d:
assert out.shape[4] == 8 # W preserved
def test_output_shape_kernel1(self):
from mlx_video.models.wan2.vae import CausalConv3d
from mlx_video.models.wan_2.vae import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0)
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
@@ -39,7 +39,7 @@ class TestCausalConv3d:
def test_causal_padding(self):
"""Causal conv should only use past/current frames, not future."""
from mlx_video.models.wan2.vae import CausalConv3d
from mlx_video.models.wan_2.vae import CausalConv3d
conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1)
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
@@ -56,7 +56,7 @@ class TestCausalConv3d:
class TestResidualBlock:
def test_same_dim(self):
from mlx_video.models.wan2.vae import ResidualBlock
from mlx_video.models.wan_2.vae import ResidualBlock
block = ResidualBlock(8, 8)
x = mx.random.normal((1, 8, 2, 4, 4))
@@ -65,7 +65,7 @@ class TestResidualBlock:
assert out.shape == (1, 8, 2, 4, 4)
def test_different_dim(self):
from mlx_video.models.wan2.vae import ResidualBlock
from mlx_video.models.wan_2.vae import ResidualBlock
block = ResidualBlock(8, 16)
x = mx.random.normal((1, 8, 2, 4, 4))
@@ -74,13 +74,13 @@ class TestResidualBlock:
assert out.shape == (1, 16, 2, 4, 4)
def test_shortcut_exists_when_dims_differ(self):
from mlx_video.models.wan2.vae import ResidualBlock
from mlx_video.models.wan_2.vae import ResidualBlock
block = ResidualBlock(8, 16)
assert block.shortcut is not None
def test_no_shortcut_when_dims_same(self):
from mlx_video.models.wan2.vae import ResidualBlock
from mlx_video.models.wan_2.vae import ResidualBlock
block = ResidualBlock(8, 8)
assert block.shortcut is None
@@ -88,7 +88,7 @@ class TestResidualBlock:
class TestAttentionBlock:
def test_output_shape(self):
from mlx_video.models.wan2.vae import AttentionBlock
from mlx_video.models.wan_2.vae import AttentionBlock
block = AttentionBlock(8)
x = mx.random.normal((1, 8, 2, 4, 4))
@@ -97,7 +97,7 @@ class TestAttentionBlock:
assert out.shape == (1, 8, 2, 4, 4)
def test_residual_connection(self):
from mlx_video.models.wan2.vae import AttentionBlock
from mlx_video.models.wan_2.vae import AttentionBlock
block = AttentionBlock(8)
x = mx.random.normal((1, 8, 1, 3, 3))
@@ -109,7 +109,7 @@ class TestAttentionBlock:
class TestWanVAE:
def test_instantiation(self):
from mlx_video.models.wan2.vae import WanVAE
from mlx_video.models.wan_2.vae import WanVAE
vae = WanVAE(z_dim=16)
assert vae.z_dim == 16
@@ -117,7 +117,7 @@ class TestWanVAE:
assert vae.std.shape == (16,)
def test_normalization_stats(self):
from mlx_video.models.wan2.vae import VAE_MEAN, VAE_STD
from mlx_video.models.wan_2.vae import VAE_MEAN, VAE_STD
assert len(VAE_MEAN) == 16
assert len(VAE_STD) == 16
@@ -133,7 +133,7 @@ class TestVAE22CausalConv3d:
"""Tests for vae22.CausalConv3d (channels-last)."""
def test_output_shape_k3(self):
from mlx_video.models.wan2.vae22 import CausalConv3d
from mlx_video.models.wan_2.vae22 import CausalConv3d
conv = CausalConv3d(8, 16, kernel_size=3, padding=1)
x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C]
@@ -142,7 +142,7 @@ class TestVAE22CausalConv3d:
assert out.shape == (1, 4, 8, 8, 16)
def test_output_shape_k1(self):
from mlx_video.models.wan2.vae22 import CausalConv3d
from mlx_video.models.wan_2.vae22 import CausalConv3d
conv = CausalConv3d(8, 16, kernel_size=1)
x = mx.random.normal((1, 2, 4, 4, 8))
@@ -152,7 +152,7 @@ class TestVAE22CausalConv3d:
def test_temporal_causal(self):
"""Output at t=0 should not depend on t>0."""
from mlx_video.models.wan2.vae22 import CausalConv3d
from mlx_video.models.wan_2.vae22 import CausalConv3d
conv = CausalConv3d(2, 2, kernel_size=3, padding=1)
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
@@ -178,7 +178,7 @@ class TestVAE22CausalConv3d:
def test_channels_last_format(self):
"""Verify input/output are channels-last [B, T, H, W, C]."""
from mlx_video.models.wan2.vae22 import CausalConv3d
from mlx_video.models.wan_2.vae22 import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=3, padding=1)
x = mx.random.normal((2, 3, 6, 6, 4))
@@ -191,7 +191,7 @@ class TestRMSNorm:
"""Tests for vae22.RMS_norm (actually L2 normalization)."""
def test_output_shape(self):
from mlx_video.models.wan2.vae22 import RMS_norm
from mlx_video.models.wan_2.vae22 import RMS_norm
norm = RMS_norm(16)
x = mx.random.normal((2, 4, 4, 4, 16))
@@ -201,7 +201,7 @@ class TestRMSNorm:
def test_l2_normalization(self):
"""RMS_norm should normalize to unit L2 norm * sqrt(dim)."""
from mlx_video.models.wan2.vae22 import RMS_norm
from mlx_video.models.wan_2.vae22 import RMS_norm
dim = 32
norm = RMS_norm(dim)
@@ -215,7 +215,7 @@ class TestRMSNorm:
def test_scale_invariant(self):
"""Scaling input by constant should not change output (L2 norm property)."""
from mlx_video.models.wan2.vae22 import RMS_norm
from mlx_video.models.wan_2.vae22 import RMS_norm
norm = RMS_norm(8)
x = mx.random.normal((1, 1, 1, 1, 8))
@@ -226,7 +226,7 @@ class TestRMSNorm:
def test_gamma_effect(self):
"""Non-unit gamma should scale output."""
from mlx_video.models.wan2.vae22 import RMS_norm
from mlx_video.models.wan_2.vae22 import RMS_norm
norm = RMS_norm(4)
norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0])
@@ -241,7 +241,7 @@ class TestDupUp3D:
"""Tests for vae22.DupUp3D spatial/temporal upsampling."""
def test_spatial_only(self):
from mlx_video.models.wan2.vae22 import DupUp3D
from mlx_video.models.wan_2.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
@@ -250,7 +250,7 @@ class TestDupUp3D:
assert out.shape == (1, 3, 8, 8, 4)
def test_temporal_and_spatial(self):
from mlx_video.models.wan2.vae22 import DupUp3D
from mlx_video.models.wan_2.vae22 import DupUp3D
up = DupUp3D(16, 8, factor_t=2, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 16))
@@ -259,7 +259,7 @@ class TestDupUp3D:
assert out.shape == (1, 6, 8, 8, 8)
def test_first_chunk_trims(self):
from mlx_video.models.wan2.vae22 import DupUp3D
from mlx_video.models.wan_2.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=2, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
@@ -271,7 +271,7 @@ class TestDupUp3D:
assert out_trimmed.shape[1] == 5
def test_no_temporal_first_chunk_noop(self):
from mlx_video.models.wan2.vae22 import DupUp3D
from mlx_video.models.wan_2.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
@@ -286,7 +286,7 @@ class TestVAE22Resample:
"""Tests for vae22.Resample (spatial/temporal upsampling)."""
def test_upsample2d_shape(self):
from mlx_video.models.wan2.vae22 import Resample
from mlx_video.models.wan_2.vae22 import Resample
r = Resample(8, "upsample2d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
@@ -296,7 +296,7 @@ class TestVAE22Resample:
assert out.shape == (1, 2, 8, 8, 8) # 2x spatial, same temporal
def test_upsample3d_shape(self):
from mlx_video.models.wan2.vae22 import Resample
from mlx_video.models.wan_2.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
@@ -306,7 +306,7 @@ class TestVAE22Resample:
assert out.shape == (1, 4, 8, 8, 8) # 2x spatial + 2x temporal
def test_upsample3d_first_chunk(self):
from mlx_video.models.wan2.vae22 import Resample
from mlx_video.models.wan_2.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
@@ -318,7 +318,7 @@ class TestVAE22Resample:
def test_upsample3d_first_chunk_single_frame(self):
"""Single-frame input with first_chunk: no temporal upsample."""
from mlx_video.models.wan2.vae22 import Resample
from mlx_video.models.wan_2.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
@@ -336,7 +336,7 @@ class TestVAE22Resample:
We verify this by checking that the first output frame depends only on
the first input frame (not on time_conv parameters).
"""
from mlx_video.models.wan2.vae22 import Resample
from mlx_video.models.wan_2.vae22 import Resample
C = 8
r = Resample(C, "upsample3d")
@@ -373,7 +373,7 @@ class TestVAE22ResidualBlock:
"""Tests for vae22.ResidualBlock."""
def test_same_dim(self):
from mlx_video.models.wan2.vae22 import ResidualBlock
from mlx_video.models.wan_2.vae22 import ResidualBlock
block = ResidualBlock(8, 8)
x = mx.random.normal((1, 2, 4, 4, 8))
@@ -382,7 +382,7 @@ class TestVAE22ResidualBlock:
assert out.shape == (1, 2, 4, 4, 8)
def test_different_dim(self):
from mlx_video.models.wan2.vae22 import ResidualBlock
from mlx_video.models.wan_2.vae22 import ResidualBlock
block = ResidualBlock(8, 16)
x = mx.random.normal((1, 2, 4, 4, 8))
@@ -391,13 +391,13 @@ class TestVAE22ResidualBlock:
assert out.shape == (1, 2, 4, 4, 16)
def test_shortcut_when_dims_differ(self):
from mlx_video.models.wan2.vae22 import ResidualBlock
from mlx_video.models.wan_2.vae22 import ResidualBlock
block = ResidualBlock(8, 16)
assert block.shortcut is not None
def test_no_shortcut_same_dim(self):
from mlx_video.models.wan2.vae22 import ResidualBlock
from mlx_video.models.wan_2.vae22 import ResidualBlock
block = ResidualBlock(8, 8)
assert block.shortcut is None
@@ -408,7 +408,7 @@ class TestResidualBlockLayers:
def test_layer_names_no_underscore_prefix(self):
"""Layer names must NOT start with underscore (MLX ignores them)."""
from mlx_video.models.wan2.vae22 import ResidualBlockLayers
from mlx_video.models.wan_2.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 8)
params = dict(block.parameters())
@@ -417,7 +417,7 @@ class TestResidualBlockLayers:
assert not key.startswith("_"), f"Parameter {key} starts with underscore"
def test_has_expected_layers(self):
from mlx_video.models.wan2.vae22 import ResidualBlockLayers
from mlx_video.models.wan_2.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 16)
assert hasattr(block, "layer_0") # first RMS_norm
@@ -426,7 +426,7 @@ class TestResidualBlockLayers:
assert hasattr(block, "layer_6") # second CausalConv3d
def test_forward_shape(self):
from mlx_video.models.wan2.vae22 import ResidualBlockLayers
from mlx_video.models.wan_2.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 16)
x = mx.random.normal((1, 2, 4, 4, 8))
@@ -439,7 +439,7 @@ class TestVAE22AttentionBlock:
"""Tests for vae22.AttentionBlock (per-frame 2D self-attention)."""
def test_output_shape(self):
from mlx_video.models.wan2.vae22 import AttentionBlock
from mlx_video.models.wan_2.vae22 import AttentionBlock
block = AttentionBlock(16)
block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01
@@ -450,7 +450,7 @@ class TestVAE22AttentionBlock:
assert out.shape == (1, 2, 4, 4, 16)
def test_residual_connection(self):
from mlx_video.models.wan2.vae22 import AttentionBlock
from mlx_video.models.wan_2.vae22 import AttentionBlock
block = AttentionBlock(8)
block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape)
@@ -466,7 +466,7 @@ class TestHead22:
"""Tests for vae22.Head22 output head."""
def test_output_shape(self):
from mlx_video.models.wan2.vae22 import Head22
from mlx_video.models.wan_2.vae22 import Head22
head = Head22(16, out_channels=12)
x = mx.random.normal((1, 2, 4, 4, 16))
@@ -476,7 +476,7 @@ class TestHead22:
def test_layer_names_no_underscore(self):
"""Head layers must not use underscore prefix."""
from mlx_video.models.wan2.vae22 import Head22
from mlx_video.models.wan_2.vae22 import Head22
head = Head22(8)
assert hasattr(head, "layer_0") # RMS_norm
@@ -490,7 +490,7 @@ class TestUnpatchify:
"""Tests for vae22._unpatchify."""
def test_basic_shape(self):
from mlx_video.models.wan2.vae22 import _unpatchify
from mlx_video.models.wan_2.vae22 import _unpatchify
x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2
out = _unpatchify(x, patch_size=2)
@@ -498,7 +498,7 @@ class TestUnpatchify:
assert out.shape == (1, 2, 8, 8, 3)
def test_patch_size_1_noop(self):
from mlx_video.models.wan2.vae22 import _unpatchify
from mlx_video.models.wan_2.vae22 import _unpatchify
x = mx.random.normal((1, 2, 4, 4, 3))
out = _unpatchify(x, patch_size=1)
@@ -507,7 +507,7 @@ class TestUnpatchify:
def test_preserves_content(self):
"""Unpatchify should be a lossless rearrangement."""
from mlx_video.models.wan2.vae22 import _unpatchify
from mlx_video.models.wan_2.vae22 import _unpatchify
x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32)
out = _unpatchify(x, patch_size=2)
@@ -521,7 +521,7 @@ class TestDenormalizeLatents:
"""Tests for vae22.denormalize_latents."""
def test_output_shape(self):
from mlx_video.models.wan2.vae22 import denormalize_latents
from mlx_video.models.wan_2.vae22 import denormalize_latents
z = mx.random.normal((1, 2, 4, 4, 48))
out = denormalize_latents(z)
@@ -529,7 +529,7 @@ class TestDenormalizeLatents:
assert out.shape == (1, 2, 4, 4, 48)
def test_custom_mean_std(self):
from mlx_video.models.wan2.vae22 import denormalize_latents
from mlx_video.models.wan_2.vae22 import denormalize_latents
z = mx.ones((1, 1, 1, 1, 4))
mean = mx.array([1.0, 2.0, 3.0, 4.0])
@@ -542,7 +542,7 @@ class TestDenormalizeLatents:
)
def test_uses_default_constants(self):
from mlx_video.models.wan2.vae22 import (
from mlx_video.models.wan_2.vae22 import (
VAE22_MEAN,
denormalize_latents,
)
@@ -563,14 +563,14 @@ class TestVAE22NormConstants:
"""Tests for VAE22_MEAN and VAE22_STD constants."""
def test_dimensions(self):
from mlx_video.models.wan2.vae22 import VAE22_MEAN, VAE22_STD
from mlx_video.models.wan_2.vae22 import VAE22_MEAN, VAE22_STD
mx.eval(VAE22_MEAN, VAE22_STD)
assert VAE22_MEAN.shape == (48,)
assert VAE22_STD.shape == (48,)
def test_std_positive(self):
from mlx_video.models.wan2.vae22 import VAE22_STD
from mlx_video.models.wan_2.vae22 import VAE22_STD
mx.eval(VAE22_STD)
assert (np.array(VAE22_STD) > 0).all()
@@ -581,7 +581,7 @@ class TestWan22VAEDecoder:
def test_output_shape_small(self):
"""Tiny decoder should produce correct spatial/temporal output."""
from mlx_video.models.wan2.vae22 import Wan22VAEDecoder
from mlx_video.models.wan_2.vae22 import Wan22VAEDecoder
# Use very small dims to keep test fast
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
@@ -597,7 +597,7 @@ class TestWan22VAEDecoder:
assert np.array(out).max() <= 1.0
def test_output_clipped(self):
from mlx_video.models.wan2.vae22 import Wan22VAEDecoder
from mlx_video.models.wan_2.vae22 import Wan22VAEDecoder
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values
@@ -611,7 +611,7 @@ class TestSanitizeWan22VAEWeights:
"""Tests for vae22.sanitize_wan22_vae_weights."""
def test_skip_encoder(self):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.layer.weight": mx.zeros((4,)),
@@ -624,7 +624,7 @@ class TestSanitizeWan22VAEWeights:
assert "decoder.conv1.bias" in out
def test_sequential_index_remapping(self):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)),
@@ -639,7 +639,7 @@ class TestSanitizeWan22VAEWeights:
assert "decoder.head.layer_2.bias" in out
def test_resample_conv_remapping(self):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)),
@@ -650,7 +650,7 @@ class TestSanitizeWan22VAEWeights:
assert "decoder.upsamples.1.upsamples.3.resample_bias" in out
def test_attention_remapping(self):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)),
@@ -665,7 +665,7 @@ class TestSanitizeWan22VAEWeights:
assert "decoder.middle.1.proj_bias" in out
def test_conv3d_transpose(self):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
# Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I]
w = mx.zeros((16, 8, 3, 3, 3))
@@ -674,7 +674,7 @@ class TestSanitizeWan22VAEWeights:
assert out["decoder.conv1.weight"].shape == (16, 3, 3, 3, 8)
def test_conv2d_transpose(self):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
# Conv2d weight: [O, I, H, W] → [O, H, W, I]
w = mx.zeros((8, 8, 3, 3))
@@ -684,7 +684,7 @@ class TestSanitizeWan22VAEWeights:
assert out[key].shape == (8, 3, 3, 8)
def test_gamma_squeeze(self):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
# gamma: (dim, 1, 1, 1) → (dim,)
w = mx.ones((16, 1, 1, 1))
@@ -698,7 +698,7 @@ class TestUpResidualBlock:
"""Tests for vae22.Up_ResidualBlock."""
def test_no_upsample(self):
from mlx_video.models.wan2.vae22 import Up_ResidualBlock
from mlx_video.models.wan_2.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(
8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False
@@ -710,7 +710,7 @@ class TestUpResidualBlock:
assert out.shape == (1, 2, 4, 4, 8)
def test_spatial_upsample(self):
from mlx_video.models.wan2.vae22 import Up_ResidualBlock
from mlx_video.models.wan_2.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(
8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True
@@ -722,7 +722,7 @@ class TestUpResidualBlock:
assert out.shape == (1, 2, 8, 8, 4)
def test_spatial_temporal_upsample(self):
from mlx_video.models.wan2.vae22 import Up_ResidualBlock
from mlx_video.models.wan_2.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(
8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True
@@ -738,7 +738,7 @@ class TestPatchify:
"""Tests for _patchify and _unpatchify round-trip."""
def test_roundtrip(self):
from mlx_video.models.wan2.vae22 import _patchify, _unpatchify
from mlx_video.models.wan_2.vae22 import _patchify, _unpatchify
x = mx.random.normal((1, 1, 64, 64, 3))
p = _patchify(x, patch_size=2)
@@ -748,7 +748,7 @@ class TestPatchify:
assert float(mx.abs(x - back).max()) == 0.0
def test_identity_patch_1(self):
from mlx_video.models.wan2.vae22 import _patchify, _unpatchify
from mlx_video.models.wan_2.vae22 import _patchify, _unpatchify
x = mx.random.normal((1, 2, 8, 8, 3))
assert _patchify(x, patch_size=1).shape == x.shape
@@ -759,7 +759,7 @@ class TestAvgDown3D:
"""Tests for AvgDown3D downsampling."""
def test_spatial_only(self):
from mlx_video.models.wan2.vae22 import AvgDown3D
from mlx_video.models.wan_2.vae22 import AvgDown3D
down = AvgDown3D(8, 16, factor_t=1, factor_s=2)
x = mx.random.normal((1, 2, 8, 8, 8))
@@ -768,7 +768,7 @@ class TestAvgDown3D:
assert out.shape == (1, 2, 4, 4, 16)
def test_temporal_and_spatial(self):
from mlx_video.models.wan2.vae22 import AvgDown3D
from mlx_video.models.wan_2.vae22 import AvgDown3D
down = AvgDown3D(8, 16, factor_t=2, factor_s=2)
x = mx.random.normal((1, 4, 8, 8, 8))
@@ -777,7 +777,7 @@ class TestAvgDown3D:
assert out.shape == (1, 2, 4, 4, 16)
def test_single_frame(self):
from mlx_video.models.wan2.vae22 import AvgDown3D
from mlx_video.models.wan_2.vae22 import AvgDown3D
down = AvgDown3D(8, 8, factor_t=2, factor_s=2)
x = mx.random.normal((1, 1, 8, 8, 8))
@@ -791,7 +791,7 @@ class TestDownResidualBlock:
"""Tests for Down_ResidualBlock."""
def test_no_downsample(self):
from mlx_video.models.wan2.vae22 import Down_ResidualBlock
from mlx_video.models.wan_2.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(
8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False
@@ -802,7 +802,7 @@ class TestDownResidualBlock:
assert out.shape == (1, 2, 8, 8, 8)
def test_spatial_downsample(self):
from mlx_video.models.wan2.vae22 import Down_ResidualBlock
from mlx_video.models.wan_2.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(
8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True
@@ -813,7 +813,7 @@ class TestDownResidualBlock:
assert out.shape == (1, 2, 4, 4, 16)
def test_spatial_temporal_downsample(self):
from mlx_video.models.wan2.vae22 import Down_ResidualBlock
from mlx_video.models.wan_2.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(
8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True
@@ -828,7 +828,7 @@ class TestEncoder3d:
"""Tests for Encoder3d."""
def test_output_shape(self):
from mlx_video.models.wan2.vae22 import Encoder3d
from mlx_video.models.wan_2.vae22 import Encoder3d
enc = Encoder3d(dim=16, z_dim=8)
x = mx.random.normal((1, 1, 16, 16, 12))
@@ -839,7 +839,7 @@ class TestEncoder3d:
assert out.shape == (1, 1, 2, 2, 8)
def test_multi_frame(self):
from mlx_video.models.wan2.vae22 import Encoder3d
from mlx_video.models.wan_2.vae22 import Encoder3d
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
x = mx.random.normal((1, 5, 16, 16, 12))
@@ -854,7 +854,7 @@ class TestWan22VAEEncoder:
"""Tests for Wan22VAEEncoder wrapper."""
def test_output_shape(self):
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder
from mlx_video.models.wan_2.vae22 import Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16)
# Input: single image 32×32 (patchify÷2 → 16×16, then 3 spatial ÷8 → 2×2)
@@ -865,7 +865,7 @@ class TestWan22VAEEncoder:
assert z.shape == (1, 1, 2, 2, 48)
def test_full_dim(self):
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder
from mlx_video.models.wan_2.vae22 import Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=160)
img = mx.random.normal((1, 1, 64, 64, 3))
@@ -880,7 +880,7 @@ class TestNormalizeLatents:
"""Tests for normalize/denormalize latent roundtrip."""
def test_roundtrip(self):
from mlx_video.models.wan2.vae22 import denormalize_latents, normalize_latents
from mlx_video.models.wan_2.vae22 import denormalize_latents, normalize_latents
z = mx.random.normal((1, 2, 4, 4, 48))
z_norm = normalize_latents(z)
@@ -895,7 +895,7 @@ class TestVAEEncoderTemporalOrder:
def test_encoder_temporal_downsample_pattern(self):
"""Encoder3d with (False, True, True): T=5→5→3→2."""
from mlx_video.models.wan2.vae22 import Encoder3d
from mlx_video.models.wan_2.vae22 import Encoder3d
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
x = mx.random.normal((1, 5, 16, 16, 12))
@@ -906,7 +906,7 @@ class TestVAEEncoderTemporalOrder:
def test_wrapper_uses_correct_pattern(self):
"""Wan22VAEEncoder should use (False, True, True) temporal downsample."""
from mlx_video.models.wan2.vae22 import Resample, Wan22VAEEncoder
from mlx_video.models.wan_2.vae22 import Resample, Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16)
down_blocks = enc.encoder.downsamples
@@ -921,7 +921,7 @@ class TestVAEEncoderTemporalOrder:
def test_single_frame_encoder(self):
"""Single frame (T=1) should work with (False, True, True) pattern."""
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder
from mlx_video.models.wan_2.vae22 import Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16)
img = mx.random.normal((1, 1, 32, 32, 3))
@@ -933,7 +933,7 @@ class TestVAEEncoderTemporalOrder:
def test_wrong_order_gives_different_result(self):
"""(True, True, False) vs (False, True, True) produce different outputs."""
from mlx_video.models.wan2.vae22 import Encoder3d
from mlx_video.models.wan_2.vae22 import Encoder3d
enc_correct = Encoder3d(
dim=16, z_dim=8, temperal_downsample=(False, True, True)
@@ -963,7 +963,7 @@ class TestVAE21RoundTrip:
def test_encode_decode_shape_and_values(self):
"""Encoder3d → Decoder3d: output shape matches input, values are finite."""
from mlx_video.models.wan2.vae import Decoder3d, Encoder3d
from mlx_video.models.wan_2.vae import Decoder3d, Encoder3d
z_dim = 4
dim = 8
@@ -995,7 +995,7 @@ class TestVAE22RoundTrip:
def test_encode_decode_shape_and_values(self):
"""Wan22VAEEncoder → Wan22VAEDecoder: shapes consistent, values in range."""
from mlx_video.models.wan2.vae22 import (
from mlx_video.models.wan_2.vae22 import (
Wan22VAEDecoder,
Wan22VAEEncoder,
denormalize_latents,