This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -1,12 +1,12 @@
"""ResNet blocks for audio VAE and vocoder."""
from typing import List, Tuple
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from .causal_conv_2d import make_conv2d
from ..config import CausalityAxis
from .causal_conv_2d import make_conv2d
from .normalization import NormType, build_normalization_layer
LRELU_SLOPE = 0.1
@@ -125,7 +125,11 @@ class ResnetBlock(nn.Module):
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
self.conv1 = make_conv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
in_channels,
out_channels,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
)
if temb_channels > 0:
@@ -134,17 +138,29 @@ class ResnetBlock(nn.Module):
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
self.dropout_rate = dropout
self.conv2 = make_conv2d(
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
out_channels,
out_channels,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
in_channels,
out_channels,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
)
else:
self.nin_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
in_channels,
out_channels,
kernel_size=1,
stride=1,
causality_axis=causality_axis,
)
def __call__(
@@ -168,7 +184,9 @@ class ResnetBlock(nn.Module):
if temb is not None and self.temb_channels > 0:
# temb: (B, temb_channels) -> (B, out_channels)
# Need to add spatial dims: (B, 1, 1, out_channels) for broadcasting
h = h + mx.expand_dims(mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1)
h = h + mx.expand_dims(
mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1
)
h = self.norm2(h)
h = nn.silu(h)