Add Adaptive Projected Guidance (APG) support to denoising functions. Introduce apg_delta function for stable guidance by decomposing into parallel and orthogonal components. Update denoise_dev and generate_video functions to accept APG parameters, enhancing flexibility in video generation. Modify command-line arguments for APG integration.
This commit is contained in:
@@ -75,6 +75,59 @@ def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array:
|
|||||||
return (scale - 1.0) * (cond - uncond)
|
return (scale - 1.0) * (cond - uncond)
|
||||||
|
|
||||||
|
|
||||||
|
def apg_delta(
|
||||||
|
cond: mx.array,
|
||||||
|
uncond: mx.array,
|
||||||
|
scale: float,
|
||||||
|
eta: float = 1.0,
|
||||||
|
norm_threshold: float = 0.0,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Compute APG (Adaptive Projected Guidance) delta.
|
||||||
|
|
||||||
|
Decomposes guidance into parallel and orthogonal components relative to
|
||||||
|
the conditional prediction, providing more stable guidance for I2V.
|
||||||
|
|
||||||
|
Based on: https://arxiv.org/abs/2407.12173
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cond: Conditional prediction (x0_pos)
|
||||||
|
uncond: Unconditional prediction (x0_neg)
|
||||||
|
scale: Guidance strength (same as CFG scale)
|
||||||
|
eta: Weight for parallel component (1.0 = keep full parallel)
|
||||||
|
norm_threshold: Clamp guidance norm to this value (0 = no clamping)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Delta to add to unconditional for APG guidance
|
||||||
|
"""
|
||||||
|
guidance = cond - uncond
|
||||||
|
|
||||||
|
# Optionally clamp guidance norm for stability
|
||||||
|
if norm_threshold > 0:
|
||||||
|
guidance_norm = mx.sqrt(mx.sum(guidance ** 2, axis=(-1, -2, -3), keepdims=True) + 1e-8)
|
||||||
|
scale_factor = mx.minimum(mx.ones_like(guidance_norm), norm_threshold / guidance_norm)
|
||||||
|
guidance = guidance * scale_factor
|
||||||
|
|
||||||
|
# Project guidance onto cond direction
|
||||||
|
batch_size = cond.shape[0]
|
||||||
|
cond_flat = mx.reshape(cond, (batch_size, -1))
|
||||||
|
guidance_flat = mx.reshape(guidance, (batch_size, -1))
|
||||||
|
|
||||||
|
# Projection coefficient: (guidance · cond) / (cond · cond)
|
||||||
|
dot_product = mx.sum(guidance_flat * cond_flat, axis=1, keepdims=True)
|
||||||
|
squared_norm = mx.sum(cond_flat ** 2, axis=1, keepdims=True) + 1e-8
|
||||||
|
proj_coeff = dot_product / squared_norm
|
||||||
|
|
||||||
|
# Reshape back and compute parallel/orthogonal components
|
||||||
|
proj_coeff = mx.reshape(proj_coeff, (batch_size,) + (1,) * (cond.ndim - 1))
|
||||||
|
g_parallel = proj_coeff * cond
|
||||||
|
g_orth = guidance - g_parallel
|
||||||
|
|
||||||
|
# Combine with eta weighting parallel component
|
||||||
|
g_apg = g_parallel * eta + g_orth
|
||||||
|
|
||||||
|
return g_apg * (scale - 1.0)
|
||||||
|
|
||||||
|
|
||||||
def ltx2_scheduler(
|
def ltx2_scheduler(
|
||||||
steps: int,
|
steps: int,
|
||||||
num_tokens: Optional[int] = None,
|
num_tokens: Optional[int] = None,
|
||||||
@@ -371,8 +424,19 @@ def denoise_dev(
|
|||||||
cfg_scale: float = 4.0,
|
cfg_scale: float = 4.0,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
state: Optional[LatentState] = None,
|
state: Optional[LatentState] = None,
|
||||||
|
use_apg: bool = False,
|
||||||
|
apg_eta: float = 1.0,
|
||||||
|
apg_norm_threshold: float = 0.0,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""Run denoising loop for dev pipeline with CFG."""
|
"""Run denoising loop for dev pipeline with CFG or APG guidance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_apg: Use Adaptive Projected Guidance instead of standard CFG.
|
||||||
|
APG decomposes guidance into parallel/orthogonal components
|
||||||
|
for more stable I2V generation.
|
||||||
|
apg_eta: APG parallel component weight (1.0 = keep full parallel)
|
||||||
|
apg_norm_threshold: APG guidance norm clamp (0 = no clamping)
|
||||||
|
"""
|
||||||
from mlx_video.models.ltx.rope import precompute_freqs_cis
|
from mlx_video.models.ltx.rope import precompute_freqs_cis
|
||||||
|
|
||||||
dtype = latents.dtype
|
dtype = latents.dtype
|
||||||
@@ -465,9 +529,17 @@ def denoise_dev(
|
|||||||
# Convert negative velocity to x0 using per-token timesteps
|
# Convert negative velocity to x0 using per-token timesteps
|
||||||
x0_neg_f32 = latents_flat_f32 - timesteps_f32 * velocity_neg.astype(mx.float32)
|
x0_neg_f32 = latents_flat_f32 - timesteps_f32 * velocity_neg.astype(mx.float32)
|
||||||
|
|
||||||
# Apply CFG to x0 predictions (matches PyTorch CFGGuider)
|
# Apply guidance to x0 predictions
|
||||||
# For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0
|
# For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0
|
||||||
x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32)
|
if use_apg:
|
||||||
|
# APG: decompose into parallel/orthogonal components for stability
|
||||||
|
x0_guided_f32 = x0_pos_f32 + apg_delta(
|
||||||
|
x0_pos_f32, x0_neg_f32, cfg_scale,
|
||||||
|
eta=apg_eta, norm_threshold=apg_norm_threshold
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Standard CFG
|
||||||
|
x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32)
|
||||||
else:
|
else:
|
||||||
x0_guided_f32 = x0_pos_f32
|
x0_guided_f32 = x0_pos_f32
|
||||||
|
|
||||||
@@ -507,13 +579,19 @@ def denoise_dev_av(
|
|||||||
cfg_rescale: float = 0.0,
|
cfg_rescale: float = 0.0,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
video_state: Optional[LatentState] = None,
|
video_state: Optional[LatentState] = None,
|
||||||
|
use_apg: bool = False,
|
||||||
|
apg_eta: float = 1.0,
|
||||||
|
apg_norm_threshold: float = 0.0,
|
||||||
) -> tuple[mx.array, mx.array]:
|
) -> tuple[mx.array, mx.array]:
|
||||||
"""Run denoising loop for dev pipeline with CFG and audio.
|
"""Run denoising loop for dev pipeline with CFG/APG and audio.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg_rescale: Rescale factor for CFG (0.0-1.0). Higher values blend the CFG result
|
cfg_rescale: Rescale factor for CFG (0.0-1.0). Higher values blend the CFG result
|
||||||
towards the positive-only prediction, helping reduce artifacts.
|
towards the positive-only prediction, helping reduce artifacts.
|
||||||
Default 0.0 means no rescaling (standard CFG).
|
Default 0.0 means no rescaling (standard CFG).
|
||||||
|
use_apg: Use Adaptive Projected Guidance instead of standard CFG for video.
|
||||||
|
apg_eta: APG parallel component weight (1.0 = keep full parallel)
|
||||||
|
apg_norm_threshold: APG guidance norm clamp (0 = no clamping)
|
||||||
"""
|
"""
|
||||||
from mlx_video.models.ltx.rope import precompute_freqs_cis
|
from mlx_video.models.ltx.rope import precompute_freqs_cis
|
||||||
|
|
||||||
@@ -638,10 +716,17 @@ def denoise_dev_av(
|
|||||||
video_x0_neg_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_neg.astype(mx.float32)
|
video_x0_neg_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_neg.astype(mx.float32)
|
||||||
audio_x0_neg_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_neg.astype(mx.float32)
|
audio_x0_neg_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_neg.astype(mx.float32)
|
||||||
|
|
||||||
# Apply CFG to x0 (denoised) predictions - matches PyTorch CFGGuider
|
# Apply guidance to x0 (denoised) predictions
|
||||||
# delta = (scale - 1) * (x0_pos - x0_neg)
|
# For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 (no effect)
|
||||||
# For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 (no CFG effect)
|
if use_apg:
|
||||||
video_x0_guided_f32 = video_x0_pos_f32 + (cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32)
|
# APG for video (more stable for I2V), standard CFG for audio
|
||||||
|
video_x0_guided_f32 = video_x0_pos_f32 + apg_delta(
|
||||||
|
video_x0_pos_f32, video_x0_neg_f32, cfg_scale,
|
||||||
|
eta=apg_eta, norm_threshold=apg_norm_threshold
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
video_x0_guided_f32 = video_x0_pos_f32 + (cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32)
|
||||||
|
# Always use standard CFG for audio
|
||||||
audio_x0_guided_f32 = audio_x0_pos_f32 + (cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
|
audio_x0_guided_f32 = audio_x0_pos_f32 + (cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
|
||||||
|
|
||||||
# Apply CFG rescale if enabled
|
# Apply CFG rescale if enabled
|
||||||
@@ -788,6 +873,9 @@ def generate_video(
|
|||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
audio: bool = False,
|
audio: bool = False,
|
||||||
output_audio_path: Optional[str] = None,
|
output_audio_path: Optional[str] = None,
|
||||||
|
use_apg: bool = False,
|
||||||
|
apg_eta: float = 1.0,
|
||||||
|
apg_norm_threshold: float = 0.0,
|
||||||
):
|
):
|
||||||
"""Generate video using LTX-2 models.
|
"""Generate video using LTX-2 models.
|
||||||
|
|
||||||
@@ -821,6 +909,9 @@ def generate_video(
|
|||||||
stream: Stream frames to output as they're decoded
|
stream: Stream frames to output as they're decoded
|
||||||
audio: Enable synchronized audio generation
|
audio: Enable synchronized audio generation
|
||||||
output_audio_path: Path to save audio file
|
output_audio_path: Path to save audio file
|
||||||
|
use_apg: Use Adaptive Projected Guidance instead of CFG (more stable for I2V)
|
||||||
|
apg_eta: APG parallel component weight (1.0 = keep full parallel)
|
||||||
|
apg_norm_threshold: APG guidance norm clamp (0 = no clamping)
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
@@ -1080,9 +1171,9 @@ def generate_video(
|
|||||||
mx.clear_cache()
|
mx.clear_cache()
|
||||||
console.print("[green]✓[/] VAE encoder loaded and image encoded")
|
console.print("[green]✓[/] VAE encoder loaded and image encoded")
|
||||||
|
|
||||||
# Generate sigma schedule (uses MAX_SHIFT_ANCHOR=4096 like the reference implementation)
|
# Generate sigma schedule with token-count-dependent shifting
|
||||||
num_tokens = latent_frames * latent_h * latent_w
|
num_tokens = latent_frames * latent_h * latent_w
|
||||||
sigmas = ltx2_scheduler(steps=num_inference_steps)
|
sigmas = ltx2_scheduler(steps=num_inference_steps, num_tokens=num_tokens)
|
||||||
mx.eval(sigmas)
|
mx.eval(sigmas)
|
||||||
console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]")
|
console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]")
|
||||||
|
|
||||||
@@ -1125,7 +1216,7 @@ def generate_video(
|
|||||||
latents = mx.random.normal(video_latent_shape, dtype=model_dtype)
|
latents = mx.random.normal(video_latent_shape, dtype=model_dtype)
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
|
|
||||||
# Denoise with CFG
|
# Denoise with CFG/APG
|
||||||
if audio:
|
if audio:
|
||||||
latents, audio_latents = denoise_dev_av(
|
latents, audio_latents = denoise_dev_av(
|
||||||
latents, audio_latents,
|
latents, audio_latents,
|
||||||
@@ -1133,7 +1224,8 @@ def generate_video(
|
|||||||
video_embeddings_pos, video_embeddings_neg,
|
video_embeddings_pos, video_embeddings_neg,
|
||||||
audio_embeddings_pos, audio_embeddings_neg,
|
audio_embeddings_pos, audio_embeddings_neg,
|
||||||
transformer, sigmas, cfg_scale=cfg_scale,
|
transformer, sigmas, cfg_scale=cfg_scale,
|
||||||
cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state
|
cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state,
|
||||||
|
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use original denoise_dev with computed sigmas
|
# Use original denoise_dev with computed sigmas
|
||||||
@@ -1141,7 +1233,8 @@ def generate_video(
|
|||||||
latents, video_positions,
|
latents, video_positions,
|
||||||
video_embeddings_pos, video_embeddings_neg,
|
video_embeddings_pos, video_embeddings_neg,
|
||||||
transformer, sigmas, cfg_scale=cfg_scale,
|
transformer, sigmas, cfg_scale=cfg_scale,
|
||||||
verbose=verbose, state=video_state
|
verbose=verbose, state=video_state,
|
||||||
|
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
|
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
|
||||||
@@ -1371,6 +1464,9 @@ Examples:
|
|||||||
parser.add_argument("--stream", action="store_true", help="Stream frames to output as they're decoded")
|
parser.add_argument("--stream", action="store_true", help="Stream frames to output as they're decoded")
|
||||||
parser.add_argument("--audio", "-a", action="store_true", help="Enable synchronized audio generation")
|
parser.add_argument("--audio", "-a", action="store_true", help="Enable synchronized audio generation")
|
||||||
parser.add_argument("--output-audio", type=str, default=None, help="Output audio path")
|
parser.add_argument("--output-audio", type=str, default=None, help="Output audio path")
|
||||||
|
parser.add_argument("--apg", action="store_true", help="Use Adaptive Projected Guidance instead of CFG (more stable for I2V)")
|
||||||
|
parser.add_argument("--apg-eta", type=float, default=1.0, help="APG parallel component weight (1.0 = keep full parallel)")
|
||||||
|
parser.add_argument("--apg-norm-threshold", type=float, default=0.0, help="APG guidance norm clamp (0 = no clamping)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
pipeline = PipelineType.DEV if args.pipeline == "dev" else PipelineType.DISTILLED
|
pipeline = PipelineType.DEV if args.pipeline == "dev" else PipelineType.DISTILLED
|
||||||
@@ -1402,6 +1498,9 @@ Examples:
|
|||||||
stream=args.stream,
|
stream=args.stream,
|
||||||
audio=args.audio,
|
audio=args.audio,
|
||||||
output_audio_path=args.output_audio,
|
output_audio_path=args.output_audio,
|
||||||
|
use_apg=args.apg,
|
||||||
|
apg_eta=args.apg_eta,
|
||||||
|
apg_norm_threshold=args.apg_norm_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user