Refactor weight loading and sanitization processes for audio models

This commit is contained in:
Prince Canuma
2026-01-23 17:31:25 +01:00
parent 2681f75d2f
commit 02bfa228d9
18 changed files with 510 additions and 498 deletions

View File

@@ -15,13 +15,14 @@ Architecture (from PyTorch weights):
"""
import math
from typing import Optional
from typing import Optional, Dict
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx.video_vae.ops import unpatchify
from mlx_video.models.ltx.video_vae.ops import unpatchify, PerChannelStatistics
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
@@ -269,8 +270,7 @@ class LTX2VideoDecoder(nn.Module):
self.decode_timestep = 0.05
# Per-channel statistics for denormalization (loaded from weights)
self.latents_mean = mx.zeros((in_channels,))
self.latents_std = mx.ones((in_channels,))
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
# Initial conv: 128 -> 1024
class ConvInWrapper(nn.Module):
@@ -346,13 +346,72 @@ class LTX2VideoDecoder(nn.Module):
)
self.last_scale_shift_table = mx.zeros((2, 128))
def denormalize(self, x: mx.array) -> mx.array:
"""Denormalize latents using per-channel statistics."""
dtype = x.dtype
# Cast to float32 for precision (statistics may be in bfloat16)
mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
return (x * std + mean).astype(dtype)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
# Build decoder weights dict with key remapping
sanitized = {}
for key, value in weights.items():
new_key = key
if not key.startswith("vae.") or key.startswith("vae.encoder."):
continue
if key.startswith("vae.per_channel_statistics."):
# Map per-channel statistics (use exact key matching)
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
continue # Skip other statistics keys
if key.startswith("vae.decoder."):
new_key = key.replace("vae.decoder.", "")
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".conv.weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
if ".conv.bias" in key:
pass # bias doesn't need transpose
if ".conv.weight" in new_key or ".conv.bias" in new_key:
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path, timestep_conditioning: Optional[bool] = None, strict: bool = True) -> "LTX2VideoDecoder":
from safetensors import safe_open
import json
weights = mx.load(str(model_path))
# Read config from safetensors metadata to auto-detect timestep_conditioning
if timestep_conditioning is None:
try:
with safe_open(str(model_path), framework="numpy") as f:
metadata = f.metadata()
if metadata and "config" in metadata:
configs = json.loads(metadata["config"])
vae_config = configs.get("vae", {})
timestep_conditioning = vae_config.get("timestep_conditioning", False)
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
else:
timestep_conditioning = False
except Exception as e:
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
timestep_conditioning = False
model = cls(timestep_conditioning=timestep_conditioning)
weights = model.sanitize(weights)
model.load_weights(list(weights.items()), strict=strict)
return model
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization."""
@@ -367,28 +426,19 @@ class LTX2VideoDecoder(nn.Module):
chunked_conv: bool = False,
) -> mx.array:
def debug_stats(name, t):
if debug:
mx.eval(t)
print(f" [VAE] {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
batch_size = sample.shape[0]
if debug:
debug_stats("Input", sample)
# Add noise if timestep conditioning is enabled
if self.timestep_conditioning:
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
sample = noise + (1.0 - self.decode_noise_scale) * sample
if debug:
debug_stats("After noise", sample)
if debug:
print(f" [VAE] Denorm stats - mean: [{self.latents_mean.min().item():.4f}, {self.latents_mean.max().item():.4f}], std: [{self.latents_std.min().item():.4f}, {self.latents_std.max().item():.4f}]")
sample = self.denormalize(sample)
if debug:
debug_stats("After denormalize", sample)
sample = self.per_channel_statistics.un_normalize(sample)
if timestep is None and self.timestep_conditioning:
timestep = mx.full((batch_size,), self.decode_timestep)
@@ -398,8 +448,7 @@ class LTX2VideoDecoder(nn.Module):
scaled_timestep = timestep * self.timestep_scale_multiplier
x = self.conv_in(sample, causal=causal)
if debug:
debug_stats("After conv_in", x)
for i, block in self.up_blocks.items():
if isinstance(block, ResBlockGroup):
@@ -408,13 +457,10 @@ class LTX2VideoDecoder(nn.Module):
x = block(x, causal=causal, chunked_conv=chunked_conv)
else:
x = block(x, causal=causal)
if debug:
block_type = type(block).__name__
debug_stats(f"After up_blocks[{i}] ({block_type})", x)
x = self.pixel_norm(x)
if debug:
debug_stats("After pixel_norm", x)
if self.timestep_conditioning and scaled_timestep is not None:
embedded_timestep = self.last_time_embedder(
@@ -431,21 +477,16 @@ class LTX2VideoDecoder(nn.Module):
scale = ada_values[:, 1]
x = x * (1 + scale) + shift
if debug:
debug_stats("After timestep modulation", x)
x = self.act(x)
if debug:
debug_stats("After activation", x)
x = self.conv_out(x, causal=causal)
if debug:
debug_stats("After conv_out", x)
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
if debug:
debug_stats("After unpatchify", x)
return x
@@ -519,103 +560,3 @@ class LTX2VideoDecoder(nn.Module):
chunked_conv=use_chunked_conv,
on_frames_ready=on_frames_ready,
)
def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder:
from pathlib import Path
import json
from safetensors import safe_open
model_path = Path(model_path)
# Try to find the weights file
if model_path.is_file() and model_path.suffix == ".safetensors":
weights_path = model_path
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
else:
raise FileNotFoundError(f"VAE weights not found at {model_path}")
print(f"Loading VAE decoder from {weights_path}...")
# Read config from safetensors metadata to auto-detect timestep_conditioning
if timestep_conditioning is None:
try:
with safe_open(str(weights_path), framework="numpy") as f:
metadata = f.metadata()
if metadata and "config" in metadata:
configs = json.loads(metadata["config"])
vae_config = configs.get("vae", {})
timestep_conditioning = vae_config.get("timestep_conditioning", False)
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
else:
timestep_conditioning = False
except Exception as e:
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
timestep_conditioning = False
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
weights = mx.load(str(weights_path))
# Determine prefix based on weight keys
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
has_decoder_prefix = any(k.startswith("decoder.") for k in weights.keys())
if has_vae_prefix:
prefix = "vae.decoder."
stats_prefix = "vae.per_channel_statistics."
elif has_decoder_prefix:
prefix = "decoder."
stats_prefix = ""
else:
prefix = ""
stats_prefix = ""
# Load per-channel statistics for denormalization
# Note: use std-of-means (not mean-of-stds) for proper denormalization
mean_key = f"{stats_prefix}mean-of-means" if stats_prefix else "latents_mean"
std_key = f"{stats_prefix}std-of-means" if stats_prefix else "latents_std"
if mean_key in weights:
decoder.latents_mean = weights[mean_key]
print(f" Loaded latent mean: shape {decoder.latents_mean.shape}")
if std_key in weights:
decoder.latents_std = weights[std_key]
print(f" Loaded latent std: shape {decoder.latents_std.shape}")
# Build decoder weights dict with key remapping
decoder_weights = {}
for key, value in weights.items():
if not key.startswith(prefix):
continue
# Remove prefix
new_key = key[len(prefix):]
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".conv.weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
if ".conv.bias" in key:
pass # bias doesn't need transpose
if ".conv.weight" in new_key or ".conv.bias" in new_key:
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
decoder_weights[new_key] = value
print(f" Found {len(decoder_weights)} decoder weights")
ts_keys = [k for k in decoder_weights.keys() if "scale_shift" in k or "time_embedder" in k or "timestep_scale" in k]
print(f" Found {len(ts_keys)} timestep conditioning weights")
# Load weights
decoder.load_weights(list(decoder_weights.items()), strict=False)
print("VAE decoder loaded successfully")
return decoder