Enhance README.md with new usage examples for STG and modality scale parameters in video generation. Update generate.py to support STG and modality guidance in the denoising process, allowing for improved audio-visual integration. Refactor attention mechanisms in the transformer to include options for skipping self-attention, facilitating STG perturbation and modality isolation. Update LTXModel and transformer block processing to accommodate new parameters for enhanced flexibility in model configurations.

This commit is contained in:
Prince Canuma
2026-03-14 10:26:12 +01:00
parent f346e09de4
commit 9cba2ea7cd
5 changed files with 200 additions and 78 deletions

View File

@@ -78,6 +78,10 @@ uv run mlx_video.generate --pipeline dev --prompt "Waves crashing" --image beach
```bash ```bash
uv run mlx_video.generate --prompt "Ocean waves crashing" --audio uv run mlx_video.generate --prompt "Ocean waves crashing" --audio
uv run mlx_video.generate --pipeline dev --prompt "A jazz band playing" --audio --enhance-prompt uv run mlx_video.generate --pipeline dev --prompt "A jazz band playing" --audio --enhance-prompt
# With full guidance (STG + modality_scale, matches PyTorch defaults)
uv run mlx_video.generate --pipeline dev --prompt "Ocean waves crashing" --audio \
--stg-scale 1.0 --stg-blocks 29 --modality-scale 3.0
``` ```
### LoRA ### LoRA
@@ -146,6 +150,9 @@ uv run mlx_video.upscale --input video.mp4 --output upscaled.mp4 --refine --prom
| `--cfg-rescale` | 0.7 | CFG rescale factor (reduces over-saturation) | | `--cfg-rescale` | 0.7 | CFG rescale factor (reduces over-saturation) |
| `--negative-prompt` | (default) | Negative prompt for CFG | | `--negative-prompt` | (default) | Negative prompt for CFG |
| `--apg` | false | Use Adaptive Projected Guidance (more stable for I2V) | | `--apg` | false | Use Adaptive Projected Guidance (more stable for I2V) |
| `--stg-scale` | 0.0 | STG scale (PyTorch default: 1.0, requires `--audio`) |
| `--stg-blocks` | None | Transformer blocks for STG ([29] for LTX-2, [28] for LTX-2.3) |
| `--modality-scale` | 1.0 | Cross-modal guidance scale (PyTorch default: 3.0, requires `--audio`) |
**Dev-Two-Stage LoRA options:** **Dev-Two-Stage LoRA options:**

View File

@@ -715,22 +715,31 @@ def denoise_dev_av(
transformer: LTXModel, transformer: LTXModel,
sigmas: mx.array, sigmas: mx.array,
cfg_scale: float = 4.0, cfg_scale: float = 4.0,
audio_cfg_scale: float = 7.0,
cfg_rescale: float = 0.0, cfg_rescale: float = 0.0,
verbose: bool = True, verbose: bool = True,
video_state: Optional[LatentState] = None, video_state: Optional[LatentState] = None,
use_apg: bool = False, use_apg: bool = False,
apg_eta: float = 1.0, apg_eta: float = 1.0,
apg_norm_threshold: float = 0.0, apg_norm_threshold: float = 0.0,
stg_scale: float = 0.0,
stg_video_blocks: Optional[list] = None,
stg_audio_blocks: Optional[list] = None,
modality_scale: float = 1.0,
) -> tuple[mx.array, mx.array]: ) -> tuple[mx.array, mx.array]:
"""Run denoising loop for dev pipeline with CFG/APG and audio. """Run denoising loop for dev pipeline with CFG/APG, STG, modality guidance, and audio.
Args: Args:
cfg_rescale: Rescale factor for CFG (0.0-1.0). Higher values blend the CFG result audio_cfg_scale: Separate CFG scale for audio (PyTorch default: 7.0).
towards the positive-only prediction, helping reduce artifacts. cfg_rescale: Rescale factor for CFG (0.0-1.0). Normalizes guided prediction
Default 0.0 means no rescaling (standard CFG). variance to reduce artifacts. Default 0.0 means no rescaling.
use_apg: Use Adaptive Projected Guidance instead of standard CFG for video. use_apg: Use Adaptive Projected Guidance instead of standard CFG for video.
apg_eta: APG parallel component weight (1.0 = keep full parallel) apg_eta: APG parallel component weight (1.0 = keep full parallel)
apg_norm_threshold: APG guidance norm clamp (0 = no clamping) apg_norm_threshold: APG guidance norm clamp (0 = no clamping)
stg_scale: STG (Spatiotemporal Guidance) scale. 0.0 = disabled.
stg_video_blocks: Transformer block indices for video STG perturbation.
stg_audio_blocks: Transformer block indices for audio STG perturbation.
modality_scale: Cross-modal guidance scale. 1.0 = disabled.
""" """
from mlx_video.models.ltx.rope import precompute_freqs_cis from mlx_video.models.ltx.rope import precompute_freqs_cis
@@ -738,14 +747,14 @@ def denoise_dev_av(
if video_state is not None: if video_state is not None:
video_latents = video_state.latent video_latents = video_state.latent
# Keep latents in float32 throughout the denoising loop to avoid # Keep latents in float32 throughout the denoising loop for precision.
# bfloat16 quantization noise accumulation over many steps.
# PyTorch keeps latents in float32; model input is cast to model dtype.
video_latents = video_latents.astype(mx.float32) video_latents = video_latents.astype(mx.float32)
audio_latents = audio_latents.astype(mx.float32) audio_latents = audio_latents.astype(mx.float32)
sigmas_list = sigmas.tolist() sigmas_list = sigmas.tolist()
use_cfg = cfg_scale != 1.0 use_cfg = cfg_scale != 1.0
use_stg = stg_scale != 0.0 and stg_video_blocks is not None
use_modality = modality_scale != 1.0
num_steps = len(sigmas_list) - 1 num_steps = len(sigmas_list) - 1
# Precompute video RoPE # Precompute video RoPE
@@ -782,7 +791,11 @@ def denoise_dev_av(
console=console, console=console,
disable=not verbose, disable=not verbose,
) as progress: ) as progress:
task = progress.add_task("[cyan]Denoising A/V (CFG)[/]", total=num_steps) passes = ["CFG"] if use_cfg else []
if use_stg: passes.append("STG")
if use_modality: passes.append("Mod")
label = "+".join(passes) if passes else "uncond"
task = progress.add_task(f"[cyan]Denoising A/V ({label})[/]", total=num_steps)
for i in range(num_steps): for i in range(num_steps):
sigma = sigmas_list[i] sigma = sigmas_list[i]
@@ -827,7 +840,6 @@ def denoise_dev_av(
# This matches PyTorch's X0ModelWrapper: x0 = latent - timestep * velocity # This matches PyTorch's X0ModelWrapper: x0 = latent - timestep * velocity
# For conditioned tokens (timestep=0): x0 = latent (velocity is irrelevant) # For conditioned tokens (timestep=0): x0 = latent (velocity is irrelevant)
# For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity # For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity
# Use the float32 latents (not the bfloat16 model input) for precision
video_flat_f32 = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1)) video_flat_f32 = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1))
audio_flat_f32 = mx.reshape(mx.transpose(audio_latents, (0, 2, 1, 3)), (ab, at, ac * af)) audio_flat_f32 = mx.reshape(mx.transpose(audio_latents, (0, 2, 1, 3)), (ab, at, ac * af))
video_timesteps_f32 = mx.expand_dims(video_timesteps.astype(mx.float32), axis=-1) video_timesteps_f32 = mx.expand_dims(video_timesteps.astype(mx.float32), axis=-1)
@@ -836,8 +848,12 @@ def denoise_dev_av(
video_x0_pos_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_pos.astype(mx.float32) video_x0_pos_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_pos.astype(mx.float32)
audio_x0_pos_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_pos.astype(mx.float32) audio_x0_pos_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_pos.astype(mx.float32)
# Start with positive prediction
video_x0_guided_f32 = video_x0_pos_f32
audio_x0_guided_f32 = audio_x0_pos_f32
# Pass 2: CFG (negative conditioning)
if use_cfg: if use_cfg:
# Negative conditioning pass
video_modality_neg = Modality( video_modality_neg = Modality(
latent=video_flat, timesteps=video_timesteps, positions=video_positions, latent=video_flat, timesteps=video_timesteps, positions=video_positions,
context=video_embeddings_neg, context_mask=None, enabled=True, context=video_embeddings_neg, context_mask=None, enabled=True,
@@ -851,36 +867,54 @@ def denoise_dev_av(
video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg) video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg)
mx.eval(video_vel_neg, audio_vel_neg) mx.eval(video_vel_neg, audio_vel_neg)
# Convert negative velocity to x0 using per-token timesteps
video_x0_neg_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_neg.astype(mx.float32) video_x0_neg_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_neg.astype(mx.float32)
audio_x0_neg_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_neg.astype(mx.float32) audio_x0_neg_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_neg.astype(mx.float32)
# Apply guidance to x0 (denoised) predictions
# For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 (no effect)
if use_apg: if use_apg:
# APG for video (more stable for I2V), standard CFG for audio
video_x0_guided_f32 = video_x0_pos_f32 + apg_delta( video_x0_guided_f32 = video_x0_pos_f32 + apg_delta(
video_x0_pos_f32, video_x0_neg_f32, cfg_scale, video_x0_pos_f32, video_x0_neg_f32, cfg_scale,
eta=apg_eta, norm_threshold=apg_norm_threshold eta=apg_eta, norm_threshold=apg_norm_threshold
) )
else: else:
video_x0_guided_f32 = video_x0_pos_f32 + (cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32) video_x0_guided_f32 = video_x0_pos_f32 + (cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32)
# Always use standard CFG for audio audio_x0_guided_f32 = audio_x0_pos_f32 + (audio_cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
audio_x0_guided_f32 = audio_x0_pos_f32 + (cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
# Apply CFG rescale if enabled (std-ratio rescaling to reduce over-saturation) # Pass 3: STG (self-attention perturbation at specified blocks)
# factor = rescale * (cond_std / pred_std) + (1 - rescale) if use_stg:
# pred = pred * factor video_vel_ptb, audio_vel_ptb = transformer(
if cfg_rescale > 0.0: video=video_modality_pos, audio=audio_modality_pos,
v_factor = video_x0_pos_f32.std() / (video_x0_guided_f32.std() + 1e-8) stg_video_blocks=stg_video_blocks, stg_audio_blocks=stg_audio_blocks,
v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale) )
video_x0_guided_f32 = video_x0_guided_f32 * v_factor mx.eval(video_vel_ptb, audio_vel_ptb)
a_factor = audio_x0_pos_f32.std() / (audio_x0_guided_f32.std() + 1e-8)
a_factor = cfg_rescale * a_factor + (1.0 - cfg_rescale) video_x0_ptb_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_ptb.astype(mx.float32)
audio_x0_guided_f32 = audio_x0_guided_f32 * a_factor audio_x0_ptb_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_ptb.astype(mx.float32)
else:
video_x0_guided_f32 = video_x0_pos_f32 video_x0_guided_f32 = video_x0_guided_f32 + stg_scale * (video_x0_pos_f32 - video_x0_ptb_f32)
audio_x0_guided_f32 = audio_x0_pos_f32 audio_x0_guided_f32 = audio_x0_guided_f32 + stg_scale * (audio_x0_pos_f32 - audio_x0_ptb_f32)
# Pass 4: Modality isolation (skip all cross-modal attention)
if use_modality:
video_vel_iso, audio_vel_iso = transformer(
video=video_modality_pos, audio=audio_modality_pos,
skip_cross_modal=True,
)
mx.eval(video_vel_iso, audio_vel_iso)
video_x0_iso_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_iso.astype(mx.float32)
audio_x0_iso_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_iso.astype(mx.float32)
video_x0_guided_f32 = video_x0_guided_f32 + (modality_scale - 1.0) * (video_x0_pos_f32 - video_x0_iso_f32)
audio_x0_guided_f32 = audio_x0_guided_f32 + (modality_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_iso_f32)
# Apply CFG rescale (std-ratio rescaling to reduce over-saturation)
if cfg_rescale > 0.0 and (use_cfg or use_stg or use_modality):
v_factor = video_x0_pos_f32.std() / (video_x0_guided_f32.std() + 1e-8)
v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale)
video_x0_guided_f32 = video_x0_guided_f32 * v_factor
a_factor = audio_x0_pos_f32.std() / (audio_x0_guided_f32.std() + 1e-8)
a_factor = cfg_rescale * a_factor + (1.0 - cfg_rescale)
audio_x0_guided_f32 = audio_x0_guided_f32 * a_factor
# Reshape x0 from token space (b, tokens, c) to spatial (b, c, f, h, w) # Reshape x0 from token space (b, tokens, c) to spatial (b, c, f, h, w)
video_denoised_f32 = mx.reshape(mx.transpose(video_x0_guided_f32, (0, 2, 1)), (b, c, f, h, w)) video_denoised_f32 = mx.reshape(mx.transpose(video_x0_guided_f32, (0, 2, 1)), (b, c, f, h, w))
@@ -898,8 +932,7 @@ def denoise_dev_av(
mx.eval(video_denoised_f32, audio_denoised_f32) mx.eval(video_denoised_f32, audio_denoised_f32)
# Euler step matching PyTorch: sample + velocity * dt # Euler step: sample + velocity * dt (float32)
# Latents stay in float32 throughout (matching PyTorch behavior)
if sigma_next > 0: if sigma_next > 0:
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32) sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
dt_f32 = sigma_next_f32 - sigma_f32 dt_f32 = sigma_next_f32 - sigma_f32
@@ -998,6 +1031,7 @@ def generate_video(
num_frames: int = 33, num_frames: int = 33,
num_inference_steps: int = 40, num_inference_steps: int = 40,
cfg_scale: float = 4.0, cfg_scale: float = 4.0,
audio_cfg_scale: float = 7.0,
cfg_rescale: float = 0.0, cfg_rescale: float = 0.0,
seed: int = 42, seed: int = 42,
fps: int = 24, fps: int = 24,
@@ -1017,6 +1051,9 @@ def generate_video(
use_apg: bool = False, use_apg: bool = False,
apg_eta: float = 1.0, apg_eta: float = 1.0,
apg_norm_threshold: float = 0.0, apg_norm_threshold: float = 0.0,
stg_scale: float = 0.0,
stg_blocks: Optional[list] = None,
modality_scale: float = 1.0,
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
lora_strength: float = 1.0, lora_strength: float = 1.0,
): ):
@@ -1086,7 +1123,10 @@ def generate_video(
console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]") console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]")
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE): if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE):
console.print(f"[dim]Steps: {num_inference_steps}, CFG: {cfg_scale}, Rescale: {cfg_rescale}[/]") audio_cfg_info = f", Audio CFG: {audio_cfg_scale}" if audio else ""
stg_info = f", STG: {stg_scale} blocks={stg_blocks}" if stg_scale != 0.0 else ""
mod_info = f", Modality: {modality_scale}" if modality_scale != 1.0 else ""
console.print(f"[dim]Steps: {num_inference_steps}, CFG: {cfg_scale}{audio_cfg_info}, Rescale: {cfg_rescale}{stg_info}{mod_info}[/]")
if is_i2v: if is_i2v:
console.print(f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]") console.print(f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]")
@@ -1268,10 +1308,6 @@ def generate_video(
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions) mx.eval(positions)
# Save stage 1 audio latents — stage 2 only refines video (spatial upsampling).
# Audio is already fully denoised from stage 1; re-noising would destroy the signal.
stage1_audio_latents = audio_latents
state2 = None state2 = None
if is_i2v and stage2_image_latent is not None: if is_i2v and stage2_image_latent is not None:
state2 = LatentState( state2 = LatentState(
@@ -1299,13 +1335,20 @@ def generate_video(
latents = noise * noise_scale + latents * one_minus_scale latents = noise * noise_scale + latents * one_minus_scale
mx.eval(latents) mx.eval(latents)
# Stage 2 refines video only (no audio re-denoising) # Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch)
latents, _ = denoise_distilled( if audio and audio_latents is not None:
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
mx.eval(audio_latents)
# Joint video + audio refinement (no CFG, positive embeddings only)
latents, audio_latents = denoise_distilled(
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
verbose=verbose, state=state2, verbose=verbose, state=state2,
audio_latents=audio_latents, audio_positions=audio_positions,
audio_embeddings=audio_embeddings if audio else None,
) )
# Restore audio latents from stage 1
audio_latents = stage1_audio_latents
elif pipeline == PipelineType.DEV: elif pipeline == PipelineType.DEV:
# ====================================================================== # ======================================================================
@@ -1371,7 +1414,7 @@ def generate_video(
latents = mx.random.normal(video_latent_shape, dtype=model_dtype) latents = mx.random.normal(video_latent_shape, dtype=model_dtype)
mx.eval(latents) mx.eval(latents)
# Denoise with CFG/APG # Denoise with CFG/APG/STG/modality
if audio: if audio:
latents, audio_latents = denoise_dev_av( latents, audio_latents = denoise_dev_av(
latents, audio_latents, latents, audio_latents,
@@ -1379,8 +1422,11 @@ def generate_video(
video_embeddings_pos, video_embeddings_neg, video_embeddings_pos, video_embeddings_neg,
audio_embeddings_pos, audio_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale, transformer, sigmas, cfg_scale=cfg_scale,
audio_cfg_scale=audio_cfg_scale,
cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state, cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state,
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
) )
else: else:
# Use original denoise_dev with computed sigmas # Use original denoise_dev with computed sigmas
@@ -1469,7 +1515,7 @@ def generate_video(
latents = mx.random.normal(stage1_shape, dtype=model_dtype) latents = mx.random.normal(stage1_shape, dtype=model_dtype)
mx.eval(latents) mx.eval(latents)
# Run stage 1 with dev-style CFG denoising # Stage 1: Joint AV denoising at half resolution (matches PyTorch)
if audio: if audio:
latents, audio_latents = denoise_dev_av( latents, audio_latents = denoise_dev_av(
latents, audio_latents, latents, audio_latents,
@@ -1477,8 +1523,11 @@ def generate_video(
video_embeddings_pos, video_embeddings_neg, video_embeddings_pos, video_embeddings_neg,
audio_embeddings_pos, audio_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale, transformer, sigmas, cfg_scale=cfg_scale,
audio_cfg_scale=audio_cfg_scale,
cfg_rescale=cfg_rescale, verbose=verbose, video_state=state1, cfg_rescale=cfg_rescale, verbose=verbose, video_state=state1,
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
) )
else: else:
latents = denoise_dev( latents = denoise_dev(
@@ -1490,6 +1539,9 @@ def generate_video(
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold
) )
if audio and audio_latents is not None:
mx.eval(audio_latents)
# Upsample latents 2x # Upsample latents 2x
with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"): with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"):
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors")) upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
@@ -1522,14 +1574,12 @@ def generate_video(
load_and_merge_lora(transformer, lora_path, strength=lora_strength) load_and_merge_lora(transformer, lora_path, strength=lora_strength)
# Stage 2: Distilled refinement at full resolution (no CFG) # Stage 2: Distilled refinement at full resolution (no CFG)
# Matches PyTorch: re-noise audio at sigma=0.909375, then jointly refine
# both video and audio through the distilled schedule using the LoRA-merged model.
console.print(f"\n[bold yellow]⚡ Stage 2:[/] Distilled refining at {width}x{height} (3 steps, no CFG)") console.print(f"\n[bold yellow]⚡ Stage 2:[/] Distilled refining at {width}x{height} (3 steps, no CFG)")
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions) mx.eval(positions)
# Save stage 1 audio latents — stage 2 only refines video (spatial upsampling).
# Audio is already fully denoised from stage 1; re-noising would destroy the signal.
stage1_audio_latents = audio_latents
state2 = None state2 = None
if is_i2v and stage2_image_latent is not None: if is_i2v and stage2_image_latent is not None:
state2 = LatentState( state2 = LatentState(
@@ -1557,13 +1607,20 @@ def generate_video(
latents = noise * noise_scale + latents * one_minus_scale latents = noise * noise_scale + latents * one_minus_scale
mx.eval(latents) mx.eval(latents)
# Stage 2 refines video only (no audio re-denoising) # Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch)
latents, _ = denoise_distilled( if audio and audio_latents is not None:
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
mx.eval(audio_latents)
# Joint video + audio refinement (no CFG, positive embeddings only)
latents, audio_latents = denoise_distilled(
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
verbose=verbose, state=state2, verbose=verbose, state=state2,
audio_latents=audio_latents, audio_positions=audio_positions,
audio_embeddings=audio_embeddings_pos if audio else None,
) )
# Restore audio latents from stage 1
audio_latents = stage1_audio_latents
del transformer del transformer
mx.clear_cache() mx.clear_cache()
@@ -1685,6 +1742,7 @@ def generate_video(
mel_spectrogram = audio_decoder(audio_latents) mel_spectrogram = audio_decoder(audio_latents)
mx.eval(mel_spectrogram) mx.eval(mel_spectrogram)
console.print(f"[dim] Mel spectrogram: shape={mel_spectrogram.shape}, std={mel_spectrogram.std().item():.4f}, mean={mel_spectrogram.mean().item():.4f}[/]")
audio_waveform = vocoder(mel_spectrogram) audio_waveform = vocoder(mel_spectrogram)
mx.eval(audio_waveform) mx.eval(audio_waveform)
@@ -1771,7 +1829,8 @@ Examples:
parser.add_argument("--width", "-W", type=int, default=512, help="Output video width") parser.add_argument("--width", "-W", type=int, default=512, help="Output video width")
parser.add_argument("--num-frames", "-n", type=int, default=33, help="Number of frames") parser.add_argument("--num-frames", "-n", type=int, default=33, help="Number of frames")
parser.add_argument("--steps", type=int, default=30, help="Number of inference steps (dev pipeline only, default 30)") parser.add_argument("--steps", type=int, default=30, help="Number of inference steps (dev pipeline only, default 30)")
parser.add_argument("--cfg-scale", type=float, default=3.0, help="CFG guidance scale (dev pipeline only, default 3.0)") parser.add_argument("--cfg-scale", type=float, default=3.0, help="CFG guidance scale for video (dev pipeline only, default 3.0)")
parser.add_argument("--audio-cfg-scale", type=float, default=7.0, help="CFG guidance scale for audio (default 7.0, PyTorch default)")
parser.add_argument("--cfg-rescale", type=float, default=0.7, help="CFG rescale factor (0.0-1.0). Normalizes guided prediction variance to reduce artifacts (dev pipeline only, default 0.7)") parser.add_argument("--cfg-rescale", type=float, default=0.7, help="CFG rescale factor (0.0-1.0). Normalizes guided prediction variance to reduce artifacts (dev pipeline only, default 0.7)")
parser.add_argument("--seed", "-s", type=int, default=42, help="Random seed") parser.add_argument("--seed", "-s", type=int, default=42, help="Random seed")
parser.add_argument("--fps", type=int, default=24, help="Frames per second") parser.add_argument("--fps", type=int, default=24, help="Frames per second")
@@ -1795,6 +1854,9 @@ Examples:
parser.add_argument("--apg", action="store_true", help="Use Adaptive Projected Guidance instead of CFG (more stable for I2V)") parser.add_argument("--apg", action="store_true", help="Use Adaptive Projected Guidance instead of CFG (more stable for I2V)")
parser.add_argument("--apg-eta", type=float, default=1.0, help="APG parallel component weight (1.0 = keep full parallel)") parser.add_argument("--apg-eta", type=float, default=1.0, help="APG parallel component weight (1.0 = keep full parallel)")
parser.add_argument("--apg-norm-threshold", type=float, default=0.0, help="APG guidance norm clamp (0 = no clamping)") parser.add_argument("--apg-norm-threshold", type=float, default=0.0, help="APG guidance norm clamp (0 = no clamping)")
parser.add_argument("--stg-scale", type=float, default=0.0, help="STG (Spatiotemporal Guidance) scale (default 0.0 = disabled, PyTorch default: 1.0)")
parser.add_argument("--stg-blocks", type=int, nargs="+", default=None, help="Transformer block indices for STG perturbation (default: [29] for LTX-2, [28] for LTX-2.3)")
parser.add_argument("--modality-scale", type=float, default=1.0, help="Cross-modal guidance scale (default 1.0 = disabled, PyTorch default: 3.0)")
parser.add_argument("--lora-path", type=str, default=None, help="Path to LoRA safetensors file (dev-two-stage pipeline)") parser.add_argument("--lora-path", type=str, default=None, help="Path to LoRA safetensors file (dev-two-stage pipeline)")
parser.add_argument("--lora-strength", type=float, default=1.0, help="LoRA merge strength (dev-two-stage pipeline, default 1.0)") parser.add_argument("--lora-strength", type=float, default=1.0, help="LoRA merge strength (dev-two-stage pipeline, default 1.0)")
args = parser.parse_args() args = parser.parse_args()
@@ -1817,6 +1879,7 @@ Examples:
num_frames=args.num_frames, num_frames=args.num_frames,
num_inference_steps=args.steps, num_inference_steps=args.steps,
cfg_scale=args.cfg_scale, cfg_scale=args.cfg_scale,
audio_cfg_scale=args.audio_cfg_scale,
cfg_rescale=args.cfg_rescale, cfg_rescale=args.cfg_rescale,
seed=args.seed, seed=args.seed,
fps=args.fps, fps=args.fps,
@@ -1836,6 +1899,9 @@ Examples:
use_apg=args.apg, use_apg=args.apg,
apg_eta=args.apg_eta, apg_eta=args.apg_eta,
apg_norm_threshold=args.apg_norm_threshold, apg_norm_threshold=args.apg_norm_threshold,
stg_scale=args.stg_scale,
stg_blocks=args.stg_blocks,
modality_scale=args.modality_scale,
lora_path=args.lora_path, lora_path=args.lora_path,
lora_strength=args.lora_strength, lora_strength=args.lora_strength,
) )

View File

@@ -101,6 +101,7 @@ class Attention(nn.Module):
mask: Optional[mx.array] = None, mask: Optional[mx.array] = None,
pe: Optional[Tuple[mx.array, mx.array]] = None, pe: Optional[Tuple[mx.array, mx.array]] = None,
k_pe: Optional[Tuple[mx.array, mx.array]] = None, k_pe: Optional[Tuple[mx.array, mx.array]] = None,
skip_attention: bool = False,
) -> mx.array: ) -> mx.array:
"""Forward pass. """Forward pass.
@@ -110,6 +111,8 @@ class Attention(nn.Module):
mask: Attention mask mask: Attention mask
pe: Position embeddings for query (and key if k_pe is None) pe: Position embeddings for query (and key if k_pe is None)
k_pe: Position embeddings for key (optional, uses pe if None) k_pe: Position embeddings for key (optional, uses pe if None)
skip_attention: If True, bypass Q*K*V attention and use value projection
only (for STG perturbation). Matches PyTorch all_perturbed=True.
Returns: Returns:
Attention output of shape (B, seq_len, query_dim) Attention output of shape (B, seq_len, query_dim)
@@ -119,24 +122,26 @@ class Attention(nn.Module):
if hasattr(self, "to_gate_logits"): if hasattr(self, "to_gate_logits"):
gate = 2.0 * mx.sigmoid(self.to_gate_logits(x)) # (B, seq, heads) gate = 2.0 * mx.sigmoid(self.to_gate_logits(x)) # (B, seq, heads)
# Compute Q, K, V
q = self.to_q(x)
context = x if context is None else context context = x if context is None else context
k = self.to_k(context)
v = self.to_v(context) v = self.to_v(context)
# Apply normalization if skip_attention:
q = self.q_norm(q) # STG: bypass Q*K*V attention, use value projection only
k = self.k_norm(k) out = v
else:
# Standard attention
q = self.to_q(x)
k = self.to_k(context)
# Apply rotary position embeddings q = self.q_norm(q)
if pe is not None: k = self.k_norm(k)
q = apply_rotary_emb(q, pe, self.rope_type)
k_pe_to_use = pe if k_pe is None else k_pe
k = apply_rotary_emb(k, k_pe_to_use, self.rope_type)
# Compute attention if pe is not None:
out = scaled_dot_product_attention(q, k, v, self.heads, mask) q = apply_rotary_emb(q, pe, self.rope_type)
k_pe_to_use = pe if k_pe is None else k_pe
k = apply_rotary_emb(k, k_pe_to_use, self.rope_type)
out = scaled_dot_product_attention(q, k, v, self.heads, mask)
# Apply per-head gating # Apply per-head gating
if gate is not None: if gate is not None:

View File

@@ -453,10 +453,26 @@ class LTXModel(nn.Module):
self, self,
video: Optional[TransformerArgs], video: Optional[TransformerArgs],
audio: Optional[TransformerArgs], audio: Optional[TransformerArgs],
stg_video_blocks: Optional[List[int]] = None,
stg_audio_blocks: Optional[List[int]] = None,
skip_cross_modal: bool = False,
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]: ) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
"""Process through all transformer blocks.""" """Process through all transformer blocks.
for block in self.transformer_blocks.values():
video, audio = block(video=video, audio=audio) Args:
stg_video_blocks: Block indices where video self-attention is skipped (STG).
stg_audio_blocks: Block indices where audio self-attention is skipped (STG).
skip_cross_modal: Skip all A2V/V2A cross-attention (modality isolation).
"""
stg_v_set = set(stg_video_blocks) if stg_video_blocks else set()
stg_a_set = set(stg_audio_blocks) if stg_audio_blocks else set()
for idx, block in self.transformer_blocks.items():
video, audio = block(
video=video, audio=audio,
skip_video_self_attn=(idx in stg_v_set),
skip_audio_self_attn=(idx in stg_a_set),
skip_cross_modal=skip_cross_modal,
)
return video, audio return video, audio
def _process_output( def _process_output(
@@ -490,8 +506,19 @@ class LTXModel(nn.Module):
self, self,
video: Optional[Modality] = None, video: Optional[Modality] = None,
audio: Optional[Modality] = None, audio: Optional[Modality] = None,
stg_video_blocks: Optional[List[int]] = None,
stg_audio_blocks: Optional[List[int]] = None,
skip_cross_modal: bool = False,
) -> Tuple[Optional[mx.array], Optional[mx.array]]: ) -> Tuple[Optional[mx.array], Optional[mx.array]]:
"""Forward pass.
Args:
video: Video modality input.
audio: Audio modality input.
stg_video_blocks: Block indices where video self-attention is skipped (STG).
stg_audio_blocks: Block indices where audio self-attention is skipped (STG).
skip_cross_modal: Skip all A2V/V2A cross-attention (modality isolation).
"""
# Validate inputs # Validate inputs
if not self.model_type.is_video_enabled() and video is not None: if not self.model_type.is_video_enabled() and video is not None:
raise ValueError("Video is not enabled for this model") raise ValueError("Video is not enabled for this model")
@@ -506,6 +533,9 @@ class LTXModel(nn.Module):
video_out, audio_out = self._process_transformer_blocks( video_out, audio_out = self._process_transformer_blocks(
video=video_args, video=video_args,
audio=audio_args, audio=audio_args,
stg_video_blocks=stg_video_blocks,
stg_audio_blocks=stg_audio_blocks,
skip_cross_modal=skip_cross_modal,
) )
# Process outputs # Process outputs
@@ -603,9 +633,17 @@ class X0Model(nn.Module):
self, self,
video: Optional[Modality] = None, video: Optional[Modality] = None,
audio: Optional[Modality] = None, audio: Optional[Modality] = None,
stg_video_blocks: Optional[List[int]] = None,
stg_audio_blocks: Optional[List[int]] = None,
skip_cross_modal: bool = False,
) -> Tuple[Optional[mx.array], Optional[mx.array]]: ) -> Tuple[Optional[mx.array], Optional[mx.array]]:
vx, ax = self.velocity_model(video, audio) vx, ax = self.velocity_model(
video, audio,
stg_video_blocks=stg_video_blocks,
stg_audio_blocks=stg_audio_blocks,
skip_cross_modal=skip_cross_modal,
)
denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None

View File

@@ -234,12 +234,18 @@ class BasicAVTransformerBlock(nn.Module):
self, self,
video: Optional[TransformerArgs] = None, video: Optional[TransformerArgs] = None,
audio: Optional[TransformerArgs] = None, audio: Optional[TransformerArgs] = None,
skip_video_self_attn: bool = False,
skip_audio_self_attn: bool = False,
skip_cross_modal: bool = False,
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]: ) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
"""Forward pass through transformer block. """Forward pass through transformer block.
Args: Args:
video: Video modality arguments video: Video modality arguments
audio: Audio modality arguments audio: Audio modality arguments
skip_video_self_attn: Skip video self-attention (for STG perturbation)
skip_audio_self_attn: Skip audio self-attention (for STG perturbation)
skip_cross_modal: Skip all cross-modal attention (for modality isolation)
Returns: Returns:
Tuple of (updated_video, updated_audio) TransformerArgs Tuple of (updated_video, updated_audio) TransformerArgs
@@ -252,8 +258,8 @@ class BasicAVTransformerBlock(nn.Module):
# Check which modalities to run # Check which modalities to run
run_vx = video is not None and video.enabled and vx.size > 0 run_vx = video is not None and video.enabled and vx.size > 0
run_ax = audio is not None and audio.enabled and ax.size > 0 run_ax = audio is not None and audio.enabled and ax.size > 0
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0) run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0) and not skip_cross_modal
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0) run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0) and not skip_cross_modal
# Process video self-attention and cross-attention with text # Process video self-attention and cross-attention with text
if run_vx: if run_vx:
@@ -261,9 +267,9 @@ class BasicAVTransformerBlock(nn.Module):
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3) self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
) )
# Self-attention with RoPE # Self-attention with RoPE (skip_attention=True for STG perturbation)
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings, skip_attention=skip_video_self_attn) * vgate_msa
# Cross-attention with text context # Cross-attention with text context
if self.has_prompt_adaln: if self.has_prompt_adaln:
@@ -290,9 +296,9 @@ class BasicAVTransformerBlock(nn.Module):
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3) self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
) )
# Self-attention with RoPE # Self-attention with RoPE (skip_attention=True for STG perturbation)
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings, skip_attention=skip_audio_self_attn) * agate_msa
# Cross-attention with text context # Cross-attention with text context
if self.has_prompt_adaln: if self.has_prompt_adaln: