Add audio to video conditioning
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
"""Audio VAE module for LTX-2 audio generation."""
|
||||
|
||||
from .attention import AttentionType, AttnBlock, make_attn
|
||||
from .audio_vae import AudioDecoder, decode_audio
|
||||
from .audio_vae import AudioDecoder, AudioEncoder, decode_audio
|
||||
from .audio_processor import load_audio, ensure_stereo, waveform_to_mel
|
||||
from .causal_conv_2d import CausalConv2d, make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .downsample import Downsample, build_downsampling_path
|
||||
@@ -13,10 +14,15 @@ from .vocoder import Vocoder, load_vocoder
|
||||
|
||||
__all__ = [
|
||||
# Main components
|
||||
"AudioEncoder",
|
||||
"AudioDecoder",
|
||||
"Vocoder",
|
||||
"load_vocoder",
|
||||
"decode_audio",
|
||||
# Audio processing
|
||||
"load_audio",
|
||||
"ensure_stereo",
|
||||
"waveform_to_mel",
|
||||
# Ops
|
||||
"AudioLatentShape",
|
||||
"AudioPatchifier",
|
||||
|
||||
135
mlx_video/models/ltx/audio_vae/audio_processor.py
Normal file
135
mlx_video/models/ltx/audio_vae/audio_processor.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Audio processing utilities for loading audio files and computing mel-spectrograms.
|
||||
|
||||
Matches the PyTorch AudioProcessor from LTX-2 (torchaudio.transforms.MelSpectrogram)
|
||||
using librosa for macOS/MLX compatibility.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def load_audio(
|
||||
path: str,
|
||||
target_sr: int = 16000,
|
||||
start_time: float = 0.0,
|
||||
max_duration: float | None = None,
|
||||
mono: bool = False,
|
||||
) -> tuple[np.ndarray, int]:
|
||||
"""Load audio file, resample to target sample rate.
|
||||
|
||||
Args:
|
||||
path: Path to audio file (WAV, FLAC, MP3, OGG, or video with audio track).
|
||||
target_sr: Target sample rate (default 16000 Hz).
|
||||
start_time: Start time in seconds.
|
||||
max_duration: Maximum duration in seconds. None = read to end.
|
||||
mono: If True, convert to mono. Default False (preserve channels).
|
||||
|
||||
Returns:
|
||||
(waveform, sample_rate) where waveform is (channels, samples) float32 numpy array.
|
||||
"""
|
||||
import librosa
|
||||
|
||||
# librosa.load returns mono by default; we want to preserve stereo
|
||||
y, sr = librosa.load(
|
||||
path,
|
||||
sr=target_sr,
|
||||
mono=mono,
|
||||
offset=start_time,
|
||||
duration=max_duration,
|
||||
)
|
||||
|
||||
# Ensure 2D: (channels, samples)
|
||||
if y.ndim == 1:
|
||||
y = y[np.newaxis, :] # (1, samples)
|
||||
|
||||
return y.astype(np.float32), sr
|
||||
|
||||
|
||||
def ensure_stereo(waveform: np.ndarray) -> np.ndarray:
|
||||
"""Ensure waveform is stereo (2, samples). Duplicates mono if needed."""
|
||||
if waveform.ndim == 1:
|
||||
waveform = waveform[np.newaxis, :]
|
||||
if waveform.shape[0] == 1:
|
||||
waveform = np.concatenate([waveform, waveform], axis=0)
|
||||
elif waveform.shape[0] > 2:
|
||||
waveform = waveform[:2]
|
||||
return waveform
|
||||
|
||||
|
||||
def waveform_to_mel(
|
||||
waveform: np.ndarray,
|
||||
sample_rate: int = 16000,
|
||||
n_fft: int = 1024,
|
||||
hop_length: int = 160,
|
||||
win_length: int = 1024,
|
||||
n_mels: int = 64,
|
||||
fmin: float = 0.0,
|
||||
fmax: float = 8000.0,
|
||||
) -> mx.array:
|
||||
"""Convert waveform to log-mel spectrogram matching PyTorch MelSpectrogram.
|
||||
|
||||
PyTorch reference:
|
||||
MelSpectrogram(sample_rate=16000, n_fft=1024, win_length=1024, hop_length=160,
|
||||
f_min=0.0, f_max=8000.0, n_mels=64, power=1.0,
|
||||
mel_scale="slaney", norm="slaney", center=True, pad_mode="reflect")
|
||||
|
||||
Args:
|
||||
waveform: (channels, samples) float32 numpy array.
|
||||
sample_rate: Sample rate of the waveform.
|
||||
n_fft: FFT size.
|
||||
hop_length: Hop length.
|
||||
win_length: Window length.
|
||||
n_mels: Number of mel bins.
|
||||
fmin: Minimum frequency for mel filterbank.
|
||||
fmax: Maximum frequency for mel filterbank.
|
||||
|
||||
Returns:
|
||||
Log-mel spectrogram as mx.array of shape (1, channels, time, n_mels).
|
||||
"""
|
||||
import librosa
|
||||
|
||||
# Ensure 2D
|
||||
if waveform.ndim == 1:
|
||||
waveform = waveform[np.newaxis, :]
|
||||
|
||||
channels = waveform.shape[0]
|
||||
mels = []
|
||||
|
||||
for ch in range(channels):
|
||||
# Magnitude spectrogram (power=1.0)
|
||||
S = np.abs(librosa.stft(
|
||||
waveform[ch],
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
))
|
||||
|
||||
# Mel filterbank with slaney normalization
|
||||
mel_basis = librosa.filters.mel(
|
||||
sr=sample_rate,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mels,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
norm="slaney",
|
||||
)
|
||||
mel = mel_basis @ S
|
||||
|
||||
# Log scale
|
||||
mel = np.log(np.clip(mel, a_min=1e-5, a_max=None))
|
||||
|
||||
# Transpose: (n_mels, time) -> (time, n_mels)
|
||||
mel = mel.T
|
||||
mels.append(mel)
|
||||
|
||||
# Stack channels: (channels, time, n_mels)
|
||||
mel_spec = np.stack(mels, axis=0)
|
||||
|
||||
# Add batch dim: (1, channels, time, n_mels)
|
||||
mel_spec = mel_spec[np.newaxis, ...]
|
||||
|
||||
return mx.array(mel_spec, dtype=mx.float32)
|
||||
@@ -6,10 +6,11 @@ from pathlib import Path
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_vlm.models.base import check_array_shape
|
||||
from ..config import AudioDecoderModelConfig
|
||||
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig
|
||||
from .attention import AttentionType, make_attn
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .downsample import build_downsampling_path
|
||||
from .normalization import NormType, build_normalization_layer
|
||||
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||
from .resnet import ResnetBlock
|
||||
@@ -59,6 +60,179 @@ def run_mid_block(mid: dict, features: mx.array) -> mx.array:
|
||||
return mid["block_2"](features, temb=None)
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
"""Encoder that compresses audio spectrograms into latent representations."""
|
||||
|
||||
def __init__(self, config: AudioEncoderModelConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch)
|
||||
self.sample_rate = config.sample_rate
|
||||
self.mel_hop_length = config.mel_hop_length
|
||||
self.is_causal = config.is_causal
|
||||
self.mel_bins = config.mel_bins
|
||||
|
||||
self.patchifier = AudioPatchifier(
|
||||
patch_size=1,
|
||||
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||
sample_rate=config.sample_rate,
|
||||
hop_length=config.mel_hop_length,
|
||||
is_causal=config.is_causal,
|
||||
)
|
||||
|
||||
self.ch = config.ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(config.ch_mult)
|
||||
self.num_res_blocks = config.num_res_blocks
|
||||
self.resolution = config.resolution
|
||||
self.in_channels = config.in_channels
|
||||
self.z_channels = config.z_channels
|
||||
self.double_z = config.double_z
|
||||
self.norm_type = config.norm_type
|
||||
self.causality_axis = config.causality_axis
|
||||
self.attn_type = config.attn_type
|
||||
|
||||
self.conv_in = make_conv2d(
|
||||
config.in_channels, self.ch, kernel_size=3, stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
self.down, block_in = build_downsampling_path(
|
||||
ch=config.ch,
|
||||
ch_mult=config.ch_mult,
|
||||
num_resolutions=self.num_resolutions,
|
||||
num_res_blocks=config.num_res_blocks,
|
||||
resolution=config.resolution,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=config.dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
attn_type=self.attn_type,
|
||||
attn_resolutions=config.attn_resolutions or set(),
|
||||
resamp_with_conv=config.resamp_with_conv,
|
||||
)
|
||||
|
||||
self.mid = build_mid_block(
|
||||
channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=config.dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
attn_type=self.attn_type,
|
||||
add_attention=config.mid_block_add_attention,
|
||||
)
|
||||
|
||||
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
|
||||
out_channels = 2 * config.z_channels if config.double_z else config.z_channels
|
||||
self.conv_out = make_conv2d(
|
||||
block_in, out_channels, kernel_size=3, stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize audio encoder weights from PyTorch format."""
|
||||
sanitized = {}
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
if key.startswith("audio_vae.encoder."):
|
||||
new_key = key.replace("audio_vae.encoder.", "")
|
||||
elif key.startswith("encoder."):
|
||||
new_key = key.replace("encoder.", "")
|
||||
elif key.startswith("audio_vae.per_channel_statistics."):
|
||||
if "mean-of-means" in key:
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
elif "std-of-means" in key:
|
||||
new_key = "per_channel_statistics.std_of_means"
|
||||
else:
|
||||
continue
|
||||
elif "per_channel_statistics" in key:
|
||||
if "mean-of-means" in key or "latents_mean" in key:
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
elif "std-of-means" in key or "latents_std" in key:
|
||||
new_key = "per_channel_statistics.std_of_means"
|
||||
else:
|
||||
continue
|
||||
elif key == "latents_mean":
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
elif key == "latents_std":
|
||||
new_key = "per_channel_statistics.std_of_means"
|
||||
else:
|
||||
continue
|
||||
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
|
||||
"""Load audio encoder from pretrained weights."""
|
||||
from mlx_video.models.ltx.config import AudioEncoderModelConfig
|
||||
import json
|
||||
|
||||
model_path = Path(model_path)
|
||||
config = AudioEncoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
||||
encoder = cls(config)
|
||||
weights = mx.load(str(model_path / "model.safetensors"))
|
||||
encoder.load_weights(list(weights.items()), strict=True)
|
||||
return encoder
|
||||
|
||||
def __call__(self, spectrogram: mx.array) -> mx.array:
|
||||
"""Encode audio spectrogram into normalized latent representation.
|
||||
|
||||
Args:
|
||||
spectrogram: (B, C, T, F) PyTorch format or (B, T, F, C) MLX format.
|
||||
Returns:
|
||||
Normalized latent (B, T', F', z_channels) in MLX channels-last format.
|
||||
"""
|
||||
if spectrogram.ndim == 4 and spectrogram.shape[1] == self.in_channels:
|
||||
spectrogram = mx.transpose(spectrogram, (0, 2, 3, 1))
|
||||
|
||||
h = self.conv_in(spectrogram)
|
||||
h = self._run_downsampling_path(h)
|
||||
h = run_mid_block(self.mid, h)
|
||||
h = self._finalize_output(h)
|
||||
return self._normalize_latents(h)
|
||||
|
||||
def _run_downsampling_path(self, h: mx.array) -> mx.array:
|
||||
for level in range(self.num_resolutions):
|
||||
stage = self.down[level]
|
||||
for block_idx in range(self.num_res_blocks):
|
||||
h = stage["block"][block_idx](h, temb=None)
|
||||
if block_idx in stage["attn"]:
|
||||
h = stage["attn"][block_idx](h)
|
||||
if level != self.num_resolutions - 1 and "downsample" in stage:
|
||||
h = stage["downsample"](h)
|
||||
return h
|
||||
|
||||
def _finalize_output(self, h: mx.array) -> mx.array:
|
||||
h = self.norm_out(h)
|
||||
h = nn.silu(h)
|
||||
return self.conv_out(h)
|
||||
|
||||
def _normalize_latents(self, h: mx.array) -> mx.array:
|
||||
"""Normalize encoder output using per-channel statistics.
|
||||
|
||||
Takes first half of channels (mean) when double_z=True,
|
||||
then patchifies, normalizes, and unpatchifies.
|
||||
"""
|
||||
# h shape: (B, T', F', 2*z_channels) in MLX format
|
||||
z_channels = self.z_channels
|
||||
means = h[..., :z_channels]
|
||||
|
||||
latent_shape = AudioLatentShape(
|
||||
batch=means.shape[0],
|
||||
channels=means.shape[3],
|
||||
frames=means.shape[1],
|
||||
mel_bins=means.shape[2],
|
||||
)
|
||||
|
||||
patched = self.patchifier.patchify(means)
|
||||
normalized = self.per_channel_statistics.normalize(patched)
|
||||
return self.patchifier.unpatchify(normalized, latent_shape)
|
||||
|
||||
|
||||
class AudioDecoder(nn.Module):
|
||||
"""
|
||||
Symmetric decoder that reconstructs audio spectrograms from latent features.
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Tuple, Set
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
|
||||
class LTXModelType(Enum):
|
||||
@@ -252,6 +252,47 @@ class AudioDecoderModelConfig(BaseModelConfig):
|
||||
if isinstance(self.attn_type, str):
|
||||
self.attn_type = AttentionType(self.attn_type)
|
||||
|
||||
@dataclass
|
||||
class AudioEncoderModelConfig(BaseModelConfig):
|
||||
ch: int = 128
|
||||
in_channels: int = 2
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4)
|
||||
num_res_blocks: int = 2
|
||||
attn_resolutions: Optional[List[int]] = None
|
||||
resolution: int = 256
|
||||
z_channels: int = 8
|
||||
double_z: bool = True
|
||||
n_fft: int = 1024
|
||||
norm_type: Enum = None
|
||||
causality_axis: Enum = None
|
||||
dropout: float = 0.0
|
||||
mid_block_add_attention: bool = True
|
||||
sample_rate: int = 16000
|
||||
mel_hop_length: int = 160
|
||||
is_causal: bool = True
|
||||
mel_bins: int = 64
|
||||
resamp_with_conv: bool = True
|
||||
attn_type: str = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if self.attn_resolutions is not None:
|
||||
result["attn_resolutions"] = list(self.attn_resolutions)
|
||||
return result
|
||||
|
||||
def __post_init__(self):
|
||||
"""Convert string enum values to proper enum types."""
|
||||
from .audio_vae.normalization import NormType
|
||||
from .audio_vae.attention import AttentionType
|
||||
|
||||
if isinstance(self.causality_axis, str):
|
||||
self.causality_axis = CausalityAxis(self.causality_axis)
|
||||
if isinstance(self.norm_type, str):
|
||||
self.norm_type = NormType(self.norm_type)
|
||||
if isinstance(self.attn_type, str):
|
||||
self.attn_type = AttentionType(self.attn_type)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocoderModelConfig(BaseModelConfig):
|
||||
resblock_kernel_sizes: Optional[List[int]] = None
|
||||
|
||||
Reference in New Issue
Block a user