Add Adaptive Projected Guidance (APG) support to denoising functions. Introduce apg_delta function for stable guidance by decomposing into parallel and orthogonal components. Update denoise_dev and generate_video functions to accept APG parameters, enhancing flexibility in video generation. Modify command-line arguments for APG integration.

This commit is contained in:
Prince Canuma
2026-01-26 21:35:58 +01:00
parent 87962c7f83
commit d1dd30cbac

View File

@@ -75,6 +75,59 @@ def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array:
return (scale - 1.0) * (cond - uncond) return (scale - 1.0) * (cond - uncond)
def apg_delta(
cond: mx.array,
uncond: mx.array,
scale: float,
eta: float = 1.0,
norm_threshold: float = 0.0,
) -> mx.array:
"""Compute APG (Adaptive Projected Guidance) delta.
Decomposes guidance into parallel and orthogonal components relative to
the conditional prediction, providing more stable guidance for I2V.
Based on: https://arxiv.org/abs/2407.12173
Args:
cond: Conditional prediction (x0_pos)
uncond: Unconditional prediction (x0_neg)
scale: Guidance strength (same as CFG scale)
eta: Weight for parallel component (1.0 = keep full parallel)
norm_threshold: Clamp guidance norm to this value (0 = no clamping)
Returns:
Delta to add to unconditional for APG guidance
"""
guidance = cond - uncond
# Optionally clamp guidance norm for stability
if norm_threshold > 0:
guidance_norm = mx.sqrt(mx.sum(guidance ** 2, axis=(-1, -2, -3), keepdims=True) + 1e-8)
scale_factor = mx.minimum(mx.ones_like(guidance_norm), norm_threshold / guidance_norm)
guidance = guidance * scale_factor
# Project guidance onto cond direction
batch_size = cond.shape[0]
cond_flat = mx.reshape(cond, (batch_size, -1))
guidance_flat = mx.reshape(guidance, (batch_size, -1))
# Projection coefficient: (guidance · cond) / (cond · cond)
dot_product = mx.sum(guidance_flat * cond_flat, axis=1, keepdims=True)
squared_norm = mx.sum(cond_flat ** 2, axis=1, keepdims=True) + 1e-8
proj_coeff = dot_product / squared_norm
# Reshape back and compute parallel/orthogonal components
proj_coeff = mx.reshape(proj_coeff, (batch_size,) + (1,) * (cond.ndim - 1))
g_parallel = proj_coeff * cond
g_orth = guidance - g_parallel
# Combine with eta weighting parallel component
g_apg = g_parallel * eta + g_orth
return g_apg * (scale - 1.0)
def ltx2_scheduler( def ltx2_scheduler(
steps: int, steps: int,
num_tokens: Optional[int] = None, num_tokens: Optional[int] = None,
@@ -371,8 +424,19 @@ def denoise_dev(
cfg_scale: float = 4.0, cfg_scale: float = 4.0,
verbose: bool = True, verbose: bool = True,
state: Optional[LatentState] = None, state: Optional[LatentState] = None,
use_apg: bool = False,
apg_eta: float = 1.0,
apg_norm_threshold: float = 0.0,
) -> mx.array: ) -> mx.array:
"""Run denoising loop for dev pipeline with CFG.""" """Run denoising loop for dev pipeline with CFG or APG guidance.
Args:
use_apg: Use Adaptive Projected Guidance instead of standard CFG.
APG decomposes guidance into parallel/orthogonal components
for more stable I2V generation.
apg_eta: APG parallel component weight (1.0 = keep full parallel)
apg_norm_threshold: APG guidance norm clamp (0 = no clamping)
"""
from mlx_video.models.ltx.rope import precompute_freqs_cis from mlx_video.models.ltx.rope import precompute_freqs_cis
dtype = latents.dtype dtype = latents.dtype
@@ -465,8 +529,16 @@ def denoise_dev(
# Convert negative velocity to x0 using per-token timesteps # Convert negative velocity to x0 using per-token timesteps
x0_neg_f32 = latents_flat_f32 - timesteps_f32 * velocity_neg.astype(mx.float32) x0_neg_f32 = latents_flat_f32 - timesteps_f32 * velocity_neg.astype(mx.float32)
# Apply CFG to x0 predictions (matches PyTorch CFGGuider) # Apply guidance to x0 predictions
# For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 # For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0
if use_apg:
# APG: decompose into parallel/orthogonal components for stability
x0_guided_f32 = x0_pos_f32 + apg_delta(
x0_pos_f32, x0_neg_f32, cfg_scale,
eta=apg_eta, norm_threshold=apg_norm_threshold
)
else:
# Standard CFG
x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32) x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32)
else: else:
x0_guided_f32 = x0_pos_f32 x0_guided_f32 = x0_pos_f32
@@ -507,13 +579,19 @@ def denoise_dev_av(
cfg_rescale: float = 0.0, cfg_rescale: float = 0.0,
verbose: bool = True, verbose: bool = True,
video_state: Optional[LatentState] = None, video_state: Optional[LatentState] = None,
use_apg: bool = False,
apg_eta: float = 1.0,
apg_norm_threshold: float = 0.0,
) -> tuple[mx.array, mx.array]: ) -> tuple[mx.array, mx.array]:
"""Run denoising loop for dev pipeline with CFG and audio. """Run denoising loop for dev pipeline with CFG/APG and audio.
Args: Args:
cfg_rescale: Rescale factor for CFG (0.0-1.0). Higher values blend the CFG result cfg_rescale: Rescale factor for CFG (0.0-1.0). Higher values blend the CFG result
towards the positive-only prediction, helping reduce artifacts. towards the positive-only prediction, helping reduce artifacts.
Default 0.0 means no rescaling (standard CFG). Default 0.0 means no rescaling (standard CFG).
use_apg: Use Adaptive Projected Guidance instead of standard CFG for video.
apg_eta: APG parallel component weight (1.0 = keep full parallel)
apg_norm_threshold: APG guidance norm clamp (0 = no clamping)
""" """
from mlx_video.models.ltx.rope import precompute_freqs_cis from mlx_video.models.ltx.rope import precompute_freqs_cis
@@ -638,10 +716,17 @@ def denoise_dev_av(
video_x0_neg_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_neg.astype(mx.float32) video_x0_neg_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_neg.astype(mx.float32)
audio_x0_neg_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_neg.astype(mx.float32) audio_x0_neg_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_neg.astype(mx.float32)
# Apply CFG to x0 (denoised) predictions - matches PyTorch CFGGuider # Apply guidance to x0 (denoised) predictions
# delta = (scale - 1) * (x0_pos - x0_neg) # For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 (no effect)
# For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 (no CFG effect) if use_apg:
# APG for video (more stable for I2V), standard CFG for audio
video_x0_guided_f32 = video_x0_pos_f32 + apg_delta(
video_x0_pos_f32, video_x0_neg_f32, cfg_scale,
eta=apg_eta, norm_threshold=apg_norm_threshold
)
else:
video_x0_guided_f32 = video_x0_pos_f32 + (cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32) video_x0_guided_f32 = video_x0_pos_f32 + (cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32)
# Always use standard CFG for audio
audio_x0_guided_f32 = audio_x0_pos_f32 + (cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32) audio_x0_guided_f32 = audio_x0_pos_f32 + (cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
# Apply CFG rescale if enabled # Apply CFG rescale if enabled
@@ -788,6 +873,9 @@ def generate_video(
stream: bool = False, stream: bool = False,
audio: bool = False, audio: bool = False,
output_audio_path: Optional[str] = None, output_audio_path: Optional[str] = None,
use_apg: bool = False,
apg_eta: float = 1.0,
apg_norm_threshold: float = 0.0,
): ):
"""Generate video using LTX-2 models. """Generate video using LTX-2 models.
@@ -821,6 +909,9 @@ def generate_video(
stream: Stream frames to output as they're decoded stream: Stream frames to output as they're decoded
audio: Enable synchronized audio generation audio: Enable synchronized audio generation
output_audio_path: Path to save audio file output_audio_path: Path to save audio file
use_apg: Use Adaptive Projected Guidance instead of CFG (more stable for I2V)
apg_eta: APG parallel component weight (1.0 = keep full parallel)
apg_norm_threshold: APG guidance norm clamp (0 = no clamping)
""" """
start_time = time.time() start_time = time.time()
@@ -1080,9 +1171,9 @@ def generate_video(
mx.clear_cache() mx.clear_cache()
console.print("[green]✓[/] VAE encoder loaded and image encoded") console.print("[green]✓[/] VAE encoder loaded and image encoded")
# Generate sigma schedule (uses MAX_SHIFT_ANCHOR=4096 like the reference implementation) # Generate sigma schedule with token-count-dependent shifting
num_tokens = latent_frames * latent_h * latent_w num_tokens = latent_frames * latent_h * latent_w
sigmas = ltx2_scheduler(steps=num_inference_steps) sigmas = ltx2_scheduler(steps=num_inference_steps, num_tokens=num_tokens)
mx.eval(sigmas) mx.eval(sigmas)
console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f}{sigmas[-2].item():.4f}{sigmas[-1].item():.4f}[/]") console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f}{sigmas[-2].item():.4f}{sigmas[-1].item():.4f}[/]")
@@ -1125,7 +1216,7 @@ def generate_video(
latents = mx.random.normal(video_latent_shape, dtype=model_dtype) latents = mx.random.normal(video_latent_shape, dtype=model_dtype)
mx.eval(latents) mx.eval(latents)
# Denoise with CFG # Denoise with CFG/APG
if audio: if audio:
latents, audio_latents = denoise_dev_av( latents, audio_latents = denoise_dev_av(
latents, audio_latents, latents, audio_latents,
@@ -1133,7 +1224,8 @@ def generate_video(
video_embeddings_pos, video_embeddings_neg, video_embeddings_pos, video_embeddings_neg,
audio_embeddings_pos, audio_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale, transformer, sigmas, cfg_scale=cfg_scale,
cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state,
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold
) )
else: else:
# Use original denoise_dev with computed sigmas # Use original denoise_dev with computed sigmas
@@ -1141,7 +1233,8 @@ def generate_video(
latents, video_positions, latents, video_positions,
video_embeddings_pos, video_embeddings_neg, video_embeddings_pos, video_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale, transformer, sigmas, cfg_scale=cfg_scale,
verbose=verbose, state=video_state verbose=verbose, state=video_state,
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold
) )
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling) # Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
@@ -1371,6 +1464,9 @@ Examples:
parser.add_argument("--stream", action="store_true", help="Stream frames to output as they're decoded") parser.add_argument("--stream", action="store_true", help="Stream frames to output as they're decoded")
parser.add_argument("--audio", "-a", action="store_true", help="Enable synchronized audio generation") parser.add_argument("--audio", "-a", action="store_true", help="Enable synchronized audio generation")
parser.add_argument("--output-audio", type=str, default=None, help="Output audio path") parser.add_argument("--output-audio", type=str, default=None, help="Output audio path")
parser.add_argument("--apg", action="store_true", help="Use Adaptive Projected Guidance instead of CFG (more stable for I2V)")
parser.add_argument("--apg-eta", type=float, default=1.0, help="APG parallel component weight (1.0 = keep full parallel)")
parser.add_argument("--apg-norm-threshold", type=float, default=0.0, help="APG guidance norm clamp (0 = no clamping)")
args = parser.parse_args() args = parser.parse_args()
pipeline = PipelineType.DEV if args.pipeline == "dev" else PipelineType.DISTILLED pipeline = PipelineType.DEV if args.pipeline == "dev" else PipelineType.DISTILLED
@@ -1402,6 +1498,9 @@ Examples:
stream=args.stream, stream=args.stream,
audio=args.audio, audio=args.audio,
output_audio_path=args.output_audio, output_audio_path=args.output_audio,
use_apg=args.apg,
apg_eta=args.apg_eta,
apg_norm_threshold=args.apg_norm_threshold,
) )