format
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
"""Audio VAE module for LTX-2 audio generation."""
|
||||
|
||||
from .attention import AttentionType, AttnBlock, make_attn
|
||||
from .audio_vae import AudioDecoder, AudioEncoder, decode_audio
|
||||
from .audio_processor import load_audio, ensure_stereo, waveform_to_mel
|
||||
from .causal_conv_2d import CausalConv2d, make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .attention import AttentionType, AttnBlock, make_attn
|
||||
from .audio_processor import ensure_stereo, load_audio, waveform_to_mel
|
||||
from .audio_vae import AudioDecoder, AudioEncoder, decode_audio
|
||||
from .causal_conv_2d import CausalConv2d, make_conv2d
|
||||
from .downsample import Downsample, build_downsampling_path
|
||||
from .normalization import NormType, PixelNorm, build_normalization_layer
|
||||
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||
|
||||
@@ -32,7 +32,9 @@ class AttnBlock(nn.Module):
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""
|
||||
@@ -103,6 +105,8 @@ def make_attn(
|
||||
elif attn_type == AttentionType.NONE:
|
||||
return Identity()
|
||||
elif attn_type == AttentionType.LINEAR:
|
||||
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
||||
raise NotImplementedError(
|
||||
f"Attention type {attn_type.value} is not supported yet."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown attention type: {attn_type}")
|
||||
|
||||
@@ -4,10 +4,9 @@ Matches the PyTorch AudioProcessor from LTX-2 (torchaudio.transforms.MelSpectrog
|
||||
using librosa for macOS/MLX compatibility.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_audio(
|
||||
@@ -99,14 +98,16 @@ def waveform_to_mel(
|
||||
|
||||
for ch in range(channels):
|
||||
# Magnitude spectrogram (power=1.0)
|
||||
S = np.abs(librosa.stft(
|
||||
waveform[ch],
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
))
|
||||
S = np.abs(
|
||||
librosa.stft(
|
||||
waveform[ch],
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
)
|
||||
)
|
||||
|
||||
# Mel filterbank with slaney normalization
|
||||
mel_basis = librosa.filters.mel(
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
"""Audio VAE encoder and decoder for LTX-2."""
|
||||
|
||||
from typing import Dict
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_vlm.models.base import check_array_shape
|
||||
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig
|
||||
|
||||
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig, CausalityAxis
|
||||
from .attention import AttentionType, make_attn
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .downsample import build_downsampling_path
|
||||
from .normalization import NormType, build_normalization_layer
|
||||
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||
@@ -39,7 +39,9 @@ def build_mid_block(
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
mid["attn_1"] = (
|
||||
make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else None
|
||||
make_attn(channels, attn_type=attn_type, norm_type=norm_type)
|
||||
if add_attention
|
||||
else None
|
||||
)
|
||||
mid["block_2"] = ResnetBlock(
|
||||
in_channels=channels,
|
||||
@@ -93,7 +95,10 @@ class AudioEncoder(nn.Module):
|
||||
self.attn_type = config.attn_type
|
||||
|
||||
self.conv_in = make_conv2d(
|
||||
config.in_channels, self.ch, kernel_size=3, stride=1,
|
||||
config.in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
@@ -125,7 +130,10 @@ class AudioEncoder(nn.Module):
|
||||
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
|
||||
out_channels = 2 * config.z_channels if config.double_z else config.z_channels
|
||||
self.conv_out = make_conv2d(
|
||||
block_in, out_channels, kernel_size=3, stride=1,
|
||||
block_in,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
@@ -160,7 +168,11 @@ class AudioEncoder(nn.Module):
|
||||
continue
|
||||
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
|
||||
value = (
|
||||
value
|
||||
if check_array_shape(value)
|
||||
else mx.transpose(value, (0, 2, 3, 1))
|
||||
)
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
@@ -168,11 +180,14 @@ class AudioEncoder(nn.Module):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
|
||||
"""Load audio encoder from pretrained weights."""
|
||||
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
||||
import json
|
||||
|
||||
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
||||
|
||||
model_path = Path(model_path)
|
||||
config = AudioEncoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
||||
config = AudioEncoderModelConfig.from_dict(
|
||||
json.load(open(model_path / "config.json"))
|
||||
)
|
||||
encoder = cls(config)
|
||||
weights = mx.load(str(model_path / "model.safetensors"))
|
||||
encoder.load_weights(list(weights.items()), strict=True)
|
||||
@@ -265,7 +280,6 @@ class AudioDecoder(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
||||
# Per-channel statistics for denormalizing latents
|
||||
# Uses ch (base channel count) to match the patchified latent dimension
|
||||
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
|
||||
@@ -305,7 +319,11 @@ class AudioDecoder(nn.Module):
|
||||
self.z_shape = (1, config.z_channels, base_resolution, base_resolution)
|
||||
|
||||
self.conv_in = make_conv2d(
|
||||
config.z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
config.z_channels,
|
||||
base_block_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
self.mid = build_mid_block(
|
||||
@@ -334,9 +352,15 @@ class AudioDecoder(nn.Module):
|
||||
initial_block_channels=base_block_channels,
|
||||
)
|
||||
|
||||
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
|
||||
self.norm_out = build_normalization_layer(
|
||||
final_block_channels, normtype=self.norm_type
|
||||
)
|
||||
self.conv_out = make_conv2d(
|
||||
final_block_channels, config.out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
final_block_channels,
|
||||
config.out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
@@ -371,7 +395,11 @@ class AudioDecoder(nn.Module):
|
||||
# PyTorch: (out_channels, in_channels, H, W)
|
||||
# MLX: (out_channels, H, W, in_channels)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
|
||||
value = (
|
||||
value
|
||||
if check_array_shape(value)
|
||||
else mx.transpose(value, (0, 2, 3, 1))
|
||||
)
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
@@ -380,17 +408,19 @@ class AudioDecoder(nn.Module):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
|
||||
"""Load audio VAE decoder from pretrained model."""
|
||||
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
|
||||
import json
|
||||
|
||||
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
||||
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
|
||||
|
||||
config = AudioDecoderModelConfig.from_dict(
|
||||
json.load(open(model_path / "config.json"))
|
||||
)
|
||||
decoder = cls(config)
|
||||
weights = mx.load(str(model_path / "model.safetensors"))
|
||||
# weights = decoder.sanitize(weights)
|
||||
decoder.load_weights(list(weights.items()), strict=True)
|
||||
return decoder
|
||||
|
||||
|
||||
def __call__(self, sample: mx.array) -> mx.array:
|
||||
"""
|
||||
Decode latent features back to audio spectrograms.
|
||||
@@ -414,7 +444,9 @@ class AudioDecoder(nn.Module):
|
||||
|
||||
return self._adjust_output_shape(h, target_shape)
|
||||
|
||||
def _denormalize_latents(self, sample: mx.array) -> tuple[mx.array, AudioLatentShape]:
|
||||
def _denormalize_latents(
|
||||
self, sample: mx.array
|
||||
) -> tuple[mx.array, AudioLatentShape]:
|
||||
"""Denormalize latents using per-channel statistics."""
|
||||
# sample shape: (B, H, W, C) in MLX format
|
||||
latent_shape = AudioLatentShape(
|
||||
@@ -436,7 +468,9 @@ class AudioDecoder(nn.Module):
|
||||
batch=latent_shape.batch,
|
||||
channels=self.out_ch,
|
||||
frames=target_frames,
|
||||
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
|
||||
mel_bins=(
|
||||
self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins
|
||||
),
|
||||
)
|
||||
|
||||
return sample, target_shape
|
||||
@@ -462,7 +496,10 @@ class AudioDecoder(nn.Module):
|
||||
|
||||
# Step 1: Crop first to avoid exceeding target dimensions
|
||||
decoded_output = decoded_output[
|
||||
:, : min(current_time, target_time), : min(current_freq, target_freq), :target_channels
|
||||
:,
|
||||
: min(current_time, target_time),
|
||||
: min(current_freq, target_freq),
|
||||
:target_channels,
|
||||
]
|
||||
|
||||
# Step 2: Calculate padding needed for time and frequency dimensions
|
||||
@@ -514,7 +551,9 @@ class AudioDecoder(nn.Module):
|
||||
return mx.tanh(h) if self.tanh_out else h
|
||||
|
||||
|
||||
def decode_audio(latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder") -> mx.array:
|
||||
def decode_audio(
|
||||
latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder"
|
||||
) -> mx.array:
|
||||
"""
|
||||
Decode an audio latent representation using the provided audio decoder and vocoder.
|
||||
Args:
|
||||
|
||||
@@ -53,8 +53,16 @@ class CausalConv2d(nn.Module):
|
||||
# For (N, H, W, C) format: axis 1 is H (height), axis 2 is W (width)
|
||||
if self.causality_axis == CausalityAxis.NONE:
|
||||
# Non-causal: symmetric padding
|
||||
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w // 2, pad_w - pad_w // 2)
|
||||
elif self.causality_axis in (CausalityAxis.WIDTH, CausalityAxis.WIDTH_COMPATIBILITY):
|
||||
self.padding = (
|
||||
pad_h // 2,
|
||||
pad_h - pad_h // 2,
|
||||
pad_w // 2,
|
||||
pad_w - pad_w // 2,
|
||||
)
|
||||
elif self.causality_axis in (
|
||||
CausalityAxis.WIDTH,
|
||||
CausalityAxis.WIDTH_COMPATIBILITY,
|
||||
):
|
||||
# Causal on width: pad left (before width axis)
|
||||
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w, 0)
|
||||
elif self.causality_axis == CausalityAxis.HEIGHT:
|
||||
@@ -90,7 +98,10 @@ class CausalConv2d(nn.Module):
|
||||
if any(p > 0 for p in self.padding):
|
||||
# MLX pad expects: [(before_0, after_0), (before_1, after_1), ...]
|
||||
# For (N, H, W, C): axis 0=N, axis 1=H, axis 2=W, axis 3=C
|
||||
x = mx.pad(x, [(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)])
|
||||
x = mx.pad(
|
||||
x,
|
||||
[(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)],
|
||||
)
|
||||
|
||||
return self.conv(x)
|
||||
|
||||
@@ -124,7 +135,14 @@ def make_conv2d(
|
||||
if causality_axis is not None:
|
||||
# For causal convolution, padding is handled internally by CausalConv2d
|
||||
return CausalConv2d(
|
||||
in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
causality_axis,
|
||||
)
|
||||
else:
|
||||
# For non-causal convolution, use symmetric padding if not specified
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Set, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .attention import AttentionType, make_attn
|
||||
from ..config import CausalityAxis
|
||||
from .attention import AttentionType, make_attn
|
||||
from .normalization import NormType
|
||||
from .resnet import ResnetBlock
|
||||
|
||||
@@ -34,7 +34,9 @@ class Downsample(nn.Module):
|
||||
if self.with_conv:
|
||||
# Do time downsampling here
|
||||
# no asymmetric padding in MLX conv, must do it ourselves
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""
|
||||
@@ -116,10 +118,14 @@ def build_downsampling_path(
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
|
||||
stage["attn"][i_block] = make_attn(
|
||||
block_in, attn_type=attn_type, norm_type=norm_type
|
||||
)
|
||||
|
||||
if i_level != num_resolutions - 1:
|
||||
stage["downsample"] = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||
stage["downsample"] = Downsample(
|
||||
block_in, resamp_with_conv, causality_axis=causality_axis
|
||||
)
|
||||
curr_res = curr_res // 2
|
||||
|
||||
down_modules[i_level] = stage
|
||||
|
||||
@@ -51,7 +51,9 @@ def build_normalization_layer(
|
||||
A normalization layer
|
||||
"""
|
||||
if normtype == NormType.GROUP:
|
||||
return nn.GroupNorm(num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True)
|
||||
return nn.GroupNorm(
|
||||
num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
if normtype == NormType.PIXEL:
|
||||
# For MLX channels-last format (B, H, W, C), normalize along channels (dim=-1)
|
||||
# PyTorch uses dim=1 for channels-first format (B, C, H, W)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""ResNet blocks for audio VAE and vocoder."""
|
||||
|
||||
from typing import List, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from .normalization import NormType, build_normalization_layer
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
@@ -125,7 +125,11 @@ class ResnetBlock(nn.Module):
|
||||
|
||||
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
|
||||
self.conv1 = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
if temb_channels > 0:
|
||||
@@ -134,17 +138,29 @@ class ResnetBlock(nn.Module):
|
||||
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
|
||||
self.dropout_rate = dropout
|
||||
self.conv2 = make_conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
@@ -168,7 +184,9 @@ class ResnetBlock(nn.Module):
|
||||
if temb is not None and self.temb_channels > 0:
|
||||
# temb: (B, temb_channels) -> (B, out_channels)
|
||||
# Need to add spatial dims: (B, 1, 1, out_channels) for broadcasting
|
||||
h = h + mx.expand_dims(mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1)
|
||||
h = h + mx.expand_dims(
|
||||
mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1
|
||||
)
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nn.silu(h)
|
||||
|
||||
@@ -5,9 +5,9 @@ from typing import Set, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from ..config import CausalityAxis
|
||||
from .attention import AttentionType, make_attn
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .normalization import NormType
|
||||
from .resnet import ResnetBlock
|
||||
|
||||
@@ -42,7 +42,11 @@ class Upsample(nn.Module):
|
||||
self.causality_axis = causality_axis
|
||||
if self.with_conv:
|
||||
self.conv = make_conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
@@ -124,10 +128,14 @@ def build_upsampling_path(
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
|
||||
stage["attn"][i_block] = make_attn(
|
||||
block_in, attn_type=attn_type, norm_type=norm_type
|
||||
)
|
||||
|
||||
if level != 0:
|
||||
stage["upsample"] = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||
stage["upsample"] = Upsample(
|
||||
block_in, resamp_with_conv, causality_axis=causality_axis
|
||||
)
|
||||
curr_res *= 2
|
||||
|
||||
up_modules[level] = stage
|
||||
|
||||
@@ -7,8 +7,8 @@ Supports:
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -32,7 +32,9 @@ class Snake(nn.Module):
|
||||
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
self.alpha = (
|
||||
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# x: (N, L, C) in MLX format
|
||||
@@ -48,8 +50,12 @@ class SnakeBeta(nn.Module):
|
||||
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
self.beta = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
self.alpha = (
|
||||
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
)
|
||||
self.beta = (
|
||||
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
alpha = self.alpha
|
||||
@@ -73,7 +79,9 @@ def _sinc(x: mx.array) -> mx.array:
|
||||
)
|
||||
|
||||
|
||||
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> mx.array:
|
||||
def kaiser_sinc_filter1d(
|
||||
cutoff: float, half_width: float, kernel_size: int
|
||||
) -> mx.array:
|
||||
"""Compute a Kaiser-windowed sinc filter."""
|
||||
even = kernel_size % 2 == 0
|
||||
half_size = kernel_size // 2
|
||||
@@ -88,6 +96,7 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) ->
|
||||
|
||||
# Kaiser window - compute using scipy-compatible formula
|
||||
import numpy as np
|
||||
|
||||
window = mx.array(np.kaiser(kernel_size, beta).astype(np.float32))
|
||||
|
||||
if even:
|
||||
@@ -107,6 +116,7 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) ->
|
||||
def hann_sinc_filter1d(ratio: int) -> Tuple[mx.array, int, int, int]:
|
||||
"""Compute a Hann-windowed sinc filter for upsampling (used by BWE resampler)."""
|
||||
import numpy as np
|
||||
|
||||
rolloff = 0.99
|
||||
lowpass_filter_width = 6
|
||||
width = math.ceil(lowpass_filter_width / rolloff)
|
||||
@@ -187,10 +197,16 @@ class UpSample1d(nn.Module):
|
||||
self.kernel_size = filt.shape[2]
|
||||
self.filter = filt
|
||||
else:
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.kernel_size = (
|
||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
)
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
self.pad_left = (
|
||||
self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
)
|
||||
self.pad_right = (
|
||||
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
)
|
||||
self.filter = kaiser_sinc_filter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
@@ -215,10 +231,12 @@ class UpSample1d(nn.Module):
|
||||
filt = self.filter.astype(x.dtype) # (1, 1, K)
|
||||
filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1)
|
||||
|
||||
x = self.ratio * mx.conv_transpose1d(x, filt, stride=self.stride) # (N*C, L', 1)
|
||||
x = self.ratio * mx.conv_transpose1d(
|
||||
x, filt, stride=self.stride
|
||||
) # (N*C, L', 1)
|
||||
|
||||
# Trim padding
|
||||
x = x[:, self.pad_left:-self.pad_right, :]
|
||||
x = x[:, self.pad_left : -self.pad_right, :]
|
||||
|
||||
x = x.reshape(n, c, -1) # (N, C, L')
|
||||
x = mx.transpose(x, (0, 2, 1)) # (N, L', C)
|
||||
@@ -285,16 +303,24 @@ class AMPBlock1(nn.Module):
|
||||
|
||||
self.convs1 = {
|
||||
i: nn.Conv1d(
|
||||
channels, channels, kernel_size, stride=1,
|
||||
dilation=d, padding=get_padding(kernel_size, d),
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=d,
|
||||
padding=get_padding(kernel_size, d),
|
||||
)
|
||||
for i, d in enumerate(dilation)
|
||||
}
|
||||
|
||||
self.convs2 = {
|
||||
i: nn.Conv1d(
|
||||
channels, channels, kernel_size, stride=1,
|
||||
dilation=1, padding=get_padding(kernel_size, 1),
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
for i in range(len(dilation))
|
||||
}
|
||||
@@ -348,7 +374,9 @@ class STFTFn(nn.Module):
|
||||
y = mx.concatenate([first, y], axis=1)
|
||||
|
||||
# forward_basis: (514, 1, 512) PyTorch format -> (514, 512, 1) MLX
|
||||
basis = mx.transpose(self.forward_basis.astype(y.dtype), (0, 2, 1)) # (514, K, 1)
|
||||
basis = mx.transpose(
|
||||
self.forward_basis.astype(y.dtype), (0, 2, 1)
|
||||
) # (514, K, 1)
|
||||
|
||||
# Conv1d: (B, T, 1) * (514, K, 1) -> (B, T_frames, 514)
|
||||
spec = mx.conv1d(y, basis, stride=self.hop_length)
|
||||
@@ -358,8 +386,10 @@ class STFTFn(nn.Module):
|
||||
real = spec[..., :n_freqs]
|
||||
imag = spec[..., n_freqs:]
|
||||
|
||||
magnitude = mx.sqrt(real ** 2 + imag ** 2)
|
||||
phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype(real.dtype)
|
||||
magnitude = mx.sqrt(real**2 + imag**2)
|
||||
phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype(
|
||||
real.dtype
|
||||
)
|
||||
|
||||
# Output: (B, T_frames, n_freqs) in MLX channels-last
|
||||
return magnitude, phase
|
||||
@@ -368,7 +398,9 @@ class STFTFn(nn.Module):
|
||||
class MelSTFT(nn.Module):
|
||||
"""Causal log-mel spectrogram from precomputed STFT bases."""
|
||||
|
||||
def __init__(self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int) -> None:
|
||||
def __init__(
|
||||
self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.stft_fn = STFTFn(filter_length, hop_length, win_length)
|
||||
n_freqs = filter_length // 2 + 1
|
||||
@@ -385,7 +417,9 @@ class MelSTFT(nn.Module):
|
||||
"""
|
||||
magnitude, phase = self.stft_fn(y)
|
||||
# magnitude: (B, T_frames, n_freqs)
|
||||
mel = magnitude @ self.mel_basis.astype(magnitude.dtype).T # (B, T_frames, n_mels)
|
||||
mel = (
|
||||
magnitude @ self.mel_basis.astype(magnitude.dtype).T
|
||||
) # (B, T_frames, n_mels)
|
||||
log_mel = mx.log(mx.clip(mel, 1e-5, None))
|
||||
# Transpose to (B, n_mels, T_frames) for compatibility with vocoder input format
|
||||
return mx.transpose(log_mel, (0, 2, 1))
|
||||
@@ -415,8 +449,11 @@ class Vocoder(nn.Module):
|
||||
|
||||
in_channels = 128 if config.stereo else 64
|
||||
self.conv_pre = nn.Conv1d(
|
||||
in_channels, config.upsample_initial_channel,
|
||||
kernel_size=7, stride=1, padding=3,
|
||||
in_channels,
|
||||
config.upsample_initial_channel,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=3,
|
||||
)
|
||||
|
||||
# Upsampling layers
|
||||
@@ -424,11 +461,13 @@ class Vocoder(nn.Module):
|
||||
for i, (stride, kernel_size) in enumerate(
|
||||
zip(config.upsample_rates, config.upsample_kernel_sizes)
|
||||
):
|
||||
in_ch = config.upsample_initial_channel // (2 ** i)
|
||||
in_ch = config.upsample_initial_channel // (2**i)
|
||||
out_ch = config.upsample_initial_channel // (2 ** (i + 1))
|
||||
self.ups[i] = nn.ConvTranspose1d(
|
||||
in_ch, out_ch,
|
||||
kernel_size=kernel_size, stride=stride,
|
||||
in_ch,
|
||||
out_ch,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - stride) // 2,
|
||||
)
|
||||
|
||||
@@ -442,7 +481,9 @@ class Vocoder(nn.Module):
|
||||
config.resblock_kernel_sizes, config.resblock_dilation_sizes
|
||||
):
|
||||
self.resblocks[block_idx] = AMPBlock1(
|
||||
ch, kernel_size, tuple(dilations),
|
||||
ch,
|
||||
kernel_size,
|
||||
tuple(dilations),
|
||||
activation=config.activation,
|
||||
)
|
||||
block_idx += 1
|
||||
@@ -455,10 +496,14 @@ class Vocoder(nn.Module):
|
||||
for kernel_size, dilations in zip(
|
||||
config.resblock_kernel_sizes, config.resblock_dilation_sizes
|
||||
):
|
||||
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations))
|
||||
self.resblocks[block_idx] = resblock_class(
|
||||
ch, kernel_size, tuple(dilations)
|
||||
)
|
||||
block_idx += 1
|
||||
|
||||
final_channels = config.upsample_initial_channel // (2 ** len(config.upsample_rates))
|
||||
final_channels = config.upsample_initial_channel // (
|
||||
2 ** len(config.upsample_rates)
|
||||
)
|
||||
|
||||
# Post-activation
|
||||
if self.is_amp:
|
||||
@@ -468,8 +513,11 @@ class Vocoder(nn.Module):
|
||||
# Final conv
|
||||
out_channels = 2 if config.stereo else 1
|
||||
self.conv_post = nn.Conv1d(
|
||||
final_channels, out_channels,
|
||||
kernel_size=7, stride=1, padding=3,
|
||||
final_channels,
|
||||
out_channels,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=3,
|
||||
bias=config.use_bias_at_final,
|
||||
)
|
||||
|
||||
@@ -588,7 +636,9 @@ class VocoderWithBWE(nn.Module):
|
||||
"""
|
||||
x = self.vocoder(mel_spec) # (B, C, T) at input_sampling_rate
|
||||
_, _, length_low_rate = x.shape
|
||||
output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
|
||||
output_length = (
|
||||
length_low_rate * self.output_sampling_rate // self.input_sampling_rate
|
||||
)
|
||||
|
||||
# Pad to hop_length multiple
|
||||
remainder = length_low_rate % self.hop_length
|
||||
@@ -685,5 +735,3 @@ def _load_vocoder_with_bwe(config_dict: dict, weights: dict) -> VocoderWithBWE:
|
||||
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user