diff --git a/mlx_video/generate.py b/mlx_video/generate.py index c400081..9a72fe9 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -488,12 +488,13 @@ def generate_video( # Stream mode: write frames as they're decoded video_writer = None - frames_written = [0] # Use list to allow mutation in closure + stream_pbar = None if stream and tiling_config is not None: import cv2 fourcc = cv2.VideoWriter_fourcc(*'avc1') video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height)) + stream_pbar = tqdm(total=num_frames, desc="Streaming", unit="frame") def on_frames_ready(frames: mx.array, start_idx: int): """Callback to write frames as they're finalized.""" @@ -506,11 +507,7 @@ def generate_video( for frame in frames_np: video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) - frames_written[0] += 1 - - print(f"{Colors.DIM} Progressive: wrote frames {start_idx}-{start_idx + len(frames_np) - 1} ({frames_written[0]} total){Colors.RESET}") - - print(f"{Colors.MAGENTA}📹 Streaming enabled - frames will be written as decoded{Colors.RESET}") + stream_pbar.update(1) else: on_frames_ready = None @@ -528,6 +525,8 @@ def generate_video( # Close progressive video writer if used if video_writer is not None: video_writer.release() + if stream_pbar is not None: + stream_pbar.close() print(f"{Colors.GREEN}✅ Streamed video to{Colors.RESET} {output_path}") # Still need video_np for save_frames option video = mx.squeeze(video, axis=0) diff --git a/mlx_video/models/ltx/video_vae/decoder.py b/mlx_video/models/ltx/video_vae/decoder.py index 464b873..9a6cbb3 100644 --- a/mlx_video/models/ltx/video_vae/decoder.py +++ b/mlx_video/models/ltx/video_vae/decoder.py @@ -506,13 +506,8 @@ class LTX2VideoDecoder(nn.Module): if not needs_spatial_tiling and not needs_temporal_tiling: # No tiling needed, use regular decode - if debug: - print("[Tiling] Input fits within tile size, using regular decode") return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv) - if debug: - print(f"[Tiling] Using tiled decode (spatial={needs_spatial_tiling}, temporal={needs_temporal_tiling})") - return decode_with_tiling( decoder_fn=self, latents=sample, @@ -521,7 +516,6 @@ class LTX2VideoDecoder(nn.Module): temporal_scale=8, # VAE temporal upsampling factor causal=causal, timestep=timestep, - debug=debug, chunked_conv=use_chunked_conv, on_frames_ready=on_frames_ready, ) diff --git a/mlx_video/models/ltx/video_vae/tiling.py b/mlx_video/models/ltx/video_vae/tiling.py index ee55026..72d32e4 100644 --- a/mlx_video/models/ltx/video_vae/tiling.py +++ b/mlx_video/models/ltx/video_vae/tiling.py @@ -284,7 +284,6 @@ def decode_with_tiling( temporal_scale: int = 8, causal: bool = False, timestep: Optional[mx.array] = None, - debug: bool = False, chunked_conv: bool = False, on_frames_ready: Optional[Callable[[mx.array, int], None]] = None, ) -> mx.array: @@ -298,7 +297,6 @@ def decode_with_tiling( temporal_scale: Temporal scale factor (8 for LTX VAE). causal: Whether to use causal convolutions. timestep: Optional timestep for conditioning. - debug: Whether to print debug info. chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory). on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized. frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames. @@ -343,10 +341,6 @@ def decode_with_tiling( num_w_tiles = len(width_intervals.starts) total_tiles = num_t_tiles * num_h_tiles * num_w_tiles - if debug: - print(f"[Tiling] Latent shape: {latents.shape}, Output shape: ({b}, 3, {out_f}, {out_h}, {out_w})") - print(f"[Tiling] Tiles: {num_t_tiles} temporal x {num_h_tiles} height x {num_w_tiles} width = {total_tiles}") - # Initialize output and weight accumulator # Use float32 for accumulation to avoid precision issues output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32) @@ -381,10 +375,6 @@ def decode_with_tiling( # Map width coordinates out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale) - if debug: - print(f"[Tiling] Tile {tile_idx + 1}/{total_tiles}: " - f"latent t=[{t_start},{t_end}) h=[{h_start},{h_end}) w=[{w_start},{w_end})") - # Extract tile latents (small slice) tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end] @@ -487,9 +477,6 @@ def decode_with_tiling( finalized_output = finalized_output.astype(latents.dtype) mx.eval(finalized_output) - if debug: - print(f"[Tiling] Emitting finalized frames {emitted}-{next_tile_start_out - 1}") - on_frames_ready(finalized_output, emitted) decode_with_tiling._emitted_frames = next_tile_start_out @@ -507,8 +494,6 @@ def decode_with_tiling( if emitted < out_f: remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype) mx.eval(remaining_output) - if debug: - print(f"[Tiling] Emitting remaining frames {emitted}-{out_f - 1}") on_frames_ready(remaining_output, emitted) del remaining_output @@ -520,8 +505,5 @@ def decode_with_tiling( del weights gc.collect() - if debug: - print(f"[Tiling] Done. Final shape: {output.shape}") - # Convert back to original dtype if needed return output.astype(latents.dtype)