Enhance video generation with progress bar for streaming and remove debug prints from tiling decoder
This commit is contained in:
@@ -488,12 +488,13 @@ def generate_video(
|
||||
|
||||
# Stream mode: write frames as they're decoded
|
||||
video_writer = None
|
||||
frames_written = [0] # Use list to allow mutation in closure
|
||||
stream_pbar = None
|
||||
|
||||
if stream and tiling_config is not None:
|
||||
import cv2
|
||||
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
||||
video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
|
||||
stream_pbar = tqdm(total=num_frames, desc="Streaming", unit="frame")
|
||||
|
||||
def on_frames_ready(frames: mx.array, start_idx: int):
|
||||
"""Callback to write frames as they're finalized."""
|
||||
@@ -506,11 +507,7 @@ def generate_video(
|
||||
|
||||
for frame in frames_np:
|
||||
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||
frames_written[0] += 1
|
||||
|
||||
print(f"{Colors.DIM} Progressive: wrote frames {start_idx}-{start_idx + len(frames_np) - 1} ({frames_written[0]} total){Colors.RESET}")
|
||||
|
||||
print(f"{Colors.MAGENTA}📹 Streaming enabled - frames will be written as decoded{Colors.RESET}")
|
||||
stream_pbar.update(1)
|
||||
else:
|
||||
on_frames_ready = None
|
||||
|
||||
@@ -528,6 +525,8 @@ def generate_video(
|
||||
# Close progressive video writer if used
|
||||
if video_writer is not None:
|
||||
video_writer.release()
|
||||
if stream_pbar is not None:
|
||||
stream_pbar.close()
|
||||
print(f"{Colors.GREEN}✅ Streamed video to{Colors.RESET} {output_path}")
|
||||
# Still need video_np for save_frames option
|
||||
video = mx.squeeze(video, axis=0)
|
||||
|
||||
@@ -506,13 +506,8 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
if not needs_spatial_tiling and not needs_temporal_tiling:
|
||||
# No tiling needed, use regular decode
|
||||
if debug:
|
||||
print("[Tiling] Input fits within tile size, using regular decode")
|
||||
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
|
||||
|
||||
if debug:
|
||||
print(f"[Tiling] Using tiled decode (spatial={needs_spatial_tiling}, temporal={needs_temporal_tiling})")
|
||||
|
||||
return decode_with_tiling(
|
||||
decoder_fn=self,
|
||||
latents=sample,
|
||||
@@ -521,7 +516,6 @@ class LTX2VideoDecoder(nn.Module):
|
||||
temporal_scale=8, # VAE temporal upsampling factor
|
||||
causal=causal,
|
||||
timestep=timestep,
|
||||
debug=debug,
|
||||
chunked_conv=use_chunked_conv,
|
||||
on_frames_ready=on_frames_ready,
|
||||
)
|
||||
|
||||
@@ -284,7 +284,6 @@ def decode_with_tiling(
|
||||
temporal_scale: int = 8,
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
debug: bool = False,
|
||||
chunked_conv: bool = False,
|
||||
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
|
||||
) -> mx.array:
|
||||
@@ -298,7 +297,6 @@ def decode_with_tiling(
|
||||
temporal_scale: Temporal scale factor (8 for LTX VAE).
|
||||
causal: Whether to use causal convolutions.
|
||||
timestep: Optional timestep for conditioning.
|
||||
debug: Whether to print debug info.
|
||||
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
|
||||
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
|
||||
frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames.
|
||||
@@ -343,10 +341,6 @@ def decode_with_tiling(
|
||||
num_w_tiles = len(width_intervals.starts)
|
||||
total_tiles = num_t_tiles * num_h_tiles * num_w_tiles
|
||||
|
||||
if debug:
|
||||
print(f"[Tiling] Latent shape: {latents.shape}, Output shape: ({b}, 3, {out_f}, {out_h}, {out_w})")
|
||||
print(f"[Tiling] Tiles: {num_t_tiles} temporal x {num_h_tiles} height x {num_w_tiles} width = {total_tiles}")
|
||||
|
||||
# Initialize output and weight accumulator
|
||||
# Use float32 for accumulation to avoid precision issues
|
||||
output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32)
|
||||
@@ -381,10 +375,6 @@ def decode_with_tiling(
|
||||
# Map width coordinates
|
||||
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
|
||||
|
||||
if debug:
|
||||
print(f"[Tiling] Tile {tile_idx + 1}/{total_tiles}: "
|
||||
f"latent t=[{t_start},{t_end}) h=[{h_start},{h_end}) w=[{w_start},{w_end})")
|
||||
|
||||
# Extract tile latents (small slice)
|
||||
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
|
||||
|
||||
@@ -487,9 +477,6 @@ def decode_with_tiling(
|
||||
finalized_output = finalized_output.astype(latents.dtype)
|
||||
mx.eval(finalized_output)
|
||||
|
||||
if debug:
|
||||
print(f"[Tiling] Emitting finalized frames {emitted}-{next_tile_start_out - 1}")
|
||||
|
||||
on_frames_ready(finalized_output, emitted)
|
||||
decode_with_tiling._emitted_frames = next_tile_start_out
|
||||
|
||||
@@ -507,8 +494,6 @@ def decode_with_tiling(
|
||||
if emitted < out_f:
|
||||
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
||||
mx.eval(remaining_output)
|
||||
if debug:
|
||||
print(f"[Tiling] Emitting remaining frames {emitted}-{out_f - 1}")
|
||||
on_frames_ready(remaining_output, emitted)
|
||||
del remaining_output
|
||||
|
||||
@@ -520,8 +505,5 @@ def decode_with_tiling(
|
||||
del weights
|
||||
gc.collect()
|
||||
|
||||
if debug:
|
||||
print(f"[Tiling] Done. Final shape: {output.shape}")
|
||||
|
||||
# Convert back to original dtype if needed
|
||||
return output.astype(latents.dtype)
|
||||
|
||||
Reference in New Issue
Block a user