Files
Prince Canuma 17397da70c format
2026-03-18 17:40:05 +01:00

134 lines
4.5 KiB
Python

"""Downsampling layers for audio VAE."""
from typing import Set, Tuple
import mlx.core as mx
import mlx.nn as nn
from ..config import CausalityAxis
from .attention import AttentionType, make_attn
from .normalization import NormType
from .resnet import ResnetBlock
class Downsample(nn.Module):
"""
A downsampling layer that can use either a strided convolution
or average pooling. Supports standard and causal padding for the
convolutional mode.
"""
def __init__(
self,
in_channels: int,
with_conv: bool,
causality_axis: CausalityAxis = CausalityAxis.WIDTH,
) -> None:
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
raise ValueError("causality is only supported when `with_conv=True`.")
if self.with_conv:
# Do time downsampling here
# no asymmetric padding in MLX conv, must do it ourselves
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def __call__(self, x: mx.array) -> mx.array:
"""
Forward pass with downsampling.
Args:
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
Returns:
Downsampled tensor
"""
if self.with_conv:
# Padding tuple is in the order: (left, right, top, bottom) for PyTorch
# For MLX pad: [(before_axis0, after_axis0), ...]
# x shape: (N, H, W, C) -> pad on H and W axes
if self.causality_axis == CausalityAxis.NONE:
# pad: (left=0, right=1, top=0, bottom=1)
pad = [(0, 0), (0, 1), (0, 1), (0, 0)]
elif self.causality_axis == CausalityAxis.WIDTH:
# pad: (left=2, right=0, top=0, bottom=1)
pad = [(0, 0), (0, 1), (2, 0), (0, 0)]
elif self.causality_axis == CausalityAxis.HEIGHT:
# pad: (left=0, right=1, top=2, bottom=0)
pad = [(0, 0), (2, 0), (0, 1), (0, 0)]
elif self.causality_axis == CausalityAxis.WIDTH_COMPATIBILITY:
# pad: (left=1, right=0, top=0, bottom=1)
pad = [(0, 0), (0, 1), (1, 0), (0, 0)]
else:
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
x = mx.pad(x, pad, constant_values=0)
x = self.conv(x)
else:
# Average pooling with 2x2 kernel and stride 2
# MLX doesn't have built-in avg_pool2d, implement manually
# x shape: (N, H, W, C)
n, h, w, c = x.shape
# Reshape to (N, H//2, 2, W//2, 2, C) and mean over pooling dims
x = x.reshape(n, h // 2, 2, w // 2, 2, c)
x = mx.mean(x, axis=(2, 4))
return x
def build_downsampling_path(
*,
ch: int,
ch_mult: Tuple[int, ...],
num_resolutions: int,
num_res_blocks: int,
resolution: int,
temb_channels: int,
dropout: float,
norm_type: NormType,
causality_axis: CausalityAxis,
attn_type: AttentionType,
attn_resolutions: Set[int],
resamp_with_conv: bool,
) -> tuple[dict, int]:
"""Build the downsampling path with residual blocks, attention, and downsampling layers."""
down_modules = {}
curr_res = resolution
in_ch_mult = (1, *tuple(ch_mult))
block_in = ch
for i_level in range(num_resolutions):
stage = {}
stage["block"] = {}
stage["attn"] = {}
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(num_res_blocks):
stage["block"][i_block] = ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=temb_channels,
dropout=dropout,
norm_type=norm_type,
causality_axis=causality_axis,
)
block_in = block_out
if curr_res in attn_resolutions:
stage["attn"][i_block] = make_attn(
block_in, attn_type=attn_type, norm_type=norm_type
)
if i_level != num_resolutions - 1:
stage["downsample"] = Downsample(
block_in, resamp_with_conv, causality_axis=causality_axis
)
curr_res = curr_res // 2
down_modules[i_level] = stage
return down_modules, block_in