diff --git a/README.md b/README.md index 99b3f62..da7c7aa 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,10 @@ uv run mlx_video.generate --pipeline dev --prompt "Waves crashing" --image beach ```bash uv run mlx_video.generate --prompt "Ocean waves crashing" --audio uv run mlx_video.generate --pipeline dev --prompt "A jazz band playing" --audio --enhance-prompt + +# With full guidance (STG + modality_scale, matches PyTorch defaults) +uv run mlx_video.generate --pipeline dev --prompt "Ocean waves crashing" --audio \ + --stg-scale 1.0 --stg-blocks 29 --modality-scale 3.0 ``` ### LoRA @@ -146,6 +150,9 @@ uv run mlx_video.upscale --input video.mp4 --output upscaled.mp4 --refine --prom | `--cfg-rescale` | 0.7 | CFG rescale factor (reduces over-saturation) | | `--negative-prompt` | (default) | Negative prompt for CFG | | `--apg` | false | Use Adaptive Projected Guidance (more stable for I2V) | +| `--stg-scale` | 0.0 | STG scale (PyTorch default: 1.0, requires `--audio`) | +| `--stg-blocks` | None | Transformer blocks for STG ([29] for LTX-2, [28] for LTX-2.3) | +| `--modality-scale` | 1.0 | Cross-modal guidance scale (PyTorch default: 3.0, requires `--audio`) | **Dev-Two-Stage LoRA options:** diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 1f0d2e1..daa7ed0 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -715,22 +715,31 @@ def denoise_dev_av( transformer: LTXModel, sigmas: mx.array, cfg_scale: float = 4.0, + audio_cfg_scale: float = 7.0, cfg_rescale: float = 0.0, verbose: bool = True, video_state: Optional[LatentState] = None, use_apg: bool = False, apg_eta: float = 1.0, apg_norm_threshold: float = 0.0, + stg_scale: float = 0.0, + stg_video_blocks: Optional[list] = None, + stg_audio_blocks: Optional[list] = None, + modality_scale: float = 1.0, ) -> tuple[mx.array, mx.array]: - """Run denoising loop for dev pipeline with CFG/APG and audio. + """Run denoising loop for dev pipeline with CFG/APG, STG, modality guidance, and audio. Args: - cfg_rescale: Rescale factor for CFG (0.0-1.0). Higher values blend the CFG result - towards the positive-only prediction, helping reduce artifacts. - Default 0.0 means no rescaling (standard CFG). + audio_cfg_scale: Separate CFG scale for audio (PyTorch default: 7.0). + cfg_rescale: Rescale factor for CFG (0.0-1.0). Normalizes guided prediction + variance to reduce artifacts. Default 0.0 means no rescaling. use_apg: Use Adaptive Projected Guidance instead of standard CFG for video. apg_eta: APG parallel component weight (1.0 = keep full parallel) apg_norm_threshold: APG guidance norm clamp (0 = no clamping) + stg_scale: STG (Spatiotemporal Guidance) scale. 0.0 = disabled. + stg_video_blocks: Transformer block indices for video STG perturbation. + stg_audio_blocks: Transformer block indices for audio STG perturbation. + modality_scale: Cross-modal guidance scale. 1.0 = disabled. """ from mlx_video.models.ltx.rope import precompute_freqs_cis @@ -738,14 +747,14 @@ def denoise_dev_av( if video_state is not None: video_latents = video_state.latent - # Keep latents in float32 throughout the denoising loop to avoid - # bfloat16 quantization noise accumulation over many steps. - # PyTorch keeps latents in float32; model input is cast to model dtype. + # Keep latents in float32 throughout the denoising loop for precision. video_latents = video_latents.astype(mx.float32) audio_latents = audio_latents.astype(mx.float32) sigmas_list = sigmas.tolist() use_cfg = cfg_scale != 1.0 + use_stg = stg_scale != 0.0 and stg_video_blocks is not None + use_modality = modality_scale != 1.0 num_steps = len(sigmas_list) - 1 # Precompute video RoPE @@ -782,7 +791,11 @@ def denoise_dev_av( console=console, disable=not verbose, ) as progress: - task = progress.add_task("[cyan]Denoising A/V (CFG)[/]", total=num_steps) + passes = ["CFG"] if use_cfg else [] + if use_stg: passes.append("STG") + if use_modality: passes.append("Mod") + label = "+".join(passes) if passes else "uncond" + task = progress.add_task(f"[cyan]Denoising A/V ({label})[/]", total=num_steps) for i in range(num_steps): sigma = sigmas_list[i] @@ -827,7 +840,6 @@ def denoise_dev_av( # This matches PyTorch's X0ModelWrapper: x0 = latent - timestep * velocity # For conditioned tokens (timestep=0): x0 = latent (velocity is irrelevant) # For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity - # Use the float32 latents (not the bfloat16 model input) for precision video_flat_f32 = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1)) audio_flat_f32 = mx.reshape(mx.transpose(audio_latents, (0, 2, 1, 3)), (ab, at, ac * af)) video_timesteps_f32 = mx.expand_dims(video_timesteps.astype(mx.float32), axis=-1) @@ -836,8 +848,12 @@ def denoise_dev_av( video_x0_pos_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_pos.astype(mx.float32) audio_x0_pos_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_pos.astype(mx.float32) + # Start with positive prediction + video_x0_guided_f32 = video_x0_pos_f32 + audio_x0_guided_f32 = audio_x0_pos_f32 + + # Pass 2: CFG (negative conditioning) if use_cfg: - # Negative conditioning pass video_modality_neg = Modality( latent=video_flat, timesteps=video_timesteps, positions=video_positions, context=video_embeddings_neg, context_mask=None, enabled=True, @@ -851,36 +867,54 @@ def denoise_dev_av( video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg) mx.eval(video_vel_neg, audio_vel_neg) - # Convert negative velocity to x0 using per-token timesteps video_x0_neg_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_neg.astype(mx.float32) audio_x0_neg_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_neg.astype(mx.float32) - # Apply guidance to x0 (denoised) predictions - # For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 (no effect) if use_apg: - # APG for video (more stable for I2V), standard CFG for audio video_x0_guided_f32 = video_x0_pos_f32 + apg_delta( video_x0_pos_f32, video_x0_neg_f32, cfg_scale, eta=apg_eta, norm_threshold=apg_norm_threshold ) else: video_x0_guided_f32 = video_x0_pos_f32 + (cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32) - # Always use standard CFG for audio - audio_x0_guided_f32 = audio_x0_pos_f32 + (cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32) + audio_x0_guided_f32 = audio_x0_pos_f32 + (audio_cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32) - # Apply CFG rescale if enabled (std-ratio rescaling to reduce over-saturation) - # factor = rescale * (cond_std / pred_std) + (1 - rescale) - # pred = pred * factor - if cfg_rescale > 0.0: - v_factor = video_x0_pos_f32.std() / (video_x0_guided_f32.std() + 1e-8) - v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale) - video_x0_guided_f32 = video_x0_guided_f32 * v_factor - a_factor = audio_x0_pos_f32.std() / (audio_x0_guided_f32.std() + 1e-8) - a_factor = cfg_rescale * a_factor + (1.0 - cfg_rescale) - audio_x0_guided_f32 = audio_x0_guided_f32 * a_factor - else: - video_x0_guided_f32 = video_x0_pos_f32 - audio_x0_guided_f32 = audio_x0_pos_f32 + # Pass 3: STG (self-attention perturbation at specified blocks) + if use_stg: + video_vel_ptb, audio_vel_ptb = transformer( + video=video_modality_pos, audio=audio_modality_pos, + stg_video_blocks=stg_video_blocks, stg_audio_blocks=stg_audio_blocks, + ) + mx.eval(video_vel_ptb, audio_vel_ptb) + + video_x0_ptb_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_ptb.astype(mx.float32) + audio_x0_ptb_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_ptb.astype(mx.float32) + + video_x0_guided_f32 = video_x0_guided_f32 + stg_scale * (video_x0_pos_f32 - video_x0_ptb_f32) + audio_x0_guided_f32 = audio_x0_guided_f32 + stg_scale * (audio_x0_pos_f32 - audio_x0_ptb_f32) + + # Pass 4: Modality isolation (skip all cross-modal attention) + if use_modality: + video_vel_iso, audio_vel_iso = transformer( + video=video_modality_pos, audio=audio_modality_pos, + skip_cross_modal=True, + ) + mx.eval(video_vel_iso, audio_vel_iso) + + video_x0_iso_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_iso.astype(mx.float32) + audio_x0_iso_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_iso.astype(mx.float32) + + video_x0_guided_f32 = video_x0_guided_f32 + (modality_scale - 1.0) * (video_x0_pos_f32 - video_x0_iso_f32) + audio_x0_guided_f32 = audio_x0_guided_f32 + (modality_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_iso_f32) + + # Apply CFG rescale (std-ratio rescaling to reduce over-saturation) + if cfg_rescale > 0.0 and (use_cfg or use_stg or use_modality): + v_factor = video_x0_pos_f32.std() / (video_x0_guided_f32.std() + 1e-8) + v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale) + video_x0_guided_f32 = video_x0_guided_f32 * v_factor + a_factor = audio_x0_pos_f32.std() / (audio_x0_guided_f32.std() + 1e-8) + a_factor = cfg_rescale * a_factor + (1.0 - cfg_rescale) + audio_x0_guided_f32 = audio_x0_guided_f32 * a_factor # Reshape x0 from token space (b, tokens, c) to spatial (b, c, f, h, w) video_denoised_f32 = mx.reshape(mx.transpose(video_x0_guided_f32, (0, 2, 1)), (b, c, f, h, w)) @@ -898,8 +932,7 @@ def denoise_dev_av( mx.eval(video_denoised_f32, audio_denoised_f32) - # Euler step matching PyTorch: sample + velocity * dt - # Latents stay in float32 throughout (matching PyTorch behavior) + # Euler step: sample + velocity * dt (float32) if sigma_next > 0: sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32) dt_f32 = sigma_next_f32 - sigma_f32 @@ -998,6 +1031,7 @@ def generate_video( num_frames: int = 33, num_inference_steps: int = 40, cfg_scale: float = 4.0, + audio_cfg_scale: float = 7.0, cfg_rescale: float = 0.0, seed: int = 42, fps: int = 24, @@ -1017,6 +1051,9 @@ def generate_video( use_apg: bool = False, apg_eta: float = 1.0, apg_norm_threshold: float = 0.0, + stg_scale: float = 0.0, + stg_blocks: Optional[list] = None, + modality_scale: float = 1.0, lora_path: Optional[str] = None, lora_strength: float = 1.0, ): @@ -1086,7 +1123,10 @@ def generate_video( console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]") if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE): - console.print(f"[dim]Steps: {num_inference_steps}, CFG: {cfg_scale}, Rescale: {cfg_rescale}[/]") + audio_cfg_info = f", Audio CFG: {audio_cfg_scale}" if audio else "" + stg_info = f", STG: {stg_scale} blocks={stg_blocks}" if stg_scale != 0.0 else "" + mod_info = f", Modality: {modality_scale}" if modality_scale != 1.0 else "" + console.print(f"[dim]Steps: {num_inference_steps}, CFG: {cfg_scale}{audio_cfg_info}, Rescale: {cfg_rescale}{stg_info}{mod_info}[/]") if is_i2v: console.print(f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]") @@ -1268,10 +1308,6 @@ def generate_video( positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) mx.eval(positions) - # Save stage 1 audio latents — stage 2 only refines video (spatial upsampling). - # Audio is already fully denoised from stage 1; re-noising would destroy the signal. - stage1_audio_latents = audio_latents - state2 = None if is_i2v and stage2_image_latent is not None: state2 = LatentState( @@ -1299,13 +1335,20 @@ def generate_video( latents = noise * noise_scale + latents * one_minus_scale mx.eval(latents) - # Stage 2 refines video only (no audio re-denoising) - latents, _ = denoise_distilled( + # Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch) + if audio and audio_latents is not None: + audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype) + audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) + audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale) + mx.eval(audio_latents) + + # Joint video + audio refinement (no CFG, positive embeddings only) + latents, audio_latents = denoise_distilled( latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2, + audio_latents=audio_latents, audio_positions=audio_positions, + audio_embeddings=audio_embeddings if audio else None, ) - # Restore audio latents from stage 1 - audio_latents = stage1_audio_latents elif pipeline == PipelineType.DEV: # ====================================================================== @@ -1371,7 +1414,7 @@ def generate_video( latents = mx.random.normal(video_latent_shape, dtype=model_dtype) mx.eval(latents) - # Denoise with CFG/APG + # Denoise with CFG/APG/STG/modality if audio: latents, audio_latents = denoise_dev_av( latents, audio_latents, @@ -1379,8 +1422,11 @@ def generate_video( video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg, transformer, sigmas, cfg_scale=cfg_scale, + audio_cfg_scale=audio_cfg_scale, cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state, - use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold + use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold, + stg_scale=stg_scale, stg_video_blocks=stg_blocks, + stg_audio_blocks=stg_blocks, modality_scale=modality_scale, ) else: # Use original denoise_dev with computed sigmas @@ -1469,7 +1515,7 @@ def generate_video( latents = mx.random.normal(stage1_shape, dtype=model_dtype) mx.eval(latents) - # Run stage 1 with dev-style CFG denoising + # Stage 1: Joint AV denoising at half resolution (matches PyTorch) if audio: latents, audio_latents = denoise_dev_av( latents, audio_latents, @@ -1477,8 +1523,11 @@ def generate_video( video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg, transformer, sigmas, cfg_scale=cfg_scale, + audio_cfg_scale=audio_cfg_scale, cfg_rescale=cfg_rescale, verbose=verbose, video_state=state1, - use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold + use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold, + stg_scale=stg_scale, stg_video_blocks=stg_blocks, + stg_audio_blocks=stg_blocks, modality_scale=modality_scale, ) else: latents = denoise_dev( @@ -1490,6 +1539,9 @@ def generate_video( use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold ) + if audio and audio_latents is not None: + mx.eval(audio_latents) + # Upsample latents 2x with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"): upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors")) @@ -1522,14 +1574,12 @@ def generate_video( load_and_merge_lora(transformer, lora_path, strength=lora_strength) # Stage 2: Distilled refinement at full resolution (no CFG) + # Matches PyTorch: re-noise audio at sigma=0.909375, then jointly refine + # both video and audio through the distilled schedule using the LoRA-merged model. console.print(f"\n[bold yellow]⚡ Stage 2:[/] Distilled refining at {width}x{height} (3 steps, no CFG)") positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) mx.eval(positions) - # Save stage 1 audio latents — stage 2 only refines video (spatial upsampling). - # Audio is already fully denoised from stage 1; re-noising would destroy the signal. - stage1_audio_latents = audio_latents - state2 = None if is_i2v and stage2_image_latent is not None: state2 = LatentState( @@ -1557,13 +1607,20 @@ def generate_video( latents = noise * noise_scale + latents * one_minus_scale mx.eval(latents) - # Stage 2 refines video only (no audio re-denoising) - latents, _ = denoise_distilled( + # Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch) + if audio and audio_latents is not None: + audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype) + audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) + audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale) + mx.eval(audio_latents) + + # Joint video + audio refinement (no CFG, positive embeddings only) + latents, audio_latents = denoise_distilled( latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2, + audio_latents=audio_latents, audio_positions=audio_positions, + audio_embeddings=audio_embeddings_pos if audio else None, ) - # Restore audio latents from stage 1 - audio_latents = stage1_audio_latents del transformer mx.clear_cache() @@ -1685,6 +1742,7 @@ def generate_video( mel_spectrogram = audio_decoder(audio_latents) mx.eval(mel_spectrogram) + console.print(f"[dim] Mel spectrogram: shape={mel_spectrogram.shape}, std={mel_spectrogram.std().item():.4f}, mean={mel_spectrogram.mean().item():.4f}[/]") audio_waveform = vocoder(mel_spectrogram) mx.eval(audio_waveform) @@ -1771,7 +1829,8 @@ Examples: parser.add_argument("--width", "-W", type=int, default=512, help="Output video width") parser.add_argument("--num-frames", "-n", type=int, default=33, help="Number of frames") parser.add_argument("--steps", type=int, default=30, help="Number of inference steps (dev pipeline only, default 30)") - parser.add_argument("--cfg-scale", type=float, default=3.0, help="CFG guidance scale (dev pipeline only, default 3.0)") + parser.add_argument("--cfg-scale", type=float, default=3.0, help="CFG guidance scale for video (dev pipeline only, default 3.0)") + parser.add_argument("--audio-cfg-scale", type=float, default=7.0, help="CFG guidance scale for audio (default 7.0, PyTorch default)") parser.add_argument("--cfg-rescale", type=float, default=0.7, help="CFG rescale factor (0.0-1.0). Normalizes guided prediction variance to reduce artifacts (dev pipeline only, default 0.7)") parser.add_argument("--seed", "-s", type=int, default=42, help="Random seed") parser.add_argument("--fps", type=int, default=24, help="Frames per second") @@ -1795,6 +1854,9 @@ Examples: parser.add_argument("--apg", action="store_true", help="Use Adaptive Projected Guidance instead of CFG (more stable for I2V)") parser.add_argument("--apg-eta", type=float, default=1.0, help="APG parallel component weight (1.0 = keep full parallel)") parser.add_argument("--apg-norm-threshold", type=float, default=0.0, help="APG guidance norm clamp (0 = no clamping)") + parser.add_argument("--stg-scale", type=float, default=0.0, help="STG (Spatiotemporal Guidance) scale (default 0.0 = disabled, PyTorch default: 1.0)") + parser.add_argument("--stg-blocks", type=int, nargs="+", default=None, help="Transformer block indices for STG perturbation (default: [29] for LTX-2, [28] for LTX-2.3)") + parser.add_argument("--modality-scale", type=float, default=1.0, help="Cross-modal guidance scale (default 1.0 = disabled, PyTorch default: 3.0)") parser.add_argument("--lora-path", type=str, default=None, help="Path to LoRA safetensors file (dev-two-stage pipeline)") parser.add_argument("--lora-strength", type=float, default=1.0, help="LoRA merge strength (dev-two-stage pipeline, default 1.0)") args = parser.parse_args() @@ -1817,6 +1879,7 @@ Examples: num_frames=args.num_frames, num_inference_steps=args.steps, cfg_scale=args.cfg_scale, + audio_cfg_scale=args.audio_cfg_scale, cfg_rescale=args.cfg_rescale, seed=args.seed, fps=args.fps, @@ -1836,6 +1899,9 @@ Examples: use_apg=args.apg, apg_eta=args.apg_eta, apg_norm_threshold=args.apg_norm_threshold, + stg_scale=args.stg_scale, + stg_blocks=args.stg_blocks, + modality_scale=args.modality_scale, lora_path=args.lora_path, lora_strength=args.lora_strength, ) diff --git a/mlx_video/models/ltx/attention.py b/mlx_video/models/ltx/attention.py index ebc0a24..99e249c 100644 --- a/mlx_video/models/ltx/attention.py +++ b/mlx_video/models/ltx/attention.py @@ -101,6 +101,7 @@ class Attention(nn.Module): mask: Optional[mx.array] = None, pe: Optional[Tuple[mx.array, mx.array]] = None, k_pe: Optional[Tuple[mx.array, mx.array]] = None, + skip_attention: bool = False, ) -> mx.array: """Forward pass. @@ -110,6 +111,8 @@ class Attention(nn.Module): mask: Attention mask pe: Position embeddings for query (and key if k_pe is None) k_pe: Position embeddings for key (optional, uses pe if None) + skip_attention: If True, bypass Q*K*V attention and use value projection + only (for STG perturbation). Matches PyTorch all_perturbed=True. Returns: Attention output of shape (B, seq_len, query_dim) @@ -119,24 +122,26 @@ class Attention(nn.Module): if hasattr(self, "to_gate_logits"): gate = 2.0 * mx.sigmoid(self.to_gate_logits(x)) # (B, seq, heads) - # Compute Q, K, V - q = self.to_q(x) context = x if context is None else context - k = self.to_k(context) v = self.to_v(context) - # Apply normalization - q = self.q_norm(q) - k = self.k_norm(k) + if skip_attention: + # STG: bypass Q*K*V attention, use value projection only + out = v + else: + # Standard attention + q = self.to_q(x) + k = self.to_k(context) - # Apply rotary position embeddings - if pe is not None: - q = apply_rotary_emb(q, pe, self.rope_type) - k_pe_to_use = pe if k_pe is None else k_pe - k = apply_rotary_emb(k, k_pe_to_use, self.rope_type) + q = self.q_norm(q) + k = self.k_norm(k) - # Compute attention - out = scaled_dot_product_attention(q, k, v, self.heads, mask) + if pe is not None: + q = apply_rotary_emb(q, pe, self.rope_type) + k_pe_to_use = pe if k_pe is None else k_pe + k = apply_rotary_emb(k, k_pe_to_use, self.rope_type) + + out = scaled_dot_product_attention(q, k, v, self.heads, mask) # Apply per-head gating if gate is not None: diff --git a/mlx_video/models/ltx/ltx.py b/mlx_video/models/ltx/ltx.py index 6a63d7b..527e523 100644 --- a/mlx_video/models/ltx/ltx.py +++ b/mlx_video/models/ltx/ltx.py @@ -453,10 +453,26 @@ class LTXModel(nn.Module): self, video: Optional[TransformerArgs], audio: Optional[TransformerArgs], + stg_video_blocks: Optional[List[int]] = None, + stg_audio_blocks: Optional[List[int]] = None, + skip_cross_modal: bool = False, ) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]: - """Process through all transformer blocks.""" - for block in self.transformer_blocks.values(): - video, audio = block(video=video, audio=audio) + """Process through all transformer blocks. + + Args: + stg_video_blocks: Block indices where video self-attention is skipped (STG). + stg_audio_blocks: Block indices where audio self-attention is skipped (STG). + skip_cross_modal: Skip all A2V/V2A cross-attention (modality isolation). + """ + stg_v_set = set(stg_video_blocks) if stg_video_blocks else set() + stg_a_set = set(stg_audio_blocks) if stg_audio_blocks else set() + for idx, block in self.transformer_blocks.items(): + video, audio = block( + video=video, audio=audio, + skip_video_self_attn=(idx in stg_v_set), + skip_audio_self_attn=(idx in stg_a_set), + skip_cross_modal=skip_cross_modal, + ) return video, audio def _process_output( @@ -490,8 +506,19 @@ class LTXModel(nn.Module): self, video: Optional[Modality] = None, audio: Optional[Modality] = None, + stg_video_blocks: Optional[List[int]] = None, + stg_audio_blocks: Optional[List[int]] = None, + skip_cross_modal: bool = False, ) -> Tuple[Optional[mx.array], Optional[mx.array]]: - + """Forward pass. + + Args: + video: Video modality input. + audio: Audio modality input. + stg_video_blocks: Block indices where video self-attention is skipped (STG). + stg_audio_blocks: Block indices where audio self-attention is skipped (STG). + skip_cross_modal: Skip all A2V/V2A cross-attention (modality isolation). + """ # Validate inputs if not self.model_type.is_video_enabled() and video is not None: raise ValueError("Video is not enabled for this model") @@ -506,6 +533,9 @@ class LTXModel(nn.Module): video_out, audio_out = self._process_transformer_blocks( video=video_args, audio=audio_args, + stg_video_blocks=stg_video_blocks, + stg_audio_blocks=stg_audio_blocks, + skip_cross_modal=skip_cross_modal, ) # Process outputs @@ -603,9 +633,17 @@ class X0Model(nn.Module): self, video: Optional[Modality] = None, audio: Optional[Modality] = None, + stg_video_blocks: Optional[List[int]] = None, + stg_audio_blocks: Optional[List[int]] = None, + skip_cross_modal: bool = False, ) -> Tuple[Optional[mx.array], Optional[mx.array]]: - - vx, ax = self.velocity_model(video, audio) + + vx, ax = self.velocity_model( + video, audio, + stg_video_blocks=stg_video_blocks, + stg_audio_blocks=stg_audio_blocks, + skip_cross_modal=skip_cross_modal, + ) denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None diff --git a/mlx_video/models/ltx/transformer.py b/mlx_video/models/ltx/transformer.py index 4b311e6..e4355b0 100644 --- a/mlx_video/models/ltx/transformer.py +++ b/mlx_video/models/ltx/transformer.py @@ -234,12 +234,18 @@ class BasicAVTransformerBlock(nn.Module): self, video: Optional[TransformerArgs] = None, audio: Optional[TransformerArgs] = None, + skip_video_self_attn: bool = False, + skip_audio_self_attn: bool = False, + skip_cross_modal: bool = False, ) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]: """Forward pass through transformer block. Args: video: Video modality arguments audio: Audio modality arguments + skip_video_self_attn: Skip video self-attention (for STG perturbation) + skip_audio_self_attn: Skip audio self-attention (for STG perturbation) + skip_cross_modal: Skip all cross-modal attention (for modality isolation) Returns: Tuple of (updated_video, updated_audio) TransformerArgs @@ -252,8 +258,8 @@ class BasicAVTransformerBlock(nn.Module): # Check which modalities to run run_vx = video is not None and video.enabled and vx.size > 0 run_ax = audio is not None and audio.enabled and ax.size > 0 - run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0) - run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0) + run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0) and not skip_cross_modal + run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0) and not skip_cross_modal # Process video self-attention and cross-attention with text if run_vx: @@ -261,9 +267,9 @@ class BasicAVTransformerBlock(nn.Module): self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3) ) - # Self-attention with RoPE + # Self-attention with RoPE (skip_attention=True for STG perturbation) norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa - vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa + vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings, skip_attention=skip_video_self_attn) * vgate_msa # Cross-attention with text context if self.has_prompt_adaln: @@ -290,9 +296,9 @@ class BasicAVTransformerBlock(nn.Module): self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3) ) - # Self-attention with RoPE + # Self-attention with RoPE (skip_attention=True for STG perturbation) norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa - ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa + ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings, skip_attention=skip_audio_self_attn) * agate_msa # Cross-attention with text context if self.has_prompt_adaln: