Add streaming support to video generation

This commit is contained in:
Prince Canuma
2026-01-17 23:17:08 +01:00
parent f33f496fba
commit 7f20840dc7
4 changed files with 229 additions and 34 deletions

View File

@@ -8,8 +8,8 @@ Default configuration (from PyTorch):
- Temporal: 64 frames with 24 frame overlap
"""
from dataclasses import dataclass, replace
from typing import List, Optional, Tuple
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple
import mlx.core as mx
@@ -285,6 +285,8 @@ def decode_with_tiling(
causal: bool = False,
timestep: Optional[mx.array] = None,
debug: bool = False,
chunked_conv: bool = False,
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
) -> mx.array:
"""Decode latents using tiling to reduce memory usage.
@@ -297,6 +299,10 @@ def decode_with_tiling(
causal: Whether to use causal convolutions.
timestep: Optional timestep for conditioning.
debug: Whether to print debug info.
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames.
start_idx: Starting frame index in the full video.
Returns:
Decoded video.
@@ -383,7 +389,7 @@ def decode_with_tiling(
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
# Decode tile
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False)
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
mx.eval(tile_output)
# Clear tile_latents reference
@@ -454,11 +460,62 @@ def decode_with_tiling(
except Exception:
pass # May not be available on all platforms
# After completing all spatial tiles for this temporal tile,
# check if any frames are now finalized (no future tiles will contribute)
if on_frames_ready is not None and num_t_tiles > 1:
# Determine the finalized frame boundary
# Frames before the start of the next tile's output region are finalized
if t_idx < num_t_tiles - 1:
# Next tile starts at temporal_intervals.starts[t_idx + 1]
next_tile_start_latent = temporal_intervals.starts[t_idx + 1]
# Map to output frame index (first frame of next tile's contribution)
if next_tile_start_latent == 0:
next_tile_start_out = 0
else:
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
# We need to track how many frames we've already emitted
if not hasattr(decode_with_tiling, '_emitted_frames'):
decode_with_tiling._emitted_frames = 0
emitted = decode_with_tiling._emitted_frames
if next_tile_start_out > emitted:
# Normalize and emit frames [emitted, next_tile_start_out)
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
finalized_weights = mx.maximum(finalized_weights, 1e-8)
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
finalized_output = finalized_output.astype(latents.dtype)
mx.eval(finalized_output)
if debug:
print(f"[Tiling] Emitting finalized frames {emitted}-{next_tile_start_out - 1}")
on_frames_ready(finalized_output, emitted)
decode_with_tiling._emitted_frames = next_tile_start_out
del finalized_output, finalized_weights
gc.collect()
# Normalize by weights
weights = mx.maximum(weights, 1e-8)
output = output / weights
mx.eval(output)
# Emit remaining frames if callback provided
if on_frames_ready is not None:
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
if emitted < out_f:
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
mx.eval(remaining_output)
if debug:
print(f"[Tiling] Emitting remaining frames {emitted}-{out_f - 1}")
on_frames_ready(remaining_output, emitted)
del remaining_output
# Reset emitted frames counter for next call
if hasattr(decode_with_tiling, '_emitted_frames'):
del decode_with_tiling._emitted_frames
# Clean up weights
del weights
gc.collect()