From 7f20840dc7e82ba5cd2d70da283a780de5e756e0 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 17 Jan 2026 23:17:08 +0100 Subject: [PATCH 1/3] Add streaming support to video generation --- mlx_video/generate.py | 90 +++++++++++++++----- mlx_video/models/ltx/video_vae/decoder.py | 15 +++- mlx_video/models/ltx/video_vae/sampling.py | 95 ++++++++++++++++++++-- mlx_video/models/ltx/video_vae/tiling.py | 63 +++++++++++++- 4 files changed, 229 insertions(+), 34 deletions(-) diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 17f3770..c400081 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -215,6 +215,7 @@ def generate_video( image_strength: float = 1.0, image_frame_idx: int = 0, tiling: str = "auto", + stream: bool = False, ): """Generate video from text prompt, optionally conditioned on an image. @@ -481,39 +482,79 @@ def generate_video( print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}") tiling_config = TilingConfig.auto(height, width, num_frames) + # Save outputs + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Stream mode: write frames as they're decoded + video_writer = None + frames_written = [0] # Use list to allow mutation in closure + + if stream and tiling_config is not None: + import cv2 + fourcc = cv2.VideoWriter_fourcc(*'avc1') + video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height)) + + def on_frames_ready(frames: mx.array, start_idx: int): + """Callback to write frames as they're finalized.""" + # frames: (B, 3, num_frames, H, W) + frames = mx.squeeze(frames, axis=0) # (3, num_frames, H, W) + frames = mx.transpose(frames, (1, 2, 3, 0)) # (num_frames, H, W, 3) + frames = mx.clip((frames + 1.0) / 2.0, 0.0, 1.0) + frames = (frames * 255).astype(mx.uint8) + frames_np = np.array(frames) + + for frame in frames_np: + video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) + frames_written[0] += 1 + + print(f"{Colors.DIM} Progressive: wrote frames {start_idx}-{start_idx + len(frames_np) - 1} ({frames_written[0]} total){Colors.RESET}") + + print(f"{Colors.MAGENTA}📹 Streaming enabled - frames will be written as decoded{Colors.RESET}") + else: + on_frames_ready = None + if tiling_config is not None: spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none" temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none" print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}") - video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, debug=verbose) + video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose, on_frames_ready=on_frames_ready) else: print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}") video = vae_decoder(latents) mx.eval(video) mx.clear_cache() - # Convert to uint8 frames - video = mx.squeeze(video, axis=0) # (C, F, H, W) - video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C) - video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0) - video = (video * 255).astype(mx.uint8) - video_np = np.array(video) + # Close progressive video writer if used + if video_writer is not None: + video_writer.release() + print(f"{Colors.GREEN}✅ Streamed video to{Colors.RESET} {output_path}") + # Still need video_np for save_frames option + video = mx.squeeze(video, axis=0) + video = mx.transpose(video, (1, 2, 3, 0)) + video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0) + video = (video * 255).astype(mx.uint8) + video_np = np.array(video) + else: + # Convert to uint8 frames + video = mx.squeeze(video, axis=0) # (C, F, H, W) + video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C) + video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0) + video = (video * 255).astype(mx.uint8) + video_np = np.array(video) - # Save outputs - output_path = Path(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - - try: - import cv2 - height, width = video_np.shape[1], video_np.shape[2] - fourcc = cv2.VideoWriter_fourcc(*'avc1') - out = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height)) - for frame in video_np: - out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) - out.release() - print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}") - except Exception as e: - print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}") + # Save video normally + try: + import cv2 + h, w = video_np.shape[1], video_np.shape[2] + fourcc = cv2.VideoWriter_fourcc(*'avc1') + out = cv2.VideoWriter(str(output_path), fourcc, fps, (w, h)) + for frame in video_np: + out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) + out.release() + print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}") + except Exception as e: + print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}") if save_frames: frames_dir = output_path.parent / f"{output_path.stem}_frames" @@ -654,6 +695,11 @@ Examples: "auto=based on video size, none=disabled, default=512px/64f, " "aggressive=256px/32f (lowest memory), conservative=768px/96f, spatial=spatial only, temporal=temporal only" ) + parser.add_argument( + "--stream", + action="store_true", + help="Stream frames to output file as they're decoded (requires tiling). Allows viewing partial results sooner." + ) args = parser.parse_args() generate_video( diff --git a/mlx_video/models/ltx/video_vae/decoder.py b/mlx_video/models/ltx/video_vae/decoder.py index 0cb0d7b..464b873 100644 --- a/mlx_video/models/ltx/video_vae/decoder.py +++ b/mlx_video/models/ltx/video_vae/decoder.py @@ -15,7 +15,7 @@ Architecture (from PyTorch weights): """ import math -from typing import List, Optional +from typing import Optional import mlx.core as mx import mlx.nn as nn @@ -364,6 +364,7 @@ class LTX2VideoDecoder(nn.Module): causal: bool = False, timestep: Optional[mx.array] = None, debug: bool = False, + chunked_conv: bool = False, ) -> mx.array: def debug_stats(name, t): @@ -403,6 +404,8 @@ class LTX2VideoDecoder(nn.Module): for i, block in self.up_blocks.items(): if isinstance(block, ResBlockGroup): x = block(x, causal=causal, timestep=scaled_timestep) + elif isinstance(block, DepthToSpaceUpsample): + x = block(x, causal=causal, chunked_conv=chunked_conv) else: x = block(x, causal=causal) if debug: @@ -450,9 +453,11 @@ class LTX2VideoDecoder(nn.Module): self, sample: mx.array, tiling_config: Optional[TilingConfig] = None, + tiling_mode: str = "auto", causal: bool = False, timestep: Optional[mx.array] = None, debug: bool = False, + on_frames_ready: Optional[callable] = None, ) -> mx.array: """Decode latents using tiling to reduce memory usage. @@ -495,11 +500,15 @@ class LTX2VideoDecoder(nn.Module): if f > tile_size_latent: needs_temporal_tiling = True + # Auto-enable chunked conv for modes where it helps (larger tiles) + # Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks + use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial") + if not needs_spatial_tiling and not needs_temporal_tiling: # No tiling needed, use regular decode if debug: print("[Tiling] Input fits within tile size, using regular decode") - return self(sample, causal=causal, timestep=timestep, debug=debug) + return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv) if debug: print(f"[Tiling] Using tiled decode (spatial={needs_spatial_tiling}, temporal={needs_temporal_tiling})") @@ -513,6 +522,8 @@ class LTX2VideoDecoder(nn.Module): causal=causal, timestep=timestep, debug=debug, + chunked_conv=use_chunked_conv, + on_frames_ready=on_frames_ready, ) diff --git a/mlx_video/models/ltx/video_vae/sampling.py b/mlx_video/models/ltx/video_vae/sampling.py index 6ca3d41..76a96bf 100644 --- a/mlx_video/models/ltx/video_vae/sampling.py +++ b/mlx_video/models/ltx/video_vae/sampling.py @@ -156,8 +156,8 @@ class DepthToSpaceUpsample(nn.Module): return x - def __call__(self, x: mx.array, causal: bool = True) -> mx.array: - + def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array: + b, c, d, h, w = x.shape st, sh, sw = self.stride @@ -177,11 +177,14 @@ class DepthToSpaceUpsample(nn.Module): if st > 1: x_residual = x_residual[:, :, 1:, :, :] - # Apply conv - x = self.conv(x, causal=causal) - - # Depth to space rearrangement - x = self._depth_to_space(x) + # Use chunked mode for large tensors to reduce peak memory + if chunked_conv and d > 4: + x = self._chunked_conv_depth_to_space(x, causal) + else: + # Apply conv + x = self.conv(x, causal=causal) + # Depth to space rearrangement + x = self._depth_to_space(x) # Remove first frame for causal temporal upsampling if st > 1: @@ -192,3 +195,81 @@ class DepthToSpaceUpsample(nn.Module): x = x + x_residual return x + + def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array: + """Chunked conv + depth_to_space that processes in temporal chunks. + + This reduces peak memory by avoiding the full high-channel intermediate tensor. + Instead of materializing (B, 4096, D, H, W), we process temporal chunks and + immediately apply depth_to_space. + + Args: + x: Input tensor of shape (B, C, D, H, W) + causal: Whether to use causal convolutions + + Returns: + Output tensor after conv + depth_to_space + """ + b, c, d, h, w = x.shape + st, sh, sw = self.stride + out_c = self.out_channels + + # Output dimensions + out_d = d * st + out_h = h * sh + out_w = w * sw + + # Chunk size in temporal dimension (process 4 frames at a time) + chunk_size = 4 + kernel_t = 3 # Temporal kernel size + + # For causal conv, we need (kernel_t - 1) frames of padding at the start + # For non-causal, we need (kernel_t - 1) // 2 on each side + if causal: + # Pad start with first frame repeated + pad_start = kernel_t - 1 + pad_end = 0 + else: + pad_start = (kernel_t - 1) // 2 + pad_end = (kernel_t - 1) // 2 + + # Allocate output + outputs = [] + + # Process in chunks with overlap for conv kernel + t_pos = 0 + while t_pos < d: + t_end = min(t_pos + chunk_size, d) + + # Calculate input range with padding for kernel + in_start = max(0, t_pos - pad_start) + in_end = min(d, t_end + pad_end) + + # Extract chunk + chunk = x[:, :, in_start:in_end, :, :] + + # Apply conv to chunk + chunk_conv = self.conv(chunk, causal=causal) + + # Apply depth_to_space + chunk_out = self._depth_to_space(chunk_conv) + + # Calculate valid output range (excluding padding effects) + # Each input frame produces st output frames + out_start = (t_pos - in_start) * st + out_end = out_start + (t_end - t_pos) * st + + # Extract valid portion + chunk_out = chunk_out[:, :, out_start:out_end, :, :] + + outputs.append(chunk_out) + + # Evaluate to free intermediate memory + mx.eval(outputs[-1]) + + t_pos = t_end + + # Concatenate all chunks + if len(outputs) == 1: + return outputs[0] + return mx.concatenate(outputs, axis=2) diff --git a/mlx_video/models/ltx/video_vae/tiling.py b/mlx_video/models/ltx/video_vae/tiling.py index 20950fc..ee55026 100644 --- a/mlx_video/models/ltx/video_vae/tiling.py +++ b/mlx_video/models/ltx/video_vae/tiling.py @@ -8,8 +8,8 @@ Default configuration (from PyTorch): - Temporal: 64 frames with 24 frame overlap """ -from dataclasses import dataclass, replace -from typing import List, Optional, Tuple +from dataclasses import dataclass +from typing import Callable, List, Optional, Tuple import mlx.core as mx @@ -285,6 +285,8 @@ def decode_with_tiling( causal: bool = False, timestep: Optional[mx.array] = None, debug: bool = False, + chunked_conv: bool = False, + on_frames_ready: Optional[Callable[[mx.array, int], None]] = None, ) -> mx.array: """Decode latents using tiling to reduce memory usage. @@ -297,6 +299,10 @@ def decode_with_tiling( causal: Whether to use causal convolutions. timestep: Optional timestep for conditioning. debug: Whether to print debug info. + chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory). + on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized. + frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames. + start_idx: Starting frame index in the full video. Returns: Decoded video. @@ -383,7 +389,7 @@ def decode_with_tiling( tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end] # Decode tile - tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False) + tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv) mx.eval(tile_output) # Clear tile_latents reference @@ -454,11 +460,62 @@ def decode_with_tiling( except Exception: pass # May not be available on all platforms + # After completing all spatial tiles for this temporal tile, + # check if any frames are now finalized (no future tiles will contribute) + if on_frames_ready is not None and num_t_tiles > 1: + # Determine the finalized frame boundary + # Frames before the start of the next tile's output region are finalized + if t_idx < num_t_tiles - 1: + # Next tile starts at temporal_intervals.starts[t_idx + 1] + next_tile_start_latent = temporal_intervals.starts[t_idx + 1] + # Map to output frame index (first frame of next tile's contribution) + if next_tile_start_latent == 0: + next_tile_start_out = 0 + else: + next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale + + # We need to track how many frames we've already emitted + if not hasattr(decode_with_tiling, '_emitted_frames'): + decode_with_tiling._emitted_frames = 0 + emitted = decode_with_tiling._emitted_frames + + if next_tile_start_out > emitted: + # Normalize and emit frames [emitted, next_tile_start_out) + finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :] + finalized_weights = mx.maximum(finalized_weights, 1e-8) + finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights + finalized_output = finalized_output.astype(latents.dtype) + mx.eval(finalized_output) + + if debug: + print(f"[Tiling] Emitting finalized frames {emitted}-{next_tile_start_out - 1}") + + on_frames_ready(finalized_output, emitted) + decode_with_tiling._emitted_frames = next_tile_start_out + + del finalized_output, finalized_weights + gc.collect() + # Normalize by weights weights = mx.maximum(weights, 1e-8) output = output / weights mx.eval(output) + # Emit remaining frames if callback provided + if on_frames_ready is not None: + emitted = getattr(decode_with_tiling, '_emitted_frames', 0) + if emitted < out_f: + remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype) + mx.eval(remaining_output) + if debug: + print(f"[Tiling] Emitting remaining frames {emitted}-{out_f - 1}") + on_frames_ready(remaining_output, emitted) + del remaining_output + + # Reset emitted frames counter for next call + if hasattr(decode_with_tiling, '_emitted_frames'): + del decode_with_tiling._emitted_frames + # Clean up weights del weights gc.collect() From f256c5fb257bc54ec346f099dbcd2e163b13d55f Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 17 Jan 2026 23:36:39 +0100 Subject: [PATCH 2/3] add tests --- tests/test_vae_streaming.py | 301 ++++++++++++++++++++++++++++++++++++ 1 file changed, 301 insertions(+) create mode 100644 tests/test_vae_streaming.py diff --git a/tests/test_vae_streaming.py b/tests/test_vae_streaming.py new file mode 100644 index 0000000..be29d00 --- /dev/null +++ b/tests/test_vae_streaming.py @@ -0,0 +1,301 @@ +"""Tests for VAE streaming and chunked conv features.""" + +import pytest +import mlx.core as mx +import numpy as np + +from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample +from mlx_video.models.ltx.video_vae.tiling import ( + TilingConfig, + compute_trapezoidal_mask_1d, + decode_with_tiling, +) + + +class TestChunkedConv: + """Tests for chunked conv optimization in DepthToSpaceUpsample.""" + + def test_chunked_conv_output_matches_regular(self): + """Verify chunked_conv produces identical output to regular processing.""" + mx.random.seed(42) + + # Create upsampler with residual (matches decoder config) + upsampler = DepthToSpaceUpsample( + dims=3, + in_channels=256, + stride=(2, 2, 2), + residual=True, + out_channels_reduction_factor=2, + ) + + # Initialize weights deterministically + mx.eval(upsampler.parameters()) + + # Create test input: (B, C, D, H, W) with enough frames to trigger chunking + # chunked_conv activates when d > 4 + x = mx.random.normal((1, 256, 8, 8, 8)) + mx.eval(x) + + # Run without chunked conv + out_regular = upsampler(x, causal=True, chunked_conv=False) + mx.eval(out_regular) + + # Run with chunked conv + out_chunked = upsampler(x, causal=True, chunked_conv=True) + mx.eval(out_chunked) + + # Outputs should be identical + np.testing.assert_allclose( + np.array(out_regular), + np.array(out_chunked), + rtol=1e-5, + atol=1e-5, + err_msg="Chunked conv output differs from regular output" + ) + + def test_chunked_conv_small_input_passthrough(self): + """Verify chunked_conv doesn't activate for small inputs (d <= 4).""" + mx.random.seed(42) + + upsampler = DepthToSpaceUpsample( + dims=3, + in_channels=256, + stride=(2, 2, 2), + residual=True, + out_channels_reduction_factor=2, + ) + mx.eval(upsampler.parameters()) + + # Small input with d=4 (should NOT trigger chunking) + x = mx.random.normal((1, 256, 4, 8, 8)) + mx.eval(x) + + out_regular = upsampler(x, causal=True, chunked_conv=False) + out_chunked = upsampler(x, causal=True, chunked_conv=True) + mx.eval(out_regular, out_chunked) + + # Should be identical since chunking doesn't activate + np.testing.assert_allclose( + np.array(out_regular), + np.array(out_chunked), + rtol=1e-5, + atol=1e-5, + ) + + def test_chunked_conv_output_shape(self): + """Verify chunked_conv produces correct output shape.""" + mx.random.seed(42) + + upsampler = DepthToSpaceUpsample( + dims=3, + in_channels=256, + stride=(2, 2, 2), + residual=True, + out_channels_reduction_factor=2, + ) + mx.eval(upsampler.parameters()) + + # Input shape: (1, 256, 8, 16, 16) + x = mx.random.normal((1, 256, 8, 16, 16)) + mx.eval(x) + + out = upsampler(x, causal=True, chunked_conv=True) + mx.eval(out) + + # Expected output: + # - Channels: 256 / 2 = 128 + # - Temporal: 8 * 2 - 1 = 15 (minus 1 for causal) + # - Spatial: 16 * 2 = 32 + assert out.shape == (1, 128, 15, 32, 32), f"Unexpected shape: {out.shape}" + + +class TestProgressiveFrameSaving: + """Tests for progressive frame saving via on_frames_ready callback.""" + + def test_on_frames_ready_called(self): + """Verify on_frames_ready callback is called during tiled decoding.""" + frames_received = [] + + def on_frames_ready(frames: mx.array, start_idx: int): + frames_received.append({ + 'shape': frames.shape, + 'start_idx': start_idx, + }) + + # Create a mock decoder that just returns scaled input + def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False): + # Simulate VAE output: upsample 8x temporal, 32x spatial + b, c, f, h, w = x.shape + out_f = 1 + (f - 1) * 8 + out_h = h * 32 + out_w = w * 32 + return mx.zeros((b, 3, out_f, out_h, out_w)) + + # Create tiling config with temporal tiling to trigger callbacks + tiling_config = TilingConfig.temporal_only(tile_size=32, overlap=8) + + # Input latents: enough frames to require multiple tiles + # 10 latent frames -> 73 output frames + latents = mx.zeros((1, 128, 10, 4, 4)) + + # Decode with tiling + output = decode_with_tiling( + decoder_fn=mock_decoder, + latents=latents, + tiling_config=tiling_config, + spatial_scale=32, + temporal_scale=8, + on_frames_ready=on_frames_ready, + ) + mx.eval(output) + + # Should have received at least one callback + assert len(frames_received) > 0, "on_frames_ready was never called" + + # All received frames should have correct channel count + for received in frames_received: + assert received['shape'][1] == 3, f"Expected 3 channels, got {received['shape'][1]}" + + def test_on_frames_ready_covers_all_frames(self): + """Verify all frames are emitted via callbacks.""" + all_frame_indices = set() + + def on_frames_ready(frames: mx.array, start_idx: int): + num_frames = frames.shape[2] + for i in range(num_frames): + all_frame_indices.add(start_idx + i) + + def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False): + b, c, f, h, w = x.shape + out_f = 1 + (f - 1) * 8 + out_h = h * 32 + out_w = w * 32 + return mx.random.normal((b, 3, out_f, out_h, out_w)) + + tiling_config = TilingConfig.temporal_only(tile_size=32, overlap=8) + + # 12 latent frames -> 89 output frames + latents = mx.zeros((1, 128, 12, 4, 4)) + + output = decode_with_tiling( + decoder_fn=mock_decoder, + latents=latents, + tiling_config=tiling_config, + spatial_scale=32, + temporal_scale=8, + on_frames_ready=on_frames_ready, + ) + mx.eval(output) + + # Calculate expected frame count + expected_frames = 1 + (12 - 1) * 8 # 89 frames + + # All frames should have been emitted + assert len(all_frame_indices) == expected_frames, \ + f"Expected {expected_frames} frames, got {len(all_frame_indices)}" + assert all_frame_indices == set(range(expected_frames)), \ + "Not all frame indices were covered" + + +class TestAutoChunkedConv: + """Tests for auto-enabling chunked_conv based on tiling mode.""" + + @pytest.mark.parametrize("tiling_mode,should_enable", [ + ("conservative", True), + ("none", True), + ("auto", True), + ("default", True), + ("spatial", True), + ("aggressive", False), + ("temporal", False), + ]) + def test_chunked_conv_auto_enable(self, tiling_mode: str, should_enable: bool): + """Verify chunked_conv is auto-enabled for correct tiling modes.""" + # The logic is: tiling_mode in ("conservative", "none", "auto", "default", "spatial") + expected_modes = {"conservative", "none", "auto", "default", "spatial"} + + use_chunked_conv = tiling_mode in expected_modes + + assert use_chunked_conv == should_enable, \ + f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}" + + +class TestTrapezoidalMask: + """Tests for trapezoidal blending mask generation.""" + + def test_mask_values_in_range(self): + """Verify mask values are always in [0, 1].""" + for length in [16, 32, 64, 128]: + for ramp in [0, 4, 8, 16]: + if ramp < length: + mask = compute_trapezoidal_mask_1d(length, ramp, ramp, False) + assert mx.all(mask >= 0).item(), f"Mask has negative values" + assert mx.all(mask <= 1).item(), f"Mask has values > 1" + + def test_mask_center_is_one(self): + """Verify center of mask is 1.0 when ramps don't overlap.""" + mask = compute_trapezoidal_mask_1d(32, 8, 8, False) + # Center region should be 1.0 + center = mask[12:20] # Middle portion + np.testing.assert_allclose(np.array(center), 1.0, rtol=1e-5) + + def test_mask_ramp_monotonic(self): + """Verify ramps are monotonically increasing/decreasing.""" + mask = compute_trapezoidal_mask_1d(32, 8, 8, False) + mask_np = np.array(mask) + + # Left ramp should be increasing + left_ramp = mask_np[:8] + assert np.all(np.diff(left_ramp) >= 0), "Left ramp not monotonically increasing" + + # Right ramp should be decreasing + right_ramp = mask_np[-8:] + assert np.all(np.diff(right_ramp) <= 0), "Right ramp not monotonically decreasing" + + def test_temporal_mask_starts_from_zero(self): + """Verify temporal mask (left_starts_from_0=True) starts from 0.""" + mask = compute_trapezoidal_mask_1d(32, 8, 0, left_starts_from_0=True) + assert mask[0].item() == 0.0, "Temporal mask should start from 0" + + def test_spatial_mask_starts_above_zero(self): + """Verify spatial mask (left_starts_from_0=False) starts above 0.""" + mask = compute_trapezoidal_mask_1d(32, 8, 0, left_starts_from_0=False) + assert mask[0].item() > 0.0, "Spatial mask should start above 0" + + +class TestTilingConfig: + """Tests for TilingConfig presets.""" + + def test_default_config(self): + """Verify default tiling configuration.""" + config = TilingConfig.default() + assert config.spatial_config is not None + assert config.temporal_config is not None + assert config.spatial_config.tile_size_in_pixels == 512 + assert config.temporal_config.tile_size_in_frames == 64 + + def test_aggressive_config(self): + """Verify aggressive tiling configuration.""" + config = TilingConfig.aggressive() + assert config.spatial_config.tile_size_in_pixels == 256 + assert config.temporal_config.tile_size_in_frames == 32 + + def test_conservative_config(self): + """Verify conservative tiling configuration.""" + config = TilingConfig.conservative() + assert config.spatial_config.tile_size_in_pixels == 768 + assert config.temporal_config.tile_size_in_frames == 96 + + def test_auto_returns_none_for_small_video(self): + """Verify auto returns None for small videos.""" + config = TilingConfig.auto(height=256, width=256, num_frames=33) + assert config is None, "Auto should return None for small videos" + + def test_auto_returns_config_for_large_video(self): + """Verify auto returns config for large videos.""" + config = TilingConfig.auto(height=1024, width=768, num_frames=145) + assert config is not None, "Auto should return config for large videos" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From b1bf9e2dc00419a06702c0b730c6ff522e82ad77 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 17 Jan 2026 23:53:53 +0100 Subject: [PATCH 3/3] Enhance video generation with progress bar for streaming and remove debug prints from tiling decoder --- mlx_video/generate.py | 11 +++++------ mlx_video/models/ltx/video_vae/decoder.py | 6 ------ mlx_video/models/ltx/video_vae/tiling.py | 18 ------------------ 3 files changed, 5 insertions(+), 30 deletions(-) diff --git a/mlx_video/generate.py b/mlx_video/generate.py index c400081..9a72fe9 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -488,12 +488,13 @@ def generate_video( # Stream mode: write frames as they're decoded video_writer = None - frames_written = [0] # Use list to allow mutation in closure + stream_pbar = None if stream and tiling_config is not None: import cv2 fourcc = cv2.VideoWriter_fourcc(*'avc1') video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height)) + stream_pbar = tqdm(total=num_frames, desc="Streaming", unit="frame") def on_frames_ready(frames: mx.array, start_idx: int): """Callback to write frames as they're finalized.""" @@ -506,11 +507,7 @@ def generate_video( for frame in frames_np: video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) - frames_written[0] += 1 - - print(f"{Colors.DIM} Progressive: wrote frames {start_idx}-{start_idx + len(frames_np) - 1} ({frames_written[0]} total){Colors.RESET}") - - print(f"{Colors.MAGENTA}📹 Streaming enabled - frames will be written as decoded{Colors.RESET}") + stream_pbar.update(1) else: on_frames_ready = None @@ -528,6 +525,8 @@ def generate_video( # Close progressive video writer if used if video_writer is not None: video_writer.release() + if stream_pbar is not None: + stream_pbar.close() print(f"{Colors.GREEN}✅ Streamed video to{Colors.RESET} {output_path}") # Still need video_np for save_frames option video = mx.squeeze(video, axis=0) diff --git a/mlx_video/models/ltx/video_vae/decoder.py b/mlx_video/models/ltx/video_vae/decoder.py index 464b873..9a6cbb3 100644 --- a/mlx_video/models/ltx/video_vae/decoder.py +++ b/mlx_video/models/ltx/video_vae/decoder.py @@ -506,13 +506,8 @@ class LTX2VideoDecoder(nn.Module): if not needs_spatial_tiling and not needs_temporal_tiling: # No tiling needed, use regular decode - if debug: - print("[Tiling] Input fits within tile size, using regular decode") return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv) - if debug: - print(f"[Tiling] Using tiled decode (spatial={needs_spatial_tiling}, temporal={needs_temporal_tiling})") - return decode_with_tiling( decoder_fn=self, latents=sample, @@ -521,7 +516,6 @@ class LTX2VideoDecoder(nn.Module): temporal_scale=8, # VAE temporal upsampling factor causal=causal, timestep=timestep, - debug=debug, chunked_conv=use_chunked_conv, on_frames_ready=on_frames_ready, ) diff --git a/mlx_video/models/ltx/video_vae/tiling.py b/mlx_video/models/ltx/video_vae/tiling.py index ee55026..72d32e4 100644 --- a/mlx_video/models/ltx/video_vae/tiling.py +++ b/mlx_video/models/ltx/video_vae/tiling.py @@ -284,7 +284,6 @@ def decode_with_tiling( temporal_scale: int = 8, causal: bool = False, timestep: Optional[mx.array] = None, - debug: bool = False, chunked_conv: bool = False, on_frames_ready: Optional[Callable[[mx.array, int], None]] = None, ) -> mx.array: @@ -298,7 +297,6 @@ def decode_with_tiling( temporal_scale: Temporal scale factor (8 for LTX VAE). causal: Whether to use causal convolutions. timestep: Optional timestep for conditioning. - debug: Whether to print debug info. chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory). on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized. frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames. @@ -343,10 +341,6 @@ def decode_with_tiling( num_w_tiles = len(width_intervals.starts) total_tiles = num_t_tiles * num_h_tiles * num_w_tiles - if debug: - print(f"[Tiling] Latent shape: {latents.shape}, Output shape: ({b}, 3, {out_f}, {out_h}, {out_w})") - print(f"[Tiling] Tiles: {num_t_tiles} temporal x {num_h_tiles} height x {num_w_tiles} width = {total_tiles}") - # Initialize output and weight accumulator # Use float32 for accumulation to avoid precision issues output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32) @@ -381,10 +375,6 @@ def decode_with_tiling( # Map width coordinates out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale) - if debug: - print(f"[Tiling] Tile {tile_idx + 1}/{total_tiles}: " - f"latent t=[{t_start},{t_end}) h=[{h_start},{h_end}) w=[{w_start},{w_end})") - # Extract tile latents (small slice) tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end] @@ -487,9 +477,6 @@ def decode_with_tiling( finalized_output = finalized_output.astype(latents.dtype) mx.eval(finalized_output) - if debug: - print(f"[Tiling] Emitting finalized frames {emitted}-{next_tile_start_out - 1}") - on_frames_ready(finalized_output, emitted) decode_with_tiling._emitted_frames = next_tile_start_out @@ -507,8 +494,6 @@ def decode_with_tiling( if emitted < out_f: remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype) mx.eval(remaining_output) - if debug: - print(f"[Tiling] Emitting remaining frames {emitted}-{out_f - 1}") on_frames_ready(remaining_output, emitted) del remaining_output @@ -520,8 +505,5 @@ def decode_with_tiling( del weights gc.collect() - if debug: - print(f"[Tiling] Done. Final shape: {output.shape}") - # Convert back to original dtype if needed return output.astype(latents.dtype)