diff --git a/mlx_video/generate.py b/mlx_video/generate.py index a146031..43bdb70 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -75,6 +75,59 @@ def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array: return (scale - 1.0) * (cond - uncond) +def apg_delta( + cond: mx.array, + uncond: mx.array, + scale: float, + eta: float = 1.0, + norm_threshold: float = 0.0, +) -> mx.array: + """Compute APG (Adaptive Projected Guidance) delta. + + Decomposes guidance into parallel and orthogonal components relative to + the conditional prediction, providing more stable guidance for I2V. + + Based on: https://arxiv.org/abs/2407.12173 + + Args: + cond: Conditional prediction (x0_pos) + uncond: Unconditional prediction (x0_neg) + scale: Guidance strength (same as CFG scale) + eta: Weight for parallel component (1.0 = keep full parallel) + norm_threshold: Clamp guidance norm to this value (0 = no clamping) + + Returns: + Delta to add to unconditional for APG guidance + """ + guidance = cond - uncond + + # Optionally clamp guidance norm for stability + if norm_threshold > 0: + guidance_norm = mx.sqrt(mx.sum(guidance ** 2, axis=(-1, -2, -3), keepdims=True) + 1e-8) + scale_factor = mx.minimum(mx.ones_like(guidance_norm), norm_threshold / guidance_norm) + guidance = guidance * scale_factor + + # Project guidance onto cond direction + batch_size = cond.shape[0] + cond_flat = mx.reshape(cond, (batch_size, -1)) + guidance_flat = mx.reshape(guidance, (batch_size, -1)) + + # Projection coefficient: (guidance · cond) / (cond · cond) + dot_product = mx.sum(guidance_flat * cond_flat, axis=1, keepdims=True) + squared_norm = mx.sum(cond_flat ** 2, axis=1, keepdims=True) + 1e-8 + proj_coeff = dot_product / squared_norm + + # Reshape back and compute parallel/orthogonal components + proj_coeff = mx.reshape(proj_coeff, (batch_size,) + (1,) * (cond.ndim - 1)) + g_parallel = proj_coeff * cond + g_orth = guidance - g_parallel + + # Combine with eta weighting parallel component + g_apg = g_parallel * eta + g_orth + + return g_apg * (scale - 1.0) + + def ltx2_scheduler( steps: int, num_tokens: Optional[int] = None, @@ -371,8 +424,19 @@ def denoise_dev( cfg_scale: float = 4.0, verbose: bool = True, state: Optional[LatentState] = None, + use_apg: bool = False, + apg_eta: float = 1.0, + apg_norm_threshold: float = 0.0, ) -> mx.array: - """Run denoising loop for dev pipeline with CFG.""" + """Run denoising loop for dev pipeline with CFG or APG guidance. + + Args: + use_apg: Use Adaptive Projected Guidance instead of standard CFG. + APG decomposes guidance into parallel/orthogonal components + for more stable I2V generation. + apg_eta: APG parallel component weight (1.0 = keep full parallel) + apg_norm_threshold: APG guidance norm clamp (0 = no clamping) + """ from mlx_video.models.ltx.rope import precompute_freqs_cis dtype = latents.dtype @@ -465,9 +529,17 @@ def denoise_dev( # Convert negative velocity to x0 using per-token timesteps x0_neg_f32 = latents_flat_f32 - timesteps_f32 * velocity_neg.astype(mx.float32) - # Apply CFG to x0 predictions (matches PyTorch CFGGuider) + # Apply guidance to x0 predictions # For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 - x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32) + if use_apg: + # APG: decompose into parallel/orthogonal components for stability + x0_guided_f32 = x0_pos_f32 + apg_delta( + x0_pos_f32, x0_neg_f32, cfg_scale, + eta=apg_eta, norm_threshold=apg_norm_threshold + ) + else: + # Standard CFG + x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32) else: x0_guided_f32 = x0_pos_f32 @@ -507,13 +579,19 @@ def denoise_dev_av( cfg_rescale: float = 0.0, verbose: bool = True, video_state: Optional[LatentState] = None, + use_apg: bool = False, + apg_eta: float = 1.0, + apg_norm_threshold: float = 0.0, ) -> tuple[mx.array, mx.array]: - """Run denoising loop for dev pipeline with CFG and audio. + """Run denoising loop for dev pipeline with CFG/APG and audio. Args: cfg_rescale: Rescale factor for CFG (0.0-1.0). Higher values blend the CFG result towards the positive-only prediction, helping reduce artifacts. Default 0.0 means no rescaling (standard CFG). + use_apg: Use Adaptive Projected Guidance instead of standard CFG for video. + apg_eta: APG parallel component weight (1.0 = keep full parallel) + apg_norm_threshold: APG guidance norm clamp (0 = no clamping) """ from mlx_video.models.ltx.rope import precompute_freqs_cis @@ -638,10 +716,17 @@ def denoise_dev_av( video_x0_neg_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_neg.astype(mx.float32) audio_x0_neg_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_neg.astype(mx.float32) - # Apply CFG to x0 (denoised) predictions - matches PyTorch CFGGuider - # delta = (scale - 1) * (x0_pos - x0_neg) - # For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 (no CFG effect) - video_x0_guided_f32 = video_x0_pos_f32 + (cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32) + # Apply guidance to x0 (denoised) predictions + # For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 (no effect) + if use_apg: + # APG for video (more stable for I2V), standard CFG for audio + video_x0_guided_f32 = video_x0_pos_f32 + apg_delta( + video_x0_pos_f32, video_x0_neg_f32, cfg_scale, + eta=apg_eta, norm_threshold=apg_norm_threshold + ) + else: + video_x0_guided_f32 = video_x0_pos_f32 + (cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32) + # Always use standard CFG for audio audio_x0_guided_f32 = audio_x0_pos_f32 + (cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32) # Apply CFG rescale if enabled @@ -788,6 +873,9 @@ def generate_video( stream: bool = False, audio: bool = False, output_audio_path: Optional[str] = None, + use_apg: bool = False, + apg_eta: float = 1.0, + apg_norm_threshold: float = 0.0, ): """Generate video using LTX-2 models. @@ -821,6 +909,9 @@ def generate_video( stream: Stream frames to output as they're decoded audio: Enable synchronized audio generation output_audio_path: Path to save audio file + use_apg: Use Adaptive Projected Guidance instead of CFG (more stable for I2V) + apg_eta: APG parallel component weight (1.0 = keep full parallel) + apg_norm_threshold: APG guidance norm clamp (0 = no clamping) """ start_time = time.time() @@ -1080,9 +1171,9 @@ def generate_video( mx.clear_cache() console.print("[green]✓[/] VAE encoder loaded and image encoded") - # Generate sigma schedule (uses MAX_SHIFT_ANCHOR=4096 like the reference implementation) + # Generate sigma schedule with token-count-dependent shifting num_tokens = latent_frames * latent_h * latent_w - sigmas = ltx2_scheduler(steps=num_inference_steps) + sigmas = ltx2_scheduler(steps=num_inference_steps, num_tokens=num_tokens) mx.eval(sigmas) console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]") @@ -1125,7 +1216,7 @@ def generate_video( latents = mx.random.normal(video_latent_shape, dtype=model_dtype) mx.eval(latents) - # Denoise with CFG + # Denoise with CFG/APG if audio: latents, audio_latents = denoise_dev_av( latents, audio_latents, @@ -1133,7 +1224,8 @@ def generate_video( video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg, transformer, sigmas, cfg_scale=cfg_scale, - cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state + cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state, + use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold ) else: # Use original denoise_dev with computed sigmas @@ -1141,7 +1233,8 @@ def generate_video( latents, video_positions, video_embeddings_pos, video_embeddings_neg, transformer, sigmas, cfg_scale=cfg_scale, - verbose=verbose, state=video_state + verbose=verbose, state=video_state, + use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold ) # Load VAE decoder (for dev pipeline, loaded here instead of during upsampling) @@ -1371,6 +1464,9 @@ Examples: parser.add_argument("--stream", action="store_true", help="Stream frames to output as they're decoded") parser.add_argument("--audio", "-a", action="store_true", help="Enable synchronized audio generation") parser.add_argument("--output-audio", type=str, default=None, help="Output audio path") + parser.add_argument("--apg", action="store_true", help="Use Adaptive Projected Guidance instead of CFG (more stable for I2V)") + parser.add_argument("--apg-eta", type=float, default=1.0, help="APG parallel component weight (1.0 = keep full parallel)") + parser.add_argument("--apg-norm-threshold", type=float, default=0.0, help="APG guidance norm clamp (0 = no clamping)") args = parser.parse_args() pipeline = PipelineType.DEV if args.pipeline == "dev" else PipelineType.DISTILLED @@ -1402,6 +1498,9 @@ Examples: stream=args.stream, audio=args.audio, output_audio_path=args.output_audio, + use_apg=args.apg, + apg_eta=args.apg_eta, + apg_norm_threshold=args.apg_norm_threshold, )