format
This commit is contained in:
@@ -4,16 +4,16 @@ import math
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VAE 2.1 Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCausalConv3d:
|
||||
def test_output_shape_stride1(self):
|
||||
from mlx_video.models.wan.vae import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1)
|
||||
# Initialize weights
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
|
||||
@@ -29,6 +29,7 @@ class TestCausalConv3d:
|
||||
|
||||
def test_output_shape_kernel1(self):
|
||||
from mlx_video.models.wan.vae import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0)
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
|
||||
x = mx.random.normal((1, 4, 2, 4, 4))
|
||||
@@ -39,6 +40,7 @@ class TestCausalConv3d:
|
||||
def test_causal_padding(self):
|
||||
"""Causal conv should only use past/current frames, not future."""
|
||||
from mlx_video.models.wan.vae import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1)
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
|
||||
conv.bias = mx.zeros((2,))
|
||||
@@ -55,6 +57,7 @@ class TestCausalConv3d:
|
||||
class TestResidualBlock:
|
||||
def test_same_dim(self):
|
||||
from mlx_video.models.wan.vae import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 8)
|
||||
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||
out = block(x)
|
||||
@@ -63,6 +66,7 @@ class TestResidualBlock:
|
||||
|
||||
def test_different_dim(self):
|
||||
from mlx_video.models.wan.vae import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 16)
|
||||
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||
out = block(x)
|
||||
@@ -71,11 +75,13 @@ class TestResidualBlock:
|
||||
|
||||
def test_shortcut_exists_when_dims_differ(self):
|
||||
from mlx_video.models.wan.vae import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 16)
|
||||
assert block.shortcut is not None
|
||||
|
||||
def test_no_shortcut_when_dims_same(self):
|
||||
from mlx_video.models.wan.vae import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 8)
|
||||
assert block.shortcut is None
|
||||
|
||||
@@ -83,6 +89,7 @@ class TestResidualBlock:
|
||||
class TestAttentionBlock:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae import AttentionBlock
|
||||
|
||||
block = AttentionBlock(8)
|
||||
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||
out = block(x)
|
||||
@@ -91,6 +98,7 @@ class TestAttentionBlock:
|
||||
|
||||
def test_residual_connection(self):
|
||||
from mlx_video.models.wan.vae import AttentionBlock
|
||||
|
||||
block = AttentionBlock(8)
|
||||
x = mx.random.normal((1, 8, 1, 3, 3))
|
||||
out = block(x)
|
||||
@@ -102,13 +110,15 @@ class TestAttentionBlock:
|
||||
class TestWanVAE:
|
||||
def test_instantiation(self):
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16)
|
||||
assert vae.z_dim == 16
|
||||
assert vae.mean.shape == (16,)
|
||||
assert vae.std.shape == (16,)
|
||||
|
||||
def test_normalization_stats(self):
|
||||
from mlx_video.models.wan.vae import WanVAE, VAE_MEAN, VAE_STD
|
||||
from mlx_video.models.wan.vae import VAE_MEAN, VAE_STD
|
||||
|
||||
assert len(VAE_MEAN) == 16
|
||||
assert len(VAE_STD) == 16
|
||||
assert all(s > 0 for s in VAE_STD)
|
||||
@@ -124,6 +134,7 @@ class TestVAE22CausalConv3d:
|
||||
|
||||
def test_output_shape_k3(self):
|
||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(8, 16, kernel_size=3, padding=1)
|
||||
x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C]
|
||||
out = conv(x)
|
||||
@@ -132,6 +143,7 @@ class TestVAE22CausalConv3d:
|
||||
|
||||
def test_output_shape_k1(self):
|
||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(8, 16, kernel_size=1)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = conv(x)
|
||||
@@ -141,6 +153,7 @@ class TestVAE22CausalConv3d:
|
||||
def test_temporal_causal(self):
|
||||
"""Output at t=0 should not depend on t>0."""
|
||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(2, 2, kernel_size=3, padding=1)
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
|
||||
conv.bias = mx.zeros(conv.bias.shape)
|
||||
@@ -151,10 +164,13 @@ class TestVAE22CausalConv3d:
|
||||
t0_ref = np.array(out_zero[0, 0])
|
||||
|
||||
# Modify t=2..3; output at t=0 should be unchanged
|
||||
x_mod = mx.concatenate([
|
||||
x[:, :2],
|
||||
mx.ones((1, 2, 4, 4, 2)),
|
||||
], axis=1)
|
||||
x_mod = mx.concatenate(
|
||||
[
|
||||
x[:, :2],
|
||||
mx.ones((1, 2, 4, 4, 2)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
out_mod = conv(x_mod)
|
||||
mx.eval(out_mod)
|
||||
t0_mod = np.array(out_mod[0, 0])
|
||||
@@ -163,6 +179,7 @@ class TestVAE22CausalConv3d:
|
||||
def test_channels_last_format(self):
|
||||
"""Verify input/output are channels-last [B, T, H, W, C]."""
|
||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(4, 8, kernel_size=3, padding=1)
|
||||
x = mx.random.normal((2, 3, 6, 6, 4))
|
||||
out = conv(x)
|
||||
@@ -175,6 +192,7 @@ class TestRMSNorm:
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import RMS_norm
|
||||
|
||||
norm = RMS_norm(16)
|
||||
x = mx.random.normal((2, 4, 4, 4, 16))
|
||||
out = norm(x)
|
||||
@@ -184,6 +202,7 @@ class TestRMSNorm:
|
||||
def test_l2_normalization(self):
|
||||
"""RMS_norm should normalize to unit L2 norm * sqrt(dim)."""
|
||||
from mlx_video.models.wan.vae22 import RMS_norm
|
||||
|
||||
dim = 32
|
||||
norm = RMS_norm(dim)
|
||||
x = mx.random.normal((1, 1, 1, 1, dim)) * 5.0 # large values
|
||||
@@ -197,6 +216,7 @@ class TestRMSNorm:
|
||||
def test_scale_invariant(self):
|
||||
"""Scaling input by constant should not change output (L2 norm property)."""
|
||||
from mlx_video.models.wan.vae22 import RMS_norm
|
||||
|
||||
norm = RMS_norm(8)
|
||||
x = mx.random.normal((1, 1, 1, 1, 8))
|
||||
out1 = norm(x)
|
||||
@@ -207,6 +227,7 @@ class TestRMSNorm:
|
||||
def test_gamma_effect(self):
|
||||
"""Non-unit gamma should scale output."""
|
||||
from mlx_video.models.wan.vae22 import RMS_norm
|
||||
|
||||
norm = RMS_norm(4)
|
||||
norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0])
|
||||
x = mx.ones((1, 1, 1, 1, 4))
|
||||
@@ -221,6 +242,7 @@ class TestDupUp3D:
|
||||
|
||||
def test_spatial_only(self):
|
||||
from mlx_video.models.wan.vae22 import DupUp3D
|
||||
|
||||
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||
out = up(x)
|
||||
@@ -229,6 +251,7 @@ class TestDupUp3D:
|
||||
|
||||
def test_temporal_and_spatial(self):
|
||||
from mlx_video.models.wan.vae22 import DupUp3D
|
||||
|
||||
up = DupUp3D(16, 8, factor_t=2, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 16))
|
||||
out = up(x)
|
||||
@@ -237,6 +260,7 @@ class TestDupUp3D:
|
||||
|
||||
def test_first_chunk_trims(self):
|
||||
from mlx_video.models.wan.vae22 import DupUp3D
|
||||
|
||||
up = DupUp3D(8, 4, factor_t=2, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||
out_normal = up(x, first_chunk=False)
|
||||
@@ -248,6 +272,7 @@ class TestDupUp3D:
|
||||
|
||||
def test_no_temporal_first_chunk_noop(self):
|
||||
from mlx_video.models.wan.vae22 import DupUp3D
|
||||
|
||||
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||
out_normal = up(x, first_chunk=False)
|
||||
@@ -262,6 +287,7 @@ class TestVAE22Resample:
|
||||
|
||||
def test_upsample2d_shape(self):
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
|
||||
r = Resample(8, "upsample2d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
@@ -271,6 +297,7 @@ class TestVAE22Resample:
|
||||
|
||||
def test_upsample3d_shape(self):
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
|
||||
r = Resample(8, "upsample3d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
@@ -280,6 +307,7 @@ class TestVAE22Resample:
|
||||
|
||||
def test_upsample3d_first_chunk(self):
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
|
||||
r = Resample(8, "upsample3d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
@@ -291,6 +319,7 @@ class TestVAE22Resample:
|
||||
def test_upsample3d_first_chunk_single_frame(self):
|
||||
"""Single-frame input with first_chunk: no temporal upsample."""
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
|
||||
r = Resample(8, "upsample3d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
x = mx.random.normal((1, 1, 4, 4, 8))
|
||||
@@ -308,6 +337,7 @@ class TestVAE22Resample:
|
||||
the first input frame (not on time_conv parameters).
|
||||
"""
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
|
||||
C = 8
|
||||
r = Resample(C, "upsample3d")
|
||||
# Set time_conv weights to large values so its effect is detectable
|
||||
@@ -334,8 +364,9 @@ class TestVAE22Resample:
|
||||
# Compare first output frame to reference
|
||||
first_out = out[:, 0:1].reshape(1, out.shape[2], out.shape[3], C)
|
||||
mx.eval(first_out)
|
||||
assert mx.allclose(first_out, ref, atol=1e-5).item(), \
|
||||
"First frame should bypass time_conv and match spatial-only upsample"
|
||||
assert mx.allclose(
|
||||
first_out, ref, atol=1e-5
|
||||
).item(), "First frame should bypass time_conv and match spatial-only upsample"
|
||||
|
||||
|
||||
class TestVAE22ResidualBlock:
|
||||
@@ -343,6 +374,7 @@ class TestVAE22ResidualBlock:
|
||||
|
||||
def test_same_dim(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 8)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
@@ -351,6 +383,7 @@ class TestVAE22ResidualBlock:
|
||||
|
||||
def test_different_dim(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 16)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
@@ -359,11 +392,13 @@ class TestVAE22ResidualBlock:
|
||||
|
||||
def test_shortcut_when_dims_differ(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 16)
|
||||
assert block.shortcut is not None
|
||||
|
||||
def test_no_shortcut_same_dim(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 8)
|
||||
assert block.shortcut is None
|
||||
|
||||
@@ -374,6 +409,7 @@ class TestResidualBlockLayers:
|
||||
def test_layer_names_no_underscore_prefix(self):
|
||||
"""Layer names must NOT start with underscore (MLX ignores them)."""
|
||||
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
||||
|
||||
block = ResidualBlockLayers(8, 8)
|
||||
params = dict(block.parameters())
|
||||
# All param keys should use layer_N, not _layer_N
|
||||
@@ -382,6 +418,7 @@ class TestResidualBlockLayers:
|
||||
|
||||
def test_has_expected_layers(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
||||
|
||||
block = ResidualBlockLayers(8, 16)
|
||||
assert hasattr(block, "layer_0") # first RMS_norm
|
||||
assert hasattr(block, "layer_2") # first CausalConv3d
|
||||
@@ -390,6 +427,7 @@ class TestResidualBlockLayers:
|
||||
|
||||
def test_forward_shape(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
||||
|
||||
block = ResidualBlockLayers(8, 16)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
@@ -402,6 +440,7 @@ class TestVAE22AttentionBlock:
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import AttentionBlock
|
||||
|
||||
block = AttentionBlock(16)
|
||||
block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01
|
||||
block.proj_weight = mx.random.normal(block.proj_weight.shape) * 0.01
|
||||
@@ -412,6 +451,7 @@ class TestVAE22AttentionBlock:
|
||||
|
||||
def test_residual_connection(self):
|
||||
from mlx_video.models.wan.vae22 import AttentionBlock
|
||||
|
||||
block = AttentionBlock(8)
|
||||
block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape)
|
||||
block.proj_weight = mx.zeros(block.proj_weight.shape)
|
||||
@@ -427,6 +467,7 @@ class TestHead22:
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import Head22
|
||||
|
||||
head = Head22(16, out_channels=12)
|
||||
x = mx.random.normal((1, 2, 4, 4, 16))
|
||||
out = head(x)
|
||||
@@ -436,6 +477,7 @@ class TestHead22:
|
||||
def test_layer_names_no_underscore(self):
|
||||
"""Head layers must not use underscore prefix."""
|
||||
from mlx_video.models.wan.vae22 import Head22
|
||||
|
||||
head = Head22(8)
|
||||
assert hasattr(head, "layer_0") # RMS_norm
|
||||
assert hasattr(head, "layer_2") # CausalConv3d
|
||||
@@ -449,6 +491,7 @@ class TestUnpatchify:
|
||||
|
||||
def test_basic_shape(self):
|
||||
from mlx_video.models.wan.vae22 import _unpatchify
|
||||
|
||||
x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2
|
||||
out = _unpatchify(x, patch_size=2)
|
||||
mx.eval(out)
|
||||
@@ -456,6 +499,7 @@ class TestUnpatchify:
|
||||
|
||||
def test_patch_size_1_noop(self):
|
||||
from mlx_video.models.wan.vae22 import _unpatchify
|
||||
|
||||
x = mx.random.normal((1, 2, 4, 4, 3))
|
||||
out = _unpatchify(x, patch_size=1)
|
||||
mx.eval(out)
|
||||
@@ -464,6 +508,7 @@ class TestUnpatchify:
|
||||
def test_preserves_content(self):
|
||||
"""Unpatchify should be a lossless rearrangement."""
|
||||
from mlx_video.models.wan.vae22 import _unpatchify
|
||||
|
||||
x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32)
|
||||
out = _unpatchify(x, patch_size=2)
|
||||
mx.eval(out)
|
||||
@@ -477,6 +522,7 @@ class TestDenormalizeLatents:
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||
|
||||
z = mx.random.normal((1, 2, 4, 4, 48))
|
||||
out = denormalize_latents(z)
|
||||
mx.eval(out)
|
||||
@@ -484,16 +530,23 @@ class TestDenormalizeLatents:
|
||||
|
||||
def test_custom_mean_std(self):
|
||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||
|
||||
z = mx.ones((1, 1, 1, 1, 4))
|
||||
mean = mx.array([1.0, 2.0, 3.0, 4.0])
|
||||
std = mx.array([0.5, 0.5, 0.5, 0.5])
|
||||
out = denormalize_latents(z, mean=mean, std=std)
|
||||
mx.eval(out)
|
||||
# z * std + mean = 1*0.5 + [1,2,3,4] = [1.5, 2.5, 3.5, 4.5]
|
||||
np.testing.assert_allclose(np.array(out).flatten(), [1.5, 2.5, 3.5, 4.5], atol=1e-5)
|
||||
np.testing.assert_allclose(
|
||||
np.array(out).flatten(), [1.5, 2.5, 3.5, 4.5], atol=1e-5
|
||||
)
|
||||
|
||||
def test_uses_default_constants(self):
|
||||
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD, denormalize_latents
|
||||
from mlx_video.models.wan.vae22 import (
|
||||
VAE22_MEAN,
|
||||
denormalize_latents,
|
||||
)
|
||||
|
||||
# Should not raise with default constants
|
||||
z = mx.zeros((1, 1, 1, 1, 48))
|
||||
out = denormalize_latents(z)
|
||||
@@ -511,12 +564,14 @@ class TestVAE22NormConstants:
|
||||
|
||||
def test_dimensions(self):
|
||||
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD
|
||||
|
||||
mx.eval(VAE22_MEAN, VAE22_STD)
|
||||
assert VAE22_MEAN.shape == (48,)
|
||||
assert VAE22_STD.shape == (48,)
|
||||
|
||||
def test_std_positive(self):
|
||||
from mlx_video.models.wan.vae22 import VAE22_STD
|
||||
|
||||
mx.eval(VAE22_STD)
|
||||
assert (np.array(VAE22_STD) > 0).all()
|
||||
|
||||
@@ -527,6 +582,7 @@ class TestWan22VAEDecoder:
|
||||
def test_output_shape_small(self):
|
||||
"""Tiny decoder should produce correct spatial/temporal output."""
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||
|
||||
# Use very small dims to keep test fast
|
||||
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
|
||||
# Latent: [B=1, T=3, H=2, W=2, C=4]
|
||||
@@ -542,6 +598,7 @@ class TestWan22VAEDecoder:
|
||||
|
||||
def test_output_clipped(self):
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||
|
||||
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
|
||||
z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values
|
||||
out = dec(z)
|
||||
@@ -555,6 +612,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
|
||||
def test_skip_encoder(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.layer.weight": mx.zeros((4,)),
|
||||
"conv1.weight": mx.zeros((4,)),
|
||||
@@ -567,6 +625,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
|
||||
def test_sequential_index_remapping(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)),
|
||||
"decoder.upsamples.0.upsamples.0.residual.6.bias": mx.zeros((8,)),
|
||||
@@ -581,6 +640,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
|
||||
def test_resample_conv_remapping(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)),
|
||||
"decoder.upsamples.1.upsamples.3.resample.1.bias": mx.zeros((8,)),
|
||||
@@ -591,6 +651,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
|
||||
def test_attention_remapping(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)),
|
||||
"decoder.middle.1.to_qkv.bias": mx.zeros((24,)),
|
||||
@@ -605,6 +666,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
|
||||
def test_conv3d_transpose(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
# Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I]
|
||||
w = mx.zeros((16, 8, 3, 3, 3))
|
||||
weights = {"decoder.conv1.weight": w}
|
||||
@@ -613,6 +675,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
|
||||
def test_conv2d_transpose(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
# Conv2d weight: [O, I, H, W] → [O, H, W, I]
|
||||
w = mx.zeros((8, 8, 3, 3))
|
||||
weights = {"decoder.upsamples.0.upsamples.2.resample.1.weight": w}
|
||||
@@ -622,6 +685,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
|
||||
def test_gamma_squeeze(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
# gamma: (dim, 1, 1, 1) → (dim,)
|
||||
w = mx.ones((16, 1, 1, 1))
|
||||
weights = {"decoder.upsamples.0.upsamples.0.residual.0.gamma": w}
|
||||
@@ -635,7 +699,10 @@ class TestUpResidualBlock:
|
||||
|
||||
def test_no_upsample(self):
|
||||
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
||||
block = Up_ResidualBlock(8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False)
|
||||
|
||||
block = Up_ResidualBlock(
|
||||
8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False
|
||||
)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
@@ -644,7 +711,10 @@ class TestUpResidualBlock:
|
||||
|
||||
def test_spatial_upsample(self):
|
||||
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
||||
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True)
|
||||
|
||||
block = Up_ResidualBlock(
|
||||
8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True
|
||||
)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
@@ -653,7 +723,10 @@ class TestUpResidualBlock:
|
||||
|
||||
def test_spatial_temporal_upsample(self):
|
||||
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
||||
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True)
|
||||
|
||||
block = Up_ResidualBlock(
|
||||
8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True
|
||||
)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
@@ -720,7 +793,9 @@ class TestDownResidualBlock:
|
||||
def test_no_downsample(self):
|
||||
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
||||
|
||||
block = Down_ResidualBlock(8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False)
|
||||
block = Down_ResidualBlock(
|
||||
8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False
|
||||
)
|
||||
x = mx.random.normal((1, 2, 8, 8, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
@@ -729,7 +804,9 @@ class TestDownResidualBlock:
|
||||
def test_spatial_downsample(self):
|
||||
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
||||
|
||||
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True)
|
||||
block = Down_ResidualBlock(
|
||||
8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True
|
||||
)
|
||||
x = mx.random.normal((1, 2, 8, 8, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
@@ -738,7 +815,9 @@ class TestDownResidualBlock:
|
||||
def test_spatial_temporal_downsample(self):
|
||||
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
||||
|
||||
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True)
|
||||
block = Down_ResidualBlock(
|
||||
8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True
|
||||
)
|
||||
x = mx.random.normal((1, 4, 8, 8, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
@@ -817,6 +896,7 @@ class TestVAEEncoderTemporalOrder:
|
||||
def test_encoder_temporal_downsample_pattern(self):
|
||||
"""Encoder3d with (False, True, True): T=5→5→3→2."""
|
||||
from mlx_video.models.wan.vae22 import Encoder3d
|
||||
|
||||
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
|
||||
x = mx.random.normal((1, 5, 16, 16, 12))
|
||||
mx.eval(enc.parameters())
|
||||
@@ -826,7 +906,8 @@ class TestVAEEncoderTemporalOrder:
|
||||
|
||||
def test_wrapper_uses_correct_pattern(self):
|
||||
"""Wan22VAEEncoder should use (False, True, True) temporal downsample."""
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder, Resample
|
||||
from mlx_video.models.wan.vae22 import Resample, Wan22VAEEncoder
|
||||
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||
down_blocks = enc.encoder.downsamples
|
||||
found_modes = []
|
||||
@@ -841,6 +922,7 @@ class TestVAEEncoderTemporalOrder:
|
||||
def test_single_frame_encoder(self):
|
||||
"""Single frame (T=1) should work with (False, True, True) pattern."""
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||
img = mx.random.normal((1, 1, 32, 32, 3))
|
||||
mx.eval(enc.parameters())
|
||||
@@ -852,7 +934,10 @@ class TestVAEEncoderTemporalOrder:
|
||||
def test_wrong_order_gives_different_result(self):
|
||||
"""(True, True, False) vs (False, True, True) produce different outputs."""
|
||||
from mlx_video.models.wan.vae22 import Encoder3d
|
||||
enc_correct = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
|
||||
|
||||
enc_correct = Encoder3d(
|
||||
dim=16, z_dim=8, temperal_downsample=(False, True, True)
|
||||
)
|
||||
enc_wrong = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
|
||||
|
||||
x = mx.random.normal((1, 5, 16, 16, 12))
|
||||
@@ -883,12 +968,8 @@ class TestVAE21RoundTrip:
|
||||
z_dim = 4
|
||||
dim = 8
|
||||
# No temporal up/downsampling to keep the test simple
|
||||
enc = Encoder3d(
|
||||
dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False]
|
||||
)
|
||||
dec = Decoder3d(
|
||||
dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False]
|
||||
)
|
||||
enc = Encoder3d(dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False])
|
||||
dec = Decoder3d(dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False])
|
||||
mx.eval(enc.parameters(), dec.parameters())
|
||||
|
||||
# [B=1, C=3, T=1, H=8, W=8]
|
||||
@@ -937,15 +1018,12 @@ class TestVAE22RoundTrip:
|
||||
mx.eval(out)
|
||||
|
||||
# 3 spatial upsamples(×8) + unpatchify(×2) = ×16
|
||||
assert out.shape[0] == 1 # batch
|
||||
assert out.shape[2] == 32 # H recovered
|
||||
assert out.shape[3] == 32 # W recovered
|
||||
assert out.shape[-1] == 3 # RGB
|
||||
assert out.shape[0] == 1 # batch
|
||||
assert out.shape[2] == 32 # H recovered
|
||||
assert out.shape[3] == 32 # W recovered
|
||||
assert out.shape[-1] == 3 # RGB
|
||||
|
||||
out_np = np.array(out)
|
||||
assert np.all(np.isfinite(out_np))
|
||||
assert out_np.min() >= -1.0 - 1e-6
|
||||
assert out_np.max() <= 1.0 + 1e-6
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user