Update generate.py to enhance denoising functionality with optional Spatiotemporal Guidance (STG) support. Modify DEFAULT_NEGATIVE_PROMPT for improved clarity and detail. Implement auto-detection of STG blocks based on transformer configuration. Refactor denoise_dev function to incorporate STG parameters, allowing for more flexible audio-visual integration during video generation.

This commit is contained in:
Prince Canuma
2026-03-14 20:02:42 +01:00
parent ffe271699a
commit 5644492f7d

View File

@@ -58,8 +58,20 @@ AUDIO_MEL_BINS = 16
AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25 AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25
# Default negative prompt for CFG (dev pipeline) # Default negative prompt for CFG (dev pipeline)
# Matches PyTorch LTX-2 reference InferenceConfig default # Matches PyTorch LTX-2 reference DEFAULT_NEGATIVE_PROMPT from constants.py
DEFAULT_NEGATIVE_PROMPT = "worst quality, inconsistent motion, blurry, jittery, distorted" DEFAULT_NEGATIVE_PROMPT = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
def load_and_merge_lora( def load_and_merge_lora(
@@ -564,8 +576,10 @@ def denoise_dev(
use_apg: bool = False, use_apg: bool = False,
apg_eta: float = 1.0, apg_eta: float = 1.0,
apg_norm_threshold: float = 0.0, apg_norm_threshold: float = 0.0,
stg_scale: float = 0.0,
stg_blocks: Optional[list] = None,
) -> mx.array: ) -> mx.array:
"""Run denoising loop for dev pipeline with CFG or APG guidance. """Run denoising loop for dev pipeline with CFG/APG and optional STG guidance.
Args: Args:
cfg_rescale: Rescale factor for CFG (0.0-1.0). Normalizes guided prediction cfg_rescale: Rescale factor for CFG (0.0-1.0). Normalizes guided prediction
@@ -576,6 +590,8 @@ def denoise_dev(
for more stable I2V generation. for more stable I2V generation.
apg_eta: APG parallel component weight (1.0 = keep full parallel) apg_eta: APG parallel component weight (1.0 = keep full parallel)
apg_norm_threshold: APG guidance norm clamp (0 = no clamping) apg_norm_threshold: APG guidance norm clamp (0 = no clamping)
stg_scale: STG (Spatiotemporal Guidance) scale. 0.0 = disabled.
stg_blocks: Transformer block indices for STG perturbation.
""" """
from mlx_video.models.ltx.rope import precompute_freqs_cis from mlx_video.models.ltx.rope import precompute_freqs_cis
@@ -590,6 +606,7 @@ def denoise_dev(
sigmas_list = sigmas.tolist() sigmas_list = sigmas.tolist()
use_cfg = cfg_scale != 1.0 use_cfg = cfg_scale != 1.0
use_stg = stg_scale != 0.0 and stg_blocks is not None
num_steps = len(sigmas_list) - 1 num_steps = len(sigmas_list) - 1
# Precompute RoPE once # Precompute RoPE once
@@ -614,7 +631,10 @@ def denoise_dev(
console=console, console=console,
disable=not verbose, disable=not verbose,
) as progress: ) as progress:
task = progress.add_task("[cyan]Denoising (CFG)[/]", total=num_steps) passes = ["CFG"] if use_cfg else []
if use_stg: passes.append("STG")
label = "+".join(passes) if passes else "uncond"
task = progress.add_task(f"[cyan]Denoising ({label})[/]", total=num_steps)
for i in range(num_steps): for i in range(num_steps):
sigma = sigmas_list[i] sigma = sigmas_list[i]
@@ -656,6 +676,9 @@ def denoise_dev(
timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1) timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1)
x0_pos_f32 = latents_flat_f32 - timesteps_f32 * velocity_pos.astype(mx.float32) x0_pos_f32 = latents_flat_f32 - timesteps_f32 * velocity_pos.astype(mx.float32)
# Start with positive prediction
x0_guided_f32 = x0_pos_f32
if use_cfg: if use_cfg:
# Negative conditioning pass # Negative conditioning pass
video_modality_neg = Modality( video_modality_neg = Modality(
@@ -685,15 +708,24 @@ def denoise_dev(
# Standard CFG # Standard CFG
x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32) x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32)
# Apply CFG rescale if enabled (std-ratio rescaling to reduce over-saturation) # STG pass: skip self-attention at specified blocks
# factor = rescale * (cond_std / pred_std) + (1 - rescale) if use_stg:
# pred = pred * factor velocity_ptb, _ = transformer(
if cfg_rescale > 0.0: video=video_modality_pos, audio=None,
v_factor = x0_pos_f32.std() / (x0_guided_f32.std() + 1e-8) stg_video_blocks=stg_blocks,
v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale) )
x0_guided_f32 = x0_guided_f32 * v_factor mx.eval(velocity_ptb)
else:
x0_guided_f32 = x0_pos_f32 x0_ptb_f32 = latents_flat_f32 - timesteps_f32 * velocity_ptb.astype(mx.float32)
x0_guided_f32 = x0_guided_f32 + stg_scale * (x0_pos_f32 - x0_ptb_f32)
# Apply CFG rescale if enabled (std-ratio rescaling to reduce over-saturation)
# factor = rescale * (cond_std / pred_std) + (1 - rescale)
# pred = pred * factor
if cfg_rescale > 0.0 and (use_cfg or use_stg):
v_factor = x0_pos_f32.std() / (x0_guided_f32.std() + 1e-8)
v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale)
x0_guided_f32 = x0_guided_f32 * v_factor
# Reshape x0 from token space (b, tokens, c) to spatial (b, c, f, h, w) # Reshape x0 from token space (b, tokens, c) to spatial (b, c, f, h, w)
denoised = mx.reshape(mx.transpose(x0_guided_f32, (0, 2, 1)), (b, c, f, h, w)) denoised = mx.reshape(mx.transpose(x0_guided_f32, (0, 2, 1)), (b, c, f, h, w))
@@ -1225,6 +1257,15 @@ def generate_video(
console.print("[green]✓[/] Transformer loaded") console.print("[green]✓[/] Transformer loaded")
# Auto-detect stg_blocks from transformer config if not explicitly provided.
# LTX-2.3 (has_prompt_adaln=True) uses block 28; LTX-2 uses block 29.
if stg_blocks is None and stg_scale != 0.0:
if transformer.config.has_prompt_adaln:
stg_blocks = [28]
else:
stg_blocks = [29]
console.print(f"[dim]Auto-detected STG blocks: {stg_blocks} (model={'2.3' if transformer.config.has_prompt_adaln else '2'})[/]")
# ========================================================================== # ==========================================================================
# Pipeline-specific generation logic # Pipeline-specific generation logic
# ========================================================================== # ==========================================================================
@@ -1451,7 +1492,8 @@ def generate_video(
transformer, sigmas, cfg_scale=cfg_scale, transformer, sigmas, cfg_scale=cfg_scale,
cfg_rescale=cfg_rescale, cfg_rescale=cfg_rescale,
verbose=verbose, state=video_state, verbose=verbose, state=video_state,
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
stg_scale=stg_scale, stg_blocks=stg_blocks,
) )
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling) # Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
@@ -1551,7 +1593,8 @@ def generate_video(
transformer, sigmas, cfg_scale=cfg_scale, transformer, sigmas, cfg_scale=cfg_scale,
cfg_rescale=cfg_rescale, cfg_rescale=cfg_rescale,
verbose=verbose, state=state1, verbose=verbose, state=state1,
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
stg_scale=stg_scale, stg_blocks=stg_blocks,
) )
if audio and audio_latents is not None: if audio and audio_latents is not None: