This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -1,8 +1,8 @@
"""Tests for VAE streaming and chunked conv features."""
import pytest
import mlx.core as mx
import numpy as np
import pytest
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx_2.video_vae.tiling import (
@@ -50,7 +50,7 @@ class TestChunkedConv:
np.array(out_chunked),
rtol=1e-5,
atol=1e-5,
err_msg="Chunked conv output differs from regular output"
err_msg="Chunked conv output differs from regular output",
)
def test_chunked_conv_small_input_passthrough(self):
@@ -117,13 +117,17 @@ class TestProgressiveFrameSaving:
frames_received = []
def on_frames_ready(frames: mx.array, start_idx: int):
frames_received.append({
'shape': frames.shape,
'start_idx': start_idx,
})
frames_received.append(
{
"shape": frames.shape,
"start_idx": start_idx,
}
)
# Create a mock decoder that just returns scaled input
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False):
def mock_decoder(
x, causal=False, timestep=None, debug=False, chunked_conv=False
):
# Simulate VAE output: upsample 8x temporal, 32x spatial
b, c, f, h, w = x.shape
out_f = 1 + (f - 1) * 8
@@ -154,7 +158,9 @@ class TestProgressiveFrameSaving:
# All received frames should have correct channel count
for received in frames_received:
assert received['shape'][1] == 3, f"Expected 3 channels, got {received['shape'][1]}"
assert (
received["shape"][1] == 3
), f"Expected 3 channels, got {received['shape'][1]}"
def test_on_frames_ready_covers_all_frames(self):
"""Verify all frames are emitted via callbacks."""
@@ -165,7 +171,9 @@ class TestProgressiveFrameSaving:
for i in range(num_frames):
all_frame_indices.add(start_idx + i)
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False):
def mock_decoder(
x, causal=False, timestep=None, debug=False, chunked_conv=False
):
b, c, f, h, w = x.shape
out_f = 1 + (f - 1) * 8
out_h = h * 32
@@ -191,24 +199,29 @@ class TestProgressiveFrameSaving:
expected_frames = 1 + (12 - 1) * 8 # 89 frames
# All frames should have been emitted
assert len(all_frame_indices) == expected_frames, \
f"Expected {expected_frames} frames, got {len(all_frame_indices)}"
assert all_frame_indices == set(range(expected_frames)), \
"Not all frame indices were covered"
assert (
len(all_frame_indices) == expected_frames
), f"Expected {expected_frames} frames, got {len(all_frame_indices)}"
assert all_frame_indices == set(
range(expected_frames)
), "Not all frame indices were covered"
class TestAutoChunkedConv:
"""Tests for auto-enabling chunked_conv based on tiling mode."""
@pytest.mark.parametrize("tiling_mode,should_enable", [
("conservative", True),
("none", True),
("auto", True),
("default", True),
("spatial", True),
("aggressive", False),
("temporal", False),
])
@pytest.mark.parametrize(
"tiling_mode,should_enable",
[
("conservative", True),
("none", True),
("auto", True),
("default", True),
("spatial", True),
("aggressive", False),
("temporal", False),
],
)
def test_chunked_conv_auto_enable(self, tiling_mode: str, should_enable: bool):
"""Verify chunked_conv is auto-enabled for correct tiling modes."""
# The logic is: tiling_mode in ("conservative", "none", "auto", "default", "spatial")
@@ -216,8 +229,9 @@ class TestAutoChunkedConv:
use_chunked_conv = tiling_mode in expected_modes
assert use_chunked_conv == should_enable, \
f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}"
assert (
use_chunked_conv == should_enable
), f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}"
class TestTrapezoidalMask:
@@ -250,7 +264,9 @@ class TestTrapezoidalMask:
# Right ramp should be decreasing
right_ramp = mask_np[-8:]
assert np.all(np.diff(right_ramp) <= 0), "Right ramp not monotonically decreasing"
assert np.all(
np.diff(right_ramp) <= 0
), "Right ramp not monotonically decreasing"
def test_temporal_mask_starts_from_zero(self):
"""Verify temporal mask (left_starts_from_0=True) starts from 0."""