Enhance video generation with progress bar for streaming and remove debug prints from tiling decoder
This commit is contained in:
@@ -488,12 +488,13 @@ def generate_video(
|
|||||||
|
|
||||||
# Stream mode: write frames as they're decoded
|
# Stream mode: write frames as they're decoded
|
||||||
video_writer = None
|
video_writer = None
|
||||||
frames_written = [0] # Use list to allow mutation in closure
|
stream_pbar = None
|
||||||
|
|
||||||
if stream and tiling_config is not None:
|
if stream and tiling_config is not None:
|
||||||
import cv2
|
import cv2
|
||||||
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
||||||
video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
|
video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
|
||||||
|
stream_pbar = tqdm(total=num_frames, desc="Streaming", unit="frame")
|
||||||
|
|
||||||
def on_frames_ready(frames: mx.array, start_idx: int):
|
def on_frames_ready(frames: mx.array, start_idx: int):
|
||||||
"""Callback to write frames as they're finalized."""
|
"""Callback to write frames as they're finalized."""
|
||||||
@@ -506,11 +507,7 @@ def generate_video(
|
|||||||
|
|
||||||
for frame in frames_np:
|
for frame in frames_np:
|
||||||
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||||
frames_written[0] += 1
|
stream_pbar.update(1)
|
||||||
|
|
||||||
print(f"{Colors.DIM} Progressive: wrote frames {start_idx}-{start_idx + len(frames_np) - 1} ({frames_written[0]} total){Colors.RESET}")
|
|
||||||
|
|
||||||
print(f"{Colors.MAGENTA}📹 Streaming enabled - frames will be written as decoded{Colors.RESET}")
|
|
||||||
else:
|
else:
|
||||||
on_frames_ready = None
|
on_frames_ready = None
|
||||||
|
|
||||||
@@ -528,6 +525,8 @@ def generate_video(
|
|||||||
# Close progressive video writer if used
|
# Close progressive video writer if used
|
||||||
if video_writer is not None:
|
if video_writer is not None:
|
||||||
video_writer.release()
|
video_writer.release()
|
||||||
|
if stream_pbar is not None:
|
||||||
|
stream_pbar.close()
|
||||||
print(f"{Colors.GREEN}✅ Streamed video to{Colors.RESET} {output_path}")
|
print(f"{Colors.GREEN}✅ Streamed video to{Colors.RESET} {output_path}")
|
||||||
# Still need video_np for save_frames option
|
# Still need video_np for save_frames option
|
||||||
video = mx.squeeze(video, axis=0)
|
video = mx.squeeze(video, axis=0)
|
||||||
|
|||||||
@@ -506,13 +506,8 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
|
|
||||||
if not needs_spatial_tiling and not needs_temporal_tiling:
|
if not needs_spatial_tiling and not needs_temporal_tiling:
|
||||||
# No tiling needed, use regular decode
|
# No tiling needed, use regular decode
|
||||||
if debug:
|
|
||||||
print("[Tiling] Input fits within tile size, using regular decode")
|
|
||||||
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
|
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f"[Tiling] Using tiled decode (spatial={needs_spatial_tiling}, temporal={needs_temporal_tiling})")
|
|
||||||
|
|
||||||
return decode_with_tiling(
|
return decode_with_tiling(
|
||||||
decoder_fn=self,
|
decoder_fn=self,
|
||||||
latents=sample,
|
latents=sample,
|
||||||
@@ -521,7 +516,6 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
temporal_scale=8, # VAE temporal upsampling factor
|
temporal_scale=8, # VAE temporal upsampling factor
|
||||||
causal=causal,
|
causal=causal,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
debug=debug,
|
|
||||||
chunked_conv=use_chunked_conv,
|
chunked_conv=use_chunked_conv,
|
||||||
on_frames_ready=on_frames_ready,
|
on_frames_ready=on_frames_ready,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -284,7 +284,6 @@ def decode_with_tiling(
|
|||||||
temporal_scale: int = 8,
|
temporal_scale: int = 8,
|
||||||
causal: bool = False,
|
causal: bool = False,
|
||||||
timestep: Optional[mx.array] = None,
|
timestep: Optional[mx.array] = None,
|
||||||
debug: bool = False,
|
|
||||||
chunked_conv: bool = False,
|
chunked_conv: bool = False,
|
||||||
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
|
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
@@ -298,7 +297,6 @@ def decode_with_tiling(
|
|||||||
temporal_scale: Temporal scale factor (8 for LTX VAE).
|
temporal_scale: Temporal scale factor (8 for LTX VAE).
|
||||||
causal: Whether to use causal convolutions.
|
causal: Whether to use causal convolutions.
|
||||||
timestep: Optional timestep for conditioning.
|
timestep: Optional timestep for conditioning.
|
||||||
debug: Whether to print debug info.
|
|
||||||
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
|
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
|
||||||
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
|
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
|
||||||
frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames.
|
frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames.
|
||||||
@@ -343,10 +341,6 @@ def decode_with_tiling(
|
|||||||
num_w_tiles = len(width_intervals.starts)
|
num_w_tiles = len(width_intervals.starts)
|
||||||
total_tiles = num_t_tiles * num_h_tiles * num_w_tiles
|
total_tiles = num_t_tiles * num_h_tiles * num_w_tiles
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f"[Tiling] Latent shape: {latents.shape}, Output shape: ({b}, 3, {out_f}, {out_h}, {out_w})")
|
|
||||||
print(f"[Tiling] Tiles: {num_t_tiles} temporal x {num_h_tiles} height x {num_w_tiles} width = {total_tiles}")
|
|
||||||
|
|
||||||
# Initialize output and weight accumulator
|
# Initialize output and weight accumulator
|
||||||
# Use float32 for accumulation to avoid precision issues
|
# Use float32 for accumulation to avoid precision issues
|
||||||
output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32)
|
output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32)
|
||||||
@@ -381,10 +375,6 @@ def decode_with_tiling(
|
|||||||
# Map width coordinates
|
# Map width coordinates
|
||||||
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
|
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f"[Tiling] Tile {tile_idx + 1}/{total_tiles}: "
|
|
||||||
f"latent t=[{t_start},{t_end}) h=[{h_start},{h_end}) w=[{w_start},{w_end})")
|
|
||||||
|
|
||||||
# Extract tile latents (small slice)
|
# Extract tile latents (small slice)
|
||||||
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
|
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
|
||||||
|
|
||||||
@@ -487,9 +477,6 @@ def decode_with_tiling(
|
|||||||
finalized_output = finalized_output.astype(latents.dtype)
|
finalized_output = finalized_output.astype(latents.dtype)
|
||||||
mx.eval(finalized_output)
|
mx.eval(finalized_output)
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f"[Tiling] Emitting finalized frames {emitted}-{next_tile_start_out - 1}")
|
|
||||||
|
|
||||||
on_frames_ready(finalized_output, emitted)
|
on_frames_ready(finalized_output, emitted)
|
||||||
decode_with_tiling._emitted_frames = next_tile_start_out
|
decode_with_tiling._emitted_frames = next_tile_start_out
|
||||||
|
|
||||||
@@ -507,8 +494,6 @@ def decode_with_tiling(
|
|||||||
if emitted < out_f:
|
if emitted < out_f:
|
||||||
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
||||||
mx.eval(remaining_output)
|
mx.eval(remaining_output)
|
||||||
if debug:
|
|
||||||
print(f"[Tiling] Emitting remaining frames {emitted}-{out_f - 1}")
|
|
||||||
on_frames_ready(remaining_output, emitted)
|
on_frames_ready(remaining_output, emitted)
|
||||||
del remaining_output
|
del remaining_output
|
||||||
|
|
||||||
@@ -520,8 +505,5 @@ def decode_with_tiling(
|
|||||||
del weights
|
del weights
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f"[Tiling] Done. Final shape: {output.shape}")
|
|
||||||
|
|
||||||
# Convert back to original dtype if needed
|
# Convert back to original dtype if needed
|
||||||
return output.astype(latents.dtype)
|
return output.astype(latents.dtype)
|
||||||
|
|||||||
Reference in New Issue
Block a user