Files
mlx-video/mlx_video/models/ltx/audio_vae/attention.py
Prince Canuma a658911f98 add audio
2026-01-16 01:15:22 +01:00

109 lines
3.1 KiB
Python

"""Attention blocks for audio VAE."""
from enum import Enum
import mlx.core as mx
import mlx.nn as nn
from .normalization import NormType, build_normalization_layer
class AttentionType(Enum):
"""Enum for specifying the attention mechanism type."""
VANILLA = "vanilla"
LINEAR = "linear"
NONE = "none"
class AttnBlock(nn.Module):
"""Self-attention block for audio VAE."""
def __init__(
self,
in_channels: int,
norm_type: NormType = NormType.GROUP,
) -> None:
super().__init__()
self.in_channels = in_channels
self.norm = build_normalization_layer(in_channels, normtype=norm_type)
# Using Conv2d with kernel_size=1 for Q, K, V projections
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def __call__(self, x: mx.array) -> mx.array:
"""
Forward pass through attention block.
Args:
x: Input tensor of shape (B, H, W, C) in MLX channels-last format
Returns:
Output tensor with attention applied (residual connection)
"""
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# Compute attention
# x shape: (B, H, W, C)
b, h, w, c = q.shape
# Reshape for attention: (B, H*W, C)
q = q.reshape(b, h * w, c)
k = k.reshape(b, h * w, c)
v = v.reshape(b, h * w, c)
# Attention: Q @ K^T / sqrt(d)
# q: (B, HW, C), k: (B, HW, C) -> k^T: (B, C, HW)
# w_: (B, HW, HW)
scale = float(c) ** (-0.5)
w_ = mx.matmul(q, k.transpose(0, 2, 1)) * scale
w_ = mx.softmax(w_, axis=-1)
# Attend to values
# w_: (B, HW, HW), v: (B, HW, C) -> h_: (B, HW, C)
h_ = mx.matmul(w_, v)
# Reshape back to spatial dims
h_ = h_.reshape(b, h, w, c)
h_ = self.proj_out(h_)
return x + h_
class Identity(nn.Module):
"""Identity module that returns input unchanged."""
def __call__(self, x: mx.array) -> mx.array:
return x
def make_attn(
in_channels: int,
attn_type: AttentionType = AttentionType.VANILLA,
norm_type: NormType = NormType.GROUP,
) -> nn.Module:
"""
Create an attention module based on type.
Args:
in_channels: Number of input channels
attn_type: Type of attention mechanism
norm_type: Type of normalization
Returns:
Attention module
"""
if attn_type == AttentionType.VANILLA:
return AttnBlock(in_channels, norm_type=norm_type)
elif attn_type == AttentionType.NONE:
return Identity()
elif attn_type == AttentionType.LINEAR:
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
else:
raise ValueError(f"Unknown attention type: {attn_type}")