Add streaming support to video generation
This commit is contained in:
@@ -15,7 +15,7 @@ Architecture (from PyTorch weights):
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -364,6 +364,7 @@ class LTX2VideoDecoder(nn.Module):
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
debug: bool = False,
|
||||
chunked_conv: bool = False,
|
||||
) -> mx.array:
|
||||
|
||||
def debug_stats(name, t):
|
||||
@@ -403,6 +404,8 @@ class LTX2VideoDecoder(nn.Module):
|
||||
for i, block in self.up_blocks.items():
|
||||
if isinstance(block, ResBlockGroup):
|
||||
x = block(x, causal=causal, timestep=scaled_timestep)
|
||||
elif isinstance(block, DepthToSpaceUpsample):
|
||||
x = block(x, causal=causal, chunked_conv=chunked_conv)
|
||||
else:
|
||||
x = block(x, causal=causal)
|
||||
if debug:
|
||||
@@ -450,9 +453,11 @@ class LTX2VideoDecoder(nn.Module):
|
||||
self,
|
||||
sample: mx.array,
|
||||
tiling_config: Optional[TilingConfig] = None,
|
||||
tiling_mode: str = "auto",
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
debug: bool = False,
|
||||
on_frames_ready: Optional[callable] = None,
|
||||
) -> mx.array:
|
||||
"""Decode latents using tiling to reduce memory usage.
|
||||
|
||||
@@ -495,11 +500,15 @@ class LTX2VideoDecoder(nn.Module):
|
||||
if f > tile_size_latent:
|
||||
needs_temporal_tiling = True
|
||||
|
||||
# Auto-enable chunked conv for modes where it helps (larger tiles)
|
||||
# Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks
|
||||
use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial")
|
||||
|
||||
if not needs_spatial_tiling and not needs_temporal_tiling:
|
||||
# No tiling needed, use regular decode
|
||||
if debug:
|
||||
print("[Tiling] Input fits within tile size, using regular decode")
|
||||
return self(sample, causal=causal, timestep=timestep, debug=debug)
|
||||
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
|
||||
|
||||
if debug:
|
||||
print(f"[Tiling] Using tiled decode (spatial={needs_spatial_tiling}, temporal={needs_temporal_tiling})")
|
||||
@@ -513,6 +522,8 @@ class LTX2VideoDecoder(nn.Module):
|
||||
causal=causal,
|
||||
timestep=timestep,
|
||||
debug=debug,
|
||||
chunked_conv=use_chunked_conv,
|
||||
on_frames_ready=on_frames_ready,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user