"""Downsampling layers for audio VAE.""" from typing import Set, Tuple import mlx.core as mx import mlx.nn as nn from ..config import CausalityAxis from .attention import AttentionType, make_attn from .normalization import NormType from .resnet import ResnetBlock class Downsample(nn.Module): """ A downsampling layer that can use either a strided convolution or average pooling. Supports standard and causal padding for the convolutional mode. """ def __init__( self, in_channels: int, with_conv: bool, causality_axis: CausalityAxis = CausalityAxis.WIDTH, ) -> None: super().__init__() self.with_conv = with_conv self.causality_axis = causality_axis if self.causality_axis != CausalityAxis.NONE and not self.with_conv: raise ValueError("causality is only supported when `with_conv=True`.") if self.with_conv: # Do time downsampling here # no asymmetric padding in MLX conv, must do it ourselves self.conv = nn.Conv2d( in_channels, in_channels, kernel_size=3, stride=2, padding=0 ) def __call__(self, x: mx.array) -> mx.array: """ Forward pass with downsampling. Args: x: Input tensor of shape (N, H, W, C) in MLX channels-last format Returns: Downsampled tensor """ if self.with_conv: # Padding tuple is in the order: (left, right, top, bottom) for PyTorch # For MLX pad: [(before_axis0, after_axis0), ...] # x shape: (N, H, W, C) -> pad on H and W axes if self.causality_axis == CausalityAxis.NONE: # pad: (left=0, right=1, top=0, bottom=1) pad = [(0, 0), (0, 1), (0, 1), (0, 0)] elif self.causality_axis == CausalityAxis.WIDTH: # pad: (left=2, right=0, top=0, bottom=1) pad = [(0, 0), (0, 1), (2, 0), (0, 0)] elif self.causality_axis == CausalityAxis.HEIGHT: # pad: (left=0, right=1, top=2, bottom=0) pad = [(0, 0), (2, 0), (0, 1), (0, 0)] elif self.causality_axis == CausalityAxis.WIDTH_COMPATIBILITY: # pad: (left=1, right=0, top=0, bottom=1) pad = [(0, 0), (0, 1), (1, 0), (0, 0)] else: raise ValueError(f"Invalid causality_axis: {self.causality_axis}") x = mx.pad(x, pad, constant_values=0) x = self.conv(x) else: # Average pooling with 2x2 kernel and stride 2 # MLX doesn't have built-in avg_pool2d, implement manually # x shape: (N, H, W, C) n, h, w, c = x.shape # Reshape to (N, H//2, 2, W//2, 2, C) and mean over pooling dims x = x.reshape(n, h // 2, 2, w // 2, 2, c) x = mx.mean(x, axis=(2, 4)) return x def build_downsampling_path( *, ch: int, ch_mult: Tuple[int, ...], num_resolutions: int, num_res_blocks: int, resolution: int, temb_channels: int, dropout: float, norm_type: NormType, causality_axis: CausalityAxis, attn_type: AttentionType, attn_resolutions: Set[int], resamp_with_conv: bool, ) -> tuple[dict, int]: """Build the downsampling path with residual blocks, attention, and downsampling layers.""" down_modules = {} curr_res = resolution in_ch_mult = (1, *tuple(ch_mult)) block_in = ch for i_level in range(num_resolutions): stage = {} stage["block"] = {} stage["attn"] = {} block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for i_block in range(num_res_blocks): stage["block"][i_block] = ResnetBlock( in_channels=block_in, out_channels=block_out, temb_channels=temb_channels, dropout=dropout, norm_type=norm_type, causality_axis=causality_axis, ) block_in = block_out if curr_res in attn_resolutions: stage["attn"][i_block] = make_attn( block_in, attn_type=attn_type, norm_type=norm_type ) if i_level != num_resolutions - 1: stage["downsample"] = Downsample( block_in, resamp_with_conv, causality_axis=causality_axis ) curr_res = curr_res // 2 down_modules[i_level] = stage return down_modules, block_in