Refactor VideoEncoder to initialize from VideoEncoderModelConfig, enhancing configuration management. Add methods for weight sanitization and loading from pretrained models, improving model usability and integration with existing workflows.
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder
|
||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder
|
||||
from mlx_video.models.ltx.video_vae.encoder import encode_image
|
||||
from mlx_video.models.ltx.video_vae.decoder import LTX2VideoDecoder
|
||||
from mlx_video.models.ltx.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
|
||||
from mlx_video.models.ltx.video_vae.tiling import (
|
||||
TilingConfig,
|
||||
SpatialTilingConfig,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Video VAE Encoder and Decoder for LTX-2."""
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
@@ -221,46 +222,30 @@ class VideoEncoder(nn.Module):
|
||||
|
||||
_DEFAULT_NORM_NUM_GROUPS = 32
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
convolution_dimensions: int = 3,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 128,
|
||||
encoder_blocks: List[Tuple[str, Any]] = None,
|
||||
patch_size: int = 4,
|
||||
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||
latent_log_var: LogVarianceType = LogVarianceType.UNIFORM,
|
||||
encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
"""Initialize VideoEncoder.
|
||||
def __init__(self, config: "VideoEncoderModelConfig"):
|
||||
"""Initialize VideoEncoder from config.
|
||||
|
||||
Args:
|
||||
convolution_dimensions: Number of dimensions (3 for video)
|
||||
in_channels: Input channels (3 for RGB)
|
||||
out_channels: Output latent channels
|
||||
encoder_blocks: List of (block_name, config) tuples
|
||||
patch_size: Spatial patch size
|
||||
norm_layer: Normalization layer type
|
||||
latent_log_var: Log variance mode
|
||||
encoder_spatial_padding_mode: Padding mode
|
||||
config: VideoEncoderModelConfig with encoder parameters
|
||||
"""
|
||||
super().__init__()
|
||||
from mlx_video.models.ltx.config import VideoEncoderModelConfig
|
||||
|
||||
if encoder_blocks is None:
|
||||
encoder_blocks = []
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.norm_layer = norm_layer
|
||||
self.latent_channels = out_channels
|
||||
self.latent_log_var = latent_log_var
|
||||
self.patch_size = config.patch_size
|
||||
self.norm_layer = config.norm_layer
|
||||
self.latent_channels = config.out_channels
|
||||
self.latent_log_var = config.latent_log_var
|
||||
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
|
||||
|
||||
encoder_blocks = config.encoder_blocks if config.encoder_blocks else []
|
||||
encoder_spatial_padding_mode = config.encoder_spatial_padding_mode
|
||||
|
||||
# Per-channel statistics for normalizing latents
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=out_channels)
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.out_channels)
|
||||
|
||||
# After patchify, channels increase by patch_size^2
|
||||
in_channels = in_channels * patch_size ** 2
|
||||
feature_channels = out_channels
|
||||
in_channels = config.in_channels * config.patch_size ** 2
|
||||
feature_channels = config.out_channels
|
||||
|
||||
# Initial convolution
|
||||
self.conv_in = CausalConv3d(
|
||||
@@ -283,30 +268,30 @@ class VideoEncoder(nn.Module):
|
||||
block_name=block_name,
|
||||
block_config=block_config,
|
||||
in_channels=feature_channels,
|
||||
convolution_dimensions=convolution_dimensions,
|
||||
norm_layer=norm_layer,
|
||||
convolution_dimensions=config.convolution_dimensions,
|
||||
norm_layer=config.norm_layer,
|
||||
norm_num_groups=self._norm_num_groups,
|
||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||
)
|
||||
self.down_blocks[idx] = block
|
||||
|
||||
# Output normalization and convolution
|
||||
if norm_layer == NormLayerType.GROUP_NORM:
|
||||
if config.norm_layer == NormLayerType.GROUP_NORM:
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_groups=self._norm_num_groups,
|
||||
dims=feature_channels,
|
||||
eps=1e-6,
|
||||
)
|
||||
elif norm_layer == NormLayerType.PIXEL_NORM:
|
||||
elif config.norm_layer == NormLayerType.PIXEL_NORM:
|
||||
self.conv_norm_out = PixelNorm()
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
# Calculate output convolution channels
|
||||
conv_out_channels = out_channels
|
||||
if latent_log_var == LogVarianceType.PER_CHANNEL:
|
||||
conv_out_channels = config.out_channels
|
||||
if config.latent_log_var == LogVarianceType.PER_CHANNEL:
|
||||
conv_out_channels *= 2
|
||||
elif latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
|
||||
elif config.latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
|
||||
conv_out_channels += 1
|
||||
|
||||
self.conv_out = CausalConv3d(
|
||||
@@ -373,6 +358,86 @@ class VideoEncoder(nn.Module):
|
||||
means = sample[:, :self.latent_channels, ...]
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize VAE encoder weights from PyTorch format to MLX format."""
|
||||
sanitized = {}
|
||||
if "per_channel_statistics.mean" in weights:
|
||||
return weights
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
if "position_ids" in key:
|
||||
continue
|
||||
|
||||
# Only process VAE encoder weights
|
||||
if not key.startswith("vae."):
|
||||
continue
|
||||
|
||||
# Handle per-channel statistics
|
||||
if "vae.per_channel_statistics" in key:
|
||||
if key == "vae.per_channel_statistics.mean-of-means":
|
||||
new_key = "per_channel_statistics.mean"
|
||||
elif key == "vae.per_channel_statistics.std-of-means":
|
||||
new_key = "per_channel_statistics.std"
|
||||
else:
|
||||
continue
|
||||
elif key.startswith("vae.encoder."):
|
||||
new_key = key.replace("vae.encoder.", "")
|
||||
else:
|
||||
continue
|
||||
|
||||
# Conv3d: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
# Conv2d: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "VideoEncoder":
|
||||
"""Load a pretrained VideoEncoder from a directory with weights and config.
|
||||
|
||||
Args:
|
||||
model_path: Path to directory containing safetensors weights and config.json
|
||||
|
||||
Returns:
|
||||
Loaded VideoEncoder instance
|
||||
"""
|
||||
import json
|
||||
from mlx_video.models.ltx.config import VideoEncoderModelConfig
|
||||
|
||||
# Load config
|
||||
config_path = model_path / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
config_dict = json.load(f)
|
||||
config = VideoEncoderModelConfig(**config_dict)
|
||||
else:
|
||||
config = VideoEncoderModelConfig()
|
||||
|
||||
# Load weights
|
||||
weight_files = sorted(model_path.glob("*.safetensors"))
|
||||
if not weight_files:
|
||||
if model_path.is_file():
|
||||
weights = mx.load(str(model_path))
|
||||
else:
|
||||
raise FileNotFoundError(f"No safetensors files found in {model_path}")
|
||||
else:
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(str(wf)))
|
||||
|
||||
# Create model, sanitize and load weights
|
||||
model = cls(config)
|
||||
weights = model.sanitize(weights)
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
return model
|
||||
|
||||
|
||||
class VideoDecoder(nn.Module):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user