Enhance video generation with progress bar for streaming and remove debug prints from tiling decoder
This commit is contained in:
@@ -284,7 +284,6 @@ def decode_with_tiling(
|
||||
temporal_scale: int = 8,
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
debug: bool = False,
|
||||
chunked_conv: bool = False,
|
||||
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
|
||||
) -> mx.array:
|
||||
@@ -298,7 +297,6 @@ def decode_with_tiling(
|
||||
temporal_scale: Temporal scale factor (8 for LTX VAE).
|
||||
causal: Whether to use causal convolutions.
|
||||
timestep: Optional timestep for conditioning.
|
||||
debug: Whether to print debug info.
|
||||
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
|
||||
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
|
||||
frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames.
|
||||
@@ -343,10 +341,6 @@ def decode_with_tiling(
|
||||
num_w_tiles = len(width_intervals.starts)
|
||||
total_tiles = num_t_tiles * num_h_tiles * num_w_tiles
|
||||
|
||||
if debug:
|
||||
print(f"[Tiling] Latent shape: {latents.shape}, Output shape: ({b}, 3, {out_f}, {out_h}, {out_w})")
|
||||
print(f"[Tiling] Tiles: {num_t_tiles} temporal x {num_h_tiles} height x {num_w_tiles} width = {total_tiles}")
|
||||
|
||||
# Initialize output and weight accumulator
|
||||
# Use float32 for accumulation to avoid precision issues
|
||||
output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32)
|
||||
@@ -381,10 +375,6 @@ def decode_with_tiling(
|
||||
# Map width coordinates
|
||||
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
|
||||
|
||||
if debug:
|
||||
print(f"[Tiling] Tile {tile_idx + 1}/{total_tiles}: "
|
||||
f"latent t=[{t_start},{t_end}) h=[{h_start},{h_end}) w=[{w_start},{w_end})")
|
||||
|
||||
# Extract tile latents (small slice)
|
||||
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
|
||||
|
||||
@@ -487,9 +477,6 @@ def decode_with_tiling(
|
||||
finalized_output = finalized_output.astype(latents.dtype)
|
||||
mx.eval(finalized_output)
|
||||
|
||||
if debug:
|
||||
print(f"[Tiling] Emitting finalized frames {emitted}-{next_tile_start_out - 1}")
|
||||
|
||||
on_frames_ready(finalized_output, emitted)
|
||||
decode_with_tiling._emitted_frames = next_tile_start_out
|
||||
|
||||
@@ -507,8 +494,6 @@ def decode_with_tiling(
|
||||
if emitted < out_f:
|
||||
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
||||
mx.eval(remaining_output)
|
||||
if debug:
|
||||
print(f"[Tiling] Emitting remaining frames {emitted}-{out_f - 1}")
|
||||
on_frames_ready(remaining_output, emitted)
|
||||
del remaining_output
|
||||
|
||||
@@ -520,8 +505,5 @@ def decode_with_tiling(
|
||||
del weights
|
||||
gc.collect()
|
||||
|
||||
if debug:
|
||||
print(f"[Tiling] Done. Final shape: {output.shape}")
|
||||
|
||||
# Convert back to original dtype if needed
|
||||
return output.astype(latents.dtype)
|
||||
|
||||
Reference in New Issue
Block a user