Refactor LTX-2 model structure

This commit is contained in:
Prince Canuma
2026-03-16 14:50:01 +01:00
parent decb3eb9e5
commit 3a0da19adb
50 changed files with 3882 additions and 3365 deletions

View File

@@ -0,0 +1,8 @@
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
from mlx_video.models.ltx_2.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
from mlx_video.models.ltx_2.video_vae.tiling import (
TilingConfig,
SpatialTilingConfig,
TemporalTilingConfig,
)

View File

@@ -0,0 +1,294 @@
from enum import Enum
from typing import List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
class PaddingModeType(Enum):
ZEROS = "zeros"
REFLECT = "reflect"
def reflect_pad_2d(x: mx.array, pad_h: int, pad_w: int) -> mx.array:
"""Apply reflect padding to spatial dimensions of a 5D tensor.
Args:
x: Input tensor of shape (B, D, H, W, C) - channels last
pad_h: Padding for height dimension
pad_w: Padding for width dimension
Returns:
Padded tensor
"""
if pad_h == 0 and pad_w == 0:
return x
# Height padding (axis 2)
if pad_h > 0:
# Get reflection indices - exclude boundary
top_pad = x[:, :, 1:pad_h+1, :, :][:, :, ::-1, :, :] # Flip top portion
bottom_pad = x[:, :, -pad_h-1:-1, :, :][:, :, ::-1, :, :] # Flip bottom portion
x = mx.concatenate([top_pad, x, bottom_pad], axis=2)
# Width padding (axis 3)
if pad_w > 0:
left_pad = x[:, :, :, 1:pad_w+1, :][:, :, :, ::-1, :] # Flip left portion
right_pad = x[:, :, :, -pad_w-1:-1, :][:, :, :, ::-1, :] # Flip right portion
x = mx.concatenate([left_pad, x, right_pad], axis=3)
return x
def make_conv_nd(
dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[int, Tuple[int, ...], str] = 0,
causal: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
) -> nn.Module:
if dims == 2:
return CausalConv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
causal=causal,
spatial_padding_mode=spatial_padding_mode,
)
elif dims == 3:
return CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
causal=causal,
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"Unsupported number of dimensions: {dims}")
class CausalConv3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int], str] = 0,
causal: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
self.causal = causal
self.spatial_padding_mode = spatial_padding_mode
# Normalize kernel_size and stride to tuples
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride, stride)
self.kernel_size = kernel_size
self.stride = stride
self.time_kernel_size = kernel_size[0]
# Calculate spatial padding (temporal is handled separately via frame replication)
height_pad = kernel_size[1] // 2
width_pad = kernel_size[2] // 2
self.spatial_padding = (height_pad, width_pad)
# Create the base convolution (without padding, we'll handle it manually)
self.conv = nn.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0, # We handle padding manually
bias=True,
)
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
use_causal = causal if causal is not None else self.causal
# Apply temporal padding via frame replication
# Only apply if kernel_size > 1
if self.time_kernel_size > 1:
if use_causal:
# Causal: replicate first frame kernel_size-1 times at the beginning
first_frame_pad = mx.repeat(x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2)
x = mx.concatenate([first_frame_pad, x], axis=2)
else:
# Non-causal: replicate first frame at start, last frame at end
pad_size = (self.time_kernel_size - 1) // 2
if pad_size > 0:
first_frame_pad = mx.repeat(x[:, :, :1, :, :], pad_size, axis=2)
last_frame_pad = mx.repeat(x[:, :, -1:, :, :], pad_size, axis=2)
x = mx.concatenate([first_frame_pad, x, last_frame_pad], axis=2)
# Transpose to channels last: (B, C, D, H, W) -> (B, D, H, W, C)
x = mx.transpose(x, (0, 2, 3, 4, 1))
# Apply spatial padding
pad_h, pad_w = self.spatial_padding
if pad_h > 0 or pad_w > 0:
if self.spatial_padding_mode == PaddingModeType.REFLECT:
# Use reflect padding for spatial dimensions
x = reflect_pad_2d(x, pad_h, pad_w)
else:
# Use zero padding for spatial dimensions
pad_width = [
(0, 0), # Batch
(0, 0), # D (temporal - already padded)
(pad_h, pad_h), # H
(pad_w, pad_w), # W
(0, 0), # C
]
x = mx.pad(x, pad_width)
# Apply convolution with chunking for large tensors
# Note: We choose to use chunking because MLX conv3d fails around 33 frames with 192x192 spatial
x = self._chunked_conv3d(x)
# Transpose back to channels first: (B, D, H, W, C) -> (B, C, D, H, W)
x = mx.transpose(x, (0, 4, 1, 2, 3))
return x
def _chunked_conv3d(self, x: mx.array) -> mx.array:
"""Apply conv3d in temporal chunks to work around MLX bug with large tensors.
Args:
x: Input tensor of shape (B, D, H, W, C) in channels-last format
Returns:
Output tensor after conv3d
"""
b, d, h, w, c = x.shape
total_elements = d * h * w * c
max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk
if total_elements <= max_safe_elements:
return self.conv(x)
elements_per_frame = h * w * c
max_frames_per_chunk = max(1, max_safe_elements // elements_per_frame)
chunk_size = min(max_frames_per_chunk, 24) # Cap at 24 frames per chunk
kernel_t = self.time_kernel_size
overlap = kernel_t - 1
expected_output_frames = d - overlap
outputs = []
out_idx = 0
# Process chunks
in_start = 0
while out_idx < expected_output_frames:
remaining = expected_output_frames - out_idx
out_frames_this_chunk = min(chunk_size, remaining)
in_frames_needed = out_frames_this_chunk + overlap
in_end = min(in_start + in_frames_needed, d)
chunk = x[:, in_start:in_end, :, :, :]
chunk_out = self.conv(chunk)
mx.eval(chunk_out)
outputs.append(chunk_out)
out_idx += chunk_out.shape[1]
in_start += chunk_out.shape[1]
# Concatenate all chunks
if len(outputs) == 1:
return outputs[0]
return mx.concatenate(outputs, axis=1)
class CausalConv2d(nn.Module):
"""2D convolution with optional causal padding."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int], str] = 0,
causal: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
"""Initialize CausalConv2d."""
super().__init__()
self.causal = causal
self.spatial_padding_mode = spatial_padding_mode
# Normalize kernel_size and stride
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride)
self.kernel_size = kernel_size
self.stride = stride
# Calculate padding
if isinstance(padding, str) and padding == "same":
self.padding = (
(kernel_size[0] - 1) // 2,
(kernel_size[1] - 1) // 2,
)
elif isinstance(padding, int):
self.padding = (padding, padding)
else:
self.padding = padding
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0,
bias=True,
)
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
"""Forward pass."""
# Transpose to channels last: (B, C, H, W) -> (B, H, W, C)
x = mx.transpose(x, (0, 2, 3, 1))
# Apply padding
pad_h, pad_w = self.padding
if pad_h != 0 or pad_w != 0:
pad_width = [
(0, 0), # Batch
(pad_h, pad_h), # H
(pad_w, pad_w), # W
(0, 0), # C
]
x = mx.pad(x, pad_width)
x = self.conv(x)
# Transpose back: (B, H, W, C) -> (B, C, H, W)
x = mx.transpose(x, (0, 3, 1, 2))
return x

View File

@@ -0,0 +1,692 @@
"""Video VAE Decoder for LTX-2 with timestep conditioning.
Architecture (from PyTorch weights):
- conv_in: 128 -> 1024
- up_blocks.0: 5 ResBlocks at 1024 (with timestep)
- up_blocks.1: Conv 1024 -> 4096, depth2space -> 512, upscale 2x
- up_blocks.2: 5 ResBlocks at 512 (with timestep)
- up_blocks.3: Conv 512 -> 2048, depth2space -> 256, upscale 2x
- up_blocks.4: 5 ResBlocks at 256 (with timestep)
- up_blocks.5: Conv 256 -> 1024, depth2space -> 128, upscale 2x
- up_blocks.6: 5 ResBlocks at 128 (with timestep)
- pixel_norm + timestep modulation (last_scale_shift_table)
- conv_out: 128 -> 48
- unpatchify: 48 -> 3 with patch_size=4
"""
import math
from typing import Optional, Dict
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx_2.video_vae.ops import unpatchify, PerChannelStatistics
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling
def get_timestep_embedding(
timesteps: mx.array,
embedding_dim: int,
flip_sin_to_cos: bool = True,
downscale_freq_shift: float = 0,
scale: float = 1,
max_period: int = 10000,
) -> mx.array:
"""Create sinusoidal timestep embeddings."""
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * mx.arange(0, half_dim, dtype=mx.float32)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = mx.exp(exponent)
emb = timesteps[:, None].astype(mx.float32) * emb[None, :]
emb = scale * emb
emb = mx.concatenate([mx.sin(emb), mx.cos(emb)], axis=-1)
if flip_sin_to_cos:
emb = mx.concatenate([emb[:, half_dim:], emb[:, :half_dim]], axis=-1)
if embedding_dim % 2 == 1:
emb = mx.pad(emb, [(0, 0), (0, 1)])
return emb
class TimestepEmbedding(nn.Module):
"""MLP for timestep embedding."""
def __init__(self, in_channels: int, time_embed_dim: int):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
self.act = nn.SiLU()
def __call__(self, sample: mx.array) -> mx.array:
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class PixArtAlphaTimestepEmbedder(nn.Module):
"""Combined timestep embedding (sinusoidal + MLP)."""
def __init__(self, embedding_dim: int):
super().__init__()
self.timestep_embedder = TimestepEmbedding(
in_channels=256,
time_embed_dim=embedding_dim
)
def __call__(self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32) -> mx.array:
timesteps_proj = get_timestep_embedding(
timestep,
embedding_dim=256,
flip_sin_to_cos=True,
downscale_freq_shift=0
)
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
return timesteps_emb
class ResnetBlock3DSimple(nn.Module):
"""ResNet block with optional timestep conditioning.
Weight keys: conv1.conv, conv2.conv, scale_shift_table
"""
def __init__(
self,
channels: int,
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
timestep_conditioning: bool = False,
):
super().__init__()
self.timestep_conditioning = timestep_conditioning
# Nested conv structure to match PyTorch naming: conv1.conv.weight
self.conv1 = self._make_conv_wrapper(channels, channels, spatial_padding_mode)
self.conv2 = self._make_conv_wrapper(channels, channels, spatial_padding_mode)
self.act = nn.SiLU()
# Scale-shift table for timestep conditioning: [shift1, scale1, shift2, scale2]
if timestep_conditioning:
self.scale_shift_table = mx.zeros((4, channels))
def _make_conv_wrapper(self, in_ch, out_ch, padding_mode):
"""Create a wrapper object with a 'conv' attribute to match PyTorch naming."""
class ConvWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
self_inner.conv = CausalConv3d(
in_channels=in_ch,
out_channels=out_ch,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
return ConvWrapper()
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization."""
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
def __call__(
self,
x: mx.array,
causal: bool = False,
timestep_embed: Optional[mx.array] = None,
) -> mx.array:
residual = x
batch_size = x.shape[0]
# Block 1 with optional timestep conditioning
x = self.pixel_norm(x)
if self.timestep_conditioning and timestep_embed is not None:
# scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1)
# Combine table with timestep embedding
ada_values = self.scale_shift_table[None, :, :, None, None, None] # (1, 4, C, 1, 1, 1)
# Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1)
channels = self.scale_shift_table.shape[1]
ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1)
ada_values = ada_values + ts_reshaped
shift1 = ada_values[:, 0] # (B, C, 1, 1, 1)
scale1 = ada_values[:, 1]
shift2 = ada_values[:, 2]
scale2 = ada_values[:, 3]
x = x * (1 + scale1) + shift1
x = self.act(x)
x = self.conv1(x, causal=causal)
# Block 2 with optional timestep conditioning
x = self.pixel_norm(x)
if self.timestep_conditioning and timestep_embed is not None:
x = x * (1 + scale2) + shift2
x = self.act(x)
x = self.conv2(x, causal=causal)
return x + residual
class ResBlockGroup(nn.Module):
"""Group of ResNet blocks with shared timestep embedding.
PyTorch naming: res_blocks.0, res_blocks.1, etc.
"""
def __init__(
self,
channels: int,
num_layers: int = 5,
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
timestep_conditioning: bool = False,
):
super().__init__()
self.timestep_conditioning = timestep_conditioning
# Time embedder for this block group: embed_dim = 4 * channels
if timestep_conditioning:
self.time_embedder = PixArtAlphaTimestepEmbedder(
embedding_dim=channels * 4
)
# Use dict with int keys for MLX to track parameters properly
self.res_blocks = {
i: ResnetBlock3DSimple(
channels,
spatial_padding_mode,
timestep_conditioning=timestep_conditioning
)
for i in range(num_layers)
}
def __call__(
self,
x: mx.array,
causal: bool = False,
timestep: Optional[mx.array] = None,
) -> mx.array:
timestep_embed = None
if self.timestep_conditioning and timestep is not None:
batch_size = x.shape[0]
timestep_embed = self.time_embedder(
timestep.flatten(),
hidden_dtype=x.dtype
)
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
for res_block in self.res_blocks.values():
x = res_block(x, causal=causal, timestep_embed=timestep_embed)
return x
class LTX2VideoDecoder(nn.Module):
"""LTX-2 Video VAE Decoder with timestep conditioning.
Architecture:
- conv_in: 128 -> 1024
- up_blocks.0: 5 ResBlocks at 1024 (with timestep)
- up_blocks.1: Upsampler 1024 -> 512
- up_blocks.2: 5 ResBlocks at 512 (with timestep)
- up_blocks.3: Upsampler 512 -> 256
- up_blocks.4: 5 ResBlocks at 256 (with timestep)
- up_blocks.5: Upsampler 256 -> 128
- up_blocks.6: 5 ResBlocks at 128 (with timestep)
- conv_out: 128 -> 48 (3 * 4^2 for patch_size=4)
"""
# Block definitions: ("res", channels, num_layers) or ("d2s", in_channels, reduction, stride)
# stride is (D, H, W) tuple
DEFAULT_BLOCKS = [
("res", 1024, 5),
("d2s", 1024, 2, (2, 2, 2)),
("res", 512, 5),
("d2s", 512, 2, (2, 2, 2)),
("res", 256, 5),
("d2s", 256, 2, (2, 2, 2)),
("res", 128, 5),
]
def __init__(
self,
in_channels: int = 128,
out_channels: int = 3,
patch_size: int = 4,
num_layers_per_block: int = 5,
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
timestep_conditioning: bool = True,
decoder_blocks: list = None,
):
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.timestep_conditioning = timestep_conditioning
# Decode parameters (configurable via constructor)
self.decode_noise_scale = 0.025 # Set to 0.0 to disable noise
self.decode_timestep = 0.05
# Per-channel statistics for denormalization (loaded from weights)
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
blocks = decoder_blocks or self.DEFAULT_BLOCKS
first_ch = blocks[0][1]
last_ch = blocks[-1][1]
# Initial conv: in_channels -> first block channels
class ConvInWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
self_inner.conv = CausalConv3d(
in_channels=in_channels,
out_channels=first_ch,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
self.conv_in = ConvInWrapper()
# Build up blocks from config
self.up_blocks = {}
for idx, block_def in enumerate(blocks):
block_type = block_def[0]
ch = block_def[1]
if block_type == "res":
num_layers = block_def[2] if len(block_def) > 2 else num_layers_per_block
self.up_blocks[idx] = ResBlockGroup(ch, num_layers, spatial_padding_mode, timestep_conditioning)
elif block_type == "d2s":
reduction = block_def[2] if len(block_def) > 2 else 2
stride = block_def[3] if len(block_def) > 3 else (2, 2, 2)
residual = block_def[4] if len(block_def) > 4 else True
self.up_blocks[idx] = DepthToSpaceUpsample(
dims=3,
in_channels=ch,
stride=stride,
residual=residual,
out_channels_reduction_factor=reduction,
spatial_padding_mode=spatial_padding_mode,
)
final_out_channels = out_channels * patch_size * patch_size
class ConvOutWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
self_inner.conv = CausalConv3d(
in_channels=last_ch,
out_channels=final_out_channels,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
self.conv_out = ConvOutWrapper()
self.act = nn.SiLU()
if timestep_conditioning:
self.timestep_scale_multiplier = mx.array(1000.0)
self.last_time_embedder = PixArtAlphaTimestepEmbedder(
embedding_dim=last_ch * 2
)
self.last_scale_shift_table = mx.zeros((2, last_ch))
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
# Build decoder weights dict with key remapping
sanitized = {}
if "per_channel_statistics.mean" in weights:
return weights
for key, value in weights.items():
new_key = key
if not key.startswith("vae.") or key.startswith("vae.encoder."):
continue
if key.startswith("vae.per_channel_statistics."):
# Map per-channel statistics (use exact key matching)
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
continue # Skip other statistics keys
if key.startswith("vae.decoder."):
new_key = key.replace("vae.decoder.", "")
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".conv.weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
if ".conv.bias" in key:
pass # bias doesn't need transpose
if ".conv.weight" in new_key or ".conv.bias" in new_key:
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTX2VideoDecoder":
"""Load a pretrained decoder from a directory with config.json and weights.
Args:
model_path: Path to directory containing config.json and safetensors files,
or path to a single safetensors file.
strict: Whether to require all weight keys to match.
Returns:
Loaded LTX2VideoDecoder instance
"""
import json
model_path = Path(model_path)
config_dict = {}
# Load config from directory
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path) as f:
config_dict = json.load(f)
# Load weights from directory
weight_files = sorted(model_path.glob("*.safetensors"))
if not weight_files:
raise FileNotFoundError(f"No safetensors files found in {model_path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(str(wf)))
# Infer block structure from weights
decoder_blocks = cls._infer_blocks(weights)
# Determine spatial padding mode from config
spatial_padding_mode_str = config_dict.get("spatial_padding_mode", "reflect")
spatial_padding_mode = PaddingModeType(spatial_padding_mode_str)
model = cls(
timestep_conditioning=config_dict.get("timestep_conditioning", False),
decoder_blocks=decoder_blocks,
spatial_padding_mode=spatial_padding_mode,
)
weights = model.sanitize(weights)
model.load_weights(list(weights.items()), strict=strict)
return model
@staticmethod
def _infer_blocks(weights: dict) -> list:
"""Infer decoder block structure from weight keys."""
block_indices = set()
for k in weights:
if "up_blocks." in k:
idx_str = k.split("up_blocks.")[1].split(".")[0]
if idx_str.isdigit():
block_indices.add(int(idx_str))
if not block_indices:
return None
# First pass: collect block info
raw_blocks = []
for idx in sorted(block_indices):
has_conv = any(f"up_blocks.{idx}.conv." in k for k in weights)
res_indices = set()
for k in weights:
prefix = f"up_blocks.{idx}.res_blocks."
if prefix in k:
res_idx = k.split(prefix)[1].split(".")[0]
if res_idx.isdigit():
res_indices.add(int(res_idx))
if has_conv and not res_indices:
# D2S block - get conv shape
for k, v in weights.items():
if f"up_blocks.{idx}.conv." in k and "weight" in k:
in_ch = v.shape[-1] if v.ndim == 5 else v.shape[1]
conv_out_ch = v.shape[0]
raw_blocks.append(("d2s", in_ch, conv_out_ch))
break
elif res_indices:
num_res = max(res_indices) + 1
for k, v in weights.items():
if f"up_blocks.{idx}.res_blocks.0.conv1" in k and "weight" in k:
ch = v.shape[0]
raw_blocks.append(("res", ch, num_res))
break
# Second pass: determine d2s strides using the channel progression
# For each d2s block, the next res block tells us the expected output channels
blocks = []
d2s_strides = []
for i, block in enumerate(raw_blocks):
if block[0] == "res":
blocks.append(block)
elif block[0] == "d2s":
in_ch, conv_out_ch = block[1], block[2]
# Find next res block's channels
next_ch = None
for j in range(i + 1, len(raw_blocks)):
if raw_blocks[j][0] == "res":
next_ch = raw_blocks[j][1]
break
if next_ch is None:
next_ch = in_ch // 2 # fallback
# out_ch = in_ch // reduction
reduction = in_ch // next_ch if next_ch > 0 else 2
# conv_out = next_ch * multiplier → multiplier = conv_out / next_ch
multiplier = conv_out_ch // next_ch if next_ch > 0 else 8
# Determine stride from multiplier
if multiplier == 8:
stride = (2, 2, 2)
elif multiplier == 4:
stride = (1, 2, 2)
elif multiplier == 2:
stride = (2, 1, 1)
else:
stride = (2, 2, 2)
d2s_strides.append(stride)
blocks.append(("d2s", in_ch, reduction, stride))
if not blocks:
return None
# Determine residual flag: LTX-2 has uniform (2,2,2) strides with reduction=2 → residual=True
# LTX-2.3 has mixed strides or reduction=1 → residual=False
has_mixed_strides = len(set(d2s_strides)) > 1
has_non_standard_reduction = any(b[2] != 2 for b in blocks if b[0] == "d2s")
use_residual = not has_mixed_strides and not has_non_standard_reduction
# Apply residual flag to all d2s blocks
final_blocks = []
for block in blocks:
if block[0] == "d2s":
final_blocks.append(("d2s", block[1], block[2], block[3], use_residual))
else:
final_blocks.append(block)
return final_blocks
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization."""
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
def __call__(
self,
sample: mx.array,
causal: bool = False,
timestep: Optional[mx.array] = None,
debug: bool = False,
chunked_conv: bool = False,
) -> mx.array:
batch_size = sample.shape[0]
# Add noise if timestep conditioning is enabled
if self.timestep_conditioning:
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
sample = noise + (1.0 - self.decode_noise_scale) * sample
sample = self.per_channel_statistics.un_normalize(sample)
if timestep is None and self.timestep_conditioning:
timestep = mx.full((batch_size,), self.decode_timestep)
scaled_timestep = None
if self.timestep_conditioning and timestep is not None:
scaled_timestep = timestep * self.timestep_scale_multiplier
x = self.conv_in(sample, causal=causal)
for i, block in self.up_blocks.items():
if isinstance(block, ResBlockGroup):
x = block(x, causal=causal, timestep=scaled_timestep)
elif isinstance(block, DepthToSpaceUpsample):
x = block(x, causal=causal, chunked_conv=chunked_conv)
else:
x = block(x, causal=causal)
x = self.pixel_norm(x)
if self.timestep_conditioning and scaled_timestep is not None:
embedded_timestep = self.last_time_embedder(
scaled_timestep.flatten(),
hidden_dtype=x.dtype
)
embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1)
ada_values = self.last_scale_shift_table[None, :, :, None, None, None] # (1, 2, 128, 1, 1, 1)
ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1)
ada_values = ada_values + ts_reshaped
shift = ada_values[:, 0] # (B, 128, 1, 1, 1)
scale = ada_values[:, 1]
x = x * (1 + scale) + shift
x = self.act(x)
x = self.conv_out(x, causal=causal)
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
return x
def decode_tiled(
self,
sample: mx.array,
tiling_config: Optional[TilingConfig] = None,
tiling_mode: str = "auto",
causal: bool = False,
timestep: Optional[mx.array] = None,
debug: bool = False,
on_frames_ready: Optional[callable] = None,
) -> mx.array:
"""Decode latents using tiling to reduce memory usage.
This method is useful for decoding large videos that would otherwise
cause out-of-memory errors. It divides the latents into tiles,
decodes each tile separately, and blends them together.
Args:
sample: Input latents of shape (B, C, F, H, W).
tiling_config: Tiling configuration. If None, uses TilingConfig.default().
causal: Whether to use causal convolutions.
timestep: Optional timestep for conditioning.
debug: Whether to print debug info.
Returns:
Decoded video of shape (B, 3, F*8, H*8, W*8).
"""
if tiling_config is None:
tiling_config = TilingConfig.default()
# Check if tiling is actually needed
_, _, f, h, w = sample.shape
needs_spatial_tiling = False
needs_temporal_tiling = False
# Spatial scale is 32 (8x VAE upsample + 4x unpatchify)
# Temporal scale is 8
spatial_scale = 32
temporal_scale = 8
if tiling_config.spatial_config is not None:
s_cfg = tiling_config.spatial_config
tile_size_latent = s_cfg.tile_size_in_pixels // spatial_scale
if h > tile_size_latent or w > tile_size_latent:
needs_spatial_tiling = True
if tiling_config.temporal_config is not None:
t_cfg = tiling_config.temporal_config
tile_size_latent = t_cfg.tile_size_in_frames // temporal_scale
if f > tile_size_latent:
needs_temporal_tiling = True
# Auto-enable chunked conv for modes where it helps (larger tiles)
# Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks
use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial")
if not needs_spatial_tiling and not needs_temporal_tiling:
# No tiling needed, use regular decode
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
return decode_with_tiling(
decoder_fn=self,
latents=sample,
tiling_config=tiling_config,
spatial_scale=32, # VAE spatial: 8x upsampling + 4x unpatchify = 32x
temporal_scale=8, # VAE temporal upsampling factor
causal=causal,
timestep=timestep,
chunked_conv=use_chunked_conv,
on_frames_ready=on_frames_ready,
)
# Backward-compatible alias
VideoDecoder = LTX2VideoDecoder

View File

@@ -0,0 +1,44 @@
"""Video VAE Encoder for LTX-2 Image-to-Video.
The encoder compresses input images/videos to latent representations.
Used for I2V (image-to-video) conditioning by encoding the input image
to latent space, which can then be used to condition video generation.
"""
import mlx.core as mx
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
def encode_image(
image: mx.array,
encoder: VideoEncoder,
) -> mx.array:
"""Encode a single image to latent space.
Args:
image: Image tensor of shape (H, W, 3) in range [0, 1] or (B, H, W, 3)
encoder: Loaded VAE encoder
Returns:
Latent tensor of shape (1, 128, 1, H//32, W//32)
"""
# Add batch dimension if needed
if image.ndim == 3:
image = mx.expand_dims(image, axis=0) # (1, H, W, 3)
# Convert from (B, H, W, C) to (B, C, H, W)
image = mx.transpose(image, (0, 3, 1, 2)) # (B, 3, H, W)
# Normalize to [-1, 1]
if image.max() > 1.0:
image = image / 255.0
image = image * 2.0 - 1.0
# Add temporal dimension: (B, C, H, W) -> (B, C, 1, H, W)
image = mx.expand_dims(image, axis=2) # (B, 3, 1, H, W)
# Encode
latent = encoder(image)
return latent

View File

@@ -0,0 +1,125 @@
"""Operations for Video VAE."""
from typing import List, Tuple
import mlx.core as mx
import mlx.nn as nn
def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
"""Convert video to patches.
Moves spatial pixels from H, W dimensions to channel dimension.
Args:
x: Input tensor of shape (B, C, F, H, W)
patch_size_hw: Spatial patch size
patch_size_t: Temporal patch size
Returns:
Patched tensor of shape (B, C * patch_size_hw^2, F, H/patch_size_hw, W/patch_size_hw)
"""
b, c, f, h, w = x.shape
# Check dimensions are divisible
assert h % patch_size_hw == 0 and w % patch_size_hw == 0
assert f % patch_size_t == 0
# New dimensions
new_h = h // patch_size_hw
new_w = w // patch_size_hw
new_f = f // patch_size_t
new_c = c * patch_size_hw * patch_size_hw * patch_size_t
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw))
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, pw, ph, F', H', W')
# PyTorch einops uses (c, p, r, q) = (c, temporal, width, height), so we need pw before ph
x = mx.transpose(x, (0, 1, 3, 7, 5, 2, 4, 6))
# Reshape: (B, C, pt, pw, ph, F', H', W') -> (B, C*pt*pw*ph, F', H', W')
x = mx.reshape(x, (b, new_c, new_f, new_h, new_w))
return x
def unpatchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
"""Convert patches back to video.
Inverse of patchify - moves pixels from channel dimension back to spatial.
Matches PyTorch einops: "b (c p r q) f h w -> b c (f p) (h q) (w r)"
where p=patch_size_t, r=patch_size_hw (width), q=patch_size_hw (height)
Args:
x: Patched tensor of shape (B, C * patch_size_hw^2, F, H, W)
patch_size_hw: Spatial patch size
patch_size_t: Temporal patch size
Returns:
Video tensor of shape (B, C, F * patch_size_t, H * patch_size_hw, W * patch_size_hw)
"""
b, c_packed, f, h, w = x.shape
# Calculate original channel count
c = c_packed // (patch_size_hw * patch_size_hw * patch_size_t)
# Reshape: (B, C*pt*pr*pq, F, H, W) -> (B, C, pt, pr, pq, F, H, W)
# where pt=temporal, pr=width_patch (r), pq=height_patch (q)
# Channel layout from PyTorch is (c, p, r, q) = (c, temporal, width, height)
x = mx.reshape(x, (b, c, patch_size_t, patch_size_hw, patch_size_hw, f, h, w))
# Permute to interleave patches with spatial dims:
# (B, C, pt, pr, pq, F, H, W) -> (B, C, F, pt, H, pq, W, pr)
x = mx.transpose(x, (0, 1, 5, 2, 6, 4, 7, 3))
# Reshape: (B, C, F, pt, H, pq, W, pr) -> (B, C, F*pt, H*pq, W*pr)
x = mx.reshape(x, (b, c, f * patch_size_t, h * patch_size_hw, w * patch_size_hw))
return x
class PerChannelStatistics(nn.Module):
def __init__(self, latent_channels: int = 128):
super().__init__()
self.latent_channels = latent_channels
# Learnable per-channel mean and std
self.mean = mx.zeros((latent_channels,))
self.std = mx.ones((latent_channels,))
def normalize(self, x: mx.array) -> mx.array:
"""Normalize latents using per-channel statistics.
Args:
x: Input tensor of shape (B, C, ...)
Returns:
Normalized tensor
"""
# Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1)
dtype = x.dtype
# Cast to float32 for precision
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
return ((x - mean) / std).astype(dtype)
def un_normalize(self, x: mx.array) -> mx.array:
"""Denormalize latents using per-channel statistics.
Args:
x: Normalized tensor of shape (B, C, ...)
Returns:
Denormalized tensor
"""
dtype = x.dtype
# Cast to float32 for precision
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
return (x * std + mean).astype(dtype)

View File

@@ -0,0 +1,172 @@
"""ResNet blocks for Video VAE."""
from enum import Enum
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.utils import PixelNorm
class NormLayerType(Enum):
GROUP_NORM = "group_norm"
PIXEL_NORM = "pixel_norm"
def get_norm_layer(
norm_type: NormLayerType,
num_channels: int,
num_groups: int = 32,
eps: float = 1e-6,
) -> nn.Module:
if norm_type == NormLayerType.GROUP_NORM:
return nn.GroupNorm(num_groups=num_groups, dims=num_channels, eps=eps)
elif norm_type == NormLayerType.PIXEL_NORM:
return PixelNorm(eps=eps)
else:
raise ValueError(f"Unknown norm type: {norm_type}")
class ResnetBlock3D(nn.Module):
def __init__(
self,
dims: int,
in_channels: int,
out_channels: Optional[int] = None,
eps: float = 1e-6,
groups: int = 32,
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
inject_noise: bool = False,
timestep_conditioning: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
out_channels = out_channels or in_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.inject_noise = inject_noise
# First normalization and convolution
self.norm1 = get_norm_layer(norm_layer, in_channels, groups, eps)
self.conv1 = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
# Second normalization and convolution
self.norm2 = get_norm_layer(norm_layer, out_channels, groups, eps)
self.conv2 = CausalConv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
# Shortcut connection if channels change
if in_channels != out_channels:
self.shortcut = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
spatial_padding_mode=spatial_padding_mode,
)
else:
self.shortcut = None
# Activation
self.act = nn.SiLU()
def __call__(
self,
x: mx.array,
causal: bool = True,
generator: Optional[int] = None,
) -> mx.array:
residual = x
# First block
x = self.norm1(x)
x = self.act(x)
x = self.conv1(x, causal=causal)
# Inject noise if enabled
if self.inject_noise and generator is not None:
noise = mx.random.normal(x.shape)
x = x + noise * 0.01
# Second block
x = self.norm2(x)
x = self.act(x)
x = self.conv2(x, causal=causal)
# Shortcut
if self.shortcut is not None:
residual = self.shortcut(residual, causal=causal)
return x + residual
class UNetMidBlock3D(nn.Module):
def __init__(
self,
dims: int,
in_channels: int,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_groups: int = 32,
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
inject_noise: bool = False,
timestep_conditioning: bool = False,
attention_head_dim: Optional[int] = None,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
self.num_layers = num_layers
# Create ResNet blocks - use dict for MLX parameter tracking
# Named res_blocks to match PyTorch weight keys
self.res_blocks = {
i: ResnetBlock3D(
dims=dims,
in_channels=in_channels,
out_channels=in_channels,
eps=resnet_eps,
groups=resnet_groups,
norm_layer=norm_layer,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
for i in range(num_layers)
}
def __call__(
self,
x: mx.array,
causal: bool = True,
timestep: Optional[mx.array] = None,
generator: Optional[int] = None,
) -> mx.array:
for resnet in self.res_blocks.values():
x = resnet(x, causal=causal, generator=generator)
return x

View File

@@ -0,0 +1,275 @@
"""Sampling operations for Video VAE (upsampling/downsampling)."""
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
class SpaceToDepthDownsample(nn.Module):
"""Space-to-depth downsampling with 3x3 conv and skip connection.
PyTorch-compatible implementation:
1. Apply 3x3 conv: in_channels -> out_channels // prod(stride)
2. Space-to-depth on conv output: channels * prod(stride)
3. Space-to-depth on input with group averaging for skip connection
4. Add skip connection
"""
def __init__(
self,
dims: int,
in_channels: int,
out_channels: int,
stride: Union[int, Tuple[int, int, int]],
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
if isinstance(stride, int):
stride = (stride, stride, stride)
self.stride = stride
self.dims = dims
self.out_channels = out_channels
# Calculate channels
multiplier = stride[0] * stride[1] * stride[2]
self.group_size = in_channels * multiplier // out_channels
conv_out_channels = out_channels // multiplier
# 3x3 convolution (not 1x1)
self.conv = CausalConv3d(
in_channels=in_channels,
out_channels=conv_out_channels,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def _space_to_depth(self, x: mx.array) -> mx.array:
"""Rearrange: b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w"""
b, c, d, h, w = x.shape
st, sh, sw = self.stride
# Reshape to group spatial elements
x = mx.reshape(x, (b, c, d // st, st, h // sh, sh, w // sw, sw))
# Permute: (B, C, D', st, H', sh, W', sw) -> (B, C, st, sh, sw, D', H', W')
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
# Reshape to combine channels
new_c = c * st * sh * sw
new_d = d // st
new_h = h // sh
new_w = w // sw
x = mx.reshape(x, (b, new_c, new_d, new_h, new_w))
return x
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
b, c, d, h, w = x.shape
st, sh, sw = self.stride
# Temporal padding for causal mode
if st == 2:
# Duplicate first frame for padding
x = mx.concatenate([x[:, :, :1, :, :], x], axis=2)
d = d + 1
# Pad if necessary to make dimensions divisible by stride
pad_d = (st - d % st) % st
pad_h = (sh - h % sh) % sh
pad_w = (sw - w % sw) % sw
if pad_d > 0 or pad_h > 0 or pad_w > 0:
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)])
# Skip connection: space-to-depth on input, then group mean
x_in = self._space_to_depth(x)
# Reshape for group mean: (b, c*prod(stride), d, h, w) -> (b, out_channels, group_size, d, h, w)
b2, c2, d2, h2, w2 = x_in.shape
x_in = mx.reshape(x_in, (b2, self.out_channels, self.group_size, d2, h2, w2))
x_in = mx.mean(x_in, axis=2) # (b, out_channels, d, h, w)
# Conv branch: apply conv then space-to-depth
x_conv = self.conv(x, causal=causal)
x_conv = self._space_to_depth(x_conv)
# Add skip connection
return x_conv + x_in
class DepthToSpaceUpsample(nn.Module):
def __init__(
self,
dims: int,
in_channels: int,
stride: Union[int, Tuple[int, int, int]],
residual: bool = False,
out_channels_reduction_factor: int = 1,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
if isinstance(stride, int):
stride = (stride, stride, stride)
self.stride = stride
self.dims = dims
self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor
# Calculate output channels
multiplier = stride[0] * stride[1] * stride[2]
out_channels = in_channels // out_channels_reduction_factor
self.out_channels = out_channels
# 3x3x3 convolution to prepare channels for unpacking (matches PyTorch)
self.conv = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels * multiplier,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def _depth_to_space(self, x: mx.array) -> mx.array:
b, c_packed, d, h, w = x.shape
st, sh, sw = self.stride
c = c_packed // (st * sh * sw)
# (B, C*st*sh*sw, D, H, W) -> (B, C, st, sh, sw, D, H, W)
x = mx.reshape(x, (b, c, st, sh, sw, d, h, w))
# (B, C, st, sh, sw, D, H, W) -> (B, C, D, st, H, sh, W, sw)
x = mx.transpose(x, (0, 1, 5, 2, 6, 3, 7, 4))
# (B, C, D, st, H, sh, W, sw) -> (B, C, D*st, H*sh, W*sw)
x = mx.reshape(x, (b, c, d * st, h * sh, w * sw))
return x
def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array:
b, c, d, h, w = x.shape
st, sh, sw = self.stride
# Compute residual path if enabled
x_residual = None
if self.residual:
# Reshape input: treat channels as spatial factors
# "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)"
x_residual = self._depth_to_space(x)
# Tile channels to match output (PyTorch .repeat() tiles, not element-repeat!)
# num_repeat = prod(stride) / out_channels_reduction_factor
num_repeat = (st * sh * sw) // self.out_channels_reduction_factor
x_residual = mx.tile(x_residual, (1, num_repeat, 1, 1, 1))
# Remove first temporal frame if temporal upsampling
if st > 1:
x_residual = x_residual[:, :, 1:, :, :]
# Use chunked mode for large tensors to reduce peak memory
if chunked_conv and d > 4:
x = self._chunked_conv_depth_to_space(x, causal)
else:
# Apply conv
x = self.conv(x, causal=causal)
# Depth to space rearrangement
x = self._depth_to_space(x)
# Remove first frame for causal temporal upsampling
if st > 1:
x = x[:, :, 1:, :, :]
# Add residual
if self.residual and x_residual is not None:
x = x + x_residual
return x
def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array:
"""Chunked conv + depth_to_space that processes in temporal chunks.
This reduces peak memory by avoiding the full high-channel intermediate tensor.
Instead of materializing (B, 4096, D, H, W), we process temporal chunks and
immediately apply depth_to_space.
Args:
x: Input tensor of shape (B, C, D, H, W)
causal: Whether to use causal convolutions
Returns:
Output tensor after conv + depth_to_space
"""
b, c, d, h, w = x.shape
st, sh, sw = self.stride
out_c = self.out_channels
# Output dimensions
out_d = d * st
out_h = h * sh
out_w = w * sw
# Chunk size in temporal dimension (process 4 frames at a time)
chunk_size = 4
kernel_t = 3 # Temporal kernel size
# For causal conv, we need (kernel_t - 1) frames of padding at the start
# For non-causal, we need (kernel_t - 1) // 2 on each side
if causal:
# Pad start with first frame repeated
pad_start = kernel_t - 1
pad_end = 0
else:
pad_start = (kernel_t - 1) // 2
pad_end = (kernel_t - 1) // 2
# Allocate output
outputs = []
# Process in chunks with overlap for conv kernel
t_pos = 0
while t_pos < d:
t_end = min(t_pos + chunk_size, d)
# Calculate input range with padding for kernel
in_start = max(0, t_pos - pad_start)
in_end = min(d, t_end + pad_end)
# Extract chunk
chunk = x[:, :, in_start:in_end, :, :]
# Apply conv to chunk
chunk_conv = self.conv(chunk, causal=causal)
# Apply depth_to_space
chunk_out = self._depth_to_space(chunk_conv)
# Calculate valid output range (excluding padding effects)
# Each input frame produces st output frames
out_start = (t_pos - in_start) * st
out_end = out_start + (t_end - t_pos) * st
# Extract valid portion
chunk_out = chunk_out[:, :, out_start:out_end, :, :]
outputs.append(chunk_out)
# Evaluate to free intermediate memory
mx.eval(outputs[-1])
t_pos = t_end
# Concatenate all chunks
if len(outputs) == 1:
return outputs[0]
return mx.concatenate(outputs, axis=2)

View File

@@ -0,0 +1,492 @@
"""VAE Tiling Configuration for decoding large videos.
Implements spatial and temporal tiling with trapezoidal blending masks
to decode large videos without running out of memory.
Default configuration (from PyTorch):
- Spatial: 512px tiles with 64px overlap
- Temporal: 64 frames with 24 frame overlap
"""
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple
import mlx.core as mx
def compute_trapezoidal_mask_1d(
length: int,
ramp_left: int,
ramp_right: int,
left_starts_from_0: bool = False,
) -> mx.array:
"""Generate a 1D trapezoidal blending mask with linear ramps.
Args:
length: Output length of the mask.
ramp_left: Fade-in length on the left.
ramp_right: Fade-out length on the right.
left_starts_from_0: Whether the ramp starts from 0 or first non-zero value.
Useful for temporal tiles where the first tile is causal.
Returns:
A 1D array of shape (length,) with values in [0, 1].
"""
if length <= 0:
raise ValueError("Mask length must be positive.")
ramp_left = max(0, min(ramp_left, length))
ramp_right = max(0, min(ramp_right, length))
# Start with ones
mask = [1.0] * length
# Apply left ramp (fade in)
if ramp_left > 0:
interval_length = ramp_left + 1 if left_starts_from_0 else ramp_left + 2
# Create fade_in values using linspace logic
fade_in_full = [i / (interval_length - 1) for i in range(interval_length)]
fade_in = fade_in_full[:-1] # Remove last element
if not left_starts_from_0:
fade_in = fade_in[1:] # Remove first element too
for i in range(min(ramp_left, len(fade_in))):
mask[i] *= fade_in[i]
# Apply right ramp (fade out)
if ramp_right > 0:
# Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1]
fade_out = [(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)]
for i in range(ramp_right):
mask[length - ramp_right + i] *= fade_out[i]
return mx.clip(mx.array(mask), 0, 1)
@dataclass(frozen=True)
class SpatialTilingConfig:
"""Configuration for dividing each frame into spatial tiles with optional overlap."""
tile_size_in_pixels: int
tile_overlap_in_pixels: int = 0
def __post_init__(self) -> None:
if self.tile_size_in_pixels < 64:
raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}")
if self.tile_size_in_pixels % 32 != 0:
raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}")
if self.tile_overlap_in_pixels % 32 != 0:
raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}")
if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:
raise ValueError(
f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}"
)
@dataclass(frozen=True)
class TemporalTilingConfig:
"""Configuration for dividing a video into temporal tiles."""
tile_size_in_frames: int
tile_overlap_in_frames: int = 0
def __post_init__(self) -> None:
if self.tile_size_in_frames < 16:
raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}")
if self.tile_size_in_frames % 8 != 0:
raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}")
if self.tile_overlap_in_frames % 8 != 0:
raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}")
if self.tile_overlap_in_frames >= self.tile_size_in_frames:
raise ValueError(
f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}"
)
@dataclass(frozen=True)
class TilingConfig:
"""Configuration for splitting video into tiles with optional overlap."""
spatial_config: Optional[SpatialTilingConfig] = None
temporal_config: Optional[TemporalTilingConfig] = None
@classmethod
def default(cls) -> "TilingConfig":
"""Default tiling: 512px spatial, 64 frame temporal."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24),
)
@classmethod
def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig":
"""Spatial tiling only (for short videos with large resolution)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap),
temporal_config=None,
)
@classmethod
def temporal_only(cls, tile_size: int = 64, overlap: int = 24) -> "TilingConfig":
"""Temporal tiling only (for long videos with small resolution)."""
return cls(
spatial_config=None,
temporal_config=TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap),
)
@classmethod
def aggressive(cls) -> "TilingConfig":
"""Aggressive tiling for very large videos (smaller tiles, much lower memory)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=256, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=32, tile_overlap_in_frames=8),
)
@classmethod
def conservative(cls) -> "TilingConfig":
"""Conservative tiling (larger tiles, less memory savings but faster)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=768, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=96, tile_overlap_in_frames=24),
)
@classmethod
def auto(
cls,
height: int,
width: int,
num_frames: int,
spatial_threshold: int = 512,
temporal_threshold: int = 65,
) -> Optional["TilingConfig"]:
"""Automatically determine tiling config based on video dimensions.
Uses PyTorch's default tiling (512px spatial, 64f temporal) which provides
enough context for CausalConv3d and sufficient overlap for clean blending.
Args:
height: Video height in pixels
width: Video width in pixels
num_frames: Number of video frames
spatial_threshold: Enable spatial tiling if either dimension exceeds this
temporal_threshold: Enable temporal tiling if frames exceed this
Returns:
TilingConfig if tiling is needed, None otherwise
"""
needs_spatial = height > spatial_threshold or width > spatial_threshold
needs_temporal = num_frames > temporal_threshold
if not needs_spatial and not needs_temporal:
return None
# Use the same defaults as PyTorch (512px spatial, 64f temporal).
# Smaller tiles cause quality degradation because CausalConv3d needs
# sufficient temporal context and overlap for clean blending.
spatial_config = None
temporal_config = None
if needs_spatial:
spatial_config = SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64)
if needs_temporal:
temporal_config = TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24)
return cls(spatial_config=spatial_config, temporal_config=temporal_config)
@dataclass
class DimensionIntervals:
"""Intervals for splitting a single dimension."""
starts: List[int]
ends: List[int]
left_ramps: List[int]
right_ramps: List[int]
def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
"""Split a spatial dimension into intervals."""
if dimension_size <= size:
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)
starts = [i * (size - overlap) for i in range(amount)]
ends = [start + size for start in starts]
ends[-1] = dimension_size
left_ramps = [0] + [overlap] * (amount - 1)
right_ramps = [overlap] * (amount - 1) + [0]
return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps)
def split_in_temporal(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
"""Split a temporal dimension into intervals with causal adjustment."""
if dimension_size <= size:
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
# Start with spatial split
intervals = split_in_spatial(size, overlap, dimension_size)
# Adjust for temporal: starts[1:] -= 1, left_ramps[1:] += 1
starts = intervals.starts.copy()
left_ramps = intervals.left_ramps.copy()
for i in range(1, len(starts)):
starts[i] = starts[i] - 1
left_ramps[i] = left_ramps[i] + 1
return DimensionIntervals(starts=starts, ends=intervals.ends, left_ramps=left_ramps, right_ramps=intervals.right_ramps)
def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
"""Map temporal latent interval to output coordinates and mask."""
start = begin * scale
stop = 1 + (end - 1) * scale
left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0
right_ramp_scaled = right_ramp * scale
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, True)
return slice(start, stop), mask
def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
"""Map spatial latent interval to output coordinates and mask."""
start = begin * scale
stop = end * scale
left_ramp_scaled = left_ramp * scale
right_ramp_scaled = right_ramp * scale
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, False)
return slice(start, stop), mask
def decode_with_tiling(
decoder_fn,
latents: mx.array,
tiling_config: TilingConfig,
spatial_scale: int = 32,
temporal_scale: int = 8,
causal: bool = False,
timestep: Optional[mx.array] = None,
chunked_conv: bool = False,
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
) -> mx.array:
"""Decode latents using tiling to reduce memory usage.
Args:
decoder_fn: Decoder function to call for each tile.
latents: Input latents of shape (B, C, F, H, W).
tiling_config: Tiling configuration.
spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify).
temporal_scale: Temporal scale factor (8 for LTX VAE).
causal: Whether to use causal convolutions.
timestep: Optional timestep for conditioning.
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames.
start_idx: Starting frame index in the full video.
Returns:
Decoded video.
"""
import gc
b, c, f_latent, h_latent, w_latent = latents.shape
# Compute output shape
out_f = 1 + (f_latent - 1) * temporal_scale
out_h = h_latent * spatial_scale
out_w = w_latent * spatial_scale
# Get tile size and overlap in latent space
if tiling_config.spatial_config is not None:
s_cfg = tiling_config.spatial_config
spatial_tile_size = s_cfg.tile_size_in_pixels // spatial_scale
spatial_overlap = s_cfg.tile_overlap_in_pixels // spatial_scale
else:
spatial_tile_size = max(h_latent, w_latent)
spatial_overlap = 0
if tiling_config.temporal_config is not None:
t_cfg = tiling_config.temporal_config
temporal_tile_size = t_cfg.tile_size_in_frames // temporal_scale
temporal_overlap = t_cfg.tile_overlap_in_frames // temporal_scale
else:
temporal_tile_size = f_latent
temporal_overlap = 0
# Compute intervals for each dimension
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
num_t_tiles = len(temporal_intervals.starts)
num_h_tiles = len(height_intervals.starts)
num_w_tiles = len(width_intervals.starts)
total_tiles = num_t_tiles * num_h_tiles * num_w_tiles
# Initialize output and weight accumulator
# Use float32 for accumulation to avoid precision issues
output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32)
weights = mx.zeros((b, 1, out_f, out_h, out_w), dtype=mx.float32)
mx.eval(output, weights)
tile_idx = 0
for t_idx in range(num_t_tiles):
t_start = temporal_intervals.starts[t_idx]
t_end = temporal_intervals.ends[t_idx]
t_left = temporal_intervals.left_ramps[t_idx]
t_right = temporal_intervals.right_ramps[t_idx]
# Map temporal coordinates
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
for h_idx in range(num_h_tiles):
h_start = height_intervals.starts[h_idx]
h_end = height_intervals.ends[h_idx]
h_left = height_intervals.left_ramps[h_idx]
h_right = height_intervals.right_ramps[h_idx]
# Map height coordinates
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
for w_idx in range(num_w_tiles):
w_start = width_intervals.starts[w_idx]
w_end = width_intervals.ends[w_idx]
w_left = width_intervals.left_ramps[w_idx]
w_right = width_intervals.right_ramps[w_idx]
# Map width coordinates
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
# Extract tile latents (small slice)
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
# Decode tile
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
mx.eval(tile_output)
# Clear tile_latents reference
del tile_latents
# Get actual decoded dimensions
_, _, decoded_t, decoded_h, decoded_w = tile_output.shape
expected_t = out_t_slice.stop - out_t_slice.start
expected_h = out_h_slice.stop - out_h_slice.start
expected_w = out_w_slice.stop - out_w_slice.start
# Handle potential size mismatches (use minimum)
actual_t = min(decoded_t, expected_t)
actual_h = min(decoded_h, expected_h)
actual_w = min(decoded_w, expected_w)
# Build blend mask
t_mask_slice = t_mask[:actual_t] if len(t_mask) > actual_t else t_mask
h_mask_slice = h_mask[:actual_h] if len(h_mask) > actual_h else h_mask
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
blend_mask = (
t_mask_slice.reshape(1, 1, -1, 1, 1) *
h_mask_slice.reshape(1, 1, 1, -1, 1) *
w_mask_slice.reshape(1, 1, 1, 1, -1)
)
# Slice tile output to match
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
# Clear full tile_output
del tile_output
# Compute output coordinates
t_out_start = out_t_slice.start
t_out_end = t_out_start + actual_t
h_out_start = out_h_slice.start
h_out_end = h_out_start + actual_h
w_out_start = out_w_slice.start
w_out_end = w_out_start + actual_w
# Use direct slice assignment (MLX supports this)
# Weighted accumulation
weighted_tile = tile_output_slice * blend_mask
# Update output using slice assignment
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
)
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
)
# Force evaluation to free memory
mx.eval(output, weights)
# Clean up tile-specific arrays
del tile_output_slice, weighted_tile, blend_mask
del t_mask_slice, h_mask_slice, w_mask_slice
tile_idx += 1
# Periodic garbage collection and cache clearing
if tile_idx % 4 == 0:
gc.collect()
try:
mx.clear_cache()
except Exception:
pass # May not be available on all platforms
# After completing all spatial tiles for this temporal tile,
# check if any frames are now finalized (no future tiles will contribute)
if on_frames_ready is not None and num_t_tiles > 1:
# Determine the finalized frame boundary
# Frames before the start of the next tile's output region are finalized
if t_idx < num_t_tiles - 1:
# Next tile starts at temporal_intervals.starts[t_idx + 1]
next_tile_start_latent = temporal_intervals.starts[t_idx + 1]
# Map to output frame index (first frame of next tile's contribution)
if next_tile_start_latent == 0:
next_tile_start_out = 0
else:
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
# We need to track how many frames we've already emitted
if not hasattr(decode_with_tiling, '_emitted_frames'):
decode_with_tiling._emitted_frames = 0
emitted = decode_with_tiling._emitted_frames
if next_tile_start_out > emitted:
# Normalize and emit frames [emitted, next_tile_start_out)
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
finalized_weights = mx.maximum(finalized_weights, 1e-8)
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
finalized_output = finalized_output.astype(latents.dtype)
mx.eval(finalized_output)
on_frames_ready(finalized_output, emitted)
decode_with_tiling._emitted_frames = next_tile_start_out
del finalized_output, finalized_weights
gc.collect()
# Normalize by weights
weights = mx.maximum(weights, 1e-8)
output = output / weights
mx.eval(output)
# Emit remaining frames if callback provided
if on_frames_ready is not None:
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
if emitted < out_f:
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
mx.eval(remaining_output)
on_frames_ready(remaining_output, emitted)
del remaining_output
# Reset emitted frames counter for next call
if hasattr(decode_with_tiling, '_emitted_frames'):
del decode_with_tiling._emitted_frames
# Clean up weights
del weights
gc.collect()
# Convert back to original dtype if needed
return output.astype(latents.dtype)

View File

@@ -0,0 +1,597 @@
"""Video VAE Encoder and Decoder for LTX-2."""
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, patchify, unpatchify
from mlx_video.models.ltx_2.video_vae.resnet import (
NormLayerType,
ResnetBlock3D,
UNetMidBlock3D,
get_norm_layer,
)
from mlx_video.models.ltx_2.video_vae.sampling import (
DepthToSpaceUpsample,
SpaceToDepthDownsample,
)
from mlx_video.utils import PixelNorm
class LogVarianceType(Enum):
"""Log variance mode for VAE."""
PER_CHANNEL = "per_channel"
UNIFORM = "uniform"
CONSTANT = "constant"
NONE = "none"
def _make_encoder_block(
block_name: str,
block_config: Dict[str, Any],
in_channels: int,
convolution_dimensions: int,
norm_layer: NormLayerType,
norm_num_groups: int,
spatial_padding_mode: PaddingModeType,
) -> Tuple[nn.Module, int]:
"""Create an encoder block.
Args:
block_name: Type of block
block_config: Block configuration
in_channels: Input channels
convolution_dimensions: Number of dimensions
norm_layer: Normalization layer type
norm_num_groups: Number of groups for group norm
spatial_padding_mode: Padding mode
Returns:
Tuple of (block, output_channels)
"""
out_channels = in_channels
if block_name == "res_x":
block = UNetMidBlock3D(
dims=convolution_dimensions,
in_channels=in_channels,
num_layers=block_config["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
out_channels = in_channels * block_config.get("multiplier", 2)
block = ResnetBlock3D(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(2, 1, 1),
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(1, 2, 2),
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
block = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(2, 2, 2),
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_x_y":
out_channels = in_channels * block_config.get("multiplier", 2)
block = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(2, 2, 2),
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_res":
out_channels = in_channels * block_config.get("multiplier", 2)
block = SpaceToDepthDownsample(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space_res":
out_channels = in_channels * block_config.get("multiplier", 2)
block = SpaceToDepthDownsample(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time_res":
out_channels = in_channels * block_config.get("multiplier", 2)
block = SpaceToDepthDownsample(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"Unknown encoder block: {block_name}")
return block, out_channels
def _make_decoder_block(
block_name: str,
block_config: Dict[str, Any],
in_channels: int,
convolution_dimensions: int,
norm_layer: NormLayerType,
timestep_conditioning: bool,
norm_num_groups: int,
spatial_padding_mode: PaddingModeType,
) -> Tuple[nn.Module, int]:
"""Create a decoder block."""
out_channels = in_channels
if block_name == "res_x":
block = UNetMidBlock3D(
dims=convolution_dimensions,
in_channels=in_channels,
num_layers=block_config["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_config.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
out_channels = in_channels // block_config.get("multiplier", 2)
block = ResnetBlock3D(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_config.get("inject_noise", False),
timestep_conditioning=False,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = DepthToSpaceUpsample(
dims=convolution_dimensions,
in_channels=in_channels,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = DepthToSpaceUpsample(
dims=convolution_dimensions,
in_channels=in_channels,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
out_channels = in_channels // block_config.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=convolution_dimensions,
in_channels=in_channels,
stride=(2, 2, 2),
residual=block_config.get("residual", False),
out_channels_reduction_factor=block_config.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"Unknown decoder block: {block_name}")
return block, out_channels
class VideoEncoder(nn.Module):
_DEFAULT_NORM_NUM_GROUPS = 32
def __init__(self, config: "VideoEncoderModelConfig"):
"""Initialize VideoEncoder from config.
Args:
config: VideoEncoderModelConfig with encoder parameters
"""
super().__init__()
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
self.patch_size = config.patch_size
self.norm_layer = config.norm_layer
self.latent_channels = config.out_channels
self.latent_log_var = config.latent_log_var
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
encoder_blocks = config.encoder_blocks if config.encoder_blocks else []
encoder_spatial_padding_mode = config.encoder_spatial_padding_mode
# Per-channel statistics for normalizing latents
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.out_channels)
# After patchify, channels increase by patch_size^2
in_channels = config.in_channels * config.patch_size ** 2
feature_channels = config.out_channels
# Initial convolution
self.conv_in = CausalConv3d(
in_channels=in_channels,
out_channels=feature_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=encoder_spatial_padding_mode,
)
# Build encoder blocks
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
self.down_blocks = {}
for idx, (block_name, block_params) in enumerate(encoder_blocks):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block, feature_channels = _make_encoder_block(
block_name=block_name,
block_config=block_config,
in_channels=feature_channels,
convolution_dimensions=config.convolution_dimensions,
norm_layer=config.norm_layer,
norm_num_groups=self._norm_num_groups,
spatial_padding_mode=encoder_spatial_padding_mode,
)
self.down_blocks[idx] = block
# Output normalization and convolution
if config.norm_layer == NormLayerType.GROUP_NORM:
self.conv_norm_out = nn.GroupNorm(
num_groups=self._norm_num_groups,
dims=feature_channels,
eps=1e-6,
)
elif config.norm_layer == NormLayerType.PIXEL_NORM:
self.conv_norm_out = PixelNorm()
self.conv_act = nn.SiLU()
# Calculate output convolution channels
conv_out_channels = config.out_channels
if config.latent_log_var == LogVarianceType.PER_CHANNEL:
conv_out_channels *= 2
elif config.latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
conv_out_channels += 1
self.conv_out = CausalConv3d(
in_channels=feature_channels,
out_channels=conv_out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=encoder_spatial_padding_mode,
)
def __call__(self, sample: mx.array) -> mx.array:
"""Encode video to latent representation.
Args:
sample: Input video of shape (B, C, F, H, W).
F must be 1 + 8*k (e.g., 1, 9, 17, 25, 33...)
Returns:
Normalized latent means of shape (B, 128, F', H', W')
"""
# Validate frame count
frames_count = sample.shape[2]
if ((frames_count - 1) % 8) != 0:
raise ValueError(
"Invalid number of frames: Encode input must have 1 + 8 * x frames "
f"(e.g., 1, 9, 17, ...). Got {frames_count} frames."
)
# Initial patchify
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
sample = self.conv_in(sample, causal=True)
# Process through encoder blocks
for i in range(len(self.down_blocks)):
down_block = self.down_blocks[i]
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
sample = down_block(sample, causal=True)
else:
sample = down_block(sample, causal=True)
# Output processing
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=True)
# Handle log variance modes
if self.latent_log_var == LogVarianceType.UNIFORM:
means = sample[:, :-1, ...]
logvar = sample[:, -1:, ...]
num_channels = means.shape[1]
repeated_logvar = mx.tile(logvar, (1, num_channels, 1, 1, 1))
sample = mx.concatenate([means, repeated_logvar], axis=1)
elif self.latent_log_var == LogVarianceType.CONSTANT:
sample = sample[:, :-1, ...]
approx_ln_0 = -30
sample = mx.concatenate([
sample,
mx.full_like(sample, approx_ln_0),
], axis=1)
# Split into means and logvar, normalize means
means = sample[:, :self.latent_channels, ...]
return self.per_channel_statistics.normalize(means)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize VAE encoder weights from PyTorch format to MLX format."""
sanitized = {}
if "per_channel_statistics.mean" in weights:
return weights
for key, value in weights.items():
new_key = key
if "position_ids" in key:
continue
# Only process VAE encoder weights
if not key.startswith("vae."):
continue
# Handle per-channel statistics
if "vae.per_channel_statistics" in key:
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
continue
elif key.startswith("vae.encoder."):
new_key = key.replace("vae.encoder.", "")
else:
continue
# Conv3d: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Conv2d: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path) -> "VideoEncoder":
"""Load a pretrained VideoEncoder from a directory with weights and config.
Args:
model_path: Path to directory containing safetensors weights and config.json
Returns:
Loaded VideoEncoder instance
"""
import json
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
# Load config
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path) as f:
config_dict = json.load(f)
config = VideoEncoderModelConfig(**config_dict)
else:
config = VideoEncoderModelConfig()
# Load weights
weight_files = sorted(model_path.glob("*.safetensors"))
if not weight_files:
if model_path.is_file():
weights = mx.load(str(model_path))
else:
raise FileNotFoundError(f"No safetensors files found in {model_path}")
else:
weights = {}
for wf in weight_files:
weights.update(mx.load(str(wf)))
# Create model, sanitize and load weights
model = cls(config)
weights = model.sanitize(weights)
model.load_weights(list(weights.items()), strict=False)
return model
class VideoDecoder(nn.Module):
_DEFAULT_NORM_NUM_GROUPS = 32
def __init__(
self,
convolution_dimensions: int = 3,
in_channels: int = 128,
out_channels: int = 3,
decoder_blocks: List[Tuple[str, Any]] = None,
patch_size: int = 4,
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
causal: bool = False,
timestep_conditioning: bool = False,
decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
):
"""Initialize VideoDecoder.
Args:
convolution_dimensions: Number of dimensions
in_channels: Input latent channels
out_channels: Output channels (3 for RGB)
decoder_blocks: List of (block_name, config) tuples
patch_size: Spatial patch size
norm_layer: Normalization layer type
causal: Whether to use causal convolutions
timestep_conditioning: Whether to use timestep conditioning
decoder_spatial_padding_mode: Padding mode
"""
super().__init__()
if decoder_blocks is None:
decoder_blocks = []
self.patch_size = patch_size
out_channels = out_channels * patch_size ** 2
self.causal = causal
self.timestep_conditioning = timestep_conditioning
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
# Per-channel statistics for denormalizing latents
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
# Noise and timestep parameters
self.decode_noise_scale = 0.025
self.decode_timestep = 0.05
# Compute initial feature channels
feature_channels = in_channels
for block_name, block_params in list(reversed(decoder_blocks)):
block_config = block_params if isinstance(block_params, dict) else {}
if block_name == "res_x_y":
feature_channels = feature_channels * block_config.get("multiplier", 2)
if block_name == "compress_all":
feature_channels = feature_channels * block_config.get("multiplier", 1)
# Initial convolution
self.conv_in = CausalConv3d(
in_channels=in_channels,
out_channels=feature_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=decoder_spatial_padding_mode,
)
# Build decoder blocks (reversed order)
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
self.up_blocks = {}
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block, feature_channels = _make_decoder_block(
block_name=block_name,
block_config=block_config,
in_channels=feature_channels,
convolution_dimensions=convolution_dimensions,
norm_layer=norm_layer,
timestep_conditioning=timestep_conditioning,
norm_num_groups=self._norm_num_groups,
spatial_padding_mode=decoder_spatial_padding_mode,
)
self.up_blocks[idx] = block
# Output normalization
if norm_layer == NormLayerType.GROUP_NORM:
self.conv_norm_out = nn.GroupNorm(
num_groups=self._norm_num_groups,
dims=feature_channels,
eps=1e-6,
)
elif norm_layer == NormLayerType.PIXEL_NORM:
self.conv_norm_out = PixelNorm()
self.conv_act = nn.SiLU()
self.conv_out = CausalConv3d(
in_channels=feature_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=decoder_spatial_padding_mode,
)
def __call__(
self,
sample: mx.array,
timestep: Optional[mx.array] = None,
) -> mx.array:
"""Decode latent to video.
Args:
sample: Latent tensor of shape (B, 128, F', H', W')
timestep: Optional timestep for conditioning
Returns:
Decoded video of shape (B, 3, F, H, W)
"""
batch_size = sample.shape[0]
# Add noise if timestep conditioning is enabled
if self.timestep_conditioning:
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
sample = noise + (1.0 - self.decode_noise_scale) * sample
# Denormalize latents
sample = self.per_channel_statistics.un_normalize(sample)
# Use default timestep if not provided
if timestep is None and self.timestep_conditioning:
timestep = mx.full((batch_size,), self.decode_timestep)
# Initial convolution
sample = self.conv_in(sample, causal=self.causal)
# Process through decoder blocks
for i in range(len(self.up_blocks)):
up_block = self.up_blocks[i]
if isinstance(up_block, UNetMidBlock3D):
sample = up_block(sample, causal=self.causal)
elif isinstance(up_block, ResnetBlock3D):
sample = up_block(sample, causal=self.causal)
else:
sample = up_block(sample, causal=self.causal)
# Output processing
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=self.causal)
# Unpatchify to restore spatial resolution
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample