Add streaming support to video generation
This commit is contained in:
@@ -215,6 +215,7 @@ def generate_video(
|
|||||||
image_strength: float = 1.0,
|
image_strength: float = 1.0,
|
||||||
image_frame_idx: int = 0,
|
image_frame_idx: int = 0,
|
||||||
tiling: str = "auto",
|
tiling: str = "auto",
|
||||||
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
"""Generate video from text prompt, optionally conditioned on an image.
|
"""Generate video from text prompt, optionally conditioned on an image.
|
||||||
|
|
||||||
@@ -481,39 +482,79 @@ def generate_video(
|
|||||||
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
|
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
|
||||||
tiling_config = TilingConfig.auto(height, width, num_frames)
|
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||||
|
|
||||||
|
# Save outputs
|
||||||
|
output_path = Path(output_path)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Stream mode: write frames as they're decoded
|
||||||
|
video_writer = None
|
||||||
|
frames_written = [0] # Use list to allow mutation in closure
|
||||||
|
|
||||||
|
if stream and tiling_config is not None:
|
||||||
|
import cv2
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
||||||
|
video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
|
||||||
|
|
||||||
|
def on_frames_ready(frames: mx.array, start_idx: int):
|
||||||
|
"""Callback to write frames as they're finalized."""
|
||||||
|
# frames: (B, 3, num_frames, H, W)
|
||||||
|
frames = mx.squeeze(frames, axis=0) # (3, num_frames, H, W)
|
||||||
|
frames = mx.transpose(frames, (1, 2, 3, 0)) # (num_frames, H, W, 3)
|
||||||
|
frames = mx.clip((frames + 1.0) / 2.0, 0.0, 1.0)
|
||||||
|
frames = (frames * 255).astype(mx.uint8)
|
||||||
|
frames_np = np.array(frames)
|
||||||
|
|
||||||
|
for frame in frames_np:
|
||||||
|
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||||
|
frames_written[0] += 1
|
||||||
|
|
||||||
|
print(f"{Colors.DIM} Progressive: wrote frames {start_idx}-{start_idx + len(frames_np) - 1} ({frames_written[0]} total){Colors.RESET}")
|
||||||
|
|
||||||
|
print(f"{Colors.MAGENTA}📹 Streaming enabled - frames will be written as decoded{Colors.RESET}")
|
||||||
|
else:
|
||||||
|
on_frames_ready = None
|
||||||
|
|
||||||
if tiling_config is not None:
|
if tiling_config is not None:
|
||||||
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
|
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
|
||||||
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
|
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
|
||||||
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
|
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
|
||||||
video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, debug=verbose)
|
video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose, on_frames_ready=on_frames_ready)
|
||||||
else:
|
else:
|
||||||
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
|
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
|
||||||
video = vae_decoder(latents)
|
video = vae_decoder(latents)
|
||||||
mx.eval(video)
|
mx.eval(video)
|
||||||
mx.clear_cache()
|
mx.clear_cache()
|
||||||
|
|
||||||
# Convert to uint8 frames
|
# Close progressive video writer if used
|
||||||
video = mx.squeeze(video, axis=0) # (C, F, H, W)
|
if video_writer is not None:
|
||||||
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
|
video_writer.release()
|
||||||
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
|
print(f"{Colors.GREEN}✅ Streamed video to{Colors.RESET} {output_path}")
|
||||||
video = (video * 255).astype(mx.uint8)
|
# Still need video_np for save_frames option
|
||||||
video_np = np.array(video)
|
video = mx.squeeze(video, axis=0)
|
||||||
|
video = mx.transpose(video, (1, 2, 3, 0))
|
||||||
|
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
|
||||||
|
video = (video * 255).astype(mx.uint8)
|
||||||
|
video_np = np.array(video)
|
||||||
|
else:
|
||||||
|
# Convert to uint8 frames
|
||||||
|
video = mx.squeeze(video, axis=0) # (C, F, H, W)
|
||||||
|
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
|
||||||
|
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
|
||||||
|
video = (video * 255).astype(mx.uint8)
|
||||||
|
video_np = np.array(video)
|
||||||
|
|
||||||
# Save outputs
|
# Save video normally
|
||||||
output_path = Path(output_path)
|
try:
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
import cv2
|
||||||
|
h, w = video_np.shape[1], video_np.shape[2]
|
||||||
try:
|
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
||||||
import cv2
|
out = cv2.VideoWriter(str(output_path), fourcc, fps, (w, h))
|
||||||
height, width = video_np.shape[1], video_np.shape[2]
|
for frame in video_np:
|
||||||
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||||
out = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
|
out.release()
|
||||||
for frame in video_np:
|
print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}")
|
||||||
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
except Exception as e:
|
||||||
out.release()
|
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}")
|
||||||
print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}")
|
|
||||||
|
|
||||||
if save_frames:
|
if save_frames:
|
||||||
frames_dir = output_path.parent / f"{output_path.stem}_frames"
|
frames_dir = output_path.parent / f"{output_path.stem}_frames"
|
||||||
@@ -654,6 +695,11 @@ Examples:
|
|||||||
"auto=based on video size, none=disabled, default=512px/64f, "
|
"auto=based on video size, none=disabled, default=512px/64f, "
|
||||||
"aggressive=256px/32f (lowest memory), conservative=768px/96f, spatial=spatial only, temporal=temporal only"
|
"aggressive=256px/32f (lowest memory), conservative=768px/96f, spatial=spatial only, temporal=temporal only"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stream",
|
||||||
|
action="store_true",
|
||||||
|
help="Stream frames to output file as they're decoded (requires tiling). Allows viewing partial results sooner."
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
generate_video(
|
generate_video(
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ Architecture (from PyTorch weights):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@@ -364,6 +364,7 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
causal: bool = False,
|
causal: bool = False,
|
||||||
timestep: Optional[mx.array] = None,
|
timestep: Optional[mx.array] = None,
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
|
chunked_conv: bool = False,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
|
|
||||||
def debug_stats(name, t):
|
def debug_stats(name, t):
|
||||||
@@ -403,6 +404,8 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
for i, block in self.up_blocks.items():
|
for i, block in self.up_blocks.items():
|
||||||
if isinstance(block, ResBlockGroup):
|
if isinstance(block, ResBlockGroup):
|
||||||
x = block(x, causal=causal, timestep=scaled_timestep)
|
x = block(x, causal=causal, timestep=scaled_timestep)
|
||||||
|
elif isinstance(block, DepthToSpaceUpsample):
|
||||||
|
x = block(x, causal=causal, chunked_conv=chunked_conv)
|
||||||
else:
|
else:
|
||||||
x = block(x, causal=causal)
|
x = block(x, causal=causal)
|
||||||
if debug:
|
if debug:
|
||||||
@@ -450,9 +453,11 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
sample: mx.array,
|
sample: mx.array,
|
||||||
tiling_config: Optional[TilingConfig] = None,
|
tiling_config: Optional[TilingConfig] = None,
|
||||||
|
tiling_mode: str = "auto",
|
||||||
causal: bool = False,
|
causal: bool = False,
|
||||||
timestep: Optional[mx.array] = None,
|
timestep: Optional[mx.array] = None,
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
|
on_frames_ready: Optional[callable] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""Decode latents using tiling to reduce memory usage.
|
"""Decode latents using tiling to reduce memory usage.
|
||||||
|
|
||||||
@@ -495,11 +500,15 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
if f > tile_size_latent:
|
if f > tile_size_latent:
|
||||||
needs_temporal_tiling = True
|
needs_temporal_tiling = True
|
||||||
|
|
||||||
|
# Auto-enable chunked conv for modes where it helps (larger tiles)
|
||||||
|
# Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks
|
||||||
|
use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial")
|
||||||
|
|
||||||
if not needs_spatial_tiling and not needs_temporal_tiling:
|
if not needs_spatial_tiling and not needs_temporal_tiling:
|
||||||
# No tiling needed, use regular decode
|
# No tiling needed, use regular decode
|
||||||
if debug:
|
if debug:
|
||||||
print("[Tiling] Input fits within tile size, using regular decode")
|
print("[Tiling] Input fits within tile size, using regular decode")
|
||||||
return self(sample, causal=causal, timestep=timestep, debug=debug)
|
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
print(f"[Tiling] Using tiled decode (spatial={needs_spatial_tiling}, temporal={needs_temporal_tiling})")
|
print(f"[Tiling] Using tiled decode (spatial={needs_spatial_tiling}, temporal={needs_temporal_tiling})")
|
||||||
@@ -513,6 +522,8 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
causal=causal,
|
causal=causal,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
debug=debug,
|
debug=debug,
|
||||||
|
chunked_conv=use_chunked_conv,
|
||||||
|
on_frames_ready=on_frames_ready,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -156,8 +156,8 @@ class DepthToSpaceUpsample(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
|
def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array:
|
||||||
|
|
||||||
b, c, d, h, w = x.shape
|
b, c, d, h, w = x.shape
|
||||||
st, sh, sw = self.stride
|
st, sh, sw = self.stride
|
||||||
|
|
||||||
@@ -177,11 +177,14 @@ class DepthToSpaceUpsample(nn.Module):
|
|||||||
if st > 1:
|
if st > 1:
|
||||||
x_residual = x_residual[:, :, 1:, :, :]
|
x_residual = x_residual[:, :, 1:, :, :]
|
||||||
|
|
||||||
# Apply conv
|
# Use chunked mode for large tensors to reduce peak memory
|
||||||
x = self.conv(x, causal=causal)
|
if chunked_conv and d > 4:
|
||||||
|
x = self._chunked_conv_depth_to_space(x, causal)
|
||||||
# Depth to space rearrangement
|
else:
|
||||||
x = self._depth_to_space(x)
|
# Apply conv
|
||||||
|
x = self.conv(x, causal=causal)
|
||||||
|
# Depth to space rearrangement
|
||||||
|
x = self._depth_to_space(x)
|
||||||
|
|
||||||
# Remove first frame for causal temporal upsampling
|
# Remove first frame for causal temporal upsampling
|
||||||
if st > 1:
|
if st > 1:
|
||||||
@@ -192,3 +195,81 @@ class DepthToSpaceUpsample(nn.Module):
|
|||||||
x = x + x_residual
|
x = x + x_residual
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array:
|
||||||
|
"""Chunked conv + depth_to_space that processes in temporal chunks.
|
||||||
|
|
||||||
|
This reduces peak memory by avoiding the full high-channel intermediate tensor.
|
||||||
|
Instead of materializing (B, 4096, D, H, W), we process temporal chunks and
|
||||||
|
immediately apply depth_to_space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor of shape (B, C, D, H, W)
|
||||||
|
causal: Whether to use causal convolutions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output tensor after conv + depth_to_space
|
||||||
|
"""
|
||||||
|
b, c, d, h, w = x.shape
|
||||||
|
st, sh, sw = self.stride
|
||||||
|
out_c = self.out_channels
|
||||||
|
|
||||||
|
# Output dimensions
|
||||||
|
out_d = d * st
|
||||||
|
out_h = h * sh
|
||||||
|
out_w = w * sw
|
||||||
|
|
||||||
|
# Chunk size in temporal dimension (process 4 frames at a time)
|
||||||
|
chunk_size = 4
|
||||||
|
kernel_t = 3 # Temporal kernel size
|
||||||
|
|
||||||
|
# For causal conv, we need (kernel_t - 1) frames of padding at the start
|
||||||
|
# For non-causal, we need (kernel_t - 1) // 2 on each side
|
||||||
|
if causal:
|
||||||
|
# Pad start with first frame repeated
|
||||||
|
pad_start = kernel_t - 1
|
||||||
|
pad_end = 0
|
||||||
|
else:
|
||||||
|
pad_start = (kernel_t - 1) // 2
|
||||||
|
pad_end = (kernel_t - 1) // 2
|
||||||
|
|
||||||
|
# Allocate output
|
||||||
|
outputs = []
|
||||||
|
|
||||||
|
# Process in chunks with overlap for conv kernel
|
||||||
|
t_pos = 0
|
||||||
|
while t_pos < d:
|
||||||
|
t_end = min(t_pos + chunk_size, d)
|
||||||
|
|
||||||
|
# Calculate input range with padding for kernel
|
||||||
|
in_start = max(0, t_pos - pad_start)
|
||||||
|
in_end = min(d, t_end + pad_end)
|
||||||
|
|
||||||
|
# Extract chunk
|
||||||
|
chunk = x[:, :, in_start:in_end, :, :]
|
||||||
|
|
||||||
|
# Apply conv to chunk
|
||||||
|
chunk_conv = self.conv(chunk, causal=causal)
|
||||||
|
|
||||||
|
# Apply depth_to_space
|
||||||
|
chunk_out = self._depth_to_space(chunk_conv)
|
||||||
|
|
||||||
|
# Calculate valid output range (excluding padding effects)
|
||||||
|
# Each input frame produces st output frames
|
||||||
|
out_start = (t_pos - in_start) * st
|
||||||
|
out_end = out_start + (t_end - t_pos) * st
|
||||||
|
|
||||||
|
# Extract valid portion
|
||||||
|
chunk_out = chunk_out[:, :, out_start:out_end, :, :]
|
||||||
|
|
||||||
|
outputs.append(chunk_out)
|
||||||
|
|
||||||
|
# Evaluate to free intermediate memory
|
||||||
|
mx.eval(outputs[-1])
|
||||||
|
|
||||||
|
t_pos = t_end
|
||||||
|
|
||||||
|
# Concatenate all chunks
|
||||||
|
if len(outputs) == 1:
|
||||||
|
return outputs[0]
|
||||||
|
return mx.concatenate(outputs, axis=2)
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ Default configuration (from PyTorch):
|
|||||||
- Temporal: 64 frames with 24 frame overlap
|
- Temporal: 64 frames with 24 frame overlap
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass, replace
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple
|
from typing import Callable, List, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
@@ -285,6 +285,8 @@ def decode_with_tiling(
|
|||||||
causal: bool = False,
|
causal: bool = False,
|
||||||
timestep: Optional[mx.array] = None,
|
timestep: Optional[mx.array] = None,
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
|
chunked_conv: bool = False,
|
||||||
|
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""Decode latents using tiling to reduce memory usage.
|
"""Decode latents using tiling to reduce memory usage.
|
||||||
|
|
||||||
@@ -297,6 +299,10 @@ def decode_with_tiling(
|
|||||||
causal: Whether to use causal convolutions.
|
causal: Whether to use causal convolutions.
|
||||||
timestep: Optional timestep for conditioning.
|
timestep: Optional timestep for conditioning.
|
||||||
debug: Whether to print debug info.
|
debug: Whether to print debug info.
|
||||||
|
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
|
||||||
|
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
|
||||||
|
frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames.
|
||||||
|
start_idx: Starting frame index in the full video.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Decoded video.
|
Decoded video.
|
||||||
@@ -383,7 +389,7 @@ def decode_with_tiling(
|
|||||||
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
|
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
|
||||||
|
|
||||||
# Decode tile
|
# Decode tile
|
||||||
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False)
|
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
|
||||||
mx.eval(tile_output)
|
mx.eval(tile_output)
|
||||||
|
|
||||||
# Clear tile_latents reference
|
# Clear tile_latents reference
|
||||||
@@ -454,11 +460,62 @@ def decode_with_tiling(
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass # May not be available on all platforms
|
pass # May not be available on all platforms
|
||||||
|
|
||||||
|
# After completing all spatial tiles for this temporal tile,
|
||||||
|
# check if any frames are now finalized (no future tiles will contribute)
|
||||||
|
if on_frames_ready is not None and num_t_tiles > 1:
|
||||||
|
# Determine the finalized frame boundary
|
||||||
|
# Frames before the start of the next tile's output region are finalized
|
||||||
|
if t_idx < num_t_tiles - 1:
|
||||||
|
# Next tile starts at temporal_intervals.starts[t_idx + 1]
|
||||||
|
next_tile_start_latent = temporal_intervals.starts[t_idx + 1]
|
||||||
|
# Map to output frame index (first frame of next tile's contribution)
|
||||||
|
if next_tile_start_latent == 0:
|
||||||
|
next_tile_start_out = 0
|
||||||
|
else:
|
||||||
|
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
|
||||||
|
|
||||||
|
# We need to track how many frames we've already emitted
|
||||||
|
if not hasattr(decode_with_tiling, '_emitted_frames'):
|
||||||
|
decode_with_tiling._emitted_frames = 0
|
||||||
|
emitted = decode_with_tiling._emitted_frames
|
||||||
|
|
||||||
|
if next_tile_start_out > emitted:
|
||||||
|
# Normalize and emit frames [emitted, next_tile_start_out)
|
||||||
|
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
|
||||||
|
finalized_weights = mx.maximum(finalized_weights, 1e-8)
|
||||||
|
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
|
||||||
|
finalized_output = finalized_output.astype(latents.dtype)
|
||||||
|
mx.eval(finalized_output)
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(f"[Tiling] Emitting finalized frames {emitted}-{next_tile_start_out - 1}")
|
||||||
|
|
||||||
|
on_frames_ready(finalized_output, emitted)
|
||||||
|
decode_with_tiling._emitted_frames = next_tile_start_out
|
||||||
|
|
||||||
|
del finalized_output, finalized_weights
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
# Normalize by weights
|
# Normalize by weights
|
||||||
weights = mx.maximum(weights, 1e-8)
|
weights = mx.maximum(weights, 1e-8)
|
||||||
output = output / weights
|
output = output / weights
|
||||||
mx.eval(output)
|
mx.eval(output)
|
||||||
|
|
||||||
|
# Emit remaining frames if callback provided
|
||||||
|
if on_frames_ready is not None:
|
||||||
|
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
|
||||||
|
if emitted < out_f:
|
||||||
|
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
||||||
|
mx.eval(remaining_output)
|
||||||
|
if debug:
|
||||||
|
print(f"[Tiling] Emitting remaining frames {emitted}-{out_f - 1}")
|
||||||
|
on_frames_ready(remaining_output, emitted)
|
||||||
|
del remaining_output
|
||||||
|
|
||||||
|
# Reset emitted frames counter for next call
|
||||||
|
if hasattr(decode_with_tiling, '_emitted_frames'):
|
||||||
|
del decode_with_tiling._emitted_frames
|
||||||
|
|
||||||
# Clean up weights
|
# Clean up weights
|
||||||
del weights
|
del weights
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|||||||
Reference in New Issue
Block a user