Add audio to video conditioning

This commit is contained in:
Prince Canuma
2026-03-16 01:42:11 +01:00
parent f53b9e0807
commit 6f6105b715
7 changed files with 623 additions and 62 deletions

View File

@@ -6,10 +6,11 @@ from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from mlx_vlm.models.base import check_array_shape
from ..config import AudioDecoderModelConfig
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig
from .attention import AttentionType, make_attn
from .causal_conv_2d import make_conv2d
from ..config import CausalityAxis
from .downsample import build_downsampling_path
from .normalization import NormType, build_normalization_layer
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
from .resnet import ResnetBlock
@@ -59,6 +60,179 @@ def run_mid_block(mid: dict, features: mx.array) -> mx.array:
return mid["block_2"](features, temb=None)
class AudioEncoder(nn.Module):
"""Encoder that compresses audio spectrograms into latent representations."""
def __init__(self, config: AudioEncoderModelConfig) -> None:
super().__init__()
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch)
self.sample_rate = config.sample_rate
self.mel_hop_length = config.mel_hop_length
self.is_causal = config.is_causal
self.mel_bins = config.mel_bins
self.patchifier = AudioPatchifier(
patch_size=1,
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
sample_rate=config.sample_rate,
hop_length=config.mel_hop_length,
is_causal=config.is_causal,
)
self.ch = config.ch
self.temb_ch = 0
self.num_resolutions = len(config.ch_mult)
self.num_res_blocks = config.num_res_blocks
self.resolution = config.resolution
self.in_channels = config.in_channels
self.z_channels = config.z_channels
self.double_z = config.double_z
self.norm_type = config.norm_type
self.causality_axis = config.causality_axis
self.attn_type = config.attn_type
self.conv_in = make_conv2d(
config.in_channels, self.ch, kernel_size=3, stride=1,
causality_axis=self.causality_axis,
)
self.down, block_in = build_downsampling_path(
ch=config.ch,
ch_mult=config.ch_mult,
num_resolutions=self.num_resolutions,
num_res_blocks=config.num_res_blocks,
resolution=config.resolution,
temb_channels=self.temb_ch,
dropout=config.dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
attn_resolutions=config.attn_resolutions or set(),
resamp_with_conv=config.resamp_with_conv,
)
self.mid = build_mid_block(
channels=block_in,
temb_channels=self.temb_ch,
dropout=config.dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
add_attention=config.mid_block_add_attention,
)
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
out_channels = 2 * config.z_channels if config.double_z else config.z_channels
self.conv_out = make_conv2d(
block_in, out_channels, kernel_size=3, stride=1,
causality_axis=self.causality_axis,
)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize audio encoder weights from PyTorch format."""
sanitized = {}
for key, value in weights.items():
new_key = key
if key.startswith("audio_vae.encoder."):
new_key = key.replace("audio_vae.encoder.", "")
elif key.startswith("encoder."):
new_key = key.replace("encoder.", "")
elif key.startswith("audio_vae.per_channel_statistics."):
if "mean-of-means" in key:
new_key = "per_channel_statistics.mean_of_means"
elif "std-of-means" in key:
new_key = "per_channel_statistics.std_of_means"
else:
continue
elif "per_channel_statistics" in key:
if "mean-of-means" in key or "latents_mean" in key:
new_key = "per_channel_statistics.mean_of_means"
elif "std-of-means" in key or "latents_std" in key:
new_key = "per_channel_statistics.std_of_means"
else:
continue
elif key == "latents_mean":
new_key = "per_channel_statistics.mean_of_means"
elif key == "latents_std":
new_key = "per_channel_statistics.std_of_means"
else:
continue
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
"""Load audio encoder from pretrained weights."""
from mlx_video.models.ltx.config import AudioEncoderModelConfig
import json
model_path = Path(model_path)
config = AudioEncoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
encoder = cls(config)
weights = mx.load(str(model_path / "model.safetensors"))
encoder.load_weights(list(weights.items()), strict=True)
return encoder
def __call__(self, spectrogram: mx.array) -> mx.array:
"""Encode audio spectrogram into normalized latent representation.
Args:
spectrogram: (B, C, T, F) PyTorch format or (B, T, F, C) MLX format.
Returns:
Normalized latent (B, T', F', z_channels) in MLX channels-last format.
"""
if spectrogram.ndim == 4 and spectrogram.shape[1] == self.in_channels:
spectrogram = mx.transpose(spectrogram, (0, 2, 3, 1))
h = self.conv_in(spectrogram)
h = self._run_downsampling_path(h)
h = run_mid_block(self.mid, h)
h = self._finalize_output(h)
return self._normalize_latents(h)
def _run_downsampling_path(self, h: mx.array) -> mx.array:
for level in range(self.num_resolutions):
stage = self.down[level]
for block_idx in range(self.num_res_blocks):
h = stage["block"][block_idx](h, temb=None)
if block_idx in stage["attn"]:
h = stage["attn"][block_idx](h)
if level != self.num_resolutions - 1 and "downsample" in stage:
h = stage["downsample"](h)
return h
def _finalize_output(self, h: mx.array) -> mx.array:
h = self.norm_out(h)
h = nn.silu(h)
return self.conv_out(h)
def _normalize_latents(self, h: mx.array) -> mx.array:
"""Normalize encoder output using per-channel statistics.
Takes first half of channels (mean) when double_z=True,
then patchifies, normalizes, and unpatchifies.
"""
# h shape: (B, T', F', 2*z_channels) in MLX format
z_channels = self.z_channels
means = h[..., :z_channels]
latent_shape = AudioLatentShape(
batch=means.shape[0],
channels=means.shape[3],
frames=means.shape[1],
mel_bins=means.shape[2],
)
patched = self.patchifier.patchify(means)
normalized = self.per_channel_statistics.normalize(patched)
return self.patchifier.unpatchify(normalized, latent_shape)
class AudioDecoder(nn.Module):
"""
Symmetric decoder that reconstructs audio spectrograms from latent features.