Refactor LTX-2 model structure
This commit is contained in:
48
mlx_video/models/ltx_2/audio_vae/__init__.py
Normal file
48
mlx_video/models/ltx_2/audio_vae/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Audio VAE module for LTX-2 audio generation."""
|
||||
|
||||
from .attention import AttentionType, AttnBlock, make_attn
|
||||
from .audio_vae import AudioDecoder, AudioEncoder, decode_audio
|
||||
from .audio_processor import load_audio, ensure_stereo, waveform_to_mel
|
||||
from .causal_conv_2d import CausalConv2d, make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .downsample import Downsample, build_downsampling_path
|
||||
from .normalization import NormType, PixelNorm, build_normalization_layer
|
||||
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, ResnetBlock
|
||||
from .upsample import Upsample, build_upsampling_path
|
||||
from .vocoder import Vocoder, load_vocoder
|
||||
|
||||
__all__ = [
|
||||
# Main components
|
||||
"AudioEncoder",
|
||||
"AudioDecoder",
|
||||
"Vocoder",
|
||||
"load_vocoder",
|
||||
"decode_audio",
|
||||
# Audio processing
|
||||
"load_audio",
|
||||
"ensure_stereo",
|
||||
"waveform_to_mel",
|
||||
# Ops
|
||||
"AudioLatentShape",
|
||||
"AudioPatchifier",
|
||||
"PerChannelStatistics",
|
||||
# Building blocks
|
||||
"AttentionType",
|
||||
"AttnBlock",
|
||||
"make_attn",
|
||||
"CausalConv2d",
|
||||
"make_conv2d",
|
||||
"CausalityAxis",
|
||||
"Downsample",
|
||||
"build_downsampling_path",
|
||||
"NormType",
|
||||
"PixelNorm",
|
||||
"build_normalization_layer",
|
||||
"ResBlock1",
|
||||
"ResBlock2",
|
||||
"ResnetBlock",
|
||||
"LRELU_SLOPE",
|
||||
"Upsample",
|
||||
"build_upsampling_path",
|
||||
]
|
||||
108
mlx_video/models/ltx_2/audio_vae/attention.py
Normal file
108
mlx_video/models/ltx_2/audio_vae/attention.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Attention blocks for audio VAE."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .normalization import NormType, build_normalization_layer
|
||||
|
||||
|
||||
class AttentionType(Enum):
|
||||
"""Enum for specifying the attention mechanism type."""
|
||||
|
||||
VANILLA = "vanilla"
|
||||
LINEAR = "linear"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
"""Self-attention block for audio VAE."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
norm_type: NormType = NormType.GROUP,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = build_normalization_layer(in_channels, normtype=norm_type)
|
||||
# Using Conv2d with kernel_size=1 for Q, K, V projections
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""
|
||||
Forward pass through attention block.
|
||||
Args:
|
||||
x: Input tensor of shape (B, H, W, C) in MLX channels-last format
|
||||
Returns:
|
||||
Output tensor with attention applied (residual connection)
|
||||
"""
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# Compute attention
|
||||
# x shape: (B, H, W, C)
|
||||
b, h, w, c = q.shape
|
||||
|
||||
# Reshape for attention: (B, H*W, C)
|
||||
q = q.reshape(b, h * w, c)
|
||||
k = k.reshape(b, h * w, c)
|
||||
v = v.reshape(b, h * w, c)
|
||||
|
||||
# Attention: Q @ K^T / sqrt(d)
|
||||
# q: (B, HW, C), k: (B, HW, C) -> k^T: (B, C, HW)
|
||||
# w_: (B, HW, HW)
|
||||
scale = float(c) ** (-0.5)
|
||||
w_ = mx.matmul(q, k.transpose(0, 2, 1)) * scale
|
||||
w_ = mx.softmax(w_, axis=-1)
|
||||
|
||||
# Attend to values
|
||||
# w_: (B, HW, HW), v: (B, HW, C) -> h_: (B, HW, C)
|
||||
h_ = mx.matmul(w_, v)
|
||||
|
||||
# Reshape back to spatial dims
|
||||
h_ = h_.reshape(b, h, w, c)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
"""Identity module that returns input unchanged."""
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return x
|
||||
|
||||
|
||||
def make_attn(
|
||||
in_channels: int,
|
||||
attn_type: AttentionType = AttentionType.VANILLA,
|
||||
norm_type: NormType = NormType.GROUP,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Create an attention module based on type.
|
||||
Args:
|
||||
in_channels: Number of input channels
|
||||
attn_type: Type of attention mechanism
|
||||
norm_type: Type of normalization
|
||||
Returns:
|
||||
Attention module
|
||||
"""
|
||||
if attn_type == AttentionType.VANILLA:
|
||||
return AttnBlock(in_channels, norm_type=norm_type)
|
||||
elif attn_type == AttentionType.NONE:
|
||||
return Identity()
|
||||
elif attn_type == AttentionType.LINEAR:
|
||||
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
||||
else:
|
||||
raise ValueError(f"Unknown attention type: {attn_type}")
|
||||
135
mlx_video/models/ltx_2/audio_vae/audio_processor.py
Normal file
135
mlx_video/models/ltx_2/audio_vae/audio_processor.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Audio processing utilities for loading audio files and computing mel-spectrograms.
|
||||
|
||||
Matches the PyTorch AudioProcessor from LTX-2 (torchaudio.transforms.MelSpectrogram)
|
||||
using librosa for macOS/MLX compatibility.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def load_audio(
|
||||
path: str,
|
||||
target_sr: int = 16000,
|
||||
start_time: float = 0.0,
|
||||
max_duration: float | None = None,
|
||||
mono: bool = False,
|
||||
) -> tuple[np.ndarray, int]:
|
||||
"""Load audio file, resample to target sample rate.
|
||||
|
||||
Args:
|
||||
path: Path to audio file (WAV, FLAC, MP3, OGG, or video with audio track).
|
||||
target_sr: Target sample rate (default 16000 Hz).
|
||||
start_time: Start time in seconds.
|
||||
max_duration: Maximum duration in seconds. None = read to end.
|
||||
mono: If True, convert to mono. Default False (preserve channels).
|
||||
|
||||
Returns:
|
||||
(waveform, sample_rate) where waveform is (channels, samples) float32 numpy array.
|
||||
"""
|
||||
import librosa
|
||||
|
||||
# librosa.load returns mono by default; we want to preserve stereo
|
||||
y, sr = librosa.load(
|
||||
path,
|
||||
sr=target_sr,
|
||||
mono=mono,
|
||||
offset=start_time,
|
||||
duration=max_duration,
|
||||
)
|
||||
|
||||
# Ensure 2D: (channels, samples)
|
||||
if y.ndim == 1:
|
||||
y = y[np.newaxis, :] # (1, samples)
|
||||
|
||||
return y.astype(np.float32), sr
|
||||
|
||||
|
||||
def ensure_stereo(waveform: np.ndarray) -> np.ndarray:
|
||||
"""Ensure waveform is stereo (2, samples). Duplicates mono if needed."""
|
||||
if waveform.ndim == 1:
|
||||
waveform = waveform[np.newaxis, :]
|
||||
if waveform.shape[0] == 1:
|
||||
waveform = np.concatenate([waveform, waveform], axis=0)
|
||||
elif waveform.shape[0] > 2:
|
||||
waveform = waveform[:2]
|
||||
return waveform
|
||||
|
||||
|
||||
def waveform_to_mel(
|
||||
waveform: np.ndarray,
|
||||
sample_rate: int = 16000,
|
||||
n_fft: int = 1024,
|
||||
hop_length: int = 160,
|
||||
win_length: int = 1024,
|
||||
n_mels: int = 64,
|
||||
fmin: float = 0.0,
|
||||
fmax: float = 8000.0,
|
||||
) -> mx.array:
|
||||
"""Convert waveform to log-mel spectrogram matching PyTorch MelSpectrogram.
|
||||
|
||||
PyTorch reference:
|
||||
MelSpectrogram(sample_rate=16000, n_fft=1024, win_length=1024, hop_length=160,
|
||||
f_min=0.0, f_max=8000.0, n_mels=64, power=1.0,
|
||||
mel_scale="slaney", norm="slaney", center=True, pad_mode="reflect")
|
||||
|
||||
Args:
|
||||
waveform: (channels, samples) float32 numpy array.
|
||||
sample_rate: Sample rate of the waveform.
|
||||
n_fft: FFT size.
|
||||
hop_length: Hop length.
|
||||
win_length: Window length.
|
||||
n_mels: Number of mel bins.
|
||||
fmin: Minimum frequency for mel filterbank.
|
||||
fmax: Maximum frequency for mel filterbank.
|
||||
|
||||
Returns:
|
||||
Log-mel spectrogram as mx.array of shape (1, channels, time, n_mels).
|
||||
"""
|
||||
import librosa
|
||||
|
||||
# Ensure 2D
|
||||
if waveform.ndim == 1:
|
||||
waveform = waveform[np.newaxis, :]
|
||||
|
||||
channels = waveform.shape[0]
|
||||
mels = []
|
||||
|
||||
for ch in range(channels):
|
||||
# Magnitude spectrogram (power=1.0)
|
||||
S = np.abs(librosa.stft(
|
||||
waveform[ch],
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
))
|
||||
|
||||
# Mel filterbank with slaney normalization
|
||||
mel_basis = librosa.filters.mel(
|
||||
sr=sample_rate,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mels,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
norm="slaney",
|
||||
)
|
||||
mel = mel_basis @ S
|
||||
|
||||
# Log scale
|
||||
mel = np.log(np.clip(mel, a_min=1e-5, a_max=None))
|
||||
|
||||
# Transpose: (n_mels, time) -> (time, n_mels)
|
||||
mel = mel.T
|
||||
mels.append(mel)
|
||||
|
||||
# Stack channels: (channels, time, n_mels)
|
||||
mel_spec = np.stack(mels, axis=0)
|
||||
|
||||
# Add batch dim: (1, channels, time, n_mels)
|
||||
mel_spec = mel_spec[np.newaxis, ...]
|
||||
|
||||
return mx.array(mel_spec, dtype=mx.float32)
|
||||
532
mlx_video/models/ltx_2/audio_vae/audio_vae.py
Normal file
532
mlx_video/models/ltx_2/audio_vae/audio_vae.py
Normal file
@@ -0,0 +1,532 @@
|
||||
"""Audio VAE encoder and decoder for LTX-2."""
|
||||
|
||||
from typing import Dict
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_vlm.models.base import check_array_shape
|
||||
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig
|
||||
from .attention import AttentionType, make_attn
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .downsample import build_downsampling_path
|
||||
from .normalization import NormType, build_normalization_layer
|
||||
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||
from .resnet import ResnetBlock
|
||||
from .upsample import build_upsampling_path
|
||||
|
||||
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||
|
||||
|
||||
def build_mid_block(
|
||||
channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float,
|
||||
norm_type: NormType,
|
||||
causality_axis: CausalityAxis,
|
||||
attn_type: AttentionType,
|
||||
add_attention: bool,
|
||||
) -> dict:
|
||||
"""Build the middle block with two ResNet blocks and optional attention."""
|
||||
mid = {}
|
||||
mid["block_1"] = ResnetBlock(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
temb_channels=temb_channels,
|
||||
dropout=dropout,
|
||||
norm_type=norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
mid["attn_1"] = (
|
||||
make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else None
|
||||
)
|
||||
mid["block_2"] = ResnetBlock(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
temb_channels=temb_channels,
|
||||
dropout=dropout,
|
||||
norm_type=norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
return mid
|
||||
|
||||
|
||||
def run_mid_block(mid: dict, features: mx.array) -> mx.array:
|
||||
"""Run features through the middle block."""
|
||||
features = mid["block_1"](features, temb=None)
|
||||
if mid["attn_1"] is not None:
|
||||
features = mid["attn_1"](features)
|
||||
return mid["block_2"](features, temb=None)
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
"""Encoder that compresses audio spectrograms into latent representations."""
|
||||
|
||||
def __init__(self, config: AudioEncoderModelConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch)
|
||||
self.sample_rate = config.sample_rate
|
||||
self.mel_hop_length = config.mel_hop_length
|
||||
self.is_causal = config.is_causal
|
||||
self.mel_bins = config.mel_bins
|
||||
|
||||
self.patchifier = AudioPatchifier(
|
||||
patch_size=1,
|
||||
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||
sample_rate=config.sample_rate,
|
||||
hop_length=config.mel_hop_length,
|
||||
is_causal=config.is_causal,
|
||||
)
|
||||
|
||||
self.ch = config.ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(config.ch_mult)
|
||||
self.num_res_blocks = config.num_res_blocks
|
||||
self.resolution = config.resolution
|
||||
self.in_channels = config.in_channels
|
||||
self.z_channels = config.z_channels
|
||||
self.double_z = config.double_z
|
||||
self.norm_type = config.norm_type
|
||||
self.causality_axis = config.causality_axis
|
||||
self.attn_type = config.attn_type
|
||||
|
||||
self.conv_in = make_conv2d(
|
||||
config.in_channels, self.ch, kernel_size=3, stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
self.down, block_in = build_downsampling_path(
|
||||
ch=config.ch,
|
||||
ch_mult=config.ch_mult,
|
||||
num_resolutions=self.num_resolutions,
|
||||
num_res_blocks=config.num_res_blocks,
|
||||
resolution=config.resolution,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=config.dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
attn_type=self.attn_type,
|
||||
attn_resolutions=config.attn_resolutions or set(),
|
||||
resamp_with_conv=config.resamp_with_conv,
|
||||
)
|
||||
|
||||
self.mid = build_mid_block(
|
||||
channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=config.dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
attn_type=self.attn_type,
|
||||
add_attention=config.mid_block_add_attention,
|
||||
)
|
||||
|
||||
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
|
||||
out_channels = 2 * config.z_channels if config.double_z else config.z_channels
|
||||
self.conv_out = make_conv2d(
|
||||
block_in, out_channels, kernel_size=3, stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize audio encoder weights from PyTorch format."""
|
||||
sanitized = {}
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
if key.startswith("audio_vae.encoder."):
|
||||
new_key = key.replace("audio_vae.encoder.", "")
|
||||
elif key.startswith("encoder."):
|
||||
new_key = key.replace("encoder.", "")
|
||||
elif key.startswith("audio_vae.per_channel_statistics."):
|
||||
if "mean-of-means" in key:
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
elif "std-of-means" in key:
|
||||
new_key = "per_channel_statistics.std_of_means"
|
||||
else:
|
||||
continue
|
||||
elif "per_channel_statistics" in key:
|
||||
if "mean-of-means" in key or "latents_mean" in key:
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
elif "std-of-means" in key or "latents_std" in key:
|
||||
new_key = "per_channel_statistics.std_of_means"
|
||||
else:
|
||||
continue
|
||||
elif key == "latents_mean":
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
elif key == "latents_std":
|
||||
new_key = "per_channel_statistics.std_of_means"
|
||||
else:
|
||||
continue
|
||||
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
|
||||
"""Load audio encoder from pretrained weights."""
|
||||
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
||||
import json
|
||||
|
||||
model_path = Path(model_path)
|
||||
config = AudioEncoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
||||
encoder = cls(config)
|
||||
weights = mx.load(str(model_path / "model.safetensors"))
|
||||
encoder.load_weights(list(weights.items()), strict=True)
|
||||
return encoder
|
||||
|
||||
def __call__(self, spectrogram: mx.array) -> mx.array:
|
||||
"""Encode audio spectrogram into normalized latent representation.
|
||||
|
||||
Args:
|
||||
spectrogram: (B, C, T, F) PyTorch format or (B, T, F, C) MLX format.
|
||||
Returns:
|
||||
Normalized latent (B, T', F', z_channels) in MLX channels-last format.
|
||||
"""
|
||||
if spectrogram.ndim == 4 and spectrogram.shape[1] == self.in_channels:
|
||||
spectrogram = mx.transpose(spectrogram, (0, 2, 3, 1))
|
||||
|
||||
h = self.conv_in(spectrogram)
|
||||
h = self._run_downsampling_path(h)
|
||||
h = run_mid_block(self.mid, h)
|
||||
h = self._finalize_output(h)
|
||||
return self._normalize_latents(h)
|
||||
|
||||
def _run_downsampling_path(self, h: mx.array) -> mx.array:
|
||||
for level in range(self.num_resolutions):
|
||||
stage = self.down[level]
|
||||
for block_idx in range(self.num_res_blocks):
|
||||
h = stage["block"][block_idx](h, temb=None)
|
||||
if block_idx in stage["attn"]:
|
||||
h = stage["attn"][block_idx](h)
|
||||
if level != self.num_resolutions - 1 and "downsample" in stage:
|
||||
h = stage["downsample"](h)
|
||||
return h
|
||||
|
||||
def _finalize_output(self, h: mx.array) -> mx.array:
|
||||
h = self.norm_out(h)
|
||||
h = nn.silu(h)
|
||||
return self.conv_out(h)
|
||||
|
||||
def _normalize_latents(self, h: mx.array) -> mx.array:
|
||||
"""Normalize encoder output using per-channel statistics.
|
||||
|
||||
Takes first half of channels (mean) when double_z=True,
|
||||
then patchifies, normalizes, and unpatchifies.
|
||||
"""
|
||||
# h shape: (B, T', F', 2*z_channels) in MLX format
|
||||
z_channels = self.z_channels
|
||||
means = h[..., :z_channels]
|
||||
|
||||
latent_shape = AudioLatentShape(
|
||||
batch=means.shape[0],
|
||||
channels=means.shape[3],
|
||||
frames=means.shape[1],
|
||||
mel_bins=means.shape[2],
|
||||
)
|
||||
|
||||
patched = self.patchifier.patchify(means)
|
||||
normalized = self.per_channel_statistics.normalize(patched)
|
||||
return self.patchifier.unpatchify(normalized, latent_shape)
|
||||
|
||||
|
||||
class AudioDecoder(nn.Module):
|
||||
"""
|
||||
Symmetric decoder that reconstructs audio spectrograms from latent features.
|
||||
The decoder mirrors the encoder structure with configurable channel multipliers,
|
||||
attention resolutions, and causal convolutions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AudioDecoderModelConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the AudioDecoder.
|
||||
Args:
|
||||
ch: Base number of feature channels
|
||||
out_ch: Number of output channels (2 for stereo)
|
||||
ch_mult: Multiplicative factors for channels at each resolution
|
||||
num_res_blocks: Number of residual blocks per resolution
|
||||
attn_resolutions: Resolutions at which to apply attention
|
||||
resolution: Input spatial resolution
|
||||
z_channels: Number of latent channels
|
||||
norm_type: Normalization type
|
||||
causality_axis: Axis for causal convolutions
|
||||
dropout: Dropout probability
|
||||
mid_block_add_attention: Whether to add attention in middle block
|
||||
sample_rate: Audio sample rate
|
||||
mel_hop_length: Hop length for mel spectrogram
|
||||
is_causal: Whether to use causal convolutions
|
||||
mel_bins: Number of mel frequency bins
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
||||
# Per-channel statistics for denormalizing latents
|
||||
# Uses ch (base channel count) to match the patchified latent dimension
|
||||
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
|
||||
# After patchify: (B, T, z_channels * latent_mel_bins) = (B, T, 128)
|
||||
# ch=128 matches this dimension, so use ch for per_channel_statistics
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch)
|
||||
self.sample_rate = config.sample_rate
|
||||
self.mel_hop_length = config.mel_hop_length
|
||||
self.is_causal = config.is_causal
|
||||
self.mel_bins = config.mel_bins
|
||||
|
||||
self.patchifier = AudioPatchifier(
|
||||
patch_size=1,
|
||||
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||
sample_rate=config.sample_rate,
|
||||
hop_length=config.mel_hop_length,
|
||||
is_causal=config.is_causal,
|
||||
)
|
||||
|
||||
self.ch = config.ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(config.ch_mult)
|
||||
self.num_res_blocks = config.num_res_blocks
|
||||
self.resolution = config.resolution
|
||||
self.out_ch = config.out_ch
|
||||
self.give_pre_end = config.give_pre_end
|
||||
self.tanh_out = config.tanh_out
|
||||
self.norm_type = config.norm_type
|
||||
self.z_channels = config.z_channels
|
||||
self.channel_multipliers = config.ch_mult
|
||||
self.attn_resolutions = config.attn_resolutions
|
||||
self.causality_axis = config.causality_axis
|
||||
self.attn_type = config.attn_type
|
||||
|
||||
base_block_channels = config.ch * self.channel_multipliers[-1]
|
||||
base_resolution = config.resolution // (2 ** (self.num_resolutions - 1))
|
||||
self.z_shape = (1, config.z_channels, base_resolution, base_resolution)
|
||||
|
||||
self.conv_in = make_conv2d(
|
||||
config.z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
|
||||
self.mid = build_mid_block(
|
||||
channels=base_block_channels,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=config.dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
attn_type=self.attn_type,
|
||||
add_attention=config.mid_block_add_attention,
|
||||
)
|
||||
|
||||
self.up, final_block_channels = build_upsampling_path(
|
||||
ch=config.ch,
|
||||
ch_mult=config.ch_mult,
|
||||
num_resolutions=self.num_resolutions,
|
||||
num_res_blocks=config.num_res_blocks,
|
||||
resolution=config.resolution,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=config.dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
attn_type=self.attn_type,
|
||||
attn_resolutions=config.attn_resolutions,
|
||||
resamp_with_conv=config.resamp_with_conv,
|
||||
initial_block_channels=base_block_channels,
|
||||
)
|
||||
|
||||
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
|
||||
self.conv_out = make_conv2d(
|
||||
final_block_channels, config.out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize audio VAE weight names from PyTorch format to MLX format.
|
||||
|
||||
Args:
|
||||
weights: Dictionary of weights with PyTorch naming
|
||||
|
||||
Returns:
|
||||
Dictionary with MLX-compatible naming for audio VAE decoder
|
||||
"""
|
||||
sanitized = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Handle audio_vae.decoder weights
|
||||
if key.startswith("audio_vae.decoder."):
|
||||
new_key = key.replace("audio_vae.decoder.", "")
|
||||
elif key.startswith("audio_vae.per_channel_statistics."):
|
||||
# Map per-channel statistics
|
||||
if "mean-of-means" in key:
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
elif "std-of-means" in key:
|
||||
new_key = "per_channel_statistics.std_of_means"
|
||||
else:
|
||||
continue # Skip other statistics keys
|
||||
else:
|
||||
continue # Skip non-decoder keys
|
||||
|
||||
# Handle Conv2d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, H, W)
|
||||
# MLX: (out_channels, H, W, in_channels)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
|
||||
"""Load audio VAE decoder from pretrained model."""
|
||||
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
|
||||
import json
|
||||
|
||||
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
||||
decoder = cls(config)
|
||||
weights = mx.load(str(model_path / "model.safetensors"))
|
||||
# weights = decoder.sanitize(weights)
|
||||
decoder.load_weights(list(weights.items()), strict=True)
|
||||
return decoder
|
||||
|
||||
|
||||
def __call__(self, sample: mx.array) -> mx.array:
|
||||
"""
|
||||
Decode latent features back to audio spectrograms.
|
||||
Args:
|
||||
sample: Encoded latent representation of shape (B, H, W, C) in MLX format
|
||||
or (B, C, H, W) in PyTorch format (will be transposed)
|
||||
Returns:
|
||||
Reconstructed audio spectrogram
|
||||
"""
|
||||
# Handle input format - if channels are in dim 1, transpose to channels-last
|
||||
if sample.shape[1] == self.z_channels and sample.ndim == 4:
|
||||
# PyTorch format (B, C, H, W) -> MLX format (B, H, W, C)
|
||||
sample = mx.transpose(sample, (0, 2, 3, 1))
|
||||
|
||||
sample, target_shape = self._denormalize_latents(sample)
|
||||
|
||||
h = self.conv_in(sample)
|
||||
h = run_mid_block(self.mid, h)
|
||||
h = self._run_upsampling_path(h)
|
||||
h = self._finalize_output(h)
|
||||
|
||||
return self._adjust_output_shape(h, target_shape)
|
||||
|
||||
def _denormalize_latents(self, sample: mx.array) -> tuple[mx.array, AudioLatentShape]:
|
||||
"""Denormalize latents using per-channel statistics."""
|
||||
# sample shape: (B, H, W, C) in MLX format
|
||||
latent_shape = AudioLatentShape(
|
||||
batch=sample.shape[0],
|
||||
channels=sample.shape[3], # channels last
|
||||
frames=sample.shape[1], # height = frames
|
||||
mel_bins=sample.shape[2], # width = mel_bins
|
||||
)
|
||||
|
||||
sample_patched = self.patchifier.patchify(sample)
|
||||
sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched)
|
||||
sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
|
||||
|
||||
target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
|
||||
if self.causality_axis != CausalityAxis.NONE:
|
||||
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
|
||||
|
||||
target_shape = AudioLatentShape(
|
||||
batch=latent_shape.batch,
|
||||
channels=self.out_ch,
|
||||
frames=target_frames,
|
||||
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
|
||||
)
|
||||
|
||||
return sample, target_shape
|
||||
|
||||
def _adjust_output_shape(
|
||||
self,
|
||||
decoded_output: mx.array,
|
||||
target_shape: AudioLatentShape,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Adjust output shape to match target dimensions for variable-length audio.
|
||||
Args:
|
||||
decoded_output: Tensor of shape (B, H, W, C) in MLX format
|
||||
target_shape: AudioLatentShape describing target dimensions
|
||||
Returns:
|
||||
Tensor adjusted to match target_shape exactly
|
||||
"""
|
||||
# Current output shape: (batch, frames, mel_bins, channels) in MLX format
|
||||
_, current_time, current_freq, _ = decoded_output.shape
|
||||
target_channels = target_shape.channels
|
||||
target_time = target_shape.frames
|
||||
target_freq = target_shape.mel_bins
|
||||
|
||||
# Step 1: Crop first to avoid exceeding target dimensions
|
||||
decoded_output = decoded_output[
|
||||
:, : min(current_time, target_time), : min(current_freq, target_freq), :target_channels
|
||||
]
|
||||
|
||||
# Step 2: Calculate padding needed for time and frequency dimensions
|
||||
time_padding_needed = target_time - decoded_output.shape[1]
|
||||
freq_padding_needed = target_freq - decoded_output.shape[2]
|
||||
|
||||
# Step 3: Apply padding if needed
|
||||
if time_padding_needed > 0 or freq_padding_needed > 0:
|
||||
# MLX pad: [(before_0, after_0), ...]
|
||||
# For (B, H, W, C): H=time, W=freq
|
||||
padding = [
|
||||
(0, 0), # batch
|
||||
(0, max(time_padding_needed, 0)), # time
|
||||
(0, max(freq_padding_needed, 0)), # freq
|
||||
(0, 0), # channels
|
||||
]
|
||||
decoded_output = mx.pad(decoded_output, padding)
|
||||
|
||||
# Step 4: Final safety crop to ensure exact target shape
|
||||
decoded_output = decoded_output[:, :target_time, :target_freq, :target_channels]
|
||||
|
||||
# Transpose back to PyTorch format (B, C, H, W) for vocoder compatibility
|
||||
decoded_output = mx.transpose(decoded_output, (0, 3, 1, 2))
|
||||
|
||||
return decoded_output
|
||||
|
||||
def _run_upsampling_path(self, h: mx.array) -> mx.array:
|
||||
"""Run through upsampling path."""
|
||||
for level in reversed(range(self.num_resolutions)):
|
||||
stage = self.up[level]
|
||||
for block_idx in range(len(stage["block"])):
|
||||
h = stage["block"][block_idx](h, temb=None)
|
||||
if block_idx in stage["attn"]:
|
||||
h = stage["attn"][block_idx](h)
|
||||
|
||||
if level != 0 and "upsample" in stage:
|
||||
h = stage["upsample"](h)
|
||||
|
||||
return h
|
||||
|
||||
def _finalize_output(self, h: mx.array) -> mx.array:
|
||||
"""Apply final normalization and convolution."""
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nn.silu(h)
|
||||
h = self.conv_out(h)
|
||||
return mx.tanh(h) if self.tanh_out else h
|
||||
|
||||
|
||||
def decode_audio(latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder") -> mx.array:
|
||||
"""
|
||||
Decode an audio latent representation using the provided audio decoder and vocoder.
|
||||
Args:
|
||||
latent: Input audio latent tensor
|
||||
audio_decoder: Model to decode the latent to spectrogram
|
||||
vocoder: Model to convert spectrogram to audio waveform
|
||||
Returns:
|
||||
Decoded audio as a float tensor
|
||||
"""
|
||||
decoded_audio = audio_decoder(latent)
|
||||
decoded_audio = vocoder(decoded_audio)
|
||||
# Remove batch dimension if present
|
||||
if decoded_audio.shape[0] == 1:
|
||||
decoded_audio = decoded_audio[0]
|
||||
return decoded_audio.astype(mx.float32)
|
||||
146
mlx_video/models/ltx_2/audio_vae/causal_conv_2d.py
Normal file
146
mlx_video/models/ltx_2/audio_vae/causal_conv_2d.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Causal 2D convolutions for audio VAE."""
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from ..config import CausalityAxis
|
||||
|
||||
|
||||
def _pair(x: Union[int, Tuple[int, int]]) -> Tuple[int, int]:
|
||||
"""Convert int or tuple to tuple pair."""
|
||||
if isinstance(x, int):
|
||||
return (x, x)
|
||||
return x
|
||||
|
||||
|
||||
class CausalConv2d(nn.Module):
|
||||
"""
|
||||
A causal 2D convolution.
|
||||
This layer ensures that the output at time `t` only depends on inputs
|
||||
at time `t` and earlier. It achieves this by applying asymmetric padding
|
||||
to the time dimension before the convolution.
|
||||
|
||||
Note: MLX Conv2d expects input shape (N, H, W, C) - channels last.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: int = 1,
|
||||
dilation: Union[int, Tuple[int, int]] = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
# Ensure kernel_size and dilation are tuples
|
||||
kernel_size = _pair(kernel_size)
|
||||
dilation = _pair(dilation)
|
||||
|
||||
# Calculate padding dimensions
|
||||
pad_h = (kernel_size[0] - 1) * dilation[0]
|
||||
pad_w = (kernel_size[1] - 1) * dilation[1]
|
||||
|
||||
# Store padding for manual application
|
||||
# MLX pad order: [(before_axis0, after_axis0), (before_axis1, after_axis1), ...]
|
||||
# For (N, H, W, C) format: axis 1 is H (height), axis 2 is W (width)
|
||||
if self.causality_axis == CausalityAxis.NONE:
|
||||
# Non-causal: symmetric padding
|
||||
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w // 2, pad_w - pad_w // 2)
|
||||
elif self.causality_axis in (CausalityAxis.WIDTH, CausalityAxis.WIDTH_COMPATIBILITY):
|
||||
# Causal on width: pad left (before width axis)
|
||||
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w, 0)
|
||||
elif self.causality_axis == CausalityAxis.HEIGHT:
|
||||
# Causal on height: pad top (before height axis)
|
||||
self.padding = (pad_h, 0, pad_w // 2, pad_w - pad_w // 2)
|
||||
else:
|
||||
raise ValueError(f"Invalid causality_axis: {causality_axis}")
|
||||
|
||||
# The internal convolution layer uses no padding, as we handle it manually
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""
|
||||
Forward pass with causal padding.
|
||||
Args:
|
||||
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
|
||||
Returns:
|
||||
Output tensor after causal convolution
|
||||
"""
|
||||
# Apply causal padding before convolution
|
||||
# padding format: (pad_h_top, pad_h_bottom, pad_w_left, pad_w_right)
|
||||
pad_h_top, pad_h_bottom, pad_w_left, pad_w_right = self.padding
|
||||
|
||||
if any(p > 0 for p in self.padding):
|
||||
# MLX pad expects: [(before_0, after_0), (before_1, after_1), ...]
|
||||
# For (N, H, W, C): axis 0=N, axis 1=H, axis 2=W, axis 3=C
|
||||
x = mx.pad(x, [(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)])
|
||||
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
def make_conv2d(
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: int = 1,
|
||||
padding: Union[int, Tuple[int, int], None] = None,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
causality_axis: CausalityAxis | None = None,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Create a 2D convolution layer that can be either causal or non-causal.
|
||||
Args:
|
||||
in_channels: Number of input channels
|
||||
out_channels: Number of output channels
|
||||
kernel_size: Size of the convolution kernel
|
||||
stride: Convolution stride
|
||||
padding: Padding (if None, will be calculated based on causal flag)
|
||||
dilation: Dilation rate
|
||||
groups: Number of groups for grouped convolution
|
||||
bias: Whether to use bias
|
||||
causality_axis: Dimension along which to apply causality.
|
||||
Returns:
|
||||
Either a regular Conv2d or CausalConv2d layer
|
||||
"""
|
||||
if causality_axis is not None:
|
||||
# For causal convolution, padding is handled internally by CausalConv2d
|
||||
return CausalConv2d(
|
||||
in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis
|
||||
)
|
||||
else:
|
||||
# For non-causal convolution, use symmetric padding if not specified
|
||||
if padding is None:
|
||||
if isinstance(kernel_size, int):
|
||||
padding = kernel_size // 2
|
||||
else:
|
||||
padding = tuple(k // 2 for k in kernel_size)
|
||||
|
||||
return nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
)
|
||||
127
mlx_video/models/ltx_2/audio_vae/downsample.py
Normal file
127
mlx_video/models/ltx_2/audio_vae/downsample.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Downsampling layers for audio VAE."""
|
||||
|
||||
from typing import Set, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .attention import AttentionType, make_attn
|
||||
from ..config import CausalityAxis
|
||||
from .normalization import NormType
|
||||
from .resnet import ResnetBlock
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer that can use either a strided convolution
|
||||
or average pooling. Supports standard and causal padding for the
|
||||
convolutional mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
with_conv: bool,
|
||||
causality_axis: CausalityAxis = CausalityAxis.WIDTH,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
|
||||
raise ValueError("causality is only supported when `with_conv=True`.")
|
||||
|
||||
if self.with_conv:
|
||||
# Do time downsampling here
|
||||
# no asymmetric padding in MLX conv, must do it ourselves
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""
|
||||
Forward pass with downsampling.
|
||||
Args:
|
||||
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
|
||||
Returns:
|
||||
Downsampled tensor
|
||||
"""
|
||||
if self.with_conv:
|
||||
# Padding tuple is in the order: (left, right, top, bottom) for PyTorch
|
||||
# For MLX pad: [(before_axis0, after_axis0), ...]
|
||||
# x shape: (N, H, W, C) -> pad on H and W axes
|
||||
if self.causality_axis == CausalityAxis.NONE:
|
||||
# pad: (left=0, right=1, top=0, bottom=1)
|
||||
pad = [(0, 0), (0, 1), (0, 1), (0, 0)]
|
||||
elif self.causality_axis == CausalityAxis.WIDTH:
|
||||
# pad: (left=2, right=0, top=0, bottom=1)
|
||||
pad = [(0, 0), (0, 1), (2, 0), (0, 0)]
|
||||
elif self.causality_axis == CausalityAxis.HEIGHT:
|
||||
# pad: (left=0, right=1, top=2, bottom=0)
|
||||
pad = [(0, 0), (2, 0), (0, 1), (0, 0)]
|
||||
elif self.causality_axis == CausalityAxis.WIDTH_COMPATIBILITY:
|
||||
# pad: (left=1, right=0, top=0, bottom=1)
|
||||
pad = [(0, 0), (0, 1), (1, 0), (0, 0)]
|
||||
else:
|
||||
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||
|
||||
x = mx.pad(x, pad, constant_values=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
# Average pooling with 2x2 kernel and stride 2
|
||||
# MLX doesn't have built-in avg_pool2d, implement manually
|
||||
# x shape: (N, H, W, C)
|
||||
n, h, w, c = x.shape
|
||||
# Reshape to (N, H//2, 2, W//2, 2, C) and mean over pooling dims
|
||||
x = x.reshape(n, h // 2, 2, w // 2, 2, c)
|
||||
x = mx.mean(x, axis=(2, 4))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def build_downsampling_path(
|
||||
*,
|
||||
ch: int,
|
||||
ch_mult: Tuple[int, ...],
|
||||
num_resolutions: int,
|
||||
num_res_blocks: int,
|
||||
resolution: int,
|
||||
temb_channels: int,
|
||||
dropout: float,
|
||||
norm_type: NormType,
|
||||
causality_axis: CausalityAxis,
|
||||
attn_type: AttentionType,
|
||||
attn_resolutions: Set[int],
|
||||
resamp_with_conv: bool,
|
||||
) -> tuple[dict, int]:
|
||||
"""Build the downsampling path with residual blocks, attention, and downsampling layers."""
|
||||
down_modules = {}
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1, *tuple(ch_mult))
|
||||
block_in = ch
|
||||
|
||||
for i_level in range(num_resolutions):
|
||||
stage = {}
|
||||
stage["block"] = {}
|
||||
stage["attn"] = {}
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
|
||||
for i_block in range(num_res_blocks):
|
||||
stage["block"][i_block] = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=temb_channels,
|
||||
dropout=dropout,
|
||||
norm_type=norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
|
||||
|
||||
if i_level != num_resolutions - 1:
|
||||
stage["downsample"] = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||
curr_res = curr_res // 2
|
||||
|
||||
down_modules[i_level] = stage
|
||||
|
||||
return down_modules, block_in
|
||||
59
mlx_video/models/ltx_2/audio_vae/normalization.py
Normal file
59
mlx_video/models/ltx_2/audio_vae/normalization.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Normalization layers for audio VAE."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class NormType(Enum):
|
||||
"""Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
|
||||
|
||||
GROUP = "group"
|
||||
PIXEL = "pixel"
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
"""
|
||||
Per-pixel (per-location) RMS normalization layer.
|
||||
For each element along the chosen dimension, this layer normalizes the tensor
|
||||
by the root-mean-square of its values across that dimension:
|
||||
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
|
||||
"""
|
||||
Args:
|
||||
dim: Dimension along which to compute the RMS (typically channels).
|
||||
eps: Small constant added for numerical stability.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""Apply RMS normalization along the configured dimension."""
|
||||
mean_sq = mx.mean(x**2, axis=self.dim, keepdims=True)
|
||||
rms = mx.sqrt(mean_sq + self.eps)
|
||||
return x / rms
|
||||
|
||||
|
||||
def build_normalization_layer(
|
||||
in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Create a normalization layer based on the normalization type.
|
||||
Args:
|
||||
in_channels: Number of input channels
|
||||
num_groups: Number of groups for group normalization
|
||||
normtype: Type of normalization: "group" or "pixel"
|
||||
Returns:
|
||||
A normalization layer
|
||||
"""
|
||||
if normtype == NormType.GROUP:
|
||||
return nn.GroupNorm(num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True)
|
||||
if normtype == NormType.PIXEL:
|
||||
# For MLX channels-last format (B, H, W, C), normalize along channels (dim=-1)
|
||||
# PyTorch uses dim=1 for channels-first format (B, C, H, W)
|
||||
return PixelNorm(dim=-1, eps=1e-6)
|
||||
raise ValueError(f"Invalid normalization type: {normtype}")
|
||||
98
mlx_video/models/ltx_2/audio_vae/ops.py
Normal file
98
mlx_video/models/ltx_2/audio_vae/ops.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Audio processing utilities for audio VAE."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioLatentShape:
|
||||
"""Shape descriptor for audio latent representations."""
|
||||
|
||||
batch: int
|
||||
channels: int
|
||||
frames: int
|
||||
mel_bins: int
|
||||
|
||||
|
||||
class PerChannelStatistics(nn.Module):
|
||||
"""
|
||||
Per-channel statistics for normalizing and denormalizing the latent representation.
|
||||
This statistics is computed over the entire dataset and stored in model's checkpoint.
|
||||
"""
|
||||
|
||||
def __init__(self, latent_channels: int = 128) -> None:
|
||||
super().__init__()
|
||||
self.latent_channels = latent_channels
|
||||
# Initialize buffers - will be loaded from weights
|
||||
# Using underscores for MLX compatibility with weight loading
|
||||
self.std_of_means = mx.ones((latent_channels,))
|
||||
self.mean_of_means = mx.zeros((latent_channels,))
|
||||
|
||||
def un_normalize(self, x: mx.array) -> mx.array:
|
||||
"""Denormalize latent representation."""
|
||||
# Broadcast statistics to match x shape
|
||||
# x shape: (B, C, ...) or (B, ..., C)
|
||||
std = self.std_of_means.astype(x.dtype)
|
||||
mean = self.mean_of_means.astype(x.dtype)
|
||||
return (x * std) + mean
|
||||
|
||||
def normalize(self, x: mx.array) -> mx.array:
|
||||
"""Normalize latent representation."""
|
||||
std = self.std_of_means.astype(x.dtype)
|
||||
mean = self.mean_of_means.astype(x.dtype)
|
||||
return (x - mean) / std
|
||||
|
||||
|
||||
class AudioPatchifier:
|
||||
"""
|
||||
Audio patchifier for converting between audio latents and patches.
|
||||
Combines channels and mel_bins dimensions for per-channel statistics.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
audio_latent_downsample_factor: int = 4,
|
||||
sample_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
):
|
||||
self.patch_size = patch_size
|
||||
self.audio_latent_downsample_factor = audio_latent_downsample_factor
|
||||
self.sample_rate = sample_rate
|
||||
self.hop_length = hop_length
|
||||
self.is_causal = is_causal
|
||||
|
||||
def patchify(self, x: mx.array) -> mx.array:
|
||||
"""Convert audio latents to patches.
|
||||
|
||||
Input shape: (B, T, F, C) in MLX format (channels last)
|
||||
Output shape: (B, T, C*F) - flattened for per-channel statistics
|
||||
|
||||
The output order is (c f) to match PyTorch's "b c t f -> b t (c f)".
|
||||
"""
|
||||
# x shape: (B, T, F, C) e.g., (1, 68, 16, 8)
|
||||
b, t, f, c = x.shape
|
||||
# Transpose to (B, T, C, F) for correct (c f) ordering
|
||||
x = mx.transpose(x, (0, 1, 3, 2))
|
||||
# Reshape to (B, T, C*F) e.g., (1, 68, 128)
|
||||
return x.reshape(b, t, c * f)
|
||||
|
||||
def unpatchify(self, x: mx.array, latent_shape: AudioLatentShape) -> mx.array:
|
||||
"""Convert patches back to audio latents.
|
||||
|
||||
Input shape: (B, T, C*F)
|
||||
Output shape: (B, T, F, C) in MLX format
|
||||
|
||||
Reverses patchify's "b t (c f) -> b c t f" then transposes to MLX format.
|
||||
"""
|
||||
# x shape: (B, T, C*F) e.g., (1, 68, 128)
|
||||
b, t, cf = x.shape
|
||||
c = latent_shape.channels
|
||||
f = latent_shape.mel_bins
|
||||
# Reshape to (B, T, C, F)
|
||||
x = x.reshape(b, t, c, f)
|
||||
# Transpose to MLX format (B, T, F, C)
|
||||
return mx.transpose(x, (0, 1, 3, 2))
|
||||
185
mlx_video/models/ltx_2/audio_vae/resnet.py
Normal file
185
mlx_video/models/ltx_2/audio_vae/resnet.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""ResNet blocks for audio VAE and vocoder."""
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .normalization import NormType, build_normalization_layer
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
def leaky_relu(x: mx.array, negative_slope: float = LRELU_SLOPE) -> mx.array:
|
||||
"""Leaky ReLU activation."""
|
||||
return mx.maximum(x, x * negative_slope)
|
||||
|
||||
|
||||
class ResBlock1(nn.Module):
|
||||
"""1D ResNet block for vocoder with dilated convolutions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int = 3,
|
||||
dilation: Tuple[int, int, int] = (1, 3, 5),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# First set of convolutions with different dilations
|
||||
self.convs1 = {
|
||||
i: nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=d,
|
||||
padding=(kernel_size - 1) * d // 2,
|
||||
)
|
||||
for i, d in enumerate(dilation)
|
||||
}
|
||||
|
||||
# Second set of convolutions with dilation=1
|
||||
self.convs2 = {
|
||||
i: nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
for i in range(len(dilation))
|
||||
}
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""Forward pass through residual blocks."""
|
||||
for i in range(len(self.convs1)):
|
||||
xt = leaky_relu(x, LRELU_SLOPE)
|
||||
xt = self.convs1[i](xt)
|
||||
xt = leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = self.convs2[i](xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock2(nn.Module):
|
||||
"""1D ResNet block for vocoder (alternative version)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int = 3,
|
||||
dilation: Tuple[int, int] = (1, 3),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.convs = {
|
||||
i: nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=d,
|
||||
padding=(kernel_size - 1) * d // 2,
|
||||
)
|
||||
for i, d in enumerate(dilation)
|
||||
}
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""Forward pass through residual blocks."""
|
||||
for i in range(len(self.convs)):
|
||||
xt = leaky_relu(x, LRELU_SLOPE)
|
||||
xt = self.convs[i](xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
"""2D ResNet block for audio VAE encoder/decoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels: int,
|
||||
out_channels: int | None = None,
|
||||
conv_shortcut: bool = False,
|
||||
dropout: float = 0.0,
|
||||
temb_channels: int = 512,
|
||||
norm_type: NormType = NormType.GROUP,
|
||||
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP:
|
||||
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
|
||||
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.temb_channels = temb_channels
|
||||
|
||||
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
|
||||
self.conv1 = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = nn.Linear(temb_channels, out_channels)
|
||||
|
||||
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
|
||||
self.dropout_rate = dropout
|
||||
self.conv2 = make_conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
temb: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Forward pass through ResNet block.
|
||||
Args:
|
||||
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
|
||||
temb: Optional time embedding tensor
|
||||
Returns:
|
||||
Output tensor
|
||||
"""
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nn.silu(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None and self.temb_channels > 0:
|
||||
# temb: (B, temb_channels) -> (B, out_channels)
|
||||
# Need to add spatial dims: (B, 1, 1, out_channels) for broadcasting
|
||||
h = h + mx.expand_dims(mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1)
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nn.silu(h)
|
||||
if self.dropout_rate > 0:
|
||||
h = nn.Dropout(self.dropout_rate)(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
135
mlx_video/models/ltx_2/audio_vae/upsample.py
Normal file
135
mlx_video/models/ltx_2/audio_vae/upsample.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Upsampling layers for audio VAE."""
|
||||
|
||||
from typing import Set, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .attention import AttentionType, make_attn
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .normalization import NormType
|
||||
from .resnet import ResnetBlock
|
||||
|
||||
|
||||
def nearest_neighbor_upsample(x: mx.array, scale_factor: int = 2) -> mx.array:
|
||||
"""
|
||||
Nearest neighbor upsampling for 4D tensors.
|
||||
Args:
|
||||
x: Input tensor of shape (N, H, W, C)
|
||||
scale_factor: Upsampling factor
|
||||
Returns:
|
||||
Upsampled tensor of shape (N, H*scale_factor, W*scale_factor, C)
|
||||
"""
|
||||
n, h, w, c = x.shape
|
||||
# Repeat along height and width
|
||||
x = mx.repeat(x, scale_factor, axis=1)
|
||||
x = mx.repeat(x, scale_factor, axis=2)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""Upsampling layer with optional convolution."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
with_conv: bool,
|
||||
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
if self.with_conv:
|
||||
self.conv = make_conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""
|
||||
Forward pass with upsampling.
|
||||
Args:
|
||||
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
|
||||
Returns:
|
||||
Upsampled tensor
|
||||
"""
|
||||
# Nearest neighbor 2x upsampling
|
||||
x = nearest_neighbor_upsample(x, scale_factor=2)
|
||||
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
# Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
|
||||
# For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
|
||||
# The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
|
||||
# So the output elements rely on the following windows:
|
||||
# 0: [-,-,0]
|
||||
# 1: [-,0,0]
|
||||
# 2: [0,0,1]
|
||||
# 3: [0,1,1]
|
||||
# 4: [1,1,2]
|
||||
# 5: [1,2,2]
|
||||
# Notice that the first and second elements in the output rely only on the first element in the input,
|
||||
# while all other elements rely on two elements in the input.
|
||||
# So we can drop the first element to undo the padding (rather than the last element).
|
||||
# This is a no-op for non-causal convolutions.
|
||||
if self.causality_axis == CausalityAxis.NONE:
|
||||
pass # x remains unchanged
|
||||
elif self.causality_axis == CausalityAxis.HEIGHT:
|
||||
x = x[:, 1:, :, :]
|
||||
elif self.causality_axis == CausalityAxis.WIDTH:
|
||||
x = x[:, :, 1:, :]
|
||||
elif self.causality_axis == CausalityAxis.WIDTH_COMPATIBILITY:
|
||||
pass # x remains unchanged
|
||||
else:
|
||||
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def build_upsampling_path(
|
||||
*,
|
||||
ch: int,
|
||||
ch_mult: Tuple[int, ...],
|
||||
num_resolutions: int,
|
||||
num_res_blocks: int,
|
||||
resolution: int,
|
||||
temb_channels: int,
|
||||
dropout: float,
|
||||
norm_type: NormType,
|
||||
causality_axis: CausalityAxis,
|
||||
attn_type: AttentionType,
|
||||
attn_resolutions: Set[int],
|
||||
resamp_with_conv: bool,
|
||||
initial_block_channels: int,
|
||||
) -> tuple[dict, int]:
|
||||
"""Build the upsampling path with residual blocks, attention, and upsampling layers."""
|
||||
up_modules = {}
|
||||
block_in = initial_block_channels
|
||||
curr_res = resolution // (2 ** (num_resolutions - 1))
|
||||
|
||||
for level in reversed(range(num_resolutions)):
|
||||
stage = {}
|
||||
stage["block"] = {}
|
||||
stage["attn"] = {}
|
||||
block_out = ch * ch_mult[level]
|
||||
|
||||
for i_block in range(num_res_blocks + 1):
|
||||
stage["block"][i_block] = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=temb_channels,
|
||||
dropout=dropout,
|
||||
norm_type=norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
|
||||
|
||||
if level != 0:
|
||||
stage["upsample"] = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||
curr_res *= 2
|
||||
|
||||
up_modules[level] = stage
|
||||
|
||||
return up_modules, block_in
|
||||
689
mlx_video/models/ltx_2/audio_vae/vocoder.py
Normal file
689
mlx_video/models/ltx_2/audio_vae/vocoder.py
Normal file
@@ -0,0 +1,689 @@
|
||||
"""Vocoder for converting mel spectrograms to audio waveforms.
|
||||
|
||||
Supports:
|
||||
- HiFi-GAN (LTX-2): ResBlock1 with LeakyReLU
|
||||
- BigVGAN v2 (LTX-2.3): AMPBlock1 with Snake/SnakeBeta + anti-aliased resampling
|
||||
- VocoderWithBWE (LTX-2.3): Base vocoder + bandwidth extension (16kHz -> 48kHz)
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from ..config import VocoderModelConfig
|
||||
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, leaky_relu
|
||||
|
||||
|
||||
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Snake / SnakeBeta activations (BigVGAN v2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Snake(nn.Module):
|
||||
"""Snake activation: x + (1/alpha) * sin^2(alpha * x)."""
|
||||
|
||||
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# x: (N, L, C) in MLX format
|
||||
alpha = self.alpha # (C,)
|
||||
if self.alpha_logscale:
|
||||
alpha = mx.exp(alpha)
|
||||
return x + (1.0 / (alpha + 1e-9)) * mx.power(mx.sin(x * alpha), 2)
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
"""SnakeBeta activation: x + (1/beta) * sin^2(alpha * x)."""
|
||||
|
||||
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
self.beta = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
alpha = self.alpha
|
||||
beta = self.beta
|
||||
if self.alpha_logscale:
|
||||
alpha = mx.exp(alpha)
|
||||
beta = mx.exp(beta)
|
||||
return x + (1.0 / (beta + 1e-9)) * mx.power(mx.sin(x * alpha), 2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Anti-aliased resampling (Kaiser-sinc filters)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sinc(x: mx.array) -> mx.array:
|
||||
return mx.where(
|
||||
x == 0,
|
||||
mx.ones_like(x),
|
||||
mx.sin(mx.array(math.pi) * x) / (mx.array(math.pi) * x),
|
||||
)
|
||||
|
||||
|
||||
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> mx.array:
|
||||
"""Compute a Kaiser-windowed sinc filter."""
|
||||
even = kernel_size % 2 == 0
|
||||
half_size = kernel_size // 2
|
||||
delta_f = 4 * half_width
|
||||
amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||
if amplitude > 50.0:
|
||||
beta = 0.1102 * (amplitude - 8.7)
|
||||
elif amplitude >= 21.0:
|
||||
beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
|
||||
else:
|
||||
beta = 0.0
|
||||
|
||||
# Kaiser window - compute using scipy-compatible formula
|
||||
import numpy as np
|
||||
window = mx.array(np.kaiser(kernel_size, beta).astype(np.float32))
|
||||
|
||||
if even:
|
||||
time = mx.arange(-half_size, half_size).astype(mx.float32) + 0.5
|
||||
else:
|
||||
time = mx.arange(kernel_size).astype(mx.float32) - half_size
|
||||
|
||||
if cutoff == 0:
|
||||
filter_ = mx.zeros_like(time)
|
||||
else:
|
||||
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
|
||||
filter_ = filter_ / mx.sum(filter_)
|
||||
|
||||
return filter_.reshape(1, 1, kernel_size)
|
||||
|
||||
|
||||
def hann_sinc_filter1d(ratio: int) -> Tuple[mx.array, int, int, int]:
|
||||
"""Compute a Hann-windowed sinc filter for upsampling (used by BWE resampler)."""
|
||||
import numpy as np
|
||||
rolloff = 0.99
|
||||
lowpass_filter_width = 6
|
||||
width = math.ceil(lowpass_filter_width / rolloff)
|
||||
kernel_size = 2 * width * ratio + 1
|
||||
pad = width
|
||||
pad_left = 2 * width * ratio
|
||||
pad_right = kernel_size - ratio
|
||||
|
||||
time = (np.arange(kernel_size) / ratio - width) * rolloff
|
||||
time_clamped = np.clip(time, -lowpass_filter_width, lowpass_filter_width)
|
||||
window = np.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
|
||||
sinc_filter = np.sinc(time) * window * rolloff / ratio
|
||||
|
||||
filter_ = mx.array(sinc_filter.astype(np.float32)).reshape(1, 1, kernel_size)
|
||||
return filter_, pad, pad_left, pad_right
|
||||
|
||||
|
||||
class LowPassFilter1d(nn.Module):
|
||||
"""Low-pass filter using depthwise convolution with Kaiser-sinc kernel."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cutoff: float = 0.5,
|
||||
half_width: float = 0.6,
|
||||
stride: int = 1,
|
||||
kernel_size: int = 12,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.even = kernel_size % 2 == 0
|
||||
self.pad_left = kernel_size // 2 - int(self.even)
|
||||
self.pad_right = kernel_size // 2
|
||||
self.stride = stride
|
||||
# Filter buffer - shape (1, 1, K) in PyTorch format, loaded from weights
|
||||
self.filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# x: (N, L, C) in MLX format
|
||||
n, l, c = x.shape
|
||||
|
||||
# Pad with edge values: replicate first/last value
|
||||
first = mx.repeat(x[:, :1, :], self.pad_left, axis=1)
|
||||
last = mx.repeat(x[:, -1:, :], self.pad_right, axis=1)
|
||||
x = mx.concatenate([first, x, last], axis=1)
|
||||
|
||||
# Expand filter for depthwise conv: (1, 1, K) -> (C, K, 1) for groups=C
|
||||
# Filter is stored in PyTorch format (1, 1, K), need (C, K, 1) MLX format
|
||||
filt = self.filter.astype(x.dtype) # (1, 1, K)
|
||||
filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1)
|
||||
filt = mx.repeat(filt, c, axis=0) # (C, K, 1)
|
||||
|
||||
# Transpose x for depthwise conv: (N, L, C) -> (N*C, L, 1) then conv
|
||||
x = mx.transpose(x, (0, 2, 1)) # (N, C, L)
|
||||
x = x.reshape(n * c, -1, 1) # (N*C, L, 1)
|
||||
|
||||
x = mx.conv1d(x, filt[:1], stride=self.stride, groups=1) # (N*C, L', 1)
|
||||
|
||||
x = x.reshape(n, c, -1) # (N, C, L')
|
||||
x = mx.transpose(x, (0, 2, 1)) # (N, L', C)
|
||||
return x
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
"""Anti-aliased upsampling using transposed convolution with sinc filter."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ratio: int = 2,
|
||||
kernel_size: int = None,
|
||||
window_type: str = "kaiser",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.stride = ratio
|
||||
|
||||
if window_type == "hann":
|
||||
filt, self.pad, self.pad_left, self.pad_right = hann_sinc_filter1d(ratio)
|
||||
self.kernel_size = filt.shape[2]
|
||||
self.filter = filt
|
||||
else:
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
self.filter = kaiser_sinc_filter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
kernel_size=self.kernel_size,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# x: (N, L, C) in MLX format
|
||||
n, l, c = x.shape
|
||||
|
||||
# Pad with edge values
|
||||
first = mx.repeat(x[:, :1, :], self.pad, axis=1)
|
||||
last = mx.repeat(x[:, -1:, :], self.pad, axis=1)
|
||||
x = mx.concatenate([first, x, last], axis=1)
|
||||
|
||||
# Process per-channel via reshape: (N, L, C) -> (N*C, L, 1)
|
||||
x = mx.transpose(x, (0, 2, 1)) # (N, C, L)
|
||||
x = x.reshape(n * c, -1, 1) # (N*C, L, 1)
|
||||
|
||||
# Transposed conv for upsampling
|
||||
# Filter: (1, 1, K) PyTorch -> (1, K, 1) MLX format for conv_transpose1d
|
||||
filt = self.filter.astype(x.dtype) # (1, 1, K)
|
||||
filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1)
|
||||
|
||||
x = self.ratio * mx.conv_transpose1d(x, filt, stride=self.stride) # (N*C, L', 1)
|
||||
|
||||
# Trim padding
|
||||
x = x[:, self.pad_left:-self.pad_right, :]
|
||||
|
||||
x = x.reshape(n, c, -1) # (N, C, L')
|
||||
x = mx.transpose(x, (0, 2, 1)) # (N, L', C)
|
||||
return x
|
||||
|
||||
|
||||
class DownSample1d(nn.Module):
|
||||
"""Anti-aliased downsampling using low-pass filter."""
|
||||
|
||||
def __init__(self, ratio: int = 2, kernel_size: int = None) -> None:
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.lowpass = LowPassFilter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
stride=ratio,
|
||||
kernel_size=kernel_size,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.lowpass(x)
|
||||
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
"""Anti-aliased activation: upsample -> activate -> downsample."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activation: nn.Module,
|
||||
up_ratio: int = 2,
|
||||
down_ratio: int = 2,
|
||||
up_kernel_size: int = 12,
|
||||
down_kernel_size: int = 12,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.act = activation
|
||||
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = self.upsample(x)
|
||||
x = self.act(x)
|
||||
return self.downsample(x)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AMPBlock1 (BigVGAN v2 residual block)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AMPBlock1(nn.Module):
|
||||
"""BigVGAN v2 residual block with anti-aliased Snake activations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int = 3,
|
||||
dilation: Tuple[int, int, int] = (1, 3, 5),
|
||||
activation: str = "snakebeta",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
||||
|
||||
self.convs1 = {
|
||||
i: nn.Conv1d(
|
||||
channels, channels, kernel_size, stride=1,
|
||||
dilation=d, padding=get_padding(kernel_size, d),
|
||||
)
|
||||
for i, d in enumerate(dilation)
|
||||
}
|
||||
|
||||
self.convs2 = {
|
||||
i: nn.Conv1d(
|
||||
channels, channels, kernel_size, stride=1,
|
||||
dilation=1, padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
for i in range(len(dilation))
|
||||
}
|
||||
|
||||
self.acts1 = {i: Activation1d(act_cls(channels)) for i in range(len(dilation))}
|
||||
self.acts2 = {i: Activation1d(act_cls(channels)) for i in range(len(dilation))}
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
for i in range(len(self.convs1)):
|
||||
xt = self.acts1[i](x)
|
||||
xt = self.convs1[i](xt)
|
||||
xt = self.acts2[i](xt)
|
||||
xt = self.convs2[i](xt)
|
||||
x = x + xt
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# STFT and MelSTFT (for BWE)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class STFTFn(nn.Module):
|
||||
"""STFT via conv1d with precomputed DFT x window bases (loaded from checkpoint)."""
|
||||
|
||||
def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None:
|
||||
super().__init__()
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
n_freqs = filter_length // 2 + 1
|
||||
# Buffers loaded from checkpoint - PyTorch format (n_freqs*2, 1, filter_length)
|
||||
self.forward_basis = mx.zeros((n_freqs * 2, 1, filter_length))
|
||||
self.inverse_basis = mx.zeros((n_freqs * 2, 1, filter_length))
|
||||
|
||||
def __call__(self, y: mx.array) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute magnitude and phase from waveform.
|
||||
|
||||
Args:
|
||||
y: (B, T) waveform
|
||||
|
||||
Returns:
|
||||
magnitude: (B, n_freqs, T_frames)
|
||||
phase: (B, n_freqs, T_frames)
|
||||
"""
|
||||
if y.ndim == 2:
|
||||
y = mx.expand_dims(y, -1) # (B, T, 1)
|
||||
|
||||
left_pad = max(0, self.win_length - self.hop_length)
|
||||
if left_pad > 0:
|
||||
first = mx.repeat(y[:, :1, :], left_pad, axis=1)
|
||||
y = mx.concatenate([first, y], axis=1)
|
||||
|
||||
# forward_basis: (514, 1, 512) PyTorch format -> (514, 512, 1) MLX
|
||||
basis = mx.transpose(self.forward_basis.astype(y.dtype), (0, 2, 1)) # (514, K, 1)
|
||||
|
||||
# Conv1d: (B, T, 1) * (514, K, 1) -> (B, T_frames, 514)
|
||||
spec = mx.conv1d(y, basis, stride=self.hop_length)
|
||||
|
||||
# Split real and imaginary
|
||||
n_freqs = spec.shape[-1] // 2
|
||||
real = spec[..., :n_freqs]
|
||||
imag = spec[..., n_freqs:]
|
||||
|
||||
magnitude = mx.sqrt(real ** 2 + imag ** 2)
|
||||
phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype(real.dtype)
|
||||
|
||||
# Output: (B, T_frames, n_freqs) in MLX channels-last
|
||||
return magnitude, phase
|
||||
|
||||
|
||||
class MelSTFT(nn.Module):
|
||||
"""Causal log-mel spectrogram from precomputed STFT bases."""
|
||||
|
||||
def __init__(self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int) -> None:
|
||||
super().__init__()
|
||||
self.stft_fn = STFTFn(filter_length, hop_length, win_length)
|
||||
n_freqs = filter_length // 2 + 1
|
||||
self.mel_basis = mx.zeros((n_mel_channels, n_freqs))
|
||||
|
||||
def mel_spectrogram(self, y: mx.array) -> mx.array:
|
||||
"""Compute log-mel spectrogram.
|
||||
|
||||
Args:
|
||||
y: (B, T) waveform
|
||||
|
||||
Returns:
|
||||
log_mel: (B, n_mels, T_frames) in channels-first for compatibility
|
||||
"""
|
||||
magnitude, phase = self.stft_fn(y)
|
||||
# magnitude: (B, T_frames, n_freqs)
|
||||
mel = magnitude @ self.mel_basis.astype(magnitude.dtype).T # (B, T_frames, n_mels)
|
||||
log_mel = mx.log(mx.clip(mel, 1e-5, None))
|
||||
# Transpose to (B, n_mels, T_frames) for compatibility with vocoder input format
|
||||
return mx.transpose(log_mel, (0, 2, 1))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Vocoder (supports both HiFi-GAN and BigVGAN v2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Vocoder(nn.Module):
|
||||
"""Vocoder for mel-to-waveform synthesis.
|
||||
|
||||
Supports resblock="1" (HiFi-GAN / LTX-2) and resblock="AMP1" (BigVGAN v2 / LTX-2.3).
|
||||
"""
|
||||
|
||||
def __init__(self, config: VocoderModelConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.output_sampling_rate = config.output_sample_rate
|
||||
self.num_kernels = len(config.resblock_kernel_sizes)
|
||||
self.num_upsamples = len(config.upsample_rates)
|
||||
self.upsample_rates = config.upsample_rates
|
||||
self.is_amp = config.resblock == "AMP1"
|
||||
self.use_tanh_at_final = config.use_tanh_at_final
|
||||
self.apply_final_activation = config.apply_final_activation
|
||||
|
||||
in_channels = 128 if config.stereo else 64
|
||||
self.conv_pre = nn.Conv1d(
|
||||
in_channels, config.upsample_initial_channel,
|
||||
kernel_size=7, stride=1, padding=3,
|
||||
)
|
||||
|
||||
# Upsampling layers
|
||||
self.ups = {}
|
||||
for i, (stride, kernel_size) in enumerate(
|
||||
zip(config.upsample_rates, config.upsample_kernel_sizes)
|
||||
):
|
||||
in_ch = config.upsample_initial_channel // (2 ** i)
|
||||
out_ch = config.upsample_initial_channel // (2 ** (i + 1))
|
||||
self.ups[i] = nn.ConvTranspose1d(
|
||||
in_ch, out_ch,
|
||||
kernel_size=kernel_size, stride=stride,
|
||||
padding=(kernel_size - stride) // 2,
|
||||
)
|
||||
|
||||
# Residual blocks
|
||||
if self.is_amp:
|
||||
self.resblocks = {}
|
||||
block_idx = 0
|
||||
for i in range(len(self.ups)):
|
||||
ch = config.upsample_initial_channel // (2 ** (i + 1))
|
||||
for kernel_size, dilations in zip(
|
||||
config.resblock_kernel_sizes, config.resblock_dilation_sizes
|
||||
):
|
||||
self.resblocks[block_idx] = AMPBlock1(
|
||||
ch, kernel_size, tuple(dilations),
|
||||
activation=config.activation,
|
||||
)
|
||||
block_idx += 1
|
||||
else:
|
||||
resblock_class = ResBlock1 if config.resblock == "1" else ResBlock2
|
||||
self.resblocks = {}
|
||||
block_idx = 0
|
||||
for i in range(len(self.ups)):
|
||||
ch = config.upsample_initial_channel // (2 ** (i + 1))
|
||||
for kernel_size, dilations in zip(
|
||||
config.resblock_kernel_sizes, config.resblock_dilation_sizes
|
||||
):
|
||||
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations))
|
||||
block_idx += 1
|
||||
|
||||
final_channels = config.upsample_initial_channel // (2 ** len(config.upsample_rates))
|
||||
|
||||
# Post-activation
|
||||
if self.is_amp:
|
||||
act_cls = SnakeBeta if config.activation == "snakebeta" else Snake
|
||||
self.act_post = Activation1d(act_cls(final_channels))
|
||||
|
||||
# Final conv
|
||||
out_channels = 2 if config.stereo else 1
|
||||
self.conv_post = nn.Conv1d(
|
||||
final_channels, out_channels,
|
||||
kernel_size=7, stride=1, padding=3,
|
||||
bias=config.use_bias_at_final,
|
||||
)
|
||||
|
||||
self.upsample_factor = math.prod(config.upsample_rates)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Mel spectrogram (B, C, T, mel_bins) for stereo or (B, T, mel_bins) mono.
|
||||
|
||||
Returns:
|
||||
Waveform (B, out_channels, T_audio) in channels-first format.
|
||||
"""
|
||||
# (B, C, T, mel) -> (B, C, mel, T)
|
||||
x = mx.transpose(x, (0, 1, 3, 2))
|
||||
|
||||
if x.ndim == 4: # stereo: (B, 2, mel, T) -> (B, 2*mel, T)
|
||||
b, s, c, t = x.shape
|
||||
x = x.reshape(b, s * c, t)
|
||||
|
||||
# Channels-first (B, C, T) -> channels-last (B, T, C) for MLX conv
|
||||
x = mx.transpose(x, (0, 2, 1))
|
||||
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
if not self.is_amp:
|
||||
x = leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
|
||||
start = i * self.num_kernels
|
||||
end = start + self.num_kernels
|
||||
|
||||
block_outputs = mx.stack(
|
||||
[self.resblocks[idx](x) for idx in range(start, end)],
|
||||
axis=0,
|
||||
)
|
||||
x = mx.mean(block_outputs, axis=0)
|
||||
|
||||
if self.is_amp:
|
||||
x = self.act_post(x)
|
||||
else:
|
||||
x = nn.leaky_relu(x)
|
||||
|
||||
x = self.conv_post(x)
|
||||
|
||||
if self.apply_final_activation:
|
||||
x = mx.tanh(x) if self.use_tanh_at_final else mx.clip(x, -1, 1)
|
||||
|
||||
# Back to channels-first (B, T, C) -> (B, C, T)
|
||||
x = mx.transpose(x, (0, 2, 1))
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VocoderWithBWE (Bandwidth Extension)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class VocoderWithBWE(nn.Module):
|
||||
"""Vocoder + bandwidth extension upsampling (16kHz -> 48kHz).
|
||||
|
||||
Chains a base vocoder with a BWE generator that predicts a residual
|
||||
added to a sinc-resampled skip connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocoder: Vocoder,
|
||||
bwe_generator: Vocoder,
|
||||
mel_stft: MelSTFT,
|
||||
input_sampling_rate: int = 16000,
|
||||
output_sampling_rate: int = 48000,
|
||||
hop_length: int = 80,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.vocoder = vocoder
|
||||
self.bwe_generator = bwe_generator
|
||||
self.mel_stft = mel_stft
|
||||
self.input_sampling_rate = input_sampling_rate
|
||||
self.output_sampling_rate = output_sampling_rate
|
||||
self.hop_length = hop_length
|
||||
# Hann-windowed sinc resampler (not stored in checkpoint)
|
||||
self.resampler = UpSample1d(
|
||||
ratio=output_sampling_rate // input_sampling_rate,
|
||||
window_type="hann",
|
||||
)
|
||||
|
||||
@property
|
||||
def output_sample_rate(self) -> int:
|
||||
return self.output_sampling_rate
|
||||
|
||||
def _compute_mel(self, audio: mx.array) -> mx.array:
|
||||
"""Compute log-mel spectrogram from waveform.
|
||||
|
||||
Args:
|
||||
audio: (B, C, T) waveform in channels-first
|
||||
|
||||
Returns:
|
||||
mel: (B, C, n_mels, T_frames)
|
||||
"""
|
||||
batch, n_channels, _ = audio.shape
|
||||
flat = audio.reshape(batch * n_channels, -1) # (B*C, T)
|
||||
mel = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
|
||||
return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2])
|
||||
|
||||
def __call__(self, mel_spec: mx.array) -> mx.array:
|
||||
"""Run vocoder + BWE.
|
||||
|
||||
Args:
|
||||
mel_spec: Mel spectrogram, same format as Vocoder.forward input.
|
||||
|
||||
Returns:
|
||||
Waveform (B, out_channels, T_audio) at output_sampling_rate.
|
||||
"""
|
||||
x = self.vocoder(mel_spec) # (B, C, T) at input_sampling_rate
|
||||
_, _, length_low_rate = x.shape
|
||||
output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
|
||||
|
||||
# Pad to hop_length multiple
|
||||
remainder = length_low_rate % self.hop_length
|
||||
if remainder != 0:
|
||||
pad_amount = self.hop_length - remainder
|
||||
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_amount)])
|
||||
|
||||
# Compute mel from vocoder output: (B, C, n_mels, T_frames)
|
||||
mel = self._compute_mel(x)
|
||||
|
||||
# BWE expects (B, C, T_frames, mel_bins) -> transpose last two dims
|
||||
mel_for_bwe = mx.transpose(mel, (0, 1, 3, 2)) # (B, C, T_frames, n_mels)
|
||||
residual = self.bwe_generator(mel_for_bwe) # (B, C, T_high)
|
||||
|
||||
# Sinc upsample skip connection
|
||||
# resampler expects (N, L, C): transpose from (B, C, T) -> (B, T, C)
|
||||
x_for_resample = mx.transpose(x, (0, 2, 1))
|
||||
skip = self.resampler(x_for_resample)
|
||||
skip = mx.transpose(skip, (0, 2, 1)) # back to (B, C, T)
|
||||
|
||||
return mx.clip(residual + skip, -1, 1)[..., :output_length]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factory / from_pretrained
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_vocoder(model_path: Path) -> nn.Module:
|
||||
"""Load vocoder from pretrained model directory.
|
||||
|
||||
Automatically detects whether to load a simple Vocoder or VocoderWithBWE.
|
||||
"""
|
||||
import json
|
||||
|
||||
config_path = model_path / "config.json"
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"No config.json found in {model_path}")
|
||||
|
||||
with open(config_path) as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
weights = mx.load(str(model_path / "model.safetensors"))
|
||||
|
||||
has_bwe = config_dict.get("has_bwe_generator", False)
|
||||
|
||||
if has_bwe:
|
||||
return _load_vocoder_with_bwe(config_dict, weights)
|
||||
else:
|
||||
config = VocoderModelConfig.from_dict(config_dict)
|
||||
model = Vocoder(config)
|
||||
model.load_weights(list(weights.items()), strict=True)
|
||||
return model
|
||||
|
||||
|
||||
def _load_vocoder_with_bwe(config_dict: dict, weights: dict) -> VocoderWithBWE:
|
||||
"""Load VocoderWithBWE from config and weights."""
|
||||
# Build vocoder from config
|
||||
vocoder_cfg = config_dict.get("vocoder", {})
|
||||
vocoder_config = VocoderModelConfig.from_dict(vocoder_cfg)
|
||||
vocoder = Vocoder(vocoder_config)
|
||||
|
||||
# Build BWE generator from config
|
||||
bwe_cfg = config_dict.get("bwe", {})
|
||||
bwe_config = VocoderModelConfig.from_dict(bwe_cfg)
|
||||
bwe_config.apply_final_activation = False
|
||||
bwe_generator = Vocoder(bwe_config)
|
||||
|
||||
# MelSTFT from weight shapes
|
||||
stft_basis = weights.get("mel_stft.stft_fn.forward_basis")
|
||||
filter_length = stft_basis.shape[2] if stft_basis is not None else 512
|
||||
mel_basis = weights.get("mel_stft.mel_basis")
|
||||
n_mel_channels = mel_basis.shape[0] if mel_basis is not None else 64
|
||||
|
||||
hop_length = bwe_cfg.get("hop_length", 80)
|
||||
input_sr = bwe_cfg.get("input_sampling_rate", 16000)
|
||||
output_sr = bwe_cfg.get("output_sampling_rate", 48000)
|
||||
|
||||
mel_stft = MelSTFT(
|
||||
filter_length=filter_length,
|
||||
hop_length=hop_length,
|
||||
win_length=filter_length,
|
||||
n_mel_channels=n_mel_channels,
|
||||
)
|
||||
|
||||
model = VocoderWithBWE(
|
||||
vocoder=vocoder,
|
||||
bwe_generator=bwe_generator,
|
||||
mel_stft=mel_stft,
|
||||
input_sampling_rate=input_sr,
|
||||
output_sampling_rate=output_sr,
|
||||
hop_length=hop_length,
|
||||
)
|
||||
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
return model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user