Add streaming support to video generation
This commit is contained in:
@@ -8,8 +8,8 @@ Default configuration (from PyTorch):
|
||||
- Temporal: 64 frames with 24 frame overlap
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
@@ -285,6 +285,8 @@ def decode_with_tiling(
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
debug: bool = False,
|
||||
chunked_conv: bool = False,
|
||||
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
|
||||
) -> mx.array:
|
||||
"""Decode latents using tiling to reduce memory usage.
|
||||
|
||||
@@ -297,6 +299,10 @@ def decode_with_tiling(
|
||||
causal: Whether to use causal convolutions.
|
||||
timestep: Optional timestep for conditioning.
|
||||
debug: Whether to print debug info.
|
||||
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
|
||||
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
|
||||
frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames.
|
||||
start_idx: Starting frame index in the full video.
|
||||
|
||||
Returns:
|
||||
Decoded video.
|
||||
@@ -383,7 +389,7 @@ def decode_with_tiling(
|
||||
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
|
||||
|
||||
# Decode tile
|
||||
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False)
|
||||
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
|
||||
mx.eval(tile_output)
|
||||
|
||||
# Clear tile_latents reference
|
||||
@@ -454,11 +460,62 @@ def decode_with_tiling(
|
||||
except Exception:
|
||||
pass # May not be available on all platforms
|
||||
|
||||
# After completing all spatial tiles for this temporal tile,
|
||||
# check if any frames are now finalized (no future tiles will contribute)
|
||||
if on_frames_ready is not None and num_t_tiles > 1:
|
||||
# Determine the finalized frame boundary
|
||||
# Frames before the start of the next tile's output region are finalized
|
||||
if t_idx < num_t_tiles - 1:
|
||||
# Next tile starts at temporal_intervals.starts[t_idx + 1]
|
||||
next_tile_start_latent = temporal_intervals.starts[t_idx + 1]
|
||||
# Map to output frame index (first frame of next tile's contribution)
|
||||
if next_tile_start_latent == 0:
|
||||
next_tile_start_out = 0
|
||||
else:
|
||||
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
|
||||
|
||||
# We need to track how many frames we've already emitted
|
||||
if not hasattr(decode_with_tiling, '_emitted_frames'):
|
||||
decode_with_tiling._emitted_frames = 0
|
||||
emitted = decode_with_tiling._emitted_frames
|
||||
|
||||
if next_tile_start_out > emitted:
|
||||
# Normalize and emit frames [emitted, next_tile_start_out)
|
||||
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
|
||||
finalized_weights = mx.maximum(finalized_weights, 1e-8)
|
||||
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
|
||||
finalized_output = finalized_output.astype(latents.dtype)
|
||||
mx.eval(finalized_output)
|
||||
|
||||
if debug:
|
||||
print(f"[Tiling] Emitting finalized frames {emitted}-{next_tile_start_out - 1}")
|
||||
|
||||
on_frames_ready(finalized_output, emitted)
|
||||
decode_with_tiling._emitted_frames = next_tile_start_out
|
||||
|
||||
del finalized_output, finalized_weights
|
||||
gc.collect()
|
||||
|
||||
# Normalize by weights
|
||||
weights = mx.maximum(weights, 1e-8)
|
||||
output = output / weights
|
||||
mx.eval(output)
|
||||
|
||||
# Emit remaining frames if callback provided
|
||||
if on_frames_ready is not None:
|
||||
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
|
||||
if emitted < out_f:
|
||||
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
||||
mx.eval(remaining_output)
|
||||
if debug:
|
||||
print(f"[Tiling] Emitting remaining frames {emitted}-{out_f - 1}")
|
||||
on_frames_ready(remaining_output, emitted)
|
||||
del remaining_output
|
||||
|
||||
# Reset emitted frames counter for next call
|
||||
if hasattr(decode_with_tiling, '_emitted_frames'):
|
||||
del decode_with_tiling._emitted_frames
|
||||
|
||||
# Clean up weights
|
||||
del weights
|
||||
gc.collect()
|
||||
|
||||
Reference in New Issue
Block a user