Refactor weight loading and sanitization processes for audio models
This commit is contained in:
@@ -5,152 +5,9 @@ Used for I2V (image-to-video) conditioning by encoding the input image
|
||||
to latent space, which can then be used to condition video generation.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Any, Optional
|
||||
import json
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder
|
||||
|
||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, LogVarianceType, NormLayerType, PaddingModeType
|
||||
|
||||
|
||||
def load_vae_encoder(model_path: str) -> VideoEncoder:
|
||||
"""Load VAE encoder from safetensors file.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model weights (safetensors file or directory)
|
||||
|
||||
Returns:
|
||||
Loaded VideoEncoder instance
|
||||
"""
|
||||
from safetensors import safe_open
|
||||
|
||||
model_path = Path(model_path)
|
||||
|
||||
# Try to find the weights file
|
||||
if model_path.is_file() and model_path.suffix == ".safetensors":
|
||||
weights_path = model_path
|
||||
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
|
||||
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
|
||||
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
|
||||
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
|
||||
else:
|
||||
raise FileNotFoundError(f"VAE weights not found at {model_path}")
|
||||
|
||||
print(f"Loading VAE encoder from {weights_path}...")
|
||||
|
||||
# Read config from safetensors metadata
|
||||
encoder_blocks = []
|
||||
norm_layer = NormLayerType.PIXEL_NORM
|
||||
latent_log_var = LogVarianceType.UNIFORM
|
||||
patch_size = 4
|
||||
|
||||
try:
|
||||
with safe_open(str(weights_path), framework="numpy") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata and "config" in metadata:
|
||||
configs = json.loads(metadata["config"])
|
||||
vae_config = configs.get("vae", {})
|
||||
|
||||
# Parse encoder blocks
|
||||
raw_blocks = vae_config.get("encoder_blocks", [])
|
||||
for block in raw_blocks:
|
||||
if isinstance(block, list) and len(block) == 2:
|
||||
name, params = block
|
||||
encoder_blocks.append((name, params))
|
||||
|
||||
# Parse other config
|
||||
norm_str = vae_config.get("norm_layer", "pixel_norm")
|
||||
norm_layer = NormLayerType.PIXEL_NORM if norm_str == "pixel_norm" else NormLayerType.GROUP_NORM
|
||||
|
||||
var_str = vae_config.get("latent_log_var", "uniform")
|
||||
if var_str == "uniform":
|
||||
latent_log_var = LogVarianceType.UNIFORM
|
||||
elif var_str == "per_channel":
|
||||
latent_log_var = LogVarianceType.PER_CHANNEL
|
||||
elif var_str == "constant":
|
||||
latent_log_var = LogVarianceType.CONSTANT
|
||||
else:
|
||||
latent_log_var = LogVarianceType.NONE
|
||||
|
||||
patch_size = vae_config.get("patch_size", 4)
|
||||
|
||||
print(f" Loaded config: {len(encoder_blocks)} encoder blocks, norm={norm_str}, patch_size={patch_size}")
|
||||
except Exception as e:
|
||||
print(f" Could not read config from metadata: {e}")
|
||||
# Use default config
|
||||
encoder_blocks = [
|
||||
("res_x", {"num_layers": 4}),
|
||||
("compress_space_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_time_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
]
|
||||
print(f" Using default encoder config with {len(encoder_blocks)} blocks")
|
||||
|
||||
# Create encoder
|
||||
encoder = VideoEncoder(
|
||||
convolution_dimensions=3,
|
||||
in_channels=3,
|
||||
out_channels=128,
|
||||
encoder_blocks=encoder_blocks,
|
||||
patch_size=patch_size,
|
||||
norm_layer=norm_layer,
|
||||
latent_log_var=latent_log_var,
|
||||
encoder_spatial_padding_mode=PaddingModeType.ZEROS,
|
||||
)
|
||||
|
||||
# Load weights
|
||||
weights = mx.load(str(weights_path))
|
||||
|
||||
# Determine prefix based on weight keys
|
||||
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
|
||||
|
||||
if has_vae_prefix:
|
||||
prefix = "vae.encoder."
|
||||
stats_prefix = "vae.per_channel_statistics."
|
||||
else:
|
||||
prefix = "encoder."
|
||||
stats_prefix = "per_channel_statistics."
|
||||
|
||||
# Load per-channel statistics for normalization
|
||||
mean_key = f"{stats_prefix}mean-of-means"
|
||||
std_key = f"{stats_prefix}std-of-means"
|
||||
|
||||
if mean_key in weights:
|
||||
encoder.per_channel_statistics.mean = weights[mean_key]
|
||||
print(f" Loaded latent mean: shape {weights[mean_key].shape}")
|
||||
if std_key in weights:
|
||||
encoder.per_channel_statistics.std = weights[std_key]
|
||||
print(f" Loaded latent std: shape {weights[std_key].shape}")
|
||||
|
||||
# Build encoder weights dict with key remapping
|
||||
encoder_weights = {}
|
||||
for key, value in weights.items():
|
||||
if not key.startswith(prefix):
|
||||
continue
|
||||
|
||||
# Remove prefix
|
||||
new_key = key[len(prefix):]
|
||||
|
||||
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||
if ".weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
encoder_weights[new_key] = value
|
||||
|
||||
print(f" Found {len(encoder_weights)} encoder weights")
|
||||
|
||||
# Load weights
|
||||
encoder.load_weights(list(encoder_weights.items()), strict=False)
|
||||
|
||||
print("VAE encoder loaded successfully")
|
||||
return encoder
|
||||
|
||||
|
||||
def encode_image(
|
||||
|
||||
Reference in New Issue
Block a user