Enhance README.md with new usage examples for STG and modality scale parameters in video generation. Update generate.py to support STG and modality guidance in the denoising process, allowing for improved audio-visual integration. Refactor attention mechanisms in the transformer to include options for skipping self-attention, facilitating STG perturbation and modality isolation. Update LTXModel and transformer block processing to accommodate new parameters for enhanced flexibility in model configurations.

This commit is contained in:
Prince Canuma
2026-03-14 10:26:12 +01:00
parent f346e09de4
commit 9cba2ea7cd
5 changed files with 200 additions and 78 deletions

View File

@@ -234,12 +234,18 @@ class BasicAVTransformerBlock(nn.Module):
self,
video: Optional[TransformerArgs] = None,
audio: Optional[TransformerArgs] = None,
skip_video_self_attn: bool = False,
skip_audio_self_attn: bool = False,
skip_cross_modal: bool = False,
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
"""Forward pass through transformer block.
Args:
video: Video modality arguments
audio: Audio modality arguments
skip_video_self_attn: Skip video self-attention (for STG perturbation)
skip_audio_self_attn: Skip audio self-attention (for STG perturbation)
skip_cross_modal: Skip all cross-modal attention (for modality isolation)
Returns:
Tuple of (updated_video, updated_audio) TransformerArgs
@@ -252,8 +258,8 @@ class BasicAVTransformerBlock(nn.Module):
# Check which modalities to run
run_vx = video is not None and video.enabled and vx.size > 0
run_ax = audio is not None and audio.enabled and ax.size > 0
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0)
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0)
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0) and not skip_cross_modal
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0) and not skip_cross_modal
# Process video self-attention and cross-attention with text
if run_vx:
@@ -261,9 +267,9 @@ class BasicAVTransformerBlock(nn.Module):
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
)
# Self-attention with RoPE
# Self-attention with RoPE (skip_attention=True for STG perturbation)
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings, skip_attention=skip_video_self_attn) * vgate_msa
# Cross-attention with text context
if self.has_prompt_adaln:
@@ -290,9 +296,9 @@ class BasicAVTransformerBlock(nn.Module):
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
)
# Self-attention with RoPE
# Self-attention with RoPE (skip_attention=True for STG perturbation)
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings, skip_attention=skip_audio_self_attn) * agate_msa
# Cross-attention with text context
if self.has_prompt_adaln: