Refactor weight loading and sanitization processes for audio models

This commit is contained in:
Prince Canuma
2026-01-23 17:31:25 +01:00
parent 2681f75d2f
commit 02bfa228d9
18 changed files with 510 additions and 498 deletions

View File

@@ -1,6 +1,6 @@
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder, encode_image
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder, LTX2VideoDecoder
from mlx_video.models.ltx.video_vae.encoder import encode_image
from mlx_video.models.ltx.video_vae.decoder import LTX2VideoDecoder
from mlx_video.models.ltx.video_vae.tiling import (
TilingConfig,
SpatialTilingConfig,

View File

@@ -15,13 +15,14 @@ Architecture (from PyTorch weights):
"""
import math
from typing import Optional
from typing import Optional, Dict
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx.video_vae.ops import unpatchify
from mlx_video.models.ltx.video_vae.ops import unpatchify, PerChannelStatistics
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
@@ -269,8 +270,7 @@ class LTX2VideoDecoder(nn.Module):
self.decode_timestep = 0.05
# Per-channel statistics for denormalization (loaded from weights)
self.latents_mean = mx.zeros((in_channels,))
self.latents_std = mx.ones((in_channels,))
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
# Initial conv: 128 -> 1024
class ConvInWrapper(nn.Module):
@@ -346,13 +346,72 @@ class LTX2VideoDecoder(nn.Module):
)
self.last_scale_shift_table = mx.zeros((2, 128))
def denormalize(self, x: mx.array) -> mx.array:
"""Denormalize latents using per-channel statistics."""
dtype = x.dtype
# Cast to float32 for precision (statistics may be in bfloat16)
mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
return (x * std + mean).astype(dtype)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
# Build decoder weights dict with key remapping
sanitized = {}
for key, value in weights.items():
new_key = key
if not key.startswith("vae.") or key.startswith("vae.encoder."):
continue
if key.startswith("vae.per_channel_statistics."):
# Map per-channel statistics (use exact key matching)
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
continue # Skip other statistics keys
if key.startswith("vae.decoder."):
new_key = key.replace("vae.decoder.", "")
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".conv.weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
if ".conv.bias" in key:
pass # bias doesn't need transpose
if ".conv.weight" in new_key or ".conv.bias" in new_key:
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path, timestep_conditioning: Optional[bool] = None, strict: bool = True) -> "LTX2VideoDecoder":
from safetensors import safe_open
import json
weights = mx.load(str(model_path))
# Read config from safetensors metadata to auto-detect timestep_conditioning
if timestep_conditioning is None:
try:
with safe_open(str(model_path), framework="numpy") as f:
metadata = f.metadata()
if metadata and "config" in metadata:
configs = json.loads(metadata["config"])
vae_config = configs.get("vae", {})
timestep_conditioning = vae_config.get("timestep_conditioning", False)
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
else:
timestep_conditioning = False
except Exception as e:
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
timestep_conditioning = False
model = cls(timestep_conditioning=timestep_conditioning)
weights = model.sanitize(weights)
model.load_weights(list(weights.items()), strict=strict)
return model
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization."""
@@ -367,28 +426,19 @@ class LTX2VideoDecoder(nn.Module):
chunked_conv: bool = False,
) -> mx.array:
def debug_stats(name, t):
if debug:
mx.eval(t)
print(f" [VAE] {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
batch_size = sample.shape[0]
if debug:
debug_stats("Input", sample)
# Add noise if timestep conditioning is enabled
if self.timestep_conditioning:
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
sample = noise + (1.0 - self.decode_noise_scale) * sample
if debug:
debug_stats("After noise", sample)
if debug:
print(f" [VAE] Denorm stats - mean: [{self.latents_mean.min().item():.4f}, {self.latents_mean.max().item():.4f}], std: [{self.latents_std.min().item():.4f}, {self.latents_std.max().item():.4f}]")
sample = self.denormalize(sample)
if debug:
debug_stats("After denormalize", sample)
sample = self.per_channel_statistics.un_normalize(sample)
if timestep is None and self.timestep_conditioning:
timestep = mx.full((batch_size,), self.decode_timestep)
@@ -398,8 +448,7 @@ class LTX2VideoDecoder(nn.Module):
scaled_timestep = timestep * self.timestep_scale_multiplier
x = self.conv_in(sample, causal=causal)
if debug:
debug_stats("After conv_in", x)
for i, block in self.up_blocks.items():
if isinstance(block, ResBlockGroup):
@@ -408,13 +457,10 @@ class LTX2VideoDecoder(nn.Module):
x = block(x, causal=causal, chunked_conv=chunked_conv)
else:
x = block(x, causal=causal)
if debug:
block_type = type(block).__name__
debug_stats(f"After up_blocks[{i}] ({block_type})", x)
x = self.pixel_norm(x)
if debug:
debug_stats("After pixel_norm", x)
if self.timestep_conditioning and scaled_timestep is not None:
embedded_timestep = self.last_time_embedder(
@@ -431,21 +477,16 @@ class LTX2VideoDecoder(nn.Module):
scale = ada_values[:, 1]
x = x * (1 + scale) + shift
if debug:
debug_stats("After timestep modulation", x)
x = self.act(x)
if debug:
debug_stats("After activation", x)
x = self.conv_out(x, causal=causal)
if debug:
debug_stats("After conv_out", x)
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
if debug:
debug_stats("After unpatchify", x)
return x
@@ -519,103 +560,3 @@ class LTX2VideoDecoder(nn.Module):
chunked_conv=use_chunked_conv,
on_frames_ready=on_frames_ready,
)
def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder:
from pathlib import Path
import json
from safetensors import safe_open
model_path = Path(model_path)
# Try to find the weights file
if model_path.is_file() and model_path.suffix == ".safetensors":
weights_path = model_path
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
else:
raise FileNotFoundError(f"VAE weights not found at {model_path}")
print(f"Loading VAE decoder from {weights_path}...")
# Read config from safetensors metadata to auto-detect timestep_conditioning
if timestep_conditioning is None:
try:
with safe_open(str(weights_path), framework="numpy") as f:
metadata = f.metadata()
if metadata and "config" in metadata:
configs = json.loads(metadata["config"])
vae_config = configs.get("vae", {})
timestep_conditioning = vae_config.get("timestep_conditioning", False)
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
else:
timestep_conditioning = False
except Exception as e:
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
timestep_conditioning = False
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
weights = mx.load(str(weights_path))
# Determine prefix based on weight keys
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
has_decoder_prefix = any(k.startswith("decoder.") for k in weights.keys())
if has_vae_prefix:
prefix = "vae.decoder."
stats_prefix = "vae.per_channel_statistics."
elif has_decoder_prefix:
prefix = "decoder."
stats_prefix = ""
else:
prefix = ""
stats_prefix = ""
# Load per-channel statistics for denormalization
# Note: use std-of-means (not mean-of-stds) for proper denormalization
mean_key = f"{stats_prefix}mean-of-means" if stats_prefix else "latents_mean"
std_key = f"{stats_prefix}std-of-means" if stats_prefix else "latents_std"
if mean_key in weights:
decoder.latents_mean = weights[mean_key]
print(f" Loaded latent mean: shape {decoder.latents_mean.shape}")
if std_key in weights:
decoder.latents_std = weights[std_key]
print(f" Loaded latent std: shape {decoder.latents_std.shape}")
# Build decoder weights dict with key remapping
decoder_weights = {}
for key, value in weights.items():
if not key.startswith(prefix):
continue
# Remove prefix
new_key = key[len(prefix):]
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".conv.weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
if ".conv.bias" in key:
pass # bias doesn't need transpose
if ".conv.weight" in new_key or ".conv.bias" in new_key:
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
decoder_weights[new_key] = value
print(f" Found {len(decoder_weights)} decoder weights")
ts_keys = [k for k in decoder_weights.keys() if "scale_shift" in k or "time_embedder" in k or "timestep_scale" in k]
print(f" Found {len(ts_keys)} timestep conditioning weights")
# Load weights
decoder.load_weights(list(decoder_weights.items()), strict=False)
print("VAE decoder loaded successfully")
return decoder

View File

@@ -5,152 +5,9 @@ Used for I2V (image-to-video) conditioning by encoding the input image
to latent space, which can then be used to condition video generation.
"""
from pathlib import Path
from typing import List, Tuple, Any, Optional
import json
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, LogVarianceType, NormLayerType, PaddingModeType
def load_vae_encoder(model_path: str) -> VideoEncoder:
"""Load VAE encoder from safetensors file.
Args:
model_path: Path to the model weights (safetensors file or directory)
Returns:
Loaded VideoEncoder instance
"""
from safetensors import safe_open
model_path = Path(model_path)
# Try to find the weights file
if model_path.is_file() and model_path.suffix == ".safetensors":
weights_path = model_path
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
else:
raise FileNotFoundError(f"VAE weights not found at {model_path}")
print(f"Loading VAE encoder from {weights_path}...")
# Read config from safetensors metadata
encoder_blocks = []
norm_layer = NormLayerType.PIXEL_NORM
latent_log_var = LogVarianceType.UNIFORM
patch_size = 4
try:
with safe_open(str(weights_path), framework="numpy") as f:
metadata = f.metadata()
if metadata and "config" in metadata:
configs = json.loads(metadata["config"])
vae_config = configs.get("vae", {})
# Parse encoder blocks
raw_blocks = vae_config.get("encoder_blocks", [])
for block in raw_blocks:
if isinstance(block, list) and len(block) == 2:
name, params = block
encoder_blocks.append((name, params))
# Parse other config
norm_str = vae_config.get("norm_layer", "pixel_norm")
norm_layer = NormLayerType.PIXEL_NORM if norm_str == "pixel_norm" else NormLayerType.GROUP_NORM
var_str = vae_config.get("latent_log_var", "uniform")
if var_str == "uniform":
latent_log_var = LogVarianceType.UNIFORM
elif var_str == "per_channel":
latent_log_var = LogVarianceType.PER_CHANNEL
elif var_str == "constant":
latent_log_var = LogVarianceType.CONSTANT
else:
latent_log_var = LogVarianceType.NONE
patch_size = vae_config.get("patch_size", 4)
print(f" Loaded config: {len(encoder_blocks)} encoder blocks, norm={norm_str}, patch_size={patch_size}")
except Exception as e:
print(f" Could not read config from metadata: {e}")
# Use default config
encoder_blocks = [
("res_x", {"num_layers": 4}),
("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_time_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
]
print(f" Using default encoder config with {len(encoder_blocks)} blocks")
# Create encoder
encoder = VideoEncoder(
convolution_dimensions=3,
in_channels=3,
out_channels=128,
encoder_blocks=encoder_blocks,
patch_size=patch_size,
norm_layer=norm_layer,
latent_log_var=latent_log_var,
encoder_spatial_padding_mode=PaddingModeType.ZEROS,
)
# Load weights
weights = mx.load(str(weights_path))
# Determine prefix based on weight keys
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
if has_vae_prefix:
prefix = "vae.encoder."
stats_prefix = "vae.per_channel_statistics."
else:
prefix = "encoder."
stats_prefix = "per_channel_statistics."
# Load per-channel statistics for normalization
mean_key = f"{stats_prefix}mean-of-means"
std_key = f"{stats_prefix}std-of-means"
if mean_key in weights:
encoder.per_channel_statistics.mean = weights[mean_key]
print(f" Loaded latent mean: shape {weights[mean_key].shape}")
if std_key in weights:
encoder.per_channel_statistics.std = weights[std_key]
print(f" Loaded latent std: shape {weights[std_key].shape}")
# Build encoder weights dict with key remapping
encoder_weights = {}
for key, value in weights.items():
if not key.startswith(prefix):
continue
# Remove prefix
new_key = key[len(prefix):]
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
encoder_weights[new_key] = value
print(f" Found {len(encoder_weights)} encoder weights")
# Load weights
encoder.load_weights(list(encoder_weights.items()), strict=False)
print("VAE encoder loaded successfully")
return encoder
def encode_image(

View File

@@ -273,9 +273,10 @@ class VideoEncoder(nn.Module):
spatial_padding_mode=encoder_spatial_padding_mode,
)
# Build encoder blocks - use dict with int keys for MLX parameter tracking
# Build encoder blocks
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
self.down_blocks = {}
for i, (block_name, block_params) in enumerate(encoder_blocks):
for idx, (block_name, block_params) in enumerate(encoder_blocks):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block, feature_channels = _make_encoder_block(
@@ -287,7 +288,7 @@ class VideoEncoder(nn.Module):
norm_num_groups=self._norm_num_groups,
spatial_padding_mode=encoder_spatial_padding_mode,
)
self.down_blocks[i] = block
self.down_blocks[idx] = block
# Output normalization and convolution
if norm_layer == NormLayerType.GROUP_NORM:
@@ -341,7 +342,8 @@ class VideoEncoder(nn.Module):
sample = self.conv_in(sample, causal=True)
# Process through encoder blocks
for down_block in self.down_blocks.values():
for i in range(len(self.down_blocks)):
down_block = self.down_blocks[i]
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
sample = down_block(sample, causal=True)
else:
@@ -440,8 +442,9 @@ class VideoDecoder(nn.Module):
)
# Build decoder blocks (reversed order)
self.up_blocks = []
for block_name, block_params in list(reversed(decoder_blocks)):
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
self.up_blocks = {}
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block, feature_channels = _make_decoder_block(
@@ -454,7 +457,7 @@ class VideoDecoder(nn.Module):
norm_num_groups=self._norm_num_groups,
spatial_padding_mode=decoder_spatial_padding_mode,
)
self.up_blocks.append(block)
self.up_blocks[idx] = block
# Output normalization
if norm_layer == NormLayerType.GROUP_NORM:
@@ -509,7 +512,8 @@ class VideoDecoder(nn.Module):
sample = self.conv_in(sample, causal=self.causal)
# Process through decoder blocks
for up_block in self.up_blocks:
for i in range(len(self.up_blocks)):
up_block = self.up_blocks[i]
if isinstance(up_block, UNetMidBlock3D):
sample = up_block(sample, causal=self.causal)
elif isinstance(up_block, ResnetBlock3D):