Enhance README.md with new usage examples for STG and modality scale parameters in video generation. Update generate.py to support STG and modality guidance in the denoising process, allowing for improved audio-visual integration. Refactor attention mechanisms in the transformer to include options for skipping self-attention, facilitating STG perturbation and modality isolation. Update LTXModel and transformer block processing to accommodate new parameters for enhanced flexibility in model configurations.

This commit is contained in:
Prince Canuma
2026-03-14 10:26:12 +01:00
parent f346e09de4
commit 9cba2ea7cd
5 changed files with 200 additions and 78 deletions

View File

@@ -453,10 +453,26 @@ class LTXModel(nn.Module):
self,
video: Optional[TransformerArgs],
audio: Optional[TransformerArgs],
stg_video_blocks: Optional[List[int]] = None,
stg_audio_blocks: Optional[List[int]] = None,
skip_cross_modal: bool = False,
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
"""Process through all transformer blocks."""
for block in self.transformer_blocks.values():
video, audio = block(video=video, audio=audio)
"""Process through all transformer blocks.
Args:
stg_video_blocks: Block indices where video self-attention is skipped (STG).
stg_audio_blocks: Block indices where audio self-attention is skipped (STG).
skip_cross_modal: Skip all A2V/V2A cross-attention (modality isolation).
"""
stg_v_set = set(stg_video_blocks) if stg_video_blocks else set()
stg_a_set = set(stg_audio_blocks) if stg_audio_blocks else set()
for idx, block in self.transformer_blocks.items():
video, audio = block(
video=video, audio=audio,
skip_video_self_attn=(idx in stg_v_set),
skip_audio_self_attn=(idx in stg_a_set),
skip_cross_modal=skip_cross_modal,
)
return video, audio
def _process_output(
@@ -490,8 +506,19 @@ class LTXModel(nn.Module):
self,
video: Optional[Modality] = None,
audio: Optional[Modality] = None,
stg_video_blocks: Optional[List[int]] = None,
stg_audio_blocks: Optional[List[int]] = None,
skip_cross_modal: bool = False,
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
"""Forward pass.
Args:
video: Video modality input.
audio: Audio modality input.
stg_video_blocks: Block indices where video self-attention is skipped (STG).
stg_audio_blocks: Block indices where audio self-attention is skipped (STG).
skip_cross_modal: Skip all A2V/V2A cross-attention (modality isolation).
"""
# Validate inputs
if not self.model_type.is_video_enabled() and video is not None:
raise ValueError("Video is not enabled for this model")
@@ -506,6 +533,9 @@ class LTXModel(nn.Module):
video_out, audio_out = self._process_transformer_blocks(
video=video_args,
audio=audio_args,
stg_video_blocks=stg_video_blocks,
stg_audio_blocks=stg_audio_blocks,
skip_cross_modal=skip_cross_modal,
)
# Process outputs
@@ -603,9 +633,17 @@ class X0Model(nn.Module):
self,
video: Optional[Modality] = None,
audio: Optional[Modality] = None,
stg_video_blocks: Optional[List[int]] = None,
stg_audio_blocks: Optional[List[int]] = None,
skip_cross_modal: bool = False,
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
vx, ax = self.velocity_model(video, audio)
vx, ax = self.velocity_model(
video, audio,
stg_video_blocks=stg_video_blocks,
stg_audio_blocks=stg_audio_blocks,
skip_cross_modal=skip_cross_modal,
)
denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None