Refactor VideoEncoder to initialize from VideoEncoderModelConfig, enhancing configuration management. Add methods for weight sanitization and loading from pretrained models, improving model usability and integration with existing workflows.

This commit is contained in:
Prince Canuma
2026-01-23 17:59:57 +01:00
parent f8f78aeab5
commit ce39e744c3
3 changed files with 110 additions and 45 deletions

View File

@@ -1,6 +1,7 @@
"""Video VAE Encoder and Decoder for LTX-2."""
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import mlx.core as mx
@@ -221,46 +222,30 @@ class VideoEncoder(nn.Module):
_DEFAULT_NORM_NUM_GROUPS = 32
def __init__(
self,
convolution_dimensions: int = 3,
in_channels: int = 3,
out_channels: int = 128,
encoder_blocks: List[Tuple[str, Any]] = None,
patch_size: int = 4,
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
latent_log_var: LogVarianceType = LogVarianceType.UNIFORM,
encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
"""Initialize VideoEncoder.
def __init__(self, config: "VideoEncoderModelConfig"):
"""Initialize VideoEncoder from config.
Args:
convolution_dimensions: Number of dimensions (3 for video)
in_channels: Input channels (3 for RGB)
out_channels: Output latent channels
encoder_blocks: List of (block_name, config) tuples
patch_size: Spatial patch size
norm_layer: Normalization layer type
latent_log_var: Log variance mode
encoder_spatial_padding_mode: Padding mode
config: VideoEncoderModelConfig with encoder parameters
"""
super().__init__()
from mlx_video.models.ltx.config import VideoEncoderModelConfig
if encoder_blocks is None:
encoder_blocks = []
self.patch_size = patch_size
self.norm_layer = norm_layer
self.latent_channels = out_channels
self.latent_log_var = latent_log_var
self.patch_size = config.patch_size
self.norm_layer = config.norm_layer
self.latent_channels = config.out_channels
self.latent_log_var = config.latent_log_var
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
encoder_blocks = config.encoder_blocks if config.encoder_blocks else []
encoder_spatial_padding_mode = config.encoder_spatial_padding_mode
# Per-channel statistics for normalizing latents
self.per_channel_statistics = PerChannelStatistics(latent_channels=out_channels)
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.out_channels)
# After patchify, channels increase by patch_size^2
in_channels = in_channels * patch_size ** 2
feature_channels = out_channels
in_channels = config.in_channels * config.patch_size ** 2
feature_channels = config.out_channels
# Initial convolution
self.conv_in = CausalConv3d(
@@ -283,30 +268,30 @@ class VideoEncoder(nn.Module):
block_name=block_name,
block_config=block_config,
in_channels=feature_channels,
convolution_dimensions=convolution_dimensions,
norm_layer=norm_layer,
convolution_dimensions=config.convolution_dimensions,
norm_layer=config.norm_layer,
norm_num_groups=self._norm_num_groups,
spatial_padding_mode=encoder_spatial_padding_mode,
)
self.down_blocks[idx] = block
# Output normalization and convolution
if norm_layer == NormLayerType.GROUP_NORM:
if config.norm_layer == NormLayerType.GROUP_NORM:
self.conv_norm_out = nn.GroupNorm(
num_groups=self._norm_num_groups,
dims=feature_channels,
eps=1e-6,
)
elif norm_layer == NormLayerType.PIXEL_NORM:
elif config.norm_layer == NormLayerType.PIXEL_NORM:
self.conv_norm_out = PixelNorm()
self.conv_act = nn.SiLU()
# Calculate output convolution channels
conv_out_channels = out_channels
if latent_log_var == LogVarianceType.PER_CHANNEL:
conv_out_channels = config.out_channels
if config.latent_log_var == LogVarianceType.PER_CHANNEL:
conv_out_channels *= 2
elif latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
elif config.latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
conv_out_channels += 1
self.conv_out = CausalConv3d(
@@ -373,6 +358,86 @@ class VideoEncoder(nn.Module):
means = sample[:, :self.latent_channels, ...]
return self.per_channel_statistics.normalize(means)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize VAE encoder weights from PyTorch format to MLX format."""
sanitized = {}
if "per_channel_statistics.mean" in weights:
return weights
for key, value in weights.items():
new_key = key
if "position_ids" in key:
continue
# Only process VAE encoder weights
if not key.startswith("vae."):
continue
# Handle per-channel statistics
if "vae.per_channel_statistics" in key:
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
continue
elif key.startswith("vae.encoder."):
new_key = key.replace("vae.encoder.", "")
else:
continue
# Conv3d: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Conv2d: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path) -> "VideoEncoder":
"""Load a pretrained VideoEncoder from a directory with weights and config.
Args:
model_path: Path to directory containing safetensors weights and config.json
Returns:
Loaded VideoEncoder instance
"""
import json
from mlx_video.models.ltx.config import VideoEncoderModelConfig
# Load config
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path) as f:
config_dict = json.load(f)
config = VideoEncoderModelConfig(**config_dict)
else:
config = VideoEncoderModelConfig()
# Load weights
weight_files = sorted(model_path.glob("*.safetensors"))
if not weight_files:
if model_path.is_file():
weights = mx.load(str(model_path))
else:
raise FileNotFoundError(f"No safetensors files found in {model_path}")
else:
weights = {}
for wf in weight_files:
weights.update(mx.load(str(wf)))
# Create model, sanitize and load weights
model = cls(config)
weights = model.sanitize(weights)
model.load_weights(list(weights.items()), strict=False)
return model
class VideoDecoder(nn.Module):