598 lines
21 KiB
Python
598 lines
21 KiB
Python
"""Video VAE Encoder and Decoder for LTX-2."""
|
|
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
|
from mlx_video.models.ltx.video_vae.ops import PerChannelStatistics, patchify, unpatchify
|
|
from mlx_video.models.ltx.video_vae.resnet import (
|
|
NormLayerType,
|
|
ResnetBlock3D,
|
|
UNetMidBlock3D,
|
|
get_norm_layer,
|
|
)
|
|
from mlx_video.models.ltx.video_vae.sampling import (
|
|
DepthToSpaceUpsample,
|
|
SpaceToDepthDownsample,
|
|
)
|
|
from mlx_video.utils import PixelNorm
|
|
|
|
|
|
class LogVarianceType(Enum):
|
|
"""Log variance mode for VAE."""
|
|
PER_CHANNEL = "per_channel"
|
|
UNIFORM = "uniform"
|
|
CONSTANT = "constant"
|
|
NONE = "none"
|
|
|
|
|
|
def _make_encoder_block(
|
|
block_name: str,
|
|
block_config: Dict[str, Any],
|
|
in_channels: int,
|
|
convolution_dimensions: int,
|
|
norm_layer: NormLayerType,
|
|
norm_num_groups: int,
|
|
spatial_padding_mode: PaddingModeType,
|
|
) -> Tuple[nn.Module, int]:
|
|
"""Create an encoder block.
|
|
|
|
Args:
|
|
block_name: Type of block
|
|
block_config: Block configuration
|
|
in_channels: Input channels
|
|
convolution_dimensions: Number of dimensions
|
|
norm_layer: Normalization layer type
|
|
norm_num_groups: Number of groups for group norm
|
|
spatial_padding_mode: Padding mode
|
|
|
|
Returns:
|
|
Tuple of (block, output_channels)
|
|
"""
|
|
out_channels = in_channels
|
|
|
|
if block_name == "res_x":
|
|
block = UNetMidBlock3D(
|
|
dims=convolution_dimensions,
|
|
in_channels=in_channels,
|
|
num_layers=block_config["num_layers"],
|
|
resnet_eps=1e-6,
|
|
resnet_groups=norm_num_groups,
|
|
norm_layer=norm_layer,
|
|
spatial_padding_mode=spatial_padding_mode,
|
|
)
|
|
elif block_name == "res_x_y":
|
|
out_channels = in_channels * block_config.get("multiplier", 2)
|
|
block = ResnetBlock3D(
|
|
dims=convolution_dimensions,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
eps=1e-6,
|
|
groups=norm_num_groups,
|
|
norm_layer=norm_layer,
|
|
spatial_padding_mode=spatial_padding_mode,
|
|
)
|
|
elif block_name == "compress_time":
|
|
block = CausalConv3d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=3,
|
|
stride=(2, 1, 1),
|
|
padding=1,
|
|
causal=True,
|
|
spatial_padding_mode=spatial_padding_mode,
|
|
)
|
|
elif block_name == "compress_space":
|
|
block = CausalConv3d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=3,
|
|
stride=(1, 2, 2),
|
|
padding=1,
|
|
causal=True,
|
|
spatial_padding_mode=spatial_padding_mode,
|
|
)
|
|
elif block_name == "compress_all":
|
|
block = CausalConv3d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=3,
|
|
stride=(2, 2, 2),
|
|
padding=1,
|
|
causal=True,
|
|
spatial_padding_mode=spatial_padding_mode,
|
|
)
|
|
elif block_name == "compress_all_x_y":
|
|
out_channels = in_channels * block_config.get("multiplier", 2)
|
|
block = CausalConv3d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=3,
|
|
stride=(2, 2, 2),
|
|
padding=1,
|
|
causal=True,
|
|
spatial_padding_mode=spatial_padding_mode,
|
|
)
|
|
elif block_name == "compress_all_res":
|
|
out_channels = in_channels * block_config.get("multiplier", 2)
|
|
block = SpaceToDepthDownsample(
|
|
dims=convolution_dimensions,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
stride=(2, 2, 2),
|
|
spatial_padding_mode=spatial_padding_mode,
|
|
)
|
|
elif block_name == "compress_space_res":
|
|
out_channels = in_channels * block_config.get("multiplier", 2)
|
|
block = SpaceToDepthDownsample(
|
|
dims=convolution_dimensions,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
stride=(1, 2, 2),
|
|
spatial_padding_mode=spatial_padding_mode,
|
|
)
|
|
elif block_name == "compress_time_res":
|
|
out_channels = in_channels * block_config.get("multiplier", 2)
|
|
block = SpaceToDepthDownsample(
|
|
dims=convolution_dimensions,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
stride=(2, 1, 1),
|
|
spatial_padding_mode=spatial_padding_mode,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown encoder block: {block_name}")
|
|
|
|
return block, out_channels
|
|
|
|
|
|
def _make_decoder_block(
|
|
block_name: str,
|
|
block_config: Dict[str, Any],
|
|
in_channels: int,
|
|
convolution_dimensions: int,
|
|
norm_layer: NormLayerType,
|
|
timestep_conditioning: bool,
|
|
norm_num_groups: int,
|
|
spatial_padding_mode: PaddingModeType,
|
|
) -> Tuple[nn.Module, int]:
|
|
"""Create a decoder block."""
|
|
out_channels = in_channels
|
|
|
|
if block_name == "res_x":
|
|
block = UNetMidBlock3D(
|
|
dims=convolution_dimensions,
|
|
in_channels=in_channels,
|
|
num_layers=block_config["num_layers"],
|
|
resnet_eps=1e-6,
|
|
resnet_groups=norm_num_groups,
|
|
norm_layer=norm_layer,
|
|
inject_noise=block_config.get("inject_noise", False),
|
|
timestep_conditioning=timestep_conditioning,
|
|
spatial_padding_mode=spatial_padding_mode,
|
|
)
|
|
elif block_name == "res_x_y":
|
|
out_channels = in_channels // block_config.get("multiplier", 2)
|
|
block = ResnetBlock3D(
|
|
dims=convolution_dimensions,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
eps=1e-6,
|
|
groups=norm_num_groups,
|
|
norm_layer=norm_layer,
|
|
inject_noise=block_config.get("inject_noise", False),
|
|
timestep_conditioning=False,
|
|
spatial_padding_mode=spatial_padding_mode,
|
|
)
|
|
elif block_name == "compress_time":
|
|
block = DepthToSpaceUpsample(
|
|
dims=convolution_dimensions,
|
|
in_channels=in_channels,
|
|
stride=(2, 1, 1),
|
|
spatial_padding_mode=spatial_padding_mode,
|
|
)
|
|
elif block_name == "compress_space":
|
|
block = DepthToSpaceUpsample(
|
|
dims=convolution_dimensions,
|
|
in_channels=in_channels,
|
|
stride=(1, 2, 2),
|
|
spatial_padding_mode=spatial_padding_mode,
|
|
)
|
|
elif block_name == "compress_all":
|
|
out_channels = in_channels // block_config.get("multiplier", 1)
|
|
block = DepthToSpaceUpsample(
|
|
dims=convolution_dimensions,
|
|
in_channels=in_channels,
|
|
stride=(2, 2, 2),
|
|
residual=block_config.get("residual", False),
|
|
out_channels_reduction_factor=block_config.get("multiplier", 1),
|
|
spatial_padding_mode=spatial_padding_mode,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown decoder block: {block_name}")
|
|
|
|
return block, out_channels
|
|
|
|
|
|
class VideoEncoder(nn.Module):
|
|
|
|
_DEFAULT_NORM_NUM_GROUPS = 32
|
|
|
|
def __init__(self, config: "VideoEncoderModelConfig"):
|
|
"""Initialize VideoEncoder from config.
|
|
|
|
Args:
|
|
config: VideoEncoderModelConfig with encoder parameters
|
|
"""
|
|
super().__init__()
|
|
from mlx_video.models.ltx.config import VideoEncoderModelConfig
|
|
|
|
self.patch_size = config.patch_size
|
|
self.norm_layer = config.norm_layer
|
|
self.latent_channels = config.out_channels
|
|
self.latent_log_var = config.latent_log_var
|
|
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
|
|
|
|
encoder_blocks = config.encoder_blocks if config.encoder_blocks else []
|
|
encoder_spatial_padding_mode = config.encoder_spatial_padding_mode
|
|
|
|
# Per-channel statistics for normalizing latents
|
|
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.out_channels)
|
|
|
|
# After patchify, channels increase by patch_size^2
|
|
in_channels = config.in_channels * config.patch_size ** 2
|
|
feature_channels = config.out_channels
|
|
|
|
# Initial convolution
|
|
self.conv_in = CausalConv3d(
|
|
in_channels=in_channels,
|
|
out_channels=feature_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
causal=True,
|
|
spatial_padding_mode=encoder_spatial_padding_mode,
|
|
)
|
|
|
|
# Build encoder blocks
|
|
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
|
self.down_blocks = {}
|
|
for idx, (block_name, block_params) in enumerate(encoder_blocks):
|
|
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
|
|
|
block, feature_channels = _make_encoder_block(
|
|
block_name=block_name,
|
|
block_config=block_config,
|
|
in_channels=feature_channels,
|
|
convolution_dimensions=config.convolution_dimensions,
|
|
norm_layer=config.norm_layer,
|
|
norm_num_groups=self._norm_num_groups,
|
|
spatial_padding_mode=encoder_spatial_padding_mode,
|
|
)
|
|
self.down_blocks[idx] = block
|
|
|
|
# Output normalization and convolution
|
|
if config.norm_layer == NormLayerType.GROUP_NORM:
|
|
self.conv_norm_out = nn.GroupNorm(
|
|
num_groups=self._norm_num_groups,
|
|
dims=feature_channels,
|
|
eps=1e-6,
|
|
)
|
|
elif config.norm_layer == NormLayerType.PIXEL_NORM:
|
|
self.conv_norm_out = PixelNorm()
|
|
|
|
self.conv_act = nn.SiLU()
|
|
|
|
# Calculate output convolution channels
|
|
conv_out_channels = config.out_channels
|
|
if config.latent_log_var == LogVarianceType.PER_CHANNEL:
|
|
conv_out_channels *= 2
|
|
elif config.latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
|
|
conv_out_channels += 1
|
|
|
|
self.conv_out = CausalConv3d(
|
|
in_channels=feature_channels,
|
|
out_channels=conv_out_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
causal=True,
|
|
spatial_padding_mode=encoder_spatial_padding_mode,
|
|
)
|
|
|
|
def __call__(self, sample: mx.array) -> mx.array:
|
|
"""Encode video to latent representation.
|
|
|
|
Args:
|
|
sample: Input video of shape (B, C, F, H, W).
|
|
F must be 1 + 8*k (e.g., 1, 9, 17, 25, 33...)
|
|
|
|
Returns:
|
|
Normalized latent means of shape (B, 128, F', H', W')
|
|
"""
|
|
# Validate frame count
|
|
frames_count = sample.shape[2]
|
|
if ((frames_count - 1) % 8) != 0:
|
|
raise ValueError(
|
|
"Invalid number of frames: Encode input must have 1 + 8 * x frames "
|
|
f"(e.g., 1, 9, 17, ...). Got {frames_count} frames."
|
|
)
|
|
|
|
# Initial patchify
|
|
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
|
sample = self.conv_in(sample, causal=True)
|
|
|
|
# Process through encoder blocks
|
|
for i in range(len(self.down_blocks)):
|
|
down_block = self.down_blocks[i]
|
|
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
|
|
sample = down_block(sample, causal=True)
|
|
else:
|
|
sample = down_block(sample, causal=True)
|
|
|
|
# Output processing
|
|
sample = self.conv_norm_out(sample)
|
|
sample = self.conv_act(sample)
|
|
sample = self.conv_out(sample, causal=True)
|
|
|
|
# Handle log variance modes
|
|
if self.latent_log_var == LogVarianceType.UNIFORM:
|
|
means = sample[:, :-1, ...]
|
|
logvar = sample[:, -1:, ...]
|
|
num_channels = means.shape[1]
|
|
repeated_logvar = mx.tile(logvar, (1, num_channels, 1, 1, 1))
|
|
sample = mx.concatenate([means, repeated_logvar], axis=1)
|
|
elif self.latent_log_var == LogVarianceType.CONSTANT:
|
|
sample = sample[:, :-1, ...]
|
|
approx_ln_0 = -30
|
|
sample = mx.concatenate([
|
|
sample,
|
|
mx.full_like(sample, approx_ln_0),
|
|
], axis=1)
|
|
|
|
# Split into means and logvar, normalize means
|
|
means = sample[:, :self.latent_channels, ...]
|
|
return self.per_channel_statistics.normalize(means)
|
|
|
|
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|
"""Sanitize VAE encoder weights from PyTorch format to MLX format."""
|
|
sanitized = {}
|
|
if "per_channel_statistics.mean" in weights:
|
|
return weights
|
|
|
|
for key, value in weights.items():
|
|
new_key = key
|
|
|
|
if "position_ids" in key:
|
|
continue
|
|
|
|
# Only process VAE encoder weights
|
|
if not key.startswith("vae."):
|
|
continue
|
|
|
|
# Handle per-channel statistics
|
|
if "vae.per_channel_statistics" in key:
|
|
if key == "vae.per_channel_statistics.mean-of-means":
|
|
new_key = "per_channel_statistics.mean"
|
|
elif key == "vae.per_channel_statistics.std-of-means":
|
|
new_key = "per_channel_statistics.std"
|
|
else:
|
|
continue
|
|
elif key.startswith("vae.encoder."):
|
|
new_key = key.replace("vae.encoder.", "")
|
|
else:
|
|
continue
|
|
|
|
# Conv3d: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
|
|
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
|
|
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
|
|
|
# Conv2d: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
|
|
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
|
value = mx.transpose(value, (0, 2, 3, 1))
|
|
|
|
sanitized[new_key] = value
|
|
return sanitized
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, model_path: Path) -> "VideoEncoder":
|
|
"""Load a pretrained VideoEncoder from a directory with weights and config.
|
|
|
|
Args:
|
|
model_path: Path to directory containing safetensors weights and config.json
|
|
|
|
Returns:
|
|
Loaded VideoEncoder instance
|
|
"""
|
|
import json
|
|
from mlx_video.models.ltx.config import VideoEncoderModelConfig
|
|
|
|
# Load config
|
|
config_path = model_path / "config.json"
|
|
if config_path.exists():
|
|
with open(config_path) as f:
|
|
config_dict = json.load(f)
|
|
config = VideoEncoderModelConfig(**config_dict)
|
|
else:
|
|
config = VideoEncoderModelConfig()
|
|
|
|
# Load weights
|
|
weight_files = sorted(model_path.glob("*.safetensors"))
|
|
if not weight_files:
|
|
if model_path.is_file():
|
|
weights = mx.load(str(model_path))
|
|
else:
|
|
raise FileNotFoundError(f"No safetensors files found in {model_path}")
|
|
else:
|
|
weights = {}
|
|
for wf in weight_files:
|
|
weights.update(mx.load(str(wf)))
|
|
|
|
# Create model, sanitize and load weights
|
|
model = cls(config)
|
|
weights = model.sanitize(weights)
|
|
model.load_weights(list(weights.items()), strict=False)
|
|
return model
|
|
|
|
|
|
class VideoDecoder(nn.Module):
|
|
|
|
_DEFAULT_NORM_NUM_GROUPS = 32
|
|
|
|
def __init__(
|
|
self,
|
|
convolution_dimensions: int = 3,
|
|
in_channels: int = 128,
|
|
out_channels: int = 3,
|
|
decoder_blocks: List[Tuple[str, Any]] = None,
|
|
patch_size: int = 4,
|
|
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
|
causal: bool = False,
|
|
timestep_conditioning: bool = False,
|
|
decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
|
):
|
|
"""Initialize VideoDecoder.
|
|
|
|
Args:
|
|
convolution_dimensions: Number of dimensions
|
|
in_channels: Input latent channels
|
|
out_channels: Output channels (3 for RGB)
|
|
decoder_blocks: List of (block_name, config) tuples
|
|
patch_size: Spatial patch size
|
|
norm_layer: Normalization layer type
|
|
causal: Whether to use causal convolutions
|
|
timestep_conditioning: Whether to use timestep conditioning
|
|
decoder_spatial_padding_mode: Padding mode
|
|
"""
|
|
super().__init__()
|
|
|
|
if decoder_blocks is None:
|
|
decoder_blocks = []
|
|
|
|
self.patch_size = patch_size
|
|
out_channels = out_channels * patch_size ** 2
|
|
self.causal = causal
|
|
self.timestep_conditioning = timestep_conditioning
|
|
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
|
|
|
|
# Per-channel statistics for denormalizing latents
|
|
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
|
|
|
|
# Noise and timestep parameters
|
|
self.decode_noise_scale = 0.025
|
|
self.decode_timestep = 0.05
|
|
|
|
# Compute initial feature channels
|
|
feature_channels = in_channels
|
|
for block_name, block_params in list(reversed(decoder_blocks)):
|
|
block_config = block_params if isinstance(block_params, dict) else {}
|
|
if block_name == "res_x_y":
|
|
feature_channels = feature_channels * block_config.get("multiplier", 2)
|
|
if block_name == "compress_all":
|
|
feature_channels = feature_channels * block_config.get("multiplier", 1)
|
|
|
|
# Initial convolution
|
|
self.conv_in = CausalConv3d(
|
|
in_channels=in_channels,
|
|
out_channels=feature_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
causal=True,
|
|
spatial_padding_mode=decoder_spatial_padding_mode,
|
|
)
|
|
|
|
# Build decoder blocks (reversed order)
|
|
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
|
self.up_blocks = {}
|
|
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
|
|
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
|
|
|
block, feature_channels = _make_decoder_block(
|
|
block_name=block_name,
|
|
block_config=block_config,
|
|
in_channels=feature_channels,
|
|
convolution_dimensions=convolution_dimensions,
|
|
norm_layer=norm_layer,
|
|
timestep_conditioning=timestep_conditioning,
|
|
norm_num_groups=self._norm_num_groups,
|
|
spatial_padding_mode=decoder_spatial_padding_mode,
|
|
)
|
|
self.up_blocks[idx] = block
|
|
|
|
# Output normalization
|
|
if norm_layer == NormLayerType.GROUP_NORM:
|
|
self.conv_norm_out = nn.GroupNorm(
|
|
num_groups=self._norm_num_groups,
|
|
dims=feature_channels,
|
|
eps=1e-6,
|
|
)
|
|
elif norm_layer == NormLayerType.PIXEL_NORM:
|
|
self.conv_norm_out = PixelNorm()
|
|
|
|
self.conv_act = nn.SiLU()
|
|
self.conv_out = CausalConv3d(
|
|
in_channels=feature_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
causal=True,
|
|
spatial_padding_mode=decoder_spatial_padding_mode,
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
sample: mx.array,
|
|
timestep: Optional[mx.array] = None,
|
|
) -> mx.array:
|
|
"""Decode latent to video.
|
|
|
|
Args:
|
|
sample: Latent tensor of shape (B, 128, F', H', W')
|
|
timestep: Optional timestep for conditioning
|
|
|
|
Returns:
|
|
Decoded video of shape (B, 3, F, H, W)
|
|
"""
|
|
batch_size = sample.shape[0]
|
|
|
|
# Add noise if timestep conditioning is enabled
|
|
if self.timestep_conditioning:
|
|
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
|
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
|
|
|
# Denormalize latents
|
|
sample = self.per_channel_statistics.un_normalize(sample)
|
|
|
|
# Use default timestep if not provided
|
|
if timestep is None and self.timestep_conditioning:
|
|
timestep = mx.full((batch_size,), self.decode_timestep)
|
|
|
|
# Initial convolution
|
|
sample = self.conv_in(sample, causal=self.causal)
|
|
|
|
# Process through decoder blocks
|
|
for i in range(len(self.up_blocks)):
|
|
up_block = self.up_blocks[i]
|
|
if isinstance(up_block, UNetMidBlock3D):
|
|
sample = up_block(sample, causal=self.causal)
|
|
elif isinstance(up_block, ResnetBlock3D):
|
|
sample = up_block(sample, causal=self.causal)
|
|
else:
|
|
sample = up_block(sample, causal=self.causal)
|
|
|
|
# Output processing
|
|
sample = self.conv_norm_out(sample)
|
|
sample = self.conv_act(sample)
|
|
sample = self.conv_out(sample, causal=self.causal)
|
|
|
|
# Unpatchify to restore spatial resolution
|
|
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
|
|
|
return sample
|