Merge pull request #14 from Blaizzy/pc/add-streaming
Add --stream flag and chunked conv memory optimization for VAE decoding
This commit is contained in:
@@ -215,6 +215,7 @@ def generate_video(
|
||||
image_strength: float = 1.0,
|
||||
image_frame_idx: int = 0,
|
||||
tiling: str = "auto",
|
||||
stream: bool = False,
|
||||
):
|
||||
"""Generate video from text prompt, optionally conditioned on an image.
|
||||
|
||||
@@ -481,17 +482,59 @@ def generate_video(
|
||||
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
|
||||
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||
|
||||
# Save outputs
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Stream mode: write frames as they're decoded
|
||||
video_writer = None
|
||||
stream_pbar = None
|
||||
|
||||
if stream and tiling_config is not None:
|
||||
import cv2
|
||||
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
||||
video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
|
||||
stream_pbar = tqdm(total=num_frames, desc="Streaming", unit="frame")
|
||||
|
||||
def on_frames_ready(frames: mx.array, start_idx: int):
|
||||
"""Callback to write frames as they're finalized."""
|
||||
# frames: (B, 3, num_frames, H, W)
|
||||
frames = mx.squeeze(frames, axis=0) # (3, num_frames, H, W)
|
||||
frames = mx.transpose(frames, (1, 2, 3, 0)) # (num_frames, H, W, 3)
|
||||
frames = mx.clip((frames + 1.0) / 2.0, 0.0, 1.0)
|
||||
frames = (frames * 255).astype(mx.uint8)
|
||||
frames_np = np.array(frames)
|
||||
|
||||
for frame in frames_np:
|
||||
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||
stream_pbar.update(1)
|
||||
else:
|
||||
on_frames_ready = None
|
||||
|
||||
if tiling_config is not None:
|
||||
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
|
||||
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
|
||||
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
|
||||
video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, debug=verbose)
|
||||
video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose, on_frames_ready=on_frames_ready)
|
||||
else:
|
||||
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
|
||||
video = vae_decoder(latents)
|
||||
mx.eval(video)
|
||||
mx.clear_cache()
|
||||
|
||||
# Close progressive video writer if used
|
||||
if video_writer is not None:
|
||||
video_writer.release()
|
||||
if stream_pbar is not None:
|
||||
stream_pbar.close()
|
||||
print(f"{Colors.GREEN}✅ Streamed video to{Colors.RESET} {output_path}")
|
||||
# Still need video_np for save_frames option
|
||||
video = mx.squeeze(video, axis=0)
|
||||
video = mx.transpose(video, (1, 2, 3, 0))
|
||||
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
|
||||
video = (video * 255).astype(mx.uint8)
|
||||
video_np = np.array(video)
|
||||
else:
|
||||
# Convert to uint8 frames
|
||||
video = mx.squeeze(video, axis=0) # (C, F, H, W)
|
||||
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
|
||||
@@ -499,15 +542,12 @@ def generate_video(
|
||||
video = (video * 255).astype(mx.uint8)
|
||||
video_np = np.array(video)
|
||||
|
||||
# Save outputs
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save video normally
|
||||
try:
|
||||
import cv2
|
||||
height, width = video_np.shape[1], video_np.shape[2]
|
||||
h, w = video_np.shape[1], video_np.shape[2]
|
||||
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
||||
out = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
|
||||
out = cv2.VideoWriter(str(output_path), fourcc, fps, (w, h))
|
||||
for frame in video_np:
|
||||
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||
out.release()
|
||||
@@ -654,6 +694,11 @@ Examples:
|
||||
"auto=based on video size, none=disabled, default=512px/64f, "
|
||||
"aggressive=256px/32f (lowest memory), conservative=768px/96f, spatial=spatial only, temporal=temporal only"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stream",
|
||||
action="store_true",
|
||||
help="Stream frames to output file as they're decoded (requires tiling). Allows viewing partial results sooner."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
generate_video(
|
||||
|
||||
@@ -15,7 +15,7 @@ Architecture (from PyTorch weights):
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -364,6 +364,7 @@ class LTX2VideoDecoder(nn.Module):
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
debug: bool = False,
|
||||
chunked_conv: bool = False,
|
||||
) -> mx.array:
|
||||
|
||||
def debug_stats(name, t):
|
||||
@@ -403,6 +404,8 @@ class LTX2VideoDecoder(nn.Module):
|
||||
for i, block in self.up_blocks.items():
|
||||
if isinstance(block, ResBlockGroup):
|
||||
x = block(x, causal=causal, timestep=scaled_timestep)
|
||||
elif isinstance(block, DepthToSpaceUpsample):
|
||||
x = block(x, causal=causal, chunked_conv=chunked_conv)
|
||||
else:
|
||||
x = block(x, causal=causal)
|
||||
if debug:
|
||||
@@ -450,9 +453,11 @@ class LTX2VideoDecoder(nn.Module):
|
||||
self,
|
||||
sample: mx.array,
|
||||
tiling_config: Optional[TilingConfig] = None,
|
||||
tiling_mode: str = "auto",
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
debug: bool = False,
|
||||
on_frames_ready: Optional[callable] = None,
|
||||
) -> mx.array:
|
||||
"""Decode latents using tiling to reduce memory usage.
|
||||
|
||||
@@ -495,14 +500,13 @@ class LTX2VideoDecoder(nn.Module):
|
||||
if f > tile_size_latent:
|
||||
needs_temporal_tiling = True
|
||||
|
||||
# Auto-enable chunked conv for modes where it helps (larger tiles)
|
||||
# Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks
|
||||
use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial")
|
||||
|
||||
if not needs_spatial_tiling and not needs_temporal_tiling:
|
||||
# No tiling needed, use regular decode
|
||||
if debug:
|
||||
print("[Tiling] Input fits within tile size, using regular decode")
|
||||
return self(sample, causal=causal, timestep=timestep, debug=debug)
|
||||
|
||||
if debug:
|
||||
print(f"[Tiling] Using tiled decode (spatial={needs_spatial_tiling}, temporal={needs_temporal_tiling})")
|
||||
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
|
||||
|
||||
return decode_with_tiling(
|
||||
decoder_fn=self,
|
||||
@@ -512,7 +516,8 @@ class LTX2VideoDecoder(nn.Module):
|
||||
temporal_scale=8, # VAE temporal upsampling factor
|
||||
causal=causal,
|
||||
timestep=timestep,
|
||||
debug=debug,
|
||||
chunked_conv=use_chunked_conv,
|
||||
on_frames_ready=on_frames_ready,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -156,7 +156,7 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
|
||||
def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array:
|
||||
|
||||
b, c, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
@@ -177,9 +177,12 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
if st > 1:
|
||||
x_residual = x_residual[:, :, 1:, :, :]
|
||||
|
||||
# Use chunked mode for large tensors to reduce peak memory
|
||||
if chunked_conv and d > 4:
|
||||
x = self._chunked_conv_depth_to_space(x, causal)
|
||||
else:
|
||||
# Apply conv
|
||||
x = self.conv(x, causal=causal)
|
||||
|
||||
# Depth to space rearrangement
|
||||
x = self._depth_to_space(x)
|
||||
|
||||
@@ -192,3 +195,81 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
x = x + x_residual
|
||||
|
||||
return x
|
||||
|
||||
def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array:
|
||||
"""Chunked conv + depth_to_space that processes in temporal chunks.
|
||||
|
||||
This reduces peak memory by avoiding the full high-channel intermediate tensor.
|
||||
Instead of materializing (B, 4096, D, H, W), we process temporal chunks and
|
||||
immediately apply depth_to_space.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, C, D, H, W)
|
||||
causal: Whether to use causal convolutions
|
||||
|
||||
Returns:
|
||||
Output tensor after conv + depth_to_space
|
||||
"""
|
||||
b, c, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
out_c = self.out_channels
|
||||
|
||||
# Output dimensions
|
||||
out_d = d * st
|
||||
out_h = h * sh
|
||||
out_w = w * sw
|
||||
|
||||
# Chunk size in temporal dimension (process 4 frames at a time)
|
||||
chunk_size = 4
|
||||
kernel_t = 3 # Temporal kernel size
|
||||
|
||||
# For causal conv, we need (kernel_t - 1) frames of padding at the start
|
||||
# For non-causal, we need (kernel_t - 1) // 2 on each side
|
||||
if causal:
|
||||
# Pad start with first frame repeated
|
||||
pad_start = kernel_t - 1
|
||||
pad_end = 0
|
||||
else:
|
||||
pad_start = (kernel_t - 1) // 2
|
||||
pad_end = (kernel_t - 1) // 2
|
||||
|
||||
# Allocate output
|
||||
outputs = []
|
||||
|
||||
# Process in chunks with overlap for conv kernel
|
||||
t_pos = 0
|
||||
while t_pos < d:
|
||||
t_end = min(t_pos + chunk_size, d)
|
||||
|
||||
# Calculate input range with padding for kernel
|
||||
in_start = max(0, t_pos - pad_start)
|
||||
in_end = min(d, t_end + pad_end)
|
||||
|
||||
# Extract chunk
|
||||
chunk = x[:, :, in_start:in_end, :, :]
|
||||
|
||||
# Apply conv to chunk
|
||||
chunk_conv = self.conv(chunk, causal=causal)
|
||||
|
||||
# Apply depth_to_space
|
||||
chunk_out = self._depth_to_space(chunk_conv)
|
||||
|
||||
# Calculate valid output range (excluding padding effects)
|
||||
# Each input frame produces st output frames
|
||||
out_start = (t_pos - in_start) * st
|
||||
out_end = out_start + (t_end - t_pos) * st
|
||||
|
||||
# Extract valid portion
|
||||
chunk_out = chunk_out[:, :, out_start:out_end, :, :]
|
||||
|
||||
outputs.append(chunk_out)
|
||||
|
||||
# Evaluate to free intermediate memory
|
||||
mx.eval(outputs[-1])
|
||||
|
||||
t_pos = t_end
|
||||
|
||||
# Concatenate all chunks
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return mx.concatenate(outputs, axis=2)
|
||||
|
||||
@@ -8,8 +8,8 @@ Default configuration (from PyTorch):
|
||||
- Temporal: 64 frames with 24 frame overlap
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
@@ -284,7 +284,8 @@ def decode_with_tiling(
|
||||
temporal_scale: int = 8,
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
debug: bool = False,
|
||||
chunked_conv: bool = False,
|
||||
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
|
||||
) -> mx.array:
|
||||
"""Decode latents using tiling to reduce memory usage.
|
||||
|
||||
@@ -296,7 +297,10 @@ def decode_with_tiling(
|
||||
temporal_scale: Temporal scale factor (8 for LTX VAE).
|
||||
causal: Whether to use causal convolutions.
|
||||
timestep: Optional timestep for conditioning.
|
||||
debug: Whether to print debug info.
|
||||
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
|
||||
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
|
||||
frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames.
|
||||
start_idx: Starting frame index in the full video.
|
||||
|
||||
Returns:
|
||||
Decoded video.
|
||||
@@ -337,10 +341,6 @@ def decode_with_tiling(
|
||||
num_w_tiles = len(width_intervals.starts)
|
||||
total_tiles = num_t_tiles * num_h_tiles * num_w_tiles
|
||||
|
||||
if debug:
|
||||
print(f"[Tiling] Latent shape: {latents.shape}, Output shape: ({b}, 3, {out_f}, {out_h}, {out_w})")
|
||||
print(f"[Tiling] Tiles: {num_t_tiles} temporal x {num_h_tiles} height x {num_w_tiles} width = {total_tiles}")
|
||||
|
||||
# Initialize output and weight accumulator
|
||||
# Use float32 for accumulation to avoid precision issues
|
||||
output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32)
|
||||
@@ -375,15 +375,11 @@ def decode_with_tiling(
|
||||
# Map width coordinates
|
||||
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
|
||||
|
||||
if debug:
|
||||
print(f"[Tiling] Tile {tile_idx + 1}/{total_tiles}: "
|
||||
f"latent t=[{t_start},{t_end}) h=[{h_start},{h_end}) w=[{w_start},{w_end})")
|
||||
|
||||
# Extract tile latents (small slice)
|
||||
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
|
||||
|
||||
# Decode tile
|
||||
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False)
|
||||
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
|
||||
mx.eval(tile_output)
|
||||
|
||||
# Clear tile_latents reference
|
||||
@@ -454,17 +450,60 @@ def decode_with_tiling(
|
||||
except Exception:
|
||||
pass # May not be available on all platforms
|
||||
|
||||
# After completing all spatial tiles for this temporal tile,
|
||||
# check if any frames are now finalized (no future tiles will contribute)
|
||||
if on_frames_ready is not None and num_t_tiles > 1:
|
||||
# Determine the finalized frame boundary
|
||||
# Frames before the start of the next tile's output region are finalized
|
||||
if t_idx < num_t_tiles - 1:
|
||||
# Next tile starts at temporal_intervals.starts[t_idx + 1]
|
||||
next_tile_start_latent = temporal_intervals.starts[t_idx + 1]
|
||||
# Map to output frame index (first frame of next tile's contribution)
|
||||
if next_tile_start_latent == 0:
|
||||
next_tile_start_out = 0
|
||||
else:
|
||||
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
|
||||
|
||||
# We need to track how many frames we've already emitted
|
||||
if not hasattr(decode_with_tiling, '_emitted_frames'):
|
||||
decode_with_tiling._emitted_frames = 0
|
||||
emitted = decode_with_tiling._emitted_frames
|
||||
|
||||
if next_tile_start_out > emitted:
|
||||
# Normalize and emit frames [emitted, next_tile_start_out)
|
||||
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
|
||||
finalized_weights = mx.maximum(finalized_weights, 1e-8)
|
||||
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
|
||||
finalized_output = finalized_output.astype(latents.dtype)
|
||||
mx.eval(finalized_output)
|
||||
|
||||
on_frames_ready(finalized_output, emitted)
|
||||
decode_with_tiling._emitted_frames = next_tile_start_out
|
||||
|
||||
del finalized_output, finalized_weights
|
||||
gc.collect()
|
||||
|
||||
# Normalize by weights
|
||||
weights = mx.maximum(weights, 1e-8)
|
||||
output = output / weights
|
||||
mx.eval(output)
|
||||
|
||||
# Emit remaining frames if callback provided
|
||||
if on_frames_ready is not None:
|
||||
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
|
||||
if emitted < out_f:
|
||||
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
||||
mx.eval(remaining_output)
|
||||
on_frames_ready(remaining_output, emitted)
|
||||
del remaining_output
|
||||
|
||||
# Reset emitted frames counter for next call
|
||||
if hasattr(decode_with_tiling, '_emitted_frames'):
|
||||
del decode_with_tiling._emitted_frames
|
||||
|
||||
# Clean up weights
|
||||
del weights
|
||||
gc.collect()
|
||||
|
||||
if debug:
|
||||
print(f"[Tiling] Done. Final shape: {output.shape}")
|
||||
|
||||
# Convert back to original dtype if needed
|
||||
return output.astype(latents.dtype)
|
||||
|
||||
301
tests/test_vae_streaming.py
Normal file
301
tests/test_vae_streaming.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""Tests for VAE streaming and chunked conv features."""
|
||||
|
||||
import pytest
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
|
||||
from mlx_video.models.ltx.video_vae.tiling import (
|
||||
TilingConfig,
|
||||
compute_trapezoidal_mask_1d,
|
||||
decode_with_tiling,
|
||||
)
|
||||
|
||||
|
||||
class TestChunkedConv:
|
||||
"""Tests for chunked conv optimization in DepthToSpaceUpsample."""
|
||||
|
||||
def test_chunked_conv_output_matches_regular(self):
|
||||
"""Verify chunked_conv produces identical output to regular processing."""
|
||||
mx.random.seed(42)
|
||||
|
||||
# Create upsampler with residual (matches decoder config)
|
||||
upsampler = DepthToSpaceUpsample(
|
||||
dims=3,
|
||||
in_channels=256,
|
||||
stride=(2, 2, 2),
|
||||
residual=True,
|
||||
out_channels_reduction_factor=2,
|
||||
)
|
||||
|
||||
# Initialize weights deterministically
|
||||
mx.eval(upsampler.parameters())
|
||||
|
||||
# Create test input: (B, C, D, H, W) with enough frames to trigger chunking
|
||||
# chunked_conv activates when d > 4
|
||||
x = mx.random.normal((1, 256, 8, 8, 8))
|
||||
mx.eval(x)
|
||||
|
||||
# Run without chunked conv
|
||||
out_regular = upsampler(x, causal=True, chunked_conv=False)
|
||||
mx.eval(out_regular)
|
||||
|
||||
# Run with chunked conv
|
||||
out_chunked = upsampler(x, causal=True, chunked_conv=True)
|
||||
mx.eval(out_chunked)
|
||||
|
||||
# Outputs should be identical
|
||||
np.testing.assert_allclose(
|
||||
np.array(out_regular),
|
||||
np.array(out_chunked),
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
err_msg="Chunked conv output differs from regular output"
|
||||
)
|
||||
|
||||
def test_chunked_conv_small_input_passthrough(self):
|
||||
"""Verify chunked_conv doesn't activate for small inputs (d <= 4)."""
|
||||
mx.random.seed(42)
|
||||
|
||||
upsampler = DepthToSpaceUpsample(
|
||||
dims=3,
|
||||
in_channels=256,
|
||||
stride=(2, 2, 2),
|
||||
residual=True,
|
||||
out_channels_reduction_factor=2,
|
||||
)
|
||||
mx.eval(upsampler.parameters())
|
||||
|
||||
# Small input with d=4 (should NOT trigger chunking)
|
||||
x = mx.random.normal((1, 256, 4, 8, 8))
|
||||
mx.eval(x)
|
||||
|
||||
out_regular = upsampler(x, causal=True, chunked_conv=False)
|
||||
out_chunked = upsampler(x, causal=True, chunked_conv=True)
|
||||
mx.eval(out_regular, out_chunked)
|
||||
|
||||
# Should be identical since chunking doesn't activate
|
||||
np.testing.assert_allclose(
|
||||
np.array(out_regular),
|
||||
np.array(out_chunked),
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
)
|
||||
|
||||
def test_chunked_conv_output_shape(self):
|
||||
"""Verify chunked_conv produces correct output shape."""
|
||||
mx.random.seed(42)
|
||||
|
||||
upsampler = DepthToSpaceUpsample(
|
||||
dims=3,
|
||||
in_channels=256,
|
||||
stride=(2, 2, 2),
|
||||
residual=True,
|
||||
out_channels_reduction_factor=2,
|
||||
)
|
||||
mx.eval(upsampler.parameters())
|
||||
|
||||
# Input shape: (1, 256, 8, 16, 16)
|
||||
x = mx.random.normal((1, 256, 8, 16, 16))
|
||||
mx.eval(x)
|
||||
|
||||
out = upsampler(x, causal=True, chunked_conv=True)
|
||||
mx.eval(out)
|
||||
|
||||
# Expected output:
|
||||
# - Channels: 256 / 2 = 128
|
||||
# - Temporal: 8 * 2 - 1 = 15 (minus 1 for causal)
|
||||
# - Spatial: 16 * 2 = 32
|
||||
assert out.shape == (1, 128, 15, 32, 32), f"Unexpected shape: {out.shape}"
|
||||
|
||||
|
||||
class TestProgressiveFrameSaving:
|
||||
"""Tests for progressive frame saving via on_frames_ready callback."""
|
||||
|
||||
def test_on_frames_ready_called(self):
|
||||
"""Verify on_frames_ready callback is called during tiled decoding."""
|
||||
frames_received = []
|
||||
|
||||
def on_frames_ready(frames: mx.array, start_idx: int):
|
||||
frames_received.append({
|
||||
'shape': frames.shape,
|
||||
'start_idx': start_idx,
|
||||
})
|
||||
|
||||
# Create a mock decoder that just returns scaled input
|
||||
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False):
|
||||
# Simulate VAE output: upsample 8x temporal, 32x spatial
|
||||
b, c, f, h, w = x.shape
|
||||
out_f = 1 + (f - 1) * 8
|
||||
out_h = h * 32
|
||||
out_w = w * 32
|
||||
return mx.zeros((b, 3, out_f, out_h, out_w))
|
||||
|
||||
# Create tiling config with temporal tiling to trigger callbacks
|
||||
tiling_config = TilingConfig.temporal_only(tile_size=32, overlap=8)
|
||||
|
||||
# Input latents: enough frames to require multiple tiles
|
||||
# 10 latent frames -> 73 output frames
|
||||
latents = mx.zeros((1, 128, 10, 4, 4))
|
||||
|
||||
# Decode with tiling
|
||||
output = decode_with_tiling(
|
||||
decoder_fn=mock_decoder,
|
||||
latents=latents,
|
||||
tiling_config=tiling_config,
|
||||
spatial_scale=32,
|
||||
temporal_scale=8,
|
||||
on_frames_ready=on_frames_ready,
|
||||
)
|
||||
mx.eval(output)
|
||||
|
||||
# Should have received at least one callback
|
||||
assert len(frames_received) > 0, "on_frames_ready was never called"
|
||||
|
||||
# All received frames should have correct channel count
|
||||
for received in frames_received:
|
||||
assert received['shape'][1] == 3, f"Expected 3 channels, got {received['shape'][1]}"
|
||||
|
||||
def test_on_frames_ready_covers_all_frames(self):
|
||||
"""Verify all frames are emitted via callbacks."""
|
||||
all_frame_indices = set()
|
||||
|
||||
def on_frames_ready(frames: mx.array, start_idx: int):
|
||||
num_frames = frames.shape[2]
|
||||
for i in range(num_frames):
|
||||
all_frame_indices.add(start_idx + i)
|
||||
|
||||
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False):
|
||||
b, c, f, h, w = x.shape
|
||||
out_f = 1 + (f - 1) * 8
|
||||
out_h = h * 32
|
||||
out_w = w * 32
|
||||
return mx.random.normal((b, 3, out_f, out_h, out_w))
|
||||
|
||||
tiling_config = TilingConfig.temporal_only(tile_size=32, overlap=8)
|
||||
|
||||
# 12 latent frames -> 89 output frames
|
||||
latents = mx.zeros((1, 128, 12, 4, 4))
|
||||
|
||||
output = decode_with_tiling(
|
||||
decoder_fn=mock_decoder,
|
||||
latents=latents,
|
||||
tiling_config=tiling_config,
|
||||
spatial_scale=32,
|
||||
temporal_scale=8,
|
||||
on_frames_ready=on_frames_ready,
|
||||
)
|
||||
mx.eval(output)
|
||||
|
||||
# Calculate expected frame count
|
||||
expected_frames = 1 + (12 - 1) * 8 # 89 frames
|
||||
|
||||
# All frames should have been emitted
|
||||
assert len(all_frame_indices) == expected_frames, \
|
||||
f"Expected {expected_frames} frames, got {len(all_frame_indices)}"
|
||||
assert all_frame_indices == set(range(expected_frames)), \
|
||||
"Not all frame indices were covered"
|
||||
|
||||
|
||||
class TestAutoChunkedConv:
|
||||
"""Tests for auto-enabling chunked_conv based on tiling mode."""
|
||||
|
||||
@pytest.mark.parametrize("tiling_mode,should_enable", [
|
||||
("conservative", True),
|
||||
("none", True),
|
||||
("auto", True),
|
||||
("default", True),
|
||||
("spatial", True),
|
||||
("aggressive", False),
|
||||
("temporal", False),
|
||||
])
|
||||
def test_chunked_conv_auto_enable(self, tiling_mode: str, should_enable: bool):
|
||||
"""Verify chunked_conv is auto-enabled for correct tiling modes."""
|
||||
# The logic is: tiling_mode in ("conservative", "none", "auto", "default", "spatial")
|
||||
expected_modes = {"conservative", "none", "auto", "default", "spatial"}
|
||||
|
||||
use_chunked_conv = tiling_mode in expected_modes
|
||||
|
||||
assert use_chunked_conv == should_enable, \
|
||||
f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}"
|
||||
|
||||
|
||||
class TestTrapezoidalMask:
|
||||
"""Tests for trapezoidal blending mask generation."""
|
||||
|
||||
def test_mask_values_in_range(self):
|
||||
"""Verify mask values are always in [0, 1]."""
|
||||
for length in [16, 32, 64, 128]:
|
||||
for ramp in [0, 4, 8, 16]:
|
||||
if ramp < length:
|
||||
mask = compute_trapezoidal_mask_1d(length, ramp, ramp, False)
|
||||
assert mx.all(mask >= 0).item(), f"Mask has negative values"
|
||||
assert mx.all(mask <= 1).item(), f"Mask has values > 1"
|
||||
|
||||
def test_mask_center_is_one(self):
|
||||
"""Verify center of mask is 1.0 when ramps don't overlap."""
|
||||
mask = compute_trapezoidal_mask_1d(32, 8, 8, False)
|
||||
# Center region should be 1.0
|
||||
center = mask[12:20] # Middle portion
|
||||
np.testing.assert_allclose(np.array(center), 1.0, rtol=1e-5)
|
||||
|
||||
def test_mask_ramp_monotonic(self):
|
||||
"""Verify ramps are monotonically increasing/decreasing."""
|
||||
mask = compute_trapezoidal_mask_1d(32, 8, 8, False)
|
||||
mask_np = np.array(mask)
|
||||
|
||||
# Left ramp should be increasing
|
||||
left_ramp = mask_np[:8]
|
||||
assert np.all(np.diff(left_ramp) >= 0), "Left ramp not monotonically increasing"
|
||||
|
||||
# Right ramp should be decreasing
|
||||
right_ramp = mask_np[-8:]
|
||||
assert np.all(np.diff(right_ramp) <= 0), "Right ramp not monotonically decreasing"
|
||||
|
||||
def test_temporal_mask_starts_from_zero(self):
|
||||
"""Verify temporal mask (left_starts_from_0=True) starts from 0."""
|
||||
mask = compute_trapezoidal_mask_1d(32, 8, 0, left_starts_from_0=True)
|
||||
assert mask[0].item() == 0.0, "Temporal mask should start from 0"
|
||||
|
||||
def test_spatial_mask_starts_above_zero(self):
|
||||
"""Verify spatial mask (left_starts_from_0=False) starts above 0."""
|
||||
mask = compute_trapezoidal_mask_1d(32, 8, 0, left_starts_from_0=False)
|
||||
assert mask[0].item() > 0.0, "Spatial mask should start above 0"
|
||||
|
||||
|
||||
class TestTilingConfig:
|
||||
"""Tests for TilingConfig presets."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Verify default tiling configuration."""
|
||||
config = TilingConfig.default()
|
||||
assert config.spatial_config is not None
|
||||
assert config.temporal_config is not None
|
||||
assert config.spatial_config.tile_size_in_pixels == 512
|
||||
assert config.temporal_config.tile_size_in_frames == 64
|
||||
|
||||
def test_aggressive_config(self):
|
||||
"""Verify aggressive tiling configuration."""
|
||||
config = TilingConfig.aggressive()
|
||||
assert config.spatial_config.tile_size_in_pixels == 256
|
||||
assert config.temporal_config.tile_size_in_frames == 32
|
||||
|
||||
def test_conservative_config(self):
|
||||
"""Verify conservative tiling configuration."""
|
||||
config = TilingConfig.conservative()
|
||||
assert config.spatial_config.tile_size_in_pixels == 768
|
||||
assert config.temporal_config.tile_size_in_frames == 96
|
||||
|
||||
def test_auto_returns_none_for_small_video(self):
|
||||
"""Verify auto returns None for small videos."""
|
||||
config = TilingConfig.auto(height=256, width=256, num_frames=33)
|
||||
assert config is None, "Auto should return None for small videos"
|
||||
|
||||
def test_auto_returns_config_for_large_video(self):
|
||||
"""Verify auto returns config for large videos."""
|
||||
config = TilingConfig.auto(height=1024, width=768, num_frames=145)
|
||||
assert config is not None, "Auto should return config for large videos"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user