Merge pull request #14 from Blaizzy/pc/add-streaming

Add --stream flag and chunked conv memory optimization for VAE decoding
This commit is contained in:
Prince Canuma
2026-01-21 15:42:55 +01:00
committed by GitHub
5 changed files with 524 additions and 53 deletions

View File

@@ -215,6 +215,7 @@ def generate_video(
image_strength: float = 1.0, image_strength: float = 1.0,
image_frame_idx: int = 0, image_frame_idx: int = 0,
tiling: str = "auto", tiling: str = "auto",
stream: bool = False,
): ):
"""Generate video from text prompt, optionally conditioned on an image. """Generate video from text prompt, optionally conditioned on an image.
@@ -481,17 +482,59 @@ def generate_video(
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}") print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
tiling_config = TilingConfig.auto(height, width, num_frames) tiling_config = TilingConfig.auto(height, width, num_frames)
# Save outputs
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Stream mode: write frames as they're decoded
video_writer = None
stream_pbar = None
if stream and tiling_config is not None:
import cv2
fourcc = cv2.VideoWriter_fourcc(*'avc1')
video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
stream_pbar = tqdm(total=num_frames, desc="Streaming", unit="frame")
def on_frames_ready(frames: mx.array, start_idx: int):
"""Callback to write frames as they're finalized."""
# frames: (B, 3, num_frames, H, W)
frames = mx.squeeze(frames, axis=0) # (3, num_frames, H, W)
frames = mx.transpose(frames, (1, 2, 3, 0)) # (num_frames, H, W, 3)
frames = mx.clip((frames + 1.0) / 2.0, 0.0, 1.0)
frames = (frames * 255).astype(mx.uint8)
frames_np = np.array(frames)
for frame in frames_np:
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
stream_pbar.update(1)
else:
on_frames_ready = None
if tiling_config is not None: if tiling_config is not None:
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none" spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none" temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}") print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, debug=verbose) video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose, on_frames_ready=on_frames_ready)
else: else:
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}") print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
video = vae_decoder(latents) video = vae_decoder(latents)
mx.eval(video) mx.eval(video)
mx.clear_cache() mx.clear_cache()
# Close progressive video writer if used
if video_writer is not None:
video_writer.release()
if stream_pbar is not None:
stream_pbar.close()
print(f"{Colors.GREEN}✅ Streamed video to{Colors.RESET} {output_path}")
# Still need video_np for save_frames option
video = mx.squeeze(video, axis=0)
video = mx.transpose(video, (1, 2, 3, 0))
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
video = (video * 255).astype(mx.uint8)
video_np = np.array(video)
else:
# Convert to uint8 frames # Convert to uint8 frames
video = mx.squeeze(video, axis=0) # (C, F, H, W) video = mx.squeeze(video, axis=0) # (C, F, H, W)
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C) video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
@@ -499,15 +542,12 @@ def generate_video(
video = (video * 255).astype(mx.uint8) video = (video * 255).astype(mx.uint8)
video_np = np.array(video) video_np = np.array(video)
# Save outputs # Save video normally
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
try: try:
import cv2 import cv2
height, width = video_np.shape[1], video_np.shape[2] h, w = video_np.shape[1], video_np.shape[2]
fourcc = cv2.VideoWriter_fourcc(*'avc1') fourcc = cv2.VideoWriter_fourcc(*'avc1')
out = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height)) out = cv2.VideoWriter(str(output_path), fourcc, fps, (w, h))
for frame in video_np: for frame in video_np:
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
out.release() out.release()
@@ -654,6 +694,11 @@ Examples:
"auto=based on video size, none=disabled, default=512px/64f, " "auto=based on video size, none=disabled, default=512px/64f, "
"aggressive=256px/32f (lowest memory), conservative=768px/96f, spatial=spatial only, temporal=temporal only" "aggressive=256px/32f (lowest memory), conservative=768px/96f, spatial=spatial only, temporal=temporal only"
) )
parser.add_argument(
"--stream",
action="store_true",
help="Stream frames to output file as they're decoded (requires tiling). Allows viewing partial results sooner."
)
args = parser.parse_args() args = parser.parse_args()
generate_video( generate_video(

View File

@@ -15,7 +15,7 @@ Architecture (from PyTorch weights):
""" """
import math import math
from typing import List, Optional from typing import Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@@ -364,6 +364,7 @@ class LTX2VideoDecoder(nn.Module):
causal: bool = False, causal: bool = False,
timestep: Optional[mx.array] = None, timestep: Optional[mx.array] = None,
debug: bool = False, debug: bool = False,
chunked_conv: bool = False,
) -> mx.array: ) -> mx.array:
def debug_stats(name, t): def debug_stats(name, t):
@@ -403,6 +404,8 @@ class LTX2VideoDecoder(nn.Module):
for i, block in self.up_blocks.items(): for i, block in self.up_blocks.items():
if isinstance(block, ResBlockGroup): if isinstance(block, ResBlockGroup):
x = block(x, causal=causal, timestep=scaled_timestep) x = block(x, causal=causal, timestep=scaled_timestep)
elif isinstance(block, DepthToSpaceUpsample):
x = block(x, causal=causal, chunked_conv=chunked_conv)
else: else:
x = block(x, causal=causal) x = block(x, causal=causal)
if debug: if debug:
@@ -450,9 +453,11 @@ class LTX2VideoDecoder(nn.Module):
self, self,
sample: mx.array, sample: mx.array,
tiling_config: Optional[TilingConfig] = None, tiling_config: Optional[TilingConfig] = None,
tiling_mode: str = "auto",
causal: bool = False, causal: bool = False,
timestep: Optional[mx.array] = None, timestep: Optional[mx.array] = None,
debug: bool = False, debug: bool = False,
on_frames_ready: Optional[callable] = None,
) -> mx.array: ) -> mx.array:
"""Decode latents using tiling to reduce memory usage. """Decode latents using tiling to reduce memory usage.
@@ -495,14 +500,13 @@ class LTX2VideoDecoder(nn.Module):
if f > tile_size_latent: if f > tile_size_latent:
needs_temporal_tiling = True needs_temporal_tiling = True
# Auto-enable chunked conv for modes where it helps (larger tiles)
# Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks
use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial")
if not needs_spatial_tiling and not needs_temporal_tiling: if not needs_spatial_tiling and not needs_temporal_tiling:
# No tiling needed, use regular decode # No tiling needed, use regular decode
if debug: return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
print("[Tiling] Input fits within tile size, using regular decode")
return self(sample, causal=causal, timestep=timestep, debug=debug)
if debug:
print(f"[Tiling] Using tiled decode (spatial={needs_spatial_tiling}, temporal={needs_temporal_tiling})")
return decode_with_tiling( return decode_with_tiling(
decoder_fn=self, decoder_fn=self,
@@ -512,7 +516,8 @@ class LTX2VideoDecoder(nn.Module):
temporal_scale=8, # VAE temporal upsampling factor temporal_scale=8, # VAE temporal upsampling factor
causal=causal, causal=causal,
timestep=timestep, timestep=timestep,
debug=debug, chunked_conv=use_chunked_conv,
on_frames_ready=on_frames_ready,
) )

View File

@@ -156,7 +156,7 @@ class DepthToSpaceUpsample(nn.Module):
return x return x
def __call__(self, x: mx.array, causal: bool = True) -> mx.array: def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array:
b, c, d, h, w = x.shape b, c, d, h, w = x.shape
st, sh, sw = self.stride st, sh, sw = self.stride
@@ -177,9 +177,12 @@ class DepthToSpaceUpsample(nn.Module):
if st > 1: if st > 1:
x_residual = x_residual[:, :, 1:, :, :] x_residual = x_residual[:, :, 1:, :, :]
# Use chunked mode for large tensors to reduce peak memory
if chunked_conv and d > 4:
x = self._chunked_conv_depth_to_space(x, causal)
else:
# Apply conv # Apply conv
x = self.conv(x, causal=causal) x = self.conv(x, causal=causal)
# Depth to space rearrangement # Depth to space rearrangement
x = self._depth_to_space(x) x = self._depth_to_space(x)
@@ -192,3 +195,81 @@ class DepthToSpaceUpsample(nn.Module):
x = x + x_residual x = x + x_residual
return x return x
def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array:
"""Chunked conv + depth_to_space that processes in temporal chunks.
This reduces peak memory by avoiding the full high-channel intermediate tensor.
Instead of materializing (B, 4096, D, H, W), we process temporal chunks and
immediately apply depth_to_space.
Args:
x: Input tensor of shape (B, C, D, H, W)
causal: Whether to use causal convolutions
Returns:
Output tensor after conv + depth_to_space
"""
b, c, d, h, w = x.shape
st, sh, sw = self.stride
out_c = self.out_channels
# Output dimensions
out_d = d * st
out_h = h * sh
out_w = w * sw
# Chunk size in temporal dimension (process 4 frames at a time)
chunk_size = 4
kernel_t = 3 # Temporal kernel size
# For causal conv, we need (kernel_t - 1) frames of padding at the start
# For non-causal, we need (kernel_t - 1) // 2 on each side
if causal:
# Pad start with first frame repeated
pad_start = kernel_t - 1
pad_end = 0
else:
pad_start = (kernel_t - 1) // 2
pad_end = (kernel_t - 1) // 2
# Allocate output
outputs = []
# Process in chunks with overlap for conv kernel
t_pos = 0
while t_pos < d:
t_end = min(t_pos + chunk_size, d)
# Calculate input range with padding for kernel
in_start = max(0, t_pos - pad_start)
in_end = min(d, t_end + pad_end)
# Extract chunk
chunk = x[:, :, in_start:in_end, :, :]
# Apply conv to chunk
chunk_conv = self.conv(chunk, causal=causal)
# Apply depth_to_space
chunk_out = self._depth_to_space(chunk_conv)
# Calculate valid output range (excluding padding effects)
# Each input frame produces st output frames
out_start = (t_pos - in_start) * st
out_end = out_start + (t_end - t_pos) * st
# Extract valid portion
chunk_out = chunk_out[:, :, out_start:out_end, :, :]
outputs.append(chunk_out)
# Evaluate to free intermediate memory
mx.eval(outputs[-1])
t_pos = t_end
# Concatenate all chunks
if len(outputs) == 1:
return outputs[0]
return mx.concatenate(outputs, axis=2)

View File

@@ -8,8 +8,8 @@ Default configuration (from PyTorch):
- Temporal: 64 frames with 24 frame overlap - Temporal: 64 frames with 24 frame overlap
""" """
from dataclasses import dataclass, replace from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import Callable, List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
@@ -284,7 +284,8 @@ def decode_with_tiling(
temporal_scale: int = 8, temporal_scale: int = 8,
causal: bool = False, causal: bool = False,
timestep: Optional[mx.array] = None, timestep: Optional[mx.array] = None,
debug: bool = False, chunked_conv: bool = False,
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
) -> mx.array: ) -> mx.array:
"""Decode latents using tiling to reduce memory usage. """Decode latents using tiling to reduce memory usage.
@@ -296,7 +297,10 @@ def decode_with_tiling(
temporal_scale: Temporal scale factor (8 for LTX VAE). temporal_scale: Temporal scale factor (8 for LTX VAE).
causal: Whether to use causal convolutions. causal: Whether to use causal convolutions.
timestep: Optional timestep for conditioning. timestep: Optional timestep for conditioning.
debug: Whether to print debug info. chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames.
start_idx: Starting frame index in the full video.
Returns: Returns:
Decoded video. Decoded video.
@@ -337,10 +341,6 @@ def decode_with_tiling(
num_w_tiles = len(width_intervals.starts) num_w_tiles = len(width_intervals.starts)
total_tiles = num_t_tiles * num_h_tiles * num_w_tiles total_tiles = num_t_tiles * num_h_tiles * num_w_tiles
if debug:
print(f"[Tiling] Latent shape: {latents.shape}, Output shape: ({b}, 3, {out_f}, {out_h}, {out_w})")
print(f"[Tiling] Tiles: {num_t_tiles} temporal x {num_h_tiles} height x {num_w_tiles} width = {total_tiles}")
# Initialize output and weight accumulator # Initialize output and weight accumulator
# Use float32 for accumulation to avoid precision issues # Use float32 for accumulation to avoid precision issues
output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32) output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32)
@@ -375,15 +375,11 @@ def decode_with_tiling(
# Map width coordinates # Map width coordinates
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale) out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
if debug:
print(f"[Tiling] Tile {tile_idx + 1}/{total_tiles}: "
f"latent t=[{t_start},{t_end}) h=[{h_start},{h_end}) w=[{w_start},{w_end})")
# Extract tile latents (small slice) # Extract tile latents (small slice)
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end] tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
# Decode tile # Decode tile
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False) tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
mx.eval(tile_output) mx.eval(tile_output)
# Clear tile_latents reference # Clear tile_latents reference
@@ -454,17 +450,60 @@ def decode_with_tiling(
except Exception: except Exception:
pass # May not be available on all platforms pass # May not be available on all platforms
# After completing all spatial tiles for this temporal tile,
# check if any frames are now finalized (no future tiles will contribute)
if on_frames_ready is not None and num_t_tiles > 1:
# Determine the finalized frame boundary
# Frames before the start of the next tile's output region are finalized
if t_idx < num_t_tiles - 1:
# Next tile starts at temporal_intervals.starts[t_idx + 1]
next_tile_start_latent = temporal_intervals.starts[t_idx + 1]
# Map to output frame index (first frame of next tile's contribution)
if next_tile_start_latent == 0:
next_tile_start_out = 0
else:
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
# We need to track how many frames we've already emitted
if not hasattr(decode_with_tiling, '_emitted_frames'):
decode_with_tiling._emitted_frames = 0
emitted = decode_with_tiling._emitted_frames
if next_tile_start_out > emitted:
# Normalize and emit frames [emitted, next_tile_start_out)
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
finalized_weights = mx.maximum(finalized_weights, 1e-8)
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
finalized_output = finalized_output.astype(latents.dtype)
mx.eval(finalized_output)
on_frames_ready(finalized_output, emitted)
decode_with_tiling._emitted_frames = next_tile_start_out
del finalized_output, finalized_weights
gc.collect()
# Normalize by weights # Normalize by weights
weights = mx.maximum(weights, 1e-8) weights = mx.maximum(weights, 1e-8)
output = output / weights output = output / weights
mx.eval(output) mx.eval(output)
# Emit remaining frames if callback provided
if on_frames_ready is not None:
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
if emitted < out_f:
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
mx.eval(remaining_output)
on_frames_ready(remaining_output, emitted)
del remaining_output
# Reset emitted frames counter for next call
if hasattr(decode_with_tiling, '_emitted_frames'):
del decode_with_tiling._emitted_frames
# Clean up weights # Clean up weights
del weights del weights
gc.collect() gc.collect()
if debug:
print(f"[Tiling] Done. Final shape: {output.shape}")
# Convert back to original dtype if needed # Convert back to original dtype if needed
return output.astype(latents.dtype) return output.astype(latents.dtype)

301
tests/test_vae_streaming.py Normal file
View File

@@ -0,0 +1,301 @@
"""Tests for VAE streaming and chunked conv features."""
import pytest
import mlx.core as mx
import numpy as np
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx.video_vae.tiling import (
TilingConfig,
compute_trapezoidal_mask_1d,
decode_with_tiling,
)
class TestChunkedConv:
"""Tests for chunked conv optimization in DepthToSpaceUpsample."""
def test_chunked_conv_output_matches_regular(self):
"""Verify chunked_conv produces identical output to regular processing."""
mx.random.seed(42)
# Create upsampler with residual (matches decoder config)
upsampler = DepthToSpaceUpsample(
dims=3,
in_channels=256,
stride=(2, 2, 2),
residual=True,
out_channels_reduction_factor=2,
)
# Initialize weights deterministically
mx.eval(upsampler.parameters())
# Create test input: (B, C, D, H, W) with enough frames to trigger chunking
# chunked_conv activates when d > 4
x = mx.random.normal((1, 256, 8, 8, 8))
mx.eval(x)
# Run without chunked conv
out_regular = upsampler(x, causal=True, chunked_conv=False)
mx.eval(out_regular)
# Run with chunked conv
out_chunked = upsampler(x, causal=True, chunked_conv=True)
mx.eval(out_chunked)
# Outputs should be identical
np.testing.assert_allclose(
np.array(out_regular),
np.array(out_chunked),
rtol=1e-5,
atol=1e-5,
err_msg="Chunked conv output differs from regular output"
)
def test_chunked_conv_small_input_passthrough(self):
"""Verify chunked_conv doesn't activate for small inputs (d <= 4)."""
mx.random.seed(42)
upsampler = DepthToSpaceUpsample(
dims=3,
in_channels=256,
stride=(2, 2, 2),
residual=True,
out_channels_reduction_factor=2,
)
mx.eval(upsampler.parameters())
# Small input with d=4 (should NOT trigger chunking)
x = mx.random.normal((1, 256, 4, 8, 8))
mx.eval(x)
out_regular = upsampler(x, causal=True, chunked_conv=False)
out_chunked = upsampler(x, causal=True, chunked_conv=True)
mx.eval(out_regular, out_chunked)
# Should be identical since chunking doesn't activate
np.testing.assert_allclose(
np.array(out_regular),
np.array(out_chunked),
rtol=1e-5,
atol=1e-5,
)
def test_chunked_conv_output_shape(self):
"""Verify chunked_conv produces correct output shape."""
mx.random.seed(42)
upsampler = DepthToSpaceUpsample(
dims=3,
in_channels=256,
stride=(2, 2, 2),
residual=True,
out_channels_reduction_factor=2,
)
mx.eval(upsampler.parameters())
# Input shape: (1, 256, 8, 16, 16)
x = mx.random.normal((1, 256, 8, 16, 16))
mx.eval(x)
out = upsampler(x, causal=True, chunked_conv=True)
mx.eval(out)
# Expected output:
# - Channels: 256 / 2 = 128
# - Temporal: 8 * 2 - 1 = 15 (minus 1 for causal)
# - Spatial: 16 * 2 = 32
assert out.shape == (1, 128, 15, 32, 32), f"Unexpected shape: {out.shape}"
class TestProgressiveFrameSaving:
"""Tests for progressive frame saving via on_frames_ready callback."""
def test_on_frames_ready_called(self):
"""Verify on_frames_ready callback is called during tiled decoding."""
frames_received = []
def on_frames_ready(frames: mx.array, start_idx: int):
frames_received.append({
'shape': frames.shape,
'start_idx': start_idx,
})
# Create a mock decoder that just returns scaled input
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False):
# Simulate VAE output: upsample 8x temporal, 32x spatial
b, c, f, h, w = x.shape
out_f = 1 + (f - 1) * 8
out_h = h * 32
out_w = w * 32
return mx.zeros((b, 3, out_f, out_h, out_w))
# Create tiling config with temporal tiling to trigger callbacks
tiling_config = TilingConfig.temporal_only(tile_size=32, overlap=8)
# Input latents: enough frames to require multiple tiles
# 10 latent frames -> 73 output frames
latents = mx.zeros((1, 128, 10, 4, 4))
# Decode with tiling
output = decode_with_tiling(
decoder_fn=mock_decoder,
latents=latents,
tiling_config=tiling_config,
spatial_scale=32,
temporal_scale=8,
on_frames_ready=on_frames_ready,
)
mx.eval(output)
# Should have received at least one callback
assert len(frames_received) > 0, "on_frames_ready was never called"
# All received frames should have correct channel count
for received in frames_received:
assert received['shape'][1] == 3, f"Expected 3 channels, got {received['shape'][1]}"
def test_on_frames_ready_covers_all_frames(self):
"""Verify all frames are emitted via callbacks."""
all_frame_indices = set()
def on_frames_ready(frames: mx.array, start_idx: int):
num_frames = frames.shape[2]
for i in range(num_frames):
all_frame_indices.add(start_idx + i)
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False):
b, c, f, h, w = x.shape
out_f = 1 + (f - 1) * 8
out_h = h * 32
out_w = w * 32
return mx.random.normal((b, 3, out_f, out_h, out_w))
tiling_config = TilingConfig.temporal_only(tile_size=32, overlap=8)
# 12 latent frames -> 89 output frames
latents = mx.zeros((1, 128, 12, 4, 4))
output = decode_with_tiling(
decoder_fn=mock_decoder,
latents=latents,
tiling_config=tiling_config,
spatial_scale=32,
temporal_scale=8,
on_frames_ready=on_frames_ready,
)
mx.eval(output)
# Calculate expected frame count
expected_frames = 1 + (12 - 1) * 8 # 89 frames
# All frames should have been emitted
assert len(all_frame_indices) == expected_frames, \
f"Expected {expected_frames} frames, got {len(all_frame_indices)}"
assert all_frame_indices == set(range(expected_frames)), \
"Not all frame indices were covered"
class TestAutoChunkedConv:
"""Tests for auto-enabling chunked_conv based on tiling mode."""
@pytest.mark.parametrize("tiling_mode,should_enable", [
("conservative", True),
("none", True),
("auto", True),
("default", True),
("spatial", True),
("aggressive", False),
("temporal", False),
])
def test_chunked_conv_auto_enable(self, tiling_mode: str, should_enable: bool):
"""Verify chunked_conv is auto-enabled for correct tiling modes."""
# The logic is: tiling_mode in ("conservative", "none", "auto", "default", "spatial")
expected_modes = {"conservative", "none", "auto", "default", "spatial"}
use_chunked_conv = tiling_mode in expected_modes
assert use_chunked_conv == should_enable, \
f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}"
class TestTrapezoidalMask:
"""Tests for trapezoidal blending mask generation."""
def test_mask_values_in_range(self):
"""Verify mask values are always in [0, 1]."""
for length in [16, 32, 64, 128]:
for ramp in [0, 4, 8, 16]:
if ramp < length:
mask = compute_trapezoidal_mask_1d(length, ramp, ramp, False)
assert mx.all(mask >= 0).item(), f"Mask has negative values"
assert mx.all(mask <= 1).item(), f"Mask has values > 1"
def test_mask_center_is_one(self):
"""Verify center of mask is 1.0 when ramps don't overlap."""
mask = compute_trapezoidal_mask_1d(32, 8, 8, False)
# Center region should be 1.0
center = mask[12:20] # Middle portion
np.testing.assert_allclose(np.array(center), 1.0, rtol=1e-5)
def test_mask_ramp_monotonic(self):
"""Verify ramps are monotonically increasing/decreasing."""
mask = compute_trapezoidal_mask_1d(32, 8, 8, False)
mask_np = np.array(mask)
# Left ramp should be increasing
left_ramp = mask_np[:8]
assert np.all(np.diff(left_ramp) >= 0), "Left ramp not monotonically increasing"
# Right ramp should be decreasing
right_ramp = mask_np[-8:]
assert np.all(np.diff(right_ramp) <= 0), "Right ramp not monotonically decreasing"
def test_temporal_mask_starts_from_zero(self):
"""Verify temporal mask (left_starts_from_0=True) starts from 0."""
mask = compute_trapezoidal_mask_1d(32, 8, 0, left_starts_from_0=True)
assert mask[0].item() == 0.0, "Temporal mask should start from 0"
def test_spatial_mask_starts_above_zero(self):
"""Verify spatial mask (left_starts_from_0=False) starts above 0."""
mask = compute_trapezoidal_mask_1d(32, 8, 0, left_starts_from_0=False)
assert mask[0].item() > 0.0, "Spatial mask should start above 0"
class TestTilingConfig:
"""Tests for TilingConfig presets."""
def test_default_config(self):
"""Verify default tiling configuration."""
config = TilingConfig.default()
assert config.spatial_config is not None
assert config.temporal_config is not None
assert config.spatial_config.tile_size_in_pixels == 512
assert config.temporal_config.tile_size_in_frames == 64
def test_aggressive_config(self):
"""Verify aggressive tiling configuration."""
config = TilingConfig.aggressive()
assert config.spatial_config.tile_size_in_pixels == 256
assert config.temporal_config.tile_size_in_frames == 32
def test_conservative_config(self):
"""Verify conservative tiling configuration."""
config = TilingConfig.conservative()
assert config.spatial_config.tile_size_in_pixels == 768
assert config.temporal_config.tile_size_in_frames == 96
def test_auto_returns_none_for_small_video(self):
"""Verify auto returns None for small videos."""
config = TilingConfig.auto(height=256, width=256, num_frames=33)
assert config is None, "Auto should return None for small videos"
def test_auto_returns_config_for_large_video(self):
"""Verify auto returns config for large videos."""
config = TilingConfig.auto(height=1024, width=768, num_frames=145)
assert config is not None, "Auto should return config for large videos"
if __name__ == "__main__":
pytest.main([__file__, "-v"])