Enhance README.md with new usage examples for STG and modality scale parameters in video generation. Update generate.py to support STG and modality guidance in the denoising process, allowing for improved audio-visual integration. Refactor attention mechanisms in the transformer to include options for skipping self-attention, facilitating STG perturbation and modality isolation. Update LTXModel and transformer block processing to accommodate new parameters for enhanced flexibility in model configurations.
This commit is contained in:
@@ -234,12 +234,18 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
self,
|
||||
video: Optional[TransformerArgs] = None,
|
||||
audio: Optional[TransformerArgs] = None,
|
||||
skip_video_self_attn: bool = False,
|
||||
skip_audio_self_attn: bool = False,
|
||||
skip_cross_modal: bool = False,
|
||||
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
|
||||
"""Forward pass through transformer block.
|
||||
|
||||
Args:
|
||||
video: Video modality arguments
|
||||
audio: Audio modality arguments
|
||||
skip_video_self_attn: Skip video self-attention (for STG perturbation)
|
||||
skip_audio_self_attn: Skip audio self-attention (for STG perturbation)
|
||||
skip_cross_modal: Skip all cross-modal attention (for modality isolation)
|
||||
|
||||
Returns:
|
||||
Tuple of (updated_video, updated_audio) TransformerArgs
|
||||
@@ -252,8 +258,8 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
# Check which modalities to run
|
||||
run_vx = video is not None and video.enabled and vx.size > 0
|
||||
run_ax = audio is not None and audio.enabled and ax.size > 0
|
||||
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0)
|
||||
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0)
|
||||
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0) and not skip_cross_modal
|
||||
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0) and not skip_cross_modal
|
||||
|
||||
# Process video self-attention and cross-attention with text
|
||||
if run_vx:
|
||||
@@ -261,9 +267,9 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
|
||||
)
|
||||
|
||||
# Self-attention with RoPE
|
||||
# Self-attention with RoPE (skip_attention=True for STG perturbation)
|
||||
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
|
||||
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa
|
||||
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings, skip_attention=skip_video_self_attn) * vgate_msa
|
||||
|
||||
# Cross-attention with text context
|
||||
if self.has_prompt_adaln:
|
||||
@@ -290,9 +296,9 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
|
||||
)
|
||||
|
||||
# Self-attention with RoPE
|
||||
# Self-attention with RoPE (skip_attention=True for STG perturbation)
|
||||
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
|
||||
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa
|
||||
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings, skip_attention=skip_audio_self_attn) * agate_msa
|
||||
|
||||
# Cross-attention with text context
|
||||
if self.has_prompt_adaln:
|
||||
|
||||
Reference in New Issue
Block a user