From 53bae534e72c34a9732b6b8138f65a0c0c857361 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 15 Mar 2026 02:06:35 +0100 Subject: [PATCH] fix LTX-2.3 audio --- mlx_video/generate.py | 18 +- mlx_video/models/ltx/audio_vae/__init__.py | 3 +- mlx_video/models/ltx/audio_vae/vocoder.py | 752 +++++++++++++++++---- mlx_video/models/ltx/config.py | 6 +- 4 files changed, 649 insertions(+), 130 deletions(-) diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 5a7e2fe..fe3cbe9 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -1012,13 +1012,14 @@ def load_audio_decoder(model_path: Path, pipeline: PipelineType): return decoder -def load_vocoder(model_path: Path, pipeline: PipelineType): - """Load vocoder for mel to waveform conversion.""" - from mlx_video.models.ltx.audio_vae import Vocoder +def load_vocoder_model(model_path: Path, pipeline: PipelineType): + """Load vocoder for mel to waveform conversion. - vocoder = Vocoder.from_pretrained(model_path / "vocoder") + Automatically detects HiFi-GAN (LTX-2) or BigVGAN+BWE (LTX-2.3). + """ + from mlx_video.models.ltx.audio_vae.vocoder import load_vocoder as _load_vocoder - return vocoder + return _load_vocoder(model_path / "vocoder") def save_audio(audio: np.ndarray, path: Path, sample_rate: int = AUDIO_SAMPLE_RATE): @@ -1795,7 +1796,7 @@ def generate_video( if audio and audio_latents is not None: with console.status("[blue]🔊 Decoding audio...[/]", spinner="dots"): audio_decoder = load_audio_decoder(model_path, pipeline) - vocoder = load_vocoder(model_path, pipeline) + vocoder = load_vocoder_model(model_path, pipeline) mx.eval(audio_decoder.parameters(), vocoder.parameters()) mel_spectrogram = audio_decoder(audio_latents) @@ -1809,12 +1810,15 @@ def generate_video( if audio_np.ndim == 3: audio_np = audio_np[0] + # Get sample rate from vocoder (dynamic: 24kHz for LTX-2, 48kHz for LTX-2.3 BWE) + vocoder_sample_rate = getattr(vocoder, 'output_sampling_rate', AUDIO_SAMPLE_RATE) + del audio_decoder, vocoder mx.clear_cache() console.print("[green]✓[/] Audio decoded") audio_path = Path(output_audio_path) if output_audio_path else output_path.with_suffix('.wav') - save_audio(audio_np, audio_path, AUDIO_SAMPLE_RATE) + save_audio(audio_np, audio_path, vocoder_sample_rate) console.print(f"[green]✅ Saved audio to[/] {audio_path}") with console.status("[blue]🎬 Combining video and audio...[/]", spinner="dots"): diff --git a/mlx_video/models/ltx/audio_vae/__init__.py b/mlx_video/models/ltx/audio_vae/__init__.py index 8786118..3a9e262 100644 --- a/mlx_video/models/ltx/audio_vae/__init__.py +++ b/mlx_video/models/ltx/audio_vae/__init__.py @@ -9,12 +9,13 @@ from .normalization import NormType, PixelNorm, build_normalization_layer from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, ResnetBlock from .upsample import Upsample, build_upsampling_path -from .vocoder import Vocoder +from .vocoder import Vocoder, load_vocoder __all__ = [ # Main components "AudioDecoder", "Vocoder", + "load_vocoder", "decode_audio", # Ops "AudioLatentShape", diff --git a/mlx_video/models/ltx/audio_vae/vocoder.py b/mlx_video/models/ltx/audio_vae/vocoder.py index ea06f63..71b548c 100644 --- a/mlx_video/models/ltx/audio_vae/vocoder.py +++ b/mlx_video/models/ltx/audio_vae/vocoder.py @@ -1,179 +1,689 @@ -"""Vocoder for converting mel spectrograms to audio waveforms.""" +"""Vocoder for converting mel spectrograms to audio waveforms. + +Supports: +- HiFi-GAN (LTX-2): ResBlock1 with LeakyReLU +- BigVGAN v2 (LTX-2.3): AMPBlock1 with Snake/SnakeBeta + anti-aliased resampling +- VocoderWithBWE (LTX-2.3): Base vocoder + bandwidth extension (16kHz -> 48kHz) +""" import math -from typing import Dict +from typing import List, Tuple from pathlib import Path + import mlx.core as mx import mlx.nn as nn -from mlx_vlm.models.base import check_array_shape + from ..config import VocoderModelConfig from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, leaky_relu -class Vocoder(nn.Module): - """ - Vocoder model for synthesizing audio from Mel spectrograms. - Based on HiFi-GAN architecture. +def get_padding(kernel_size: int, dilation: int = 1) -> int: + return int((kernel_size * dilation - dilation) / 2) - Args: - resblock_kernel_sizes: List of kernel sizes for the residual blocks - upsample_rates: List of upsampling rates - upsample_kernel_sizes: List of kernel sizes for the upsampling layers - resblock_dilation_sizes: List of dilation sizes for the residual blocks - upsample_initial_channel: Initial number of channels for upsampling - stereo: Whether to use stereo output - resblock: Type of residual block to use ("1" or "2") - output_sample_rate: Waveform sample rate - """ + +# --------------------------------------------------------------------------- +# Snake / SnakeBeta activations (BigVGAN v2) +# --------------------------------------------------------------------------- + + +class Snake(nn.Module): + """Snake activation: x + (1/alpha) * sin^2(alpha * x).""" + + def __init__(self, in_features: int, alpha_logscale: bool = True) -> None: + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,)) + + def __call__(self, x: mx.array) -> mx.array: + # x: (N, L, C) in MLX format + alpha = self.alpha # (C,) + if self.alpha_logscale: + alpha = mx.exp(alpha) + return x + (1.0 / (alpha + 1e-9)) * mx.power(mx.sin(x * alpha), 2) + + +class SnakeBeta(nn.Module): + """SnakeBeta activation: x + (1/beta) * sin^2(alpha * x).""" + + def __init__(self, in_features: int, alpha_logscale: bool = True) -> None: + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,)) + self.beta = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,)) + + def __call__(self, x: mx.array) -> mx.array: + alpha = self.alpha + beta = self.beta + if self.alpha_logscale: + alpha = mx.exp(alpha) + beta = mx.exp(beta) + return x + (1.0 / (beta + 1e-9)) * mx.power(mx.sin(x * alpha), 2) + + +# --------------------------------------------------------------------------- +# Anti-aliased resampling (Kaiser-sinc filters) +# --------------------------------------------------------------------------- + + +def _sinc(x: mx.array) -> mx.array: + return mx.where( + x == 0, + mx.ones_like(x), + mx.sin(mx.array(math.pi) * x) / (mx.array(math.pi) * x), + ) + + +def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> mx.array: + """Compute a Kaiser-windowed sinc filter.""" + even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + delta_f = 4 * half_width + amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if amplitude > 50.0: + beta = 0.1102 * (amplitude - 8.7) + elif amplitude >= 21.0: + beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0) + else: + beta = 0.0 + + # Kaiser window - compute using scipy-compatible formula + import numpy as np + window = mx.array(np.kaiser(kernel_size, beta).astype(np.float32)) + + if even: + time = mx.arange(-half_size, half_size).astype(mx.float32) + 0.5 + else: + time = mx.arange(kernel_size).astype(mx.float32) - half_size + + if cutoff == 0: + filter_ = mx.zeros_like(time) + else: + filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time) + filter_ = filter_ / mx.sum(filter_) + + return filter_.reshape(1, 1, kernel_size) + + +def hann_sinc_filter1d(ratio: int) -> Tuple[mx.array, int, int, int]: + """Compute a Hann-windowed sinc filter for upsampling (used by BWE resampler).""" + import numpy as np + rolloff = 0.99 + lowpass_filter_width = 6 + width = math.ceil(lowpass_filter_width / rolloff) + kernel_size = 2 * width * ratio + 1 + pad = width + pad_left = 2 * width * ratio + pad_right = kernel_size - ratio + + time = (np.arange(kernel_size) / ratio - width) * rolloff + time_clamped = np.clip(time, -lowpass_filter_width, lowpass_filter_width) + window = np.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2 + sinc_filter = np.sinc(time) * window * rolloff / ratio + + filter_ = mx.array(sinc_filter.astype(np.float32)).reshape(1, 1, kernel_size) + return filter_, pad, pad_left, pad_right + + +class LowPassFilter1d(nn.Module): + """Low-pass filter using depthwise convolution with Kaiser-sinc kernel.""" def __init__( self, - config: VocoderModelConfig - ): + cutoff: float = 0.5, + half_width: float = 0.6, + stride: int = 1, + kernel_size: int = 12, + ) -> None: + super().__init__() + self.kernel_size = kernel_size + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + # Filter buffer - shape (1, 1, K) in PyTorch format, loaded from weights + self.filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + + def __call__(self, x: mx.array) -> mx.array: + # x: (N, L, C) in MLX format + n, l, c = x.shape + + # Pad with edge values: replicate first/last value + first = mx.repeat(x[:, :1, :], self.pad_left, axis=1) + last = mx.repeat(x[:, -1:, :], self.pad_right, axis=1) + x = mx.concatenate([first, x, last], axis=1) + + # Expand filter for depthwise conv: (1, 1, K) -> (C, K, 1) for groups=C + # Filter is stored in PyTorch format (1, 1, K), need (C, K, 1) MLX format + filt = self.filter.astype(x.dtype) # (1, 1, K) + filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1) + filt = mx.repeat(filt, c, axis=0) # (C, K, 1) + + # Transpose x for depthwise conv: (N, L, C) -> (N*C, L, 1) then conv + x = mx.transpose(x, (0, 2, 1)) # (N, C, L) + x = x.reshape(n * c, -1, 1) # (N*C, L, 1) + + x = mx.conv1d(x, filt[:1], stride=self.stride, groups=1) # (N*C, L', 1) + + x = x.reshape(n, c, -1) # (N, C, L') + x = mx.transpose(x, (0, 2, 1)) # (N, L', C) + return x + + +class UpSample1d(nn.Module): + """Anti-aliased upsampling using transposed convolution with sinc filter.""" + + def __init__( + self, + ratio: int = 2, + kernel_size: int = None, + window_type: str = "kaiser", + ) -> None: + super().__init__() + self.ratio = ratio + self.stride = ratio + + if window_type == "hann": + filt, self.pad, self.pad_left, self.pad_right = hann_sinc_filter1d(ratio) + self.kernel_size = filt.shape[2] + self.filter = filt + else: + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + self.filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size, + ) + + def __call__(self, x: mx.array) -> mx.array: + # x: (N, L, C) in MLX format + n, l, c = x.shape + + # Pad with edge values + first = mx.repeat(x[:, :1, :], self.pad, axis=1) + last = mx.repeat(x[:, -1:, :], self.pad, axis=1) + x = mx.concatenate([first, x, last], axis=1) + + # Process per-channel via reshape: (N, L, C) -> (N*C, L, 1) + x = mx.transpose(x, (0, 2, 1)) # (N, C, L) + x = x.reshape(n * c, -1, 1) # (N*C, L, 1) + + # Transposed conv for upsampling + # Filter: (1, 1, K) PyTorch -> (1, K, 1) MLX format for conv_transpose1d + filt = self.filter.astype(x.dtype) # (1, 1, K) + filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1) + + x = self.ratio * mx.conv_transpose1d(x, filt, stride=self.stride) # (N*C, L', 1) + + # Trim padding + x = x[:, self.pad_left:-self.pad_right, :] + + x = x.reshape(n, c, -1) # (N, C, L') + x = mx.transpose(x, (0, 2, 1)) # (N, L', C) + return x + + +class DownSample1d(nn.Module): + """Anti-aliased downsampling using low-pass filter.""" + + def __init__(self, ratio: int = 2, kernel_size: int = None) -> None: + super().__init__() + self.ratio = ratio + kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=kernel_size, + ) + + def __call__(self, x: mx.array) -> mx.array: + return self.lowpass(x) + + +class Activation1d(nn.Module): + """Anti-aliased activation: upsample -> activate -> downsample.""" + + def __init__( + self, + activation: nn.Module, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + ) -> None: + super().__init__() + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + def __call__(self, x: mx.array) -> mx.array: + x = self.upsample(x) + x = self.act(x) + return self.downsample(x) + + +# --------------------------------------------------------------------------- +# AMPBlock1 (BigVGAN v2 residual block) +# --------------------------------------------------------------------------- + + +class AMPBlock1(nn.Module): + """BigVGAN v2 residual block with anti-aliased Snake activations.""" + + def __init__( + self, + channels: int, + kernel_size: int = 3, + dilation: Tuple[int, int, int] = (1, 3, 5), + activation: str = "snakebeta", + ) -> None: + super().__init__() + act_cls = SnakeBeta if activation == "snakebeta" else Snake + + self.convs1 = { + i: nn.Conv1d( + channels, channels, kernel_size, stride=1, + dilation=d, padding=get_padding(kernel_size, d), + ) + for i, d in enumerate(dilation) + } + + self.convs2 = { + i: nn.Conv1d( + channels, channels, kernel_size, stride=1, + dilation=1, padding=get_padding(kernel_size, 1), + ) + for i in range(len(dilation)) + } + + self.acts1 = {i: Activation1d(act_cls(channels)) for i in range(len(dilation))} + self.acts2 = {i: Activation1d(act_cls(channels)) for i in range(len(dilation))} + + def __call__(self, x: mx.array) -> mx.array: + for i in range(len(self.convs1)): + xt = self.acts1[i](x) + xt = self.convs1[i](xt) + xt = self.acts2[i](xt) + xt = self.convs2[i](xt) + x = x + xt + return x + + +# --------------------------------------------------------------------------- +# STFT and MelSTFT (for BWE) +# --------------------------------------------------------------------------- + + +class STFTFn(nn.Module): + """STFT via conv1d with precomputed DFT x window bases (loaded from checkpoint).""" + + def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None: + super().__init__() + self.hop_length = hop_length + self.win_length = win_length + n_freqs = filter_length // 2 + 1 + # Buffers loaded from checkpoint - PyTorch format (n_freqs*2, 1, filter_length) + self.forward_basis = mx.zeros((n_freqs * 2, 1, filter_length)) + self.inverse_basis = mx.zeros((n_freqs * 2, 1, filter_length)) + + def __call__(self, y: mx.array) -> Tuple[mx.array, mx.array]: + """Compute magnitude and phase from waveform. + + Args: + y: (B, T) waveform + + Returns: + magnitude: (B, n_freqs, T_frames) + phase: (B, n_freqs, T_frames) + """ + if y.ndim == 2: + y = mx.expand_dims(y, -1) # (B, T, 1) + + left_pad = max(0, self.win_length - self.hop_length) + if left_pad > 0: + first = mx.repeat(y[:, :1, :], left_pad, axis=1) + y = mx.concatenate([first, y], axis=1) + + # forward_basis: (514, 1, 512) PyTorch format -> (514, 512, 1) MLX + basis = mx.transpose(self.forward_basis.astype(y.dtype), (0, 2, 1)) # (514, K, 1) + + # Conv1d: (B, T, 1) * (514, K, 1) -> (B, T_frames, 514) + spec = mx.conv1d(y, basis, stride=self.hop_length) + + # Split real and imaginary + n_freqs = spec.shape[-1] // 2 + real = spec[..., :n_freqs] + imag = spec[..., n_freqs:] + + magnitude = mx.sqrt(real ** 2 + imag ** 2) + phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype(real.dtype) + + # Output: (B, T_frames, n_freqs) in MLX channels-last + return magnitude, phase + + +class MelSTFT(nn.Module): + """Causal log-mel spectrogram from precomputed STFT bases.""" + + def __init__(self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int) -> None: + super().__init__() + self.stft_fn = STFTFn(filter_length, hop_length, win_length) + n_freqs = filter_length // 2 + 1 + self.mel_basis = mx.zeros((n_mel_channels, n_freqs)) + + def mel_spectrogram(self, y: mx.array) -> mx.array: + """Compute log-mel spectrogram. + + Args: + y: (B, T) waveform + + Returns: + log_mel: (B, n_mels, T_frames) in channels-first for compatibility + """ + magnitude, phase = self.stft_fn(y) + # magnitude: (B, T_frames, n_freqs) + mel = magnitude @ self.mel_basis.astype(magnitude.dtype).T # (B, T_frames, n_mels) + log_mel = mx.log(mx.clip(mel, 1e-5, None)) + # Transpose to (B, n_mels, T_frames) for compatibility with vocoder input format + return mx.transpose(log_mel, (0, 2, 1)) + + +# --------------------------------------------------------------------------- +# Vocoder (supports both HiFi-GAN and BigVGAN v2) +# --------------------------------------------------------------------------- + + +class Vocoder(nn.Module): + """Vocoder for mel-to-waveform synthesis. + + Supports resblock="1" (HiFi-GAN / LTX-2) and resblock="AMP1" (BigVGAN v2 / LTX-2.3). + """ + + def __init__(self, config: VocoderModelConfig) -> None: super().__init__() - - - self.output_sample_rate = config.output_sample_rate + self.output_sampling_rate = config.output_sample_rate self.num_kernels = len(config.resblock_kernel_sizes) self.num_upsamples = len(config.upsample_rates) self.upsample_rates = config.upsample_rates - self.upsample_kernel_sizes = config.upsample_kernel_sizes - self.upsample_initial_channel = config.upsample_initial_channel + self.is_amp = config.resblock == "AMP1" + self.use_tanh_at_final = config.use_tanh_at_final + self.apply_final_activation = config.apply_final_activation in_channels = 128 if config.stereo else 64 - self.conv_pre = nn.Conv1d(in_channels, config.upsample_initial_channel, kernel_size=7, stride=1, padding=3) + self.conv_pre = nn.Conv1d( + in_channels, config.upsample_initial_channel, + kernel_size=7, stride=1, padding=3, + ) - resblock_class = ResBlock1 if config.resblock == "1" else ResBlock2 - - # Upsampling layers using ConvTranspose1d + # Upsampling layers self.ups = {} - for i, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)): - in_ch = config.upsample_initial_channel // (2**i) + for i, (stride, kernel_size) in enumerate( + zip(config.upsample_rates, config.upsample_kernel_sizes) + ): + in_ch = config.upsample_initial_channel // (2 ** i) out_ch = config.upsample_initial_channel // (2 ** (i + 1)) self.ups[i] = nn.ConvTranspose1d( - in_ch, - out_ch, - kernel_size=kernel_size, - stride=stride, + in_ch, out_ch, + kernel_size=kernel_size, stride=stride, padding=(kernel_size - stride) // 2, ) # Residual blocks - self.resblocks = {} - block_idx = 0 - for i in range(len(self.ups)): - ch = config.upsample_initial_channel // (2 ** (i + 1)) - for kernel_size, dilations in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes): - self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations)) - block_idx += 1 + if self.is_amp: + self.resblocks = {} + block_idx = 0 + for i in range(len(self.ups)): + ch = config.upsample_initial_channel // (2 ** (i + 1)) + for kernel_size, dilations in zip( + config.resblock_kernel_sizes, config.resblock_dilation_sizes + ): + self.resblocks[block_idx] = AMPBlock1( + ch, kernel_size, tuple(dilations), + activation=config.activation, + ) + block_idx += 1 + else: + resblock_class = ResBlock1 if config.resblock == "1" else ResBlock2 + self.resblocks = {} + block_idx = 0 + for i in range(len(self.ups)): + ch = config.upsample_initial_channel // (2 ** (i + 1)) + for kernel_size, dilations in zip( + config.resblock_kernel_sizes, config.resblock_dilation_sizes + ): + self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations)) + block_idx += 1 + final_channels = config.upsample_initial_channel // (2 ** len(config.upsample_rates)) + + # Post-activation + if self.is_amp: + act_cls = SnakeBeta if config.activation == "snakebeta" else Snake + self.act_post = Activation1d(act_cls(final_channels)) + + # Final conv out_channels = 2 if config.stereo else 1 - final_channels = config.upsample_initial_channel // (2**self.num_upsamples) - self.conv_post = nn.Conv1d(final_channels, out_channels, kernel_size=7, stride=1, padding=3) + self.conv_post = nn.Conv1d( + final_channels, out_channels, + kernel_size=7, stride=1, padding=3, + bias=config.use_bias_at_final, + ) self.upsample_factor = math.prod(config.upsample_rates) - def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: - sanitized = {} - - if "vocoder." not in weights: - return weights - - for key, value in weights.items(): - new_key = key - - # Handle vocoder weights - if key.startswith("vocoder."): - new_key = key.replace("vocoder.", "") - - # Handle ModuleList indices -> dict keys - # PyTorch: ups.0, ups.1, ... -> ups.0, ups.1, ... - # PyTorch: resblocks.0, resblocks.1, ... -> resblocks.0, resblocks.1, ... - - # Handle Conv1d weight shape conversion - # PyTorch: (out_channels, in_channels, kernel) - # MLX: (out_channels, kernel, in_channels) - if "weight" in new_key and value.ndim == 3: - if "ups" in new_key: - # ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch) - value = value if check_array_shape(value) else mx.transpose(value, (1, 2, 0)) - else: - # Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch) - value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 1)) - - sanitized[new_key] = value - - return sanitized - - @classmethod - def from_pretrained(cls, model_path: Path, strict: bool = True) -> "Vocoder": - """Load vocoder from pretrained model.""" - from mlx_video.models.ltx.config import VocoderModelConfig - import json - - config_dict = {} - with open(model_path / "config.json", "r") as f: - config_dict = json.load(f) - - config = VocoderModelConfig.from_dict(config_dict) - model = cls(config) - weights = mx.load(str(model_path / "model.safetensors")) - - # Use strict=False to skip extra keys (e.g., bwe_generator in LTX-2.3) - model.load_weights(list(weights.items()), strict=False) - return model - - def __call__(self, x: mx.array) -> mx.array: - """ - Forward pass of the vocoder. + """Forward pass. + Args: - x: Input Mel spectrogram tensor. Can be either: - - 3D: (batch_size, time, mel_bins) for mono - MLX format (N, L, C) - - 4D: (batch_size, 2, time, mel_bins) for stereo - PyTorch format (N, C, H, W) + x: Mel spectrogram (B, C, T, mel_bins) for stereo or (B, T, mel_bins) mono. + Returns: - Audio waveform tensor of shape (batch_size, out_channels, audio_length) + Waveform (B, out_channels, T_audio) in channels-first format. """ - # Input: (batch, channels, time, mel_bins) from audio decoder - # Transpose to (batch, channels, mel_bins, time) + # (B, C, T, mel) -> (B, C, mel, T) x = mx.transpose(x, (0, 1, 3, 2)) - if x.ndim == 4: # stereo - # x shape: (batch, 2, mel_bins, time) - # Rearrange to (batch, 2*mel_bins, time) + if x.ndim == 4: # stereo: (B, 2, mel, T) -> (B, 2*mel, T) b, s, c, t = x.shape x = x.reshape(b, s * c, t) - # MLX Conv1d expects (N, L, C), so transpose - # Current: (batch, channels, time) -> (batch, time, channels) + # Channels-first (B, C, T) -> channels-last (B, T, C) for MLX conv x = mx.transpose(x, (0, 2, 1)) x = self.conv_pre(x) for i in range(self.num_upsamples): - x = leaky_relu(x, LRELU_SLOPE) + if not self.is_amp: + x = leaky_relu(x, LRELU_SLOPE) x = self.ups[i](x) start = i * self.num_kernels end = start + self.num_kernels - # Apply residual blocks and average their outputs - block_outputs = [] - for idx in range(start, end): - block_outputs.append(self.resblocks[idx](x)) + block_outputs = mx.stack( + [self.resblocks[idx](x) for idx in range(start, end)], + axis=0, + ) + x = mx.mean(block_outputs, axis=0) - # Stack and mean - x = mx.stack(block_outputs, axis=0) - x = mx.mean(x, axis=0) + if self.is_amp: + x = self.act_post(x) + else: + x = nn.leaky_relu(x) - # IMPORTANT: Use default leaky_relu slope (0.01), NOT LRELU_SLOPE (0.1) - # PyTorch uses F.leaky_relu(x) which defaults to 0.01 - x = nn.leaky_relu(x) # Default negative_slope=0.01 x = self.conv_post(x) - x = mx.tanh(x) - # Transpose back to (batch, channels, time) + if self.apply_final_activation: + x = mx.tanh(x) if self.use_tanh_at_final else mx.clip(x, -1, 1) + + # Back to channels-first (B, T, C) -> (B, C, T) x = mx.transpose(x, (0, 2, 1)) - return x + + +# --------------------------------------------------------------------------- +# VocoderWithBWE (Bandwidth Extension) +# --------------------------------------------------------------------------- + + +class VocoderWithBWE(nn.Module): + """Vocoder + bandwidth extension upsampling (16kHz -> 48kHz). + + Chains a base vocoder with a BWE generator that predicts a residual + added to a sinc-resampled skip connection. + """ + + def __init__( + self, + vocoder: Vocoder, + bwe_generator: Vocoder, + mel_stft: MelSTFT, + input_sampling_rate: int = 16000, + output_sampling_rate: int = 48000, + hop_length: int = 80, + ) -> None: + super().__init__() + self.vocoder = vocoder + self.bwe_generator = bwe_generator + self.mel_stft = mel_stft + self.input_sampling_rate = input_sampling_rate + self.output_sampling_rate = output_sampling_rate + self.hop_length = hop_length + # Hann-windowed sinc resampler (not stored in checkpoint) + self.resampler = UpSample1d( + ratio=output_sampling_rate // input_sampling_rate, + window_type="hann", + ) + + @property + def output_sample_rate(self) -> int: + return self.output_sampling_rate + + def _compute_mel(self, audio: mx.array) -> mx.array: + """Compute log-mel spectrogram from waveform. + + Args: + audio: (B, C, T) waveform in channels-first + + Returns: + mel: (B, C, n_mels, T_frames) + """ + batch, n_channels, _ = audio.shape + flat = audio.reshape(batch * n_channels, -1) # (B*C, T) + mel = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames) + return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2]) + + def __call__(self, mel_spec: mx.array) -> mx.array: + """Run vocoder + BWE. + + Args: + mel_spec: Mel spectrogram, same format as Vocoder.forward input. + + Returns: + Waveform (B, out_channels, T_audio) at output_sampling_rate. + """ + x = self.vocoder(mel_spec) # (B, C, T) at input_sampling_rate + _, _, length_low_rate = x.shape + output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate + + # Pad to hop_length multiple + remainder = length_low_rate % self.hop_length + if remainder != 0: + pad_amount = self.hop_length - remainder + x = mx.pad(x, [(0, 0), (0, 0), (0, pad_amount)]) + + # Compute mel from vocoder output: (B, C, n_mels, T_frames) + mel = self._compute_mel(x) + + # BWE expects (B, C, T_frames, mel_bins) -> transpose last two dims + mel_for_bwe = mx.transpose(mel, (0, 1, 3, 2)) # (B, C, T_frames, n_mels) + residual = self.bwe_generator(mel_for_bwe) # (B, C, T_high) + + # Sinc upsample skip connection + # resampler expects (N, L, C): transpose from (B, C, T) -> (B, T, C) + x_for_resample = mx.transpose(x, (0, 2, 1)) + skip = self.resampler(x_for_resample) + skip = mx.transpose(skip, (0, 2, 1)) # back to (B, C, T) + + return mx.clip(residual + skip, -1, 1)[..., :output_length] + + +# --------------------------------------------------------------------------- +# Factory / from_pretrained +# --------------------------------------------------------------------------- + + +def load_vocoder(model_path: Path) -> nn.Module: + """Load vocoder from pretrained model directory. + + Automatically detects whether to load a simple Vocoder or VocoderWithBWE. + """ + import json + + config_path = model_path / "config.json" + if not config_path.exists(): + raise FileNotFoundError(f"No config.json found in {model_path}") + + with open(config_path) as f: + config_dict = json.load(f) + + weights = mx.load(str(model_path / "model.safetensors")) + + has_bwe = config_dict.get("has_bwe_generator", False) + + if has_bwe: + return _load_vocoder_with_bwe(config_dict, weights) + else: + config = VocoderModelConfig.from_dict(config_dict) + model = Vocoder(config) + model.load_weights(list(weights.items()), strict=True) + return model + + +def _load_vocoder_with_bwe(config_dict: dict, weights: dict) -> VocoderWithBWE: + """Load VocoderWithBWE from config and weights.""" + # Build vocoder from config + vocoder_cfg = config_dict.get("vocoder", {}) + vocoder_config = VocoderModelConfig.from_dict(vocoder_cfg) + vocoder = Vocoder(vocoder_config) + + # Build BWE generator from config + bwe_cfg = config_dict.get("bwe", {}) + bwe_config = VocoderModelConfig.from_dict(bwe_cfg) + bwe_config.apply_final_activation = False + bwe_generator = Vocoder(bwe_config) + + # MelSTFT from weight shapes + stft_basis = weights.get("mel_stft.stft_fn.forward_basis") + filter_length = stft_basis.shape[2] if stft_basis is not None else 512 + mel_basis = weights.get("mel_stft.mel_basis") + n_mel_channels = mel_basis.shape[0] if mel_basis is not None else 64 + + hop_length = bwe_cfg.get("hop_length", 80) + input_sr = bwe_cfg.get("input_sampling_rate", 16000) + output_sr = bwe_cfg.get("output_sampling_rate", 48000) + + mel_stft = MelSTFT( + filter_length=filter_length, + hop_length=hop_length, + win_length=filter_length, + n_mel_channels=n_mel_channels, + ) + + model = VocoderWithBWE( + vocoder=vocoder, + bwe_generator=bwe_generator, + mel_stft=mel_stft, + input_sampling_rate=input_sr, + output_sampling_rate=output_sr, + hop_length=hop_length, + ) + + model.load_weights(list(weights.items()), strict=False) + return model + + diff --git a/mlx_video/models/ltx/config.py b/mlx_video/models/ltx/config.py index b7dfa0a..009bf62 100644 --- a/mlx_video/models/ltx/config.py +++ b/mlx_video/models/ltx/config.py @@ -260,9 +260,13 @@ class VocoderModelConfig(BaseModelConfig): stereo: bool = True resblock: str = "1" output_sample_rate: int = 24000 + activation: str = "snake" + use_tanh_at_final: bool = True + apply_final_activation: bool = True + use_bias_at_final: bool = True def __post_init__(self): - + if self.resblock_kernel_sizes is None: self.resblock_kernel_sizes = [3, 7, 11] if self.upsample_rates is None: