This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -4,16 +4,16 @@ import math
import mlx.core as mx
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# VAE 2.1 Tests
# ---------------------------------------------------------------------------
class TestCausalConv3d:
def test_output_shape_stride1(self):
from mlx_video.models.wan.vae import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1)
# Initialize weights
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
@@ -29,6 +29,7 @@ class TestCausalConv3d:
def test_output_shape_kernel1(self):
from mlx_video.models.wan.vae import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0)
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
x = mx.random.normal((1, 4, 2, 4, 4))
@@ -39,6 +40,7 @@ class TestCausalConv3d:
def test_causal_padding(self):
"""Causal conv should only use past/current frames, not future."""
from mlx_video.models.wan.vae import CausalConv3d
conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1)
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
conv.bias = mx.zeros((2,))
@@ -55,6 +57,7 @@ class TestCausalConv3d:
class TestResidualBlock:
def test_same_dim(self):
from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 8)
x = mx.random.normal((1, 8, 2, 4, 4))
out = block(x)
@@ -63,6 +66,7 @@ class TestResidualBlock:
def test_different_dim(self):
from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 16)
x = mx.random.normal((1, 8, 2, 4, 4))
out = block(x)
@@ -71,11 +75,13 @@ class TestResidualBlock:
def test_shortcut_exists_when_dims_differ(self):
from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 16)
assert block.shortcut is not None
def test_no_shortcut_when_dims_same(self):
from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 8)
assert block.shortcut is None
@@ -83,6 +89,7 @@ class TestResidualBlock:
class TestAttentionBlock:
def test_output_shape(self):
from mlx_video.models.wan.vae import AttentionBlock
block = AttentionBlock(8)
x = mx.random.normal((1, 8, 2, 4, 4))
out = block(x)
@@ -91,6 +98,7 @@ class TestAttentionBlock:
def test_residual_connection(self):
from mlx_video.models.wan.vae import AttentionBlock
block = AttentionBlock(8)
x = mx.random.normal((1, 8, 1, 3, 3))
out = block(x)
@@ -102,13 +110,15 @@ class TestAttentionBlock:
class TestWanVAE:
def test_instantiation(self):
from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16)
assert vae.z_dim == 16
assert vae.mean.shape == (16,)
assert vae.std.shape == (16,)
def test_normalization_stats(self):
from mlx_video.models.wan.vae import WanVAE, VAE_MEAN, VAE_STD
from mlx_video.models.wan.vae import VAE_MEAN, VAE_STD
assert len(VAE_MEAN) == 16
assert len(VAE_STD) == 16
assert all(s > 0 for s in VAE_STD)
@@ -124,6 +134,7 @@ class TestVAE22CausalConv3d:
def test_output_shape_k3(self):
from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(8, 16, kernel_size=3, padding=1)
x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C]
out = conv(x)
@@ -132,6 +143,7 @@ class TestVAE22CausalConv3d:
def test_output_shape_k1(self):
from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(8, 16, kernel_size=1)
x = mx.random.normal((1, 2, 4, 4, 8))
out = conv(x)
@@ -141,6 +153,7 @@ class TestVAE22CausalConv3d:
def test_temporal_causal(self):
"""Output at t=0 should not depend on t>0."""
from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(2, 2, kernel_size=3, padding=1)
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
conv.bias = mx.zeros(conv.bias.shape)
@@ -151,10 +164,13 @@ class TestVAE22CausalConv3d:
t0_ref = np.array(out_zero[0, 0])
# Modify t=2..3; output at t=0 should be unchanged
x_mod = mx.concatenate([
x[:, :2],
mx.ones((1, 2, 4, 4, 2)),
], axis=1)
x_mod = mx.concatenate(
[
x[:, :2],
mx.ones((1, 2, 4, 4, 2)),
],
axis=1,
)
out_mod = conv(x_mod)
mx.eval(out_mod)
t0_mod = np.array(out_mod[0, 0])
@@ -163,6 +179,7 @@ class TestVAE22CausalConv3d:
def test_channels_last_format(self):
"""Verify input/output are channels-last [B, T, H, W, C]."""
from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=3, padding=1)
x = mx.random.normal((2, 3, 6, 6, 4))
out = conv(x)
@@ -175,6 +192,7 @@ class TestRMSNorm:
def test_output_shape(self):
from mlx_video.models.wan.vae22 import RMS_norm
norm = RMS_norm(16)
x = mx.random.normal((2, 4, 4, 4, 16))
out = norm(x)
@@ -184,6 +202,7 @@ class TestRMSNorm:
def test_l2_normalization(self):
"""RMS_norm should normalize to unit L2 norm * sqrt(dim)."""
from mlx_video.models.wan.vae22 import RMS_norm
dim = 32
norm = RMS_norm(dim)
x = mx.random.normal((1, 1, 1, 1, dim)) * 5.0 # large values
@@ -197,6 +216,7 @@ class TestRMSNorm:
def test_scale_invariant(self):
"""Scaling input by constant should not change output (L2 norm property)."""
from mlx_video.models.wan.vae22 import RMS_norm
norm = RMS_norm(8)
x = mx.random.normal((1, 1, 1, 1, 8))
out1 = norm(x)
@@ -207,6 +227,7 @@ class TestRMSNorm:
def test_gamma_effect(self):
"""Non-unit gamma should scale output."""
from mlx_video.models.wan.vae22 import RMS_norm
norm = RMS_norm(4)
norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0])
x = mx.ones((1, 1, 1, 1, 4))
@@ -221,6 +242,7 @@ class TestDupUp3D:
def test_spatial_only(self):
from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
out = up(x)
@@ -229,6 +251,7 @@ class TestDupUp3D:
def test_temporal_and_spatial(self):
from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(16, 8, factor_t=2, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 16))
out = up(x)
@@ -237,6 +260,7 @@ class TestDupUp3D:
def test_first_chunk_trims(self):
from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=2, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
out_normal = up(x, first_chunk=False)
@@ -248,6 +272,7 @@ class TestDupUp3D:
def test_no_temporal_first_chunk_noop(self):
from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
out_normal = up(x, first_chunk=False)
@@ -262,6 +287,7 @@ class TestVAE22Resample:
def test_upsample2d_shape(self):
from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample2d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 2, 4, 4, 8))
@@ -271,6 +297,7 @@ class TestVAE22Resample:
def test_upsample3d_shape(self):
from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 2, 4, 4, 8))
@@ -280,6 +307,7 @@ class TestVAE22Resample:
def test_upsample3d_first_chunk(self):
from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 2, 4, 4, 8))
@@ -291,6 +319,7 @@ class TestVAE22Resample:
def test_upsample3d_first_chunk_single_frame(self):
"""Single-frame input with first_chunk: no temporal upsample."""
from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 1, 4, 4, 8))
@@ -308,6 +337,7 @@ class TestVAE22Resample:
the first input frame (not on time_conv parameters).
"""
from mlx_video.models.wan.vae22 import Resample
C = 8
r = Resample(C, "upsample3d")
# Set time_conv weights to large values so its effect is detectable
@@ -334,8 +364,9 @@ class TestVAE22Resample:
# Compare first output frame to reference
first_out = out[:, 0:1].reshape(1, out.shape[2], out.shape[3], C)
mx.eval(first_out)
assert mx.allclose(first_out, ref, atol=1e-5).item(), \
"First frame should bypass time_conv and match spatial-only upsample"
assert mx.allclose(
first_out, ref, atol=1e-5
).item(), "First frame should bypass time_conv and match spatial-only upsample"
class TestVAE22ResidualBlock:
@@ -343,6 +374,7 @@ class TestVAE22ResidualBlock:
def test_same_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 8)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
@@ -351,6 +383,7 @@ class TestVAE22ResidualBlock:
def test_different_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 16)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
@@ -359,11 +392,13 @@ class TestVAE22ResidualBlock:
def test_shortcut_when_dims_differ(self):
from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 16)
assert block.shortcut is not None
def test_no_shortcut_same_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 8)
assert block.shortcut is None
@@ -374,6 +409,7 @@ class TestResidualBlockLayers:
def test_layer_names_no_underscore_prefix(self):
"""Layer names must NOT start with underscore (MLX ignores them)."""
from mlx_video.models.wan.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 8)
params = dict(block.parameters())
# All param keys should use layer_N, not _layer_N
@@ -382,6 +418,7 @@ class TestResidualBlockLayers:
def test_has_expected_layers(self):
from mlx_video.models.wan.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 16)
assert hasattr(block, "layer_0") # first RMS_norm
assert hasattr(block, "layer_2") # first CausalConv3d
@@ -390,6 +427,7 @@ class TestResidualBlockLayers:
def test_forward_shape(self):
from mlx_video.models.wan.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 16)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
@@ -402,6 +440,7 @@ class TestVAE22AttentionBlock:
def test_output_shape(self):
from mlx_video.models.wan.vae22 import AttentionBlock
block = AttentionBlock(16)
block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01
block.proj_weight = mx.random.normal(block.proj_weight.shape) * 0.01
@@ -412,6 +451,7 @@ class TestVAE22AttentionBlock:
def test_residual_connection(self):
from mlx_video.models.wan.vae22 import AttentionBlock
block = AttentionBlock(8)
block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape)
block.proj_weight = mx.zeros(block.proj_weight.shape)
@@ -427,6 +467,7 @@ class TestHead22:
def test_output_shape(self):
from mlx_video.models.wan.vae22 import Head22
head = Head22(16, out_channels=12)
x = mx.random.normal((1, 2, 4, 4, 16))
out = head(x)
@@ -436,6 +477,7 @@ class TestHead22:
def test_layer_names_no_underscore(self):
"""Head layers must not use underscore prefix."""
from mlx_video.models.wan.vae22 import Head22
head = Head22(8)
assert hasattr(head, "layer_0") # RMS_norm
assert hasattr(head, "layer_2") # CausalConv3d
@@ -449,6 +491,7 @@ class TestUnpatchify:
def test_basic_shape(self):
from mlx_video.models.wan.vae22 import _unpatchify
x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2
out = _unpatchify(x, patch_size=2)
mx.eval(out)
@@ -456,6 +499,7 @@ class TestUnpatchify:
def test_patch_size_1_noop(self):
from mlx_video.models.wan.vae22 import _unpatchify
x = mx.random.normal((1, 2, 4, 4, 3))
out = _unpatchify(x, patch_size=1)
mx.eval(out)
@@ -464,6 +508,7 @@ class TestUnpatchify:
def test_preserves_content(self):
"""Unpatchify should be a lossless rearrangement."""
from mlx_video.models.wan.vae22 import _unpatchify
x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32)
out = _unpatchify(x, patch_size=2)
mx.eval(out)
@@ -477,6 +522,7 @@ class TestDenormalizeLatents:
def test_output_shape(self):
from mlx_video.models.wan.vae22 import denormalize_latents
z = mx.random.normal((1, 2, 4, 4, 48))
out = denormalize_latents(z)
mx.eval(out)
@@ -484,16 +530,23 @@ class TestDenormalizeLatents:
def test_custom_mean_std(self):
from mlx_video.models.wan.vae22 import denormalize_latents
z = mx.ones((1, 1, 1, 1, 4))
mean = mx.array([1.0, 2.0, 3.0, 4.0])
std = mx.array([0.5, 0.5, 0.5, 0.5])
out = denormalize_latents(z, mean=mean, std=std)
mx.eval(out)
# z * std + mean = 1*0.5 + [1,2,3,4] = [1.5, 2.5, 3.5, 4.5]
np.testing.assert_allclose(np.array(out).flatten(), [1.5, 2.5, 3.5, 4.5], atol=1e-5)
np.testing.assert_allclose(
np.array(out).flatten(), [1.5, 2.5, 3.5, 4.5], atol=1e-5
)
def test_uses_default_constants(self):
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD, denormalize_latents
from mlx_video.models.wan.vae22 import (
VAE22_MEAN,
denormalize_latents,
)
# Should not raise with default constants
z = mx.zeros((1, 1, 1, 1, 48))
out = denormalize_latents(z)
@@ -511,12 +564,14 @@ class TestVAE22NormConstants:
def test_dimensions(self):
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD
mx.eval(VAE22_MEAN, VAE22_STD)
assert VAE22_MEAN.shape == (48,)
assert VAE22_STD.shape == (48,)
def test_std_positive(self):
from mlx_video.models.wan.vae22 import VAE22_STD
mx.eval(VAE22_STD)
assert (np.array(VAE22_STD) > 0).all()
@@ -527,6 +582,7 @@ class TestWan22VAEDecoder:
def test_output_shape_small(self):
"""Tiny decoder should produce correct spatial/temporal output."""
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
# Use very small dims to keep test fast
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
# Latent: [B=1, T=3, H=2, W=2, C=4]
@@ -542,6 +598,7 @@ class TestWan22VAEDecoder:
def test_output_clipped(self):
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values
out = dec(z)
@@ -555,6 +612,7 @@ class TestSanitizeWan22VAEWeights:
def test_skip_encoder(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.layer.weight": mx.zeros((4,)),
"conv1.weight": mx.zeros((4,)),
@@ -567,6 +625,7 @@ class TestSanitizeWan22VAEWeights:
def test_sequential_index_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)),
"decoder.upsamples.0.upsamples.0.residual.6.bias": mx.zeros((8,)),
@@ -581,6 +640,7 @@ class TestSanitizeWan22VAEWeights:
def test_resample_conv_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)),
"decoder.upsamples.1.upsamples.3.resample.1.bias": mx.zeros((8,)),
@@ -591,6 +651,7 @@ class TestSanitizeWan22VAEWeights:
def test_attention_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)),
"decoder.middle.1.to_qkv.bias": mx.zeros((24,)),
@@ -605,6 +666,7 @@ class TestSanitizeWan22VAEWeights:
def test_conv3d_transpose(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
# Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I]
w = mx.zeros((16, 8, 3, 3, 3))
weights = {"decoder.conv1.weight": w}
@@ -613,6 +675,7 @@ class TestSanitizeWan22VAEWeights:
def test_conv2d_transpose(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
# Conv2d weight: [O, I, H, W] → [O, H, W, I]
w = mx.zeros((8, 8, 3, 3))
weights = {"decoder.upsamples.0.upsamples.2.resample.1.weight": w}
@@ -622,6 +685,7 @@ class TestSanitizeWan22VAEWeights:
def test_gamma_squeeze(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
# gamma: (dim, 1, 1, 1) → (dim,)
w = mx.ones((16, 1, 1, 1))
weights = {"decoder.upsamples.0.upsamples.0.residual.0.gamma": w}
@@ -635,7 +699,10 @@ class TestUpResidualBlock:
def test_no_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False)
block = Up_ResidualBlock(
8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False
)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
@@ -644,7 +711,10 @@ class TestUpResidualBlock:
def test_spatial_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True)
block = Up_ResidualBlock(
8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True
)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
@@ -653,7 +723,10 @@ class TestUpResidualBlock:
def test_spatial_temporal_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True)
block = Up_ResidualBlock(
8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True
)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
@@ -720,7 +793,9 @@ class TestDownResidualBlock:
def test_no_downsample(self):
from mlx_video.models.wan.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False)
block = Down_ResidualBlock(
8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False
)
x = mx.random.normal((1, 2, 8, 8, 8))
out = block(x)
mx.eval(out)
@@ -729,7 +804,9 @@ class TestDownResidualBlock:
def test_spatial_downsample(self):
from mlx_video.models.wan.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True)
block = Down_ResidualBlock(
8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True
)
x = mx.random.normal((1, 2, 8, 8, 8))
out = block(x)
mx.eval(out)
@@ -738,7 +815,9 @@ class TestDownResidualBlock:
def test_spatial_temporal_downsample(self):
from mlx_video.models.wan.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True)
block = Down_ResidualBlock(
8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True
)
x = mx.random.normal((1, 4, 8, 8, 8))
out = block(x)
mx.eval(out)
@@ -817,6 +896,7 @@ class TestVAEEncoderTemporalOrder:
def test_encoder_temporal_downsample_pattern(self):
"""Encoder3d with (False, True, True): T=5→5→3→2."""
from mlx_video.models.wan.vae22 import Encoder3d
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
x = mx.random.normal((1, 5, 16, 16, 12))
mx.eval(enc.parameters())
@@ -826,7 +906,8 @@ class TestVAEEncoderTemporalOrder:
def test_wrapper_uses_correct_pattern(self):
"""Wan22VAEEncoder should use (False, True, True) temporal downsample."""
from mlx_video.models.wan.vae22 import Wan22VAEEncoder, Resample
from mlx_video.models.wan.vae22 import Resample, Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16)
down_blocks = enc.encoder.downsamples
found_modes = []
@@ -841,6 +922,7 @@ class TestVAEEncoderTemporalOrder:
def test_single_frame_encoder(self):
"""Single frame (T=1) should work with (False, True, True) pattern."""
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16)
img = mx.random.normal((1, 1, 32, 32, 3))
mx.eval(enc.parameters())
@@ -852,7 +934,10 @@ class TestVAEEncoderTemporalOrder:
def test_wrong_order_gives_different_result(self):
"""(True, True, False) vs (False, True, True) produce different outputs."""
from mlx_video.models.wan.vae22 import Encoder3d
enc_correct = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
enc_correct = Encoder3d(
dim=16, z_dim=8, temperal_downsample=(False, True, True)
)
enc_wrong = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
x = mx.random.normal((1, 5, 16, 16, 12))
@@ -883,12 +968,8 @@ class TestVAE21RoundTrip:
z_dim = 4
dim = 8
# No temporal up/downsampling to keep the test simple
enc = Encoder3d(
dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False]
)
dec = Decoder3d(
dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False]
)
enc = Encoder3d(dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False])
dec = Decoder3d(dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False])
mx.eval(enc.parameters(), dec.parameters())
# [B=1, C=3, T=1, H=8, W=8]
@@ -937,15 +1018,12 @@ class TestVAE22RoundTrip:
mx.eval(out)
# 3 spatial upsamples(×8) + unpatchify(×2) = ×16
assert out.shape[0] == 1 # batch
assert out.shape[2] == 32 # H recovered
assert out.shape[3] == 32 # W recovered
assert out.shape[-1] == 3 # RGB
assert out.shape[0] == 1 # batch
assert out.shape[2] == 32 # H recovered
assert out.shape[3] == 32 # W recovered
assert out.shape[-1] == 3 # RGB
out_np = np.array(out)
assert np.all(np.isfinite(out_np))
assert out_np.min() >= -1.0 - 1e-6
assert out_np.max() <= 1.0 + 1e-6