add tests

This commit is contained in:
Prince Canuma
2026-01-17 23:36:39 +01:00
parent 7f20840dc7
commit f256c5fb25

301
tests/test_vae_streaming.py Normal file
View File

@@ -0,0 +1,301 @@
"""Tests for VAE streaming and chunked conv features."""
import pytest
import mlx.core as mx
import numpy as np
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx.video_vae.tiling import (
TilingConfig,
compute_trapezoidal_mask_1d,
decode_with_tiling,
)
class TestChunkedConv:
"""Tests for chunked conv optimization in DepthToSpaceUpsample."""
def test_chunked_conv_output_matches_regular(self):
"""Verify chunked_conv produces identical output to regular processing."""
mx.random.seed(42)
# Create upsampler with residual (matches decoder config)
upsampler = DepthToSpaceUpsample(
dims=3,
in_channels=256,
stride=(2, 2, 2),
residual=True,
out_channels_reduction_factor=2,
)
# Initialize weights deterministically
mx.eval(upsampler.parameters())
# Create test input: (B, C, D, H, W) with enough frames to trigger chunking
# chunked_conv activates when d > 4
x = mx.random.normal((1, 256, 8, 8, 8))
mx.eval(x)
# Run without chunked conv
out_regular = upsampler(x, causal=True, chunked_conv=False)
mx.eval(out_regular)
# Run with chunked conv
out_chunked = upsampler(x, causal=True, chunked_conv=True)
mx.eval(out_chunked)
# Outputs should be identical
np.testing.assert_allclose(
np.array(out_regular),
np.array(out_chunked),
rtol=1e-5,
atol=1e-5,
err_msg="Chunked conv output differs from regular output"
)
def test_chunked_conv_small_input_passthrough(self):
"""Verify chunked_conv doesn't activate for small inputs (d <= 4)."""
mx.random.seed(42)
upsampler = DepthToSpaceUpsample(
dims=3,
in_channels=256,
stride=(2, 2, 2),
residual=True,
out_channels_reduction_factor=2,
)
mx.eval(upsampler.parameters())
# Small input with d=4 (should NOT trigger chunking)
x = mx.random.normal((1, 256, 4, 8, 8))
mx.eval(x)
out_regular = upsampler(x, causal=True, chunked_conv=False)
out_chunked = upsampler(x, causal=True, chunked_conv=True)
mx.eval(out_regular, out_chunked)
# Should be identical since chunking doesn't activate
np.testing.assert_allclose(
np.array(out_regular),
np.array(out_chunked),
rtol=1e-5,
atol=1e-5,
)
def test_chunked_conv_output_shape(self):
"""Verify chunked_conv produces correct output shape."""
mx.random.seed(42)
upsampler = DepthToSpaceUpsample(
dims=3,
in_channels=256,
stride=(2, 2, 2),
residual=True,
out_channels_reduction_factor=2,
)
mx.eval(upsampler.parameters())
# Input shape: (1, 256, 8, 16, 16)
x = mx.random.normal((1, 256, 8, 16, 16))
mx.eval(x)
out = upsampler(x, causal=True, chunked_conv=True)
mx.eval(out)
# Expected output:
# - Channels: 256 / 2 = 128
# - Temporal: 8 * 2 - 1 = 15 (minus 1 for causal)
# - Spatial: 16 * 2 = 32
assert out.shape == (1, 128, 15, 32, 32), f"Unexpected shape: {out.shape}"
class TestProgressiveFrameSaving:
"""Tests for progressive frame saving via on_frames_ready callback."""
def test_on_frames_ready_called(self):
"""Verify on_frames_ready callback is called during tiled decoding."""
frames_received = []
def on_frames_ready(frames: mx.array, start_idx: int):
frames_received.append({
'shape': frames.shape,
'start_idx': start_idx,
})
# Create a mock decoder that just returns scaled input
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False):
# Simulate VAE output: upsample 8x temporal, 32x spatial
b, c, f, h, w = x.shape
out_f = 1 + (f - 1) * 8
out_h = h * 32
out_w = w * 32
return mx.zeros((b, 3, out_f, out_h, out_w))
# Create tiling config with temporal tiling to trigger callbacks
tiling_config = TilingConfig.temporal_only(tile_size=32, overlap=8)
# Input latents: enough frames to require multiple tiles
# 10 latent frames -> 73 output frames
latents = mx.zeros((1, 128, 10, 4, 4))
# Decode with tiling
output = decode_with_tiling(
decoder_fn=mock_decoder,
latents=latents,
tiling_config=tiling_config,
spatial_scale=32,
temporal_scale=8,
on_frames_ready=on_frames_ready,
)
mx.eval(output)
# Should have received at least one callback
assert len(frames_received) > 0, "on_frames_ready was never called"
# All received frames should have correct channel count
for received in frames_received:
assert received['shape'][1] == 3, f"Expected 3 channels, got {received['shape'][1]}"
def test_on_frames_ready_covers_all_frames(self):
"""Verify all frames are emitted via callbacks."""
all_frame_indices = set()
def on_frames_ready(frames: mx.array, start_idx: int):
num_frames = frames.shape[2]
for i in range(num_frames):
all_frame_indices.add(start_idx + i)
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False):
b, c, f, h, w = x.shape
out_f = 1 + (f - 1) * 8
out_h = h * 32
out_w = w * 32
return mx.random.normal((b, 3, out_f, out_h, out_w))
tiling_config = TilingConfig.temporal_only(tile_size=32, overlap=8)
# 12 latent frames -> 89 output frames
latents = mx.zeros((1, 128, 12, 4, 4))
output = decode_with_tiling(
decoder_fn=mock_decoder,
latents=latents,
tiling_config=tiling_config,
spatial_scale=32,
temporal_scale=8,
on_frames_ready=on_frames_ready,
)
mx.eval(output)
# Calculate expected frame count
expected_frames = 1 + (12 - 1) * 8 # 89 frames
# All frames should have been emitted
assert len(all_frame_indices) == expected_frames, \
f"Expected {expected_frames} frames, got {len(all_frame_indices)}"
assert all_frame_indices == set(range(expected_frames)), \
"Not all frame indices were covered"
class TestAutoChunkedConv:
"""Tests for auto-enabling chunked_conv based on tiling mode."""
@pytest.mark.parametrize("tiling_mode,should_enable", [
("conservative", True),
("none", True),
("auto", True),
("default", True),
("spatial", True),
("aggressive", False),
("temporal", False),
])
def test_chunked_conv_auto_enable(self, tiling_mode: str, should_enable: bool):
"""Verify chunked_conv is auto-enabled for correct tiling modes."""
# The logic is: tiling_mode in ("conservative", "none", "auto", "default", "spatial")
expected_modes = {"conservative", "none", "auto", "default", "spatial"}
use_chunked_conv = tiling_mode in expected_modes
assert use_chunked_conv == should_enable, \
f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}"
class TestTrapezoidalMask:
"""Tests for trapezoidal blending mask generation."""
def test_mask_values_in_range(self):
"""Verify mask values are always in [0, 1]."""
for length in [16, 32, 64, 128]:
for ramp in [0, 4, 8, 16]:
if ramp < length:
mask = compute_trapezoidal_mask_1d(length, ramp, ramp, False)
assert mx.all(mask >= 0).item(), f"Mask has negative values"
assert mx.all(mask <= 1).item(), f"Mask has values > 1"
def test_mask_center_is_one(self):
"""Verify center of mask is 1.0 when ramps don't overlap."""
mask = compute_trapezoidal_mask_1d(32, 8, 8, False)
# Center region should be 1.0
center = mask[12:20] # Middle portion
np.testing.assert_allclose(np.array(center), 1.0, rtol=1e-5)
def test_mask_ramp_monotonic(self):
"""Verify ramps are monotonically increasing/decreasing."""
mask = compute_trapezoidal_mask_1d(32, 8, 8, False)
mask_np = np.array(mask)
# Left ramp should be increasing
left_ramp = mask_np[:8]
assert np.all(np.diff(left_ramp) >= 0), "Left ramp not monotonically increasing"
# Right ramp should be decreasing
right_ramp = mask_np[-8:]
assert np.all(np.diff(right_ramp) <= 0), "Right ramp not monotonically decreasing"
def test_temporal_mask_starts_from_zero(self):
"""Verify temporal mask (left_starts_from_0=True) starts from 0."""
mask = compute_trapezoidal_mask_1d(32, 8, 0, left_starts_from_0=True)
assert mask[0].item() == 0.0, "Temporal mask should start from 0"
def test_spatial_mask_starts_above_zero(self):
"""Verify spatial mask (left_starts_from_0=False) starts above 0."""
mask = compute_trapezoidal_mask_1d(32, 8, 0, left_starts_from_0=False)
assert mask[0].item() > 0.0, "Spatial mask should start above 0"
class TestTilingConfig:
"""Tests for TilingConfig presets."""
def test_default_config(self):
"""Verify default tiling configuration."""
config = TilingConfig.default()
assert config.spatial_config is not None
assert config.temporal_config is not None
assert config.spatial_config.tile_size_in_pixels == 512
assert config.temporal_config.tile_size_in_frames == 64
def test_aggressive_config(self):
"""Verify aggressive tiling configuration."""
config = TilingConfig.aggressive()
assert config.spatial_config.tile_size_in_pixels == 256
assert config.temporal_config.tile_size_in_frames == 32
def test_conservative_config(self):
"""Verify conservative tiling configuration."""
config = TilingConfig.conservative()
assert config.spatial_config.tile_size_in_pixels == 768
assert config.temporal_config.tile_size_in_frames == 96
def test_auto_returns_none_for_small_video(self):
"""Verify auto returns None for small videos."""
config = TilingConfig.auto(height=256, width=256, num_frames=33)
assert config is None, "Auto should return None for small videos"
def test_auto_returns_config_for_large_video(self):
"""Verify auto returns config for large videos."""
config = TilingConfig.auto(height=1024, width=768, num_frames=145)
assert config is not None, "Auto should return config for large videos"
if __name__ == "__main__":
pytest.main([__file__, "-v"])