Update generate.py to enhance denoising functionality with optional Spatiotemporal Guidance (STG) support. Modify DEFAULT_NEGATIVE_PROMPT for improved clarity and detail. Implement auto-detection of STG blocks based on transformer configuration. Refactor denoise_dev function to incorporate STG parameters, allowing for more flexible audio-visual integration during video generation.
This commit is contained in:
@@ -58,8 +58,20 @@ AUDIO_MEL_BINS = 16
|
|||||||
AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25
|
AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25
|
||||||
|
|
||||||
# Default negative prompt for CFG (dev pipeline)
|
# Default negative prompt for CFG (dev pipeline)
|
||||||
# Matches PyTorch LTX-2 reference InferenceConfig default
|
# Matches PyTorch LTX-2 reference DEFAULT_NEGATIVE_PROMPT from constants.py
|
||||||
DEFAULT_NEGATIVE_PROMPT = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
DEFAULT_NEGATIVE_PROMPT = (
|
||||||
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||||
|
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||||
|
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||||
|
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||||
|
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||||
|
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||||
|
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||||
|
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||||
|
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||||
|
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||||
|
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_and_merge_lora(
|
def load_and_merge_lora(
|
||||||
@@ -564,8 +576,10 @@ def denoise_dev(
|
|||||||
use_apg: bool = False,
|
use_apg: bool = False,
|
||||||
apg_eta: float = 1.0,
|
apg_eta: float = 1.0,
|
||||||
apg_norm_threshold: float = 0.0,
|
apg_norm_threshold: float = 0.0,
|
||||||
|
stg_scale: float = 0.0,
|
||||||
|
stg_blocks: Optional[list] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""Run denoising loop for dev pipeline with CFG or APG guidance.
|
"""Run denoising loop for dev pipeline with CFG/APG and optional STG guidance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg_rescale: Rescale factor for CFG (0.0-1.0). Normalizes guided prediction
|
cfg_rescale: Rescale factor for CFG (0.0-1.0). Normalizes guided prediction
|
||||||
@@ -576,6 +590,8 @@ def denoise_dev(
|
|||||||
for more stable I2V generation.
|
for more stable I2V generation.
|
||||||
apg_eta: APG parallel component weight (1.0 = keep full parallel)
|
apg_eta: APG parallel component weight (1.0 = keep full parallel)
|
||||||
apg_norm_threshold: APG guidance norm clamp (0 = no clamping)
|
apg_norm_threshold: APG guidance norm clamp (0 = no clamping)
|
||||||
|
stg_scale: STG (Spatiotemporal Guidance) scale. 0.0 = disabled.
|
||||||
|
stg_blocks: Transformer block indices for STG perturbation.
|
||||||
"""
|
"""
|
||||||
from mlx_video.models.ltx.rope import precompute_freqs_cis
|
from mlx_video.models.ltx.rope import precompute_freqs_cis
|
||||||
|
|
||||||
@@ -590,6 +606,7 @@ def denoise_dev(
|
|||||||
|
|
||||||
sigmas_list = sigmas.tolist()
|
sigmas_list = sigmas.tolist()
|
||||||
use_cfg = cfg_scale != 1.0
|
use_cfg = cfg_scale != 1.0
|
||||||
|
use_stg = stg_scale != 0.0 and stg_blocks is not None
|
||||||
num_steps = len(sigmas_list) - 1
|
num_steps = len(sigmas_list) - 1
|
||||||
|
|
||||||
# Precompute RoPE once
|
# Precompute RoPE once
|
||||||
@@ -614,7 +631,10 @@ def denoise_dev(
|
|||||||
console=console,
|
console=console,
|
||||||
disable=not verbose,
|
disable=not verbose,
|
||||||
) as progress:
|
) as progress:
|
||||||
task = progress.add_task("[cyan]Denoising (CFG)[/]", total=num_steps)
|
passes = ["CFG"] if use_cfg else []
|
||||||
|
if use_stg: passes.append("STG")
|
||||||
|
label = "+".join(passes) if passes else "uncond"
|
||||||
|
task = progress.add_task(f"[cyan]Denoising ({label})[/]", total=num_steps)
|
||||||
|
|
||||||
for i in range(num_steps):
|
for i in range(num_steps):
|
||||||
sigma = sigmas_list[i]
|
sigma = sigmas_list[i]
|
||||||
@@ -656,6 +676,9 @@ def denoise_dev(
|
|||||||
timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1)
|
timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1)
|
||||||
x0_pos_f32 = latents_flat_f32 - timesteps_f32 * velocity_pos.astype(mx.float32)
|
x0_pos_f32 = latents_flat_f32 - timesteps_f32 * velocity_pos.astype(mx.float32)
|
||||||
|
|
||||||
|
# Start with positive prediction
|
||||||
|
x0_guided_f32 = x0_pos_f32
|
||||||
|
|
||||||
if use_cfg:
|
if use_cfg:
|
||||||
# Negative conditioning pass
|
# Negative conditioning pass
|
||||||
video_modality_neg = Modality(
|
video_modality_neg = Modality(
|
||||||
@@ -685,15 +708,24 @@ def denoise_dev(
|
|||||||
# Standard CFG
|
# Standard CFG
|
||||||
x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32)
|
x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32)
|
||||||
|
|
||||||
# Apply CFG rescale if enabled (std-ratio rescaling to reduce over-saturation)
|
# STG pass: skip self-attention at specified blocks
|
||||||
# factor = rescale * (cond_std / pred_std) + (1 - rescale)
|
if use_stg:
|
||||||
# pred = pred * factor
|
velocity_ptb, _ = transformer(
|
||||||
if cfg_rescale > 0.0:
|
video=video_modality_pos, audio=None,
|
||||||
v_factor = x0_pos_f32.std() / (x0_guided_f32.std() + 1e-8)
|
stg_video_blocks=stg_blocks,
|
||||||
v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale)
|
)
|
||||||
x0_guided_f32 = x0_guided_f32 * v_factor
|
mx.eval(velocity_ptb)
|
||||||
else:
|
|
||||||
x0_guided_f32 = x0_pos_f32
|
x0_ptb_f32 = latents_flat_f32 - timesteps_f32 * velocity_ptb.astype(mx.float32)
|
||||||
|
x0_guided_f32 = x0_guided_f32 + stg_scale * (x0_pos_f32 - x0_ptb_f32)
|
||||||
|
|
||||||
|
# Apply CFG rescale if enabled (std-ratio rescaling to reduce over-saturation)
|
||||||
|
# factor = rescale * (cond_std / pred_std) + (1 - rescale)
|
||||||
|
# pred = pred * factor
|
||||||
|
if cfg_rescale > 0.0 and (use_cfg or use_stg):
|
||||||
|
v_factor = x0_pos_f32.std() / (x0_guided_f32.std() + 1e-8)
|
||||||
|
v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale)
|
||||||
|
x0_guided_f32 = x0_guided_f32 * v_factor
|
||||||
|
|
||||||
# Reshape x0 from token space (b, tokens, c) to spatial (b, c, f, h, w)
|
# Reshape x0 from token space (b, tokens, c) to spatial (b, c, f, h, w)
|
||||||
denoised = mx.reshape(mx.transpose(x0_guided_f32, (0, 2, 1)), (b, c, f, h, w))
|
denoised = mx.reshape(mx.transpose(x0_guided_f32, (0, 2, 1)), (b, c, f, h, w))
|
||||||
@@ -1225,6 +1257,15 @@ def generate_video(
|
|||||||
|
|
||||||
console.print("[green]✓[/] Transformer loaded")
|
console.print("[green]✓[/] Transformer loaded")
|
||||||
|
|
||||||
|
# Auto-detect stg_blocks from transformer config if not explicitly provided.
|
||||||
|
# LTX-2.3 (has_prompt_adaln=True) uses block 28; LTX-2 uses block 29.
|
||||||
|
if stg_blocks is None and stg_scale != 0.0:
|
||||||
|
if transformer.config.has_prompt_adaln:
|
||||||
|
stg_blocks = [28]
|
||||||
|
else:
|
||||||
|
stg_blocks = [29]
|
||||||
|
console.print(f"[dim]Auto-detected STG blocks: {stg_blocks} (model={'2.3' if transformer.config.has_prompt_adaln else '2'})[/]")
|
||||||
|
|
||||||
# ==========================================================================
|
# ==========================================================================
|
||||||
# Pipeline-specific generation logic
|
# Pipeline-specific generation logic
|
||||||
# ==========================================================================
|
# ==========================================================================
|
||||||
@@ -1451,7 +1492,8 @@ def generate_video(
|
|||||||
transformer, sigmas, cfg_scale=cfg_scale,
|
transformer, sigmas, cfg_scale=cfg_scale,
|
||||||
cfg_rescale=cfg_rescale,
|
cfg_rescale=cfg_rescale,
|
||||||
verbose=verbose, state=video_state,
|
verbose=verbose, state=video_state,
|
||||||
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold
|
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
|
||||||
|
stg_scale=stg_scale, stg_blocks=stg_blocks,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
|
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
|
||||||
@@ -1551,7 +1593,8 @@ def generate_video(
|
|||||||
transformer, sigmas, cfg_scale=cfg_scale,
|
transformer, sigmas, cfg_scale=cfg_scale,
|
||||||
cfg_rescale=cfg_rescale,
|
cfg_rescale=cfg_rescale,
|
||||||
verbose=verbose, state=state1,
|
verbose=verbose, state=state1,
|
||||||
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold
|
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
|
||||||
|
stg_scale=stg_scale, stg_blocks=stg_blocks,
|
||||||
)
|
)
|
||||||
|
|
||||||
if audio and audio_latents is not None:
|
if audio and audio_latents is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user