186 lines
5.5 KiB
Python
186 lines
5.5 KiB
Python
"""ResNet blocks for audio VAE and vocoder."""
|
|
|
|
from typing import List, Tuple
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
from .causal_conv_2d import make_conv2d
|
|
from ..config import CausalityAxis
|
|
from .normalization import NormType, build_normalization_layer
|
|
|
|
LRELU_SLOPE = 0.1
|
|
|
|
|
|
def leaky_relu(x: mx.array, negative_slope: float = LRELU_SLOPE) -> mx.array:
|
|
"""Leaky ReLU activation."""
|
|
return mx.maximum(x, x * negative_slope)
|
|
|
|
|
|
class ResBlock1(nn.Module):
|
|
"""1D ResNet block for vocoder with dilated convolutions."""
|
|
|
|
def __init__(
|
|
self,
|
|
channels: int,
|
|
kernel_size: int = 3,
|
|
dilation: Tuple[int, int, int] = (1, 3, 5),
|
|
):
|
|
super().__init__()
|
|
|
|
# First set of convolutions with different dilations
|
|
self.convs1 = {
|
|
i: nn.Conv1d(
|
|
channels,
|
|
channels,
|
|
kernel_size,
|
|
stride=1,
|
|
dilation=d,
|
|
padding=(kernel_size - 1) * d // 2,
|
|
)
|
|
for i, d in enumerate(dilation)
|
|
}
|
|
|
|
# Second set of convolutions with dilation=1
|
|
self.convs2 = {
|
|
i: nn.Conv1d(
|
|
channels,
|
|
channels,
|
|
kernel_size,
|
|
stride=1,
|
|
dilation=1,
|
|
padding=(kernel_size - 1) // 2,
|
|
)
|
|
for i in range(len(dilation))
|
|
}
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
"""Forward pass through residual blocks."""
|
|
for i in range(len(self.convs1)):
|
|
xt = leaky_relu(x, LRELU_SLOPE)
|
|
xt = self.convs1[i](xt)
|
|
xt = leaky_relu(xt, LRELU_SLOPE)
|
|
xt = self.convs2[i](xt)
|
|
x = xt + x
|
|
return x
|
|
|
|
|
|
class ResBlock2(nn.Module):
|
|
"""1D ResNet block for vocoder (alternative version)."""
|
|
|
|
def __init__(
|
|
self,
|
|
channels: int,
|
|
kernel_size: int = 3,
|
|
dilation: Tuple[int, int] = (1, 3),
|
|
):
|
|
super().__init__()
|
|
|
|
self.convs = {
|
|
i: nn.Conv1d(
|
|
channels,
|
|
channels,
|
|
kernel_size,
|
|
stride=1,
|
|
dilation=d,
|
|
padding=(kernel_size - 1) * d // 2,
|
|
)
|
|
for i, d in enumerate(dilation)
|
|
}
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
"""Forward pass through residual blocks."""
|
|
for i in range(len(self.convs)):
|
|
xt = leaky_relu(x, LRELU_SLOPE)
|
|
xt = self.convs[i](xt)
|
|
x = xt + x
|
|
return x
|
|
|
|
|
|
class ResnetBlock(nn.Module):
|
|
"""2D ResNet block for audio VAE encoder/decoder."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
in_channels: int,
|
|
out_channels: int | None = None,
|
|
conv_shortcut: bool = False,
|
|
dropout: float = 0.0,
|
|
temb_channels: int = 512,
|
|
norm_type: NormType = NormType.GROUP,
|
|
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
|
) -> None:
|
|
super().__init__()
|
|
self.causality_axis = causality_axis
|
|
|
|
if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP:
|
|
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
|
|
|
|
self.in_channels = in_channels
|
|
out_channels = in_channels if out_channels is None else out_channels
|
|
self.out_channels = out_channels
|
|
self.use_conv_shortcut = conv_shortcut
|
|
self.temb_channels = temb_channels
|
|
|
|
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
|
|
self.conv1 = make_conv2d(
|
|
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
|
)
|
|
|
|
if temb_channels > 0:
|
|
self.temb_proj = nn.Linear(temb_channels, out_channels)
|
|
|
|
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
|
|
self.dropout_rate = dropout
|
|
self.conv2 = make_conv2d(
|
|
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
|
)
|
|
|
|
if self.in_channels != self.out_channels:
|
|
if self.use_conv_shortcut:
|
|
self.conv_shortcut = make_conv2d(
|
|
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
|
)
|
|
else:
|
|
self.nin_shortcut = make_conv2d(
|
|
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
temb: mx.array | None = None,
|
|
) -> mx.array:
|
|
"""
|
|
Forward pass through ResNet block.
|
|
Args:
|
|
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
|
|
temb: Optional time embedding tensor
|
|
Returns:
|
|
Output tensor
|
|
"""
|
|
h = x
|
|
h = self.norm1(h)
|
|
h = nn.silu(h)
|
|
h = self.conv1(h)
|
|
|
|
if temb is not None and self.temb_channels > 0:
|
|
# temb: (B, temb_channels) -> (B, out_channels)
|
|
# Need to add spatial dims: (B, 1, 1, out_channels) for broadcasting
|
|
h = h + mx.expand_dims(mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1)
|
|
|
|
h = self.norm2(h)
|
|
h = nn.silu(h)
|
|
if self.dropout_rate > 0:
|
|
h = nn.Dropout(self.dropout_rate)(h)
|
|
h = self.conv2(h)
|
|
|
|
if self.in_channels != self.out_channels:
|
|
if self.use_conv_shortcut:
|
|
x = self.conv_shortcut(x)
|
|
else:
|
|
x = self.nin_shortcut(x)
|
|
|
|
return x + h
|