"""ResNet blocks for audio VAE and vocoder.""" from typing import List, Tuple import mlx.core as mx import mlx.nn as nn from .causal_conv_2d import make_conv2d from .causality_axis import CausalityAxis from .normalization import NormType, build_normalization_layer LRELU_SLOPE = 0.1 def leaky_relu(x: mx.array, negative_slope: float = LRELU_SLOPE) -> mx.array: """Leaky ReLU activation.""" return mx.maximum(x, x * negative_slope) class ResBlock1(nn.Module): """1D ResNet block for vocoder with dilated convolutions.""" def __init__( self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5), ): super().__init__() # First set of convolutions with different dilations self.convs1 = { i: nn.Conv1d( channels, channels, kernel_size, stride=1, dilation=d, padding=(kernel_size - 1) * d // 2, ) for i, d in enumerate(dilation) } # Second set of convolutions with dilation=1 self.convs2 = { i: nn.Conv1d( channels, channels, kernel_size, stride=1, dilation=1, padding=(kernel_size - 1) // 2, ) for i in range(len(dilation)) } def __call__(self, x: mx.array) -> mx.array: """Forward pass through residual blocks.""" for i in range(len(self.convs1)): xt = leaky_relu(x, LRELU_SLOPE) xt = self.convs1[i](xt) xt = leaky_relu(xt, LRELU_SLOPE) xt = self.convs2[i](xt) x = xt + x return x class ResBlock2(nn.Module): """1D ResNet block for vocoder (alternative version).""" def __init__( self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3), ): super().__init__() self.convs = { i: nn.Conv1d( channels, channels, kernel_size, stride=1, dilation=d, padding=(kernel_size - 1) * d // 2, ) for i, d in enumerate(dilation) } def __call__(self, x: mx.array) -> mx.array: """Forward pass through residual blocks.""" for i in range(len(self.convs)): xt = leaky_relu(x, LRELU_SLOPE) xt = self.convs[i](xt) x = xt + x return x class ResnetBlock(nn.Module): """2D ResNet block for audio VAE encoder/decoder.""" def __init__( self, *, in_channels: int, out_channels: int | None = None, conv_shortcut: bool = False, dropout: float = 0.0, temb_channels: int = 512, norm_type: NormType = NormType.GROUP, causality_axis: CausalityAxis = CausalityAxis.HEIGHT, ) -> None: super().__init__() self.causality_axis = causality_axis if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP: raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.temb_channels = temb_channels self.norm1 = build_normalization_layer(in_channels, normtype=norm_type) self.conv1 = make_conv2d( in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis ) if temb_channels > 0: self.temb_proj = nn.Linear(temb_channels, out_channels) self.norm2 = build_normalization_layer(out_channels, normtype=norm_type) self.dropout_rate = dropout self.conv2 = make_conv2d( out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis ) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = make_conv2d( in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis ) else: self.nin_shortcut = make_conv2d( in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis ) def __call__( self, x: mx.array, temb: mx.array | None = None, ) -> mx.array: """ Forward pass through ResNet block. Args: x: Input tensor of shape (N, H, W, C) in MLX channels-last format temb: Optional time embedding tensor Returns: Output tensor """ h = x h = self.norm1(h) h = nn.silu(h) h = self.conv1(h) if temb is not None and self.temb_channels > 0: # temb: (B, temb_channels) -> (B, out_channels) # Need to add spatial dims: (B, 1, 1, out_channels) for broadcasting h = h + mx.expand_dims(mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1) h = self.norm2(h) h = nn.silu(h) if self.dropout_rate > 0: h = nn.Dropout(self.dropout_rate)(h) h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) return x + h