Files
mlx-video/mlx_video/models/ltx/transformer.py
2026-01-11 23:48:33 +01:00

360 lines
13 KiB
Python

from dataclasses import dataclass, replace
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.config import LTXRopeType, TransformerConfig
from mlx_video.models.ltx.attention import Attention
from mlx_video.models.ltx.feed_forward import FeedForward
from mlx_video.utils import rms_norm
@dataclass(frozen=True)
class Modality:
latent: mx.array
timesteps: mx.array
positions: mx.array
context: mx.array
enabled: bool = True
context_mask: Optional[mx.array] = None
@dataclass(frozen=True)
class TransformerArgs:
x: mx.array
context: mx.array
context_mask: Optional[mx.array]
timesteps: mx.array
embedded_timestep: mx.array
positional_embeddings: Tuple[mx.array, mx.array]
cross_positional_embeddings: Optional[Tuple[mx.array, mx.array]]
cross_scale_shift_timestep: Optional[mx.array]
cross_gate_timestep: Optional[mx.array]
enabled: bool
class BasicAVTransformerBlock(nn.Module):
"""Audio-Video transformer block with cross-modal attention.
Supports video-only, audio-only, or combined audio-video processing
with bidirectional cross-attention between modalities.
"""
def __init__(
self,
idx: int,
video: Optional[TransformerConfig] = None,
audio: Optional[TransformerConfig] = None,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
norm_eps: float = 1e-6,
):
"""Initialize transformer block.
Args:
idx: Block index
video: Video modality configuration
audio: Audio modality configuration
rope_type: Type of rotary position embedding
norm_eps: Epsilon for normalization
"""
super().__init__()
self.idx = idx
self.norm_eps = norm_eps
# Video components
if video is not None:
self.attn1 = Attention(
query_dim=video.dim,
heads=video.heads,
dim_head=video.d_head,
context_dim=None, # Self-attention
rope_type=rope_type,
norm_eps=norm_eps,
)
self.attn2 = Attention(
query_dim=video.dim,
context_dim=video.context_dim,
heads=video.heads,
dim_head=video.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
)
self.ff = FeedForward(video.dim, dim_out=video.dim)
# 6 scale-shift parameters: 3 for attention, 3 for MLP
self.scale_shift_table = mx.zeros((6, video.dim))
# Audio components
if audio is not None:
self.audio_attn1 = Attention(
query_dim=audio.dim,
heads=audio.heads,
dim_head=audio.d_head,
context_dim=None,
rope_type=rope_type,
norm_eps=norm_eps,
)
self.audio_attn2 = Attention(
query_dim=audio.dim,
context_dim=audio.context_dim,
heads=audio.heads,
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
)
self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)
self.audio_scale_shift_table = mx.zeros((6, audio.dim))
# Cross-modal attention (when both video and audio are enabled)
if audio is not None and video is not None:
# Audio-to-Video: Q from video, K/V from audio
self.audio_to_video_attn = Attention(
query_dim=video.dim,
context_dim=audio.dim,
heads=audio.heads,
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
)
# Video-to-Audio: Q from audio, K/V from video
self.video_to_audio_attn = Attention(
query_dim=audio.dim,
context_dim=video.dim,
heads=audio.heads,
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
)
# Scale-shift tables for cross-attention
self.scale_shift_table_a2v_ca_audio = mx.zeros((5, audio.dim))
self.scale_shift_table_a2v_ca_video = mx.zeros((5, video.dim))
def get_ada_values(
self,
scale_shift_table: mx.array,
batch_size: int,
timestep: mx.array,
indices: slice,
) -> Tuple[mx.array, ...]:
"""Get adaptive normalization values from scale-shift table.
Args:
scale_shift_table: Table of shape (num_params, dim)
batch_size: Batch size
timestep: Timestep embeddings of shape (B, 1, num_params * dim) or similar
indices: Slice for which parameters to extract
Returns:
Tuple of scale-shift values
"""
num_ada_params = scale_shift_table.shape[0]
# scale_shift_table[indices]: (num_selected, dim)
# Add batch and sequence dimensions: (1, 1, num_selected, dim)
table_slice = scale_shift_table[indices]
table_expanded = mx.expand_dims(mx.expand_dims(table_slice, axis=0), axis=0)
# timestep: (B, seq, num_params * dim) -> reshape to (B, seq, num_params, dim)
timestep_reshaped = mx.reshape(
timestep,
(batch_size, timestep.shape[1], num_ada_params, -1)
)
# Extract the relevant indices
timestep_slice = timestep_reshaped[:, :, indices, :]
# Add table values to timestep
ada_values = table_expanded + timestep_slice
# Unbind along the parameter dimension
# Result: tuple of tensors, each of shape (B, seq, dim)
num_sliced = ada_values.shape[2]
result = tuple(ada_values[:, :, i, :] for i in range(num_sliced))
return result
def get_av_ca_ada_values(
self,
scale_shift_table: mx.array,
batch_size: int,
scale_shift_timestep: mx.array,
gate_timestep: mx.array,
num_scale_shift_values: int = 4,
) -> Tuple[mx.array, mx.array, mx.array, mx.array, mx.array]:
"""Get adaptive values for cross-modal attention.
Args:
scale_shift_table: Table with 5 parameters (4 scale-shift + 1 gate)
batch_size: Batch size
scale_shift_timestep: Timestep for scale-shift
gate_timestep: Timestep for gating
num_scale_shift_values: Number of scale-shift values (default 4)
Returns:
Tuple of 5 tensors: (scale1, shift1, scale2, shift2, gate)
"""
# Get scale-shift values
scale_shift_ada = self.get_ada_values(
scale_shift_table[:num_scale_shift_values, :],
batch_size,
scale_shift_timestep,
slice(None, None),
)
# Get gate values
gate_ada = self.get_ada_values(
scale_shift_table[num_scale_shift_values:, :],
batch_size,
gate_timestep,
slice(None, None),
)
# Squeeze the sequence dimension if it's 1
scale_shift_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada)
gate_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada)
return (*scale_shift_squeezed, *gate_squeezed)
def __call__(
self,
video: Optional[TransformerArgs] = None,
audio: Optional[TransformerArgs] = None,
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
"""Forward pass through transformer block.
Args:
video: Video modality arguments
audio: Audio modality arguments
Returns:
Tuple of (updated_video, updated_audio) TransformerArgs
"""
batch_size = video.x.shape[0] if video is not None else audio.x.shape[0]
vx = video.x if video is not None else None
ax = audio.x if audio is not None else None
# Check which modalities to run
run_vx = video is not None and video.enabled and vx.size > 0
run_ax = audio is not None and audio.enabled and ax.size > 0
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0)
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0)
# Process video self-attention and cross-attention with text
if run_vx:
vshift_msa, vscale_msa, vgate_msa = self.get_ada_values(
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
)
# Self-attention with RoPE
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa
# Cross-attention with text context
vx = vx + self.attn2(
rms_norm(vx, eps=self.norm_eps),
context=video.context,
mask=video.context_mask,
)
# Process audio self-attention and cross-attention with text
if run_ax:
ashift_msa, ascale_msa, agate_msa = self.get_ada_values(
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
)
# Self-attention with RoPE
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa
# Cross-attention with text context
ax = ax + self.audio_attn2(
rms_norm(ax, eps=self.norm_eps),
context=audio.context,
mask=audio.context_mask,
)
# Audio-Video cross-modal attention
if run_a2v or run_v2a:
vx_norm3 = rms_norm(vx, eps=self.norm_eps)
ax_norm3 = rms_norm(ax, eps=self.norm_eps)
# Get adaptive values for audio cross-attention
(
scale_ca_audio_a2v,
shift_ca_audio_a2v,
scale_ca_audio_v2a,
shift_ca_audio_v2a,
gate_out_v2a,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_audio,
ax.shape[0],
audio.cross_scale_shift_timestep,
audio.cross_gate_timestep,
)
# Get adaptive values for video cross-attention
(
scale_ca_video_a2v,
shift_ca_video_a2v,
scale_ca_video_v2a,
shift_ca_video_v2a,
gate_out_a2v,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_video,
vx.shape[0],
video.cross_scale_shift_timestep,
video.cross_gate_timestep,
)
# Audio-to-Video cross-attention
if run_a2v:
vx_scaled = vx_norm3 * (1 + scale_ca_video_a2v) + shift_ca_video_a2v
ax_scaled = ax_norm3 * (1 + scale_ca_audio_a2v) + shift_ca_audio_a2v
vx = vx + (
self.audio_to_video_attn(
vx_scaled,
context=ax_scaled,
pe=video.cross_positional_embeddings,
k_pe=audio.cross_positional_embeddings,
)
* gate_out_a2v
)
# Video-to-Audio cross-attention
if run_v2a:
ax_scaled = ax_norm3 * (1 + scale_ca_audio_v2a) + shift_ca_audio_v2a
vx_scaled = vx_norm3 * (1 + scale_ca_video_v2a) + shift_ca_video_v2a
ax = ax + (
self.video_to_audio_attn(
ax_scaled,
context=vx_scaled,
pe=audio.cross_positional_embeddings,
k_pe=video.cross_positional_embeddings,
)
* gate_out_v2a
)
# Process video feed-forward
if run_vx:
vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None)
)
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
vx = vx + self.ff(vx_scaled) * vgate_mlp
# Process audio feed-forward
if run_ax:
ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None)
)
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
ax = ax + self.audio_ff(ax_scaled) * agate_mlp
# Return updated TransformerArgs
video_out = replace(video, x=vx) if video is not None else None
audio_out = replace(audio, x=ax) if audio is not None else None
return video_out, audio_out