From ce39e744c37428a6b9590763fea3f5ed97ce763d Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 23 Jan 2026 17:59:57 +0100 Subject: [PATCH] Refactor VideoEncoder to initialize from VideoEncoderModelConfig, enhancing configuration management. Add methods for weight sanitization and loading from pretrained models, improving model usability and integration with existing workflows. --- mlx_video/models/ltx/config.py | 12 +- mlx_video/models/ltx/video_vae/__init__.py | 4 +- mlx_video/models/ltx/video_vae/video_vae.py | 139 ++++++++++++++------ 3 files changed, 110 insertions(+), 45 deletions(-) diff --git a/mlx_video/models/ltx/config.py b/mlx_video/models/ltx/config.py index 2ca56a9..400a634 100644 --- a/mlx_video/models/ltx/config.py +++ b/mlx_video/models/ltx/config.py @@ -277,12 +277,12 @@ class VideoDecoderModelConfig(BaseModelConfig): @dataclass class VideoEncoderModelConfig(BaseModelConfig): convolution_dimensions: int = 3 - in_channels : int = 3, - out_channels: int = 128, - patch_size: int = 4, - norm_layer: Enum = None, - latent_log_var: Enum = None, - encoder_spatial_padding_mode: Enum = None, + in_channels: int = 3 + out_channels: int = 128 + patch_size: int = 4 + norm_layer: Enum = None + latent_log_var: Enum = None + encoder_spatial_padding_mode: Enum = None encoder_blocks: List[tuple] = field(default_factory=lambda: [("res_x", {"num_layers": 4}), ("compress_space_res", {"multiplier": 2}), ("res_x", {"num_layers": 6}), diff --git a/mlx_video/models/ltx/video_vae/__init__.py b/mlx_video/models/ltx/video_vae/__init__.py index 79f68cd..3233b75 100644 --- a/mlx_video/models/ltx/video_vae/__init__.py +++ b/mlx_video/models/ltx/video_vae/__init__.py @@ -1,6 +1,6 @@ -from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder +from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder from mlx_video.models.ltx.video_vae.encoder import encode_image -from mlx_video.models.ltx.video_vae.decoder import LTX2VideoDecoder +from mlx_video.models.ltx.video_vae.decoder import LTX2VideoDecoder, VideoDecoder from mlx_video.models.ltx.video_vae.tiling import ( TilingConfig, SpatialTilingConfig, diff --git a/mlx_video/models/ltx/video_vae/video_vae.py b/mlx_video/models/ltx/video_vae/video_vae.py index af4349e..1b40b1f 100644 --- a/mlx_video/models/ltx/video_vae/video_vae.py +++ b/mlx_video/models/ltx/video_vae/video_vae.py @@ -1,6 +1,7 @@ """Video VAE Encoder and Decoder for LTX-2.""" from enum import Enum +from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import mlx.core as mx @@ -221,46 +222,30 @@ class VideoEncoder(nn.Module): _DEFAULT_NORM_NUM_GROUPS = 32 - def __init__( - self, - convolution_dimensions: int = 3, - in_channels: int = 3, - out_channels: int = 128, - encoder_blocks: List[Tuple[str, Any]] = None, - patch_size: int = 4, - norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, - latent_log_var: LogVarianceType = LogVarianceType.UNIFORM, - encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, - ): - """Initialize VideoEncoder. + def __init__(self, config: "VideoEncoderModelConfig"): + """Initialize VideoEncoder from config. Args: - convolution_dimensions: Number of dimensions (3 for video) - in_channels: Input channels (3 for RGB) - out_channels: Output latent channels - encoder_blocks: List of (block_name, config) tuples - patch_size: Spatial patch size - norm_layer: Normalization layer type - latent_log_var: Log variance mode - encoder_spatial_padding_mode: Padding mode + config: VideoEncoderModelConfig with encoder parameters """ super().__init__() + from mlx_video.models.ltx.config import VideoEncoderModelConfig - if encoder_blocks is None: - encoder_blocks = [] - - self.patch_size = patch_size - self.norm_layer = norm_layer - self.latent_channels = out_channels - self.latent_log_var = latent_log_var + self.patch_size = config.patch_size + self.norm_layer = config.norm_layer + self.latent_channels = config.out_channels + self.latent_log_var = config.latent_log_var self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS + encoder_blocks = config.encoder_blocks if config.encoder_blocks else [] + encoder_spatial_padding_mode = config.encoder_spatial_padding_mode + # Per-channel statistics for normalizing latents - self.per_channel_statistics = PerChannelStatistics(latent_channels=out_channels) + self.per_channel_statistics = PerChannelStatistics(latent_channels=config.out_channels) # After patchify, channels increase by patch_size^2 - in_channels = in_channels * patch_size ** 2 - feature_channels = out_channels + in_channels = config.in_channels * config.patch_size ** 2 + feature_channels = config.out_channels # Initial convolution self.conv_in = CausalConv3d( @@ -283,30 +268,30 @@ class VideoEncoder(nn.Module): block_name=block_name, block_config=block_config, in_channels=feature_channels, - convolution_dimensions=convolution_dimensions, - norm_layer=norm_layer, + convolution_dimensions=config.convolution_dimensions, + norm_layer=config.norm_layer, norm_num_groups=self._norm_num_groups, spatial_padding_mode=encoder_spatial_padding_mode, ) self.down_blocks[idx] = block # Output normalization and convolution - if norm_layer == NormLayerType.GROUP_NORM: + if config.norm_layer == NormLayerType.GROUP_NORM: self.conv_norm_out = nn.GroupNorm( num_groups=self._norm_num_groups, dims=feature_channels, eps=1e-6, ) - elif norm_layer == NormLayerType.PIXEL_NORM: + elif config.norm_layer == NormLayerType.PIXEL_NORM: self.conv_norm_out = PixelNorm() self.conv_act = nn.SiLU() # Calculate output convolution channels - conv_out_channels = out_channels - if latent_log_var == LogVarianceType.PER_CHANNEL: + conv_out_channels = config.out_channels + if config.latent_log_var == LogVarianceType.PER_CHANNEL: conv_out_channels *= 2 - elif latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}: + elif config.latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}: conv_out_channels += 1 self.conv_out = CausalConv3d( @@ -373,6 +358,86 @@ class VideoEncoder(nn.Module): means = sample[:, :self.latent_channels, ...] return self.per_channel_statistics.normalize(means) + def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: + """Sanitize VAE encoder weights from PyTorch format to MLX format.""" + sanitized = {} + if "per_channel_statistics.mean" in weights: + return weights + + for key, value in weights.items(): + new_key = key + + if "position_ids" in key: + continue + + # Only process VAE encoder weights + if not key.startswith("vae."): + continue + + # Handle per-channel statistics + if "vae.per_channel_statistics" in key: + if key == "vae.per_channel_statistics.mean-of-means": + new_key = "per_channel_statistics.mean" + elif key == "vae.per_channel_statistics.std-of-means": + new_key = "per_channel_statistics.std" + else: + continue + elif key.startswith("vae.encoder."): + new_key = key.replace("vae.encoder.", "") + else: + continue + + # Conv3d: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I) + if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5: + value = mx.transpose(value, (0, 2, 3, 4, 1)) + + # Conv2d: PyTorch (O, I, H, W) -> MLX (O, H, W, I) + if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4: + value = mx.transpose(value, (0, 2, 3, 1)) + + sanitized[new_key] = value + return sanitized + + @classmethod + def from_pretrained(cls, model_path: Path) -> "VideoEncoder": + """Load a pretrained VideoEncoder from a directory with weights and config. + + Args: + model_path: Path to directory containing safetensors weights and config.json + + Returns: + Loaded VideoEncoder instance + """ + import json + from mlx_video.models.ltx.config import VideoEncoderModelConfig + + # Load config + config_path = model_path / "config.json" + if config_path.exists(): + with open(config_path) as f: + config_dict = json.load(f) + config = VideoEncoderModelConfig(**config_dict) + else: + config = VideoEncoderModelConfig() + + # Load weights + weight_files = sorted(model_path.glob("*.safetensors")) + if not weight_files: + if model_path.is_file(): + weights = mx.load(str(model_path)) + else: + raise FileNotFoundError(f"No safetensors files found in {model_path}") + else: + weights = {} + for wf in weight_files: + weights.update(mx.load(str(wf))) + + # Create model, sanitize and load weights + model = cls(config) + weights = model.sanitize(weights) + model.load_weights(list(weights.items()), strict=False) + return model + class VideoDecoder(nn.Module):