Files
mlx-video/tests/test_wan_vae.py

952 lines
35 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Tests for Wan VAE 2.1 and 2.2 components."""
import math
import mlx.core as mx
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# VAE 2.1 Tests
# ---------------------------------------------------------------------------
class TestCausalConv3d:
def test_output_shape_stride1(self):
from mlx_video.models.wan.vae import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1)
# Initialize weights
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
x = mx.random.normal((1, 4, 3, 8, 8)) # [B, C, T, H, W]
out = conv(x)
mx.eval(out)
# With causal padding and padding=1 on spatial, dims should be preserved
assert out.shape[0] == 1
assert out.shape[1] == 8 # out_channels
assert out.shape[2] == 3 # T preserved
assert out.shape[3] == 8 # H preserved
assert out.shape[4] == 8 # W preserved
def test_output_shape_kernel1(self):
from mlx_video.models.wan.vae import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0)
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
x = mx.random.normal((1, 4, 2, 4, 4))
out = conv(x)
mx.eval(out)
assert out.shape == (1, 8, 2, 4, 4)
def test_causal_padding(self):
"""Causal conv should only use past/current frames, not future."""
from mlx_video.models.wan.vae import CausalConv3d
conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1)
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
conv.bias = mx.zeros((2,))
# Create input where only the first frame has signal
x = mx.zeros((1, 2, 4, 4, 4))
x_np = np.zeros((1, 2, 4, 4, 4), dtype=np.float32)
x_np[:, :, 0, :, :] = 1.0
x = mx.array(x_np)
out = conv(x)
mx.eval(out)
# Due to causal padding, the output at t=0 should only depend on t=0
class TestResidualBlock:
def test_same_dim(self):
from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 8)
x = mx.random.normal((1, 8, 2, 4, 4))
out = block(x)
mx.eval(out)
assert out.shape == (1, 8, 2, 4, 4)
def test_different_dim(self):
from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 16)
x = mx.random.normal((1, 8, 2, 4, 4))
out = block(x)
mx.eval(out)
assert out.shape == (1, 16, 2, 4, 4)
def test_shortcut_exists_when_dims_differ(self):
from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 16)
assert block.shortcut is not None
def test_no_shortcut_when_dims_same(self):
from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 8)
assert block.shortcut is None
class TestAttentionBlock:
def test_output_shape(self):
from mlx_video.models.wan.vae import AttentionBlock
block = AttentionBlock(8)
x = mx.random.normal((1, 8, 2, 4, 4))
out = block(x)
mx.eval(out)
assert out.shape == (1, 8, 2, 4, 4)
def test_residual_connection(self):
from mlx_video.models.wan.vae import AttentionBlock
block = AttentionBlock(8)
x = mx.random.normal((1, 8, 1, 3, 3))
out = block(x)
mx.eval(x, out)
# Residual: output should not be zero even with random init
assert np.abs(np.array(out)).max() > 0
class TestWanVAE:
def test_instantiation(self):
from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16)
assert vae.z_dim == 16
assert vae.mean.shape == (16,)
assert vae.std.shape == (16,)
def test_normalization_stats(self):
from mlx_video.models.wan.vae import WanVAE, VAE_MEAN, VAE_STD
assert len(VAE_MEAN) == 16
assert len(VAE_STD) == 16
assert all(s > 0 for s in VAE_STD)
# ---------------------------------------------------------------------------
# Wan2.2 VAE Component Tests
# ---------------------------------------------------------------------------
class TestVAE22CausalConv3d:
"""Tests for vae22.CausalConv3d (channels-last)."""
def test_output_shape_k3(self):
from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(8, 16, kernel_size=3, padding=1)
x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C]
out = conv(x)
mx.eval(out)
assert out.shape == (1, 4, 8, 8, 16)
def test_output_shape_k1(self):
from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(8, 16, kernel_size=1)
x = mx.random.normal((1, 2, 4, 4, 8))
out = conv(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 16)
def test_temporal_causal(self):
"""Output at t=0 should not depend on t>0."""
from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(2, 2, kernel_size=3, padding=1)
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
conv.bias = mx.zeros(conv.bias.shape)
x = mx.zeros((1, 4, 4, 4, 2))
out_zero = conv(x)
mx.eval(out_zero)
t0_ref = np.array(out_zero[0, 0])
# Modify t=2..3; output at t=0 should be unchanged
x_mod = mx.concatenate([
x[:, :2],
mx.ones((1, 2, 4, 4, 2)),
], axis=1)
out_mod = conv(x_mod)
mx.eval(out_mod)
t0_mod = np.array(out_mod[0, 0])
np.testing.assert_allclose(t0_ref, t0_mod, atol=1e-5)
def test_channels_last_format(self):
"""Verify input/output are channels-last [B, T, H, W, C]."""
from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=3, padding=1)
x = mx.random.normal((2, 3, 6, 6, 4))
out = conv(x)
mx.eval(out)
assert out.shape[-1] == 8 # last dim = out_channels
class TestRMSNorm:
"""Tests for vae22.RMS_norm (actually L2 normalization)."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import RMS_norm
norm = RMS_norm(16)
x = mx.random.normal((2, 4, 4, 4, 16))
out = norm(x)
mx.eval(out)
assert out.shape == x.shape
def test_l2_normalization(self):
"""RMS_norm should normalize to unit L2 norm * sqrt(dim)."""
from mlx_video.models.wan.vae22 import RMS_norm
dim = 32
norm = RMS_norm(dim)
x = mx.random.normal((1, 1, 1, 1, dim)) * 5.0 # large values
out = norm(x)
mx.eval(out)
# After L2 norm * scale(=sqrt(dim)) * gamma(=1): ||out|| = sqrt(dim)
out_np = np.array(out).flatten()
l2 = np.linalg.norm(out_np)
np.testing.assert_allclose(l2, math.sqrt(dim), rtol=1e-3)
def test_scale_invariant(self):
"""Scaling input by constant should not change output (L2 norm property)."""
from mlx_video.models.wan.vae22 import RMS_norm
norm = RMS_norm(8)
x = mx.random.normal((1, 1, 1, 1, 8))
out1 = norm(x)
out2 = norm(x * 10.0)
mx.eval(out1, out2)
np.testing.assert_allclose(np.array(out1), np.array(out2), atol=1e-4)
def test_gamma_effect(self):
"""Non-unit gamma should scale output."""
from mlx_video.models.wan.vae22 import RMS_norm
norm = RMS_norm(4)
norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0])
x = mx.ones((1, 1, 1, 1, 4))
out = norm(x)
mx.eval(out)
# With gamma=2, each component is 2 * sqrt(4) * x/||x|| = 2 * 2 * 1/2 = 2
np.testing.assert_allclose(np.array(out).flatten(), 2.0, atol=1e-4)
class TestDupUp3D:
"""Tests for vae22.DupUp3D spatial/temporal upsampling."""
def test_spatial_only(self):
from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
out = up(x)
mx.eval(out)
assert out.shape == (1, 3, 8, 8, 4)
def test_temporal_and_spatial(self):
from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(16, 8, factor_t=2, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 16))
out = up(x)
mx.eval(out)
assert out.shape == (1, 6, 8, 8, 8)
def test_first_chunk_trims(self):
from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=2, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
out_normal = up(x, first_chunk=False)
out_trimmed = up(x, first_chunk=True)
mx.eval(out_normal, out_trimmed)
# first_chunk removes factor_t-1=1 temporal frame
assert out_normal.shape[1] == 6
assert out_trimmed.shape[1] == 5
def test_no_temporal_first_chunk_noop(self):
from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
out_normal = up(x, first_chunk=False)
out_trimmed = up(x, first_chunk=True)
mx.eval(out_normal, out_trimmed)
# factor_t=1, so first_chunk removes 0 frames
assert out_normal.shape == out_trimmed.shape
class TestVAE22Resample:
"""Tests for vae22.Resample (spatial/temporal upsampling)."""
def test_upsample2d_shape(self):
from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample2d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 2, 4, 4, 8))
out = r(x)
mx.eval(out)
assert out.shape == (1, 2, 8, 8, 8) # 2x spatial, same temporal
def test_upsample3d_shape(self):
from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 2, 4, 4, 8))
out = r(x)
mx.eval(out)
assert out.shape == (1, 4, 8, 8, 8) # 2x spatial + 2x temporal
def test_upsample3d_first_chunk(self):
from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 2, 4, 4, 8))
out = r(x, first_chunk=True)
mx.eval(out)
# first_chunk: 1 (bypass) + 2*(T-1) (interleaved) = 2T-1 = 3
assert out.shape == (1, 3, 8, 8, 8)
def test_upsample3d_first_chunk_single_frame(self):
"""Single-frame input with first_chunk: no temporal upsample."""
from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 1, 4, 4, 8))
out = r(x, first_chunk=True)
mx.eval(out)
# Single frame with first_chunk: falls through to non-first path
# time_conv on 1 frame → 2 interleaved
assert out.shape == (1, 2, 8, 8, 8)
def test_upsample3d_first_frame_bypasses_time_conv(self):
"""First frame of first_chunk should NOT go through time_conv.
Official Wan2.2 skips time_conv for the very first frame entirely.
We verify this by checking that the first output frame depends only on
the first input frame (not on time_conv parameters).
"""
from mlx_video.models.wan.vae22 import Resample
C = 8
r = Resample(C, "upsample3d")
# Set time_conv weights to large values so its effect is detectable
r.time_conv.weight = mx.ones(r.time_conv.weight.shape) * 10.0
r.time_conv.bias = mx.zeros(r.time_conv.bias.shape)
# Set spatial conv to identity-like
r.resample_weight = mx.zeros(r.resample_weight.shape)
r.resample_bias = mx.zeros(r.resample_bias.shape)
x = mx.random.normal((1, 3, 2, 2, C))
out = r(x, first_chunk=True)
mx.eval(out)
# Output: 5 frames (1 bypass + 4 interleaved from 2 remaining)
assert out.shape[1] == 5
# First frame should be spatial upsample of x[:, 0:1] only.
# Run just the first frame through spatial upsample for reference
first_only = x[:, 0:1]
ref = r._upsample2x(first_only.reshape(1, 2, 2, C))
ref = mx.pad(ref, [(0, 0), (1, 1), (1, 1), (0, 0)])
ref = mx.conv_general(ref, r.resample_weight) + r.resample_bias
mx.eval(ref)
# Compare first output frame to reference
first_out = out[:, 0:1].reshape(1, out.shape[2], out.shape[3], C)
mx.eval(first_out)
assert mx.allclose(first_out, ref, atol=1e-5).item(), \
"First frame should bypass time_conv and match spatial-only upsample"
class TestVAE22ResidualBlock:
"""Tests for vae22.ResidualBlock."""
def test_same_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 8)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 8)
def test_different_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 16)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 16)
def test_shortcut_when_dims_differ(self):
from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 16)
assert block.shortcut is not None
def test_no_shortcut_same_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 8)
assert block.shortcut is None
class TestResidualBlockLayers:
"""Tests for vae22.ResidualBlockLayers naming convention."""
def test_layer_names_no_underscore_prefix(self):
"""Layer names must NOT start with underscore (MLX ignores them)."""
from mlx_video.models.wan.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 8)
params = dict(block.parameters())
# All param keys should use layer_N, not _layer_N
for key in params:
assert not key.startswith("_"), f"Parameter {key} starts with underscore"
def test_has_expected_layers(self):
from mlx_video.models.wan.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 16)
assert hasattr(block, "layer_0") # first RMS_norm
assert hasattr(block, "layer_2") # first CausalConv3d
assert hasattr(block, "layer_3") # second RMS_norm
assert hasattr(block, "layer_6") # second CausalConv3d
def test_forward_shape(self):
from mlx_video.models.wan.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 16)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 16)
class TestVAE22AttentionBlock:
"""Tests for vae22.AttentionBlock (per-frame 2D self-attention)."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import AttentionBlock
block = AttentionBlock(16)
block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01
block.proj_weight = mx.random.normal(block.proj_weight.shape) * 0.01
x = mx.random.normal((1, 2, 4, 4, 16))
out = block(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 16)
def test_residual_connection(self):
from mlx_video.models.wan.vae22 import AttentionBlock
block = AttentionBlock(8)
block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape)
block.proj_weight = mx.zeros(block.proj_weight.shape)
x = mx.ones((1, 1, 2, 2, 8))
out = block(x)
mx.eval(out)
# With zero weights, attention output is 0 → residual is identity
np.testing.assert_allclose(np.array(out), np.array(x), atol=1e-5)
class TestHead22:
"""Tests for vae22.Head22 output head."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import Head22
head = Head22(16, out_channels=12)
x = mx.random.normal((1, 2, 4, 4, 16))
out = head(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 12)
def test_layer_names_no_underscore(self):
"""Head layers must not use underscore prefix."""
from mlx_video.models.wan.vae22 import Head22
head = Head22(8)
assert hasattr(head, "layer_0") # RMS_norm
assert hasattr(head, "layer_2") # CausalConv3d
params = dict(head.parameters())
for key in params:
assert not key.startswith("_"), f"Head param {key} starts with underscore"
class TestUnpatchify:
"""Tests for vae22._unpatchify."""
def test_basic_shape(self):
from mlx_video.models.wan.vae22 import _unpatchify
x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2
out = _unpatchify(x, patch_size=2)
mx.eval(out)
assert out.shape == (1, 2, 8, 8, 3)
def test_patch_size_1_noop(self):
from mlx_video.models.wan.vae22 import _unpatchify
x = mx.random.normal((1, 2, 4, 4, 3))
out = _unpatchify(x, patch_size=1)
mx.eval(out)
np.testing.assert_array_equal(np.array(out), np.array(x))
def test_preserves_content(self):
"""Unpatchify should be a lossless rearrangement."""
from mlx_video.models.wan.vae22 import _unpatchify
x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32)
out = _unpatchify(x, patch_size=2)
mx.eval(out)
# All elements should be preserved
assert np.array(out).size == 48
assert set(np.array(out).flatten().tolist()) == set(range(48))
class TestDenormalizeLatents:
"""Tests for vae22.denormalize_latents."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import denormalize_latents
z = mx.random.normal((1, 2, 4, 4, 48))
out = denormalize_latents(z)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 48)
def test_custom_mean_std(self):
from mlx_video.models.wan.vae22 import denormalize_latents
z = mx.ones((1, 1, 1, 1, 4))
mean = mx.array([1.0, 2.0, 3.0, 4.0])
std = mx.array([0.5, 0.5, 0.5, 0.5])
out = denormalize_latents(z, mean=mean, std=std)
mx.eval(out)
# z * std + mean = 1*0.5 + [1,2,3,4] = [1.5, 2.5, 3.5, 4.5]
np.testing.assert_allclose(np.array(out).flatten(), [1.5, 2.5, 3.5, 4.5], atol=1e-5)
def test_uses_default_constants(self):
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD, denormalize_latents
# Should not raise with default constants
z = mx.zeros((1, 1, 1, 1, 48))
out = denormalize_latents(z)
mx.eval(out)
# z=0 → result = 0 * std + mean = mean
np.testing.assert_allclose(
np.array(out).flatten(),
np.array(VAE22_MEAN).flatten(),
atol=1e-5,
)
class TestVAE22NormConstants:
"""Tests for VAE22_MEAN and VAE22_STD constants."""
def test_dimensions(self):
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD
mx.eval(VAE22_MEAN, VAE22_STD)
assert VAE22_MEAN.shape == (48,)
assert VAE22_STD.shape == (48,)
def test_std_positive(self):
from mlx_video.models.wan.vae22 import VAE22_STD
mx.eval(VAE22_STD)
assert (np.array(VAE22_STD) > 0).all()
class TestWan22VAEDecoder:
"""Tests for the full Wan22VAEDecoder (tiny configuration)."""
def test_output_shape_small(self):
"""Tiny decoder should produce correct spatial/temporal output."""
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
# Use very small dims to keep test fast
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
# Latent: [B=1, T=3, H=2, W=2, C=4]
# Expected: temporal 3→5→9→9→9 (two temporal upsamples), spatial 2→4→8→16
z = mx.random.normal((1, 3, 2, 2, 4)) * 0.1
out = dec(z)
mx.eval(out)
# Output should have 3 RGB channels and be clipped to [-1, 1]
assert out.shape[-1] == 3
assert out.ndim == 5
assert np.array(out).min() >= -1.0
assert np.array(out).max() <= 1.0
def test_output_clipped(self):
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values
out = dec(z)
mx.eval(out)
assert np.array(out).min() >= -1.0 - 1e-6
assert np.array(out).max() <= 1.0 + 1e-6
class TestSanitizeWan22VAEWeights:
"""Tests for vae22.sanitize_wan22_vae_weights."""
def test_skip_encoder(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.layer.weight": mx.zeros((4,)),
"conv1.weight": mx.zeros((4,)),
"decoder.conv1.bias": mx.zeros((4,)),
}
out = sanitize_wan22_vae_weights(weights)
assert "encoder.layer.weight" not in out
assert "conv1.weight" not in out
assert "decoder.conv1.bias" in out
def test_sequential_index_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)),
"decoder.upsamples.0.upsamples.0.residual.6.bias": mx.zeros((8,)),
"decoder.head.0.gamma": mx.ones((4,)),
"decoder.head.2.bias": mx.zeros((12,)),
}
out = sanitize_wan22_vae_weights(weights)
assert "decoder.upsamples.0.upsamples.0.residual.layer_0.gamma" in out
assert "decoder.upsamples.0.upsamples.0.residual.layer_6.bias" in out
assert "decoder.head.layer_0.gamma" in out
assert "decoder.head.layer_2.bias" in out
def test_resample_conv_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)),
"decoder.upsamples.1.upsamples.3.resample.1.bias": mx.zeros((8,)),
}
out = sanitize_wan22_vae_weights(weights)
assert "decoder.upsamples.1.upsamples.3.resample_weight" in out
assert "decoder.upsamples.1.upsamples.3.resample_bias" in out
def test_attention_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)),
"decoder.middle.1.to_qkv.bias": mx.zeros((24,)),
"decoder.middle.1.proj.weight": mx.zeros((8, 8, 1, 1)),
"decoder.middle.1.proj.bias": mx.zeros((8,)),
}
out = sanitize_wan22_vae_weights(weights)
assert "decoder.middle.1.to_qkv_weight" in out
assert "decoder.middle.1.to_qkv_bias" in out
assert "decoder.middle.1.proj_weight" in out
assert "decoder.middle.1.proj_bias" in out
def test_conv3d_transpose(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
# Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I]
w = mx.zeros((16, 8, 3, 3, 3))
weights = {"decoder.conv1.weight": w}
out = sanitize_wan22_vae_weights(weights)
assert out["decoder.conv1.weight"].shape == (16, 3, 3, 3, 8)
def test_conv2d_transpose(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
# Conv2d weight: [O, I, H, W] → [O, H, W, I]
w = mx.zeros((8, 8, 3, 3))
weights = {"decoder.upsamples.0.upsamples.2.resample.1.weight": w}
out = sanitize_wan22_vae_weights(weights)
key = "decoder.upsamples.0.upsamples.2.resample_weight"
assert out[key].shape == (8, 3, 3, 8)
def test_gamma_squeeze(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
# gamma: (dim, 1, 1, 1) → (dim,)
w = mx.ones((16, 1, 1, 1))
weights = {"decoder.upsamples.0.upsamples.0.residual.0.gamma": w}
out = sanitize_wan22_vae_weights(weights)
key = "decoder.upsamples.0.upsamples.0.residual.layer_0.gamma"
assert out[key].shape == (16,)
class TestUpResidualBlock:
"""Tests for vae22.Up_ResidualBlock."""
def test_no_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
# No upsample: same shape
assert out.shape == (1, 2, 4, 4, 8)
def test_spatial_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
# 2x spatial upsample, no temporal
assert out.shape == (1, 2, 8, 8, 4)
def test_spatial_temporal_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
# 2x spatial + 2x temporal
assert out.shape == (1, 4, 8, 8, 4)
class TestPatchify:
"""Tests for _patchify and _unpatchify round-trip."""
def test_roundtrip(self):
from mlx_video.models.wan.vae22 import _patchify, _unpatchify
x = mx.random.normal((1, 1, 64, 64, 3))
p = _patchify(x, patch_size=2)
assert p.shape == (1, 1, 32, 32, 12)
back = _unpatchify(p, patch_size=2)
assert back.shape == x.shape
assert float(mx.abs(x - back).max()) == 0.0
def test_identity_patch_1(self):
from mlx_video.models.wan.vae22 import _patchify, _unpatchify
x = mx.random.normal((1, 2, 8, 8, 3))
assert _patchify(x, patch_size=1).shape == x.shape
assert _unpatchify(x, patch_size=1).shape == x.shape
class TestAvgDown3D:
"""Tests for AvgDown3D downsampling."""
def test_spatial_only(self):
from mlx_video.models.wan.vae22 import AvgDown3D
down = AvgDown3D(8, 16, factor_t=1, factor_s=2)
x = mx.random.normal((1, 2, 8, 8, 8))
out = down(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 16)
def test_temporal_and_spatial(self):
from mlx_video.models.wan.vae22 import AvgDown3D
down = AvgDown3D(8, 16, factor_t=2, factor_s=2)
x = mx.random.normal((1, 4, 8, 8, 8))
out = down(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 16)
def test_single_frame(self):
from mlx_video.models.wan.vae22 import AvgDown3D
down = AvgDown3D(8, 8, factor_t=2, factor_s=2)
x = mx.random.normal((1, 1, 8, 8, 8))
out = down(x)
mx.eval(out)
# T=1 with factor_t=2: pads to T=2 then averages → T=1
assert out.shape == (1, 1, 4, 4, 8)
class TestDownResidualBlock:
"""Tests for Down_ResidualBlock."""
def test_no_downsample(self):
from mlx_video.models.wan.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False)
x = mx.random.normal((1, 2, 8, 8, 8))
out = block(x)
mx.eval(out)
assert out.shape == (1, 2, 8, 8, 8)
def test_spatial_downsample(self):
from mlx_video.models.wan.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True)
x = mx.random.normal((1, 2, 8, 8, 8))
out = block(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 16)
def test_spatial_temporal_downsample(self):
from mlx_video.models.wan.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True)
x = mx.random.normal((1, 4, 8, 8, 8))
out = block(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 16)
class TestEncoder3d:
"""Tests for Encoder3d."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import Encoder3d
enc = Encoder3d(dim=16, z_dim=8)
x = mx.random.normal((1, 1, 16, 16, 12))
mx.eval(enc.parameters())
out = enc(x)
mx.eval(out)
# 3 spatial downsamples ÷8: 16→2
assert out.shape == (1, 1, 2, 2, 8)
def test_multi_frame(self):
from mlx_video.models.wan.vae22 import Encoder3d
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
x = mx.random.normal((1, 5, 16, 16, 12))
mx.eval(enc.parameters())
out = enc(x)
mx.eval(out)
# T: 5→3 (1st t_down) →2 (2nd t_down), spatial ÷8
assert out.shape[2:] == (2, 2, 8)
class TestWan22VAEEncoder:
"""Tests for Wan22VAEEncoder wrapper."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16)
# Input: single image 32×32 (patchify÷2 → 16×16, then 3 spatial ÷8 → 2×2)
img = mx.random.normal((1, 1, 32, 32, 3))
mx.eval(enc.parameters())
z = enc(img)
mx.eval(z)
assert z.shape == (1, 1, 2, 2, 48)
def test_full_dim(self):
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=160)
img = mx.random.normal((1, 1, 64, 64, 3))
mx.eval(enc.parameters())
z = enc(img)
mx.eval(z)
# 64 / 16 = 4 (vae stride 16×)
assert z.shape == (1, 1, 4, 4, 48)
class TestNormalizeLatents:
"""Tests for normalize/denormalize latent roundtrip."""
def test_roundtrip(self):
from mlx_video.models.wan.vae22 import denormalize_latents, normalize_latents
z = mx.random.normal((1, 2, 4, 4, 48))
z_norm = normalize_latents(z)
z_back = denormalize_latents(z_norm)
mx.eval(z_back)
assert float(mx.abs(z - z_back).max()) < 1e-4
class TestVAEEncoderTemporalOrder:
"""Tests that VAE encoder uses (False, True, True) temporal downsample order,
matching official Wan2.2 vae2_2.py."""
def test_encoder_temporal_downsample_pattern(self):
"""Encoder3d with (False, True, True): T=5→5→3→2."""
from mlx_video.models.wan.vae22 import Encoder3d
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
x = mx.random.normal((1, 5, 16, 16, 12))
mx.eval(enc.parameters())
out = enc(x)
mx.eval(out)
assert out.shape[1] == 2
def test_wrapper_uses_correct_pattern(self):
"""Wan22VAEEncoder should use (False, True, True) temporal downsample."""
from mlx_video.models.wan.vae22 import Wan22VAEEncoder, Resample
enc = Wan22VAEEncoder(z_dim=48, dim=16)
down_blocks = enc.encoder.downsamples
found_modes = []
for block in down_blocks:
for layer in block.downsamples:
if isinstance(layer, Resample):
found_modes.append(layer.mode)
# First spatial-only, then two with temporal
assert found_modes[0] == "downsample2d"
assert any("3d" in m for m in found_modes)
def test_single_frame_encoder(self):
"""Single frame (T=1) should work with (False, True, True) pattern."""
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16)
img = mx.random.normal((1, 1, 32, 32, 3))
mx.eval(enc.parameters())
z = enc(img)
mx.eval(z)
assert z.shape[1] == 1
assert z.shape[-1] == 48
def test_wrong_order_gives_different_result(self):
"""(True, True, False) vs (False, True, True) produce different outputs."""
from mlx_video.models.wan.vae22 import Encoder3d
enc_correct = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
enc_wrong = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
x = mx.random.normal((1, 5, 16, 16, 12))
mx.eval(enc_correct.parameters())
mx.eval(enc_wrong.parameters())
out_correct = enc_correct(x)
out_wrong = enc_wrong(x)
mx.eval(out_correct, out_wrong)
# Both give T=2 but spatial processing path differs
assert out_correct.shape[1] == 2
assert out_wrong.shape[1] == 2
# ---------------------------------------------------------------------------
# VAE Encode → Decode Round-Trip Tests
# ---------------------------------------------------------------------------
class TestVAE21RoundTrip:
"""Encode→decode round-trip for Wan 2.1 VAE (channels-first)."""
def test_encode_decode_shape_and_values(self):
"""Encoder3d → Decoder3d: output shape matches input, values are finite."""
from mlx_video.models.wan.vae import Decoder3d, Encoder3d
z_dim = 4
dim = 8
# No temporal up/downsampling to keep the test simple
enc = Encoder3d(
dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False]
)
dec = Decoder3d(
dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False]
)
mx.eval(enc.parameters(), dec.parameters())
# [B=1, C=3, T=1, H=8, W=8]
x = mx.random.normal((1, 3, 1, 8, 8)) * 0.5
z = enc(x)
mx.eval(z)
# 3 spatial downsamples (÷8): H=1, W=1
assert z.shape == (1, z_dim, 1, 1, 1)
x_hat = dec(z)
mx.eval(x_hat)
# 3 spatial upsamples (×8): should recover original shape
assert x_hat.shape == x.shape
out_np = np.array(x_hat)
assert np.all(np.isfinite(out_np))
assert np.abs(out_np).max() < 1000
class TestVAE22RoundTrip:
"""Encode→decode round-trip for Wan 2.2 VAE (channels-last)."""
def test_encode_decode_shape_and_values(self):
"""Wan22VAEEncoder → Wan22VAEDecoder: shapes consistent, values in range."""
from mlx_video.models.wan.vae22 import (
Wan22VAEDecoder,
Wan22VAEEncoder,
denormalize_latents,
)
enc = Wan22VAEEncoder(z_dim=48, dim=16)
dec = Wan22VAEDecoder(z_dim=48, dec_dim=8)
mx.eval(enc.parameters(), dec.parameters())
# [B=1, T=1, H=32, W=32, C=3]
img = mx.random.normal((1, 1, 32, 32, 3)) * 0.5
z_norm = enc(img)
mx.eval(z_norm)
# patchify(÷2) + 3 spatial downsamples(÷8) = ÷16
assert z_norm.shape == (1, 1, 2, 2, 48)
z = denormalize_latents(z_norm)
out = dec(z)
mx.eval(out)
# 3 spatial upsamples(×8) + unpatchify(×2) = ×16
assert out.shape[0] == 1 # batch
assert out.shape[2] == 32 # H recovered
assert out.shape[3] == 32 # W recovered
assert out.shape[-1] == 3 # RGB
out_np = np.array(out)
assert np.all(np.isfinite(out_np))
assert out_np.min() >= -1.0 - 1e-6
assert out_np.max() <= 1.0 + 1e-6