Refactor weight loading and sanitization processes for audio models

This commit is contained in:
Prince Canuma
2026-01-23 17:31:25 +01:00
parent 2681f75d2f
commit 02bfa228d9
18 changed files with 510 additions and 498 deletions

View File

@@ -3,7 +3,7 @@
from .attention import AttentionType, AttnBlock, make_attn
from .audio_vae import AudioDecoder, decode_audio
from .causal_conv_2d import CausalConv2d, make_conv2d
from .causality_axis import CausalityAxis
from ..config import CausalityAxis
from .downsample import Downsample, build_downsampling_path
from .normalization import NormType, PixelNorm, build_normalization_layer
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics

View File

@@ -1,14 +1,15 @@
"""Audio VAE encoder and decoder for LTX-2."""
from typing import Set, Tuple
from typing import Dict
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from mlx_vlm.models.base import check_array_shape
from ..config import AudioDecoderModelConfig
from .attention import AttentionType, make_attn
from .causal_conv_2d import make_conv2d
from .causality_axis import CausalityAxis
from .downsample import build_downsampling_path
from ..config import CausalityAxis
from .normalization import NormType, build_normalization_layer
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
from .resnet import ResnetBlock
@@ -67,22 +68,7 @@ class AudioDecoder(nn.Module):
def __init__(
self,
*,
ch: int = 128,
out_ch: int = 2,
ch_mult: Tuple[int, ...] = (1, 2, 4),
num_res_blocks: int = 2,
attn_resolutions: Set[int] = None,
resolution: int = 256,
z_channels: int = 8,
norm_type: NormType = NormType.PIXEL,
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
dropout: float = 0.0,
mid_block_add_attention: bool = True,
sample_rate: int = 16000,
mel_hop_length: int = 160,
is_causal: bool = True,
mel_bins: int | None = None,
config: AudioDecoderModelConfig,
) -> None:
"""
Initialize the AudioDecoder.
@@ -105,86 +91,132 @@ class AudioDecoder(nn.Module):
"""
super().__init__()
if attn_resolutions is None:
attn_resolutions = {8, 16, 32}
# Internal behavioral defaults
resamp_with_conv = True
attn_type = AttentionType.VANILLA
# Per-channel statistics for denormalizing latents
# Uses ch (base channel count) to match the patchified latent dimension
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
# After patchify: (B, T, z_channels * latent_mel_bins) = (B, T, 128)
# ch=128 matches this dimension, so use ch for per_channel_statistics
self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
self.sample_rate = sample_rate
self.mel_hop_length = mel_hop_length
self.is_causal = is_causal
self.mel_bins = mel_bins
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch)
self.sample_rate = config.sample_rate
self.mel_hop_length = config.mel_hop_length
self.is_causal = config.is_causal
self.mel_bins = config.mel_bins
self.patchifier = AudioPatchifier(
patch_size=1,
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
sample_rate=sample_rate,
hop_length=mel_hop_length,
is_causal=is_causal,
sample_rate=config.sample_rate,
hop_length=config.mel_hop_length,
is_causal=config.is_causal,
)
self.ch = ch
self.ch = config.ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.out_ch = out_ch
self.give_pre_end = False
self.tanh_out = False
self.norm_type = norm_type
self.z_channels = z_channels
self.channel_multipliers = ch_mult
self.attn_resolutions = attn_resolutions
self.causality_axis = causality_axis
self.attn_type = attn_type
self.num_resolutions = len(config.ch_mult)
self.num_res_blocks = config.num_res_blocks
self.resolution = config.resolution
self.out_ch = config.out_ch
self.give_pre_end = config.give_pre_end
self.tanh_out = config.tanh_out
self.norm_type = config.norm_type
self.z_channels = config.z_channels
self.channel_multipliers = config.ch_mult
self.attn_resolutions = config.attn_resolutions
self.causality_axis = config.causality_axis
self.attn_type = config.attn_type
base_block_channels = ch * self.channel_multipliers[-1]
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
self.z_shape = (1, z_channels, base_resolution, base_resolution)
base_block_channels = config.ch * self.channel_multipliers[-1]
base_resolution = config.resolution // (2 ** (self.num_resolutions - 1))
self.z_shape = (1, config.z_channels, base_resolution, base_resolution)
self.conv_in = make_conv2d(
z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
config.z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
self.mid = build_mid_block(
channels=base_block_channels,
temb_channels=self.temb_ch,
dropout=dropout,
dropout=config.dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
add_attention=mid_block_add_attention,
add_attention=config.mid_block_add_attention,
)
self.up, final_block_channels = build_upsampling_path(
ch=ch,
ch_mult=ch_mult,
ch=config.ch,
ch_mult=config.ch_mult,
num_resolutions=self.num_resolutions,
num_res_blocks=num_res_blocks,
resolution=resolution,
num_res_blocks=config.num_res_blocks,
resolution=config.resolution,
temb_channels=self.temb_ch,
dropout=dropout,
dropout=config.dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
attn_resolutions=attn_resolutions,
resamp_with_conv=resamp_with_conv,
attn_resolutions=config.attn_resolutions,
resamp_with_conv=config.resamp_with_conv,
initial_block_channels=base_block_channels,
)
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
self.conv_out = make_conv2d(
final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
final_block_channels, config.out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize audio VAE weight names from PyTorch format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for audio VAE decoder
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Handle audio_vae.decoder weights
if key.startswith("audio_vae.decoder."):
new_key = key.replace("audio_vae.decoder.", "")
elif key.startswith("audio_vae.per_channel_statistics."):
# Map per-channel statistics
if "mean-of-means" in key:
new_key = "per_channel_statistics.mean_of_means"
elif "std-of-means" in key:
new_key = "per_channel_statistics.std_of_means"
else:
continue # Skip other statistics keys
else:
continue # Skip non-decoder keys
# Handle Conv2d weight shape conversion
# PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
"""Load audio VAE decoder from pretrained model."""
from mlx_video.models.ltx.config import AudioDecoderModelConfig
import json
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
decoder = cls(config)
weights = mx.load(str(model_path / "model.safetensors"))
# weights = decoder.sanitize(weights)
decoder.load_weights(list(weights.items()), strict=True)
return decoder
def __call__(self, sample: mx.array) -> mx.array:
"""
Decode latent features back to audio spectrograms.

View File

@@ -5,7 +5,7 @@ from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .causality_axis import CausalityAxis
from ..config import CausalityAxis
def _pair(x: Union[int, Tuple[int, int]]) -> Tuple[int, int]:

View File

@@ -1,12 +0,0 @@
"""Causality axis enum for specifying causal convolution dimensions."""
from enum import Enum
class CausalityAxis(Enum):
"""Enum for specifying the causality axis in causal convolutions."""
NONE = None
WIDTH = "width"
HEIGHT = "height"
WIDTH_COMPATIBILITY = "width-compatibility"

View File

@@ -6,7 +6,7 @@ import mlx.core as mx
import mlx.nn as nn
from .attention import AttentionType, make_attn
from .causality_axis import CausalityAxis
from ..config import CausalityAxis
from .normalization import NormType
from .resnet import ResnetBlock

View File

@@ -27,21 +27,21 @@ class PerChannelStatistics(nn.Module):
self.latent_channels = latent_channels
# Initialize buffers - will be loaded from weights
# Using underscores for MLX compatibility with weight loading
self._std_of_means = mx.ones((latent_channels,))
self._mean_of_means = mx.zeros((latent_channels,))
self.std_of_means = mx.ones((latent_channels,))
self.mean_of_means = mx.zeros((latent_channels,))
def un_normalize(self, x: mx.array) -> mx.array:
"""Denormalize latent representation."""
# Broadcast statistics to match x shape
# x shape: (B, C, ...) or (B, ..., C)
std = self._std_of_means.astype(x.dtype)
mean = self._mean_of_means.astype(x.dtype)
std = self.std_of_means.astype(x.dtype)
mean = self.mean_of_means.astype(x.dtype)
return (x * std) + mean
def normalize(self, x: mx.array) -> mx.array:
"""Normalize latent representation."""
std = self._std_of_means.astype(x.dtype)
mean = self._mean_of_means.astype(x.dtype)
std = self.std_of_means.astype(x.dtype)
mean = self.mean_of_means.astype(x.dtype)
return (x - mean) / std

View File

@@ -6,7 +6,7 @@ import mlx.core as mx
import mlx.nn as nn
from .causal_conv_2d import make_conv2d
from .causality_axis import CausalityAxis
from ..config import CausalityAxis
from .normalization import NormType, build_normalization_layer
LRELU_SLOPE = 0.1

View File

@@ -7,7 +7,7 @@ import mlx.nn as nn
from .attention import AttentionType, make_attn
from .causal_conv_2d import make_conv2d
from .causality_axis import CausalityAxis
from ..config import CausalityAxis
from .normalization import NormType
from .resnet import ResnetBlock

View File

@@ -1,11 +1,12 @@
"""Vocoder for converting mel spectrograms to audio waveforms."""
import math
from typing import List
from typing import Dict
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from mlx_vlm.models.base import check_array_shape
from ..config import VocoderModelConfig
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, leaky_relu
@@ -27,44 +28,29 @@ class Vocoder(nn.Module):
def __init__(
self,
resblock_kernel_sizes: List[int] | None = None,
upsample_rates: List[int] | None = None,
upsample_kernel_sizes: List[int] | None = None,
resblock_dilation_sizes: List[List[int]] | None = None,
upsample_initial_channel: int = 1024,
stereo: bool = True,
resblock: str = "1",
output_sample_rate: int = 24000,
config: VocoderModelConfig
):
super().__init__()
# Initialize default values if not provided
if resblock_kernel_sizes is None:
resblock_kernel_sizes = [3, 7, 11]
if upsample_rates is None:
upsample_rates = [6, 5, 2, 2, 2]
if upsample_kernel_sizes is None:
upsample_kernel_sizes = [16, 15, 8, 4, 4]
if resblock_dilation_sizes is None:
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
self.output_sample_rate = output_sample_rate
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.upsample_rates = upsample_rates
self.upsample_kernel_sizes = upsample_kernel_sizes
self.upsample_initial_channel = upsample_initial_channel
self.output_sample_rate = config.output_sample_rate
self.num_kernels = len(config.resblock_kernel_sizes)
self.num_upsamples = len(config.upsample_rates)
self.upsample_rates = config.upsample_rates
self.upsample_kernel_sizes = config.upsample_kernel_sizes
self.upsample_initial_channel = config.upsample_initial_channel
in_channels = 128 if stereo else 64
self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, kernel_size=7, stride=1, padding=3)
in_channels = 128 if config.stereo else 64
self.conv_pre = nn.Conv1d(in_channels, config.upsample_initial_channel, kernel_size=7, stride=1, padding=3)
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
resblock_class = ResBlock1 if config.resblock == "1" else ResBlock2
# Upsampling layers using ConvTranspose1d
self.ups = {}
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
in_ch = upsample_initial_channel // (2**i)
out_ch = upsample_initial_channel // (2 ** (i + 1))
for i, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
in_ch = config.upsample_initial_channel // (2**i)
out_ch = config.upsample_initial_channel // (2 ** (i + 1))
self.ups[i] = nn.ConvTranspose1d(
in_ch,
out_ch,
@@ -77,16 +63,67 @@ class Vocoder(nn.Module):
self.resblocks = {}
block_idx = 0
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes):
ch = config.upsample_initial_channel // (2 ** (i + 1))
for kernel_size, dilations in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations))
block_idx += 1
out_channels = 2 if stereo else 1
final_channels = upsample_initial_channel // (2**self.num_upsamples)
out_channels = 2 if config.stereo else 1
final_channels = config.upsample_initial_channel // (2**self.num_upsamples)
self.conv_post = nn.Conv1d(final_channels, out_channels, kernel_size=7, stride=1, padding=3)
self.upsample_factor = math.prod(upsample_rates)
self.upsample_factor = math.prod(config.upsample_rates)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
sanitized = {}
if "vocoder." not in weights:
return weights
for key, value in weights.items():
new_key = key
# Handle vocoder weights
if key.startswith("vocoder."):
new_key = key.replace("vocoder.", "")
# Handle ModuleList indices -> dict keys
# PyTorch: ups.0, ups.1, ... -> ups.0, ups.1, ...
# PyTorch: resblocks.0, resblocks.1, ... -> resblocks.0, resblocks.1, ...
# Handle Conv1d weight shape conversion
# PyTorch: (out_channels, in_channels, kernel)
# MLX: (out_channels, kernel, in_channels)
if "weight" in new_key and value.ndim == 3:
if "ups" in new_key:
# ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch)
value = value if check_array_shape(value) else mx.transpose(value, (1, 2, 0))
else:
# Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch)
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 1))
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "Vocoder":
"""Load vocoder from pretrained model."""
from mlx_video.models.ltx.config import VocoderModelConfig
import json
config_dict = {}
with open(model_path / "config.json", "r") as f:
config_dict = json.load(f)
config = VocoderModelConfig.from_dict(config_dict)
model = cls(config)
weights = mx.load(str(model_path / "model.safetensors"))
# weights = vocoder.sanitize(weights)
model.load_weights(list(weights.items()), strict=strict)
return model
def __call__(self, x: mx.array) -> mx.array:
"""