Add streaming support to video generation

This commit is contained in:
Prince Canuma
2026-01-17 23:17:08 +01:00
parent f33f496fba
commit 7f20840dc7
4 changed files with 229 additions and 34 deletions

View File

@@ -215,6 +215,7 @@ def generate_video(
image_strength: float = 1.0,
image_frame_idx: int = 0,
tiling: str = "auto",
stream: bool = False,
):
"""Generate video from text prompt, optionally conditioned on an image.
@@ -481,39 +482,79 @@ def generate_video(
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
tiling_config = TilingConfig.auto(height, width, num_frames)
# Save outputs
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Stream mode: write frames as they're decoded
video_writer = None
frames_written = [0] # Use list to allow mutation in closure
if stream and tiling_config is not None:
import cv2
fourcc = cv2.VideoWriter_fourcc(*'avc1')
video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
def on_frames_ready(frames: mx.array, start_idx: int):
"""Callback to write frames as they're finalized."""
# frames: (B, 3, num_frames, H, W)
frames = mx.squeeze(frames, axis=0) # (3, num_frames, H, W)
frames = mx.transpose(frames, (1, 2, 3, 0)) # (num_frames, H, W, 3)
frames = mx.clip((frames + 1.0) / 2.0, 0.0, 1.0)
frames = (frames * 255).astype(mx.uint8)
frames_np = np.array(frames)
for frame in frames_np:
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
frames_written[0] += 1
print(f"{Colors.DIM} Progressive: wrote frames {start_idx}-{start_idx + len(frames_np) - 1} ({frames_written[0]} total){Colors.RESET}")
print(f"{Colors.MAGENTA}📹 Streaming enabled - frames will be written as decoded{Colors.RESET}")
else:
on_frames_ready = None
if tiling_config is not None:
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, debug=verbose)
video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose, on_frames_ready=on_frames_ready)
else:
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
video = vae_decoder(latents)
mx.eval(video)
mx.clear_cache()
# Convert to uint8 frames
video = mx.squeeze(video, axis=0) # (C, F, H, W)
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
video = (video * 255).astype(mx.uint8)
video_np = np.array(video)
# Close progressive video writer if used
if video_writer is not None:
video_writer.release()
print(f"{Colors.GREEN}✅ Streamed video to{Colors.RESET} {output_path}")
# Still need video_np for save_frames option
video = mx.squeeze(video, axis=0)
video = mx.transpose(video, (1, 2, 3, 0))
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
video = (video * 255).astype(mx.uint8)
video_np = np.array(video)
else:
# Convert to uint8 frames
video = mx.squeeze(video, axis=0) # (C, F, H, W)
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
video = (video * 255).astype(mx.uint8)
video_np = np.array(video)
# Save outputs
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
try:
import cv2
height, width = video_np.shape[1], video_np.shape[2]
fourcc = cv2.VideoWriter_fourcc(*'avc1')
out = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
for frame in video_np:
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
out.release()
print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}")
except Exception as e:
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}")
# Save video normally
try:
import cv2
h, w = video_np.shape[1], video_np.shape[2]
fourcc = cv2.VideoWriter_fourcc(*'avc1')
out = cv2.VideoWriter(str(output_path), fourcc, fps, (w, h))
for frame in video_np:
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
out.release()
print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}")
except Exception as e:
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}")
if save_frames:
frames_dir = output_path.parent / f"{output_path.stem}_frames"
@@ -654,6 +695,11 @@ Examples:
"auto=based on video size, none=disabled, default=512px/64f, "
"aggressive=256px/32f (lowest memory), conservative=768px/96f, spatial=spatial only, temporal=temporal only"
)
parser.add_argument(
"--stream",
action="store_true",
help="Stream frames to output file as they're decoded (requires tiling). Allows viewing partial results sooner."
)
args = parser.parse_args()
generate_video(

View File

@@ -15,7 +15,7 @@ Architecture (from PyTorch weights):
"""
import math
from typing import List, Optional
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
@@ -364,6 +364,7 @@ class LTX2VideoDecoder(nn.Module):
causal: bool = False,
timestep: Optional[mx.array] = None,
debug: bool = False,
chunked_conv: bool = False,
) -> mx.array:
def debug_stats(name, t):
@@ -403,6 +404,8 @@ class LTX2VideoDecoder(nn.Module):
for i, block in self.up_blocks.items():
if isinstance(block, ResBlockGroup):
x = block(x, causal=causal, timestep=scaled_timestep)
elif isinstance(block, DepthToSpaceUpsample):
x = block(x, causal=causal, chunked_conv=chunked_conv)
else:
x = block(x, causal=causal)
if debug:
@@ -450,9 +453,11 @@ class LTX2VideoDecoder(nn.Module):
self,
sample: mx.array,
tiling_config: Optional[TilingConfig] = None,
tiling_mode: str = "auto",
causal: bool = False,
timestep: Optional[mx.array] = None,
debug: bool = False,
on_frames_ready: Optional[callable] = None,
) -> mx.array:
"""Decode latents using tiling to reduce memory usage.
@@ -495,11 +500,15 @@ class LTX2VideoDecoder(nn.Module):
if f > tile_size_latent:
needs_temporal_tiling = True
# Auto-enable chunked conv for modes where it helps (larger tiles)
# Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks
use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial")
if not needs_spatial_tiling and not needs_temporal_tiling:
# No tiling needed, use regular decode
if debug:
print("[Tiling] Input fits within tile size, using regular decode")
return self(sample, causal=causal, timestep=timestep, debug=debug)
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
if debug:
print(f"[Tiling] Using tiled decode (spatial={needs_spatial_tiling}, temporal={needs_temporal_tiling})")
@@ -513,6 +522,8 @@ class LTX2VideoDecoder(nn.Module):
causal=causal,
timestep=timestep,
debug=debug,
chunked_conv=use_chunked_conv,
on_frames_ready=on_frames_ready,
)

View File

@@ -156,7 +156,7 @@ class DepthToSpaceUpsample(nn.Module):
return x
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array:
b, c, d, h, w = x.shape
st, sh, sw = self.stride
@@ -177,11 +177,14 @@ class DepthToSpaceUpsample(nn.Module):
if st > 1:
x_residual = x_residual[:, :, 1:, :, :]
# Apply conv
x = self.conv(x, causal=causal)
# Depth to space rearrangement
x = self._depth_to_space(x)
# Use chunked mode for large tensors to reduce peak memory
if chunked_conv and d > 4:
x = self._chunked_conv_depth_to_space(x, causal)
else:
# Apply conv
x = self.conv(x, causal=causal)
# Depth to space rearrangement
x = self._depth_to_space(x)
# Remove first frame for causal temporal upsampling
if st > 1:
@@ -192,3 +195,81 @@ class DepthToSpaceUpsample(nn.Module):
x = x + x_residual
return x
def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array:
"""Chunked conv + depth_to_space that processes in temporal chunks.
This reduces peak memory by avoiding the full high-channel intermediate tensor.
Instead of materializing (B, 4096, D, H, W), we process temporal chunks and
immediately apply depth_to_space.
Args:
x: Input tensor of shape (B, C, D, H, W)
causal: Whether to use causal convolutions
Returns:
Output tensor after conv + depth_to_space
"""
b, c, d, h, w = x.shape
st, sh, sw = self.stride
out_c = self.out_channels
# Output dimensions
out_d = d * st
out_h = h * sh
out_w = w * sw
# Chunk size in temporal dimension (process 4 frames at a time)
chunk_size = 4
kernel_t = 3 # Temporal kernel size
# For causal conv, we need (kernel_t - 1) frames of padding at the start
# For non-causal, we need (kernel_t - 1) // 2 on each side
if causal:
# Pad start with first frame repeated
pad_start = kernel_t - 1
pad_end = 0
else:
pad_start = (kernel_t - 1) // 2
pad_end = (kernel_t - 1) // 2
# Allocate output
outputs = []
# Process in chunks with overlap for conv kernel
t_pos = 0
while t_pos < d:
t_end = min(t_pos + chunk_size, d)
# Calculate input range with padding for kernel
in_start = max(0, t_pos - pad_start)
in_end = min(d, t_end + pad_end)
# Extract chunk
chunk = x[:, :, in_start:in_end, :, :]
# Apply conv to chunk
chunk_conv = self.conv(chunk, causal=causal)
# Apply depth_to_space
chunk_out = self._depth_to_space(chunk_conv)
# Calculate valid output range (excluding padding effects)
# Each input frame produces st output frames
out_start = (t_pos - in_start) * st
out_end = out_start + (t_end - t_pos) * st
# Extract valid portion
chunk_out = chunk_out[:, :, out_start:out_end, :, :]
outputs.append(chunk_out)
# Evaluate to free intermediate memory
mx.eval(outputs[-1])
t_pos = t_end
# Concatenate all chunks
if len(outputs) == 1:
return outputs[0]
return mx.concatenate(outputs, axis=2)

View File

@@ -8,8 +8,8 @@ Default configuration (from PyTorch):
- Temporal: 64 frames with 24 frame overlap
"""
from dataclasses import dataclass, replace
from typing import List, Optional, Tuple
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple
import mlx.core as mx
@@ -285,6 +285,8 @@ def decode_with_tiling(
causal: bool = False,
timestep: Optional[mx.array] = None,
debug: bool = False,
chunked_conv: bool = False,
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
) -> mx.array:
"""Decode latents using tiling to reduce memory usage.
@@ -297,6 +299,10 @@ def decode_with_tiling(
causal: Whether to use causal convolutions.
timestep: Optional timestep for conditioning.
debug: Whether to print debug info.
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames.
start_idx: Starting frame index in the full video.
Returns:
Decoded video.
@@ -383,7 +389,7 @@ def decode_with_tiling(
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
# Decode tile
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False)
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
mx.eval(tile_output)
# Clear tile_latents reference
@@ -454,11 +460,62 @@ def decode_with_tiling(
except Exception:
pass # May not be available on all platforms
# After completing all spatial tiles for this temporal tile,
# check if any frames are now finalized (no future tiles will contribute)
if on_frames_ready is not None and num_t_tiles > 1:
# Determine the finalized frame boundary
# Frames before the start of the next tile's output region are finalized
if t_idx < num_t_tiles - 1:
# Next tile starts at temporal_intervals.starts[t_idx + 1]
next_tile_start_latent = temporal_intervals.starts[t_idx + 1]
# Map to output frame index (first frame of next tile's contribution)
if next_tile_start_latent == 0:
next_tile_start_out = 0
else:
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
# We need to track how many frames we've already emitted
if not hasattr(decode_with_tiling, '_emitted_frames'):
decode_with_tiling._emitted_frames = 0
emitted = decode_with_tiling._emitted_frames
if next_tile_start_out > emitted:
# Normalize and emit frames [emitted, next_tile_start_out)
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
finalized_weights = mx.maximum(finalized_weights, 1e-8)
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
finalized_output = finalized_output.astype(latents.dtype)
mx.eval(finalized_output)
if debug:
print(f"[Tiling] Emitting finalized frames {emitted}-{next_tile_start_out - 1}")
on_frames_ready(finalized_output, emitted)
decode_with_tiling._emitted_frames = next_tile_start_out
del finalized_output, finalized_weights
gc.collect()
# Normalize by weights
weights = mx.maximum(weights, 1e-8)
output = output / weights
mx.eval(output)
# Emit remaining frames if callback provided
if on_frames_ready is not None:
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
if emitted < out_f:
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
mx.eval(remaining_output)
if debug:
print(f"[Tiling] Emitting remaining frames {emitted}-{out_f - 1}")
on_frames_ready(remaining_output, emitted)
del remaining_output
# Reset emitted frames counter for next call
if hasattr(decode_with_tiling, '_emitted_frames'):
del decode_with_tiling._emitted_frames
# Clean up weights
del weights
gc.collect()