format
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
"""Tests for VAE streaming and chunked conv features."""
|
||||
|
||||
import pytest
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import (
|
||||
@@ -50,7 +50,7 @@ class TestChunkedConv:
|
||||
np.array(out_chunked),
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
err_msg="Chunked conv output differs from regular output"
|
||||
err_msg="Chunked conv output differs from regular output",
|
||||
)
|
||||
|
||||
def test_chunked_conv_small_input_passthrough(self):
|
||||
@@ -117,13 +117,17 @@ class TestProgressiveFrameSaving:
|
||||
frames_received = []
|
||||
|
||||
def on_frames_ready(frames: mx.array, start_idx: int):
|
||||
frames_received.append({
|
||||
'shape': frames.shape,
|
||||
'start_idx': start_idx,
|
||||
})
|
||||
frames_received.append(
|
||||
{
|
||||
"shape": frames.shape,
|
||||
"start_idx": start_idx,
|
||||
}
|
||||
)
|
||||
|
||||
# Create a mock decoder that just returns scaled input
|
||||
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False):
|
||||
def mock_decoder(
|
||||
x, causal=False, timestep=None, debug=False, chunked_conv=False
|
||||
):
|
||||
# Simulate VAE output: upsample 8x temporal, 32x spatial
|
||||
b, c, f, h, w = x.shape
|
||||
out_f = 1 + (f - 1) * 8
|
||||
@@ -154,7 +158,9 @@ class TestProgressiveFrameSaving:
|
||||
|
||||
# All received frames should have correct channel count
|
||||
for received in frames_received:
|
||||
assert received['shape'][1] == 3, f"Expected 3 channels, got {received['shape'][1]}"
|
||||
assert (
|
||||
received["shape"][1] == 3
|
||||
), f"Expected 3 channels, got {received['shape'][1]}"
|
||||
|
||||
def test_on_frames_ready_covers_all_frames(self):
|
||||
"""Verify all frames are emitted via callbacks."""
|
||||
@@ -165,7 +171,9 @@ class TestProgressiveFrameSaving:
|
||||
for i in range(num_frames):
|
||||
all_frame_indices.add(start_idx + i)
|
||||
|
||||
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False):
|
||||
def mock_decoder(
|
||||
x, causal=False, timestep=None, debug=False, chunked_conv=False
|
||||
):
|
||||
b, c, f, h, w = x.shape
|
||||
out_f = 1 + (f - 1) * 8
|
||||
out_h = h * 32
|
||||
@@ -191,24 +199,29 @@ class TestProgressiveFrameSaving:
|
||||
expected_frames = 1 + (12 - 1) * 8 # 89 frames
|
||||
|
||||
# All frames should have been emitted
|
||||
assert len(all_frame_indices) == expected_frames, \
|
||||
f"Expected {expected_frames} frames, got {len(all_frame_indices)}"
|
||||
assert all_frame_indices == set(range(expected_frames)), \
|
||||
"Not all frame indices were covered"
|
||||
assert (
|
||||
len(all_frame_indices) == expected_frames
|
||||
), f"Expected {expected_frames} frames, got {len(all_frame_indices)}"
|
||||
assert all_frame_indices == set(
|
||||
range(expected_frames)
|
||||
), "Not all frame indices were covered"
|
||||
|
||||
|
||||
class TestAutoChunkedConv:
|
||||
"""Tests for auto-enabling chunked_conv based on tiling mode."""
|
||||
|
||||
@pytest.mark.parametrize("tiling_mode,should_enable", [
|
||||
("conservative", True),
|
||||
("none", True),
|
||||
("auto", True),
|
||||
("default", True),
|
||||
("spatial", True),
|
||||
("aggressive", False),
|
||||
("temporal", False),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"tiling_mode,should_enable",
|
||||
[
|
||||
("conservative", True),
|
||||
("none", True),
|
||||
("auto", True),
|
||||
("default", True),
|
||||
("spatial", True),
|
||||
("aggressive", False),
|
||||
("temporal", False),
|
||||
],
|
||||
)
|
||||
def test_chunked_conv_auto_enable(self, tiling_mode: str, should_enable: bool):
|
||||
"""Verify chunked_conv is auto-enabled for correct tiling modes."""
|
||||
# The logic is: tiling_mode in ("conservative", "none", "auto", "default", "spatial")
|
||||
@@ -216,8 +229,9 @@ class TestAutoChunkedConv:
|
||||
|
||||
use_chunked_conv = tiling_mode in expected_modes
|
||||
|
||||
assert use_chunked_conv == should_enable, \
|
||||
f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}"
|
||||
assert (
|
||||
use_chunked_conv == should_enable
|
||||
), f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}"
|
||||
|
||||
|
||||
class TestTrapezoidalMask:
|
||||
@@ -250,7 +264,9 @@ class TestTrapezoidalMask:
|
||||
|
||||
# Right ramp should be decreasing
|
||||
right_ramp = mask_np[-8:]
|
||||
assert np.all(np.diff(right_ramp) <= 0), "Right ramp not monotonically decreasing"
|
||||
assert np.all(
|
||||
np.diff(right_ramp) <= 0
|
||||
), "Right ramp not monotonically decreasing"
|
||||
|
||||
def test_temporal_mask_starts_from_zero(self):
|
||||
"""Verify temporal mask (left_starts_from_0=True) starts from 0."""
|
||||
|
||||
Reference in New Issue
Block a user