Add streaming support to video generation

This commit is contained in:
Prince Canuma
2026-01-17 23:17:08 +01:00
parent f33f496fba
commit 7f20840dc7
4 changed files with 229 additions and 34 deletions

View File

@@ -156,8 +156,8 @@ class DepthToSpaceUpsample(nn.Module):
return x
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array:
b, c, d, h, w = x.shape
st, sh, sw = self.stride
@@ -177,11 +177,14 @@ class DepthToSpaceUpsample(nn.Module):
if st > 1:
x_residual = x_residual[:, :, 1:, :, :]
# Apply conv
x = self.conv(x, causal=causal)
# Depth to space rearrangement
x = self._depth_to_space(x)
# Use chunked mode for large tensors to reduce peak memory
if chunked_conv and d > 4:
x = self._chunked_conv_depth_to_space(x, causal)
else:
# Apply conv
x = self.conv(x, causal=causal)
# Depth to space rearrangement
x = self._depth_to_space(x)
# Remove first frame for causal temporal upsampling
if st > 1:
@@ -192,3 +195,81 @@ class DepthToSpaceUpsample(nn.Module):
x = x + x_residual
return x
def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array:
"""Chunked conv + depth_to_space that processes in temporal chunks.
This reduces peak memory by avoiding the full high-channel intermediate tensor.
Instead of materializing (B, 4096, D, H, W), we process temporal chunks and
immediately apply depth_to_space.
Args:
x: Input tensor of shape (B, C, D, H, W)
causal: Whether to use causal convolutions
Returns:
Output tensor after conv + depth_to_space
"""
b, c, d, h, w = x.shape
st, sh, sw = self.stride
out_c = self.out_channels
# Output dimensions
out_d = d * st
out_h = h * sh
out_w = w * sw
# Chunk size in temporal dimension (process 4 frames at a time)
chunk_size = 4
kernel_t = 3 # Temporal kernel size
# For causal conv, we need (kernel_t - 1) frames of padding at the start
# For non-causal, we need (kernel_t - 1) // 2 on each side
if causal:
# Pad start with first frame repeated
pad_start = kernel_t - 1
pad_end = 0
else:
pad_start = (kernel_t - 1) // 2
pad_end = (kernel_t - 1) // 2
# Allocate output
outputs = []
# Process in chunks with overlap for conv kernel
t_pos = 0
while t_pos < d:
t_end = min(t_pos + chunk_size, d)
# Calculate input range with padding for kernel
in_start = max(0, t_pos - pad_start)
in_end = min(d, t_end + pad_end)
# Extract chunk
chunk = x[:, :, in_start:in_end, :, :]
# Apply conv to chunk
chunk_conv = self.conv(chunk, causal=causal)
# Apply depth_to_space
chunk_out = self._depth_to_space(chunk_conv)
# Calculate valid output range (excluding padding effects)
# Each input frame produces st output frames
out_start = (t_pos - in_start) * st
out_end = out_start + (t_end - t_pos) * st
# Extract valid portion
chunk_out = chunk_out[:, :, out_start:out_end, :, :]
outputs.append(chunk_out)
# Evaluate to free intermediate memory
mx.eval(outputs[-1])
t_pos = t_end
# Concatenate all chunks
if len(outputs) == 1:
return outputs[0]
return mx.concatenate(outputs, axis=2)