Refactor weight loading and sanitization processes for audio models
This commit is contained in:
@@ -15,13 +15,14 @@ Architecture (from PyTorch weights):
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx.video_vae.ops import unpatchify
|
||||
from mlx_video.models.ltx.video_vae.ops import unpatchify, PerChannelStatistics
|
||||
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
|
||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
@@ -269,8 +270,7 @@ class LTX2VideoDecoder(nn.Module):
|
||||
self.decode_timestep = 0.05
|
||||
|
||||
# Per-channel statistics for denormalization (loaded from weights)
|
||||
self.latents_mean = mx.zeros((in_channels,))
|
||||
self.latents_std = mx.ones((in_channels,))
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
|
||||
|
||||
# Initial conv: 128 -> 1024
|
||||
class ConvInWrapper(nn.Module):
|
||||
@@ -346,13 +346,72 @@ class LTX2VideoDecoder(nn.Module):
|
||||
)
|
||||
self.last_scale_shift_table = mx.zeros((2, 128))
|
||||
|
||||
def denormalize(self, x: mx.array) -> mx.array:
|
||||
"""Denormalize latents using per-channel statistics."""
|
||||
dtype = x.dtype
|
||||
# Cast to float32 for precision (statistics may be in bfloat16)
|
||||
mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
return (x * std + mean).astype(dtype)
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
# Build decoder weights dict with key remapping
|
||||
sanitized = {}
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
if not key.startswith("vae.") or key.startswith("vae.encoder."):
|
||||
continue
|
||||
|
||||
if key.startswith("vae.per_channel_statistics."):
|
||||
# Map per-channel statistics (use exact key matching)
|
||||
if key == "vae.per_channel_statistics.mean-of-means":
|
||||
new_key = "per_channel_statistics.mean"
|
||||
elif key == "vae.per_channel_statistics.std-of-means":
|
||||
new_key = "per_channel_statistics.std"
|
||||
else:
|
||||
continue # Skip other statistics keys
|
||||
|
||||
if key.startswith("vae.decoder."):
|
||||
new_key = key.replace("vae.decoder.", "")
|
||||
|
||||
|
||||
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||
if ".conv.weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
if ".conv.bias" in key:
|
||||
pass # bias doesn't need transpose
|
||||
|
||||
if ".conv.weight" in new_key or ".conv.bias" in new_key:
|
||||
|
||||
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
|
||||
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
|
||||
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path, timestep_conditioning: Optional[bool] = None, strict: bool = True) -> "LTX2VideoDecoder":
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
weights = mx.load(str(model_path))
|
||||
|
||||
# Read config from safetensors metadata to auto-detect timestep_conditioning
|
||||
if timestep_conditioning is None:
|
||||
try:
|
||||
with safe_open(str(model_path), framework="numpy") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata and "config" in metadata:
|
||||
configs = json.loads(metadata["config"])
|
||||
vae_config = configs.get("vae", {})
|
||||
timestep_conditioning = vae_config.get("timestep_conditioning", False)
|
||||
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
|
||||
else:
|
||||
timestep_conditioning = False
|
||||
except Exception as e:
|
||||
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
|
||||
timestep_conditioning = False
|
||||
|
||||
model = cls(timestep_conditioning=timestep_conditioning)
|
||||
weights = model.sanitize(weights)
|
||||
model.load_weights(list(weights.items()), strict=strict)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
"""Apply pixel normalization."""
|
||||
@@ -367,28 +426,19 @@ class LTX2VideoDecoder(nn.Module):
|
||||
chunked_conv: bool = False,
|
||||
) -> mx.array:
|
||||
|
||||
def debug_stats(name, t):
|
||||
if debug:
|
||||
mx.eval(t)
|
||||
print(f" [VAE] {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
|
||||
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
if debug:
|
||||
debug_stats("Input", sample)
|
||||
|
||||
|
||||
# Add noise if timestep conditioning is enabled
|
||||
if self.timestep_conditioning:
|
||||
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
||||
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
||||
if debug:
|
||||
debug_stats("After noise", sample)
|
||||
|
||||
|
||||
if debug:
|
||||
print(f" [VAE] Denorm stats - mean: [{self.latents_mean.min().item():.4f}, {self.latents_mean.max().item():.4f}], std: [{self.latents_std.min().item():.4f}, {self.latents_std.max().item():.4f}]")
|
||||
sample = self.denormalize(sample)
|
||||
if debug:
|
||||
debug_stats("After denormalize", sample)
|
||||
sample = self.per_channel_statistics.un_normalize(sample)
|
||||
|
||||
|
||||
if timestep is None and self.timestep_conditioning:
|
||||
timestep = mx.full((batch_size,), self.decode_timestep)
|
||||
@@ -398,8 +448,7 @@ class LTX2VideoDecoder(nn.Module):
|
||||
scaled_timestep = timestep * self.timestep_scale_multiplier
|
||||
|
||||
x = self.conv_in(sample, causal=causal)
|
||||
if debug:
|
||||
debug_stats("After conv_in", x)
|
||||
|
||||
|
||||
for i, block in self.up_blocks.items():
|
||||
if isinstance(block, ResBlockGroup):
|
||||
@@ -408,13 +457,10 @@ class LTX2VideoDecoder(nn.Module):
|
||||
x = block(x, causal=causal, chunked_conv=chunked_conv)
|
||||
else:
|
||||
x = block(x, causal=causal)
|
||||
if debug:
|
||||
block_type = type(block).__name__
|
||||
debug_stats(f"After up_blocks[{i}] ({block_type})", x)
|
||||
|
||||
|
||||
x = self.pixel_norm(x)
|
||||
if debug:
|
||||
debug_stats("After pixel_norm", x)
|
||||
|
||||
|
||||
if self.timestep_conditioning and scaled_timestep is not None:
|
||||
embedded_timestep = self.last_time_embedder(
|
||||
@@ -431,21 +477,16 @@ class LTX2VideoDecoder(nn.Module):
|
||||
scale = ada_values[:, 1]
|
||||
|
||||
x = x * (1 + scale) + shift
|
||||
if debug:
|
||||
debug_stats("After timestep modulation", x)
|
||||
|
||||
|
||||
x = self.act(x)
|
||||
if debug:
|
||||
debug_stats("After activation", x)
|
||||
|
||||
|
||||
x = self.conv_out(x, causal=causal)
|
||||
if debug:
|
||||
debug_stats("After conv_out", x)
|
||||
|
||||
|
||||
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
|
||||
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
if debug:
|
||||
debug_stats("After unpatchify", x)
|
||||
|
||||
|
||||
return x
|
||||
|
||||
@@ -519,103 +560,3 @@ class LTX2VideoDecoder(nn.Module):
|
||||
chunked_conv=use_chunked_conv,
|
||||
on_frames_ready=on_frames_ready,
|
||||
)
|
||||
|
||||
|
||||
def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder:
|
||||
from pathlib import Path
|
||||
import json
|
||||
from safetensors import safe_open
|
||||
|
||||
model_path = Path(model_path)
|
||||
|
||||
# Try to find the weights file
|
||||
if model_path.is_file() and model_path.suffix == ".safetensors":
|
||||
weights_path = model_path
|
||||
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
|
||||
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
|
||||
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
|
||||
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
|
||||
else:
|
||||
raise FileNotFoundError(f"VAE weights not found at {model_path}")
|
||||
|
||||
print(f"Loading VAE decoder from {weights_path}...")
|
||||
|
||||
# Read config from safetensors metadata to auto-detect timestep_conditioning
|
||||
if timestep_conditioning is None:
|
||||
try:
|
||||
with safe_open(str(weights_path), framework="numpy") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata and "config" in metadata:
|
||||
configs = json.loads(metadata["config"])
|
||||
vae_config = configs.get("vae", {})
|
||||
timestep_conditioning = vae_config.get("timestep_conditioning", False)
|
||||
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
|
||||
else:
|
||||
timestep_conditioning = False
|
||||
except Exception as e:
|
||||
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
|
||||
timestep_conditioning = False
|
||||
|
||||
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
|
||||
|
||||
weights = mx.load(str(weights_path))
|
||||
|
||||
# Determine prefix based on weight keys
|
||||
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
|
||||
has_decoder_prefix = any(k.startswith("decoder.") for k in weights.keys())
|
||||
|
||||
if has_vae_prefix:
|
||||
prefix = "vae.decoder."
|
||||
stats_prefix = "vae.per_channel_statistics."
|
||||
elif has_decoder_prefix:
|
||||
prefix = "decoder."
|
||||
stats_prefix = ""
|
||||
else:
|
||||
prefix = ""
|
||||
stats_prefix = ""
|
||||
|
||||
# Load per-channel statistics for denormalization
|
||||
# Note: use std-of-means (not mean-of-stds) for proper denormalization
|
||||
mean_key = f"{stats_prefix}mean-of-means" if stats_prefix else "latents_mean"
|
||||
std_key = f"{stats_prefix}std-of-means" if stats_prefix else "latents_std"
|
||||
|
||||
if mean_key in weights:
|
||||
decoder.latents_mean = weights[mean_key]
|
||||
print(f" Loaded latent mean: shape {decoder.latents_mean.shape}")
|
||||
if std_key in weights:
|
||||
decoder.latents_std = weights[std_key]
|
||||
print(f" Loaded latent std: shape {decoder.latents_std.shape}")
|
||||
|
||||
# Build decoder weights dict with key remapping
|
||||
decoder_weights = {}
|
||||
for key, value in weights.items():
|
||||
if not key.startswith(prefix):
|
||||
continue
|
||||
|
||||
# Remove prefix
|
||||
new_key = key[len(prefix):]
|
||||
|
||||
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||
if ".conv.weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
if ".conv.bias" in key:
|
||||
pass # bias doesn't need transpose
|
||||
|
||||
|
||||
if ".conv.weight" in new_key or ".conv.bias" in new_key:
|
||||
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
|
||||
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
|
||||
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
|
||||
|
||||
decoder_weights[new_key] = value
|
||||
|
||||
print(f" Found {len(decoder_weights)} decoder weights")
|
||||
|
||||
ts_keys = [k for k in decoder_weights.keys() if "scale_shift" in k or "time_embedder" in k or "timestep_scale" in k]
|
||||
print(f" Found {len(ts_keys)} timestep conditioning weights")
|
||||
|
||||
# Load weights
|
||||
decoder.load_weights(list(decoder_weights.items()), strict=False)
|
||||
|
||||
print("VAE decoder loaded successfully")
|
||||
return decoder
|
||||
|
||||
Reference in New Issue
Block a user