Refactor weight loading and sanitization processes for audio models

This commit is contained in:
Prince Canuma
2026-01-23 17:31:25 +01:00
parent 2681f75d2f
commit 02bfa228d9
18 changed files with 510 additions and 498 deletions

View File

@@ -1,11 +1,12 @@
"""Vocoder for converting mel spectrograms to audio waveforms."""
import math
from typing import List
from typing import Dict
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from mlx_vlm.models.base import check_array_shape
from ..config import VocoderModelConfig
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, leaky_relu
@@ -27,44 +28,29 @@ class Vocoder(nn.Module):
def __init__(
self,
resblock_kernel_sizes: List[int] | None = None,
upsample_rates: List[int] | None = None,
upsample_kernel_sizes: List[int] | None = None,
resblock_dilation_sizes: List[List[int]] | None = None,
upsample_initial_channel: int = 1024,
stereo: bool = True,
resblock: str = "1",
output_sample_rate: int = 24000,
config: VocoderModelConfig
):
super().__init__()
# Initialize default values if not provided
if resblock_kernel_sizes is None:
resblock_kernel_sizes = [3, 7, 11]
if upsample_rates is None:
upsample_rates = [6, 5, 2, 2, 2]
if upsample_kernel_sizes is None:
upsample_kernel_sizes = [16, 15, 8, 4, 4]
if resblock_dilation_sizes is None:
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
self.output_sample_rate = output_sample_rate
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.upsample_rates = upsample_rates
self.upsample_kernel_sizes = upsample_kernel_sizes
self.upsample_initial_channel = upsample_initial_channel
self.output_sample_rate = config.output_sample_rate
self.num_kernels = len(config.resblock_kernel_sizes)
self.num_upsamples = len(config.upsample_rates)
self.upsample_rates = config.upsample_rates
self.upsample_kernel_sizes = config.upsample_kernel_sizes
self.upsample_initial_channel = config.upsample_initial_channel
in_channels = 128 if stereo else 64
self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, kernel_size=7, stride=1, padding=3)
in_channels = 128 if config.stereo else 64
self.conv_pre = nn.Conv1d(in_channels, config.upsample_initial_channel, kernel_size=7, stride=1, padding=3)
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
resblock_class = ResBlock1 if config.resblock == "1" else ResBlock2
# Upsampling layers using ConvTranspose1d
self.ups = {}
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
in_ch = upsample_initial_channel // (2**i)
out_ch = upsample_initial_channel // (2 ** (i + 1))
for i, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
in_ch = config.upsample_initial_channel // (2**i)
out_ch = config.upsample_initial_channel // (2 ** (i + 1))
self.ups[i] = nn.ConvTranspose1d(
in_ch,
out_ch,
@@ -77,16 +63,67 @@ class Vocoder(nn.Module):
self.resblocks = {}
block_idx = 0
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes):
ch = config.upsample_initial_channel // (2 ** (i + 1))
for kernel_size, dilations in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations))
block_idx += 1
out_channels = 2 if stereo else 1
final_channels = upsample_initial_channel // (2**self.num_upsamples)
out_channels = 2 if config.stereo else 1
final_channels = config.upsample_initial_channel // (2**self.num_upsamples)
self.conv_post = nn.Conv1d(final_channels, out_channels, kernel_size=7, stride=1, padding=3)
self.upsample_factor = math.prod(upsample_rates)
self.upsample_factor = math.prod(config.upsample_rates)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
sanitized = {}
if "vocoder." not in weights:
return weights
for key, value in weights.items():
new_key = key
# Handle vocoder weights
if key.startswith("vocoder."):
new_key = key.replace("vocoder.", "")
# Handle ModuleList indices -> dict keys
# PyTorch: ups.0, ups.1, ... -> ups.0, ups.1, ...
# PyTorch: resblocks.0, resblocks.1, ... -> resblocks.0, resblocks.1, ...
# Handle Conv1d weight shape conversion
# PyTorch: (out_channels, in_channels, kernel)
# MLX: (out_channels, kernel, in_channels)
if "weight" in new_key and value.ndim == 3:
if "ups" in new_key:
# ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch)
value = value if check_array_shape(value) else mx.transpose(value, (1, 2, 0))
else:
# Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch)
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 1))
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "Vocoder":
"""Load vocoder from pretrained model."""
from mlx_video.models.ltx.config import VocoderModelConfig
import json
config_dict = {}
with open(model_path / "config.json", "r") as f:
config_dict = json.load(f)
config = VocoderModelConfig.from_dict(config_dict)
model = cls(config)
weights = mx.load(str(model_path / "model.safetensors"))
# weights = vocoder.sanitize(weights)
model.load_weights(list(weights.items()), strict=strict)
return model
def __call__(self, x: mx.array) -> mx.array:
"""