360 lines
13 KiB
Python
360 lines
13 KiB
Python
from dataclasses import dataclass, replace
|
|
from typing import Optional, Tuple
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
from mlx_video.models.ltx.config import LTXRopeType, TransformerConfig
|
|
from mlx_video.models.ltx.attention import Attention
|
|
from mlx_video.models.ltx.feed_forward import FeedForward
|
|
from mlx_video.utils import rms_norm
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Modality:
|
|
latent: mx.array
|
|
timesteps: mx.array
|
|
positions: mx.array
|
|
context: mx.array
|
|
enabled: bool = True
|
|
context_mask: Optional[mx.array] = None
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TransformerArgs:
|
|
x: mx.array
|
|
context: mx.array
|
|
context_mask: Optional[mx.array]
|
|
timesteps: mx.array
|
|
embedded_timestep: mx.array
|
|
positional_embeddings: Tuple[mx.array, mx.array]
|
|
cross_positional_embeddings: Optional[Tuple[mx.array, mx.array]]
|
|
cross_scale_shift_timestep: Optional[mx.array]
|
|
cross_gate_timestep: Optional[mx.array]
|
|
enabled: bool
|
|
|
|
|
|
class BasicAVTransformerBlock(nn.Module):
|
|
"""Audio-Video transformer block with cross-modal attention.
|
|
|
|
Supports video-only, audio-only, or combined audio-video processing
|
|
with bidirectional cross-attention between modalities.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
idx: int,
|
|
video: Optional[TransformerConfig] = None,
|
|
audio: Optional[TransformerConfig] = None,
|
|
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
|
norm_eps: float = 1e-6,
|
|
):
|
|
"""Initialize transformer block.
|
|
|
|
Args:
|
|
idx: Block index
|
|
video: Video modality configuration
|
|
audio: Audio modality configuration
|
|
rope_type: Type of rotary position embedding
|
|
norm_eps: Epsilon for normalization
|
|
"""
|
|
super().__init__()
|
|
|
|
self.idx = idx
|
|
self.norm_eps = norm_eps
|
|
|
|
# Video components
|
|
if video is not None:
|
|
self.attn1 = Attention(
|
|
query_dim=video.dim,
|
|
heads=video.heads,
|
|
dim_head=video.d_head,
|
|
context_dim=None, # Self-attention
|
|
rope_type=rope_type,
|
|
norm_eps=norm_eps,
|
|
)
|
|
self.attn2 = Attention(
|
|
query_dim=video.dim,
|
|
context_dim=video.context_dim,
|
|
heads=video.heads,
|
|
dim_head=video.d_head,
|
|
rope_type=rope_type,
|
|
norm_eps=norm_eps,
|
|
)
|
|
self.ff = FeedForward(video.dim, dim_out=video.dim)
|
|
# 6 scale-shift parameters: 3 for attention, 3 for MLP
|
|
self.scale_shift_table = mx.zeros((6, video.dim))
|
|
|
|
# Audio components
|
|
if audio is not None:
|
|
self.audio_attn1 = Attention(
|
|
query_dim=audio.dim,
|
|
heads=audio.heads,
|
|
dim_head=audio.d_head,
|
|
context_dim=None,
|
|
rope_type=rope_type,
|
|
norm_eps=norm_eps,
|
|
)
|
|
self.audio_attn2 = Attention(
|
|
query_dim=audio.dim,
|
|
context_dim=audio.context_dim,
|
|
heads=audio.heads,
|
|
dim_head=audio.d_head,
|
|
rope_type=rope_type,
|
|
norm_eps=norm_eps,
|
|
)
|
|
self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)
|
|
self.audio_scale_shift_table = mx.zeros((6, audio.dim))
|
|
|
|
# Cross-modal attention (when both video and audio are enabled)
|
|
if audio is not None and video is not None:
|
|
# Audio-to-Video: Q from video, K/V from audio
|
|
self.audio_to_video_attn = Attention(
|
|
query_dim=video.dim,
|
|
context_dim=audio.dim,
|
|
heads=audio.heads,
|
|
dim_head=audio.d_head,
|
|
rope_type=rope_type,
|
|
norm_eps=norm_eps,
|
|
)
|
|
# Video-to-Audio: Q from audio, K/V from video
|
|
self.video_to_audio_attn = Attention(
|
|
query_dim=audio.dim,
|
|
context_dim=video.dim,
|
|
heads=audio.heads,
|
|
dim_head=audio.d_head,
|
|
rope_type=rope_type,
|
|
norm_eps=norm_eps,
|
|
)
|
|
# Scale-shift tables for cross-attention
|
|
self.scale_shift_table_a2v_ca_audio = mx.zeros((5, audio.dim))
|
|
self.scale_shift_table_a2v_ca_video = mx.zeros((5, video.dim))
|
|
|
|
def get_ada_values(
|
|
self,
|
|
scale_shift_table: mx.array,
|
|
batch_size: int,
|
|
timestep: mx.array,
|
|
indices: slice,
|
|
) -> Tuple[mx.array, ...]:
|
|
"""Get adaptive normalization values from scale-shift table.
|
|
|
|
Args:
|
|
scale_shift_table: Table of shape (num_params, dim)
|
|
batch_size: Batch size
|
|
timestep: Timestep embeddings of shape (B, 1, num_params * dim) or similar
|
|
indices: Slice for which parameters to extract
|
|
|
|
Returns:
|
|
Tuple of scale-shift values
|
|
"""
|
|
num_ada_params = scale_shift_table.shape[0]
|
|
|
|
# scale_shift_table[indices]: (num_selected, dim)
|
|
# Add batch and sequence dimensions: (1, 1, num_selected, dim)
|
|
table_slice = scale_shift_table[indices]
|
|
table_expanded = mx.expand_dims(mx.expand_dims(table_slice, axis=0), axis=0)
|
|
|
|
# timestep: (B, seq, num_params * dim) -> reshape to (B, seq, num_params, dim)
|
|
timestep_reshaped = mx.reshape(
|
|
timestep,
|
|
(batch_size, timestep.shape[1], num_ada_params, -1)
|
|
)
|
|
|
|
# Extract the relevant indices
|
|
timestep_slice = timestep_reshaped[:, :, indices, :]
|
|
|
|
# Add table values to timestep
|
|
ada_values = table_expanded + timestep_slice
|
|
|
|
# Unbind along the parameter dimension
|
|
# Result: tuple of tensors, each of shape (B, seq, dim)
|
|
num_sliced = ada_values.shape[2]
|
|
result = tuple(ada_values[:, :, i, :] for i in range(num_sliced))
|
|
|
|
return result
|
|
|
|
def get_av_ca_ada_values(
|
|
self,
|
|
scale_shift_table: mx.array,
|
|
batch_size: int,
|
|
scale_shift_timestep: mx.array,
|
|
gate_timestep: mx.array,
|
|
num_scale_shift_values: int = 4,
|
|
) -> Tuple[mx.array, mx.array, mx.array, mx.array, mx.array]:
|
|
"""Get adaptive values for cross-modal attention.
|
|
|
|
Args:
|
|
scale_shift_table: Table with 5 parameters (4 scale-shift + 1 gate)
|
|
batch_size: Batch size
|
|
scale_shift_timestep: Timestep for scale-shift
|
|
gate_timestep: Timestep for gating
|
|
num_scale_shift_values: Number of scale-shift values (default 4)
|
|
|
|
Returns:
|
|
Tuple of 5 tensors: (scale1, shift1, scale2, shift2, gate)
|
|
"""
|
|
# Get scale-shift values
|
|
scale_shift_ada = self.get_ada_values(
|
|
scale_shift_table[:num_scale_shift_values, :],
|
|
batch_size,
|
|
scale_shift_timestep,
|
|
slice(None, None),
|
|
)
|
|
|
|
# Get gate values
|
|
gate_ada = self.get_ada_values(
|
|
scale_shift_table[num_scale_shift_values:, :],
|
|
batch_size,
|
|
gate_timestep,
|
|
slice(None, None),
|
|
)
|
|
|
|
# Squeeze the sequence dimension if it's 1
|
|
scale_shift_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada)
|
|
gate_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada)
|
|
|
|
return (*scale_shift_squeezed, *gate_squeezed)
|
|
|
|
def __call__(
|
|
self,
|
|
video: Optional[TransformerArgs] = None,
|
|
audio: Optional[TransformerArgs] = None,
|
|
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
|
|
"""Forward pass through transformer block.
|
|
|
|
Args:
|
|
video: Video modality arguments
|
|
audio: Audio modality arguments
|
|
|
|
Returns:
|
|
Tuple of (updated_video, updated_audio) TransformerArgs
|
|
"""
|
|
batch_size = video.x.shape[0] if video is not None else audio.x.shape[0]
|
|
|
|
vx = video.x if video is not None else None
|
|
ax = audio.x if audio is not None else None
|
|
|
|
# Check which modalities to run
|
|
run_vx = video is not None and video.enabled and vx.size > 0
|
|
run_ax = audio is not None and audio.enabled and ax.size > 0
|
|
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0)
|
|
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0)
|
|
|
|
# Process video self-attention and cross-attention with text
|
|
if run_vx:
|
|
vshift_msa, vscale_msa, vgate_msa = self.get_ada_values(
|
|
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
|
|
)
|
|
|
|
# Self-attention with RoPE
|
|
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
|
|
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa
|
|
|
|
# Cross-attention with text context
|
|
vx = vx + self.attn2(
|
|
rms_norm(vx, eps=self.norm_eps),
|
|
context=video.context,
|
|
mask=video.context_mask,
|
|
)
|
|
|
|
# Process audio self-attention and cross-attention with text
|
|
if run_ax:
|
|
ashift_msa, ascale_msa, agate_msa = self.get_ada_values(
|
|
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
|
|
)
|
|
|
|
# Self-attention with RoPE
|
|
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
|
|
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa
|
|
|
|
# Cross-attention with text context
|
|
ax = ax + self.audio_attn2(
|
|
rms_norm(ax, eps=self.norm_eps),
|
|
context=audio.context,
|
|
mask=audio.context_mask,
|
|
)
|
|
|
|
# Audio-Video cross-modal attention
|
|
if run_a2v or run_v2a:
|
|
vx_norm3 = rms_norm(vx, eps=self.norm_eps)
|
|
ax_norm3 = rms_norm(ax, eps=self.norm_eps)
|
|
|
|
# Get adaptive values for audio cross-attention
|
|
(
|
|
scale_ca_audio_a2v,
|
|
shift_ca_audio_a2v,
|
|
scale_ca_audio_v2a,
|
|
shift_ca_audio_v2a,
|
|
gate_out_v2a,
|
|
) = self.get_av_ca_ada_values(
|
|
self.scale_shift_table_a2v_ca_audio,
|
|
ax.shape[0],
|
|
audio.cross_scale_shift_timestep,
|
|
audio.cross_gate_timestep,
|
|
)
|
|
|
|
# Get adaptive values for video cross-attention
|
|
(
|
|
scale_ca_video_a2v,
|
|
shift_ca_video_a2v,
|
|
scale_ca_video_v2a,
|
|
shift_ca_video_v2a,
|
|
gate_out_a2v,
|
|
) = self.get_av_ca_ada_values(
|
|
self.scale_shift_table_a2v_ca_video,
|
|
vx.shape[0],
|
|
video.cross_scale_shift_timestep,
|
|
video.cross_gate_timestep,
|
|
)
|
|
|
|
# Audio-to-Video cross-attention
|
|
if run_a2v:
|
|
vx_scaled = vx_norm3 * (1 + scale_ca_video_a2v) + shift_ca_video_a2v
|
|
ax_scaled = ax_norm3 * (1 + scale_ca_audio_a2v) + shift_ca_audio_a2v
|
|
vx = vx + (
|
|
self.audio_to_video_attn(
|
|
vx_scaled,
|
|
context=ax_scaled,
|
|
pe=video.cross_positional_embeddings,
|
|
k_pe=audio.cross_positional_embeddings,
|
|
)
|
|
* gate_out_a2v
|
|
)
|
|
|
|
# Video-to-Audio cross-attention
|
|
if run_v2a:
|
|
ax_scaled = ax_norm3 * (1 + scale_ca_audio_v2a) + shift_ca_audio_v2a
|
|
vx_scaled = vx_norm3 * (1 + scale_ca_video_v2a) + shift_ca_video_v2a
|
|
ax = ax + (
|
|
self.video_to_audio_attn(
|
|
ax_scaled,
|
|
context=vx_scaled,
|
|
pe=audio.cross_positional_embeddings,
|
|
k_pe=video.cross_positional_embeddings,
|
|
)
|
|
* gate_out_v2a
|
|
)
|
|
|
|
# Process video feed-forward
|
|
if run_vx:
|
|
vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
|
|
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None)
|
|
)
|
|
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
|
|
vx = vx + self.ff(vx_scaled) * vgate_mlp
|
|
|
|
# Process audio feed-forward
|
|
if run_ax:
|
|
ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
|
|
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None)
|
|
)
|
|
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
|
|
ax = ax + self.audio_ff(ax_scaled) * agate_mlp
|
|
|
|
# Return updated TransformerArgs
|
|
video_out = replace(video, x=vx) if video is not None else None
|
|
audio_out = replace(audio, x=ax) if audio is not None else None
|
|
|
|
return video_out, audio_out
|