"""Upsampling layers for audio VAE.""" from typing import Set, Tuple import mlx.core as mx import mlx.nn as nn from ..config import CausalityAxis from .attention import AttentionType, make_attn from .causal_conv_2d import make_conv2d from .normalization import NormType from .resnet import ResnetBlock def nearest_neighbor_upsample(x: mx.array, scale_factor: int = 2) -> mx.array: """ Nearest neighbor upsampling for 4D tensors. Args: x: Input tensor of shape (N, H, W, C) scale_factor: Upsampling factor Returns: Upsampled tensor of shape (N, H*scale_factor, W*scale_factor, C) """ n, h, w, c = x.shape # Repeat along height and width x = mx.repeat(x, scale_factor, axis=1) x = mx.repeat(x, scale_factor, axis=2) return x class Upsample(nn.Module): """Upsampling layer with optional convolution.""" def __init__( self, in_channels: int, with_conv: bool, causality_axis: CausalityAxis = CausalityAxis.HEIGHT, ) -> None: super().__init__() self.with_conv = with_conv self.causality_axis = causality_axis if self.with_conv: self.conv = make_conv2d( in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis, ) def __call__(self, x: mx.array) -> mx.array: """ Forward pass with upsampling. Args: x: Input tensor of shape (N, H, W, C) in MLX channels-last format Returns: Upsampled tensor """ # Nearest neighbor 2x upsampling x = nearest_neighbor_upsample(x, scale_factor=2) if self.with_conv: x = self.conv(x) # Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n. # For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2]. # The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2], # So the output elements rely on the following windows: # 0: [-,-,0] # 1: [-,0,0] # 2: [0,0,1] # 3: [0,1,1] # 4: [1,1,2] # 5: [1,2,2] # Notice that the first and second elements in the output rely only on the first element in the input, # while all other elements rely on two elements in the input. # So we can drop the first element to undo the padding (rather than the last element). # This is a no-op for non-causal convolutions. if self.causality_axis == CausalityAxis.NONE: pass # x remains unchanged elif self.causality_axis == CausalityAxis.HEIGHT: x = x[:, 1:, :, :] elif self.causality_axis == CausalityAxis.WIDTH: x = x[:, :, 1:, :] elif self.causality_axis == CausalityAxis.WIDTH_COMPATIBILITY: pass # x remains unchanged else: raise ValueError(f"Invalid causality_axis: {self.causality_axis}") return x def build_upsampling_path( *, ch: int, ch_mult: Tuple[int, ...], num_resolutions: int, num_res_blocks: int, resolution: int, temb_channels: int, dropout: float, norm_type: NormType, causality_axis: CausalityAxis, attn_type: AttentionType, attn_resolutions: Set[int], resamp_with_conv: bool, initial_block_channels: int, ) -> tuple[dict, int]: """Build the upsampling path with residual blocks, attention, and upsampling layers.""" up_modules = {} block_in = initial_block_channels curr_res = resolution // (2 ** (num_resolutions - 1)) for level in reversed(range(num_resolutions)): stage = {} stage["block"] = {} stage["attn"] = {} block_out = ch * ch_mult[level] for i_block in range(num_res_blocks + 1): stage["block"][i_block] = ResnetBlock( in_channels=block_in, out_channels=block_out, temb_channels=temb_channels, dropout=dropout, norm_type=norm_type, causality_axis=causality_axis, ) block_in = block_out if curr_res in attn_resolutions: stage["attn"][i_block] = make_attn( block_in, attn_type=attn_type, norm_type=norm_type ) if level != 0: stage["upsample"] = Upsample( block_in, resamp_with_conv, causality_axis=causality_axis ) curr_res *= 2 up_modules[level] = stage return up_modules, block_in