add audio

This commit is contained in:
Prince Canuma
2026-01-16 01:15:22 +01:00
parent 81daf3f67d
commit a658911f98
19 changed files with 2335 additions and 54 deletions

View File

@@ -0,0 +1,41 @@
"""Audio VAE module for LTX-2 audio generation."""
from .attention import AttentionType, AttnBlock, make_attn
from .audio_vae import AudioDecoder, decode_audio
from .causal_conv_2d import CausalConv2d, make_conv2d
from .causality_axis import CausalityAxis
from .downsample import Downsample, build_downsampling_path
from .normalization import NormType, PixelNorm, build_normalization_layer
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, ResnetBlock
from .upsample import Upsample, build_upsampling_path
from .vocoder import Vocoder
__all__ = [
# Main components
"AudioDecoder",
"Vocoder",
"decode_audio",
# Ops
"AudioLatentShape",
"AudioPatchifier",
"PerChannelStatistics",
# Building blocks
"AttentionType",
"AttnBlock",
"make_attn",
"CausalConv2d",
"make_conv2d",
"CausalityAxis",
"Downsample",
"build_downsampling_path",
"NormType",
"PixelNorm",
"build_normalization_layer",
"ResBlock1",
"ResBlock2",
"ResnetBlock",
"LRELU_SLOPE",
"Upsample",
"build_upsampling_path",
]

View File

@@ -0,0 +1,108 @@
"""Attention blocks for audio VAE."""
from enum import Enum
import mlx.core as mx
import mlx.nn as nn
from .normalization import NormType, build_normalization_layer
class AttentionType(Enum):
"""Enum for specifying the attention mechanism type."""
VANILLA = "vanilla"
LINEAR = "linear"
NONE = "none"
class AttnBlock(nn.Module):
"""Self-attention block for audio VAE."""
def __init__(
self,
in_channels: int,
norm_type: NormType = NormType.GROUP,
) -> None:
super().__init__()
self.in_channels = in_channels
self.norm = build_normalization_layer(in_channels, normtype=norm_type)
# Using Conv2d with kernel_size=1 for Q, K, V projections
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def __call__(self, x: mx.array) -> mx.array:
"""
Forward pass through attention block.
Args:
x: Input tensor of shape (B, H, W, C) in MLX channels-last format
Returns:
Output tensor with attention applied (residual connection)
"""
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# Compute attention
# x shape: (B, H, W, C)
b, h, w, c = q.shape
# Reshape for attention: (B, H*W, C)
q = q.reshape(b, h * w, c)
k = k.reshape(b, h * w, c)
v = v.reshape(b, h * w, c)
# Attention: Q @ K^T / sqrt(d)
# q: (B, HW, C), k: (B, HW, C) -> k^T: (B, C, HW)
# w_: (B, HW, HW)
scale = float(c) ** (-0.5)
w_ = mx.matmul(q, k.transpose(0, 2, 1)) * scale
w_ = mx.softmax(w_, axis=-1)
# Attend to values
# w_: (B, HW, HW), v: (B, HW, C) -> h_: (B, HW, C)
h_ = mx.matmul(w_, v)
# Reshape back to spatial dims
h_ = h_.reshape(b, h, w, c)
h_ = self.proj_out(h_)
return x + h_
class Identity(nn.Module):
"""Identity module that returns input unchanged."""
def __call__(self, x: mx.array) -> mx.array:
return x
def make_attn(
in_channels: int,
attn_type: AttentionType = AttentionType.VANILLA,
norm_type: NormType = NormType.GROUP,
) -> nn.Module:
"""
Create an attention module based on type.
Args:
in_channels: Number of input channels
attn_type: Type of attention mechanism
norm_type: Type of normalization
Returns:
Attention module
"""
if attn_type == AttentionType.VANILLA:
return AttnBlock(in_channels, norm_type=norm_type)
elif attn_type == AttentionType.NONE:
return Identity()
elif attn_type == AttentionType.LINEAR:
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
else:
raise ValueError(f"Unknown attention type: {attn_type}")

View File

@@ -0,0 +1,326 @@
"""Audio VAE encoder and decoder for LTX-2."""
from typing import Set, Tuple
import mlx.core as mx
import mlx.nn as nn
from .attention import AttentionType, make_attn
from .causal_conv_2d import make_conv2d
from .causality_axis import CausalityAxis
from .downsample import build_downsampling_path
from .normalization import NormType, build_normalization_layer
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
from .resnet import ResnetBlock
from .upsample import build_upsampling_path
LATENT_DOWNSAMPLE_FACTOR = 4
def build_mid_block(
channels: int,
temb_channels: int,
dropout: float,
norm_type: NormType,
causality_axis: CausalityAxis,
attn_type: AttentionType,
add_attention: bool,
) -> dict:
"""Build the middle block with two ResNet blocks and optional attention."""
mid = {}
mid["block_1"] = ResnetBlock(
in_channels=channels,
out_channels=channels,
temb_channels=temb_channels,
dropout=dropout,
norm_type=norm_type,
causality_axis=causality_axis,
)
mid["attn_1"] = (
make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else None
)
mid["block_2"] = ResnetBlock(
in_channels=channels,
out_channels=channels,
temb_channels=temb_channels,
dropout=dropout,
norm_type=norm_type,
causality_axis=causality_axis,
)
return mid
def run_mid_block(mid: dict, features: mx.array) -> mx.array:
"""Run features through the middle block."""
features = mid["block_1"](features, temb=None)
if mid["attn_1"] is not None:
features = mid["attn_1"](features)
return mid["block_2"](features, temb=None)
class AudioDecoder(nn.Module):
"""
Symmetric decoder that reconstructs audio spectrograms from latent features.
The decoder mirrors the encoder structure with configurable channel multipliers,
attention resolutions, and causal convolutions.
"""
def __init__(
self,
*,
ch: int = 128,
out_ch: int = 2,
ch_mult: Tuple[int, ...] = (1, 2, 4),
num_res_blocks: int = 2,
attn_resolutions: Set[int] = None,
resolution: int = 256,
z_channels: int = 8,
norm_type: NormType = NormType.PIXEL,
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
dropout: float = 0.0,
mid_block_add_attention: bool = True,
sample_rate: int = 16000,
mel_hop_length: int = 160,
is_causal: bool = True,
mel_bins: int | None = None,
) -> None:
"""
Initialize the AudioDecoder.
Args:
ch: Base number of feature channels
out_ch: Number of output channels (2 for stereo)
ch_mult: Multiplicative factors for channels at each resolution
num_res_blocks: Number of residual blocks per resolution
attn_resolutions: Resolutions at which to apply attention
resolution: Input spatial resolution
z_channels: Number of latent channels
norm_type: Normalization type
causality_axis: Axis for causal convolutions
dropout: Dropout probability
mid_block_add_attention: Whether to add attention in middle block
sample_rate: Audio sample rate
mel_hop_length: Hop length for mel spectrogram
is_causal: Whether to use causal convolutions
mel_bins: Number of mel frequency bins
"""
super().__init__()
if attn_resolutions is None:
attn_resolutions = {8, 16, 32}
# Internal behavioral defaults
resamp_with_conv = True
attn_type = AttentionType.VANILLA
# Per-channel statistics for denormalizing latents
# Uses ch (base channel count) to match the patchified latent dimension
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
# After patchify: (B, T, z_channels * latent_mel_bins) = (B, T, 128)
# ch=128 matches this dimension, so use ch for per_channel_statistics
self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
self.sample_rate = sample_rate
self.mel_hop_length = mel_hop_length
self.is_causal = is_causal
self.mel_bins = mel_bins
self.patchifier = AudioPatchifier(
patch_size=1,
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
sample_rate=sample_rate,
hop_length=mel_hop_length,
is_causal=is_causal,
)
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.out_ch = out_ch
self.give_pre_end = False
self.tanh_out = False
self.norm_type = norm_type
self.z_channels = z_channels
self.channel_multipliers = ch_mult
self.attn_resolutions = attn_resolutions
self.causality_axis = causality_axis
self.attn_type = attn_type
base_block_channels = ch * self.channel_multipliers[-1]
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
self.z_shape = (1, z_channels, base_resolution, base_resolution)
self.conv_in = make_conv2d(
z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
self.mid = build_mid_block(
channels=base_block_channels,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
add_attention=mid_block_add_attention,
)
self.up, final_block_channels = build_upsampling_path(
ch=ch,
ch_mult=ch_mult,
num_resolutions=self.num_resolutions,
num_res_blocks=num_res_blocks,
resolution=resolution,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
attn_resolutions=attn_resolutions,
resamp_with_conv=resamp_with_conv,
initial_block_channels=base_block_channels,
)
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
self.conv_out = make_conv2d(
final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
def __call__(self, sample: mx.array) -> mx.array:
"""
Decode latent features back to audio spectrograms.
Args:
sample: Encoded latent representation of shape (B, H, W, C) in MLX format
or (B, C, H, W) in PyTorch format (will be transposed)
Returns:
Reconstructed audio spectrogram
"""
# Handle input format - if channels are in dim 1, transpose to channels-last
if sample.shape[1] == self.z_channels and sample.ndim == 4:
# PyTorch format (B, C, H, W) -> MLX format (B, H, W, C)
sample = mx.transpose(sample, (0, 2, 3, 1))
sample, target_shape = self._denormalize_latents(sample)
h = self.conv_in(sample)
h = run_mid_block(self.mid, h)
h = self._run_upsampling_path(h)
h = self._finalize_output(h)
return self._adjust_output_shape(h, target_shape)
def _denormalize_latents(self, sample: mx.array) -> tuple[mx.array, AudioLatentShape]:
"""Denormalize latents using per-channel statistics."""
# sample shape: (B, H, W, C) in MLX format
latent_shape = AudioLatentShape(
batch=sample.shape[0],
channels=sample.shape[3], # channels last
frames=sample.shape[1], # height = frames
mel_bins=sample.shape[2], # width = mel_bins
)
sample_patched = self.patchifier.patchify(sample)
sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched)
sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
if self.causality_axis != CausalityAxis.NONE:
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
target_shape = AudioLatentShape(
batch=latent_shape.batch,
channels=self.out_ch,
frames=target_frames,
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
)
return sample, target_shape
def _adjust_output_shape(
self,
decoded_output: mx.array,
target_shape: AudioLatentShape,
) -> mx.array:
"""
Adjust output shape to match target dimensions for variable-length audio.
Args:
decoded_output: Tensor of shape (B, H, W, C) in MLX format
target_shape: AudioLatentShape describing target dimensions
Returns:
Tensor adjusted to match target_shape exactly
"""
# Current output shape: (batch, frames, mel_bins, channels) in MLX format
_, current_time, current_freq, _ = decoded_output.shape
target_channels = target_shape.channels
target_time = target_shape.frames
target_freq = target_shape.mel_bins
# Step 1: Crop first to avoid exceeding target dimensions
decoded_output = decoded_output[
:, : min(current_time, target_time), : min(current_freq, target_freq), :target_channels
]
# Step 2: Calculate padding needed for time and frequency dimensions
time_padding_needed = target_time - decoded_output.shape[1]
freq_padding_needed = target_freq - decoded_output.shape[2]
# Step 3: Apply padding if needed
if time_padding_needed > 0 or freq_padding_needed > 0:
# MLX pad: [(before_0, after_0), ...]
# For (B, H, W, C): H=time, W=freq
padding = [
(0, 0), # batch
(0, max(time_padding_needed, 0)), # time
(0, max(freq_padding_needed, 0)), # freq
(0, 0), # channels
]
decoded_output = mx.pad(decoded_output, padding)
# Step 4: Final safety crop to ensure exact target shape
decoded_output = decoded_output[:, :target_time, :target_freq, :target_channels]
# Transpose back to PyTorch format (B, C, H, W) for vocoder compatibility
decoded_output = mx.transpose(decoded_output, (0, 3, 1, 2))
return decoded_output
def _run_upsampling_path(self, h: mx.array) -> mx.array:
"""Run through upsampling path."""
for level in reversed(range(self.num_resolutions)):
stage = self.up[level]
for block_idx in range(len(stage["block"])):
h = stage["block"][block_idx](h, temb=None)
if block_idx in stage["attn"]:
h = stage["attn"][block_idx](h)
if level != 0 and "upsample" in stage:
h = stage["upsample"](h)
return h
def _finalize_output(self, h: mx.array) -> mx.array:
"""Apply final normalization and convolution."""
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nn.silu(h)
h = self.conv_out(h)
return mx.tanh(h) if self.tanh_out else h
def decode_audio(latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder") -> mx.array:
"""
Decode an audio latent representation using the provided audio decoder and vocoder.
Args:
latent: Input audio latent tensor
audio_decoder: Model to decode the latent to spectrogram
vocoder: Model to convert spectrogram to audio waveform
Returns:
Decoded audio as a float tensor
"""
decoded_audio = audio_decoder(latent)
decoded_audio = vocoder(decoded_audio)
# Remove batch dimension if present
if decoded_audio.shape[0] == 1:
decoded_audio = decoded_audio[0]
return decoded_audio.astype(mx.float32)

View File

@@ -0,0 +1,146 @@
"""Causal 2D convolutions for audio VAE."""
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .causality_axis import CausalityAxis
def _pair(x: Union[int, Tuple[int, int]]) -> Tuple[int, int]:
"""Convert int or tuple to tuple pair."""
if isinstance(x, int):
return (x, x)
return x
class CausalConv2d(nn.Module):
"""
A causal 2D convolution.
This layer ensures that the output at time `t` only depends on inputs
at time `t` and earlier. It achieves this by applying asymmetric padding
to the time dimension before the convolution.
Note: MLX Conv2d expects input shape (N, H, W, C) - channels last.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: int = 1,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
) -> None:
super().__init__()
self.causality_axis = causality_axis
# Ensure kernel_size and dilation are tuples
kernel_size = _pair(kernel_size)
dilation = _pair(dilation)
# Calculate padding dimensions
pad_h = (kernel_size[0] - 1) * dilation[0]
pad_w = (kernel_size[1] - 1) * dilation[1]
# Store padding for manual application
# MLX pad order: [(before_axis0, after_axis0), (before_axis1, after_axis1), ...]
# For (N, H, W, C) format: axis 1 is H (height), axis 2 is W (width)
if self.causality_axis == CausalityAxis.NONE:
# Non-causal: symmetric padding
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w // 2, pad_w - pad_w // 2)
elif self.causality_axis in (CausalityAxis.WIDTH, CausalityAxis.WIDTH_COMPATIBILITY):
# Causal on width: pad left (before width axis)
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w, 0)
elif self.causality_axis == CausalityAxis.HEIGHT:
# Causal on height: pad top (before height axis)
self.padding = (pad_h, 0, pad_w // 2, pad_w - pad_w // 2)
else:
raise ValueError(f"Invalid causality_axis: {causality_axis}")
# The internal convolution layer uses no padding, as we handle it manually
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
)
def __call__(self, x: mx.array) -> mx.array:
"""
Forward pass with causal padding.
Args:
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
Returns:
Output tensor after causal convolution
"""
# Apply causal padding before convolution
# padding format: (pad_h_top, pad_h_bottom, pad_w_left, pad_w_right)
pad_h_top, pad_h_bottom, pad_w_left, pad_w_right = self.padding
if any(p > 0 for p in self.padding):
# MLX pad expects: [(before_0, after_0), (before_1, after_1), ...]
# For (N, H, W, C): axis 0=N, axis 1=H, axis 2=W, axis 3=C
x = mx.pad(x, [(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)])
return self.conv(x)
def make_conv2d(
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: int = 1,
padding: Union[int, Tuple[int, int], None] = None,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
causality_axis: CausalityAxis | None = None,
) -> nn.Module:
"""
Create a 2D convolution layer that can be either causal or non-causal.
Args:
in_channels: Number of input channels
out_channels: Number of output channels
kernel_size: Size of the convolution kernel
stride: Convolution stride
padding: Padding (if None, will be calculated based on causal flag)
dilation: Dilation rate
groups: Number of groups for grouped convolution
bias: Whether to use bias
causality_axis: Dimension along which to apply causality.
Returns:
Either a regular Conv2d or CausalConv2d layer
"""
if causality_axis is not None:
# For causal convolution, padding is handled internally by CausalConv2d
return CausalConv2d(
in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis
)
else:
# For non-causal convolution, use symmetric padding if not specified
if padding is None:
if isinstance(kernel_size, int):
padding = kernel_size // 2
else:
padding = tuple(k // 2 for k in kernel_size)
return nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
)

View File

@@ -0,0 +1,12 @@
"""Causality axis enum for specifying causal convolution dimensions."""
from enum import Enum
class CausalityAxis(Enum):
"""Enum for specifying the causality axis in causal convolutions."""
NONE = None
WIDTH = "width"
HEIGHT = "height"
WIDTH_COMPATIBILITY = "width-compatibility"

View File

@@ -0,0 +1,127 @@
"""Downsampling layers for audio VAE."""
from typing import Set, Tuple
import mlx.core as mx
import mlx.nn as nn
from .attention import AttentionType, make_attn
from .causality_axis import CausalityAxis
from .normalization import NormType
from .resnet import ResnetBlock
class Downsample(nn.Module):
"""
A downsampling layer that can use either a strided convolution
or average pooling. Supports standard and causal padding for the
convolutional mode.
"""
def __init__(
self,
in_channels: int,
with_conv: bool,
causality_axis: CausalityAxis = CausalityAxis.WIDTH,
) -> None:
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
raise ValueError("causality is only supported when `with_conv=True`.")
if self.with_conv:
# Do time downsampling here
# no asymmetric padding in MLX conv, must do it ourselves
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def __call__(self, x: mx.array) -> mx.array:
"""
Forward pass with downsampling.
Args:
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
Returns:
Downsampled tensor
"""
if self.with_conv:
# Padding tuple is in the order: (left, right, top, bottom) for PyTorch
# For MLX pad: [(before_axis0, after_axis0), ...]
# x shape: (N, H, W, C) -> pad on H and W axes
if self.causality_axis == CausalityAxis.NONE:
# pad: (left=0, right=1, top=0, bottom=1)
pad = [(0, 0), (0, 1), (0, 1), (0, 0)]
elif self.causality_axis == CausalityAxis.WIDTH:
# pad: (left=2, right=0, top=0, bottom=1)
pad = [(0, 0), (0, 1), (2, 0), (0, 0)]
elif self.causality_axis == CausalityAxis.HEIGHT:
# pad: (left=0, right=1, top=2, bottom=0)
pad = [(0, 0), (2, 0), (0, 1), (0, 0)]
elif self.causality_axis == CausalityAxis.WIDTH_COMPATIBILITY:
# pad: (left=1, right=0, top=0, bottom=1)
pad = [(0, 0), (0, 1), (1, 0), (0, 0)]
else:
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
x = mx.pad(x, pad, constant_values=0)
x = self.conv(x)
else:
# Average pooling with 2x2 kernel and stride 2
# MLX doesn't have built-in avg_pool2d, implement manually
# x shape: (N, H, W, C)
n, h, w, c = x.shape
# Reshape to (N, H//2, 2, W//2, 2, C) and mean over pooling dims
x = x.reshape(n, h // 2, 2, w // 2, 2, c)
x = mx.mean(x, axis=(2, 4))
return x
def build_downsampling_path(
*,
ch: int,
ch_mult: Tuple[int, ...],
num_resolutions: int,
num_res_blocks: int,
resolution: int,
temb_channels: int,
dropout: float,
norm_type: NormType,
causality_axis: CausalityAxis,
attn_type: AttentionType,
attn_resolutions: Set[int],
resamp_with_conv: bool,
) -> tuple[dict, int]:
"""Build the downsampling path with residual blocks, attention, and downsampling layers."""
down_modules = {}
curr_res = resolution
in_ch_mult = (1, *tuple(ch_mult))
block_in = ch
for i_level in range(num_resolutions):
stage = {}
stage["block"] = {}
stage["attn"] = {}
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(num_res_blocks):
stage["block"][i_block] = ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=temb_channels,
dropout=dropout,
norm_type=norm_type,
causality_axis=causality_axis,
)
block_in = block_out
if curr_res in attn_resolutions:
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
if i_level != num_resolutions - 1:
stage["downsample"] = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
curr_res = curr_res // 2
down_modules[i_level] = stage
return down_modules, block_in

View File

@@ -0,0 +1,59 @@
"""Normalization layers for audio VAE."""
from enum import Enum
import mlx.core as mx
import mlx.nn as nn
class NormType(Enum):
"""Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
GROUP = "group"
PIXEL = "pixel"
class PixelNorm(nn.Module):
"""
Per-pixel (per-location) RMS normalization layer.
For each element along the chosen dimension, this layer normalizes the tensor
by the root-mean-square of its values across that dimension:
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
"""
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
"""
Args:
dim: Dimension along which to compute the RMS (typically channels).
eps: Small constant added for numerical stability.
"""
super().__init__()
self.dim = dim
self.eps = eps
def __call__(self, x: mx.array) -> mx.array:
"""Apply RMS normalization along the configured dimension."""
mean_sq = mx.mean(x**2, axis=self.dim, keepdims=True)
rms = mx.sqrt(mean_sq + self.eps)
return x / rms
def build_normalization_layer(
in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
) -> nn.Module:
"""
Create a normalization layer based on the normalization type.
Args:
in_channels: Number of input channels
num_groups: Number of groups for group normalization
normtype: Type of normalization: "group" or "pixel"
Returns:
A normalization layer
"""
if normtype == NormType.GROUP:
return nn.GroupNorm(num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True)
if normtype == NormType.PIXEL:
# For MLX channels-last format (B, H, W, C), normalize along channels (dim=-1)
# PyTorch uses dim=1 for channels-first format (B, C, H, W)
return PixelNorm(dim=-1, eps=1e-6)
raise ValueError(f"Invalid normalization type: {normtype}")

View File

@@ -0,0 +1,98 @@
"""Audio processing utilities for audio VAE."""
from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
@dataclass
class AudioLatentShape:
"""Shape descriptor for audio latent representations."""
batch: int
channels: int
frames: int
mel_bins: int
class PerChannelStatistics(nn.Module):
"""
Per-channel statistics for normalizing and denormalizing the latent representation.
This statistics is computed over the entire dataset and stored in model's checkpoint.
"""
def __init__(self, latent_channels: int = 128) -> None:
super().__init__()
self.latent_channels = latent_channels
# Initialize buffers - will be loaded from weights
# Using underscores for MLX compatibility with weight loading
self._std_of_means = mx.ones((latent_channels,))
self._mean_of_means = mx.zeros((latent_channels,))
def un_normalize(self, x: mx.array) -> mx.array:
"""Denormalize latent representation."""
# Broadcast statistics to match x shape
# x shape: (B, C, ...) or (B, ..., C)
std = self._std_of_means.astype(x.dtype)
mean = self._mean_of_means.astype(x.dtype)
return (x * std) + mean
def normalize(self, x: mx.array) -> mx.array:
"""Normalize latent representation."""
std = self._std_of_means.astype(x.dtype)
mean = self._mean_of_means.astype(x.dtype)
return (x - mean) / std
class AudioPatchifier:
"""
Audio patchifier for converting between audio latents and patches.
Combines channels and mel_bins dimensions for per-channel statistics.
"""
def __init__(
self,
patch_size: int = 1,
audio_latent_downsample_factor: int = 4,
sample_rate: int = 16000,
hop_length: int = 160,
is_causal: bool = True,
):
self.patch_size = patch_size
self.audio_latent_downsample_factor = audio_latent_downsample_factor
self.sample_rate = sample_rate
self.hop_length = hop_length
self.is_causal = is_causal
def patchify(self, x: mx.array) -> mx.array:
"""Convert audio latents to patches.
Input shape: (B, T, F, C) in MLX format (channels last)
Output shape: (B, T, C*F) - flattened for per-channel statistics
The output order is (c f) to match PyTorch's "b c t f -> b t (c f)".
"""
# x shape: (B, T, F, C) e.g., (1, 68, 16, 8)
b, t, f, c = x.shape
# Transpose to (B, T, C, F) for correct (c f) ordering
x = mx.transpose(x, (0, 1, 3, 2))
# Reshape to (B, T, C*F) e.g., (1, 68, 128)
return x.reshape(b, t, c * f)
def unpatchify(self, x: mx.array, latent_shape: AudioLatentShape) -> mx.array:
"""Convert patches back to audio latents.
Input shape: (B, T, C*F)
Output shape: (B, T, F, C) in MLX format
Reverses patchify's "b t (c f) -> b c t f" then transposes to MLX format.
"""
# x shape: (B, T, C*F) e.g., (1, 68, 128)
b, t, cf = x.shape
c = latent_shape.channels
f = latent_shape.mel_bins
# Reshape to (B, T, C, F)
x = x.reshape(b, t, c, f)
# Transpose to MLX format (B, T, F, C)
return mx.transpose(x, (0, 1, 3, 2))

View File

@@ -0,0 +1,185 @@
"""ResNet blocks for audio VAE and vocoder."""
from typing import List, Tuple
import mlx.core as mx
import mlx.nn as nn
from .causal_conv_2d import make_conv2d
from .causality_axis import CausalityAxis
from .normalization import NormType, build_normalization_layer
LRELU_SLOPE = 0.1
def leaky_relu(x: mx.array, negative_slope: float = LRELU_SLOPE) -> mx.array:
"""Leaky ReLU activation."""
return mx.maximum(x, x * negative_slope)
class ResBlock1(nn.Module):
"""1D ResNet block for vocoder with dilated convolutions."""
def __init__(
self,
channels: int,
kernel_size: int = 3,
dilation: Tuple[int, int, int] = (1, 3, 5),
):
super().__init__()
# First set of convolutions with different dilations
self.convs1 = {
i: nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
dilation=d,
padding=(kernel_size - 1) * d // 2,
)
for i, d in enumerate(dilation)
}
# Second set of convolutions with dilation=1
self.convs2 = {
i: nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
dilation=1,
padding=(kernel_size - 1) // 2,
)
for i in range(len(dilation))
}
def __call__(self, x: mx.array) -> mx.array:
"""Forward pass through residual blocks."""
for i in range(len(self.convs1)):
xt = leaky_relu(x, LRELU_SLOPE)
xt = self.convs1[i](xt)
xt = leaky_relu(xt, LRELU_SLOPE)
xt = self.convs2[i](xt)
x = xt + x
return x
class ResBlock2(nn.Module):
"""1D ResNet block for vocoder (alternative version)."""
def __init__(
self,
channels: int,
kernel_size: int = 3,
dilation: Tuple[int, int] = (1, 3),
):
super().__init__()
self.convs = {
i: nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
dilation=d,
padding=(kernel_size - 1) * d // 2,
)
for i, d in enumerate(dilation)
}
def __call__(self, x: mx.array) -> mx.array:
"""Forward pass through residual blocks."""
for i in range(len(self.convs)):
xt = leaky_relu(x, LRELU_SLOPE)
xt = self.convs[i](xt)
x = xt + x
return x
class ResnetBlock(nn.Module):
"""2D ResNet block for audio VAE encoder/decoder."""
def __init__(
self,
*,
in_channels: int,
out_channels: int | None = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 512,
norm_type: NormType = NormType.GROUP,
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
) -> None:
super().__init__()
self.causality_axis = causality_axis
if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP:
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.temb_channels = temb_channels
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
self.conv1 = make_conv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
if temb_channels > 0:
self.temb_proj = nn.Linear(temb_channels, out_channels)
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
self.dropout_rate = dropout
self.conv2 = make_conv2d(
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.nin_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
)
def __call__(
self,
x: mx.array,
temb: mx.array | None = None,
) -> mx.array:
"""
Forward pass through ResNet block.
Args:
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
temb: Optional time embedding tensor
Returns:
Output tensor
"""
h = x
h = self.norm1(h)
h = nn.silu(h)
h = self.conv1(h)
if temb is not None and self.temb_channels > 0:
# temb: (B, temb_channels) -> (B, out_channels)
# Need to add spatial dims: (B, 1, 1, out_channels) for broadcasting
h = h + mx.expand_dims(mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1)
h = self.norm2(h)
h = nn.silu(h)
if self.dropout_rate > 0:
h = nn.Dropout(self.dropout_rate)(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h

View File

@@ -0,0 +1,135 @@
"""Upsampling layers for audio VAE."""
from typing import Set, Tuple
import mlx.core as mx
import mlx.nn as nn
from .attention import AttentionType, make_attn
from .causal_conv_2d import make_conv2d
from .causality_axis import CausalityAxis
from .normalization import NormType
from .resnet import ResnetBlock
def nearest_neighbor_upsample(x: mx.array, scale_factor: int = 2) -> mx.array:
"""
Nearest neighbor upsampling for 4D tensors.
Args:
x: Input tensor of shape (N, H, W, C)
scale_factor: Upsampling factor
Returns:
Upsampled tensor of shape (N, H*scale_factor, W*scale_factor, C)
"""
n, h, w, c = x.shape
# Repeat along height and width
x = mx.repeat(x, scale_factor, axis=1)
x = mx.repeat(x, scale_factor, axis=2)
return x
class Upsample(nn.Module):
"""Upsampling layer with optional convolution."""
def __init__(
self,
in_channels: int,
with_conv: bool,
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
) -> None:
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.with_conv:
self.conv = make_conv2d(
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
def __call__(self, x: mx.array) -> mx.array:
"""
Forward pass with upsampling.
Args:
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
Returns:
Upsampled tensor
"""
# Nearest neighbor 2x upsampling
x = nearest_neighbor_upsample(x, scale_factor=2)
if self.with_conv:
x = self.conv(x)
# Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
# For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
# The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
# So the output elements rely on the following windows:
# 0: [-,-,0]
# 1: [-,0,0]
# 2: [0,0,1]
# 3: [0,1,1]
# 4: [1,1,2]
# 5: [1,2,2]
# Notice that the first and second elements in the output rely only on the first element in the input,
# while all other elements rely on two elements in the input.
# So we can drop the first element to undo the padding (rather than the last element).
# This is a no-op for non-causal convolutions.
if self.causality_axis == CausalityAxis.NONE:
pass # x remains unchanged
elif self.causality_axis == CausalityAxis.HEIGHT:
x = x[:, 1:, :, :]
elif self.causality_axis == CausalityAxis.WIDTH:
x = x[:, :, 1:, :]
elif self.causality_axis == CausalityAxis.WIDTH_COMPATIBILITY:
pass # x remains unchanged
else:
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
return x
def build_upsampling_path(
*,
ch: int,
ch_mult: Tuple[int, ...],
num_resolutions: int,
num_res_blocks: int,
resolution: int,
temb_channels: int,
dropout: float,
norm_type: NormType,
causality_axis: CausalityAxis,
attn_type: AttentionType,
attn_resolutions: Set[int],
resamp_with_conv: bool,
initial_block_channels: int,
) -> tuple[dict, int]:
"""Build the upsampling path with residual blocks, attention, and upsampling layers."""
up_modules = {}
block_in = initial_block_channels
curr_res = resolution // (2 ** (num_resolutions - 1))
for level in reversed(range(num_resolutions)):
stage = {}
stage["block"] = {}
stage["attn"] = {}
block_out = ch * ch_mult[level]
for i_block in range(num_res_blocks + 1):
stage["block"][i_block] = ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=temb_channels,
dropout=dropout,
norm_type=norm_type,
causality_axis=causality_axis,
)
block_in = block_out
if curr_res in attn_resolutions:
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
if level != 0:
stage["upsample"] = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
curr_res *= 2
up_modules[level] = stage
return up_modules, block_in

View File

@@ -0,0 +1,142 @@
"""Vocoder for converting mel spectrograms to audio waveforms."""
import math
from typing import List
import mlx.core as mx
import mlx.nn as nn
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, leaky_relu
class Vocoder(nn.Module):
"""
Vocoder model for synthesizing audio from Mel spectrograms.
Based on HiFi-GAN architecture.
Args:
resblock_kernel_sizes: List of kernel sizes for the residual blocks
upsample_rates: List of upsampling rates
upsample_kernel_sizes: List of kernel sizes for the upsampling layers
resblock_dilation_sizes: List of dilation sizes for the residual blocks
upsample_initial_channel: Initial number of channels for upsampling
stereo: Whether to use stereo output
resblock: Type of residual block to use ("1" or "2")
output_sample_rate: Waveform sample rate
"""
def __init__(
self,
resblock_kernel_sizes: List[int] | None = None,
upsample_rates: List[int] | None = None,
upsample_kernel_sizes: List[int] | None = None,
resblock_dilation_sizes: List[List[int]] | None = None,
upsample_initial_channel: int = 1024,
stereo: bool = True,
resblock: str = "1",
output_sample_rate: int = 24000,
):
super().__init__()
# Initialize default values if not provided
if resblock_kernel_sizes is None:
resblock_kernel_sizes = [3, 7, 11]
if upsample_rates is None:
upsample_rates = [6, 5, 2, 2, 2]
if upsample_kernel_sizes is None:
upsample_kernel_sizes = [16, 15, 8, 4, 4]
if resblock_dilation_sizes is None:
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
self.output_sample_rate = output_sample_rate
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.upsample_rates = upsample_rates
self.upsample_kernel_sizes = upsample_kernel_sizes
self.upsample_initial_channel = upsample_initial_channel
in_channels = 128 if stereo else 64
self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, kernel_size=7, stride=1, padding=3)
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
# Upsampling layers using ConvTranspose1d
self.ups = {}
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
in_ch = upsample_initial_channel // (2**i)
out_ch = upsample_initial_channel // (2 ** (i + 1))
self.ups[i] = nn.ConvTranspose1d(
in_ch,
out_ch,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - stride) // 2,
)
# Residual blocks
self.resblocks = {}
block_idx = 0
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes):
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations))
block_idx += 1
out_channels = 2 if stereo else 1
final_channels = upsample_initial_channel // (2**self.num_upsamples)
self.conv_post = nn.Conv1d(final_channels, out_channels, kernel_size=7, stride=1, padding=3)
self.upsample_factor = math.prod(upsample_rates)
def __call__(self, x: mx.array) -> mx.array:
"""
Forward pass of the vocoder.
Args:
x: Input Mel spectrogram tensor. Can be either:
- 3D: (batch_size, time, mel_bins) for mono - MLX format (N, L, C)
- 4D: (batch_size, 2, time, mel_bins) for stereo - PyTorch format (N, C, H, W)
Returns:
Audio waveform tensor of shape (batch_size, out_channels, audio_length)
"""
# Input: (batch, channels, time, mel_bins) from audio decoder
# Transpose to (batch, channels, mel_bins, time)
x = mx.transpose(x, (0, 1, 3, 2))
if x.ndim == 4: # stereo
# x shape: (batch, 2, mel_bins, time)
# Rearrange to (batch, 2*mel_bins, time)
b, s, c, t = x.shape
x = x.reshape(b, s * c, t)
# MLX Conv1d expects (N, L, C), so transpose
# Current: (batch, channels, time) -> (batch, time, channels)
x = mx.transpose(x, (0, 2, 1))
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
start = i * self.num_kernels
end = start + self.num_kernels
# Apply residual blocks and average their outputs
block_outputs = []
for idx in range(start, end):
block_outputs.append(self.resblocks[idx](x))
# Stack and mean
x = mx.stack(block_outputs, axis=0)
x = mx.mean(x, axis=0)
# IMPORTANT: Use default leaky_relu slope (0.01), NOT LRELU_SLOPE (0.1)
# PyTorch uses F.leaky_relu(x) which defaults to 0.01
x = nn.leaky_relu(x) # Default negative_slope=0.01
x = self.conv_post(x)
x = mx.tanh(x)
# Transpose back to (batch, channels, time)
x = mx.transpose(x, (0, 2, 1))
return x