Add streaming support to video generation
This commit is contained in:
@@ -215,6 +215,7 @@ def generate_video(
|
||||
image_strength: float = 1.0,
|
||||
image_frame_idx: int = 0,
|
||||
tiling: str = "auto",
|
||||
stream: bool = False,
|
||||
):
|
||||
"""Generate video from text prompt, optionally conditioned on an image.
|
||||
|
||||
@@ -481,39 +482,79 @@ def generate_video(
|
||||
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
|
||||
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||
|
||||
# Save outputs
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Stream mode: write frames as they're decoded
|
||||
video_writer = None
|
||||
frames_written = [0] # Use list to allow mutation in closure
|
||||
|
||||
if stream and tiling_config is not None:
|
||||
import cv2
|
||||
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
||||
video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
|
||||
|
||||
def on_frames_ready(frames: mx.array, start_idx: int):
|
||||
"""Callback to write frames as they're finalized."""
|
||||
# frames: (B, 3, num_frames, H, W)
|
||||
frames = mx.squeeze(frames, axis=0) # (3, num_frames, H, W)
|
||||
frames = mx.transpose(frames, (1, 2, 3, 0)) # (num_frames, H, W, 3)
|
||||
frames = mx.clip((frames + 1.0) / 2.0, 0.0, 1.0)
|
||||
frames = (frames * 255).astype(mx.uint8)
|
||||
frames_np = np.array(frames)
|
||||
|
||||
for frame in frames_np:
|
||||
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||
frames_written[0] += 1
|
||||
|
||||
print(f"{Colors.DIM} Progressive: wrote frames {start_idx}-{start_idx + len(frames_np) - 1} ({frames_written[0]} total){Colors.RESET}")
|
||||
|
||||
print(f"{Colors.MAGENTA}📹 Streaming enabled - frames will be written as decoded{Colors.RESET}")
|
||||
else:
|
||||
on_frames_ready = None
|
||||
|
||||
if tiling_config is not None:
|
||||
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
|
||||
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
|
||||
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
|
||||
video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, debug=verbose)
|
||||
video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose, on_frames_ready=on_frames_ready)
|
||||
else:
|
||||
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
|
||||
video = vae_decoder(latents)
|
||||
mx.eval(video)
|
||||
mx.clear_cache()
|
||||
|
||||
# Convert to uint8 frames
|
||||
video = mx.squeeze(video, axis=0) # (C, F, H, W)
|
||||
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
|
||||
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
|
||||
video = (video * 255).astype(mx.uint8)
|
||||
video_np = np.array(video)
|
||||
# Close progressive video writer if used
|
||||
if video_writer is not None:
|
||||
video_writer.release()
|
||||
print(f"{Colors.GREEN}✅ Streamed video to{Colors.RESET} {output_path}")
|
||||
# Still need video_np for save_frames option
|
||||
video = mx.squeeze(video, axis=0)
|
||||
video = mx.transpose(video, (1, 2, 3, 0))
|
||||
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
|
||||
video = (video * 255).astype(mx.uint8)
|
||||
video_np = np.array(video)
|
||||
else:
|
||||
# Convert to uint8 frames
|
||||
video = mx.squeeze(video, axis=0) # (C, F, H, W)
|
||||
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
|
||||
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
|
||||
video = (video * 255).astype(mx.uint8)
|
||||
video_np = np.array(video)
|
||||
|
||||
# Save outputs
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
import cv2
|
||||
height, width = video_np.shape[1], video_np.shape[2]
|
||||
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
||||
out = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
|
||||
for frame in video_np:
|
||||
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||
out.release()
|
||||
print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}")
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}")
|
||||
# Save video normally
|
||||
try:
|
||||
import cv2
|
||||
h, w = video_np.shape[1], video_np.shape[2]
|
||||
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
||||
out = cv2.VideoWriter(str(output_path), fourcc, fps, (w, h))
|
||||
for frame in video_np:
|
||||
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||
out.release()
|
||||
print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}")
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}")
|
||||
|
||||
if save_frames:
|
||||
frames_dir = output_path.parent / f"{output_path.stem}_frames"
|
||||
@@ -654,6 +695,11 @@ Examples:
|
||||
"auto=based on video size, none=disabled, default=512px/64f, "
|
||||
"aggressive=256px/32f (lowest memory), conservative=768px/96f, spatial=spatial only, temporal=temporal only"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stream",
|
||||
action="store_true",
|
||||
help="Stream frames to output file as they're decoded (requires tiling). Allows viewing partial results sooner."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
generate_video(
|
||||
|
||||
Reference in New Issue
Block a user