Update .gitignore to exclude additional configuration and model files. Modify generate.py to enhance console output with rescale parameter and adjust default values for inference steps and CFG scale. Refactor text encoder to align positional embedding max position with PyTorch defaults, improving compatibility and performance.

This commit is contained in:
Prince Canuma
2026-03-12 17:13:43 +01:00
parent d1fa47722b
commit b07b1e3213
3 changed files with 43 additions and 33 deletions

View File

@@ -938,7 +938,7 @@ def generate_video(
console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]")
if pipeline == PipelineType.DEV:
console.print(f"[dim]Steps: {num_inference_steps}, CFG: {cfg_scale}[/]")
console.print(f"[dim]Steps: {num_inference_steps}, CFG: {cfg_scale}, Rescale: {cfg_rescale}[/]")
if is_i2v:
console.print(f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]")
@@ -1188,7 +1188,7 @@ def generate_video(
mx.eval(sigmas)
console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f}{sigmas[-2].item():.4f}{sigmas[-1].item():.4f}[/]")
console.print(f"\n[bold yellow]⚡ Generating:[/] {width}x{height} ({num_inference_steps} steps, CFG={cfg_scale})")
console.print(f"\n[bold yellow]⚡ Generating:[/] {width}x{height} ({num_inference_steps} steps, CFG={cfg_scale}, rescale={cfg_rescale})")
mx.random.seed(seed)
video_positions = create_position_grid(1, latent_frames, latent_h, latent_w)
@@ -1432,8 +1432,8 @@ Examples:
python -m mlx_video.generate --prompt "Ocean waves" --pipeline distilled
# Dev pipeline (single-stage, CFG, higher quality)
python -m mlx_video.generate --prompt "A cat walking" --pipeline dev --cfg-scale 4.0
python -m mlx_video.generate --prompt "Ocean waves" --pipeline dev --steps 50
python -m mlx_video.generate --prompt "A cat walking" --pipeline dev --cfg-scale 3.0
python -m mlx_video.generate --prompt "Ocean waves" --pipeline dev --steps 40
# Image-to-Video (works with both pipelines)
python -m mlx_video.generate --prompt "A person dancing" --image photo.jpg
@@ -1453,9 +1453,9 @@ Examples:
parser.add_argument("--height", "-H", type=int, default=512, help="Output video height")
parser.add_argument("--width", "-W", type=int, default=512, help="Output video width")
parser.add_argument("--num-frames", "-n", type=int, default=33, help="Number of frames")
parser.add_argument("--steps", type=int, default=40, help="Number of inference steps (dev pipeline only)")
parser.add_argument("--cfg-scale", type=float, default=4.0, help="CFG guidance scale (dev pipeline only)")
parser.add_argument("--cfg-rescale", type=float, default=0.0, help="CFG rescale factor (0.0-1.0). Higher values reduce artifacts by blending towards positive-only prediction (dev pipeline only)")
parser.add_argument("--steps", type=int, default=30, help="Number of inference steps (dev pipeline only, default 30)")
parser.add_argument("--cfg-scale", type=float, default=3.0, help="CFG guidance scale (dev pipeline only, default 3.0)")
parser.add_argument("--cfg-rescale", type=float, default=0.7, help="CFG rescale factor (0.0-1.0). Normalizes guided prediction variance to reduce artifacts (dev pipeline only, default 0.7)")
parser.add_argument("--seed", "-s", type=int, default=42, help="Random seed")
parser.add_argument("--fps", type=int, default=24, help="Frames per second")
parser.add_argument("--output-path", "-o", type=str, default="output.mp4", help="Output video path")