Refactor weight loading and sanitization processes for audio models
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
from .attention import AttentionType, AttnBlock, make_attn
|
||||
from .audio_vae import AudioDecoder, decode_audio
|
||||
from .causal_conv_2d import CausalConv2d, make_conv2d
|
||||
from .causality_axis import CausalityAxis
|
||||
from ..config import CausalityAxis
|
||||
from .downsample import Downsample, build_downsampling_path
|
||||
from .normalization import NormType, PixelNorm, build_normalization_layer
|
||||
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
"""Audio VAE encoder and decoder for LTX-2."""
|
||||
|
||||
from typing import Set, Tuple
|
||||
from typing import Dict
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_vlm.models.base import check_array_shape
|
||||
from ..config import AudioDecoderModelConfig
|
||||
from .attention import AttentionType, make_attn
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from .causality_axis import CausalityAxis
|
||||
from .downsample import build_downsampling_path
|
||||
from ..config import CausalityAxis
|
||||
from .normalization import NormType, build_normalization_layer
|
||||
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||
from .resnet import ResnetBlock
|
||||
@@ -67,22 +68,7 @@ class AudioDecoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch: int = 128,
|
||||
out_ch: int = 2,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
||||
num_res_blocks: int = 2,
|
||||
attn_resolutions: Set[int] = None,
|
||||
resolution: int = 256,
|
||||
z_channels: int = 8,
|
||||
norm_type: NormType = NormType.PIXEL,
|
||||
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||
dropout: float = 0.0,
|
||||
mid_block_add_attention: bool = True,
|
||||
sample_rate: int = 16000,
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: int | None = None,
|
||||
config: AudioDecoderModelConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the AudioDecoder.
|
||||
@@ -105,86 +91,132 @@ class AudioDecoder(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if attn_resolutions is None:
|
||||
attn_resolutions = {8, 16, 32}
|
||||
|
||||
# Internal behavioral defaults
|
||||
resamp_with_conv = True
|
||||
attn_type = AttentionType.VANILLA
|
||||
|
||||
# Per-channel statistics for denormalizing latents
|
||||
# Uses ch (base channel count) to match the patchified latent dimension
|
||||
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
|
||||
# After patchify: (B, T, z_channels * latent_mel_bins) = (B, T, 128)
|
||||
# ch=128 matches this dimension, so use ch for per_channel_statistics
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_hop_length = mel_hop_length
|
||||
self.is_causal = is_causal
|
||||
self.mel_bins = mel_bins
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch)
|
||||
self.sample_rate = config.sample_rate
|
||||
self.mel_hop_length = config.mel_hop_length
|
||||
self.is_causal = config.is_causal
|
||||
self.mel_bins = config.mel_bins
|
||||
|
||||
self.patchifier = AudioPatchifier(
|
||||
patch_size=1,
|
||||
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||
sample_rate=sample_rate,
|
||||
hop_length=mel_hop_length,
|
||||
is_causal=is_causal,
|
||||
sample_rate=config.sample_rate,
|
||||
hop_length=config.mel_hop_length,
|
||||
is_causal=config.is_causal,
|
||||
)
|
||||
|
||||
self.ch = ch
|
||||
self.ch = config.ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.out_ch = out_ch
|
||||
self.give_pre_end = False
|
||||
self.tanh_out = False
|
||||
self.norm_type = norm_type
|
||||
self.z_channels = z_channels
|
||||
self.channel_multipliers = ch_mult
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.causality_axis = causality_axis
|
||||
self.attn_type = attn_type
|
||||
self.num_resolutions = len(config.ch_mult)
|
||||
self.num_res_blocks = config.num_res_blocks
|
||||
self.resolution = config.resolution
|
||||
self.out_ch = config.out_ch
|
||||
self.give_pre_end = config.give_pre_end
|
||||
self.tanh_out = config.tanh_out
|
||||
self.norm_type = config.norm_type
|
||||
self.z_channels = config.z_channels
|
||||
self.channel_multipliers = config.ch_mult
|
||||
self.attn_resolutions = config.attn_resolutions
|
||||
self.causality_axis = config.causality_axis
|
||||
self.attn_type = config.attn_type
|
||||
|
||||
base_block_channels = ch * self.channel_multipliers[-1]
|
||||
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
|
||||
self.z_shape = (1, z_channels, base_resolution, base_resolution)
|
||||
base_block_channels = config.ch * self.channel_multipliers[-1]
|
||||
base_resolution = config.resolution // (2 ** (self.num_resolutions - 1))
|
||||
self.z_shape = (1, config.z_channels, base_resolution, base_resolution)
|
||||
|
||||
self.conv_in = make_conv2d(
|
||||
z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
config.z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
|
||||
self.mid = build_mid_block(
|
||||
channels=base_block_channels,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
dropout=config.dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
attn_type=self.attn_type,
|
||||
add_attention=mid_block_add_attention,
|
||||
add_attention=config.mid_block_add_attention,
|
||||
)
|
||||
|
||||
self.up, final_block_channels = build_upsampling_path(
|
||||
ch=ch,
|
||||
ch_mult=ch_mult,
|
||||
ch=config.ch,
|
||||
ch_mult=config.ch_mult,
|
||||
num_resolutions=self.num_resolutions,
|
||||
num_res_blocks=num_res_blocks,
|
||||
resolution=resolution,
|
||||
num_res_blocks=config.num_res_blocks,
|
||||
resolution=config.resolution,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
dropout=config.dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
attn_type=self.attn_type,
|
||||
attn_resolutions=attn_resolutions,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
attn_resolutions=config.attn_resolutions,
|
||||
resamp_with_conv=config.resamp_with_conv,
|
||||
initial_block_channels=base_block_channels,
|
||||
)
|
||||
|
||||
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
|
||||
self.conv_out = make_conv2d(
|
||||
final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
final_block_channels, config.out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize audio VAE weight names from PyTorch format to MLX format.
|
||||
|
||||
Args:
|
||||
weights: Dictionary of weights with PyTorch naming
|
||||
|
||||
Returns:
|
||||
Dictionary with MLX-compatible naming for audio VAE decoder
|
||||
"""
|
||||
sanitized = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Handle audio_vae.decoder weights
|
||||
if key.startswith("audio_vae.decoder."):
|
||||
new_key = key.replace("audio_vae.decoder.", "")
|
||||
elif key.startswith("audio_vae.per_channel_statistics."):
|
||||
# Map per-channel statistics
|
||||
if "mean-of-means" in key:
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
elif "std-of-means" in key:
|
||||
new_key = "per_channel_statistics.std_of_means"
|
||||
else:
|
||||
continue # Skip other statistics keys
|
||||
else:
|
||||
continue # Skip non-decoder keys
|
||||
|
||||
# Handle Conv2d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, H, W)
|
||||
# MLX: (out_channels, H, W, in_channels)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
|
||||
"""Load audio VAE decoder from pretrained model."""
|
||||
from mlx_video.models.ltx.config import AudioDecoderModelConfig
|
||||
import json
|
||||
|
||||
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
||||
decoder = cls(config)
|
||||
weights = mx.load(str(model_path / "model.safetensors"))
|
||||
# weights = decoder.sanitize(weights)
|
||||
decoder.load_weights(list(weights.items()), strict=True)
|
||||
return decoder
|
||||
|
||||
|
||||
def __call__(self, sample: mx.array) -> mx.array:
|
||||
"""
|
||||
Decode latent features back to audio spectrograms.
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Tuple, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .causality_axis import CausalityAxis
|
||||
from ..config import CausalityAxis
|
||||
|
||||
|
||||
def _pair(x: Union[int, Tuple[int, int]]) -> Tuple[int, int]:
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
"""Causality axis enum for specifying causal convolution dimensions."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class CausalityAxis(Enum):
|
||||
"""Enum for specifying the causality axis in causal convolutions."""
|
||||
|
||||
NONE = None
|
||||
WIDTH = "width"
|
||||
HEIGHT = "height"
|
||||
WIDTH_COMPATIBILITY = "width-compatibility"
|
||||
@@ -6,7 +6,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .attention import AttentionType, make_attn
|
||||
from .causality_axis import CausalityAxis
|
||||
from ..config import CausalityAxis
|
||||
from .normalization import NormType
|
||||
from .resnet import ResnetBlock
|
||||
|
||||
|
||||
@@ -27,21 +27,21 @@ class PerChannelStatistics(nn.Module):
|
||||
self.latent_channels = latent_channels
|
||||
# Initialize buffers - will be loaded from weights
|
||||
# Using underscores for MLX compatibility with weight loading
|
||||
self._std_of_means = mx.ones((latent_channels,))
|
||||
self._mean_of_means = mx.zeros((latent_channels,))
|
||||
self.std_of_means = mx.ones((latent_channels,))
|
||||
self.mean_of_means = mx.zeros((latent_channels,))
|
||||
|
||||
def un_normalize(self, x: mx.array) -> mx.array:
|
||||
"""Denormalize latent representation."""
|
||||
# Broadcast statistics to match x shape
|
||||
# x shape: (B, C, ...) or (B, ..., C)
|
||||
std = self._std_of_means.astype(x.dtype)
|
||||
mean = self._mean_of_means.astype(x.dtype)
|
||||
std = self.std_of_means.astype(x.dtype)
|
||||
mean = self.mean_of_means.astype(x.dtype)
|
||||
return (x * std) + mean
|
||||
|
||||
def normalize(self, x: mx.array) -> mx.array:
|
||||
"""Normalize latent representation."""
|
||||
std = self._std_of_means.astype(x.dtype)
|
||||
mean = self._mean_of_means.astype(x.dtype)
|
||||
std = self.std_of_means.astype(x.dtype)
|
||||
mean = self.mean_of_means.astype(x.dtype)
|
||||
return (x - mean) / std
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from .causality_axis import CausalityAxis
|
||||
from ..config import CausalityAxis
|
||||
from .normalization import NormType, build_normalization_layer
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
@@ -7,7 +7,7 @@ import mlx.nn as nn
|
||||
|
||||
from .attention import AttentionType, make_attn
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from .causality_axis import CausalityAxis
|
||||
from ..config import CausalityAxis
|
||||
from .normalization import NormType
|
||||
from .resnet import ResnetBlock
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""Vocoder for converting mel spectrograms to audio waveforms."""
|
||||
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
from typing import Dict
|
||||
from pathlib import Path
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_vlm.models.base import check_array_shape
|
||||
from ..config import VocoderModelConfig
|
||||
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, leaky_relu
|
||||
|
||||
|
||||
@@ -27,44 +28,29 @@ class Vocoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
resblock_kernel_sizes: List[int] | None = None,
|
||||
upsample_rates: List[int] | None = None,
|
||||
upsample_kernel_sizes: List[int] | None = None,
|
||||
resblock_dilation_sizes: List[List[int]] | None = None,
|
||||
upsample_initial_channel: int = 1024,
|
||||
stereo: bool = True,
|
||||
resblock: str = "1",
|
||||
output_sample_rate: int = 24000,
|
||||
config: VocoderModelConfig
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Initialize default values if not provided
|
||||
if resblock_kernel_sizes is None:
|
||||
resblock_kernel_sizes = [3, 7, 11]
|
||||
if upsample_rates is None:
|
||||
upsample_rates = [6, 5, 2, 2, 2]
|
||||
if upsample_kernel_sizes is None:
|
||||
upsample_kernel_sizes = [16, 15, 8, 4, 4]
|
||||
if resblock_dilation_sizes is None:
|
||||
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.upsample_rates = upsample_rates
|
||||
self.upsample_kernel_sizes = upsample_kernel_sizes
|
||||
self.upsample_initial_channel = upsample_initial_channel
|
||||
|
||||
self.output_sample_rate = config.output_sample_rate
|
||||
self.num_kernels = len(config.resblock_kernel_sizes)
|
||||
self.num_upsamples = len(config.upsample_rates)
|
||||
self.upsample_rates = config.upsample_rates
|
||||
self.upsample_kernel_sizes = config.upsample_kernel_sizes
|
||||
self.upsample_initial_channel = config.upsample_initial_channel
|
||||
|
||||
in_channels = 128 if stereo else 64
|
||||
self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, kernel_size=7, stride=1, padding=3)
|
||||
in_channels = 128 if config.stereo else 64
|
||||
self.conv_pre = nn.Conv1d(in_channels, config.upsample_initial_channel, kernel_size=7, stride=1, padding=3)
|
||||
|
||||
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
||||
resblock_class = ResBlock1 if config.resblock == "1" else ResBlock2
|
||||
|
||||
# Upsampling layers using ConvTranspose1d
|
||||
self.ups = {}
|
||||
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
in_ch = upsample_initial_channel // (2**i)
|
||||
out_ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for i, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
|
||||
in_ch = config.upsample_initial_channel // (2**i)
|
||||
out_ch = config.upsample_initial_channel // (2 ** (i + 1))
|
||||
self.ups[i] = nn.ConvTranspose1d(
|
||||
in_ch,
|
||||
out_ch,
|
||||
@@ -77,16 +63,67 @@ class Vocoder(nn.Module):
|
||||
self.resblocks = {}
|
||||
block_idx = 0
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
||||
ch = config.upsample_initial_channel // (2 ** (i + 1))
|
||||
for kernel_size, dilations in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
|
||||
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations))
|
||||
block_idx += 1
|
||||
|
||||
out_channels = 2 if stereo else 1
|
||||
final_channels = upsample_initial_channel // (2**self.num_upsamples)
|
||||
out_channels = 2 if config.stereo else 1
|
||||
final_channels = config.upsample_initial_channel // (2**self.num_upsamples)
|
||||
self.conv_post = nn.Conv1d(final_channels, out_channels, kernel_size=7, stride=1, padding=3)
|
||||
|
||||
self.upsample_factor = math.prod(upsample_rates)
|
||||
self.upsample_factor = math.prod(config.upsample_rates)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
sanitized = {}
|
||||
|
||||
if "vocoder." not in weights:
|
||||
return weights
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Handle vocoder weights
|
||||
if key.startswith("vocoder."):
|
||||
new_key = key.replace("vocoder.", "")
|
||||
|
||||
# Handle ModuleList indices -> dict keys
|
||||
# PyTorch: ups.0, ups.1, ... -> ups.0, ups.1, ...
|
||||
# PyTorch: resblocks.0, resblocks.1, ... -> resblocks.0, resblocks.1, ...
|
||||
|
||||
# Handle Conv1d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, kernel)
|
||||
# MLX: (out_channels, kernel, in_channels)
|
||||
if "weight" in new_key and value.ndim == 3:
|
||||
if "ups" in new_key:
|
||||
# ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch)
|
||||
value = value if check_array_shape(value) else mx.transpose(value, (1, 2, 0))
|
||||
else:
|
||||
# Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch)
|
||||
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "Vocoder":
|
||||
"""Load vocoder from pretrained model."""
|
||||
from mlx_video.models.ltx.config import VocoderModelConfig
|
||||
import json
|
||||
|
||||
config_dict = {}
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
config = VocoderModelConfig.from_dict(config_dict)
|
||||
model = cls(config)
|
||||
weights = mx.load(str(model_path / "model.safetensors"))
|
||||
|
||||
# weights = vocoder.sanitize(weights)
|
||||
model.load_weights(list(weights.items()), strict=strict)
|
||||
return model
|
||||
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user