Add streaming support to video generation

This commit is contained in:
Prince Canuma
2026-01-17 23:17:08 +01:00
parent f33f496fba
commit 7f20840dc7
4 changed files with 229 additions and 34 deletions

View File

@@ -15,7 +15,7 @@ Architecture (from PyTorch weights):
"""
import math
from typing import List, Optional
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
@@ -364,6 +364,7 @@ class LTX2VideoDecoder(nn.Module):
causal: bool = False,
timestep: Optional[mx.array] = None,
debug: bool = False,
chunked_conv: bool = False,
) -> mx.array:
def debug_stats(name, t):
@@ -403,6 +404,8 @@ class LTX2VideoDecoder(nn.Module):
for i, block in self.up_blocks.items():
if isinstance(block, ResBlockGroup):
x = block(x, causal=causal, timestep=scaled_timestep)
elif isinstance(block, DepthToSpaceUpsample):
x = block(x, causal=causal, chunked_conv=chunked_conv)
else:
x = block(x, causal=causal)
if debug:
@@ -450,9 +453,11 @@ class LTX2VideoDecoder(nn.Module):
self,
sample: mx.array,
tiling_config: Optional[TilingConfig] = None,
tiling_mode: str = "auto",
causal: bool = False,
timestep: Optional[mx.array] = None,
debug: bool = False,
on_frames_ready: Optional[callable] = None,
) -> mx.array:
"""Decode latents using tiling to reduce memory usage.
@@ -495,11 +500,15 @@ class LTX2VideoDecoder(nn.Module):
if f > tile_size_latent:
needs_temporal_tiling = True
# Auto-enable chunked conv for modes where it helps (larger tiles)
# Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks
use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial")
if not needs_spatial_tiling and not needs_temporal_tiling:
# No tiling needed, use regular decode
if debug:
print("[Tiling] Input fits within tile size, using regular decode")
return self(sample, causal=causal, timestep=timestep, debug=debug)
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
if debug:
print(f"[Tiling] Using tiled decode (spatial={needs_spatial_tiling}, temporal={needs_temporal_tiling})")
@@ -513,6 +522,8 @@ class LTX2VideoDecoder(nn.Module):
causal=causal,
timestep=timestep,
debug=debug,
chunked_conv=use_chunked_conv,
on_frames_ready=on_frames_ready,
)