format
This commit is contained in:
@@ -1,15 +1,15 @@
|
||||
"""Audio VAE encoder and decoder for LTX-2."""
|
||||
|
||||
from typing import Dict
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_vlm.models.base import check_array_shape
|
||||
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig
|
||||
|
||||
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig, CausalityAxis
|
||||
from .attention import AttentionType, make_attn
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .downsample import build_downsampling_path
|
||||
from .normalization import NormType, build_normalization_layer
|
||||
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||
@@ -39,7 +39,9 @@ def build_mid_block(
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
mid["attn_1"] = (
|
||||
make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else None
|
||||
make_attn(channels, attn_type=attn_type, norm_type=norm_type)
|
||||
if add_attention
|
||||
else None
|
||||
)
|
||||
mid["block_2"] = ResnetBlock(
|
||||
in_channels=channels,
|
||||
@@ -93,7 +95,10 @@ class AudioEncoder(nn.Module):
|
||||
self.attn_type = config.attn_type
|
||||
|
||||
self.conv_in = make_conv2d(
|
||||
config.in_channels, self.ch, kernel_size=3, stride=1,
|
||||
config.in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
@@ -125,7 +130,10 @@ class AudioEncoder(nn.Module):
|
||||
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
|
||||
out_channels = 2 * config.z_channels if config.double_z else config.z_channels
|
||||
self.conv_out = make_conv2d(
|
||||
block_in, out_channels, kernel_size=3, stride=1,
|
||||
block_in,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
@@ -160,7 +168,11 @@ class AudioEncoder(nn.Module):
|
||||
continue
|
||||
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
|
||||
value = (
|
||||
value
|
||||
if check_array_shape(value)
|
||||
else mx.transpose(value, (0, 2, 3, 1))
|
||||
)
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
@@ -168,11 +180,14 @@ class AudioEncoder(nn.Module):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
|
||||
"""Load audio encoder from pretrained weights."""
|
||||
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
||||
import json
|
||||
|
||||
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
||||
|
||||
model_path = Path(model_path)
|
||||
config = AudioEncoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
||||
config = AudioEncoderModelConfig.from_dict(
|
||||
json.load(open(model_path / "config.json"))
|
||||
)
|
||||
encoder = cls(config)
|
||||
weights = mx.load(str(model_path / "model.safetensors"))
|
||||
encoder.load_weights(list(weights.items()), strict=True)
|
||||
@@ -265,7 +280,6 @@ class AudioDecoder(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
||||
# Per-channel statistics for denormalizing latents
|
||||
# Uses ch (base channel count) to match the patchified latent dimension
|
||||
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
|
||||
@@ -305,7 +319,11 @@ class AudioDecoder(nn.Module):
|
||||
self.z_shape = (1, config.z_channels, base_resolution, base_resolution)
|
||||
|
||||
self.conv_in = make_conv2d(
|
||||
config.z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
config.z_channels,
|
||||
base_block_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
self.mid = build_mid_block(
|
||||
@@ -334,9 +352,15 @@ class AudioDecoder(nn.Module):
|
||||
initial_block_channels=base_block_channels,
|
||||
)
|
||||
|
||||
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
|
||||
self.norm_out = build_normalization_layer(
|
||||
final_block_channels, normtype=self.norm_type
|
||||
)
|
||||
self.conv_out = make_conv2d(
|
||||
final_block_channels, config.out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
final_block_channels,
|
||||
config.out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
@@ -371,7 +395,11 @@ class AudioDecoder(nn.Module):
|
||||
# PyTorch: (out_channels, in_channels, H, W)
|
||||
# MLX: (out_channels, H, W, in_channels)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
|
||||
value = (
|
||||
value
|
||||
if check_array_shape(value)
|
||||
else mx.transpose(value, (0, 2, 3, 1))
|
||||
)
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
@@ -380,17 +408,19 @@ class AudioDecoder(nn.Module):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
|
||||
"""Load audio VAE decoder from pretrained model."""
|
||||
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
|
||||
import json
|
||||
|
||||
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
||||
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
|
||||
|
||||
config = AudioDecoderModelConfig.from_dict(
|
||||
json.load(open(model_path / "config.json"))
|
||||
)
|
||||
decoder = cls(config)
|
||||
weights = mx.load(str(model_path / "model.safetensors"))
|
||||
# weights = decoder.sanitize(weights)
|
||||
decoder.load_weights(list(weights.items()), strict=True)
|
||||
return decoder
|
||||
|
||||
|
||||
def __call__(self, sample: mx.array) -> mx.array:
|
||||
"""
|
||||
Decode latent features back to audio spectrograms.
|
||||
@@ -414,7 +444,9 @@ class AudioDecoder(nn.Module):
|
||||
|
||||
return self._adjust_output_shape(h, target_shape)
|
||||
|
||||
def _denormalize_latents(self, sample: mx.array) -> tuple[mx.array, AudioLatentShape]:
|
||||
def _denormalize_latents(
|
||||
self, sample: mx.array
|
||||
) -> tuple[mx.array, AudioLatentShape]:
|
||||
"""Denormalize latents using per-channel statistics."""
|
||||
# sample shape: (B, H, W, C) in MLX format
|
||||
latent_shape = AudioLatentShape(
|
||||
@@ -436,7 +468,9 @@ class AudioDecoder(nn.Module):
|
||||
batch=latent_shape.batch,
|
||||
channels=self.out_ch,
|
||||
frames=target_frames,
|
||||
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
|
||||
mel_bins=(
|
||||
self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins
|
||||
),
|
||||
)
|
||||
|
||||
return sample, target_shape
|
||||
@@ -462,7 +496,10 @@ class AudioDecoder(nn.Module):
|
||||
|
||||
# Step 1: Crop first to avoid exceeding target dimensions
|
||||
decoded_output = decoded_output[
|
||||
:, : min(current_time, target_time), : min(current_freq, target_freq), :target_channels
|
||||
:,
|
||||
: min(current_time, target_time),
|
||||
: min(current_freq, target_freq),
|
||||
:target_channels,
|
||||
]
|
||||
|
||||
# Step 2: Calculate padding needed for time and frequency dimensions
|
||||
@@ -514,7 +551,9 @@ class AudioDecoder(nn.Module):
|
||||
return mx.tanh(h) if self.tanh_out else h
|
||||
|
||||
|
||||
def decode_audio(latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder") -> mx.array:
|
||||
def decode_audio(
|
||||
latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder"
|
||||
) -> mx.array:
|
||||
"""
|
||||
Decode an audio latent representation using the provided audio decoder and vocoder.
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user