Add audio to video conditioning

This commit is contained in:
Prince Canuma
2026-03-16 01:42:11 +01:00
parent f53b9e0807
commit 6f6105b715
7 changed files with 623 additions and 62 deletions

View File

@@ -0,0 +1,135 @@
"""Audio processing utilities for loading audio files and computing mel-spectrograms.
Matches the PyTorch AudioProcessor from LTX-2 (torchaudio.transforms.MelSpectrogram)
using librosa for macOS/MLX compatibility.
"""
from pathlib import Path
import numpy as np
import mlx.core as mx
def load_audio(
path: str,
target_sr: int = 16000,
start_time: float = 0.0,
max_duration: float | None = None,
mono: bool = False,
) -> tuple[np.ndarray, int]:
"""Load audio file, resample to target sample rate.
Args:
path: Path to audio file (WAV, FLAC, MP3, OGG, or video with audio track).
target_sr: Target sample rate (default 16000 Hz).
start_time: Start time in seconds.
max_duration: Maximum duration in seconds. None = read to end.
mono: If True, convert to mono. Default False (preserve channels).
Returns:
(waveform, sample_rate) where waveform is (channels, samples) float32 numpy array.
"""
import librosa
# librosa.load returns mono by default; we want to preserve stereo
y, sr = librosa.load(
path,
sr=target_sr,
mono=mono,
offset=start_time,
duration=max_duration,
)
# Ensure 2D: (channels, samples)
if y.ndim == 1:
y = y[np.newaxis, :] # (1, samples)
return y.astype(np.float32), sr
def ensure_stereo(waveform: np.ndarray) -> np.ndarray:
"""Ensure waveform is stereo (2, samples). Duplicates mono if needed."""
if waveform.ndim == 1:
waveform = waveform[np.newaxis, :]
if waveform.shape[0] == 1:
waveform = np.concatenate([waveform, waveform], axis=0)
elif waveform.shape[0] > 2:
waveform = waveform[:2]
return waveform
def waveform_to_mel(
waveform: np.ndarray,
sample_rate: int = 16000,
n_fft: int = 1024,
hop_length: int = 160,
win_length: int = 1024,
n_mels: int = 64,
fmin: float = 0.0,
fmax: float = 8000.0,
) -> mx.array:
"""Convert waveform to log-mel spectrogram matching PyTorch MelSpectrogram.
PyTorch reference:
MelSpectrogram(sample_rate=16000, n_fft=1024, win_length=1024, hop_length=160,
f_min=0.0, f_max=8000.0, n_mels=64, power=1.0,
mel_scale="slaney", norm="slaney", center=True, pad_mode="reflect")
Args:
waveform: (channels, samples) float32 numpy array.
sample_rate: Sample rate of the waveform.
n_fft: FFT size.
hop_length: Hop length.
win_length: Window length.
n_mels: Number of mel bins.
fmin: Minimum frequency for mel filterbank.
fmax: Maximum frequency for mel filterbank.
Returns:
Log-mel spectrogram as mx.array of shape (1, channels, time, n_mels).
"""
import librosa
# Ensure 2D
if waveform.ndim == 1:
waveform = waveform[np.newaxis, :]
channels = waveform.shape[0]
mels = []
for ch in range(channels):
# Magnitude spectrogram (power=1.0)
S = np.abs(librosa.stft(
waveform[ch],
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
center=True,
pad_mode="reflect",
))
# Mel filterbank with slaney normalization
mel_basis = librosa.filters.mel(
sr=sample_rate,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
norm="slaney",
)
mel = mel_basis @ S
# Log scale
mel = np.log(np.clip(mel, a_min=1e-5, a_max=None))
# Transpose: (n_mels, time) -> (time, n_mels)
mel = mel.T
mels.append(mel)
# Stack channels: (channels, time, n_mels)
mel_spec = np.stack(mels, axis=0)
# Add batch dim: (1, channels, time, n_mels)
mel_spec = mel_spec[np.newaxis, ...]
return mx.array(mel_spec, dtype=mx.float32)