From 7f20840dc7e82ba5cd2d70da283a780de5e756e0 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 17 Jan 2026 23:17:08 +0100 Subject: [PATCH] Add streaming support to video generation --- mlx_video/generate.py | 90 +++++++++++++++----- mlx_video/models/ltx/video_vae/decoder.py | 15 +++- mlx_video/models/ltx/video_vae/sampling.py | 95 ++++++++++++++++++++-- mlx_video/models/ltx/video_vae/tiling.py | 63 +++++++++++++- 4 files changed, 229 insertions(+), 34 deletions(-) diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 17f3770..c400081 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -215,6 +215,7 @@ def generate_video( image_strength: float = 1.0, image_frame_idx: int = 0, tiling: str = "auto", + stream: bool = False, ): """Generate video from text prompt, optionally conditioned on an image. @@ -481,39 +482,79 @@ def generate_video( print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}") tiling_config = TilingConfig.auto(height, width, num_frames) + # Save outputs + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Stream mode: write frames as they're decoded + video_writer = None + frames_written = [0] # Use list to allow mutation in closure + + if stream and tiling_config is not None: + import cv2 + fourcc = cv2.VideoWriter_fourcc(*'avc1') + video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height)) + + def on_frames_ready(frames: mx.array, start_idx: int): + """Callback to write frames as they're finalized.""" + # frames: (B, 3, num_frames, H, W) + frames = mx.squeeze(frames, axis=0) # (3, num_frames, H, W) + frames = mx.transpose(frames, (1, 2, 3, 0)) # (num_frames, H, W, 3) + frames = mx.clip((frames + 1.0) / 2.0, 0.0, 1.0) + frames = (frames * 255).astype(mx.uint8) + frames_np = np.array(frames) + + for frame in frames_np: + video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) + frames_written[0] += 1 + + print(f"{Colors.DIM} Progressive: wrote frames {start_idx}-{start_idx + len(frames_np) - 1} ({frames_written[0]} total){Colors.RESET}") + + print(f"{Colors.MAGENTA}📹 Streaming enabled - frames will be written as decoded{Colors.RESET}") + else: + on_frames_ready = None + if tiling_config is not None: spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none" temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none" print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}") - video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, debug=verbose) + video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose, on_frames_ready=on_frames_ready) else: print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}") video = vae_decoder(latents) mx.eval(video) mx.clear_cache() - # Convert to uint8 frames - video = mx.squeeze(video, axis=0) # (C, F, H, W) - video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C) - video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0) - video = (video * 255).astype(mx.uint8) - video_np = np.array(video) + # Close progressive video writer if used + if video_writer is not None: + video_writer.release() + print(f"{Colors.GREEN}✅ Streamed video to{Colors.RESET} {output_path}") + # Still need video_np for save_frames option + video = mx.squeeze(video, axis=0) + video = mx.transpose(video, (1, 2, 3, 0)) + video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0) + video = (video * 255).astype(mx.uint8) + video_np = np.array(video) + else: + # Convert to uint8 frames + video = mx.squeeze(video, axis=0) # (C, F, H, W) + video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C) + video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0) + video = (video * 255).astype(mx.uint8) + video_np = np.array(video) - # Save outputs - output_path = Path(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - - try: - import cv2 - height, width = video_np.shape[1], video_np.shape[2] - fourcc = cv2.VideoWriter_fourcc(*'avc1') - out = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height)) - for frame in video_np: - out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) - out.release() - print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}") - except Exception as e: - print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}") + # Save video normally + try: + import cv2 + h, w = video_np.shape[1], video_np.shape[2] + fourcc = cv2.VideoWriter_fourcc(*'avc1') + out = cv2.VideoWriter(str(output_path), fourcc, fps, (w, h)) + for frame in video_np: + out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) + out.release() + print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}") + except Exception as e: + print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}") if save_frames: frames_dir = output_path.parent / f"{output_path.stem}_frames" @@ -654,6 +695,11 @@ Examples: "auto=based on video size, none=disabled, default=512px/64f, " "aggressive=256px/32f (lowest memory), conservative=768px/96f, spatial=spatial only, temporal=temporal only" ) + parser.add_argument( + "--stream", + action="store_true", + help="Stream frames to output file as they're decoded (requires tiling). Allows viewing partial results sooner." + ) args = parser.parse_args() generate_video( diff --git a/mlx_video/models/ltx/video_vae/decoder.py b/mlx_video/models/ltx/video_vae/decoder.py index 0cb0d7b..464b873 100644 --- a/mlx_video/models/ltx/video_vae/decoder.py +++ b/mlx_video/models/ltx/video_vae/decoder.py @@ -15,7 +15,7 @@ Architecture (from PyTorch weights): """ import math -from typing import List, Optional +from typing import Optional import mlx.core as mx import mlx.nn as nn @@ -364,6 +364,7 @@ class LTX2VideoDecoder(nn.Module): causal: bool = False, timestep: Optional[mx.array] = None, debug: bool = False, + chunked_conv: bool = False, ) -> mx.array: def debug_stats(name, t): @@ -403,6 +404,8 @@ class LTX2VideoDecoder(nn.Module): for i, block in self.up_blocks.items(): if isinstance(block, ResBlockGroup): x = block(x, causal=causal, timestep=scaled_timestep) + elif isinstance(block, DepthToSpaceUpsample): + x = block(x, causal=causal, chunked_conv=chunked_conv) else: x = block(x, causal=causal) if debug: @@ -450,9 +453,11 @@ class LTX2VideoDecoder(nn.Module): self, sample: mx.array, tiling_config: Optional[TilingConfig] = None, + tiling_mode: str = "auto", causal: bool = False, timestep: Optional[mx.array] = None, debug: bool = False, + on_frames_ready: Optional[callable] = None, ) -> mx.array: """Decode latents using tiling to reduce memory usage. @@ -495,11 +500,15 @@ class LTX2VideoDecoder(nn.Module): if f > tile_size_latent: needs_temporal_tiling = True + # Auto-enable chunked conv for modes where it helps (larger tiles) + # Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks + use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial") + if not needs_spatial_tiling and not needs_temporal_tiling: # No tiling needed, use regular decode if debug: print("[Tiling] Input fits within tile size, using regular decode") - return self(sample, causal=causal, timestep=timestep, debug=debug) + return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv) if debug: print(f"[Tiling] Using tiled decode (spatial={needs_spatial_tiling}, temporal={needs_temporal_tiling})") @@ -513,6 +522,8 @@ class LTX2VideoDecoder(nn.Module): causal=causal, timestep=timestep, debug=debug, + chunked_conv=use_chunked_conv, + on_frames_ready=on_frames_ready, ) diff --git a/mlx_video/models/ltx/video_vae/sampling.py b/mlx_video/models/ltx/video_vae/sampling.py index 6ca3d41..76a96bf 100644 --- a/mlx_video/models/ltx/video_vae/sampling.py +++ b/mlx_video/models/ltx/video_vae/sampling.py @@ -156,8 +156,8 @@ class DepthToSpaceUpsample(nn.Module): return x - def __call__(self, x: mx.array, causal: bool = True) -> mx.array: - + def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array: + b, c, d, h, w = x.shape st, sh, sw = self.stride @@ -177,11 +177,14 @@ class DepthToSpaceUpsample(nn.Module): if st > 1: x_residual = x_residual[:, :, 1:, :, :] - # Apply conv - x = self.conv(x, causal=causal) - - # Depth to space rearrangement - x = self._depth_to_space(x) + # Use chunked mode for large tensors to reduce peak memory + if chunked_conv and d > 4: + x = self._chunked_conv_depth_to_space(x, causal) + else: + # Apply conv + x = self.conv(x, causal=causal) + # Depth to space rearrangement + x = self._depth_to_space(x) # Remove first frame for causal temporal upsampling if st > 1: @@ -192,3 +195,81 @@ class DepthToSpaceUpsample(nn.Module): x = x + x_residual return x + + def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array: + """Chunked conv + depth_to_space that processes in temporal chunks. + + This reduces peak memory by avoiding the full high-channel intermediate tensor. + Instead of materializing (B, 4096, D, H, W), we process temporal chunks and + immediately apply depth_to_space. + + Args: + x: Input tensor of shape (B, C, D, H, W) + causal: Whether to use causal convolutions + + Returns: + Output tensor after conv + depth_to_space + """ + b, c, d, h, w = x.shape + st, sh, sw = self.stride + out_c = self.out_channels + + # Output dimensions + out_d = d * st + out_h = h * sh + out_w = w * sw + + # Chunk size in temporal dimension (process 4 frames at a time) + chunk_size = 4 + kernel_t = 3 # Temporal kernel size + + # For causal conv, we need (kernel_t - 1) frames of padding at the start + # For non-causal, we need (kernel_t - 1) // 2 on each side + if causal: + # Pad start with first frame repeated + pad_start = kernel_t - 1 + pad_end = 0 + else: + pad_start = (kernel_t - 1) // 2 + pad_end = (kernel_t - 1) // 2 + + # Allocate output + outputs = [] + + # Process in chunks with overlap for conv kernel + t_pos = 0 + while t_pos < d: + t_end = min(t_pos + chunk_size, d) + + # Calculate input range with padding for kernel + in_start = max(0, t_pos - pad_start) + in_end = min(d, t_end + pad_end) + + # Extract chunk + chunk = x[:, :, in_start:in_end, :, :] + + # Apply conv to chunk + chunk_conv = self.conv(chunk, causal=causal) + + # Apply depth_to_space + chunk_out = self._depth_to_space(chunk_conv) + + # Calculate valid output range (excluding padding effects) + # Each input frame produces st output frames + out_start = (t_pos - in_start) * st + out_end = out_start + (t_end - t_pos) * st + + # Extract valid portion + chunk_out = chunk_out[:, :, out_start:out_end, :, :] + + outputs.append(chunk_out) + + # Evaluate to free intermediate memory + mx.eval(outputs[-1]) + + t_pos = t_end + + # Concatenate all chunks + if len(outputs) == 1: + return outputs[0] + return mx.concatenate(outputs, axis=2) diff --git a/mlx_video/models/ltx/video_vae/tiling.py b/mlx_video/models/ltx/video_vae/tiling.py index 20950fc..ee55026 100644 --- a/mlx_video/models/ltx/video_vae/tiling.py +++ b/mlx_video/models/ltx/video_vae/tiling.py @@ -8,8 +8,8 @@ Default configuration (from PyTorch): - Temporal: 64 frames with 24 frame overlap """ -from dataclasses import dataclass, replace -from typing import List, Optional, Tuple +from dataclasses import dataclass +from typing import Callable, List, Optional, Tuple import mlx.core as mx @@ -285,6 +285,8 @@ def decode_with_tiling( causal: bool = False, timestep: Optional[mx.array] = None, debug: bool = False, + chunked_conv: bool = False, + on_frames_ready: Optional[Callable[[mx.array, int], None]] = None, ) -> mx.array: """Decode latents using tiling to reduce memory usage. @@ -297,6 +299,10 @@ def decode_with_tiling( causal: Whether to use causal convolutions. timestep: Optional timestep for conditioning. debug: Whether to print debug info. + chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory). + on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized. + frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames. + start_idx: Starting frame index in the full video. Returns: Decoded video. @@ -383,7 +389,7 @@ def decode_with_tiling( tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end] # Decode tile - tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False) + tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv) mx.eval(tile_output) # Clear tile_latents reference @@ -454,11 +460,62 @@ def decode_with_tiling( except Exception: pass # May not be available on all platforms + # After completing all spatial tiles for this temporal tile, + # check if any frames are now finalized (no future tiles will contribute) + if on_frames_ready is not None and num_t_tiles > 1: + # Determine the finalized frame boundary + # Frames before the start of the next tile's output region are finalized + if t_idx < num_t_tiles - 1: + # Next tile starts at temporal_intervals.starts[t_idx + 1] + next_tile_start_latent = temporal_intervals.starts[t_idx + 1] + # Map to output frame index (first frame of next tile's contribution) + if next_tile_start_latent == 0: + next_tile_start_out = 0 + else: + next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale + + # We need to track how many frames we've already emitted + if not hasattr(decode_with_tiling, '_emitted_frames'): + decode_with_tiling._emitted_frames = 0 + emitted = decode_with_tiling._emitted_frames + + if next_tile_start_out > emitted: + # Normalize and emit frames [emitted, next_tile_start_out) + finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :] + finalized_weights = mx.maximum(finalized_weights, 1e-8) + finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights + finalized_output = finalized_output.astype(latents.dtype) + mx.eval(finalized_output) + + if debug: + print(f"[Tiling] Emitting finalized frames {emitted}-{next_tile_start_out - 1}") + + on_frames_ready(finalized_output, emitted) + decode_with_tiling._emitted_frames = next_tile_start_out + + del finalized_output, finalized_weights + gc.collect() + # Normalize by weights weights = mx.maximum(weights, 1e-8) output = output / weights mx.eval(output) + # Emit remaining frames if callback provided + if on_frames_ready is not None: + emitted = getattr(decode_with_tiling, '_emitted_frames', 0) + if emitted < out_f: + remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype) + mx.eval(remaining_output) + if debug: + print(f"[Tiling] Emitting remaining frames {emitted}-{out_f - 1}") + on_frames_ready(remaining_output, emitted) + del remaining_output + + # Reset emitted frames counter for next call + if hasattr(decode_with_tiling, '_emitted_frames'): + del decode_with_tiling._emitted_frames + # Clean up weights del weights gc.collect()