feat(wan): Add Wan2.2 I2V support
This commit is contained in:
871
tests/test_wan_vae.py
Normal file
871
tests/test_wan_vae.py
Normal file
@@ -0,0 +1,871 @@
|
||||
"""Tests for Wan VAE 2.1 and 2.2 components."""
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VAE 2.1 Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCausalConv3d:
|
||||
def test_output_shape_stride1(self):
|
||||
from mlx_video.models.wan.vae import CausalConv3d
|
||||
conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1)
|
||||
# Initialize weights
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
|
||||
x = mx.random.normal((1, 4, 3, 8, 8)) # [B, C, T, H, W]
|
||||
out = conv(x)
|
||||
mx.eval(out)
|
||||
# With causal padding and padding=1 on spatial, dims should be preserved
|
||||
assert out.shape[0] == 1
|
||||
assert out.shape[1] == 8 # out_channels
|
||||
assert out.shape[2] == 3 # T preserved
|
||||
assert out.shape[3] == 8 # H preserved
|
||||
assert out.shape[4] == 8 # W preserved
|
||||
|
||||
def test_output_shape_kernel1(self):
|
||||
from mlx_video.models.wan.vae import CausalConv3d
|
||||
conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0)
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
|
||||
x = mx.random.normal((1, 4, 2, 4, 4))
|
||||
out = conv(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 8, 2, 4, 4)
|
||||
|
||||
def test_causal_padding(self):
|
||||
"""Causal conv should only use past/current frames, not future."""
|
||||
from mlx_video.models.wan.vae import CausalConv3d
|
||||
conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1)
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
|
||||
conv.bias = mx.zeros((2,))
|
||||
# Create input where only the first frame has signal
|
||||
x = mx.zeros((1, 2, 4, 4, 4))
|
||||
x_np = np.zeros((1, 2, 4, 4, 4), dtype=np.float32)
|
||||
x_np[:, :, 0, :, :] = 1.0
|
||||
x = mx.array(x_np)
|
||||
out = conv(x)
|
||||
mx.eval(out)
|
||||
# Due to causal padding, the output at t=0 should only depend on t=0
|
||||
|
||||
|
||||
class TestResidualBlock:
|
||||
def test_same_dim(self):
|
||||
from mlx_video.models.wan.vae import ResidualBlock
|
||||
block = ResidualBlock(8, 8)
|
||||
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 8, 2, 4, 4)
|
||||
|
||||
def test_different_dim(self):
|
||||
from mlx_video.models.wan.vae import ResidualBlock
|
||||
block = ResidualBlock(8, 16)
|
||||
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 16, 2, 4, 4)
|
||||
|
||||
def test_shortcut_exists_when_dims_differ(self):
|
||||
from mlx_video.models.wan.vae import ResidualBlock
|
||||
block = ResidualBlock(8, 16)
|
||||
assert block.shortcut is not None
|
||||
|
||||
def test_no_shortcut_when_dims_same(self):
|
||||
from mlx_video.models.wan.vae import ResidualBlock
|
||||
block = ResidualBlock(8, 8)
|
||||
assert block.shortcut is None
|
||||
|
||||
|
||||
class TestAttentionBlock:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae import AttentionBlock
|
||||
block = AttentionBlock(8)
|
||||
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 8, 2, 4, 4)
|
||||
|
||||
def test_residual_connection(self):
|
||||
from mlx_video.models.wan.vae import AttentionBlock
|
||||
block = AttentionBlock(8)
|
||||
x = mx.random.normal((1, 8, 1, 3, 3))
|
||||
out = block(x)
|
||||
mx.eval(x, out)
|
||||
# Residual: output should not be zero even with random init
|
||||
assert np.abs(np.array(out)).max() > 0
|
||||
|
||||
|
||||
class TestWanVAE:
|
||||
def test_instantiation(self):
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
vae = WanVAE(z_dim=16)
|
||||
assert vae.z_dim == 16
|
||||
assert vae.mean.shape == (16,)
|
||||
assert vae.std.shape == (16,)
|
||||
|
||||
def test_normalization_stats(self):
|
||||
from mlx_video.models.wan.vae import WanVAE, VAE_MEAN, VAE_STD
|
||||
assert len(VAE_MEAN) == 16
|
||||
assert len(VAE_STD) == 16
|
||||
assert all(s > 0 for s in VAE_STD)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wan2.2 VAE Component Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVAE22CausalConv3d:
|
||||
"""Tests for vae22.CausalConv3d (channels-last)."""
|
||||
|
||||
def test_output_shape_k3(self):
|
||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||
conv = CausalConv3d(8, 16, kernel_size=3, padding=1)
|
||||
x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C]
|
||||
out = conv(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 4, 8, 8, 16)
|
||||
|
||||
def test_output_shape_k1(self):
|
||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||
conv = CausalConv3d(8, 16, kernel_size=1)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = conv(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_temporal_causal(self):
|
||||
"""Output at t=0 should not depend on t>0."""
|
||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||
conv = CausalConv3d(2, 2, kernel_size=3, padding=1)
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
|
||||
conv.bias = mx.zeros(conv.bias.shape)
|
||||
|
||||
x = mx.zeros((1, 4, 4, 4, 2))
|
||||
out_zero = conv(x)
|
||||
mx.eval(out_zero)
|
||||
t0_ref = np.array(out_zero[0, 0])
|
||||
|
||||
# Modify t=2..3; output at t=0 should be unchanged
|
||||
x_mod = mx.concatenate([
|
||||
x[:, :2],
|
||||
mx.ones((1, 2, 4, 4, 2)),
|
||||
], axis=1)
|
||||
out_mod = conv(x_mod)
|
||||
mx.eval(out_mod)
|
||||
t0_mod = np.array(out_mod[0, 0])
|
||||
np.testing.assert_allclose(t0_ref, t0_mod, atol=1e-5)
|
||||
|
||||
def test_channels_last_format(self):
|
||||
"""Verify input/output are channels-last [B, T, H, W, C]."""
|
||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||
conv = CausalConv3d(4, 8, kernel_size=3, padding=1)
|
||||
x = mx.random.normal((2, 3, 6, 6, 4))
|
||||
out = conv(x)
|
||||
mx.eval(out)
|
||||
assert out.shape[-1] == 8 # last dim = out_channels
|
||||
|
||||
|
||||
class TestRMSNorm:
|
||||
"""Tests for vae22.RMS_norm (actually L2 normalization)."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import RMS_norm
|
||||
norm = RMS_norm(16)
|
||||
x = mx.random.normal((2, 4, 4, 4, 16))
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == x.shape
|
||||
|
||||
def test_l2_normalization(self):
|
||||
"""RMS_norm should normalize to unit L2 norm * sqrt(dim)."""
|
||||
from mlx_video.models.wan.vae22 import RMS_norm
|
||||
dim = 32
|
||||
norm = RMS_norm(dim)
|
||||
x = mx.random.normal((1, 1, 1, 1, dim)) * 5.0 # large values
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
# After L2 norm * scale(=sqrt(dim)) * gamma(=1): ||out|| = sqrt(dim)
|
||||
out_np = np.array(out).flatten()
|
||||
l2 = np.linalg.norm(out_np)
|
||||
np.testing.assert_allclose(l2, math.sqrt(dim), rtol=1e-3)
|
||||
|
||||
def test_scale_invariant(self):
|
||||
"""Scaling input by constant should not change output (L2 norm property)."""
|
||||
from mlx_video.models.wan.vae22 import RMS_norm
|
||||
norm = RMS_norm(8)
|
||||
x = mx.random.normal((1, 1, 1, 1, 8))
|
||||
out1 = norm(x)
|
||||
out2 = norm(x * 10.0)
|
||||
mx.eval(out1, out2)
|
||||
np.testing.assert_allclose(np.array(out1), np.array(out2), atol=1e-4)
|
||||
|
||||
def test_gamma_effect(self):
|
||||
"""Non-unit gamma should scale output."""
|
||||
from mlx_video.models.wan.vae22 import RMS_norm
|
||||
norm = RMS_norm(4)
|
||||
norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0])
|
||||
x = mx.ones((1, 1, 1, 1, 4))
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
# With gamma=2, each component is 2 * sqrt(4) * x/||x|| = 2 * 2 * 1/2 = 2
|
||||
np.testing.assert_allclose(np.array(out).flatten(), 2.0, atol=1e-4)
|
||||
|
||||
|
||||
class TestDupUp3D:
|
||||
"""Tests for vae22.DupUp3D spatial/temporal upsampling."""
|
||||
|
||||
def test_spatial_only(self):
|
||||
from mlx_video.models.wan.vae22 import DupUp3D
|
||||
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||
out = up(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 3, 8, 8, 4)
|
||||
|
||||
def test_temporal_and_spatial(self):
|
||||
from mlx_video.models.wan.vae22 import DupUp3D
|
||||
up = DupUp3D(16, 8, factor_t=2, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 16))
|
||||
out = up(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 6, 8, 8, 8)
|
||||
|
||||
def test_first_chunk_trims(self):
|
||||
from mlx_video.models.wan.vae22 import DupUp3D
|
||||
up = DupUp3D(8, 4, factor_t=2, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||
out_normal = up(x, first_chunk=False)
|
||||
out_trimmed = up(x, first_chunk=True)
|
||||
mx.eval(out_normal, out_trimmed)
|
||||
# first_chunk removes factor_t-1=1 temporal frame
|
||||
assert out_normal.shape[1] == 6
|
||||
assert out_trimmed.shape[1] == 5
|
||||
|
||||
def test_no_temporal_first_chunk_noop(self):
|
||||
from mlx_video.models.wan.vae22 import DupUp3D
|
||||
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||
out_normal = up(x, first_chunk=False)
|
||||
out_trimmed = up(x, first_chunk=True)
|
||||
mx.eval(out_normal, out_trimmed)
|
||||
# factor_t=1, so first_chunk removes 0 frames
|
||||
assert out_normal.shape == out_trimmed.shape
|
||||
|
||||
|
||||
class TestVAE22Resample:
|
||||
"""Tests for vae22.Resample (spatial/temporal upsampling)."""
|
||||
|
||||
def test_upsample2d_shape(self):
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
r = Resample(8, "upsample2d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = r(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 8, 8, 8) # 2x spatial, same temporal
|
||||
|
||||
def test_upsample3d_shape(self):
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
r = Resample(8, "upsample3d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = r(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 4, 8, 8, 8) # 2x spatial + 2x temporal
|
||||
|
||||
def test_upsample3d_first_chunk(self):
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
r = Resample(8, "upsample3d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = r(x, first_chunk=True)
|
||||
mx.eval(out)
|
||||
# first_chunk: 1 (bypass) + 2*(T-1) (interleaved) = 2T-1 = 3
|
||||
assert out.shape == (1, 3, 8, 8, 8)
|
||||
|
||||
def test_upsample3d_first_chunk_single_frame(self):
|
||||
"""Single-frame input with first_chunk: no temporal upsample."""
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
r = Resample(8, "upsample3d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
x = mx.random.normal((1, 1, 4, 4, 8))
|
||||
out = r(x, first_chunk=True)
|
||||
mx.eval(out)
|
||||
# Single frame with first_chunk: falls through to non-first path
|
||||
# time_conv on 1 frame → 2 interleaved
|
||||
assert out.shape == (1, 2, 8, 8, 8)
|
||||
|
||||
def test_upsample3d_first_frame_bypasses_time_conv(self):
|
||||
"""First frame of first_chunk should NOT go through time_conv.
|
||||
|
||||
Official Wan2.2 skips time_conv for the very first frame entirely.
|
||||
We verify this by checking that the first output frame depends only on
|
||||
the first input frame (not on time_conv parameters).
|
||||
"""
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
C = 8
|
||||
r = Resample(C, "upsample3d")
|
||||
# Set time_conv weights to large values so its effect is detectable
|
||||
r.time_conv.weight = mx.ones(r.time_conv.weight.shape) * 10.0
|
||||
r.time_conv.bias = mx.zeros(r.time_conv.bias.shape)
|
||||
# Set spatial conv to identity-like
|
||||
r.resample_weight = mx.zeros(r.resample_weight.shape)
|
||||
r.resample_bias = mx.zeros(r.resample_bias.shape)
|
||||
|
||||
x = mx.random.normal((1, 3, 2, 2, C))
|
||||
out = r(x, first_chunk=True)
|
||||
mx.eval(out)
|
||||
# Output: 5 frames (1 bypass + 4 interleaved from 2 remaining)
|
||||
assert out.shape[1] == 5
|
||||
|
||||
# First frame should be spatial upsample of x[:, 0:1] only.
|
||||
# Run just the first frame through spatial upsample for reference
|
||||
first_only = x[:, 0:1]
|
||||
ref = r._upsample2x(first_only.reshape(1, 2, 2, C))
|
||||
ref = mx.pad(ref, [(0, 0), (1, 1), (1, 1), (0, 0)])
|
||||
ref = mx.conv_general(ref, r.resample_weight) + r.resample_bias
|
||||
mx.eval(ref)
|
||||
|
||||
# Compare first output frame to reference
|
||||
first_out = out[:, 0:1].reshape(1, out.shape[2], out.shape[3], C)
|
||||
mx.eval(first_out)
|
||||
assert mx.allclose(first_out, ref, atol=1e-5).item(), \
|
||||
"First frame should bypass time_conv and match spatial-only upsample"
|
||||
|
||||
|
||||
class TestVAE22ResidualBlock:
|
||||
"""Tests for vae22.ResidualBlock."""
|
||||
|
||||
def test_same_dim(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||
block = ResidualBlock(8, 8)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 8)
|
||||
|
||||
def test_different_dim(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||
block = ResidualBlock(8, 16)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_shortcut_when_dims_differ(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||
block = ResidualBlock(8, 16)
|
||||
assert block.shortcut is not None
|
||||
|
||||
def test_no_shortcut_same_dim(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||
block = ResidualBlock(8, 8)
|
||||
assert block.shortcut is None
|
||||
|
||||
|
||||
class TestResidualBlockLayers:
|
||||
"""Tests for vae22.ResidualBlockLayers naming convention."""
|
||||
|
||||
def test_layer_names_no_underscore_prefix(self):
|
||||
"""Layer names must NOT start with underscore (MLX ignores them)."""
|
||||
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
||||
block = ResidualBlockLayers(8, 8)
|
||||
params = dict(block.parameters())
|
||||
# All param keys should use layer_N, not _layer_N
|
||||
for key in params:
|
||||
assert not key.startswith("_"), f"Parameter {key} starts with underscore"
|
||||
|
||||
def test_has_expected_layers(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
||||
block = ResidualBlockLayers(8, 16)
|
||||
assert hasattr(block, "layer_0") # first RMS_norm
|
||||
assert hasattr(block, "layer_2") # first CausalConv3d
|
||||
assert hasattr(block, "layer_3") # second RMS_norm
|
||||
assert hasattr(block, "layer_6") # second CausalConv3d
|
||||
|
||||
def test_forward_shape(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
||||
block = ResidualBlockLayers(8, 16)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
|
||||
class TestVAE22AttentionBlock:
|
||||
"""Tests for vae22.AttentionBlock (per-frame 2D self-attention)."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import AttentionBlock
|
||||
block = AttentionBlock(16)
|
||||
block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01
|
||||
block.proj_weight = mx.random.normal(block.proj_weight.shape) * 0.01
|
||||
x = mx.random.normal((1, 2, 4, 4, 16))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_residual_connection(self):
|
||||
from mlx_video.models.wan.vae22 import AttentionBlock
|
||||
block = AttentionBlock(8)
|
||||
block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape)
|
||||
block.proj_weight = mx.zeros(block.proj_weight.shape)
|
||||
x = mx.ones((1, 1, 2, 2, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
# With zero weights, attention output is 0 → residual is identity
|
||||
np.testing.assert_allclose(np.array(out), np.array(x), atol=1e-5)
|
||||
|
||||
|
||||
class TestHead22:
|
||||
"""Tests for vae22.Head22 output head."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import Head22
|
||||
head = Head22(16, out_channels=12)
|
||||
x = mx.random.normal((1, 2, 4, 4, 16))
|
||||
out = head(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 12)
|
||||
|
||||
def test_layer_names_no_underscore(self):
|
||||
"""Head layers must not use underscore prefix."""
|
||||
from mlx_video.models.wan.vae22 import Head22
|
||||
head = Head22(8)
|
||||
assert hasattr(head, "layer_0") # RMS_norm
|
||||
assert hasattr(head, "layer_2") # CausalConv3d
|
||||
params = dict(head.parameters())
|
||||
for key in params:
|
||||
assert not key.startswith("_"), f"Head param {key} starts with underscore"
|
||||
|
||||
|
||||
class TestUnpatchify:
|
||||
"""Tests for vae22._unpatchify."""
|
||||
|
||||
def test_basic_shape(self):
|
||||
from mlx_video.models.wan.vae22 import _unpatchify
|
||||
x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2
|
||||
out = _unpatchify(x, patch_size=2)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 8, 8, 3)
|
||||
|
||||
def test_patch_size_1_noop(self):
|
||||
from mlx_video.models.wan.vae22 import _unpatchify
|
||||
x = mx.random.normal((1, 2, 4, 4, 3))
|
||||
out = _unpatchify(x, patch_size=1)
|
||||
mx.eval(out)
|
||||
np.testing.assert_array_equal(np.array(out), np.array(x))
|
||||
|
||||
def test_preserves_content(self):
|
||||
"""Unpatchify should be a lossless rearrangement."""
|
||||
from mlx_video.models.wan.vae22 import _unpatchify
|
||||
x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32)
|
||||
out = _unpatchify(x, patch_size=2)
|
||||
mx.eval(out)
|
||||
# All elements should be preserved
|
||||
assert np.array(out).size == 48
|
||||
assert set(np.array(out).flatten().tolist()) == set(range(48))
|
||||
|
||||
|
||||
class TestDenormalizeLatents:
|
||||
"""Tests for vae22.denormalize_latents."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||
z = mx.random.normal((1, 2, 4, 4, 48))
|
||||
out = denormalize_latents(z)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 48)
|
||||
|
||||
def test_custom_mean_std(self):
|
||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||
z = mx.ones((1, 1, 1, 1, 4))
|
||||
mean = mx.array([1.0, 2.0, 3.0, 4.0])
|
||||
std = mx.array([0.5, 0.5, 0.5, 0.5])
|
||||
out = denormalize_latents(z, mean=mean, std=std)
|
||||
mx.eval(out)
|
||||
# z * std + mean = 1*0.5 + [1,2,3,4] = [1.5, 2.5, 3.5, 4.5]
|
||||
np.testing.assert_allclose(np.array(out).flatten(), [1.5, 2.5, 3.5, 4.5], atol=1e-5)
|
||||
|
||||
def test_uses_default_constants(self):
|
||||
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD, denormalize_latents
|
||||
# Should not raise with default constants
|
||||
z = mx.zeros((1, 1, 1, 1, 48))
|
||||
out = denormalize_latents(z)
|
||||
mx.eval(out)
|
||||
# z=0 → result = 0 * std + mean = mean
|
||||
np.testing.assert_allclose(
|
||||
np.array(out).flatten(),
|
||||
np.array(VAE22_MEAN).flatten(),
|
||||
atol=1e-5,
|
||||
)
|
||||
|
||||
|
||||
class TestVAE22NormConstants:
|
||||
"""Tests for VAE22_MEAN and VAE22_STD constants."""
|
||||
|
||||
def test_dimensions(self):
|
||||
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD
|
||||
mx.eval(VAE22_MEAN, VAE22_STD)
|
||||
assert VAE22_MEAN.shape == (48,)
|
||||
assert VAE22_STD.shape == (48,)
|
||||
|
||||
def test_std_positive(self):
|
||||
from mlx_video.models.wan.vae22 import VAE22_STD
|
||||
mx.eval(VAE22_STD)
|
||||
assert (np.array(VAE22_STD) > 0).all()
|
||||
|
||||
|
||||
class TestWan22VAEDecoder:
|
||||
"""Tests for the full Wan22VAEDecoder (tiny configuration)."""
|
||||
|
||||
def test_output_shape_small(self):
|
||||
"""Tiny decoder should produce correct spatial/temporal output."""
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||
# Use very small dims to keep test fast
|
||||
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
|
||||
# Latent: [B=1, T=3, H=2, W=2, C=4]
|
||||
# Expected: temporal 3→5→9→9→9 (two temporal upsamples), spatial 2→4→8→16
|
||||
z = mx.random.normal((1, 3, 2, 2, 4)) * 0.1
|
||||
out = dec(z)
|
||||
mx.eval(out)
|
||||
# Output should have 3 RGB channels and be clipped to [-1, 1]
|
||||
assert out.shape[-1] == 3
|
||||
assert out.ndim == 5
|
||||
assert np.array(out).min() >= -1.0
|
||||
assert np.array(out).max() <= 1.0
|
||||
|
||||
def test_output_clipped(self):
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
|
||||
z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values
|
||||
out = dec(z)
|
||||
mx.eval(out)
|
||||
assert np.array(out).min() >= -1.0 - 1e-6
|
||||
assert np.array(out).max() <= 1.0 + 1e-6
|
||||
|
||||
|
||||
class TestSanitizeWan22VAEWeights:
|
||||
"""Tests for vae22.sanitize_wan22_vae_weights."""
|
||||
|
||||
def test_skip_encoder(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
weights = {
|
||||
"encoder.layer.weight": mx.zeros((4,)),
|
||||
"conv1.weight": mx.zeros((4,)),
|
||||
"decoder.conv1.bias": mx.zeros((4,)),
|
||||
}
|
||||
out = sanitize_wan22_vae_weights(weights)
|
||||
assert "encoder.layer.weight" not in out
|
||||
assert "conv1.weight" not in out
|
||||
assert "decoder.conv1.bias" in out
|
||||
|
||||
def test_sequential_index_remapping(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
weights = {
|
||||
"decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)),
|
||||
"decoder.upsamples.0.upsamples.0.residual.6.bias": mx.zeros((8,)),
|
||||
"decoder.head.0.gamma": mx.ones((4,)),
|
||||
"decoder.head.2.bias": mx.zeros((12,)),
|
||||
}
|
||||
out = sanitize_wan22_vae_weights(weights)
|
||||
assert "decoder.upsamples.0.upsamples.0.residual.layer_0.gamma" in out
|
||||
assert "decoder.upsamples.0.upsamples.0.residual.layer_6.bias" in out
|
||||
assert "decoder.head.layer_0.gamma" in out
|
||||
assert "decoder.head.layer_2.bias" in out
|
||||
|
||||
def test_resample_conv_remapping(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
weights = {
|
||||
"decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)),
|
||||
"decoder.upsamples.1.upsamples.3.resample.1.bias": mx.zeros((8,)),
|
||||
}
|
||||
out = sanitize_wan22_vae_weights(weights)
|
||||
assert "decoder.upsamples.1.upsamples.3.resample_weight" in out
|
||||
assert "decoder.upsamples.1.upsamples.3.resample_bias" in out
|
||||
|
||||
def test_attention_remapping(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
weights = {
|
||||
"decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)),
|
||||
"decoder.middle.1.to_qkv.bias": mx.zeros((24,)),
|
||||
"decoder.middle.1.proj.weight": mx.zeros((8, 8, 1, 1)),
|
||||
"decoder.middle.1.proj.bias": mx.zeros((8,)),
|
||||
}
|
||||
out = sanitize_wan22_vae_weights(weights)
|
||||
assert "decoder.middle.1.to_qkv_weight" in out
|
||||
assert "decoder.middle.1.to_qkv_bias" in out
|
||||
assert "decoder.middle.1.proj_weight" in out
|
||||
assert "decoder.middle.1.proj_bias" in out
|
||||
|
||||
def test_conv3d_transpose(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
# Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I]
|
||||
w = mx.zeros((16, 8, 3, 3, 3))
|
||||
weights = {"decoder.conv1.weight": w}
|
||||
out = sanitize_wan22_vae_weights(weights)
|
||||
assert out["decoder.conv1.weight"].shape == (16, 3, 3, 3, 8)
|
||||
|
||||
def test_conv2d_transpose(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
# Conv2d weight: [O, I, H, W] → [O, H, W, I]
|
||||
w = mx.zeros((8, 8, 3, 3))
|
||||
weights = {"decoder.upsamples.0.upsamples.2.resample.1.weight": w}
|
||||
out = sanitize_wan22_vae_weights(weights)
|
||||
key = "decoder.upsamples.0.upsamples.2.resample_weight"
|
||||
assert out[key].shape == (8, 3, 3, 8)
|
||||
|
||||
def test_gamma_squeeze(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
# gamma: (dim, 1, 1, 1) → (dim,)
|
||||
w = mx.ones((16, 1, 1, 1))
|
||||
weights = {"decoder.upsamples.0.upsamples.0.residual.0.gamma": w}
|
||||
out = sanitize_wan22_vae_weights(weights)
|
||||
key = "decoder.upsamples.0.upsamples.0.residual.layer_0.gamma"
|
||||
assert out[key].shape == (16,)
|
||||
|
||||
|
||||
class TestUpResidualBlock:
|
||||
"""Tests for vae22.Up_ResidualBlock."""
|
||||
|
||||
def test_no_upsample(self):
|
||||
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
||||
block = Up_ResidualBlock(8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
# No upsample: same shape
|
||||
assert out.shape == (1, 2, 4, 4, 8)
|
||||
|
||||
def test_spatial_upsample(self):
|
||||
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
||||
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
# 2x spatial upsample, no temporal
|
||||
assert out.shape == (1, 2, 8, 8, 4)
|
||||
|
||||
def test_spatial_temporal_upsample(self):
|
||||
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
||||
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
# 2x spatial + 2x temporal
|
||||
assert out.shape == (1, 4, 8, 8, 4)
|
||||
|
||||
|
||||
class TestPatchify:
|
||||
"""Tests for _patchify and _unpatchify round-trip."""
|
||||
|
||||
def test_roundtrip(self):
|
||||
from mlx_video.models.wan.vae22 import _patchify, _unpatchify
|
||||
|
||||
x = mx.random.normal((1, 1, 64, 64, 3))
|
||||
p = _patchify(x, patch_size=2)
|
||||
assert p.shape == (1, 1, 32, 32, 12)
|
||||
back = _unpatchify(p, patch_size=2)
|
||||
assert back.shape == x.shape
|
||||
assert float(mx.abs(x - back).max()) == 0.0
|
||||
|
||||
def test_identity_patch_1(self):
|
||||
from mlx_video.models.wan.vae22 import _patchify, _unpatchify
|
||||
|
||||
x = mx.random.normal((1, 2, 8, 8, 3))
|
||||
assert _patchify(x, patch_size=1).shape == x.shape
|
||||
assert _unpatchify(x, patch_size=1).shape == x.shape
|
||||
|
||||
|
||||
class TestAvgDown3D:
|
||||
"""Tests for AvgDown3D downsampling."""
|
||||
|
||||
def test_spatial_only(self):
|
||||
from mlx_video.models.wan.vae22 import AvgDown3D
|
||||
|
||||
down = AvgDown3D(8, 16, factor_t=1, factor_s=2)
|
||||
x = mx.random.normal((1, 2, 8, 8, 8))
|
||||
out = down(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_temporal_and_spatial(self):
|
||||
from mlx_video.models.wan.vae22 import AvgDown3D
|
||||
|
||||
down = AvgDown3D(8, 16, factor_t=2, factor_s=2)
|
||||
x = mx.random.normal((1, 4, 8, 8, 8))
|
||||
out = down(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_single_frame(self):
|
||||
from mlx_video.models.wan.vae22 import AvgDown3D
|
||||
|
||||
down = AvgDown3D(8, 8, factor_t=2, factor_s=2)
|
||||
x = mx.random.normal((1, 1, 8, 8, 8))
|
||||
out = down(x)
|
||||
mx.eval(out)
|
||||
# T=1 with factor_t=2: pads to T=2 then averages → T=1
|
||||
assert out.shape == (1, 1, 4, 4, 8)
|
||||
|
||||
|
||||
class TestDownResidualBlock:
|
||||
"""Tests for Down_ResidualBlock."""
|
||||
|
||||
def test_no_downsample(self):
|
||||
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
||||
|
||||
block = Down_ResidualBlock(8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False)
|
||||
x = mx.random.normal((1, 2, 8, 8, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 8, 8, 8)
|
||||
|
||||
def test_spatial_downsample(self):
|
||||
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
||||
|
||||
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True)
|
||||
x = mx.random.normal((1, 2, 8, 8, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_spatial_temporal_downsample(self):
|
||||
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
||||
|
||||
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True)
|
||||
x = mx.random.normal((1, 4, 8, 8, 8))
|
||||
out = block(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
|
||||
class TestEncoder3d:
|
||||
"""Tests for Encoder3d."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import Encoder3d
|
||||
|
||||
enc = Encoder3d(dim=16, z_dim=8)
|
||||
x = mx.random.normal((1, 1, 16, 16, 12))
|
||||
mx.eval(enc.parameters())
|
||||
out = enc(x)
|
||||
mx.eval(out)
|
||||
# 3 spatial downsamples ÷8: 16→2
|
||||
assert out.shape == (1, 1, 2, 2, 8)
|
||||
|
||||
def test_multi_frame(self):
|
||||
from mlx_video.models.wan.vae22 import Encoder3d
|
||||
|
||||
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
|
||||
x = mx.random.normal((1, 5, 16, 16, 12))
|
||||
mx.eval(enc.parameters())
|
||||
out = enc(x)
|
||||
mx.eval(out)
|
||||
# T: 5→3 (1st t_down) →2 (2nd t_down), spatial ÷8
|
||||
assert out.shape[2:] == (2, 2, 8)
|
||||
|
||||
|
||||
class TestWan22VAEEncoder:
|
||||
"""Tests for Wan22VAEEncoder wrapper."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||
# Input: single image 32×32 (patchify÷2 → 16×16, then 3 spatial ÷8 → 2×2)
|
||||
img = mx.random.normal((1, 1, 32, 32, 3))
|
||||
mx.eval(enc.parameters())
|
||||
z = enc(img)
|
||||
mx.eval(z)
|
||||
assert z.shape == (1, 1, 2, 2, 48)
|
||||
|
||||
def test_full_dim(self):
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=160)
|
||||
img = mx.random.normal((1, 1, 64, 64, 3))
|
||||
mx.eval(enc.parameters())
|
||||
z = enc(img)
|
||||
mx.eval(z)
|
||||
# 64 / 16 = 4 (vae stride 16×)
|
||||
assert z.shape == (1, 1, 4, 4, 48)
|
||||
|
||||
|
||||
class TestNormalizeLatents:
|
||||
"""Tests for normalize/denormalize latent roundtrip."""
|
||||
|
||||
def test_roundtrip(self):
|
||||
from mlx_video.models.wan.vae22 import denormalize_latents, normalize_latents
|
||||
|
||||
z = mx.random.normal((1, 2, 4, 4, 48))
|
||||
z_norm = normalize_latents(z)
|
||||
z_back = denormalize_latents(z_norm)
|
||||
mx.eval(z_back)
|
||||
assert float(mx.abs(z - z_back).max()) < 1e-4
|
||||
|
||||
|
||||
class TestVAEEncoderTemporalOrder:
|
||||
"""Tests that VAE encoder uses (False, True, True) temporal downsample order,
|
||||
matching official Wan2.2 vae2_2.py."""
|
||||
|
||||
def test_encoder_temporal_downsample_pattern(self):
|
||||
"""Encoder3d with (False, True, True): T=5→5→3→2."""
|
||||
from mlx_video.models.wan.vae22 import Encoder3d
|
||||
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
|
||||
x = mx.random.normal((1, 5, 16, 16, 12))
|
||||
mx.eval(enc.parameters())
|
||||
out = enc(x)
|
||||
mx.eval(out)
|
||||
assert out.shape[1] == 2
|
||||
|
||||
def test_wrapper_uses_correct_pattern(self):
|
||||
"""Wan22VAEEncoder should use (False, True, True) temporal downsample."""
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder, Resample
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||
down_blocks = enc.encoder.downsamples
|
||||
found_modes = []
|
||||
for block in down_blocks:
|
||||
for layer in block.downsamples:
|
||||
if isinstance(layer, Resample):
|
||||
found_modes.append(layer.mode)
|
||||
# First spatial-only, then two with temporal
|
||||
assert found_modes[0] == "downsample2d"
|
||||
assert any("3d" in m for m in found_modes)
|
||||
|
||||
def test_single_frame_encoder(self):
|
||||
"""Single frame (T=1) should work with (False, True, True) pattern."""
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||
img = mx.random.normal((1, 1, 32, 32, 3))
|
||||
mx.eval(enc.parameters())
|
||||
z = enc(img)
|
||||
mx.eval(z)
|
||||
assert z.shape[1] == 1
|
||||
assert z.shape[-1] == 48
|
||||
|
||||
def test_wrong_order_gives_different_result(self):
|
||||
"""(True, True, False) vs (False, True, True) produce different outputs."""
|
||||
from mlx_video.models.wan.vae22 import Encoder3d
|
||||
enc_correct = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
|
||||
enc_wrong = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
|
||||
|
||||
x = mx.random.normal((1, 5, 16, 16, 12))
|
||||
mx.eval(enc_correct.parameters())
|
||||
mx.eval(enc_wrong.parameters())
|
||||
|
||||
out_correct = enc_correct(x)
|
||||
out_wrong = enc_wrong(x)
|
||||
mx.eval(out_correct, out_wrong)
|
||||
|
||||
# Both give T=2 but spatial processing path differs
|
||||
assert out_correct.shape[1] == 2
|
||||
assert out_wrong.shape[1] == 2
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user