format
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
"""ResNet blocks for audio VAE and vocoder."""
|
||||
|
||||
from typing import List, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from .normalization import NormType, build_normalization_layer
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
@@ -125,7 +125,11 @@ class ResnetBlock(nn.Module):
|
||||
|
||||
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
|
||||
self.conv1 = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
if temb_channels > 0:
|
||||
@@ -134,17 +138,29 @@ class ResnetBlock(nn.Module):
|
||||
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
|
||||
self.dropout_rate = dropout
|
||||
self.conv2 = make_conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
@@ -168,7 +184,9 @@ class ResnetBlock(nn.Module):
|
||||
if temb is not None and self.temb_channels > 0:
|
||||
# temb: (B, temb_channels) -> (B, out_channels)
|
||||
# Need to add spatial dims: (B, 1, 1, out_channels) for broadcasting
|
||||
h = h + mx.expand_dims(mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1)
|
||||
h = h + mx.expand_dims(
|
||||
mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1
|
||||
)
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nn.silu(h)
|
||||
|
||||
Reference in New Issue
Block a user