572 lines
21 KiB
Python
572 lines
21 KiB
Python
"""Audio VAE encoder and decoder for LTX-2."""
|
|
|
|
from pathlib import Path
|
|
from typing import Dict
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
from mlx_vlm.models.base import check_array_shape
|
|
|
|
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig, CausalityAxis
|
|
from .attention import AttentionType, make_attn
|
|
from .causal_conv_2d import make_conv2d
|
|
from .downsample import build_downsampling_path
|
|
from .normalization import NormType, build_normalization_layer
|
|
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
|
from .resnet import ResnetBlock
|
|
from .upsample import build_upsampling_path
|
|
|
|
LATENT_DOWNSAMPLE_FACTOR = 4
|
|
|
|
|
|
def build_mid_block(
|
|
channels: int,
|
|
temb_channels: int,
|
|
dropout: float,
|
|
norm_type: NormType,
|
|
causality_axis: CausalityAxis,
|
|
attn_type: AttentionType,
|
|
add_attention: bool,
|
|
) -> dict:
|
|
"""Build the middle block with two ResNet blocks and optional attention."""
|
|
mid = {}
|
|
mid["block_1"] = ResnetBlock(
|
|
in_channels=channels,
|
|
out_channels=channels,
|
|
temb_channels=temb_channels,
|
|
dropout=dropout,
|
|
norm_type=norm_type,
|
|
causality_axis=causality_axis,
|
|
)
|
|
mid["attn_1"] = (
|
|
make_attn(channels, attn_type=attn_type, norm_type=norm_type)
|
|
if add_attention
|
|
else None
|
|
)
|
|
mid["block_2"] = ResnetBlock(
|
|
in_channels=channels,
|
|
out_channels=channels,
|
|
temb_channels=temb_channels,
|
|
dropout=dropout,
|
|
norm_type=norm_type,
|
|
causality_axis=causality_axis,
|
|
)
|
|
return mid
|
|
|
|
|
|
def run_mid_block(mid: dict, features: mx.array) -> mx.array:
|
|
"""Run features through the middle block."""
|
|
features = mid["block_1"](features, temb=None)
|
|
if mid["attn_1"] is not None:
|
|
features = mid["attn_1"](features)
|
|
return mid["block_2"](features, temb=None)
|
|
|
|
|
|
class AudioEncoder(nn.Module):
|
|
"""Encoder that compresses audio spectrograms into latent representations."""
|
|
|
|
def __init__(self, config: AudioEncoderModelConfig) -> None:
|
|
super().__init__()
|
|
|
|
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch)
|
|
self.sample_rate = config.sample_rate
|
|
self.mel_hop_length = config.mel_hop_length
|
|
self.is_causal = config.is_causal
|
|
self.mel_bins = config.mel_bins
|
|
|
|
self.patchifier = AudioPatchifier(
|
|
patch_size=1,
|
|
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
|
sample_rate=config.sample_rate,
|
|
hop_length=config.mel_hop_length,
|
|
is_causal=config.is_causal,
|
|
)
|
|
|
|
self.ch = config.ch
|
|
self.temb_ch = 0
|
|
self.num_resolutions = len(config.ch_mult)
|
|
self.num_res_blocks = config.num_res_blocks
|
|
self.resolution = config.resolution
|
|
self.in_channels = config.in_channels
|
|
self.z_channels = config.z_channels
|
|
self.double_z = config.double_z
|
|
self.norm_type = config.norm_type
|
|
self.causality_axis = config.causality_axis
|
|
self.attn_type = config.attn_type
|
|
|
|
self.conv_in = make_conv2d(
|
|
config.in_channels,
|
|
self.ch,
|
|
kernel_size=3,
|
|
stride=1,
|
|
causality_axis=self.causality_axis,
|
|
)
|
|
|
|
self.down, block_in = build_downsampling_path(
|
|
ch=config.ch,
|
|
ch_mult=config.ch_mult,
|
|
num_resolutions=self.num_resolutions,
|
|
num_res_blocks=config.num_res_blocks,
|
|
resolution=config.resolution,
|
|
temb_channels=self.temb_ch,
|
|
dropout=config.dropout,
|
|
norm_type=self.norm_type,
|
|
causality_axis=self.causality_axis,
|
|
attn_type=self.attn_type,
|
|
attn_resolutions=config.attn_resolutions or set(),
|
|
resamp_with_conv=config.resamp_with_conv,
|
|
)
|
|
|
|
self.mid = build_mid_block(
|
|
channels=block_in,
|
|
temb_channels=self.temb_ch,
|
|
dropout=config.dropout,
|
|
norm_type=self.norm_type,
|
|
causality_axis=self.causality_axis,
|
|
attn_type=self.attn_type,
|
|
add_attention=config.mid_block_add_attention,
|
|
)
|
|
|
|
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
|
|
out_channels = 2 * config.z_channels if config.double_z else config.z_channels
|
|
self.conv_out = make_conv2d(
|
|
block_in,
|
|
out_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
causality_axis=self.causality_axis,
|
|
)
|
|
|
|
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|
"""Sanitize audio encoder weights from PyTorch format."""
|
|
sanitized = {}
|
|
for key, value in weights.items():
|
|
new_key = key
|
|
if key.startswith("audio_vae.encoder."):
|
|
new_key = key.replace("audio_vae.encoder.", "")
|
|
elif key.startswith("encoder."):
|
|
new_key = key.replace("encoder.", "")
|
|
elif key.startswith("audio_vae.per_channel_statistics."):
|
|
if "mean-of-means" in key:
|
|
new_key = "per_channel_statistics.mean_of_means"
|
|
elif "std-of-means" in key:
|
|
new_key = "per_channel_statistics.std_of_means"
|
|
else:
|
|
continue
|
|
elif "per_channel_statistics" in key:
|
|
if "mean-of-means" in key or "latents_mean" in key:
|
|
new_key = "per_channel_statistics.mean_of_means"
|
|
elif "std-of-means" in key or "latents_std" in key:
|
|
new_key = "per_channel_statistics.std_of_means"
|
|
else:
|
|
continue
|
|
elif key == "latents_mean":
|
|
new_key = "per_channel_statistics.mean_of_means"
|
|
elif key == "latents_std":
|
|
new_key = "per_channel_statistics.std_of_means"
|
|
else:
|
|
continue
|
|
|
|
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
|
value = (
|
|
value
|
|
if check_array_shape(value)
|
|
else mx.transpose(value, (0, 2, 3, 1))
|
|
)
|
|
|
|
sanitized[new_key] = value
|
|
return sanitized
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
|
|
"""Load audio encoder from pretrained weights."""
|
|
import json
|
|
|
|
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
|
|
|
model_path = Path(model_path)
|
|
config = AudioEncoderModelConfig.from_dict(
|
|
json.load(open(model_path / "config.json"))
|
|
)
|
|
encoder = cls(config)
|
|
weights = mx.load(str(model_path / "model.safetensors"))
|
|
encoder.load_weights(list(weights.items()), strict=True)
|
|
return encoder
|
|
|
|
def __call__(self, spectrogram: mx.array) -> mx.array:
|
|
"""Encode audio spectrogram into normalized latent representation.
|
|
|
|
Args:
|
|
spectrogram: (B, C, T, F) PyTorch format or (B, T, F, C) MLX format.
|
|
Returns:
|
|
Normalized latent (B, T', F', z_channels) in MLX channels-last format.
|
|
"""
|
|
if spectrogram.ndim == 4 and spectrogram.shape[1] == self.in_channels:
|
|
spectrogram = mx.transpose(spectrogram, (0, 2, 3, 1))
|
|
|
|
h = self.conv_in(spectrogram)
|
|
h = self._run_downsampling_path(h)
|
|
h = run_mid_block(self.mid, h)
|
|
h = self._finalize_output(h)
|
|
return self._normalize_latents(h)
|
|
|
|
def _run_downsampling_path(self, h: mx.array) -> mx.array:
|
|
for level in range(self.num_resolutions):
|
|
stage = self.down[level]
|
|
for block_idx in range(self.num_res_blocks):
|
|
h = stage["block"][block_idx](h, temb=None)
|
|
if block_idx in stage["attn"]:
|
|
h = stage["attn"][block_idx](h)
|
|
if level != self.num_resolutions - 1 and "downsample" in stage:
|
|
h = stage["downsample"](h)
|
|
return h
|
|
|
|
def _finalize_output(self, h: mx.array) -> mx.array:
|
|
h = self.norm_out(h)
|
|
h = nn.silu(h)
|
|
return self.conv_out(h)
|
|
|
|
def _normalize_latents(self, h: mx.array) -> mx.array:
|
|
"""Normalize encoder output using per-channel statistics.
|
|
|
|
Takes first half of channels (mean) when double_z=True,
|
|
then patchifies, normalizes, and unpatchifies.
|
|
"""
|
|
# h shape: (B, T', F', 2*z_channels) in MLX format
|
|
z_channels = self.z_channels
|
|
means = h[..., :z_channels]
|
|
|
|
latent_shape = AudioLatentShape(
|
|
batch=means.shape[0],
|
|
channels=means.shape[3],
|
|
frames=means.shape[1],
|
|
mel_bins=means.shape[2],
|
|
)
|
|
|
|
patched = self.patchifier.patchify(means)
|
|
normalized = self.per_channel_statistics.normalize(patched)
|
|
return self.patchifier.unpatchify(normalized, latent_shape)
|
|
|
|
|
|
class AudioDecoder(nn.Module):
|
|
"""
|
|
Symmetric decoder that reconstructs audio spectrograms from latent features.
|
|
The decoder mirrors the encoder structure with configurable channel multipliers,
|
|
attention resolutions, and causal convolutions.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: AudioDecoderModelConfig,
|
|
) -> None:
|
|
"""
|
|
Initialize the AudioDecoder.
|
|
Args:
|
|
ch: Base number of feature channels
|
|
out_ch: Number of output channels (2 for stereo)
|
|
ch_mult: Multiplicative factors for channels at each resolution
|
|
num_res_blocks: Number of residual blocks per resolution
|
|
attn_resolutions: Resolutions at which to apply attention
|
|
resolution: Input spatial resolution
|
|
z_channels: Number of latent channels
|
|
norm_type: Normalization type
|
|
causality_axis: Axis for causal convolutions
|
|
dropout: Dropout probability
|
|
mid_block_add_attention: Whether to add attention in middle block
|
|
sample_rate: Audio sample rate
|
|
mel_hop_length: Hop length for mel spectrogram
|
|
is_causal: Whether to use causal convolutions
|
|
mel_bins: Number of mel frequency bins
|
|
"""
|
|
super().__init__()
|
|
|
|
# Per-channel statistics for denormalizing latents
|
|
# Uses ch (base channel count) to match the patchified latent dimension
|
|
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
|
|
# After patchify: (B, T, z_channels * latent_mel_bins) = (B, T, 128)
|
|
# ch=128 matches this dimension, so use ch for per_channel_statistics
|
|
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch)
|
|
self.sample_rate = config.sample_rate
|
|
self.mel_hop_length = config.mel_hop_length
|
|
self.is_causal = config.is_causal
|
|
self.mel_bins = config.mel_bins
|
|
|
|
self.patchifier = AudioPatchifier(
|
|
patch_size=1,
|
|
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
|
sample_rate=config.sample_rate,
|
|
hop_length=config.mel_hop_length,
|
|
is_causal=config.is_causal,
|
|
)
|
|
|
|
self.ch = config.ch
|
|
self.temb_ch = 0
|
|
self.num_resolutions = len(config.ch_mult)
|
|
self.num_res_blocks = config.num_res_blocks
|
|
self.resolution = config.resolution
|
|
self.out_ch = config.out_ch
|
|
self.give_pre_end = config.give_pre_end
|
|
self.tanh_out = config.tanh_out
|
|
self.norm_type = config.norm_type
|
|
self.z_channels = config.z_channels
|
|
self.channel_multipliers = config.ch_mult
|
|
self.attn_resolutions = config.attn_resolutions
|
|
self.causality_axis = config.causality_axis
|
|
self.attn_type = config.attn_type
|
|
|
|
base_block_channels = config.ch * self.channel_multipliers[-1]
|
|
base_resolution = config.resolution // (2 ** (self.num_resolutions - 1))
|
|
self.z_shape = (1, config.z_channels, base_resolution, base_resolution)
|
|
|
|
self.conv_in = make_conv2d(
|
|
config.z_channels,
|
|
base_block_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
causality_axis=self.causality_axis,
|
|
)
|
|
|
|
self.mid = build_mid_block(
|
|
channels=base_block_channels,
|
|
temb_channels=self.temb_ch,
|
|
dropout=config.dropout,
|
|
norm_type=self.norm_type,
|
|
causality_axis=self.causality_axis,
|
|
attn_type=self.attn_type,
|
|
add_attention=config.mid_block_add_attention,
|
|
)
|
|
|
|
self.up, final_block_channels = build_upsampling_path(
|
|
ch=config.ch,
|
|
ch_mult=config.ch_mult,
|
|
num_resolutions=self.num_resolutions,
|
|
num_res_blocks=config.num_res_blocks,
|
|
resolution=config.resolution,
|
|
temb_channels=self.temb_ch,
|
|
dropout=config.dropout,
|
|
norm_type=self.norm_type,
|
|
causality_axis=self.causality_axis,
|
|
attn_type=self.attn_type,
|
|
attn_resolutions=config.attn_resolutions,
|
|
resamp_with_conv=config.resamp_with_conv,
|
|
initial_block_channels=base_block_channels,
|
|
)
|
|
|
|
self.norm_out = build_normalization_layer(
|
|
final_block_channels, normtype=self.norm_type
|
|
)
|
|
self.conv_out = make_conv2d(
|
|
final_block_channels,
|
|
config.out_ch,
|
|
kernel_size=3,
|
|
stride=1,
|
|
causality_axis=self.causality_axis,
|
|
)
|
|
|
|
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|
"""Sanitize audio VAE weight names from PyTorch format to MLX format.
|
|
|
|
Args:
|
|
weights: Dictionary of weights with PyTorch naming
|
|
|
|
Returns:
|
|
Dictionary with MLX-compatible naming for audio VAE decoder
|
|
"""
|
|
sanitized = {}
|
|
|
|
for key, value in weights.items():
|
|
new_key = key
|
|
|
|
# Handle audio_vae.decoder weights
|
|
if key.startswith("audio_vae.decoder."):
|
|
new_key = key.replace("audio_vae.decoder.", "")
|
|
elif key.startswith("audio_vae.per_channel_statistics."):
|
|
# Map per-channel statistics
|
|
if "mean-of-means" in key:
|
|
new_key = "per_channel_statistics.mean_of_means"
|
|
elif "std-of-means" in key:
|
|
new_key = "per_channel_statistics.std_of_means"
|
|
else:
|
|
continue # Skip other statistics keys
|
|
else:
|
|
continue # Skip non-decoder keys
|
|
|
|
# Handle Conv2d weight shape conversion
|
|
# PyTorch: (out_channels, in_channels, H, W)
|
|
# MLX: (out_channels, H, W, in_channels)
|
|
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
|
value = (
|
|
value
|
|
if check_array_shape(value)
|
|
else mx.transpose(value, (0, 2, 3, 1))
|
|
)
|
|
|
|
sanitized[new_key] = value
|
|
|
|
return sanitized
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
|
|
"""Load audio VAE decoder from pretrained model."""
|
|
import json
|
|
|
|
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
|
|
|
|
config = AudioDecoderModelConfig.from_dict(
|
|
json.load(open(model_path / "config.json"))
|
|
)
|
|
decoder = cls(config)
|
|
weights = mx.load(str(model_path / "model.safetensors"))
|
|
# weights = decoder.sanitize(weights)
|
|
decoder.load_weights(list(weights.items()), strict=True)
|
|
return decoder
|
|
|
|
def __call__(self, sample: mx.array) -> mx.array:
|
|
"""
|
|
Decode latent features back to audio spectrograms.
|
|
Args:
|
|
sample: Encoded latent representation of shape (B, H, W, C) in MLX format
|
|
or (B, C, H, W) in PyTorch format (will be transposed)
|
|
Returns:
|
|
Reconstructed audio spectrogram
|
|
"""
|
|
# Handle input format - if channels are in dim 1, transpose to channels-last
|
|
if sample.shape[1] == self.z_channels and sample.ndim == 4:
|
|
# PyTorch format (B, C, H, W) -> MLX format (B, H, W, C)
|
|
sample = mx.transpose(sample, (0, 2, 3, 1))
|
|
|
|
sample, target_shape = self._denormalize_latents(sample)
|
|
|
|
h = self.conv_in(sample)
|
|
h = run_mid_block(self.mid, h)
|
|
h = self._run_upsampling_path(h)
|
|
h = self._finalize_output(h)
|
|
|
|
return self._adjust_output_shape(h, target_shape)
|
|
|
|
def _denormalize_latents(
|
|
self, sample: mx.array
|
|
) -> tuple[mx.array, AudioLatentShape]:
|
|
"""Denormalize latents using per-channel statistics."""
|
|
# sample shape: (B, H, W, C) in MLX format
|
|
latent_shape = AudioLatentShape(
|
|
batch=sample.shape[0],
|
|
channels=sample.shape[3], # channels last
|
|
frames=sample.shape[1], # height = frames
|
|
mel_bins=sample.shape[2], # width = mel_bins
|
|
)
|
|
|
|
sample_patched = self.patchifier.patchify(sample)
|
|
sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched)
|
|
sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
|
|
|
|
target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
|
|
if self.causality_axis != CausalityAxis.NONE:
|
|
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
|
|
|
|
target_shape = AudioLatentShape(
|
|
batch=latent_shape.batch,
|
|
channels=self.out_ch,
|
|
frames=target_frames,
|
|
mel_bins=(
|
|
self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins
|
|
),
|
|
)
|
|
|
|
return sample, target_shape
|
|
|
|
def _adjust_output_shape(
|
|
self,
|
|
decoded_output: mx.array,
|
|
target_shape: AudioLatentShape,
|
|
) -> mx.array:
|
|
"""
|
|
Adjust output shape to match target dimensions for variable-length audio.
|
|
Args:
|
|
decoded_output: Tensor of shape (B, H, W, C) in MLX format
|
|
target_shape: AudioLatentShape describing target dimensions
|
|
Returns:
|
|
Tensor adjusted to match target_shape exactly
|
|
"""
|
|
# Current output shape: (batch, frames, mel_bins, channels) in MLX format
|
|
_, current_time, current_freq, _ = decoded_output.shape
|
|
target_channels = target_shape.channels
|
|
target_time = target_shape.frames
|
|
target_freq = target_shape.mel_bins
|
|
|
|
# Step 1: Crop first to avoid exceeding target dimensions
|
|
decoded_output = decoded_output[
|
|
:,
|
|
: min(current_time, target_time),
|
|
: min(current_freq, target_freq),
|
|
:target_channels,
|
|
]
|
|
|
|
# Step 2: Calculate padding needed for time and frequency dimensions
|
|
time_padding_needed = target_time - decoded_output.shape[1]
|
|
freq_padding_needed = target_freq - decoded_output.shape[2]
|
|
|
|
# Step 3: Apply padding if needed
|
|
if time_padding_needed > 0 or freq_padding_needed > 0:
|
|
# MLX pad: [(before_0, after_0), ...]
|
|
# For (B, H, W, C): H=time, W=freq
|
|
padding = [
|
|
(0, 0), # batch
|
|
(0, max(time_padding_needed, 0)), # time
|
|
(0, max(freq_padding_needed, 0)), # freq
|
|
(0, 0), # channels
|
|
]
|
|
decoded_output = mx.pad(decoded_output, padding)
|
|
|
|
# Step 4: Final safety crop to ensure exact target shape
|
|
decoded_output = decoded_output[:, :target_time, :target_freq, :target_channels]
|
|
|
|
# Transpose back to PyTorch format (B, C, H, W) for vocoder compatibility
|
|
decoded_output = mx.transpose(decoded_output, (0, 3, 1, 2))
|
|
|
|
return decoded_output
|
|
|
|
def _run_upsampling_path(self, h: mx.array) -> mx.array:
|
|
"""Run through upsampling path."""
|
|
for level in reversed(range(self.num_resolutions)):
|
|
stage = self.up[level]
|
|
for block_idx in range(len(stage["block"])):
|
|
h = stage["block"][block_idx](h, temb=None)
|
|
if block_idx in stage["attn"]:
|
|
h = stage["attn"][block_idx](h)
|
|
|
|
if level != 0 and "upsample" in stage:
|
|
h = stage["upsample"](h)
|
|
|
|
return h
|
|
|
|
def _finalize_output(self, h: mx.array) -> mx.array:
|
|
"""Apply final normalization and convolution."""
|
|
if self.give_pre_end:
|
|
return h
|
|
|
|
h = self.norm_out(h)
|
|
h = nn.silu(h)
|
|
h = self.conv_out(h)
|
|
return mx.tanh(h) if self.tanh_out else h
|
|
|
|
|
|
def decode_audio(
|
|
latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder"
|
|
) -> mx.array:
|
|
"""
|
|
Decode an audio latent representation using the provided audio decoder and vocoder.
|
|
Args:
|
|
latent: Input audio latent tensor
|
|
audio_decoder: Model to decode the latent to spectrogram
|
|
vocoder: Model to convert spectrogram to audio waveform
|
|
Returns:
|
|
Decoded audio as a float tensor
|
|
"""
|
|
decoded_audio = audio_decoder(latent)
|
|
decoded_audio = vocoder(decoded_audio)
|
|
# Remove batch dimension if present
|
|
if decoded_audio.shape[0] == 1:
|
|
decoded_audio = decoded_audio[0]
|
|
return decoded_audio.astype(mx.float32)
|