Update generate.py to enhance denoising functionality with optional Spatiotemporal Guidance (STG) support. Modify DEFAULT_NEGATIVE_PROMPT for improved clarity and detail. Implement auto-detection of STG blocks based on transformer configuration. Refactor denoise_dev function to incorporate STG parameters, allowing for more flexible audio-visual integration during video generation.

This commit is contained in:
Prince Canuma
2026-03-14 20:02:42 +01:00
parent ffe271699a
commit 5644492f7d

View File

@@ -58,8 +58,20 @@ AUDIO_MEL_BINS = 16
AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25
# Default negative prompt for CFG (dev pipeline)
# Matches PyTorch LTX-2 reference InferenceConfig default
DEFAULT_NEGATIVE_PROMPT = "worst quality, inconsistent motion, blurry, jittery, distorted"
# Matches PyTorch LTX-2 reference DEFAULT_NEGATIVE_PROMPT from constants.py
DEFAULT_NEGATIVE_PROMPT = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
def load_and_merge_lora(
@@ -564,8 +576,10 @@ def denoise_dev(
use_apg: bool = False,
apg_eta: float = 1.0,
apg_norm_threshold: float = 0.0,
stg_scale: float = 0.0,
stg_blocks: Optional[list] = None,
) -> mx.array:
"""Run denoising loop for dev pipeline with CFG or APG guidance.
"""Run denoising loop for dev pipeline with CFG/APG and optional STG guidance.
Args:
cfg_rescale: Rescale factor for CFG (0.0-1.0). Normalizes guided prediction
@@ -576,6 +590,8 @@ def denoise_dev(
for more stable I2V generation.
apg_eta: APG parallel component weight (1.0 = keep full parallel)
apg_norm_threshold: APG guidance norm clamp (0 = no clamping)
stg_scale: STG (Spatiotemporal Guidance) scale. 0.0 = disabled.
stg_blocks: Transformer block indices for STG perturbation.
"""
from mlx_video.models.ltx.rope import precompute_freqs_cis
@@ -590,6 +606,7 @@ def denoise_dev(
sigmas_list = sigmas.tolist()
use_cfg = cfg_scale != 1.0
use_stg = stg_scale != 0.0 and stg_blocks is not None
num_steps = len(sigmas_list) - 1
# Precompute RoPE once
@@ -614,7 +631,10 @@ def denoise_dev(
console=console,
disable=not verbose,
) as progress:
task = progress.add_task("[cyan]Denoising (CFG)[/]", total=num_steps)
passes = ["CFG"] if use_cfg else []
if use_stg: passes.append("STG")
label = "+".join(passes) if passes else "uncond"
task = progress.add_task(f"[cyan]Denoising ({label})[/]", total=num_steps)
for i in range(num_steps):
sigma = sigmas_list[i]
@@ -656,6 +676,9 @@ def denoise_dev(
timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1)
x0_pos_f32 = latents_flat_f32 - timesteps_f32 * velocity_pos.astype(mx.float32)
# Start with positive prediction
x0_guided_f32 = x0_pos_f32
if use_cfg:
# Negative conditioning pass
video_modality_neg = Modality(
@@ -685,15 +708,24 @@ def denoise_dev(
# Standard CFG
x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32)
# STG pass: skip self-attention at specified blocks
if use_stg:
velocity_ptb, _ = transformer(
video=video_modality_pos, audio=None,
stg_video_blocks=stg_blocks,
)
mx.eval(velocity_ptb)
x0_ptb_f32 = latents_flat_f32 - timesteps_f32 * velocity_ptb.astype(mx.float32)
x0_guided_f32 = x0_guided_f32 + stg_scale * (x0_pos_f32 - x0_ptb_f32)
# Apply CFG rescale if enabled (std-ratio rescaling to reduce over-saturation)
# factor = rescale * (cond_std / pred_std) + (1 - rescale)
# pred = pred * factor
if cfg_rescale > 0.0:
if cfg_rescale > 0.0 and (use_cfg or use_stg):
v_factor = x0_pos_f32.std() / (x0_guided_f32.std() + 1e-8)
v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale)
x0_guided_f32 = x0_guided_f32 * v_factor
else:
x0_guided_f32 = x0_pos_f32
# Reshape x0 from token space (b, tokens, c) to spatial (b, c, f, h, w)
denoised = mx.reshape(mx.transpose(x0_guided_f32, (0, 2, 1)), (b, c, f, h, w))
@@ -1225,6 +1257,15 @@ def generate_video(
console.print("[green]✓[/] Transformer loaded")
# Auto-detect stg_blocks from transformer config if not explicitly provided.
# LTX-2.3 (has_prompt_adaln=True) uses block 28; LTX-2 uses block 29.
if stg_blocks is None and stg_scale != 0.0:
if transformer.config.has_prompt_adaln:
stg_blocks = [28]
else:
stg_blocks = [29]
console.print(f"[dim]Auto-detected STG blocks: {stg_blocks} (model={'2.3' if transformer.config.has_prompt_adaln else '2'})[/]")
# ==========================================================================
# Pipeline-specific generation logic
# ==========================================================================
@@ -1451,7 +1492,8 @@ def generate_video(
transformer, sigmas, cfg_scale=cfg_scale,
cfg_rescale=cfg_rescale,
verbose=verbose, state=video_state,
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
stg_scale=stg_scale, stg_blocks=stg_blocks,
)
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
@@ -1551,7 +1593,8 @@ def generate_video(
transformer, sigmas, cfg_scale=cfg_scale,
cfg_rescale=cfg_rescale,
verbose=verbose, state=state1,
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
stg_scale=stg_scale, stg_blocks=stg_blocks,
)
if audio and audio_latents is not None: