136 lines
3.8 KiB
Python
136 lines
3.8 KiB
Python
"""Audio processing utilities for loading audio files and computing mel-spectrograms.
|
|
|
|
Matches the PyTorch AudioProcessor from LTX-2 (torchaudio.transforms.MelSpectrogram)
|
|
using librosa for macOS/MLX compatibility.
|
|
"""
|
|
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import mlx.core as mx
|
|
|
|
|
|
def load_audio(
|
|
path: str,
|
|
target_sr: int = 16000,
|
|
start_time: float = 0.0,
|
|
max_duration: float | None = None,
|
|
mono: bool = False,
|
|
) -> tuple[np.ndarray, int]:
|
|
"""Load audio file, resample to target sample rate.
|
|
|
|
Args:
|
|
path: Path to audio file (WAV, FLAC, MP3, OGG, or video with audio track).
|
|
target_sr: Target sample rate (default 16000 Hz).
|
|
start_time: Start time in seconds.
|
|
max_duration: Maximum duration in seconds. None = read to end.
|
|
mono: If True, convert to mono. Default False (preserve channels).
|
|
|
|
Returns:
|
|
(waveform, sample_rate) where waveform is (channels, samples) float32 numpy array.
|
|
"""
|
|
import librosa
|
|
|
|
# librosa.load returns mono by default; we want to preserve stereo
|
|
y, sr = librosa.load(
|
|
path,
|
|
sr=target_sr,
|
|
mono=mono,
|
|
offset=start_time,
|
|
duration=max_duration,
|
|
)
|
|
|
|
# Ensure 2D: (channels, samples)
|
|
if y.ndim == 1:
|
|
y = y[np.newaxis, :] # (1, samples)
|
|
|
|
return y.astype(np.float32), sr
|
|
|
|
|
|
def ensure_stereo(waveform: np.ndarray) -> np.ndarray:
|
|
"""Ensure waveform is stereo (2, samples). Duplicates mono if needed."""
|
|
if waveform.ndim == 1:
|
|
waveform = waveform[np.newaxis, :]
|
|
if waveform.shape[0] == 1:
|
|
waveform = np.concatenate([waveform, waveform], axis=0)
|
|
elif waveform.shape[0] > 2:
|
|
waveform = waveform[:2]
|
|
return waveform
|
|
|
|
|
|
def waveform_to_mel(
|
|
waveform: np.ndarray,
|
|
sample_rate: int = 16000,
|
|
n_fft: int = 1024,
|
|
hop_length: int = 160,
|
|
win_length: int = 1024,
|
|
n_mels: int = 64,
|
|
fmin: float = 0.0,
|
|
fmax: float = 8000.0,
|
|
) -> mx.array:
|
|
"""Convert waveform to log-mel spectrogram matching PyTorch MelSpectrogram.
|
|
|
|
PyTorch reference:
|
|
MelSpectrogram(sample_rate=16000, n_fft=1024, win_length=1024, hop_length=160,
|
|
f_min=0.0, f_max=8000.0, n_mels=64, power=1.0,
|
|
mel_scale="slaney", norm="slaney", center=True, pad_mode="reflect")
|
|
|
|
Args:
|
|
waveform: (channels, samples) float32 numpy array.
|
|
sample_rate: Sample rate of the waveform.
|
|
n_fft: FFT size.
|
|
hop_length: Hop length.
|
|
win_length: Window length.
|
|
n_mels: Number of mel bins.
|
|
fmin: Minimum frequency for mel filterbank.
|
|
fmax: Maximum frequency for mel filterbank.
|
|
|
|
Returns:
|
|
Log-mel spectrogram as mx.array of shape (1, channels, time, n_mels).
|
|
"""
|
|
import librosa
|
|
|
|
# Ensure 2D
|
|
if waveform.ndim == 1:
|
|
waveform = waveform[np.newaxis, :]
|
|
|
|
channels = waveform.shape[0]
|
|
mels = []
|
|
|
|
for ch in range(channels):
|
|
# Magnitude spectrogram (power=1.0)
|
|
S = np.abs(librosa.stft(
|
|
waveform[ch],
|
|
n_fft=n_fft,
|
|
hop_length=hop_length,
|
|
win_length=win_length,
|
|
center=True,
|
|
pad_mode="reflect",
|
|
))
|
|
|
|
# Mel filterbank with slaney normalization
|
|
mel_basis = librosa.filters.mel(
|
|
sr=sample_rate,
|
|
n_fft=n_fft,
|
|
n_mels=n_mels,
|
|
fmin=fmin,
|
|
fmax=fmax,
|
|
norm="slaney",
|
|
)
|
|
mel = mel_basis @ S
|
|
|
|
# Log scale
|
|
mel = np.log(np.clip(mel, a_min=1e-5, a_max=None))
|
|
|
|
# Transpose: (n_mels, time) -> (time, n_mels)
|
|
mel = mel.T
|
|
mels.append(mel)
|
|
|
|
# Stack channels: (channels, time, n_mels)
|
|
mel_spec = np.stack(mels, axis=0)
|
|
|
|
# Add batch dim: (1, channels, time, n_mels)
|
|
mel_spec = mel_spec[np.newaxis, ...]
|
|
|
|
return mx.array(mel_spec, dtype=mx.float32)
|