Refactor LTX-2 model structure
This commit is contained in:
108
mlx_video/models/ltx_2/audio_vae/attention.py
Normal file
108
mlx_video/models/ltx_2/audio_vae/attention.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Attention blocks for audio VAE."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .normalization import NormType, build_normalization_layer
|
||||
|
||||
|
||||
class AttentionType(Enum):
|
||||
"""Enum for specifying the attention mechanism type."""
|
||||
|
||||
VANILLA = "vanilla"
|
||||
LINEAR = "linear"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
"""Self-attention block for audio VAE."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
norm_type: NormType = NormType.GROUP,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = build_normalization_layer(in_channels, normtype=norm_type)
|
||||
# Using Conv2d with kernel_size=1 for Q, K, V projections
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""
|
||||
Forward pass through attention block.
|
||||
Args:
|
||||
x: Input tensor of shape (B, H, W, C) in MLX channels-last format
|
||||
Returns:
|
||||
Output tensor with attention applied (residual connection)
|
||||
"""
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# Compute attention
|
||||
# x shape: (B, H, W, C)
|
||||
b, h, w, c = q.shape
|
||||
|
||||
# Reshape for attention: (B, H*W, C)
|
||||
q = q.reshape(b, h * w, c)
|
||||
k = k.reshape(b, h * w, c)
|
||||
v = v.reshape(b, h * w, c)
|
||||
|
||||
# Attention: Q @ K^T / sqrt(d)
|
||||
# q: (B, HW, C), k: (B, HW, C) -> k^T: (B, C, HW)
|
||||
# w_: (B, HW, HW)
|
||||
scale = float(c) ** (-0.5)
|
||||
w_ = mx.matmul(q, k.transpose(0, 2, 1)) * scale
|
||||
w_ = mx.softmax(w_, axis=-1)
|
||||
|
||||
# Attend to values
|
||||
# w_: (B, HW, HW), v: (B, HW, C) -> h_: (B, HW, C)
|
||||
h_ = mx.matmul(w_, v)
|
||||
|
||||
# Reshape back to spatial dims
|
||||
h_ = h_.reshape(b, h, w, c)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
"""Identity module that returns input unchanged."""
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return x
|
||||
|
||||
|
||||
def make_attn(
|
||||
in_channels: int,
|
||||
attn_type: AttentionType = AttentionType.VANILLA,
|
||||
norm_type: NormType = NormType.GROUP,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Create an attention module based on type.
|
||||
Args:
|
||||
in_channels: Number of input channels
|
||||
attn_type: Type of attention mechanism
|
||||
norm_type: Type of normalization
|
||||
Returns:
|
||||
Attention module
|
||||
"""
|
||||
if attn_type == AttentionType.VANILLA:
|
||||
return AttnBlock(in_channels, norm_type=norm_type)
|
||||
elif attn_type == AttentionType.NONE:
|
||||
return Identity()
|
||||
elif attn_type == AttentionType.LINEAR:
|
||||
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
||||
else:
|
||||
raise ValueError(f"Unknown attention type: {attn_type}")
|
||||
Reference in New Issue
Block a user