Add Adaptive Projected Guidance (APG) support to denoising functions. Introduce apg_delta function for stable guidance by decomposing into parallel and orthogonal components. Update denoise_dev and generate_video functions to accept APG parameters, enhancing flexibility in video generation. Modify command-line arguments for APG integration.
This commit is contained in:
@@ -75,6 +75,59 @@ def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array:
|
||||
return (scale - 1.0) * (cond - uncond)
|
||||
|
||||
|
||||
def apg_delta(
|
||||
cond: mx.array,
|
||||
uncond: mx.array,
|
||||
scale: float,
|
||||
eta: float = 1.0,
|
||||
norm_threshold: float = 0.0,
|
||||
) -> mx.array:
|
||||
"""Compute APG (Adaptive Projected Guidance) delta.
|
||||
|
||||
Decomposes guidance into parallel and orthogonal components relative to
|
||||
the conditional prediction, providing more stable guidance for I2V.
|
||||
|
||||
Based on: https://arxiv.org/abs/2407.12173
|
||||
|
||||
Args:
|
||||
cond: Conditional prediction (x0_pos)
|
||||
uncond: Unconditional prediction (x0_neg)
|
||||
scale: Guidance strength (same as CFG scale)
|
||||
eta: Weight for parallel component (1.0 = keep full parallel)
|
||||
norm_threshold: Clamp guidance norm to this value (0 = no clamping)
|
||||
|
||||
Returns:
|
||||
Delta to add to unconditional for APG guidance
|
||||
"""
|
||||
guidance = cond - uncond
|
||||
|
||||
# Optionally clamp guidance norm for stability
|
||||
if norm_threshold > 0:
|
||||
guidance_norm = mx.sqrt(mx.sum(guidance ** 2, axis=(-1, -2, -3), keepdims=True) + 1e-8)
|
||||
scale_factor = mx.minimum(mx.ones_like(guidance_norm), norm_threshold / guidance_norm)
|
||||
guidance = guidance * scale_factor
|
||||
|
||||
# Project guidance onto cond direction
|
||||
batch_size = cond.shape[0]
|
||||
cond_flat = mx.reshape(cond, (batch_size, -1))
|
||||
guidance_flat = mx.reshape(guidance, (batch_size, -1))
|
||||
|
||||
# Projection coefficient: (guidance · cond) / (cond · cond)
|
||||
dot_product = mx.sum(guidance_flat * cond_flat, axis=1, keepdims=True)
|
||||
squared_norm = mx.sum(cond_flat ** 2, axis=1, keepdims=True) + 1e-8
|
||||
proj_coeff = dot_product / squared_norm
|
||||
|
||||
# Reshape back and compute parallel/orthogonal components
|
||||
proj_coeff = mx.reshape(proj_coeff, (batch_size,) + (1,) * (cond.ndim - 1))
|
||||
g_parallel = proj_coeff * cond
|
||||
g_orth = guidance - g_parallel
|
||||
|
||||
# Combine with eta weighting parallel component
|
||||
g_apg = g_parallel * eta + g_orth
|
||||
|
||||
return g_apg * (scale - 1.0)
|
||||
|
||||
|
||||
def ltx2_scheduler(
|
||||
steps: int,
|
||||
num_tokens: Optional[int] = None,
|
||||
@@ -371,8 +424,19 @@ def denoise_dev(
|
||||
cfg_scale: float = 4.0,
|
||||
verbose: bool = True,
|
||||
state: Optional[LatentState] = None,
|
||||
use_apg: bool = False,
|
||||
apg_eta: float = 1.0,
|
||||
apg_norm_threshold: float = 0.0,
|
||||
) -> mx.array:
|
||||
"""Run denoising loop for dev pipeline with CFG."""
|
||||
"""Run denoising loop for dev pipeline with CFG or APG guidance.
|
||||
|
||||
Args:
|
||||
use_apg: Use Adaptive Projected Guidance instead of standard CFG.
|
||||
APG decomposes guidance into parallel/orthogonal components
|
||||
for more stable I2V generation.
|
||||
apg_eta: APG parallel component weight (1.0 = keep full parallel)
|
||||
apg_norm_threshold: APG guidance norm clamp (0 = no clamping)
|
||||
"""
|
||||
from mlx_video.models.ltx.rope import precompute_freqs_cis
|
||||
|
||||
dtype = latents.dtype
|
||||
@@ -465,9 +529,17 @@ def denoise_dev(
|
||||
# Convert negative velocity to x0 using per-token timesteps
|
||||
x0_neg_f32 = latents_flat_f32 - timesteps_f32 * velocity_neg.astype(mx.float32)
|
||||
|
||||
# Apply CFG to x0 predictions (matches PyTorch CFGGuider)
|
||||
# Apply guidance to x0 predictions
|
||||
# For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0
|
||||
x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32)
|
||||
if use_apg:
|
||||
# APG: decompose into parallel/orthogonal components for stability
|
||||
x0_guided_f32 = x0_pos_f32 + apg_delta(
|
||||
x0_pos_f32, x0_neg_f32, cfg_scale,
|
||||
eta=apg_eta, norm_threshold=apg_norm_threshold
|
||||
)
|
||||
else:
|
||||
# Standard CFG
|
||||
x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32)
|
||||
else:
|
||||
x0_guided_f32 = x0_pos_f32
|
||||
|
||||
@@ -507,13 +579,19 @@ def denoise_dev_av(
|
||||
cfg_rescale: float = 0.0,
|
||||
verbose: bool = True,
|
||||
video_state: Optional[LatentState] = None,
|
||||
use_apg: bool = False,
|
||||
apg_eta: float = 1.0,
|
||||
apg_norm_threshold: float = 0.0,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Run denoising loop for dev pipeline with CFG and audio.
|
||||
"""Run denoising loop for dev pipeline with CFG/APG and audio.
|
||||
|
||||
Args:
|
||||
cfg_rescale: Rescale factor for CFG (0.0-1.0). Higher values blend the CFG result
|
||||
towards the positive-only prediction, helping reduce artifacts.
|
||||
Default 0.0 means no rescaling (standard CFG).
|
||||
use_apg: Use Adaptive Projected Guidance instead of standard CFG for video.
|
||||
apg_eta: APG parallel component weight (1.0 = keep full parallel)
|
||||
apg_norm_threshold: APG guidance norm clamp (0 = no clamping)
|
||||
"""
|
||||
from mlx_video.models.ltx.rope import precompute_freqs_cis
|
||||
|
||||
@@ -638,10 +716,17 @@ def denoise_dev_av(
|
||||
video_x0_neg_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_neg.astype(mx.float32)
|
||||
audio_x0_neg_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_neg.astype(mx.float32)
|
||||
|
||||
# Apply CFG to x0 (denoised) predictions - matches PyTorch CFGGuider
|
||||
# delta = (scale - 1) * (x0_pos - x0_neg)
|
||||
# For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 (no CFG effect)
|
||||
video_x0_guided_f32 = video_x0_pos_f32 + (cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32)
|
||||
# Apply guidance to x0 (denoised) predictions
|
||||
# For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 (no effect)
|
||||
if use_apg:
|
||||
# APG for video (more stable for I2V), standard CFG for audio
|
||||
video_x0_guided_f32 = video_x0_pos_f32 + apg_delta(
|
||||
video_x0_pos_f32, video_x0_neg_f32, cfg_scale,
|
||||
eta=apg_eta, norm_threshold=apg_norm_threshold
|
||||
)
|
||||
else:
|
||||
video_x0_guided_f32 = video_x0_pos_f32 + (cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32)
|
||||
# Always use standard CFG for audio
|
||||
audio_x0_guided_f32 = audio_x0_pos_f32 + (cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
|
||||
|
||||
# Apply CFG rescale if enabled
|
||||
@@ -788,6 +873,9 @@ def generate_video(
|
||||
stream: bool = False,
|
||||
audio: bool = False,
|
||||
output_audio_path: Optional[str] = None,
|
||||
use_apg: bool = False,
|
||||
apg_eta: float = 1.0,
|
||||
apg_norm_threshold: float = 0.0,
|
||||
):
|
||||
"""Generate video using LTX-2 models.
|
||||
|
||||
@@ -821,6 +909,9 @@ def generate_video(
|
||||
stream: Stream frames to output as they're decoded
|
||||
audio: Enable synchronized audio generation
|
||||
output_audio_path: Path to save audio file
|
||||
use_apg: Use Adaptive Projected Guidance instead of CFG (more stable for I2V)
|
||||
apg_eta: APG parallel component weight (1.0 = keep full parallel)
|
||||
apg_norm_threshold: APG guidance norm clamp (0 = no clamping)
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -1080,9 +1171,9 @@ def generate_video(
|
||||
mx.clear_cache()
|
||||
console.print("[green]✓[/] VAE encoder loaded and image encoded")
|
||||
|
||||
# Generate sigma schedule (uses MAX_SHIFT_ANCHOR=4096 like the reference implementation)
|
||||
# Generate sigma schedule with token-count-dependent shifting
|
||||
num_tokens = latent_frames * latent_h * latent_w
|
||||
sigmas = ltx2_scheduler(steps=num_inference_steps)
|
||||
sigmas = ltx2_scheduler(steps=num_inference_steps, num_tokens=num_tokens)
|
||||
mx.eval(sigmas)
|
||||
console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]")
|
||||
|
||||
@@ -1125,7 +1216,7 @@ def generate_video(
|
||||
latents = mx.random.normal(video_latent_shape, dtype=model_dtype)
|
||||
mx.eval(latents)
|
||||
|
||||
# Denoise with CFG
|
||||
# Denoise with CFG/APG
|
||||
if audio:
|
||||
latents, audio_latents = denoise_dev_av(
|
||||
latents, audio_latents,
|
||||
@@ -1133,7 +1224,8 @@ def generate_video(
|
||||
video_embeddings_pos, video_embeddings_neg,
|
||||
audio_embeddings_pos, audio_embeddings_neg,
|
||||
transformer, sigmas, cfg_scale=cfg_scale,
|
||||
cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state
|
||||
cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state,
|
||||
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold
|
||||
)
|
||||
else:
|
||||
# Use original denoise_dev with computed sigmas
|
||||
@@ -1141,7 +1233,8 @@ def generate_video(
|
||||
latents, video_positions,
|
||||
video_embeddings_pos, video_embeddings_neg,
|
||||
transformer, sigmas, cfg_scale=cfg_scale,
|
||||
verbose=verbose, state=video_state
|
||||
verbose=verbose, state=video_state,
|
||||
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold
|
||||
)
|
||||
|
||||
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
|
||||
@@ -1371,6 +1464,9 @@ Examples:
|
||||
parser.add_argument("--stream", action="store_true", help="Stream frames to output as they're decoded")
|
||||
parser.add_argument("--audio", "-a", action="store_true", help="Enable synchronized audio generation")
|
||||
parser.add_argument("--output-audio", type=str, default=None, help="Output audio path")
|
||||
parser.add_argument("--apg", action="store_true", help="Use Adaptive Projected Guidance instead of CFG (more stable for I2V)")
|
||||
parser.add_argument("--apg-eta", type=float, default=1.0, help="APG parallel component weight (1.0 = keep full parallel)")
|
||||
parser.add_argument("--apg-norm-threshold", type=float, default=0.0, help="APG guidance norm clamp (0 = no clamping)")
|
||||
args = parser.parse_args()
|
||||
|
||||
pipeline = PipelineType.DEV if args.pipeline == "dev" else PipelineType.DISTILLED
|
||||
@@ -1402,6 +1498,9 @@ Examples:
|
||||
stream=args.stream,
|
||||
audio=args.audio,
|
||||
output_audio_path=args.output_audio,
|
||||
use_apg=args.apg,
|
||||
apg_eta=args.apg_eta,
|
||||
apg_norm_threshold=args.apg_norm_threshold,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user