109 lines
3.1 KiB
Python
109 lines
3.1 KiB
Python
"""Attention blocks for audio VAE."""
|
|
|
|
from enum import Enum
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
from .normalization import NormType, build_normalization_layer
|
|
|
|
|
|
class AttentionType(Enum):
|
|
"""Enum for specifying the attention mechanism type."""
|
|
|
|
VANILLA = "vanilla"
|
|
LINEAR = "linear"
|
|
NONE = "none"
|
|
|
|
|
|
class AttnBlock(nn.Module):
|
|
"""Self-attention block for audio VAE."""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
norm_type: NormType = NormType.GROUP,
|
|
) -> None:
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
|
|
self.norm = build_normalization_layer(in_channels, normtype=norm_type)
|
|
# Using Conv2d with kernel_size=1 for Q, K, V projections
|
|
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
"""
|
|
Forward pass through attention block.
|
|
Args:
|
|
x: Input tensor of shape (B, H, W, C) in MLX channels-last format
|
|
Returns:
|
|
Output tensor with attention applied (residual connection)
|
|
"""
|
|
h_ = x
|
|
h_ = self.norm(h_)
|
|
|
|
q = self.q(h_)
|
|
k = self.k(h_)
|
|
v = self.v(h_)
|
|
|
|
# Compute attention
|
|
# x shape: (B, H, W, C)
|
|
b, h, w, c = q.shape
|
|
|
|
# Reshape for attention: (B, H*W, C)
|
|
q = q.reshape(b, h * w, c)
|
|
k = k.reshape(b, h * w, c)
|
|
v = v.reshape(b, h * w, c)
|
|
|
|
# Attention: Q @ K^T / sqrt(d)
|
|
# q: (B, HW, C), k: (B, HW, C) -> k^T: (B, C, HW)
|
|
# w_: (B, HW, HW)
|
|
scale = float(c) ** (-0.5)
|
|
w_ = mx.matmul(q, k.transpose(0, 2, 1)) * scale
|
|
w_ = mx.softmax(w_, axis=-1)
|
|
|
|
# Attend to values
|
|
# w_: (B, HW, HW), v: (B, HW, C) -> h_: (B, HW, C)
|
|
h_ = mx.matmul(w_, v)
|
|
|
|
# Reshape back to spatial dims
|
|
h_ = h_.reshape(b, h, w, c)
|
|
|
|
h_ = self.proj_out(h_)
|
|
|
|
return x + h_
|
|
|
|
|
|
class Identity(nn.Module):
|
|
"""Identity module that returns input unchanged."""
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
return x
|
|
|
|
|
|
def make_attn(
|
|
in_channels: int,
|
|
attn_type: AttentionType = AttentionType.VANILLA,
|
|
norm_type: NormType = NormType.GROUP,
|
|
) -> nn.Module:
|
|
"""
|
|
Create an attention module based on type.
|
|
Args:
|
|
in_channels: Number of input channels
|
|
attn_type: Type of attention mechanism
|
|
norm_type: Type of normalization
|
|
Returns:
|
|
Attention module
|
|
"""
|
|
if attn_type == AttentionType.VANILLA:
|
|
return AttnBlock(in_channels, norm_type=norm_type)
|
|
elif attn_type == AttentionType.NONE:
|
|
return Identity()
|
|
elif attn_type == AttentionType.LINEAR:
|
|
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
|
else:
|
|
raise ValueError(f"Unknown attention type: {attn_type}")
|