Enhance video generation with progress bar for streaming and remove debug prints from tiling decoder

This commit is contained in:
Prince Canuma
2026-01-17 23:53:53 +01:00
parent f256c5fb25
commit b1bf9e2dc0
3 changed files with 5 additions and 30 deletions

View File

@@ -488,12 +488,13 @@ def generate_video(
# Stream mode: write frames as they're decoded
video_writer = None
frames_written = [0] # Use list to allow mutation in closure
stream_pbar = None
if stream and tiling_config is not None:
import cv2
fourcc = cv2.VideoWriter_fourcc(*'avc1')
video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
stream_pbar = tqdm(total=num_frames, desc="Streaming", unit="frame")
def on_frames_ready(frames: mx.array, start_idx: int):
"""Callback to write frames as they're finalized."""
@@ -506,11 +507,7 @@ def generate_video(
for frame in frames_np:
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
frames_written[0] += 1
print(f"{Colors.DIM} Progressive: wrote frames {start_idx}-{start_idx + len(frames_np) - 1} ({frames_written[0]} total){Colors.RESET}")
print(f"{Colors.MAGENTA}📹 Streaming enabled - frames will be written as decoded{Colors.RESET}")
stream_pbar.update(1)
else:
on_frames_ready = None
@@ -528,6 +525,8 @@ def generate_video(
# Close progressive video writer if used
if video_writer is not None:
video_writer.release()
if stream_pbar is not None:
stream_pbar.close()
print(f"{Colors.GREEN}✅ Streamed video to{Colors.RESET} {output_path}")
# Still need video_np for save_frames option
video = mx.squeeze(video, axis=0)

View File

@@ -506,13 +506,8 @@ class LTX2VideoDecoder(nn.Module):
if not needs_spatial_tiling and not needs_temporal_tiling:
# No tiling needed, use regular decode
if debug:
print("[Tiling] Input fits within tile size, using regular decode")
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
if debug:
print(f"[Tiling] Using tiled decode (spatial={needs_spatial_tiling}, temporal={needs_temporal_tiling})")
return decode_with_tiling(
decoder_fn=self,
latents=sample,
@@ -521,7 +516,6 @@ class LTX2VideoDecoder(nn.Module):
temporal_scale=8, # VAE temporal upsampling factor
causal=causal,
timestep=timestep,
debug=debug,
chunked_conv=use_chunked_conv,
on_frames_ready=on_frames_ready,
)

View File

@@ -284,7 +284,6 @@ def decode_with_tiling(
temporal_scale: int = 8,
causal: bool = False,
timestep: Optional[mx.array] = None,
debug: bool = False,
chunked_conv: bool = False,
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
) -> mx.array:
@@ -298,7 +297,6 @@ def decode_with_tiling(
temporal_scale: Temporal scale factor (8 for LTX VAE).
causal: Whether to use causal convolutions.
timestep: Optional timestep for conditioning.
debug: Whether to print debug info.
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames.
@@ -343,10 +341,6 @@ def decode_with_tiling(
num_w_tiles = len(width_intervals.starts)
total_tiles = num_t_tiles * num_h_tiles * num_w_tiles
if debug:
print(f"[Tiling] Latent shape: {latents.shape}, Output shape: ({b}, 3, {out_f}, {out_h}, {out_w})")
print(f"[Tiling] Tiles: {num_t_tiles} temporal x {num_h_tiles} height x {num_w_tiles} width = {total_tiles}")
# Initialize output and weight accumulator
# Use float32 for accumulation to avoid precision issues
output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32)
@@ -381,10 +375,6 @@ def decode_with_tiling(
# Map width coordinates
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
if debug:
print(f"[Tiling] Tile {tile_idx + 1}/{total_tiles}: "
f"latent t=[{t_start},{t_end}) h=[{h_start},{h_end}) w=[{w_start},{w_end})")
# Extract tile latents (small slice)
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
@@ -487,9 +477,6 @@ def decode_with_tiling(
finalized_output = finalized_output.astype(latents.dtype)
mx.eval(finalized_output)
if debug:
print(f"[Tiling] Emitting finalized frames {emitted}-{next_tile_start_out - 1}")
on_frames_ready(finalized_output, emitted)
decode_with_tiling._emitted_frames = next_tile_start_out
@@ -507,8 +494,6 @@ def decode_with_tiling(
if emitted < out_f:
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
mx.eval(remaining_output)
if debug:
print(f"[Tiling] Emitting remaining frames {emitted}-{out_f - 1}")
on_frames_ready(remaining_output, emitted)
del remaining_output
@@ -520,8 +505,5 @@ def decode_with_tiling(
del weights
gc.collect()
if debug:
print(f"[Tiling] Done. Final shape: {output.shape}")
# Convert back to original dtype if needed
return output.astype(latents.dtype)