Refactor VideoEncoder to initialize from VideoEncoderModelConfig, enhancing configuration management. Add methods for weight sanitization and loading from pretrained models, improving model usability and integration with existing workflows.

This commit is contained in:
Prince Canuma
2026-01-23 17:59:57 +01:00
parent f8f78aeab5
commit ce39e744c3
3 changed files with 110 additions and 45 deletions

View File

@@ -277,12 +277,12 @@ class VideoDecoderModelConfig(BaseModelConfig):
@dataclass @dataclass
class VideoEncoderModelConfig(BaseModelConfig): class VideoEncoderModelConfig(BaseModelConfig):
convolution_dimensions: int = 3 convolution_dimensions: int = 3
in_channels : int = 3, in_channels: int = 3
out_channels: int = 128, out_channels: int = 128
patch_size: int = 4, patch_size: int = 4
norm_layer: Enum = None, norm_layer: Enum = None
latent_log_var: Enum = None, latent_log_var: Enum = None
encoder_spatial_padding_mode: Enum = None, encoder_spatial_padding_mode: Enum = None
encoder_blocks: List[tuple] = field(default_factory=lambda: [("res_x", {"num_layers": 4}), encoder_blocks: List[tuple] = field(default_factory=lambda: [("res_x", {"num_layers": 4}),
("compress_space_res", {"multiplier": 2}), ("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}), ("res_x", {"num_layers": 6}),

View File

@@ -1,6 +1,6 @@
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder
from mlx_video.models.ltx.video_vae.encoder import encode_image from mlx_video.models.ltx.video_vae.encoder import encode_image
from mlx_video.models.ltx.video_vae.decoder import LTX2VideoDecoder from mlx_video.models.ltx.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
from mlx_video.models.ltx.video_vae.tiling import ( from mlx_video.models.ltx.video_vae.tiling import (
TilingConfig, TilingConfig,
SpatialTilingConfig, SpatialTilingConfig,

View File

@@ -1,6 +1,7 @@
"""Video VAE Encoder and Decoder for LTX-2.""" """Video VAE Encoder and Decoder for LTX-2."""
from enum import Enum from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
@@ -221,46 +222,30 @@ class VideoEncoder(nn.Module):
_DEFAULT_NORM_NUM_GROUPS = 32 _DEFAULT_NORM_NUM_GROUPS = 32
def __init__( def __init__(self, config: "VideoEncoderModelConfig"):
self, """Initialize VideoEncoder from config.
convolution_dimensions: int = 3,
in_channels: int = 3,
out_channels: int = 128,
encoder_blocks: List[Tuple[str, Any]] = None,
patch_size: int = 4,
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
latent_log_var: LogVarianceType = LogVarianceType.UNIFORM,
encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
"""Initialize VideoEncoder.
Args: Args:
convolution_dimensions: Number of dimensions (3 for video) config: VideoEncoderModelConfig with encoder parameters
in_channels: Input channels (3 for RGB)
out_channels: Output latent channels
encoder_blocks: List of (block_name, config) tuples
patch_size: Spatial patch size
norm_layer: Normalization layer type
latent_log_var: Log variance mode
encoder_spatial_padding_mode: Padding mode
""" """
super().__init__() super().__init__()
from mlx_video.models.ltx.config import VideoEncoderModelConfig
if encoder_blocks is None: self.patch_size = config.patch_size
encoder_blocks = [] self.norm_layer = config.norm_layer
self.latent_channels = config.out_channels
self.patch_size = patch_size self.latent_log_var = config.latent_log_var
self.norm_layer = norm_layer
self.latent_channels = out_channels
self.latent_log_var = latent_log_var
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
encoder_blocks = config.encoder_blocks if config.encoder_blocks else []
encoder_spatial_padding_mode = config.encoder_spatial_padding_mode
# Per-channel statistics for normalizing latents # Per-channel statistics for normalizing latents
self.per_channel_statistics = PerChannelStatistics(latent_channels=out_channels) self.per_channel_statistics = PerChannelStatistics(latent_channels=config.out_channels)
# After patchify, channels increase by patch_size^2 # After patchify, channels increase by patch_size^2
in_channels = in_channels * patch_size ** 2 in_channels = config.in_channels * config.patch_size ** 2
feature_channels = out_channels feature_channels = config.out_channels
# Initial convolution # Initial convolution
self.conv_in = CausalConv3d( self.conv_in = CausalConv3d(
@@ -283,30 +268,30 @@ class VideoEncoder(nn.Module):
block_name=block_name, block_name=block_name,
block_config=block_config, block_config=block_config,
in_channels=feature_channels, in_channels=feature_channels,
convolution_dimensions=convolution_dimensions, convolution_dimensions=config.convolution_dimensions,
norm_layer=norm_layer, norm_layer=config.norm_layer,
norm_num_groups=self._norm_num_groups, norm_num_groups=self._norm_num_groups,
spatial_padding_mode=encoder_spatial_padding_mode, spatial_padding_mode=encoder_spatial_padding_mode,
) )
self.down_blocks[idx] = block self.down_blocks[idx] = block
# Output normalization and convolution # Output normalization and convolution
if norm_layer == NormLayerType.GROUP_NORM: if config.norm_layer == NormLayerType.GROUP_NORM:
self.conv_norm_out = nn.GroupNorm( self.conv_norm_out = nn.GroupNorm(
num_groups=self._norm_num_groups, num_groups=self._norm_num_groups,
dims=feature_channels, dims=feature_channels,
eps=1e-6, eps=1e-6,
) )
elif norm_layer == NormLayerType.PIXEL_NORM: elif config.norm_layer == NormLayerType.PIXEL_NORM:
self.conv_norm_out = PixelNorm() self.conv_norm_out = PixelNorm()
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
# Calculate output convolution channels # Calculate output convolution channels
conv_out_channels = out_channels conv_out_channels = config.out_channels
if latent_log_var == LogVarianceType.PER_CHANNEL: if config.latent_log_var == LogVarianceType.PER_CHANNEL:
conv_out_channels *= 2 conv_out_channels *= 2
elif latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}: elif config.latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
conv_out_channels += 1 conv_out_channels += 1
self.conv_out = CausalConv3d( self.conv_out = CausalConv3d(
@@ -373,6 +358,86 @@ class VideoEncoder(nn.Module):
means = sample[:, :self.latent_channels, ...] means = sample[:, :self.latent_channels, ...]
return self.per_channel_statistics.normalize(means) return self.per_channel_statistics.normalize(means)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize VAE encoder weights from PyTorch format to MLX format."""
sanitized = {}
if "per_channel_statistics.mean" in weights:
return weights
for key, value in weights.items():
new_key = key
if "position_ids" in key:
continue
# Only process VAE encoder weights
if not key.startswith("vae."):
continue
# Handle per-channel statistics
if "vae.per_channel_statistics" in key:
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
continue
elif key.startswith("vae.encoder."):
new_key = key.replace("vae.encoder.", "")
else:
continue
# Conv3d: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Conv2d: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path) -> "VideoEncoder":
"""Load a pretrained VideoEncoder from a directory with weights and config.
Args:
model_path: Path to directory containing safetensors weights and config.json
Returns:
Loaded VideoEncoder instance
"""
import json
from mlx_video.models.ltx.config import VideoEncoderModelConfig
# Load config
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path) as f:
config_dict = json.load(f)
config = VideoEncoderModelConfig(**config_dict)
else:
config = VideoEncoderModelConfig()
# Load weights
weight_files = sorted(model_path.glob("*.safetensors"))
if not weight_files:
if model_path.is_file():
weights = mx.load(str(model_path))
else:
raise FileNotFoundError(f"No safetensors files found in {model_path}")
else:
weights = {}
for wf in weight_files:
weights.update(mx.load(str(wf)))
# Create model, sanitize and load weights
model = cls(config)
weights = model.sanitize(weights)
model.load_weights(list(weights.items()), strict=False)
return model
class VideoDecoder(nn.Module): class VideoDecoder(nn.Module):