fix LTX-2.3 audio

This commit is contained in:
Prince Canuma
2026-03-15 02:06:35 +01:00
parent eb0d1355e4
commit 53bae534e7
4 changed files with 649 additions and 130 deletions

View File

@@ -1012,13 +1012,14 @@ def load_audio_decoder(model_path: Path, pipeline: PipelineType):
return decoder return decoder
def load_vocoder(model_path: Path, pipeline: PipelineType): def load_vocoder_model(model_path: Path, pipeline: PipelineType):
"""Load vocoder for mel to waveform conversion.""" """Load vocoder for mel to waveform conversion.
from mlx_video.models.ltx.audio_vae import Vocoder
vocoder = Vocoder.from_pretrained(model_path / "vocoder") Automatically detects HiFi-GAN (LTX-2) or BigVGAN+BWE (LTX-2.3).
"""
from mlx_video.models.ltx.audio_vae.vocoder import load_vocoder as _load_vocoder
return vocoder return _load_vocoder(model_path / "vocoder")
def save_audio(audio: np.ndarray, path: Path, sample_rate: int = AUDIO_SAMPLE_RATE): def save_audio(audio: np.ndarray, path: Path, sample_rate: int = AUDIO_SAMPLE_RATE):
@@ -1795,7 +1796,7 @@ def generate_video(
if audio and audio_latents is not None: if audio and audio_latents is not None:
with console.status("[blue]🔊 Decoding audio...[/]", spinner="dots"): with console.status("[blue]🔊 Decoding audio...[/]", spinner="dots"):
audio_decoder = load_audio_decoder(model_path, pipeline) audio_decoder = load_audio_decoder(model_path, pipeline)
vocoder = load_vocoder(model_path, pipeline) vocoder = load_vocoder_model(model_path, pipeline)
mx.eval(audio_decoder.parameters(), vocoder.parameters()) mx.eval(audio_decoder.parameters(), vocoder.parameters())
mel_spectrogram = audio_decoder(audio_latents) mel_spectrogram = audio_decoder(audio_latents)
@@ -1809,12 +1810,15 @@ def generate_video(
if audio_np.ndim == 3: if audio_np.ndim == 3:
audio_np = audio_np[0] audio_np = audio_np[0]
# Get sample rate from vocoder (dynamic: 24kHz for LTX-2, 48kHz for LTX-2.3 BWE)
vocoder_sample_rate = getattr(vocoder, 'output_sampling_rate', AUDIO_SAMPLE_RATE)
del audio_decoder, vocoder del audio_decoder, vocoder
mx.clear_cache() mx.clear_cache()
console.print("[green]✓[/] Audio decoded") console.print("[green]✓[/] Audio decoded")
audio_path = Path(output_audio_path) if output_audio_path else output_path.with_suffix('.wav') audio_path = Path(output_audio_path) if output_audio_path else output_path.with_suffix('.wav')
save_audio(audio_np, audio_path, AUDIO_SAMPLE_RATE) save_audio(audio_np, audio_path, vocoder_sample_rate)
console.print(f"[green]✅ Saved audio to[/] {audio_path}") console.print(f"[green]✅ Saved audio to[/] {audio_path}")
with console.status("[blue]🎬 Combining video and audio...[/]", spinner="dots"): with console.status("[blue]🎬 Combining video and audio...[/]", spinner="dots"):

View File

@@ -9,12 +9,13 @@ from .normalization import NormType, PixelNorm, build_normalization_layer
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, ResnetBlock from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, ResnetBlock
from .upsample import Upsample, build_upsampling_path from .upsample import Upsample, build_upsampling_path
from .vocoder import Vocoder from .vocoder import Vocoder, load_vocoder
__all__ = [ __all__ = [
# Main components # Main components
"AudioDecoder", "AudioDecoder",
"Vocoder", "Vocoder",
"load_vocoder",
"decode_audio", "decode_audio",
# Ops # Ops
"AudioLatentShape", "AudioLatentShape",

View File

@@ -1,179 +1,689 @@
"""Vocoder for converting mel spectrograms to audio waveforms.""" """Vocoder for converting mel spectrograms to audio waveforms.
Supports:
- HiFi-GAN (LTX-2): ResBlock1 with LeakyReLU
- BigVGAN v2 (LTX-2.3): AMPBlock1 with Snake/SnakeBeta + anti-aliased resampling
- VocoderWithBWE (LTX-2.3): Base vocoder + bandwidth extension (16kHz -> 48kHz)
"""
import math import math
from typing import Dict from typing import List, Tuple
from pathlib import Path from pathlib import Path
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx_vlm.models.base import check_array_shape
from ..config import VocoderModelConfig from ..config import VocoderModelConfig
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, leaky_relu from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, leaky_relu
class Vocoder(nn.Module): def get_padding(kernel_size: int, dilation: int = 1) -> int:
""" return int((kernel_size * dilation - dilation) / 2)
Vocoder model for synthesizing audio from Mel spectrograms.
Based on HiFi-GAN architecture.
Args:
resblock_kernel_sizes: List of kernel sizes for the residual blocks # ---------------------------------------------------------------------------
upsample_rates: List of upsampling rates # Snake / SnakeBeta activations (BigVGAN v2)
upsample_kernel_sizes: List of kernel sizes for the upsampling layers # ---------------------------------------------------------------------------
resblock_dilation_sizes: List of dilation sizes for the residual blocks
upsample_initial_channel: Initial number of channels for upsampling
stereo: Whether to use stereo output class Snake(nn.Module):
resblock: Type of residual block to use ("1" or "2") """Snake activation: x + (1/alpha) * sin^2(alpha * x)."""
output_sample_rate: Waveform sample rate
""" def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
def __call__(self, x: mx.array) -> mx.array:
# x: (N, L, C) in MLX format
alpha = self.alpha # (C,)
if self.alpha_logscale:
alpha = mx.exp(alpha)
return x + (1.0 / (alpha + 1e-9)) * mx.power(mx.sin(x * alpha), 2)
class SnakeBeta(nn.Module):
"""SnakeBeta activation: x + (1/beta) * sin^2(alpha * x)."""
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
self.beta = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
def __call__(self, x: mx.array) -> mx.array:
alpha = self.alpha
beta = self.beta
if self.alpha_logscale:
alpha = mx.exp(alpha)
beta = mx.exp(beta)
return x + (1.0 / (beta + 1e-9)) * mx.power(mx.sin(x * alpha), 2)
# ---------------------------------------------------------------------------
# Anti-aliased resampling (Kaiser-sinc filters)
# ---------------------------------------------------------------------------
def _sinc(x: mx.array) -> mx.array:
return mx.where(
x == 0,
mx.ones_like(x),
mx.sin(mx.array(math.pi) * x) / (mx.array(math.pi) * x),
)
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> mx.array:
"""Compute a Kaiser-windowed sinc filter."""
even = kernel_size % 2 == 0
half_size = kernel_size // 2
delta_f = 4 * half_width
amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if amplitude > 50.0:
beta = 0.1102 * (amplitude - 8.7)
elif amplitude >= 21.0:
beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
else:
beta = 0.0
# Kaiser window - compute using scipy-compatible formula
import numpy as np
window = mx.array(np.kaiser(kernel_size, beta).astype(np.float32))
if even:
time = mx.arange(-half_size, half_size).astype(mx.float32) + 0.5
else:
time = mx.arange(kernel_size).astype(mx.float32) - half_size
if cutoff == 0:
filter_ = mx.zeros_like(time)
else:
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
filter_ = filter_ / mx.sum(filter_)
return filter_.reshape(1, 1, kernel_size)
def hann_sinc_filter1d(ratio: int) -> Tuple[mx.array, int, int, int]:
"""Compute a Hann-windowed sinc filter for upsampling (used by BWE resampler)."""
import numpy as np
rolloff = 0.99
lowpass_filter_width = 6
width = math.ceil(lowpass_filter_width / rolloff)
kernel_size = 2 * width * ratio + 1
pad = width
pad_left = 2 * width * ratio
pad_right = kernel_size - ratio
time = (np.arange(kernel_size) / ratio - width) * rolloff
time_clamped = np.clip(time, -lowpass_filter_width, lowpass_filter_width)
window = np.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
sinc_filter = np.sinc(time) * window * rolloff / ratio
filter_ = mx.array(sinc_filter.astype(np.float32)).reshape(1, 1, kernel_size)
return filter_, pad, pad_left, pad_right
class LowPassFilter1d(nn.Module):
"""Low-pass filter using depthwise convolution with Kaiser-sinc kernel."""
def __init__( def __init__(
self, self,
config: VocoderModelConfig cutoff: float = 0.5,
): half_width: float = 0.6,
stride: int = 1,
kernel_size: int = 12,
) -> None:
super().__init__()
self.kernel_size = kernel_size
self.even = kernel_size % 2 == 0
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
# Filter buffer - shape (1, 1, K) in PyTorch format, loaded from weights
self.filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
def __call__(self, x: mx.array) -> mx.array:
# x: (N, L, C) in MLX format
n, l, c = x.shape
# Pad with edge values: replicate first/last value
first = mx.repeat(x[:, :1, :], self.pad_left, axis=1)
last = mx.repeat(x[:, -1:, :], self.pad_right, axis=1)
x = mx.concatenate([first, x, last], axis=1)
# Expand filter for depthwise conv: (1, 1, K) -> (C, K, 1) for groups=C
# Filter is stored in PyTorch format (1, 1, K), need (C, K, 1) MLX format
filt = self.filter.astype(x.dtype) # (1, 1, K)
filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1)
filt = mx.repeat(filt, c, axis=0) # (C, K, 1)
# Transpose x for depthwise conv: (N, L, C) -> (N*C, L, 1) then conv
x = mx.transpose(x, (0, 2, 1)) # (N, C, L)
x = x.reshape(n * c, -1, 1) # (N*C, L, 1)
x = mx.conv1d(x, filt[:1], stride=self.stride, groups=1) # (N*C, L', 1)
x = x.reshape(n, c, -1) # (N, C, L')
x = mx.transpose(x, (0, 2, 1)) # (N, L', C)
return x
class UpSample1d(nn.Module):
"""Anti-aliased upsampling using transposed convolution with sinc filter."""
def __init__(
self,
ratio: int = 2,
kernel_size: int = None,
window_type: str = "kaiser",
) -> None:
super().__init__()
self.ratio = ratio
self.stride = ratio
if window_type == "hann":
filt, self.pad, self.pad_left, self.pad_right = hann_sinc_filter1d(ratio)
self.kernel_size = filt.shape[2]
self.filter = filt
else:
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
self.filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
kernel_size=self.kernel_size,
)
def __call__(self, x: mx.array) -> mx.array:
# x: (N, L, C) in MLX format
n, l, c = x.shape
# Pad with edge values
first = mx.repeat(x[:, :1, :], self.pad, axis=1)
last = mx.repeat(x[:, -1:, :], self.pad, axis=1)
x = mx.concatenate([first, x, last], axis=1)
# Process per-channel via reshape: (N, L, C) -> (N*C, L, 1)
x = mx.transpose(x, (0, 2, 1)) # (N, C, L)
x = x.reshape(n * c, -1, 1) # (N*C, L, 1)
# Transposed conv for upsampling
# Filter: (1, 1, K) PyTorch -> (1, K, 1) MLX format for conv_transpose1d
filt = self.filter.astype(x.dtype) # (1, 1, K)
filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1)
x = self.ratio * mx.conv_transpose1d(x, filt, stride=self.stride) # (N*C, L', 1)
# Trim padding
x = x[:, self.pad_left:-self.pad_right, :]
x = x.reshape(n, c, -1) # (N, C, L')
x = mx.transpose(x, (0, 2, 1)) # (N, L', C)
return x
class DownSample1d(nn.Module):
"""Anti-aliased downsampling using low-pass filter."""
def __init__(self, ratio: int = 2, kernel_size: int = None) -> None:
super().__init__()
self.ratio = ratio
kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.lowpass = LowPassFilter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=kernel_size,
)
def __call__(self, x: mx.array) -> mx.array:
return self.lowpass(x)
class Activation1d(nn.Module):
"""Anti-aliased activation: upsample -> activate -> downsample."""
def __init__(
self,
activation: nn.Module,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12,
) -> None:
super().__init__()
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
def __call__(self, x: mx.array) -> mx.array:
x = self.upsample(x)
x = self.act(x)
return self.downsample(x)
# ---------------------------------------------------------------------------
# AMPBlock1 (BigVGAN v2 residual block)
# ---------------------------------------------------------------------------
class AMPBlock1(nn.Module):
"""BigVGAN v2 residual block with anti-aliased Snake activations."""
def __init__(
self,
channels: int,
kernel_size: int = 3,
dilation: Tuple[int, int, int] = (1, 3, 5),
activation: str = "snakebeta",
) -> None:
super().__init__()
act_cls = SnakeBeta if activation == "snakebeta" else Snake
self.convs1 = {
i: nn.Conv1d(
channels, channels, kernel_size, stride=1,
dilation=d, padding=get_padding(kernel_size, d),
)
for i, d in enumerate(dilation)
}
self.convs2 = {
i: nn.Conv1d(
channels, channels, kernel_size, stride=1,
dilation=1, padding=get_padding(kernel_size, 1),
)
for i in range(len(dilation))
}
self.acts1 = {i: Activation1d(act_cls(channels)) for i in range(len(dilation))}
self.acts2 = {i: Activation1d(act_cls(channels)) for i in range(len(dilation))}
def __call__(self, x: mx.array) -> mx.array:
for i in range(len(self.convs1)):
xt = self.acts1[i](x)
xt = self.convs1[i](xt)
xt = self.acts2[i](xt)
xt = self.convs2[i](xt)
x = x + xt
return x
# ---------------------------------------------------------------------------
# STFT and MelSTFT (for BWE)
# ---------------------------------------------------------------------------
class STFTFn(nn.Module):
"""STFT via conv1d with precomputed DFT x window bases (loaded from checkpoint)."""
def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None:
super().__init__()
self.hop_length = hop_length
self.win_length = win_length
n_freqs = filter_length // 2 + 1
# Buffers loaded from checkpoint - PyTorch format (n_freqs*2, 1, filter_length)
self.forward_basis = mx.zeros((n_freqs * 2, 1, filter_length))
self.inverse_basis = mx.zeros((n_freqs * 2, 1, filter_length))
def __call__(self, y: mx.array) -> Tuple[mx.array, mx.array]:
"""Compute magnitude and phase from waveform.
Args:
y: (B, T) waveform
Returns:
magnitude: (B, n_freqs, T_frames)
phase: (B, n_freqs, T_frames)
"""
if y.ndim == 2:
y = mx.expand_dims(y, -1) # (B, T, 1)
left_pad = max(0, self.win_length - self.hop_length)
if left_pad > 0:
first = mx.repeat(y[:, :1, :], left_pad, axis=1)
y = mx.concatenate([first, y], axis=1)
# forward_basis: (514, 1, 512) PyTorch format -> (514, 512, 1) MLX
basis = mx.transpose(self.forward_basis.astype(y.dtype), (0, 2, 1)) # (514, K, 1)
# Conv1d: (B, T, 1) * (514, K, 1) -> (B, T_frames, 514)
spec = mx.conv1d(y, basis, stride=self.hop_length)
# Split real and imaginary
n_freqs = spec.shape[-1] // 2
real = spec[..., :n_freqs]
imag = spec[..., n_freqs:]
magnitude = mx.sqrt(real ** 2 + imag ** 2)
phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype(real.dtype)
# Output: (B, T_frames, n_freqs) in MLX channels-last
return magnitude, phase
class MelSTFT(nn.Module):
"""Causal log-mel spectrogram from precomputed STFT bases."""
def __init__(self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int) -> None:
super().__init__()
self.stft_fn = STFTFn(filter_length, hop_length, win_length)
n_freqs = filter_length // 2 + 1
self.mel_basis = mx.zeros((n_mel_channels, n_freqs))
def mel_spectrogram(self, y: mx.array) -> mx.array:
"""Compute log-mel spectrogram.
Args:
y: (B, T) waveform
Returns:
log_mel: (B, n_mels, T_frames) in channels-first for compatibility
"""
magnitude, phase = self.stft_fn(y)
# magnitude: (B, T_frames, n_freqs)
mel = magnitude @ self.mel_basis.astype(magnitude.dtype).T # (B, T_frames, n_mels)
log_mel = mx.log(mx.clip(mel, 1e-5, None))
# Transpose to (B, n_mels, T_frames) for compatibility with vocoder input format
return mx.transpose(log_mel, (0, 2, 1))
# ---------------------------------------------------------------------------
# Vocoder (supports both HiFi-GAN and BigVGAN v2)
# ---------------------------------------------------------------------------
class Vocoder(nn.Module):
"""Vocoder for mel-to-waveform synthesis.
Supports resblock="1" (HiFi-GAN / LTX-2) and resblock="AMP1" (BigVGAN v2 / LTX-2.3).
"""
def __init__(self, config: VocoderModelConfig) -> None:
super().__init__() super().__init__()
self.output_sampling_rate = config.output_sample_rate
self.output_sample_rate = config.output_sample_rate
self.num_kernels = len(config.resblock_kernel_sizes) self.num_kernels = len(config.resblock_kernel_sizes)
self.num_upsamples = len(config.upsample_rates) self.num_upsamples = len(config.upsample_rates)
self.upsample_rates = config.upsample_rates self.upsample_rates = config.upsample_rates
self.upsample_kernel_sizes = config.upsample_kernel_sizes self.is_amp = config.resblock == "AMP1"
self.upsample_initial_channel = config.upsample_initial_channel self.use_tanh_at_final = config.use_tanh_at_final
self.apply_final_activation = config.apply_final_activation
in_channels = 128 if config.stereo else 64 in_channels = 128 if config.stereo else 64
self.conv_pre = nn.Conv1d(in_channels, config.upsample_initial_channel, kernel_size=7, stride=1, padding=3) self.conv_pre = nn.Conv1d(
in_channels, config.upsample_initial_channel,
kernel_size=7, stride=1, padding=3,
)
resblock_class = ResBlock1 if config.resblock == "1" else ResBlock2 # Upsampling layers
# Upsampling layers using ConvTranspose1d
self.ups = {} self.ups = {}
for i, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)): for i, (stride, kernel_size) in enumerate(
in_ch = config.upsample_initial_channel // (2**i) zip(config.upsample_rates, config.upsample_kernel_sizes)
):
in_ch = config.upsample_initial_channel // (2 ** i)
out_ch = config.upsample_initial_channel // (2 ** (i + 1)) out_ch = config.upsample_initial_channel // (2 ** (i + 1))
self.ups[i] = nn.ConvTranspose1d( self.ups[i] = nn.ConvTranspose1d(
in_ch, in_ch, out_ch,
out_ch, kernel_size=kernel_size, stride=stride,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - stride) // 2, padding=(kernel_size - stride) // 2,
) )
# Residual blocks # Residual blocks
if self.is_amp:
self.resblocks = {} self.resblocks = {}
block_idx = 0 block_idx = 0
for i in range(len(self.ups)): for i in range(len(self.ups)):
ch = config.upsample_initial_channel // (2 ** (i + 1)) ch = config.upsample_initial_channel // (2 ** (i + 1))
for kernel_size, dilations in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes): for kernel_size, dilations in zip(
config.resblock_kernel_sizes, config.resblock_dilation_sizes
):
self.resblocks[block_idx] = AMPBlock1(
ch, kernel_size, tuple(dilations),
activation=config.activation,
)
block_idx += 1
else:
resblock_class = ResBlock1 if config.resblock == "1" else ResBlock2
self.resblocks = {}
block_idx = 0
for i in range(len(self.ups)):
ch = config.upsample_initial_channel // (2 ** (i + 1))
for kernel_size, dilations in zip(
config.resblock_kernel_sizes, config.resblock_dilation_sizes
):
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations)) self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations))
block_idx += 1 block_idx += 1
final_channels = config.upsample_initial_channel // (2 ** len(config.upsample_rates))
# Post-activation
if self.is_amp:
act_cls = SnakeBeta if config.activation == "snakebeta" else Snake
self.act_post = Activation1d(act_cls(final_channels))
# Final conv
out_channels = 2 if config.stereo else 1 out_channels = 2 if config.stereo else 1
final_channels = config.upsample_initial_channel // (2**self.num_upsamples) self.conv_post = nn.Conv1d(
self.conv_post = nn.Conv1d(final_channels, out_channels, kernel_size=7, stride=1, padding=3) final_channels, out_channels,
kernel_size=7, stride=1, padding=3,
bias=config.use_bias_at_final,
)
self.upsample_factor = math.prod(config.upsample_rates) self.upsample_factor = math.prod(config.upsample_rates)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
sanitized = {}
if "vocoder." not in weights:
return weights
for key, value in weights.items():
new_key = key
# Handle vocoder weights
if key.startswith("vocoder."):
new_key = key.replace("vocoder.", "")
# Handle ModuleList indices -> dict keys
# PyTorch: ups.0, ups.1, ... -> ups.0, ups.1, ...
# PyTorch: resblocks.0, resblocks.1, ... -> resblocks.0, resblocks.1, ...
# Handle Conv1d weight shape conversion
# PyTorch: (out_channels, in_channels, kernel)
# MLX: (out_channels, kernel, in_channels)
if "weight" in new_key and value.ndim == 3:
if "ups" in new_key:
# ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch)
value = value if check_array_shape(value) else mx.transpose(value, (1, 2, 0))
else:
# Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch)
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 1))
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "Vocoder":
"""Load vocoder from pretrained model."""
from mlx_video.models.ltx.config import VocoderModelConfig
import json
config_dict = {}
with open(model_path / "config.json", "r") as f:
config_dict = json.load(f)
config = VocoderModelConfig.from_dict(config_dict)
model = cls(config)
weights = mx.load(str(model_path / "model.safetensors"))
# Use strict=False to skip extra keys (e.g., bwe_generator in LTX-2.3)
model.load_weights(list(weights.items()), strict=False)
return model
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
""" """Forward pass.
Forward pass of the vocoder.
Args: Args:
x: Input Mel spectrogram tensor. Can be either: x: Mel spectrogram (B, C, T, mel_bins) for stereo or (B, T, mel_bins) mono.
- 3D: (batch_size, time, mel_bins) for mono - MLX format (N, L, C)
- 4D: (batch_size, 2, time, mel_bins) for stereo - PyTorch format (N, C, H, W)
Returns: Returns:
Audio waveform tensor of shape (batch_size, out_channels, audio_length) Waveform (B, out_channels, T_audio) in channels-first format.
""" """
# Input: (batch, channels, time, mel_bins) from audio decoder # (B, C, T, mel) -> (B, C, mel, T)
# Transpose to (batch, channels, mel_bins, time)
x = mx.transpose(x, (0, 1, 3, 2)) x = mx.transpose(x, (0, 1, 3, 2))
if x.ndim == 4: # stereo if x.ndim == 4: # stereo: (B, 2, mel, T) -> (B, 2*mel, T)
# x shape: (batch, 2, mel_bins, time)
# Rearrange to (batch, 2*mel_bins, time)
b, s, c, t = x.shape b, s, c, t = x.shape
x = x.reshape(b, s * c, t) x = x.reshape(b, s * c, t)
# MLX Conv1d expects (N, L, C), so transpose # Channels-first (B, C, T) -> channels-last (B, T, C) for MLX conv
# Current: (batch, channels, time) -> (batch, time, channels)
x = mx.transpose(x, (0, 2, 1)) x = mx.transpose(x, (0, 2, 1))
x = self.conv_pre(x) x = self.conv_pre(x)
for i in range(self.num_upsamples): for i in range(self.num_upsamples):
if not self.is_amp:
x = leaky_relu(x, LRELU_SLOPE) x = leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x) x = self.ups[i](x)
start = i * self.num_kernels start = i * self.num_kernels
end = start + self.num_kernels end = start + self.num_kernels
# Apply residual blocks and average their outputs block_outputs = mx.stack(
block_outputs = [] [self.resblocks[idx](x) for idx in range(start, end)],
for idx in range(start, end): axis=0,
block_outputs.append(self.resblocks[idx](x)) )
x = mx.mean(block_outputs, axis=0)
# Stack and mean if self.is_amp:
x = mx.stack(block_outputs, axis=0) x = self.act_post(x)
x = mx.mean(x, axis=0) else:
x = nn.leaky_relu(x)
# IMPORTANT: Use default leaky_relu slope (0.01), NOT LRELU_SLOPE (0.1)
# PyTorch uses F.leaky_relu(x) which defaults to 0.01
x = nn.leaky_relu(x) # Default negative_slope=0.01
x = self.conv_post(x) x = self.conv_post(x)
x = mx.tanh(x)
# Transpose back to (batch, channels, time) if self.apply_final_activation:
x = mx.tanh(x) if self.use_tanh_at_final else mx.clip(x, -1, 1)
# Back to channels-first (B, T, C) -> (B, C, T)
x = mx.transpose(x, (0, 2, 1)) x = mx.transpose(x, (0, 2, 1))
return x return x
# ---------------------------------------------------------------------------
# VocoderWithBWE (Bandwidth Extension)
# ---------------------------------------------------------------------------
class VocoderWithBWE(nn.Module):
"""Vocoder + bandwidth extension upsampling (16kHz -> 48kHz).
Chains a base vocoder with a BWE generator that predicts a residual
added to a sinc-resampled skip connection.
"""
def __init__(
self,
vocoder: Vocoder,
bwe_generator: Vocoder,
mel_stft: MelSTFT,
input_sampling_rate: int = 16000,
output_sampling_rate: int = 48000,
hop_length: int = 80,
) -> None:
super().__init__()
self.vocoder = vocoder
self.bwe_generator = bwe_generator
self.mel_stft = mel_stft
self.input_sampling_rate = input_sampling_rate
self.output_sampling_rate = output_sampling_rate
self.hop_length = hop_length
# Hann-windowed sinc resampler (not stored in checkpoint)
self.resampler = UpSample1d(
ratio=output_sampling_rate // input_sampling_rate,
window_type="hann",
)
@property
def output_sample_rate(self) -> int:
return self.output_sampling_rate
def _compute_mel(self, audio: mx.array) -> mx.array:
"""Compute log-mel spectrogram from waveform.
Args:
audio: (B, C, T) waveform in channels-first
Returns:
mel: (B, C, n_mels, T_frames)
"""
batch, n_channels, _ = audio.shape
flat = audio.reshape(batch * n_channels, -1) # (B*C, T)
mel = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2])
def __call__(self, mel_spec: mx.array) -> mx.array:
"""Run vocoder + BWE.
Args:
mel_spec: Mel spectrogram, same format as Vocoder.forward input.
Returns:
Waveform (B, out_channels, T_audio) at output_sampling_rate.
"""
x = self.vocoder(mel_spec) # (B, C, T) at input_sampling_rate
_, _, length_low_rate = x.shape
output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
# Pad to hop_length multiple
remainder = length_low_rate % self.hop_length
if remainder != 0:
pad_amount = self.hop_length - remainder
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_amount)])
# Compute mel from vocoder output: (B, C, n_mels, T_frames)
mel = self._compute_mel(x)
# BWE expects (B, C, T_frames, mel_bins) -> transpose last two dims
mel_for_bwe = mx.transpose(mel, (0, 1, 3, 2)) # (B, C, T_frames, n_mels)
residual = self.bwe_generator(mel_for_bwe) # (B, C, T_high)
# Sinc upsample skip connection
# resampler expects (N, L, C): transpose from (B, C, T) -> (B, T, C)
x_for_resample = mx.transpose(x, (0, 2, 1))
skip = self.resampler(x_for_resample)
skip = mx.transpose(skip, (0, 2, 1)) # back to (B, C, T)
return mx.clip(residual + skip, -1, 1)[..., :output_length]
# ---------------------------------------------------------------------------
# Factory / from_pretrained
# ---------------------------------------------------------------------------
def load_vocoder(model_path: Path) -> nn.Module:
"""Load vocoder from pretrained model directory.
Automatically detects whether to load a simple Vocoder or VocoderWithBWE.
"""
import json
config_path = model_path / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"No config.json found in {model_path}")
with open(config_path) as f:
config_dict = json.load(f)
weights = mx.load(str(model_path / "model.safetensors"))
has_bwe = config_dict.get("has_bwe_generator", False)
if has_bwe:
return _load_vocoder_with_bwe(config_dict, weights)
else:
config = VocoderModelConfig.from_dict(config_dict)
model = Vocoder(config)
model.load_weights(list(weights.items()), strict=True)
return model
def _load_vocoder_with_bwe(config_dict: dict, weights: dict) -> VocoderWithBWE:
"""Load VocoderWithBWE from config and weights."""
# Build vocoder from config
vocoder_cfg = config_dict.get("vocoder", {})
vocoder_config = VocoderModelConfig.from_dict(vocoder_cfg)
vocoder = Vocoder(vocoder_config)
# Build BWE generator from config
bwe_cfg = config_dict.get("bwe", {})
bwe_config = VocoderModelConfig.from_dict(bwe_cfg)
bwe_config.apply_final_activation = False
bwe_generator = Vocoder(bwe_config)
# MelSTFT from weight shapes
stft_basis = weights.get("mel_stft.stft_fn.forward_basis")
filter_length = stft_basis.shape[2] if stft_basis is not None else 512
mel_basis = weights.get("mel_stft.mel_basis")
n_mel_channels = mel_basis.shape[0] if mel_basis is not None else 64
hop_length = bwe_cfg.get("hop_length", 80)
input_sr = bwe_cfg.get("input_sampling_rate", 16000)
output_sr = bwe_cfg.get("output_sampling_rate", 48000)
mel_stft = MelSTFT(
filter_length=filter_length,
hop_length=hop_length,
win_length=filter_length,
n_mel_channels=n_mel_channels,
)
model = VocoderWithBWE(
vocoder=vocoder,
bwe_generator=bwe_generator,
mel_stft=mel_stft,
input_sampling_rate=input_sr,
output_sampling_rate=output_sr,
hop_length=hop_length,
)
model.load_weights(list(weights.items()), strict=False)
return model

View File

@@ -260,6 +260,10 @@ class VocoderModelConfig(BaseModelConfig):
stereo: bool = True stereo: bool = True
resblock: str = "1" resblock: str = "1"
output_sample_rate: int = 24000 output_sample_rate: int = 24000
activation: str = "snake"
use_tanh_at_final: bool = True
apply_final_activation: bool = True
use_bias_at_final: bool = True
def __post_init__(self): def __post_init__(self):