Enhance README.md with new usage examples for STG and modality scale parameters in video generation. Update generate.py to support STG and modality guidance in the denoising process, allowing for improved audio-visual integration. Refactor attention mechanisms in the transformer to include options for skipping self-attention, facilitating STG perturbation and modality isolation. Update LTXModel and transformer block processing to accommodate new parameters for enhanced flexibility in model configurations.
This commit is contained in:
@@ -101,6 +101,7 @@ class Attention(nn.Module):
|
||||
mask: Optional[mx.array] = None,
|
||||
pe: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
k_pe: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
skip_attention: bool = False,
|
||||
) -> mx.array:
|
||||
"""Forward pass.
|
||||
|
||||
@@ -110,6 +111,8 @@ class Attention(nn.Module):
|
||||
mask: Attention mask
|
||||
pe: Position embeddings for query (and key if k_pe is None)
|
||||
k_pe: Position embeddings for key (optional, uses pe if None)
|
||||
skip_attention: If True, bypass Q*K*V attention and use value projection
|
||||
only (for STG perturbation). Matches PyTorch all_perturbed=True.
|
||||
|
||||
Returns:
|
||||
Attention output of shape (B, seq_len, query_dim)
|
||||
@@ -119,24 +122,26 @@ class Attention(nn.Module):
|
||||
if hasattr(self, "to_gate_logits"):
|
||||
gate = 2.0 * mx.sigmoid(self.to_gate_logits(x)) # (B, seq, heads)
|
||||
|
||||
# Compute Q, K, V
|
||||
q = self.to_q(x)
|
||||
context = x if context is None else context
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
# Apply normalization
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
if skip_attention:
|
||||
# STG: bypass Q*K*V attention, use value projection only
|
||||
out = v
|
||||
else:
|
||||
# Standard attention
|
||||
q = self.to_q(x)
|
||||
k = self.to_k(context)
|
||||
|
||||
# Apply rotary position embeddings
|
||||
if pe is not None:
|
||||
q = apply_rotary_emb(q, pe, self.rope_type)
|
||||
k_pe_to_use = pe if k_pe is None else k_pe
|
||||
k = apply_rotary_emb(k, k_pe_to_use, self.rope_type)
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Compute attention
|
||||
out = scaled_dot_product_attention(q, k, v, self.heads, mask)
|
||||
if pe is not None:
|
||||
q = apply_rotary_emb(q, pe, self.rope_type)
|
||||
k_pe_to_use = pe if k_pe is None else k_pe
|
||||
k = apply_rotary_emb(k, k_pe_to_use, self.rope_type)
|
||||
|
||||
out = scaled_dot_product_attention(q, k, v, self.heads, mask)
|
||||
|
||||
# Apply per-head gating
|
||||
if gate is not None:
|
||||
|
||||
@@ -453,10 +453,26 @@ class LTXModel(nn.Module):
|
||||
self,
|
||||
video: Optional[TransformerArgs],
|
||||
audio: Optional[TransformerArgs],
|
||||
stg_video_blocks: Optional[List[int]] = None,
|
||||
stg_audio_blocks: Optional[List[int]] = None,
|
||||
skip_cross_modal: bool = False,
|
||||
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
|
||||
"""Process through all transformer blocks."""
|
||||
for block in self.transformer_blocks.values():
|
||||
video, audio = block(video=video, audio=audio)
|
||||
"""Process through all transformer blocks.
|
||||
|
||||
Args:
|
||||
stg_video_blocks: Block indices where video self-attention is skipped (STG).
|
||||
stg_audio_blocks: Block indices where audio self-attention is skipped (STG).
|
||||
skip_cross_modal: Skip all A2V/V2A cross-attention (modality isolation).
|
||||
"""
|
||||
stg_v_set = set(stg_video_blocks) if stg_video_blocks else set()
|
||||
stg_a_set = set(stg_audio_blocks) if stg_audio_blocks else set()
|
||||
for idx, block in self.transformer_blocks.items():
|
||||
video, audio = block(
|
||||
video=video, audio=audio,
|
||||
skip_video_self_attn=(idx in stg_v_set),
|
||||
skip_audio_self_attn=(idx in stg_a_set),
|
||||
skip_cross_modal=skip_cross_modal,
|
||||
)
|
||||
return video, audio
|
||||
|
||||
def _process_output(
|
||||
@@ -490,8 +506,19 @@ class LTXModel(nn.Module):
|
||||
self,
|
||||
video: Optional[Modality] = None,
|
||||
audio: Optional[Modality] = None,
|
||||
stg_video_blocks: Optional[List[int]] = None,
|
||||
stg_audio_blocks: Optional[List[int]] = None,
|
||||
skip_cross_modal: bool = False,
|
||||
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
|
||||
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
video: Video modality input.
|
||||
audio: Audio modality input.
|
||||
stg_video_blocks: Block indices where video self-attention is skipped (STG).
|
||||
stg_audio_blocks: Block indices where audio self-attention is skipped (STG).
|
||||
skip_cross_modal: Skip all A2V/V2A cross-attention (modality isolation).
|
||||
"""
|
||||
# Validate inputs
|
||||
if not self.model_type.is_video_enabled() and video is not None:
|
||||
raise ValueError("Video is not enabled for this model")
|
||||
@@ -506,6 +533,9 @@ class LTXModel(nn.Module):
|
||||
video_out, audio_out = self._process_transformer_blocks(
|
||||
video=video_args,
|
||||
audio=audio_args,
|
||||
stg_video_blocks=stg_video_blocks,
|
||||
stg_audio_blocks=stg_audio_blocks,
|
||||
skip_cross_modal=skip_cross_modal,
|
||||
)
|
||||
|
||||
# Process outputs
|
||||
@@ -603,9 +633,17 @@ class X0Model(nn.Module):
|
||||
self,
|
||||
video: Optional[Modality] = None,
|
||||
audio: Optional[Modality] = None,
|
||||
stg_video_blocks: Optional[List[int]] = None,
|
||||
stg_audio_blocks: Optional[List[int]] = None,
|
||||
skip_cross_modal: bool = False,
|
||||
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
|
||||
|
||||
vx, ax = self.velocity_model(video, audio)
|
||||
|
||||
vx, ax = self.velocity_model(
|
||||
video, audio,
|
||||
stg_video_blocks=stg_video_blocks,
|
||||
stg_audio_blocks=stg_audio_blocks,
|
||||
skip_cross_modal=skip_cross_modal,
|
||||
)
|
||||
|
||||
denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
|
||||
denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
|
||||
|
||||
@@ -234,12 +234,18 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
self,
|
||||
video: Optional[TransformerArgs] = None,
|
||||
audio: Optional[TransformerArgs] = None,
|
||||
skip_video_self_attn: bool = False,
|
||||
skip_audio_self_attn: bool = False,
|
||||
skip_cross_modal: bool = False,
|
||||
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
|
||||
"""Forward pass through transformer block.
|
||||
|
||||
Args:
|
||||
video: Video modality arguments
|
||||
audio: Audio modality arguments
|
||||
skip_video_self_attn: Skip video self-attention (for STG perturbation)
|
||||
skip_audio_self_attn: Skip audio self-attention (for STG perturbation)
|
||||
skip_cross_modal: Skip all cross-modal attention (for modality isolation)
|
||||
|
||||
Returns:
|
||||
Tuple of (updated_video, updated_audio) TransformerArgs
|
||||
@@ -252,8 +258,8 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
# Check which modalities to run
|
||||
run_vx = video is not None and video.enabled and vx.size > 0
|
||||
run_ax = audio is not None and audio.enabled and ax.size > 0
|
||||
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0)
|
||||
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0)
|
||||
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0) and not skip_cross_modal
|
||||
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0) and not skip_cross_modal
|
||||
|
||||
# Process video self-attention and cross-attention with text
|
||||
if run_vx:
|
||||
@@ -261,9 +267,9 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
|
||||
)
|
||||
|
||||
# Self-attention with RoPE
|
||||
# Self-attention with RoPE (skip_attention=True for STG perturbation)
|
||||
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
|
||||
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa
|
||||
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings, skip_attention=skip_video_self_attn) * vgate_msa
|
||||
|
||||
# Cross-attention with text context
|
||||
if self.has_prompt_adaln:
|
||||
@@ -290,9 +296,9 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
|
||||
)
|
||||
|
||||
# Self-attention with RoPE
|
||||
# Self-attention with RoPE (skip_attention=True for STG perturbation)
|
||||
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
|
||||
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa
|
||||
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings, skip_attention=skip_audio_self_attn) * agate_msa
|
||||
|
||||
# Cross-attention with text context
|
||||
if self.has_prompt_adaln:
|
||||
|
||||
Reference in New Issue
Block a user