Files
mlx-video/tests/test_wan_tiling.py

214 lines
6.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Tests for Wan VAE tiled decoding."""
import mlx.core as mx
import numpy as np
from mlx_video.models.ltx_2.video_vae.tiling import (
TilingConfig,
decode_with_tiling,
split_in_spatial,
)
class TestNonCausalTemporal:
"""Tests for the causal_temporal=False path in decode_with_tiling."""
def test_split_spatial_for_temporal(self):
"""Non-causal temporal should use split_in_spatial (no causal shift)."""
intervals = split_in_spatial(8, 2, 20)
# No causal adjustment: starts should be evenly spaced
assert intervals.starts[0] == 0
for i in range(1, len(intervals.starts)):
assert intervals.starts[i] == intervals.starts[i - 1] + (8 - 2)
def test_causal_vs_noncausal_output_size(self):
"""Causal temporal gives 1+(T-1)*S frames, non-causal gives T*S."""
mx.random.seed(42)
b, c, t, h, w = 1, 4, 4, 4, 4
latents = mx.random.normal((b, c, t, h, w))
scale = 4
# Simple passthrough decoder: just repeat along dimensions
def dummy_decoder_causal(x, **kwargs):
b, c, t, h, w = x.shape
out_t = 1 + (t - 1) * scale
out_h = h * scale
out_w = w * scale
return mx.ones((b, 3, out_t, out_h, out_w))
def dummy_decoder_noncausal(x, **kwargs):
b, c, t, h, w = x.shape
out_t = t * scale
out_h = h * scale
out_w = w * scale
return mx.ones((b, 3, out_t, out_h, out_w))
config = TilingConfig.spatial_only(tile_size=128, overlap=64)
# Causal: 1 + (4-1)*4 = 13
out_causal = decode_with_tiling(
dummy_decoder_causal,
latents,
config,
spatial_scale=scale,
temporal_scale=scale,
causal_temporal=True,
)
mx.eval(out_causal)
assert out_causal.shape[2] == 1 + (t - 1) * scale # 13
# Non-causal: 4*4 = 16
out_noncausal = decode_with_tiling(
dummy_decoder_noncausal,
latents,
config,
spatial_scale=scale,
temporal_scale=scale,
causal_temporal=False,
)
mx.eval(out_noncausal)
assert out_noncausal.shape[2] == t * scale # 16
class TestWan22TiledDecoding:
"""Tests for Wan2.2 VAE tiled decoding."""
def _make_small_wan22_decoder(self):
"""Create a small Wan2.2 decoder for testing."""
from mlx_video.models.wan_2.vae22 import Wan22VAEDecoder
# Use very small dimensions for fast testing
vae = Wan22VAEDecoder(z_dim=48, dim=16, dec_dim=16)
mx.eval(vae.parameters())
return vae
def test_decode_tiled_output_shape(self):
"""Tiled decode should produce same shape as non-tiled."""
mx.random.seed(42)
vae = self._make_small_wan22_decoder()
# Small input: [B=1, T=3, H=2, W=2, C=48]
z = mx.random.normal((1, 3, 2, 2, 48))
mx.eval(z)
# Non-tiled
out_regular = vae(z)
mx.eval(out_regular)
# Tiled (force tiling with very small tile sizes)
# Use spatial tile=32px (2 latent @ scale 16) and temporal=8 frames (2 latent @ scale 4)
config = TilingConfig(
spatial_config=None, # Don't tile spatially (input is tiny)
temporal_config=None, # Don't tile temporally (input is tiny)
)
# With no tiling config, decode_tiled should fall through to regular decode
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
mx.eval(out_tiled)
# Both should produce the same shape
assert (
out_regular.shape == out_tiled.shape
), f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
def test_decode_tiled_falls_through_when_small(self):
"""When input is smaller than tile size, decode_tiled should produce same output as __call__."""
mx.random.seed(42)
vae = self._make_small_wan22_decoder()
# Input smaller than any tile size
z = mx.random.normal((1, 2, 2, 2, 48))
mx.eval(z)
out_regular = vae(z)
mx.eval(out_regular)
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
mx.eval(out_tiled)
np.testing.assert_allclose(
np.array(out_regular),
np.array(out_tiled),
rtol=1e-4,
atol=1e-4,
err_msg="Tiled decode should match regular decode for small inputs",
)
class TestWan21TiledDecoding:
"""Tests for Wan2.1 VAE tiled decoding."""
def _make_small_wan21_vae(self):
"""Create a small Wan2.1 VAE for testing."""
from mlx_video.models.wan_2.vae import WanVAE
vae = WanVAE(z_dim=16)
mx.eval(vae.parameters())
return vae
def test_decode_tiled_output_shape(self):
"""Tiled decode should produce correct output shape."""
mx.random.seed(42)
vae = self._make_small_wan21_vae()
# [B=1, C=16, T=3, H=4, W=4]
z = mx.random.normal((1, 16, 3, 4, 4))
mx.eval(z)
out_regular = vae.decode(z)
mx.eval(out_regular)
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
mx.eval(out_tiled)
assert (
out_regular.shape == out_tiled.shape
), f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
def test_decode_tiled_falls_through_when_small(self):
"""When input is smaller than tile size, decode_tiled should produce same output as decode."""
mx.random.seed(42)
vae = self._make_small_wan21_vae()
z = mx.random.normal((1, 16, 2, 4, 4))
mx.eval(z)
out_regular = vae.decode(z)
mx.eval(out_regular)
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
mx.eval(out_tiled)
np.testing.assert_allclose(
np.array(out_regular),
np.array(out_tiled),
rtol=1e-4,
atol=1e-4,
err_msg="Tiled decode should match regular decode for small inputs",
)
class TestWan21TemporalScale:
"""Verify Wan2.1 decoder temporal output is T*4 (non-causal)."""
def test_wan21_decoder_temporal_output(self):
"""Wan2.1 Decoder3d should produce T*4 temporal output (non-causal doubling)."""
from mlx_video.models.wan_2.vae import Decoder3d
# Small decoder for fast test
dec = Decoder3d(
dim=16,
z_dim=4,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temporal_upsample=[True, True, False],
)
mx.eval(dec.parameters())
x = mx.random.normal((1, 4, 3, 4, 4)) # T=3
mx.eval(x)
out = dec(x)
mx.eval(out)
# With two temporal 2× upsamples: T=3 → 6 → 12
assert out.shape[2] == 3 * 4, f"Expected T=12, got T={out.shape[2]}"