This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -7,8 +7,8 @@ Supports:
"""
import math
from typing import List, Tuple
from pathlib import Path
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -32,7 +32,9 @@ class Snake(nn.Module):
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
self.alpha = (
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
)
def __call__(self, x: mx.array) -> mx.array:
# x: (N, L, C) in MLX format
@@ -48,8 +50,12 @@ class SnakeBeta(nn.Module):
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
self.beta = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
self.alpha = (
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
)
self.beta = (
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
)
def __call__(self, x: mx.array) -> mx.array:
alpha = self.alpha
@@ -73,7 +79,9 @@ def _sinc(x: mx.array) -> mx.array:
)
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> mx.array:
def kaiser_sinc_filter1d(
cutoff: float, half_width: float, kernel_size: int
) -> mx.array:
"""Compute a Kaiser-windowed sinc filter."""
even = kernel_size % 2 == 0
half_size = kernel_size // 2
@@ -88,6 +96,7 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) ->
# Kaiser window - compute using scipy-compatible formula
import numpy as np
window = mx.array(np.kaiser(kernel_size, beta).astype(np.float32))
if even:
@@ -107,6 +116,7 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) ->
def hann_sinc_filter1d(ratio: int) -> Tuple[mx.array, int, int, int]:
"""Compute a Hann-windowed sinc filter for upsampling (used by BWE resampler)."""
import numpy as np
rolloff = 0.99
lowpass_filter_width = 6
width = math.ceil(lowpass_filter_width / rolloff)
@@ -187,10 +197,16 @@ class UpSample1d(nn.Module):
self.kernel_size = filt.shape[2]
self.filter = filt
else:
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
self.pad_left = (
self.pad * self.stride + (self.kernel_size - self.stride) // 2
)
self.pad_right = (
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
)
self.filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
@@ -215,10 +231,12 @@ class UpSample1d(nn.Module):
filt = self.filter.astype(x.dtype) # (1, 1, K)
filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1)
x = self.ratio * mx.conv_transpose1d(x, filt, stride=self.stride) # (N*C, L', 1)
x = self.ratio * mx.conv_transpose1d(
x, filt, stride=self.stride
) # (N*C, L', 1)
# Trim padding
x = x[:, self.pad_left:-self.pad_right, :]
x = x[:, self.pad_left : -self.pad_right, :]
x = x.reshape(n, c, -1) # (N, C, L')
x = mx.transpose(x, (0, 2, 1)) # (N, L', C)
@@ -285,16 +303,24 @@ class AMPBlock1(nn.Module):
self.convs1 = {
i: nn.Conv1d(
channels, channels, kernel_size, stride=1,
dilation=d, padding=get_padding(kernel_size, d),
channels,
channels,
kernel_size,
stride=1,
dilation=d,
padding=get_padding(kernel_size, d),
)
for i, d in enumerate(dilation)
}
self.convs2 = {
i: nn.Conv1d(
channels, channels, kernel_size, stride=1,
dilation=1, padding=get_padding(kernel_size, 1),
channels,
channels,
kernel_size,
stride=1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
for i in range(len(dilation))
}
@@ -348,7 +374,9 @@ class STFTFn(nn.Module):
y = mx.concatenate([first, y], axis=1)
# forward_basis: (514, 1, 512) PyTorch format -> (514, 512, 1) MLX
basis = mx.transpose(self.forward_basis.astype(y.dtype), (0, 2, 1)) # (514, K, 1)
basis = mx.transpose(
self.forward_basis.astype(y.dtype), (0, 2, 1)
) # (514, K, 1)
# Conv1d: (B, T, 1) * (514, K, 1) -> (B, T_frames, 514)
spec = mx.conv1d(y, basis, stride=self.hop_length)
@@ -358,8 +386,10 @@ class STFTFn(nn.Module):
real = spec[..., :n_freqs]
imag = spec[..., n_freqs:]
magnitude = mx.sqrt(real ** 2 + imag ** 2)
phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype(real.dtype)
magnitude = mx.sqrt(real**2 + imag**2)
phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype(
real.dtype
)
# Output: (B, T_frames, n_freqs) in MLX channels-last
return magnitude, phase
@@ -368,7 +398,9 @@ class STFTFn(nn.Module):
class MelSTFT(nn.Module):
"""Causal log-mel spectrogram from precomputed STFT bases."""
def __init__(self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int) -> None:
def __init__(
self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int
) -> None:
super().__init__()
self.stft_fn = STFTFn(filter_length, hop_length, win_length)
n_freqs = filter_length // 2 + 1
@@ -385,7 +417,9 @@ class MelSTFT(nn.Module):
"""
magnitude, phase = self.stft_fn(y)
# magnitude: (B, T_frames, n_freqs)
mel = magnitude @ self.mel_basis.astype(magnitude.dtype).T # (B, T_frames, n_mels)
mel = (
magnitude @ self.mel_basis.astype(magnitude.dtype).T
) # (B, T_frames, n_mels)
log_mel = mx.log(mx.clip(mel, 1e-5, None))
# Transpose to (B, n_mels, T_frames) for compatibility with vocoder input format
return mx.transpose(log_mel, (0, 2, 1))
@@ -415,8 +449,11 @@ class Vocoder(nn.Module):
in_channels = 128 if config.stereo else 64
self.conv_pre = nn.Conv1d(
in_channels, config.upsample_initial_channel,
kernel_size=7, stride=1, padding=3,
in_channels,
config.upsample_initial_channel,
kernel_size=7,
stride=1,
padding=3,
)
# Upsampling layers
@@ -424,11 +461,13 @@ class Vocoder(nn.Module):
for i, (stride, kernel_size) in enumerate(
zip(config.upsample_rates, config.upsample_kernel_sizes)
):
in_ch = config.upsample_initial_channel // (2 ** i)
in_ch = config.upsample_initial_channel // (2**i)
out_ch = config.upsample_initial_channel // (2 ** (i + 1))
self.ups[i] = nn.ConvTranspose1d(
in_ch, out_ch,
kernel_size=kernel_size, stride=stride,
in_ch,
out_ch,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - stride) // 2,
)
@@ -442,7 +481,9 @@ class Vocoder(nn.Module):
config.resblock_kernel_sizes, config.resblock_dilation_sizes
):
self.resblocks[block_idx] = AMPBlock1(
ch, kernel_size, tuple(dilations),
ch,
kernel_size,
tuple(dilations),
activation=config.activation,
)
block_idx += 1
@@ -455,10 +496,14 @@ class Vocoder(nn.Module):
for kernel_size, dilations in zip(
config.resblock_kernel_sizes, config.resblock_dilation_sizes
):
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations))
self.resblocks[block_idx] = resblock_class(
ch, kernel_size, tuple(dilations)
)
block_idx += 1
final_channels = config.upsample_initial_channel // (2 ** len(config.upsample_rates))
final_channels = config.upsample_initial_channel // (
2 ** len(config.upsample_rates)
)
# Post-activation
if self.is_amp:
@@ -468,8 +513,11 @@ class Vocoder(nn.Module):
# Final conv
out_channels = 2 if config.stereo else 1
self.conv_post = nn.Conv1d(
final_channels, out_channels,
kernel_size=7, stride=1, padding=3,
final_channels,
out_channels,
kernel_size=7,
stride=1,
padding=3,
bias=config.use_bias_at_final,
)
@@ -588,7 +636,9 @@ class VocoderWithBWE(nn.Module):
"""
x = self.vocoder(mel_spec) # (B, C, T) at input_sampling_rate
_, _, length_low_rate = x.shape
output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
output_length = (
length_low_rate * self.output_sampling_rate // self.input_sampling_rate
)
# Pad to hop_length multiple
remainder = length_low_rate % self.hop_length
@@ -685,5 +735,3 @@ def _load_vocoder_with_bwe(config_dict: dict, weights: dict) -> VocoderWithBWE:
model.load_weights(list(weights.items()), strict=False)
return model