From 5644492f7d71749daf683a8723f5714e2dd9e7ab Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 14 Mar 2026 20:02:42 +0100 Subject: [PATCH] Update generate.py to enhance denoising functionality with optional Spatiotemporal Guidance (STG) support. Modify DEFAULT_NEGATIVE_PROMPT for improved clarity and detail. Implement auto-detection of STG blocks based on transformer configuration. Refactor denoise_dev function to incorporate STG parameters, allowing for more flexible audio-visual integration during video generation. --- mlx_video/generate.py | 73 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 58 insertions(+), 15 deletions(-) diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 8253b57..5a7e2fe 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -58,8 +58,20 @@ AUDIO_MEL_BINS = 16 AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25 # Default negative prompt for CFG (dev pipeline) -# Matches PyTorch LTX-2 reference InferenceConfig default -DEFAULT_NEGATIVE_PROMPT = "worst quality, inconsistent motion, blurry, jittery, distorted" +# Matches PyTorch LTX-2 reference DEFAULT_NEGATIVE_PROMPT from constants.py +DEFAULT_NEGATIVE_PROMPT = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) def load_and_merge_lora( @@ -564,8 +576,10 @@ def denoise_dev( use_apg: bool = False, apg_eta: float = 1.0, apg_norm_threshold: float = 0.0, + stg_scale: float = 0.0, + stg_blocks: Optional[list] = None, ) -> mx.array: - """Run denoising loop for dev pipeline with CFG or APG guidance. + """Run denoising loop for dev pipeline with CFG/APG and optional STG guidance. Args: cfg_rescale: Rescale factor for CFG (0.0-1.0). Normalizes guided prediction @@ -576,6 +590,8 @@ def denoise_dev( for more stable I2V generation. apg_eta: APG parallel component weight (1.0 = keep full parallel) apg_norm_threshold: APG guidance norm clamp (0 = no clamping) + stg_scale: STG (Spatiotemporal Guidance) scale. 0.0 = disabled. + stg_blocks: Transformer block indices for STG perturbation. """ from mlx_video.models.ltx.rope import precompute_freqs_cis @@ -590,6 +606,7 @@ def denoise_dev( sigmas_list = sigmas.tolist() use_cfg = cfg_scale != 1.0 + use_stg = stg_scale != 0.0 and stg_blocks is not None num_steps = len(sigmas_list) - 1 # Precompute RoPE once @@ -614,7 +631,10 @@ def denoise_dev( console=console, disable=not verbose, ) as progress: - task = progress.add_task("[cyan]Denoising (CFG)[/]", total=num_steps) + passes = ["CFG"] if use_cfg else [] + if use_stg: passes.append("STG") + label = "+".join(passes) if passes else "uncond" + task = progress.add_task(f"[cyan]Denoising ({label})[/]", total=num_steps) for i in range(num_steps): sigma = sigmas_list[i] @@ -656,6 +676,9 @@ def denoise_dev( timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1) x0_pos_f32 = latents_flat_f32 - timesteps_f32 * velocity_pos.astype(mx.float32) + # Start with positive prediction + x0_guided_f32 = x0_pos_f32 + if use_cfg: # Negative conditioning pass video_modality_neg = Modality( @@ -685,15 +708,24 @@ def denoise_dev( # Standard CFG x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32) - # Apply CFG rescale if enabled (std-ratio rescaling to reduce over-saturation) - # factor = rescale * (cond_std / pred_std) + (1 - rescale) - # pred = pred * factor - if cfg_rescale > 0.0: - v_factor = x0_pos_f32.std() / (x0_guided_f32.std() + 1e-8) - v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale) - x0_guided_f32 = x0_guided_f32 * v_factor - else: - x0_guided_f32 = x0_pos_f32 + # STG pass: skip self-attention at specified blocks + if use_stg: + velocity_ptb, _ = transformer( + video=video_modality_pos, audio=None, + stg_video_blocks=stg_blocks, + ) + mx.eval(velocity_ptb) + + x0_ptb_f32 = latents_flat_f32 - timesteps_f32 * velocity_ptb.astype(mx.float32) + x0_guided_f32 = x0_guided_f32 + stg_scale * (x0_pos_f32 - x0_ptb_f32) + + # Apply CFG rescale if enabled (std-ratio rescaling to reduce over-saturation) + # factor = rescale * (cond_std / pred_std) + (1 - rescale) + # pred = pred * factor + if cfg_rescale > 0.0 and (use_cfg or use_stg): + v_factor = x0_pos_f32.std() / (x0_guided_f32.std() + 1e-8) + v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale) + x0_guided_f32 = x0_guided_f32 * v_factor # Reshape x0 from token space (b, tokens, c) to spatial (b, c, f, h, w) denoised = mx.reshape(mx.transpose(x0_guided_f32, (0, 2, 1)), (b, c, f, h, w)) @@ -1225,6 +1257,15 @@ def generate_video( console.print("[green]✓[/] Transformer loaded") + # Auto-detect stg_blocks from transformer config if not explicitly provided. + # LTX-2.3 (has_prompt_adaln=True) uses block 28; LTX-2 uses block 29. + if stg_blocks is None and stg_scale != 0.0: + if transformer.config.has_prompt_adaln: + stg_blocks = [28] + else: + stg_blocks = [29] + console.print(f"[dim]Auto-detected STG blocks: {stg_blocks} (model={'2.3' if transformer.config.has_prompt_adaln else '2'})[/]") + # ========================================================================== # Pipeline-specific generation logic # ========================================================================== @@ -1451,7 +1492,8 @@ def generate_video( transformer, sigmas, cfg_scale=cfg_scale, cfg_rescale=cfg_rescale, verbose=verbose, state=video_state, - use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold + use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold, + stg_scale=stg_scale, stg_blocks=stg_blocks, ) # Load VAE decoder (for dev pipeline, loaded here instead of during upsampling) @@ -1551,7 +1593,8 @@ def generate_video( transformer, sigmas, cfg_scale=cfg_scale, cfg_rescale=cfg_rescale, verbose=verbose, state=state1, - use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold + use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold, + stg_scale=stg_scale, stg_blocks=stg_blocks, ) if audio and audio_latents is not None: