diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 83abd5f..0a945cc 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -276,8 +276,13 @@ def generate_video( output_path.parent.mkdir(parents=True, exist_ok=True) try: - import imageio - imageio.mimwrite(str(output_path), video_np, fps=fps, codec='libx264') + import cv2 + height, width = video_np.shape[1], video_np.shape[2] + fourcc = cv2.VideoWriter_fourcc(*'avc1') + out = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height)) + for frame in video_np: + out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) + out.release() print(f"Saved video to {output_path}") except Exception as e: print(f"Could not save video: {e}") @@ -328,8 +333,8 @@ Examples: parser.add_argument( "--num-frames", "-n", type=int, - default=33, - help="Number of frames (default: 33, must be 1 + 8*k)" + default=100, + help="Number of frames (default: 100)" ) parser.add_argument( "--seed", "-s",