Refactor LTX-2 model structure
This commit is contained in:
8
mlx_video/models/ltx_2/video_vae/__init__.py
Normal file
8
mlx_video/models/ltx_2/video_vae/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
|
||||
from mlx_video.models.ltx_2.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import (
|
||||
TilingConfig,
|
||||
SpatialTilingConfig,
|
||||
TemporalTilingConfig,
|
||||
)
|
||||
294
mlx_video/models/ltx_2/video_vae/convolution.py
Normal file
294
mlx_video/models/ltx_2/video_vae/convolution.py
Normal file
@@ -0,0 +1,294 @@
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class PaddingModeType(Enum):
|
||||
ZEROS = "zeros"
|
||||
REFLECT = "reflect"
|
||||
|
||||
|
||||
def reflect_pad_2d(x: mx.array, pad_h: int, pad_w: int) -> mx.array:
|
||||
"""Apply reflect padding to spatial dimensions of a 5D tensor.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, D, H, W, C) - channels last
|
||||
pad_h: Padding for height dimension
|
||||
pad_w: Padding for width dimension
|
||||
|
||||
Returns:
|
||||
Padded tensor
|
||||
"""
|
||||
if pad_h == 0 and pad_w == 0:
|
||||
return x
|
||||
|
||||
# Height padding (axis 2)
|
||||
if pad_h > 0:
|
||||
# Get reflection indices - exclude boundary
|
||||
top_pad = x[:, :, 1:pad_h+1, :, :][:, :, ::-1, :, :] # Flip top portion
|
||||
bottom_pad = x[:, :, -pad_h-1:-1, :, :][:, :, ::-1, :, :] # Flip bottom portion
|
||||
x = mx.concatenate([top_pad, x, bottom_pad], axis=2)
|
||||
|
||||
# Width padding (axis 3)
|
||||
if pad_w > 0:
|
||||
left_pad = x[:, :, :, 1:pad_w+1, :][:, :, :, ::-1, :] # Flip left portion
|
||||
right_pad = x[:, :, :, -pad_w-1:-1, :][:, :, :, ::-1, :] # Flip right portion
|
||||
x = mx.concatenate([left_pad, x, right_pad], axis=3)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def make_conv_nd(
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, ...]],
|
||||
stride: Union[int, Tuple[int, ...]] = 1,
|
||||
padding: Union[int, Tuple[int, ...], str] = 0,
|
||||
causal: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
) -> nn.Module:
|
||||
|
||||
if dims == 2:
|
||||
return CausalConv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
causal=causal,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif dims == 3:
|
||||
return CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
causal=causal,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported number of dimensions: {dims}")
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int], str] = 0,
|
||||
causal: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.causal = causal
|
||||
self.spatial_padding_mode = spatial_padding_mode
|
||||
|
||||
# Normalize kernel_size and stride to tuples
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.time_kernel_size = kernel_size[0]
|
||||
|
||||
# Calculate spatial padding (temporal is handled separately via frame replication)
|
||||
height_pad = kernel_size[1] // 2
|
||||
width_pad = kernel_size[2] // 2
|
||||
self.spatial_padding = (height_pad, width_pad)
|
||||
|
||||
# Create the base convolution (without padding, we'll handle it manually)
|
||||
self.conv = nn.Conv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=0, # We handle padding manually
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
|
||||
|
||||
use_causal = causal if causal is not None else self.causal
|
||||
|
||||
# Apply temporal padding via frame replication
|
||||
# Only apply if kernel_size > 1
|
||||
if self.time_kernel_size > 1:
|
||||
if use_causal:
|
||||
# Causal: replicate first frame kernel_size-1 times at the beginning
|
||||
first_frame_pad = mx.repeat(x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2)
|
||||
x = mx.concatenate([first_frame_pad, x], axis=2)
|
||||
else:
|
||||
# Non-causal: replicate first frame at start, last frame at end
|
||||
pad_size = (self.time_kernel_size - 1) // 2
|
||||
if pad_size > 0:
|
||||
first_frame_pad = mx.repeat(x[:, :, :1, :, :], pad_size, axis=2)
|
||||
last_frame_pad = mx.repeat(x[:, :, -1:, :, :], pad_size, axis=2)
|
||||
x = mx.concatenate([first_frame_pad, x, last_frame_pad], axis=2)
|
||||
|
||||
# Transpose to channels last: (B, C, D, H, W) -> (B, D, H, W, C)
|
||||
x = mx.transpose(x, (0, 2, 3, 4, 1))
|
||||
|
||||
# Apply spatial padding
|
||||
pad_h, pad_w = self.spatial_padding
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
if self.spatial_padding_mode == PaddingModeType.REFLECT:
|
||||
# Use reflect padding for spatial dimensions
|
||||
x = reflect_pad_2d(x, pad_h, pad_w)
|
||||
else:
|
||||
# Use zero padding for spatial dimensions
|
||||
pad_width = [
|
||||
(0, 0), # Batch
|
||||
(0, 0), # D (temporal - already padded)
|
||||
(pad_h, pad_h), # H
|
||||
(pad_w, pad_w), # W
|
||||
(0, 0), # C
|
||||
]
|
||||
x = mx.pad(x, pad_width)
|
||||
|
||||
# Apply convolution with chunking for large tensors
|
||||
# Note: We choose to use chunking because MLX conv3d fails around 33 frames with 192x192 spatial
|
||||
x = self._chunked_conv3d(x)
|
||||
|
||||
# Transpose back to channels first: (B, D, H, W, C) -> (B, C, D, H, W)
|
||||
x = mx.transpose(x, (0, 4, 1, 2, 3))
|
||||
|
||||
return x
|
||||
|
||||
def _chunked_conv3d(self, x: mx.array) -> mx.array:
|
||||
"""Apply conv3d in temporal chunks to work around MLX bug with large tensors.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, D, H, W, C) in channels-last format
|
||||
|
||||
Returns:
|
||||
Output tensor after conv3d
|
||||
"""
|
||||
b, d, h, w, c = x.shape
|
||||
|
||||
|
||||
total_elements = d * h * w * c
|
||||
max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk
|
||||
|
||||
if total_elements <= max_safe_elements:
|
||||
return self.conv(x)
|
||||
|
||||
elements_per_frame = h * w * c
|
||||
max_frames_per_chunk = max(1, max_safe_elements // elements_per_frame)
|
||||
chunk_size = min(max_frames_per_chunk, 24) # Cap at 24 frames per chunk
|
||||
|
||||
kernel_t = self.time_kernel_size
|
||||
|
||||
overlap = kernel_t - 1
|
||||
|
||||
|
||||
expected_output_frames = d - overlap
|
||||
|
||||
outputs = []
|
||||
out_idx = 0
|
||||
|
||||
# Process chunks
|
||||
in_start = 0
|
||||
while out_idx < expected_output_frames:
|
||||
remaining = expected_output_frames - out_idx
|
||||
out_frames_this_chunk = min(chunk_size, remaining)
|
||||
|
||||
in_frames_needed = out_frames_this_chunk + overlap
|
||||
in_end = min(in_start + in_frames_needed, d)
|
||||
|
||||
chunk = x[:, in_start:in_end, :, :, :]
|
||||
|
||||
chunk_out = self.conv(chunk)
|
||||
mx.eval(chunk_out)
|
||||
|
||||
outputs.append(chunk_out)
|
||||
|
||||
out_idx += chunk_out.shape[1]
|
||||
in_start += chunk_out.shape[1]
|
||||
|
||||
# Concatenate all chunks
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return mx.concatenate(outputs, axis=1)
|
||||
|
||||
|
||||
class CausalConv2d(nn.Module):
|
||||
"""2D convolution with optional causal padding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Union[int, Tuple[int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int], str] = 0,
|
||||
causal: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
"""Initialize CausalConv2d."""
|
||||
super().__init__()
|
||||
|
||||
self.causal = causal
|
||||
self.spatial_padding_mode = spatial_padding_mode
|
||||
|
||||
# Normalize kernel_size and stride
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride)
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
|
||||
# Calculate padding
|
||||
if isinstance(padding, str) and padding == "same":
|
||||
self.padding = (
|
||||
(kernel_size[0] - 1) // 2,
|
||||
(kernel_size[1] - 1) // 2,
|
||||
)
|
||||
elif isinstance(padding, int):
|
||||
self.padding = (padding, padding)
|
||||
else:
|
||||
self.padding = padding
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
|
||||
"""Forward pass."""
|
||||
# Transpose to channels last: (B, C, H, W) -> (B, H, W, C)
|
||||
x = mx.transpose(x, (0, 2, 3, 1))
|
||||
|
||||
# Apply padding
|
||||
pad_h, pad_w = self.padding
|
||||
if pad_h != 0 or pad_w != 0:
|
||||
pad_width = [
|
||||
(0, 0), # Batch
|
||||
(pad_h, pad_h), # H
|
||||
(pad_w, pad_w), # W
|
||||
(0, 0), # C
|
||||
]
|
||||
x = mx.pad(x, pad_width)
|
||||
|
||||
x = self.conv(x)
|
||||
|
||||
# Transpose back: (B, H, W, C) -> (B, C, H, W)
|
||||
x = mx.transpose(x, (0, 3, 1, 2))
|
||||
|
||||
return x
|
||||
692
mlx_video/models/ltx_2/video_vae/decoder.py
Normal file
692
mlx_video/models/ltx_2/video_vae/decoder.py
Normal file
@@ -0,0 +1,692 @@
|
||||
"""Video VAE Decoder for LTX-2 with timestep conditioning.
|
||||
|
||||
Architecture (from PyTorch weights):
|
||||
- conv_in: 128 -> 1024
|
||||
- up_blocks.0: 5 ResBlocks at 1024 (with timestep)
|
||||
- up_blocks.1: Conv 1024 -> 4096, depth2space -> 512, upscale 2x
|
||||
- up_blocks.2: 5 ResBlocks at 512 (with timestep)
|
||||
- up_blocks.3: Conv 512 -> 2048, depth2space -> 256, upscale 2x
|
||||
- up_blocks.4: 5 ResBlocks at 256 (with timestep)
|
||||
- up_blocks.5: Conv 256 -> 1024, depth2space -> 128, upscale 2x
|
||||
- up_blocks.6: 5 ResBlocks at 128 (with timestep)
|
||||
- pixel_norm + timestep modulation (last_scale_shift_table)
|
||||
- conv_out: 128 -> 48
|
||||
- unpatchify: 48 -> 3 with patch_size=4
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional, Dict
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.ops import unpatchify, PerChannelStatistics
|
||||
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: mx.array,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = True,
|
||||
downscale_freq_shift: float = 0,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
) -> mx.array:
|
||||
"""Create sinusoidal timestep embeddings."""
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * mx.arange(0, half_dim, dtype=mx.float32)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = mx.exp(exponent)
|
||||
emb = timesteps[:, None].astype(mx.float32) * emb[None, :]
|
||||
emb = scale * emb
|
||||
|
||||
emb = mx.concatenate([mx.sin(emb), mx.cos(emb)], axis=-1)
|
||||
|
||||
if flip_sin_to_cos:
|
||||
emb = mx.concatenate([emb[:, half_dim:], emb[:, :half_dim]], axis=-1)
|
||||
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = mx.pad(emb, [(0, 0), (0, 1)])
|
||||
|
||||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
"""MLP for timestep embedding."""
|
||||
|
||||
def __init__(self, in_channels: int, time_embed_dim: int):
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def __call__(self, sample: mx.array) -> mx.array:
|
||||
sample = self.linear_1(sample)
|
||||
sample = self.act(sample)
|
||||
sample = self.linear_2(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class PixArtAlphaTimestepEmbedder(nn.Module):
|
||||
"""Combined timestep embedding (sinusoidal + MLP)."""
|
||||
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=256,
|
||||
time_embed_dim=embedding_dim
|
||||
)
|
||||
|
||||
def __call__(self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32) -> mx.array:
|
||||
timesteps_proj = get_timestep_embedding(
|
||||
timestep,
|
||||
embedding_dim=256,
|
||||
flip_sin_to_cos=True,
|
||||
downscale_freq_shift=0
|
||||
)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class ResnetBlock3DSimple(nn.Module):
|
||||
"""ResNet block with optional timestep conditioning.
|
||||
|
||||
Weight keys: conv1.conv, conv2.conv, scale_shift_table
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||
timestep_conditioning: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
# Nested conv structure to match PyTorch naming: conv1.conv.weight
|
||||
self.conv1 = self._make_conv_wrapper(channels, channels, spatial_padding_mode)
|
||||
self.conv2 = self._make_conv_wrapper(channels, channels, spatial_padding_mode)
|
||||
|
||||
self.act = nn.SiLU()
|
||||
|
||||
# Scale-shift table for timestep conditioning: [shift1, scale1, shift2, scale2]
|
||||
if timestep_conditioning:
|
||||
self.scale_shift_table = mx.zeros((4, channels))
|
||||
|
||||
def _make_conv_wrapper(self, in_ch, out_ch, padding_mode):
|
||||
"""Create a wrapper object with a 'conv' attribute to match PyTorch naming."""
|
||||
class ConvWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
self_inner.conv = CausalConv3d(
|
||||
in_channels=in_ch,
|
||||
out_channels=out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=padding_mode,
|
||||
)
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
return ConvWrapper()
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
"""Apply pixel normalization."""
|
||||
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
causal: bool = False,
|
||||
timestep_embed: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
residual = x
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# Block 1 with optional timestep conditioning
|
||||
x = self.pixel_norm(x)
|
||||
|
||||
if self.timestep_conditioning and timestep_embed is not None:
|
||||
# scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1)
|
||||
# Combine table with timestep embedding
|
||||
ada_values = self.scale_shift_table[None, :, :, None, None, None] # (1, 4, C, 1, 1, 1)
|
||||
# Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1)
|
||||
channels = self.scale_shift_table.shape[1]
|
||||
ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1)
|
||||
ada_values = ada_values + ts_reshaped
|
||||
|
||||
shift1 = ada_values[:, 0] # (B, C, 1, 1, 1)
|
||||
scale1 = ada_values[:, 1]
|
||||
shift2 = ada_values[:, 2]
|
||||
scale2 = ada_values[:, 3]
|
||||
|
||||
x = x * (1 + scale1) + shift1
|
||||
|
||||
x = self.act(x)
|
||||
x = self.conv1(x, causal=causal)
|
||||
|
||||
# Block 2 with optional timestep conditioning
|
||||
x = self.pixel_norm(x)
|
||||
|
||||
if self.timestep_conditioning and timestep_embed is not None:
|
||||
x = x * (1 + scale2) + shift2
|
||||
|
||||
x = self.act(x)
|
||||
x = self.conv2(x, causal=causal)
|
||||
|
||||
return x + residual
|
||||
|
||||
|
||||
class ResBlockGroup(nn.Module):
|
||||
"""Group of ResNet blocks with shared timestep embedding.
|
||||
|
||||
PyTorch naming: res_blocks.0, res_blocks.1, etc.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
num_layers: int = 5,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||
timestep_conditioning: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
# Time embedder for this block group: embed_dim = 4 * channels
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaTimestepEmbedder(
|
||||
embedding_dim=channels * 4
|
||||
)
|
||||
|
||||
# Use dict with int keys for MLX to track parameters properly
|
||||
self.res_blocks = {
|
||||
i: ResnetBlock3DSimple(
|
||||
channels,
|
||||
spatial_padding_mode,
|
||||
timestep_conditioning=timestep_conditioning
|
||||
)
|
||||
for i in range(num_layers)
|
||||
}
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
timestep_embed = None
|
||||
|
||||
if self.timestep_conditioning and timestep is not None:
|
||||
batch_size = x.shape[0]
|
||||
timestep_embed = self.time_embedder(
|
||||
timestep.flatten(),
|
||||
hidden_dtype=x.dtype
|
||||
)
|
||||
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
|
||||
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
|
||||
|
||||
for res_block in self.res_blocks.values():
|
||||
x = res_block(x, causal=causal, timestep_embed=timestep_embed)
|
||||
return x
|
||||
|
||||
|
||||
class LTX2VideoDecoder(nn.Module):
|
||||
"""LTX-2 Video VAE Decoder with timestep conditioning.
|
||||
|
||||
Architecture:
|
||||
- conv_in: 128 -> 1024
|
||||
- up_blocks.0: 5 ResBlocks at 1024 (with timestep)
|
||||
- up_blocks.1: Upsampler 1024 -> 512
|
||||
- up_blocks.2: 5 ResBlocks at 512 (with timestep)
|
||||
- up_blocks.3: Upsampler 512 -> 256
|
||||
- up_blocks.4: 5 ResBlocks at 256 (with timestep)
|
||||
- up_blocks.5: Upsampler 256 -> 128
|
||||
- up_blocks.6: 5 ResBlocks at 128 (with timestep)
|
||||
- conv_out: 128 -> 48 (3 * 4^2 for patch_size=4)
|
||||
"""
|
||||
|
||||
# Block definitions: ("res", channels, num_layers) or ("d2s", in_channels, reduction, stride)
|
||||
# stride is (D, H, W) tuple
|
||||
DEFAULT_BLOCKS = [
|
||||
("res", 1024, 5),
|
||||
("d2s", 1024, 2, (2, 2, 2)),
|
||||
("res", 512, 5),
|
||||
("d2s", 512, 2, (2, 2, 2)),
|
||||
("res", 256, 5),
|
||||
("d2s", 256, 2, (2, 2, 2)),
|
||||
("res", 128, 5),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
out_channels: int = 3,
|
||||
patch_size: int = 4,
|
||||
num_layers_per_block: int = 5,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||
timestep_conditioning: bool = True,
|
||||
decoder_blocks: list = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
# Decode parameters (configurable via constructor)
|
||||
self.decode_noise_scale = 0.025 # Set to 0.0 to disable noise
|
||||
self.decode_timestep = 0.05
|
||||
|
||||
# Per-channel statistics for denormalization (loaded from weights)
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
|
||||
|
||||
blocks = decoder_blocks or self.DEFAULT_BLOCKS
|
||||
first_ch = blocks[0][1]
|
||||
last_ch = blocks[-1][1]
|
||||
|
||||
# Initial conv: in_channels -> first block channels
|
||||
class ConvInWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
self_inner.conv = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=first_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
self.conv_in = ConvInWrapper()
|
||||
|
||||
# Build up blocks from config
|
||||
self.up_blocks = {}
|
||||
for idx, block_def in enumerate(blocks):
|
||||
block_type = block_def[0]
|
||||
ch = block_def[1]
|
||||
if block_type == "res":
|
||||
num_layers = block_def[2] if len(block_def) > 2 else num_layers_per_block
|
||||
self.up_blocks[idx] = ResBlockGroup(ch, num_layers, spatial_padding_mode, timestep_conditioning)
|
||||
elif block_type == "d2s":
|
||||
reduction = block_def[2] if len(block_def) > 2 else 2
|
||||
stride = block_def[3] if len(block_def) > 3 else (2, 2, 2)
|
||||
residual = block_def[4] if len(block_def) > 4 else True
|
||||
self.up_blocks[idx] = DepthToSpaceUpsample(
|
||||
dims=3,
|
||||
in_channels=ch,
|
||||
stride=stride,
|
||||
residual=residual,
|
||||
out_channels_reduction_factor=reduction,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
final_out_channels = out_channels * patch_size * patch_size
|
||||
class ConvOutWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
self_inner.conv = CausalConv3d(
|
||||
in_channels=last_ch,
|
||||
out_channels=final_out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
self.conv_out = ConvOutWrapper()
|
||||
|
||||
self.act = nn.SiLU()
|
||||
|
||||
if timestep_conditioning:
|
||||
self.timestep_scale_multiplier = mx.array(1000.0)
|
||||
self.last_time_embedder = PixArtAlphaTimestepEmbedder(
|
||||
embedding_dim=last_ch * 2
|
||||
)
|
||||
self.last_scale_shift_table = mx.zeros((2, last_ch))
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
# Build decoder weights dict with key remapping
|
||||
sanitized = {}
|
||||
if "per_channel_statistics.mean" in weights:
|
||||
return weights
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
if not key.startswith("vae.") or key.startswith("vae.encoder."):
|
||||
continue
|
||||
|
||||
if key.startswith("vae.per_channel_statistics."):
|
||||
# Map per-channel statistics (use exact key matching)
|
||||
if key == "vae.per_channel_statistics.mean-of-means":
|
||||
new_key = "per_channel_statistics.mean"
|
||||
elif key == "vae.per_channel_statistics.std-of-means":
|
||||
new_key = "per_channel_statistics.std"
|
||||
else:
|
||||
continue # Skip other statistics keys
|
||||
|
||||
if key.startswith("vae.decoder."):
|
||||
new_key = key.replace("vae.decoder.", "")
|
||||
|
||||
|
||||
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||
if ".conv.weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
if ".conv.bias" in key:
|
||||
pass # bias doesn't need transpose
|
||||
|
||||
if ".conv.weight" in new_key or ".conv.bias" in new_key:
|
||||
|
||||
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
|
||||
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
|
||||
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTX2VideoDecoder":
|
||||
"""Load a pretrained decoder from a directory with config.json and weights.
|
||||
|
||||
Args:
|
||||
model_path: Path to directory containing config.json and safetensors files,
|
||||
or path to a single safetensors file.
|
||||
strict: Whether to require all weight keys to match.
|
||||
|
||||
Returns:
|
||||
Loaded LTX2VideoDecoder instance
|
||||
"""
|
||||
import json
|
||||
|
||||
model_path = Path(model_path)
|
||||
config_dict = {}
|
||||
|
||||
# Load config from directory
|
||||
config_path = model_path / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
# Load weights from directory
|
||||
weight_files = sorted(model_path.glob("*.safetensors"))
|
||||
if not weight_files:
|
||||
raise FileNotFoundError(f"No safetensors files found in {model_path}")
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(str(wf)))
|
||||
|
||||
|
||||
# Infer block structure from weights
|
||||
decoder_blocks = cls._infer_blocks(weights)
|
||||
|
||||
# Determine spatial padding mode from config
|
||||
spatial_padding_mode_str = config_dict.get("spatial_padding_mode", "reflect")
|
||||
spatial_padding_mode = PaddingModeType(spatial_padding_mode_str)
|
||||
|
||||
model = cls(
|
||||
timestep_conditioning=config_dict.get("timestep_conditioning", False),
|
||||
decoder_blocks=decoder_blocks,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
weights = model.sanitize(weights)
|
||||
model.load_weights(list(weights.items()), strict=strict)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _infer_blocks(weights: dict) -> list:
|
||||
"""Infer decoder block structure from weight keys."""
|
||||
block_indices = set()
|
||||
for k in weights:
|
||||
if "up_blocks." in k:
|
||||
idx_str = k.split("up_blocks.")[1].split(".")[0]
|
||||
if idx_str.isdigit():
|
||||
block_indices.add(int(idx_str))
|
||||
|
||||
if not block_indices:
|
||||
return None
|
||||
|
||||
# First pass: collect block info
|
||||
raw_blocks = []
|
||||
for idx in sorted(block_indices):
|
||||
has_conv = any(f"up_blocks.{idx}.conv." in k for k in weights)
|
||||
res_indices = set()
|
||||
for k in weights:
|
||||
prefix = f"up_blocks.{idx}.res_blocks."
|
||||
if prefix in k:
|
||||
res_idx = k.split(prefix)[1].split(".")[0]
|
||||
if res_idx.isdigit():
|
||||
res_indices.add(int(res_idx))
|
||||
|
||||
if has_conv and not res_indices:
|
||||
# D2S block - get conv shape
|
||||
for k, v in weights.items():
|
||||
if f"up_blocks.{idx}.conv." in k and "weight" in k:
|
||||
in_ch = v.shape[-1] if v.ndim == 5 else v.shape[1]
|
||||
conv_out_ch = v.shape[0]
|
||||
raw_blocks.append(("d2s", in_ch, conv_out_ch))
|
||||
break
|
||||
elif res_indices:
|
||||
num_res = max(res_indices) + 1
|
||||
for k, v in weights.items():
|
||||
if f"up_blocks.{idx}.res_blocks.0.conv1" in k and "weight" in k:
|
||||
ch = v.shape[0]
|
||||
raw_blocks.append(("res", ch, num_res))
|
||||
break
|
||||
|
||||
# Second pass: determine d2s strides using the channel progression
|
||||
# For each d2s block, the next res block tells us the expected output channels
|
||||
blocks = []
|
||||
d2s_strides = []
|
||||
for i, block in enumerate(raw_blocks):
|
||||
if block[0] == "res":
|
||||
blocks.append(block)
|
||||
elif block[0] == "d2s":
|
||||
in_ch, conv_out_ch = block[1], block[2]
|
||||
# Find next res block's channels
|
||||
next_ch = None
|
||||
for j in range(i + 1, len(raw_blocks)):
|
||||
if raw_blocks[j][0] == "res":
|
||||
next_ch = raw_blocks[j][1]
|
||||
break
|
||||
|
||||
if next_ch is None:
|
||||
next_ch = in_ch // 2 # fallback
|
||||
|
||||
# out_ch = in_ch // reduction
|
||||
reduction = in_ch // next_ch if next_ch > 0 else 2
|
||||
|
||||
# conv_out = next_ch * multiplier → multiplier = conv_out / next_ch
|
||||
multiplier = conv_out_ch // next_ch if next_ch > 0 else 8
|
||||
|
||||
# Determine stride from multiplier
|
||||
if multiplier == 8:
|
||||
stride = (2, 2, 2)
|
||||
elif multiplier == 4:
|
||||
stride = (1, 2, 2)
|
||||
elif multiplier == 2:
|
||||
stride = (2, 1, 1)
|
||||
else:
|
||||
stride = (2, 2, 2)
|
||||
|
||||
d2s_strides.append(stride)
|
||||
blocks.append(("d2s", in_ch, reduction, stride))
|
||||
|
||||
if not blocks:
|
||||
return None
|
||||
|
||||
# Determine residual flag: LTX-2 has uniform (2,2,2) strides with reduction=2 → residual=True
|
||||
# LTX-2.3 has mixed strides or reduction=1 → residual=False
|
||||
has_mixed_strides = len(set(d2s_strides)) > 1
|
||||
has_non_standard_reduction = any(b[2] != 2 for b in blocks if b[0] == "d2s")
|
||||
use_residual = not has_mixed_strides and not has_non_standard_reduction
|
||||
|
||||
# Apply residual flag to all d2s blocks
|
||||
final_blocks = []
|
||||
for block in blocks:
|
||||
if block[0] == "d2s":
|
||||
final_blocks.append(("d2s", block[1], block[2], block[3], use_residual))
|
||||
else:
|
||||
final_blocks.append(block)
|
||||
|
||||
return final_blocks
|
||||
|
||||
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
"""Apply pixel normalization."""
|
||||
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sample: mx.array,
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
debug: bool = False,
|
||||
chunked_conv: bool = False,
|
||||
) -> mx.array:
|
||||
|
||||
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
|
||||
|
||||
# Add noise if timestep conditioning is enabled
|
||||
if self.timestep_conditioning:
|
||||
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
||||
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
||||
|
||||
|
||||
sample = self.per_channel_statistics.un_normalize(sample)
|
||||
|
||||
|
||||
if timestep is None and self.timestep_conditioning:
|
||||
timestep = mx.full((batch_size,), self.decode_timestep)
|
||||
|
||||
scaled_timestep = None
|
||||
if self.timestep_conditioning and timestep is not None:
|
||||
scaled_timestep = timestep * self.timestep_scale_multiplier
|
||||
|
||||
x = self.conv_in(sample, causal=causal)
|
||||
|
||||
|
||||
for i, block in self.up_blocks.items():
|
||||
if isinstance(block, ResBlockGroup):
|
||||
x = block(x, causal=causal, timestep=scaled_timestep)
|
||||
elif isinstance(block, DepthToSpaceUpsample):
|
||||
x = block(x, causal=causal, chunked_conv=chunked_conv)
|
||||
else:
|
||||
x = block(x, causal=causal)
|
||||
|
||||
|
||||
x = self.pixel_norm(x)
|
||||
|
||||
|
||||
if self.timestep_conditioning and scaled_timestep is not None:
|
||||
embedded_timestep = self.last_time_embedder(
|
||||
scaled_timestep.flatten(),
|
||||
hidden_dtype=x.dtype
|
||||
)
|
||||
embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1)
|
||||
|
||||
ada_values = self.last_scale_shift_table[None, :, :, None, None, None] # (1, 2, 128, 1, 1, 1)
|
||||
ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1)
|
||||
ada_values = ada_values + ts_reshaped
|
||||
|
||||
shift = ada_values[:, 0] # (B, 128, 1, 1, 1)
|
||||
scale = ada_values[:, 1]
|
||||
|
||||
x = x * (1 + scale) + shift
|
||||
|
||||
|
||||
x = self.act(x)
|
||||
|
||||
|
||||
x = self.conv_out(x, causal=causal)
|
||||
|
||||
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
|
||||
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
|
||||
return x
|
||||
|
||||
def decode_tiled(
|
||||
self,
|
||||
sample: mx.array,
|
||||
tiling_config: Optional[TilingConfig] = None,
|
||||
tiling_mode: str = "auto",
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
debug: bool = False,
|
||||
on_frames_ready: Optional[callable] = None,
|
||||
) -> mx.array:
|
||||
"""Decode latents using tiling to reduce memory usage.
|
||||
|
||||
This method is useful for decoding large videos that would otherwise
|
||||
cause out-of-memory errors. It divides the latents into tiles,
|
||||
decodes each tile separately, and blends them together.
|
||||
|
||||
Args:
|
||||
sample: Input latents of shape (B, C, F, H, W).
|
||||
tiling_config: Tiling configuration. If None, uses TilingConfig.default().
|
||||
causal: Whether to use causal convolutions.
|
||||
timestep: Optional timestep for conditioning.
|
||||
debug: Whether to print debug info.
|
||||
|
||||
Returns:
|
||||
Decoded video of shape (B, 3, F*8, H*8, W*8).
|
||||
"""
|
||||
if tiling_config is None:
|
||||
tiling_config = TilingConfig.default()
|
||||
|
||||
# Check if tiling is actually needed
|
||||
_, _, f, h, w = sample.shape
|
||||
needs_spatial_tiling = False
|
||||
needs_temporal_tiling = False
|
||||
|
||||
# Spatial scale is 32 (8x VAE upsample + 4x unpatchify)
|
||||
# Temporal scale is 8
|
||||
spatial_scale = 32
|
||||
temporal_scale = 8
|
||||
|
||||
if tiling_config.spatial_config is not None:
|
||||
s_cfg = tiling_config.spatial_config
|
||||
tile_size_latent = s_cfg.tile_size_in_pixels // spatial_scale
|
||||
if h > tile_size_latent or w > tile_size_latent:
|
||||
needs_spatial_tiling = True
|
||||
|
||||
if tiling_config.temporal_config is not None:
|
||||
t_cfg = tiling_config.temporal_config
|
||||
tile_size_latent = t_cfg.tile_size_in_frames // temporal_scale
|
||||
if f > tile_size_latent:
|
||||
needs_temporal_tiling = True
|
||||
|
||||
# Auto-enable chunked conv for modes where it helps (larger tiles)
|
||||
# Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks
|
||||
use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial")
|
||||
|
||||
if not needs_spatial_tiling and not needs_temporal_tiling:
|
||||
# No tiling needed, use regular decode
|
||||
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
|
||||
|
||||
return decode_with_tiling(
|
||||
decoder_fn=self,
|
||||
latents=sample,
|
||||
tiling_config=tiling_config,
|
||||
spatial_scale=32, # VAE spatial: 8x upsampling + 4x unpatchify = 32x
|
||||
temporal_scale=8, # VAE temporal upsampling factor
|
||||
causal=causal,
|
||||
timestep=timestep,
|
||||
chunked_conv=use_chunked_conv,
|
||||
on_frames_ready=on_frames_ready,
|
||||
)
|
||||
|
||||
|
||||
# Backward-compatible alias
|
||||
VideoDecoder = LTX2VideoDecoder
|
||||
44
mlx_video/models/ltx_2/video_vae/encoder.py
Normal file
44
mlx_video/models/ltx_2/video_vae/encoder.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Video VAE Encoder for LTX-2 Image-to-Video.
|
||||
|
||||
The encoder compresses input images/videos to latent representations.
|
||||
Used for I2V (image-to-video) conditioning by encoding the input image
|
||||
to latent space, which can then be used to condition video generation.
|
||||
"""
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||
|
||||
|
||||
|
||||
def encode_image(
|
||||
image: mx.array,
|
||||
encoder: VideoEncoder,
|
||||
) -> mx.array:
|
||||
"""Encode a single image to latent space.
|
||||
|
||||
Args:
|
||||
image: Image tensor of shape (H, W, 3) in range [0, 1] or (B, H, W, 3)
|
||||
encoder: Loaded VAE encoder
|
||||
|
||||
Returns:
|
||||
Latent tensor of shape (1, 128, 1, H//32, W//32)
|
||||
"""
|
||||
# Add batch dimension if needed
|
||||
if image.ndim == 3:
|
||||
image = mx.expand_dims(image, axis=0) # (1, H, W, 3)
|
||||
|
||||
# Convert from (B, H, W, C) to (B, C, H, W)
|
||||
image = mx.transpose(image, (0, 3, 1, 2)) # (B, 3, H, W)
|
||||
|
||||
# Normalize to [-1, 1]
|
||||
if image.max() > 1.0:
|
||||
image = image / 255.0
|
||||
image = image * 2.0 - 1.0
|
||||
|
||||
# Add temporal dimension: (B, C, H, W) -> (B, C, 1, H, W)
|
||||
image = mx.expand_dims(image, axis=2) # (B, 3, 1, H, W)
|
||||
|
||||
# Encode
|
||||
latent = encoder(image)
|
||||
|
||||
return latent
|
||||
125
mlx_video/models/ltx_2/video_vae/ops.py
Normal file
125
mlx_video/models/ltx_2/video_vae/ops.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Operations for Video VAE."""
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
|
||||
"""Convert video to patches.
|
||||
|
||||
Moves spatial pixels from H, W dimensions to channel dimension.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, C, F, H, W)
|
||||
patch_size_hw: Spatial patch size
|
||||
patch_size_t: Temporal patch size
|
||||
|
||||
Returns:
|
||||
Patched tensor of shape (B, C * patch_size_hw^2, F, H/patch_size_hw, W/patch_size_hw)
|
||||
"""
|
||||
b, c, f, h, w = x.shape
|
||||
|
||||
# Check dimensions are divisible
|
||||
assert h % patch_size_hw == 0 and w % patch_size_hw == 0
|
||||
assert f % patch_size_t == 0
|
||||
|
||||
# New dimensions
|
||||
new_h = h // patch_size_hw
|
||||
new_w = w // patch_size_hw
|
||||
new_f = f // patch_size_t
|
||||
new_c = c * patch_size_hw * patch_size_hw * patch_size_t
|
||||
|
||||
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
|
||||
x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw))
|
||||
|
||||
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, pw, ph, F', H', W')
|
||||
# PyTorch einops uses (c, p, r, q) = (c, temporal, width, height), so we need pw before ph
|
||||
x = mx.transpose(x, (0, 1, 3, 7, 5, 2, 4, 6))
|
||||
|
||||
# Reshape: (B, C, pt, pw, ph, F', H', W') -> (B, C*pt*pw*ph, F', H', W')
|
||||
x = mx.reshape(x, (b, new_c, new_f, new_h, new_w))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def unpatchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
|
||||
"""Convert patches back to video.
|
||||
|
||||
Inverse of patchify - moves pixels from channel dimension back to spatial.
|
||||
Matches PyTorch einops: "b (c p r q) f h w -> b c (f p) (h q) (w r)"
|
||||
where p=patch_size_t, r=patch_size_hw (width), q=patch_size_hw (height)
|
||||
|
||||
Args:
|
||||
x: Patched tensor of shape (B, C * patch_size_hw^2, F, H, W)
|
||||
patch_size_hw: Spatial patch size
|
||||
patch_size_t: Temporal patch size
|
||||
|
||||
Returns:
|
||||
Video tensor of shape (B, C, F * patch_size_t, H * patch_size_hw, W * patch_size_hw)
|
||||
"""
|
||||
b, c_packed, f, h, w = x.shape
|
||||
|
||||
# Calculate original channel count
|
||||
c = c_packed // (patch_size_hw * patch_size_hw * patch_size_t)
|
||||
|
||||
# Reshape: (B, C*pt*pr*pq, F, H, W) -> (B, C, pt, pr, pq, F, H, W)
|
||||
# where pt=temporal, pr=width_patch (r), pq=height_patch (q)
|
||||
# Channel layout from PyTorch is (c, p, r, q) = (c, temporal, width, height)
|
||||
x = mx.reshape(x, (b, c, patch_size_t, patch_size_hw, patch_size_hw, f, h, w))
|
||||
|
||||
# Permute to interleave patches with spatial dims:
|
||||
# (B, C, pt, pr, pq, F, H, W) -> (B, C, F, pt, H, pq, W, pr)
|
||||
|
||||
x = mx.transpose(x, (0, 1, 5, 2, 6, 4, 7, 3))
|
||||
|
||||
# Reshape: (B, C, F, pt, H, pq, W, pr) -> (B, C, F*pt, H*pq, W*pr)
|
||||
x = mx.reshape(x, (b, c, f * patch_size_t, h * patch_size_hw, w * patch_size_hw))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PerChannelStatistics(nn.Module):
|
||||
|
||||
def __init__(self, latent_channels: int = 128):
|
||||
|
||||
super().__init__()
|
||||
self.latent_channels = latent_channels
|
||||
|
||||
# Learnable per-channel mean and std
|
||||
self.mean = mx.zeros((latent_channels,))
|
||||
self.std = mx.ones((latent_channels,))
|
||||
|
||||
def normalize(self, x: mx.array) -> mx.array:
|
||||
"""Normalize latents using per-channel statistics.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, C, ...)
|
||||
|
||||
Returns:
|
||||
Normalized tensor
|
||||
"""
|
||||
# Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1)
|
||||
dtype = x.dtype
|
||||
# Cast to float32 for precision
|
||||
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
|
||||
return ((x - mean) / std).astype(dtype)
|
||||
|
||||
def un_normalize(self, x: mx.array) -> mx.array:
|
||||
"""Denormalize latents using per-channel statistics.
|
||||
|
||||
Args:
|
||||
x: Normalized tensor of shape (B, C, ...)
|
||||
|
||||
Returns:
|
||||
Denormalized tensor
|
||||
"""
|
||||
dtype = x.dtype
|
||||
# Cast to float32 for precision
|
||||
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
|
||||
return (x * std + mean).astype(dtype)
|
||||
172
mlx_video/models/ltx_2/video_vae/resnet.py
Normal file
172
mlx_video/models/ltx_2/video_vae/resnet.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""ResNet blocks for Video VAE."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.utils import PixelNorm
|
||||
|
||||
|
||||
class NormLayerType(Enum):
|
||||
GROUP_NORM = "group_norm"
|
||||
PIXEL_NORM = "pixel_norm"
|
||||
|
||||
|
||||
def get_norm_layer(
|
||||
norm_type: NormLayerType,
|
||||
num_channels: int,
|
||||
num_groups: int = 32,
|
||||
eps: float = 1e-6,
|
||||
) -> nn.Module:
|
||||
|
||||
if norm_type == NormLayerType.GROUP_NORM:
|
||||
return nn.GroupNorm(num_groups=num_groups, dims=num_channels, eps=eps)
|
||||
elif norm_type == NormLayerType.PIXEL_NORM:
|
||||
return PixelNorm(eps=eps)
|
||||
else:
|
||||
raise ValueError(f"Unknown norm type: {norm_type}")
|
||||
|
||||
|
||||
class ResnetBlock3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
eps: float = 1e-6,
|
||||
groups: int = 32,
|
||||
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.inject_noise = inject_noise
|
||||
|
||||
# First normalization and convolution
|
||||
self.norm1 = get_norm_layer(norm_layer, in_channels, groups, eps)
|
||||
self.conv1 = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
# Second normalization and convolution
|
||||
self.norm2 = get_norm_layer(norm_layer, out_channels, groups, eps)
|
||||
self.conv2 = CausalConv3d(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
# Shortcut connection if channels change
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
self.shortcut = None
|
||||
|
||||
# Activation
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
causal: bool = True,
|
||||
generator: Optional[int] = None,
|
||||
) -> mx.array:
|
||||
|
||||
residual = x
|
||||
|
||||
# First block
|
||||
x = self.norm1(x)
|
||||
x = self.act(x)
|
||||
x = self.conv1(x, causal=causal)
|
||||
|
||||
# Inject noise if enabled
|
||||
if self.inject_noise and generator is not None:
|
||||
noise = mx.random.normal(x.shape)
|
||||
x = x + noise * 0.01
|
||||
|
||||
# Second block
|
||||
x = self.norm2(x)
|
||||
x = self.act(x)
|
||||
x = self.conv2(x, causal=causal)
|
||||
|
||||
# Shortcut
|
||||
if self.shortcut is not None:
|
||||
residual = self.shortcut(residual, causal=causal)
|
||||
|
||||
return x + residual
|
||||
|
||||
|
||||
class UNetMidBlock3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_groups: int = 32,
|
||||
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.num_layers = num_layers
|
||||
|
||||
# Create ResNet blocks - use dict for MLX parameter tracking
|
||||
# Named res_blocks to match PyTorch weight keys
|
||||
self.res_blocks = {
|
||||
i: ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=inject_noise,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
}
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
causal: bool = True,
|
||||
timestep: Optional[mx.array] = None,
|
||||
generator: Optional[int] = None,
|
||||
) -> mx.array:
|
||||
|
||||
for resnet in self.res_blocks.values():
|
||||
x = resnet(x, causal=causal, generator=generator)
|
||||
|
||||
return x
|
||||
275
mlx_video/models/ltx_2/video_vae/sampling.py
Normal file
275
mlx_video/models/ltx_2/video_vae/sampling.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""Sampling operations for Video VAE (upsampling/downsampling)."""
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
|
||||
|
||||
class SpaceToDepthDownsample(nn.Module):
|
||||
"""Space-to-depth downsampling with 3x3 conv and skip connection.
|
||||
|
||||
PyTorch-compatible implementation:
|
||||
1. Apply 3x3 conv: in_channels -> out_channels // prod(stride)
|
||||
2. Space-to-depth on conv output: channels * prod(stride)
|
||||
3. Space-to-depth on input with group averaging for skip connection
|
||||
4. Add skip connection
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: Union[int, Tuple[int, int, int]],
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
|
||||
self.stride = stride
|
||||
self.dims = dims
|
||||
self.out_channels = out_channels
|
||||
|
||||
# Calculate channels
|
||||
multiplier = stride[0] * stride[1] * stride[2]
|
||||
self.group_size = in_channels * multiplier // out_channels
|
||||
conv_out_channels = out_channels // multiplier
|
||||
|
||||
# 3x3 convolution (not 1x1)
|
||||
self.conv = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=conv_out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
def _space_to_depth(self, x: mx.array) -> mx.array:
|
||||
"""Rearrange: b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w"""
|
||||
b, c, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
|
||||
# Reshape to group spatial elements
|
||||
x = mx.reshape(x, (b, c, d // st, st, h // sh, sh, w // sw, sw))
|
||||
|
||||
# Permute: (B, C, D', st, H', sh, W', sw) -> (B, C, st, sh, sw, D', H', W')
|
||||
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
|
||||
|
||||
# Reshape to combine channels
|
||||
new_c = c * st * sh * sw
|
||||
new_d = d // st
|
||||
new_h = h // sh
|
||||
new_w = w // sw
|
||||
x = mx.reshape(x, (b, new_c, new_d, new_h, new_w))
|
||||
|
||||
return x
|
||||
|
||||
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
|
||||
b, c, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
|
||||
# Temporal padding for causal mode
|
||||
if st == 2:
|
||||
# Duplicate first frame for padding
|
||||
x = mx.concatenate([x[:, :, :1, :, :], x], axis=2)
|
||||
d = d + 1
|
||||
|
||||
# Pad if necessary to make dimensions divisible by stride
|
||||
pad_d = (st - d % st) % st
|
||||
pad_h = (sh - h % sh) % sh
|
||||
pad_w = (sw - w % sw) % sw
|
||||
|
||||
if pad_d > 0 or pad_h > 0 or pad_w > 0:
|
||||
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)])
|
||||
|
||||
# Skip connection: space-to-depth on input, then group mean
|
||||
x_in = self._space_to_depth(x)
|
||||
# Reshape for group mean: (b, c*prod(stride), d, h, w) -> (b, out_channels, group_size, d, h, w)
|
||||
b2, c2, d2, h2, w2 = x_in.shape
|
||||
x_in = mx.reshape(x_in, (b2, self.out_channels, self.group_size, d2, h2, w2))
|
||||
x_in = mx.mean(x_in, axis=2) # (b, out_channels, d, h, w)
|
||||
|
||||
# Conv branch: apply conv then space-to-depth
|
||||
x_conv = self.conv(x, causal=causal)
|
||||
x_conv = self._space_to_depth(x_conv)
|
||||
|
||||
# Add skip connection
|
||||
return x_conv + x_in
|
||||
|
||||
|
||||
class DepthToSpaceUpsample(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
stride: Union[int, Tuple[int, int, int]],
|
||||
residual: bool = False,
|
||||
out_channels_reduction_factor: int = 1,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
|
||||
self.stride = stride
|
||||
self.dims = dims
|
||||
self.residual = residual
|
||||
self.out_channels_reduction_factor = out_channels_reduction_factor
|
||||
|
||||
# Calculate output channels
|
||||
multiplier = stride[0] * stride[1] * stride[2]
|
||||
out_channels = in_channels // out_channels_reduction_factor
|
||||
self.out_channels = out_channels
|
||||
|
||||
# 3x3x3 convolution to prepare channels for unpacking (matches PyTorch)
|
||||
self.conv = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels * multiplier,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
def _depth_to_space(self, x: mx.array) -> mx.array:
|
||||
b, c_packed, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
c = c_packed // (st * sh * sw)
|
||||
|
||||
# (B, C*st*sh*sw, D, H, W) -> (B, C, st, sh, sw, D, H, W)
|
||||
x = mx.reshape(x, (b, c, st, sh, sw, d, h, w))
|
||||
|
||||
# (B, C, st, sh, sw, D, H, W) -> (B, C, D, st, H, sh, W, sw)
|
||||
x = mx.transpose(x, (0, 1, 5, 2, 6, 3, 7, 4))
|
||||
|
||||
# (B, C, D, st, H, sh, W, sw) -> (B, C, D*st, H*sh, W*sw)
|
||||
x = mx.reshape(x, (b, c, d * st, h * sh, w * sw))
|
||||
|
||||
return x
|
||||
|
||||
def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array:
|
||||
|
||||
b, c, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
|
||||
# Compute residual path if enabled
|
||||
x_residual = None
|
||||
if self.residual:
|
||||
# Reshape input: treat channels as spatial factors
|
||||
# "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)"
|
||||
x_residual = self._depth_to_space(x)
|
||||
|
||||
# Tile channels to match output (PyTorch .repeat() tiles, not element-repeat!)
|
||||
# num_repeat = prod(stride) / out_channels_reduction_factor
|
||||
num_repeat = (st * sh * sw) // self.out_channels_reduction_factor
|
||||
x_residual = mx.tile(x_residual, (1, num_repeat, 1, 1, 1))
|
||||
|
||||
# Remove first temporal frame if temporal upsampling
|
||||
if st > 1:
|
||||
x_residual = x_residual[:, :, 1:, :, :]
|
||||
|
||||
# Use chunked mode for large tensors to reduce peak memory
|
||||
if chunked_conv and d > 4:
|
||||
x = self._chunked_conv_depth_to_space(x, causal)
|
||||
else:
|
||||
# Apply conv
|
||||
x = self.conv(x, causal=causal)
|
||||
# Depth to space rearrangement
|
||||
x = self._depth_to_space(x)
|
||||
|
||||
# Remove first frame for causal temporal upsampling
|
||||
if st > 1:
|
||||
x = x[:, :, 1:, :, :]
|
||||
|
||||
# Add residual
|
||||
if self.residual and x_residual is not None:
|
||||
x = x + x_residual
|
||||
|
||||
return x
|
||||
|
||||
def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array:
|
||||
"""Chunked conv + depth_to_space that processes in temporal chunks.
|
||||
|
||||
This reduces peak memory by avoiding the full high-channel intermediate tensor.
|
||||
Instead of materializing (B, 4096, D, H, W), we process temporal chunks and
|
||||
immediately apply depth_to_space.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, C, D, H, W)
|
||||
causal: Whether to use causal convolutions
|
||||
|
||||
Returns:
|
||||
Output tensor after conv + depth_to_space
|
||||
"""
|
||||
b, c, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
out_c = self.out_channels
|
||||
|
||||
# Output dimensions
|
||||
out_d = d * st
|
||||
out_h = h * sh
|
||||
out_w = w * sw
|
||||
|
||||
# Chunk size in temporal dimension (process 4 frames at a time)
|
||||
chunk_size = 4
|
||||
kernel_t = 3 # Temporal kernel size
|
||||
|
||||
# For causal conv, we need (kernel_t - 1) frames of padding at the start
|
||||
# For non-causal, we need (kernel_t - 1) // 2 on each side
|
||||
if causal:
|
||||
# Pad start with first frame repeated
|
||||
pad_start = kernel_t - 1
|
||||
pad_end = 0
|
||||
else:
|
||||
pad_start = (kernel_t - 1) // 2
|
||||
pad_end = (kernel_t - 1) // 2
|
||||
|
||||
# Allocate output
|
||||
outputs = []
|
||||
|
||||
# Process in chunks with overlap for conv kernel
|
||||
t_pos = 0
|
||||
while t_pos < d:
|
||||
t_end = min(t_pos + chunk_size, d)
|
||||
|
||||
# Calculate input range with padding for kernel
|
||||
in_start = max(0, t_pos - pad_start)
|
||||
in_end = min(d, t_end + pad_end)
|
||||
|
||||
# Extract chunk
|
||||
chunk = x[:, :, in_start:in_end, :, :]
|
||||
|
||||
# Apply conv to chunk
|
||||
chunk_conv = self.conv(chunk, causal=causal)
|
||||
|
||||
# Apply depth_to_space
|
||||
chunk_out = self._depth_to_space(chunk_conv)
|
||||
|
||||
# Calculate valid output range (excluding padding effects)
|
||||
# Each input frame produces st output frames
|
||||
out_start = (t_pos - in_start) * st
|
||||
out_end = out_start + (t_end - t_pos) * st
|
||||
|
||||
# Extract valid portion
|
||||
chunk_out = chunk_out[:, :, out_start:out_end, :, :]
|
||||
|
||||
outputs.append(chunk_out)
|
||||
|
||||
# Evaluate to free intermediate memory
|
||||
mx.eval(outputs[-1])
|
||||
|
||||
t_pos = t_end
|
||||
|
||||
# Concatenate all chunks
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return mx.concatenate(outputs, axis=2)
|
||||
492
mlx_video/models/ltx_2/video_vae/tiling.py
Normal file
492
mlx_video/models/ltx_2/video_vae/tiling.py
Normal file
@@ -0,0 +1,492 @@
|
||||
"""VAE Tiling Configuration for decoding large videos.
|
||||
|
||||
Implements spatial and temporal tiling with trapezoidal blending masks
|
||||
to decode large videos without running out of memory.
|
||||
|
||||
Default configuration (from PyTorch):
|
||||
- Spatial: 512px tiles with 64px overlap
|
||||
- Temporal: 64 frames with 24 frame overlap
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def compute_trapezoidal_mask_1d(
|
||||
length: int,
|
||||
ramp_left: int,
|
||||
ramp_right: int,
|
||||
left_starts_from_0: bool = False,
|
||||
) -> mx.array:
|
||||
"""Generate a 1D trapezoidal blending mask with linear ramps.
|
||||
|
||||
Args:
|
||||
length: Output length of the mask.
|
||||
ramp_left: Fade-in length on the left.
|
||||
ramp_right: Fade-out length on the right.
|
||||
left_starts_from_0: Whether the ramp starts from 0 or first non-zero value.
|
||||
Useful for temporal tiles where the first tile is causal.
|
||||
|
||||
Returns:
|
||||
A 1D array of shape (length,) with values in [0, 1].
|
||||
"""
|
||||
if length <= 0:
|
||||
raise ValueError("Mask length must be positive.")
|
||||
|
||||
ramp_left = max(0, min(ramp_left, length))
|
||||
ramp_right = max(0, min(ramp_right, length))
|
||||
|
||||
# Start with ones
|
||||
mask = [1.0] * length
|
||||
|
||||
# Apply left ramp (fade in)
|
||||
if ramp_left > 0:
|
||||
interval_length = ramp_left + 1 if left_starts_from_0 else ramp_left + 2
|
||||
# Create fade_in values using linspace logic
|
||||
fade_in_full = [i / (interval_length - 1) for i in range(interval_length)]
|
||||
fade_in = fade_in_full[:-1] # Remove last element
|
||||
if not left_starts_from_0:
|
||||
fade_in = fade_in[1:] # Remove first element too
|
||||
for i in range(min(ramp_left, len(fade_in))):
|
||||
mask[i] *= fade_in[i]
|
||||
|
||||
# Apply right ramp (fade out)
|
||||
if ramp_right > 0:
|
||||
# Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1]
|
||||
fade_out = [(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)]
|
||||
for i in range(ramp_right):
|
||||
mask[length - ramp_right + i] *= fade_out[i]
|
||||
|
||||
return mx.clip(mx.array(mask), 0, 1)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SpatialTilingConfig:
|
||||
"""Configuration for dividing each frame into spatial tiles with optional overlap."""
|
||||
|
||||
tile_size_in_pixels: int
|
||||
tile_overlap_in_pixels: int = 0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.tile_size_in_pixels < 64:
|
||||
raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}")
|
||||
if self.tile_size_in_pixels % 32 != 0:
|
||||
raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}")
|
||||
if self.tile_overlap_in_pixels % 32 != 0:
|
||||
raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}")
|
||||
if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:
|
||||
raise ValueError(
|
||||
f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TemporalTilingConfig:
|
||||
"""Configuration for dividing a video into temporal tiles."""
|
||||
|
||||
tile_size_in_frames: int
|
||||
tile_overlap_in_frames: int = 0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.tile_size_in_frames < 16:
|
||||
raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}")
|
||||
if self.tile_size_in_frames % 8 != 0:
|
||||
raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}")
|
||||
if self.tile_overlap_in_frames % 8 != 0:
|
||||
raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}")
|
||||
if self.tile_overlap_in_frames >= self.tile_size_in_frames:
|
||||
raise ValueError(
|
||||
f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TilingConfig:
|
||||
"""Configuration for splitting video into tiles with optional overlap."""
|
||||
|
||||
spatial_config: Optional[SpatialTilingConfig] = None
|
||||
temporal_config: Optional[TemporalTilingConfig] = None
|
||||
|
||||
@classmethod
|
||||
def default(cls) -> "TilingConfig":
|
||||
"""Default tiling: 512px spatial, 64 frame temporal."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64),
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig":
|
||||
"""Spatial tiling only (for short videos with large resolution)."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap),
|
||||
temporal_config=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def temporal_only(cls, tile_size: int = 64, overlap: int = 24) -> "TilingConfig":
|
||||
"""Temporal tiling only (for long videos with small resolution)."""
|
||||
return cls(
|
||||
spatial_config=None,
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def aggressive(cls) -> "TilingConfig":
|
||||
"""Aggressive tiling for very large videos (smaller tiles, much lower memory)."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=256, tile_overlap_in_pixels=64),
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=32, tile_overlap_in_frames=8),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def conservative(cls) -> "TilingConfig":
|
||||
"""Conservative tiling (larger tiles, less memory savings but faster)."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=768, tile_overlap_in_pixels=64),
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=96, tile_overlap_in_frames=24),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def auto(
|
||||
cls,
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
spatial_threshold: int = 512,
|
||||
temporal_threshold: int = 65,
|
||||
) -> Optional["TilingConfig"]:
|
||||
"""Automatically determine tiling config based on video dimensions.
|
||||
|
||||
Uses PyTorch's default tiling (512px spatial, 64f temporal) which provides
|
||||
enough context for CausalConv3d and sufficient overlap for clean blending.
|
||||
|
||||
Args:
|
||||
height: Video height in pixels
|
||||
width: Video width in pixels
|
||||
num_frames: Number of video frames
|
||||
spatial_threshold: Enable spatial tiling if either dimension exceeds this
|
||||
temporal_threshold: Enable temporal tiling if frames exceed this
|
||||
|
||||
Returns:
|
||||
TilingConfig if tiling is needed, None otherwise
|
||||
"""
|
||||
needs_spatial = height > spatial_threshold or width > spatial_threshold
|
||||
needs_temporal = num_frames > temporal_threshold
|
||||
|
||||
if not needs_spatial and not needs_temporal:
|
||||
return None
|
||||
|
||||
# Use the same defaults as PyTorch (512px spatial, 64f temporal).
|
||||
# Smaller tiles cause quality degradation because CausalConv3d needs
|
||||
# sufficient temporal context and overlap for clean blending.
|
||||
spatial_config = None
|
||||
temporal_config = None
|
||||
|
||||
if needs_spatial:
|
||||
spatial_config = SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64)
|
||||
|
||||
if needs_temporal:
|
||||
temporal_config = TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24)
|
||||
|
||||
return cls(spatial_config=spatial_config, temporal_config=temporal_config)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DimensionIntervals:
|
||||
"""Intervals for splitting a single dimension."""
|
||||
starts: List[int]
|
||||
ends: List[int]
|
||||
left_ramps: List[int]
|
||||
right_ramps: List[int]
|
||||
|
||||
|
||||
def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
|
||||
"""Split a spatial dimension into intervals."""
|
||||
if dimension_size <= size:
|
||||
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
|
||||
|
||||
amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)
|
||||
starts = [i * (size - overlap) for i in range(amount)]
|
||||
ends = [start + size for start in starts]
|
||||
ends[-1] = dimension_size
|
||||
left_ramps = [0] + [overlap] * (amount - 1)
|
||||
right_ramps = [overlap] * (amount - 1) + [0]
|
||||
|
||||
return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps)
|
||||
|
||||
|
||||
def split_in_temporal(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
|
||||
"""Split a temporal dimension into intervals with causal adjustment."""
|
||||
if dimension_size <= size:
|
||||
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
|
||||
|
||||
# Start with spatial split
|
||||
intervals = split_in_spatial(size, overlap, dimension_size)
|
||||
|
||||
# Adjust for temporal: starts[1:] -= 1, left_ramps[1:] += 1
|
||||
starts = intervals.starts.copy()
|
||||
left_ramps = intervals.left_ramps.copy()
|
||||
|
||||
for i in range(1, len(starts)):
|
||||
starts[i] = starts[i] - 1
|
||||
left_ramps[i] = left_ramps[i] + 1
|
||||
|
||||
return DimensionIntervals(starts=starts, ends=intervals.ends, left_ramps=left_ramps, right_ramps=intervals.right_ramps)
|
||||
|
||||
|
||||
def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
|
||||
"""Map temporal latent interval to output coordinates and mask."""
|
||||
start = begin * scale
|
||||
stop = 1 + (end - 1) * scale
|
||||
left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0
|
||||
right_ramp_scaled = right_ramp * scale
|
||||
|
||||
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, True)
|
||||
return slice(start, stop), mask
|
||||
|
||||
|
||||
def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
|
||||
"""Map spatial latent interval to output coordinates and mask."""
|
||||
start = begin * scale
|
||||
stop = end * scale
|
||||
left_ramp_scaled = left_ramp * scale
|
||||
right_ramp_scaled = right_ramp * scale
|
||||
|
||||
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, False)
|
||||
return slice(start, stop), mask
|
||||
|
||||
|
||||
def decode_with_tiling(
|
||||
decoder_fn,
|
||||
latents: mx.array,
|
||||
tiling_config: TilingConfig,
|
||||
spatial_scale: int = 32,
|
||||
temporal_scale: int = 8,
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
chunked_conv: bool = False,
|
||||
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
|
||||
) -> mx.array:
|
||||
"""Decode latents using tiling to reduce memory usage.
|
||||
|
||||
Args:
|
||||
decoder_fn: Decoder function to call for each tile.
|
||||
latents: Input latents of shape (B, C, F, H, W).
|
||||
tiling_config: Tiling configuration.
|
||||
spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify).
|
||||
temporal_scale: Temporal scale factor (8 for LTX VAE).
|
||||
causal: Whether to use causal convolutions.
|
||||
timestep: Optional timestep for conditioning.
|
||||
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
|
||||
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
|
||||
frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames.
|
||||
start_idx: Starting frame index in the full video.
|
||||
|
||||
Returns:
|
||||
Decoded video.
|
||||
"""
|
||||
import gc
|
||||
|
||||
b, c, f_latent, h_latent, w_latent = latents.shape
|
||||
|
||||
# Compute output shape
|
||||
out_f = 1 + (f_latent - 1) * temporal_scale
|
||||
out_h = h_latent * spatial_scale
|
||||
out_w = w_latent * spatial_scale
|
||||
|
||||
# Get tile size and overlap in latent space
|
||||
if tiling_config.spatial_config is not None:
|
||||
s_cfg = tiling_config.spatial_config
|
||||
spatial_tile_size = s_cfg.tile_size_in_pixels // spatial_scale
|
||||
spatial_overlap = s_cfg.tile_overlap_in_pixels // spatial_scale
|
||||
else:
|
||||
spatial_tile_size = max(h_latent, w_latent)
|
||||
spatial_overlap = 0
|
||||
|
||||
if tiling_config.temporal_config is not None:
|
||||
t_cfg = tiling_config.temporal_config
|
||||
temporal_tile_size = t_cfg.tile_size_in_frames // temporal_scale
|
||||
temporal_overlap = t_cfg.tile_overlap_in_frames // temporal_scale
|
||||
else:
|
||||
temporal_tile_size = f_latent
|
||||
temporal_overlap = 0
|
||||
|
||||
# Compute intervals for each dimension
|
||||
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
|
||||
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
|
||||
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
|
||||
|
||||
num_t_tiles = len(temporal_intervals.starts)
|
||||
num_h_tiles = len(height_intervals.starts)
|
||||
num_w_tiles = len(width_intervals.starts)
|
||||
total_tiles = num_t_tiles * num_h_tiles * num_w_tiles
|
||||
|
||||
# Initialize output and weight accumulator
|
||||
# Use float32 for accumulation to avoid precision issues
|
||||
output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32)
|
||||
weights = mx.zeros((b, 1, out_f, out_h, out_w), dtype=mx.float32)
|
||||
mx.eval(output, weights)
|
||||
|
||||
tile_idx = 0
|
||||
for t_idx in range(num_t_tiles):
|
||||
t_start = temporal_intervals.starts[t_idx]
|
||||
t_end = temporal_intervals.ends[t_idx]
|
||||
t_left = temporal_intervals.left_ramps[t_idx]
|
||||
t_right = temporal_intervals.right_ramps[t_idx]
|
||||
|
||||
# Map temporal coordinates
|
||||
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
||||
|
||||
for h_idx in range(num_h_tiles):
|
||||
h_start = height_intervals.starts[h_idx]
|
||||
h_end = height_intervals.ends[h_idx]
|
||||
h_left = height_intervals.left_ramps[h_idx]
|
||||
h_right = height_intervals.right_ramps[h_idx]
|
||||
|
||||
# Map height coordinates
|
||||
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
|
||||
|
||||
for w_idx in range(num_w_tiles):
|
||||
w_start = width_intervals.starts[w_idx]
|
||||
w_end = width_intervals.ends[w_idx]
|
||||
w_left = width_intervals.left_ramps[w_idx]
|
||||
w_right = width_intervals.right_ramps[w_idx]
|
||||
|
||||
# Map width coordinates
|
||||
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
|
||||
|
||||
# Extract tile latents (small slice)
|
||||
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
|
||||
|
||||
# Decode tile
|
||||
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
|
||||
mx.eval(tile_output)
|
||||
|
||||
# Clear tile_latents reference
|
||||
del tile_latents
|
||||
|
||||
# Get actual decoded dimensions
|
||||
_, _, decoded_t, decoded_h, decoded_w = tile_output.shape
|
||||
expected_t = out_t_slice.stop - out_t_slice.start
|
||||
expected_h = out_h_slice.stop - out_h_slice.start
|
||||
expected_w = out_w_slice.stop - out_w_slice.start
|
||||
|
||||
# Handle potential size mismatches (use minimum)
|
||||
actual_t = min(decoded_t, expected_t)
|
||||
actual_h = min(decoded_h, expected_h)
|
||||
actual_w = min(decoded_w, expected_w)
|
||||
|
||||
# Build blend mask
|
||||
t_mask_slice = t_mask[:actual_t] if len(t_mask) > actual_t else t_mask
|
||||
h_mask_slice = h_mask[:actual_h] if len(h_mask) > actual_h else h_mask
|
||||
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
|
||||
|
||||
blend_mask = (
|
||||
t_mask_slice.reshape(1, 1, -1, 1, 1) *
|
||||
h_mask_slice.reshape(1, 1, 1, -1, 1) *
|
||||
w_mask_slice.reshape(1, 1, 1, 1, -1)
|
||||
)
|
||||
|
||||
# Slice tile output to match
|
||||
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
|
||||
|
||||
# Clear full tile_output
|
||||
del tile_output
|
||||
|
||||
# Compute output coordinates
|
||||
t_out_start = out_t_slice.start
|
||||
t_out_end = t_out_start + actual_t
|
||||
h_out_start = out_h_slice.start
|
||||
h_out_end = h_out_start + actual_h
|
||||
w_out_start = out_w_slice.start
|
||||
w_out_end = w_out_start + actual_w
|
||||
|
||||
# Use direct slice assignment (MLX supports this)
|
||||
# Weighted accumulation
|
||||
weighted_tile = tile_output_slice * blend_mask
|
||||
|
||||
# Update output using slice assignment
|
||||
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
||||
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
|
||||
)
|
||||
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
||||
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
|
||||
)
|
||||
|
||||
# Force evaluation to free memory
|
||||
mx.eval(output, weights)
|
||||
|
||||
# Clean up tile-specific arrays
|
||||
del tile_output_slice, weighted_tile, blend_mask
|
||||
del t_mask_slice, h_mask_slice, w_mask_slice
|
||||
|
||||
tile_idx += 1
|
||||
|
||||
# Periodic garbage collection and cache clearing
|
||||
if tile_idx % 4 == 0:
|
||||
gc.collect()
|
||||
try:
|
||||
mx.clear_cache()
|
||||
except Exception:
|
||||
pass # May not be available on all platforms
|
||||
|
||||
# After completing all spatial tiles for this temporal tile,
|
||||
# check if any frames are now finalized (no future tiles will contribute)
|
||||
if on_frames_ready is not None and num_t_tiles > 1:
|
||||
# Determine the finalized frame boundary
|
||||
# Frames before the start of the next tile's output region are finalized
|
||||
if t_idx < num_t_tiles - 1:
|
||||
# Next tile starts at temporal_intervals.starts[t_idx + 1]
|
||||
next_tile_start_latent = temporal_intervals.starts[t_idx + 1]
|
||||
# Map to output frame index (first frame of next tile's contribution)
|
||||
if next_tile_start_latent == 0:
|
||||
next_tile_start_out = 0
|
||||
else:
|
||||
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
|
||||
|
||||
# We need to track how many frames we've already emitted
|
||||
if not hasattr(decode_with_tiling, '_emitted_frames'):
|
||||
decode_with_tiling._emitted_frames = 0
|
||||
emitted = decode_with_tiling._emitted_frames
|
||||
|
||||
if next_tile_start_out > emitted:
|
||||
# Normalize and emit frames [emitted, next_tile_start_out)
|
||||
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
|
||||
finalized_weights = mx.maximum(finalized_weights, 1e-8)
|
||||
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
|
||||
finalized_output = finalized_output.astype(latents.dtype)
|
||||
mx.eval(finalized_output)
|
||||
|
||||
on_frames_ready(finalized_output, emitted)
|
||||
decode_with_tiling._emitted_frames = next_tile_start_out
|
||||
|
||||
del finalized_output, finalized_weights
|
||||
gc.collect()
|
||||
|
||||
# Normalize by weights
|
||||
weights = mx.maximum(weights, 1e-8)
|
||||
output = output / weights
|
||||
mx.eval(output)
|
||||
|
||||
# Emit remaining frames if callback provided
|
||||
if on_frames_ready is not None:
|
||||
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
|
||||
if emitted < out_f:
|
||||
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
||||
mx.eval(remaining_output)
|
||||
on_frames_ready(remaining_output, emitted)
|
||||
del remaining_output
|
||||
|
||||
# Reset emitted frames counter for next call
|
||||
if hasattr(decode_with_tiling, '_emitted_frames'):
|
||||
del decode_with_tiling._emitted_frames
|
||||
|
||||
# Clean up weights
|
||||
del weights
|
||||
gc.collect()
|
||||
|
||||
# Convert back to original dtype if needed
|
||||
return output.astype(latents.dtype)
|
||||
597
mlx_video/models/ltx_2/video_vae/video_vae.py
Normal file
597
mlx_video/models/ltx_2/video_vae/video_vae.py
Normal file
@@ -0,0 +1,597 @@
|
||||
"""Video VAE Encoder and Decoder for LTX-2."""
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, patchify, unpatchify
|
||||
from mlx_video.models.ltx_2.video_vae.resnet import (
|
||||
NormLayerType,
|
||||
ResnetBlock3D,
|
||||
UNetMidBlock3D,
|
||||
get_norm_layer,
|
||||
)
|
||||
from mlx_video.models.ltx_2.video_vae.sampling import (
|
||||
DepthToSpaceUpsample,
|
||||
SpaceToDepthDownsample,
|
||||
)
|
||||
from mlx_video.utils import PixelNorm
|
||||
|
||||
|
||||
class LogVarianceType(Enum):
|
||||
"""Log variance mode for VAE."""
|
||||
PER_CHANNEL = "per_channel"
|
||||
UNIFORM = "uniform"
|
||||
CONSTANT = "constant"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
def _make_encoder_block(
|
||||
block_name: str,
|
||||
block_config: Dict[str, Any],
|
||||
in_channels: int,
|
||||
convolution_dimensions: int,
|
||||
norm_layer: NormLayerType,
|
||||
norm_num_groups: int,
|
||||
spatial_padding_mode: PaddingModeType,
|
||||
) -> Tuple[nn.Module, int]:
|
||||
"""Create an encoder block.
|
||||
|
||||
Args:
|
||||
block_name: Type of block
|
||||
block_config: Block configuration
|
||||
in_channels: Input channels
|
||||
convolution_dimensions: Number of dimensions
|
||||
norm_layer: Normalization layer type
|
||||
norm_num_groups: Number of groups for group norm
|
||||
spatial_padding_mode: Padding mode
|
||||
|
||||
Returns:
|
||||
Tuple of (block, output_channels)
|
||||
"""
|
||||
out_channels = in_channels
|
||||
|
||||
if block_name == "res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
num_layers=block_config["num_layers"],
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||
block = ResnetBlock3D(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=(2, 1, 1),
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=(1, 2, 2),
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
block = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all_x_y":
|
||||
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||
block = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all_res":
|
||||
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||
block = SpaceToDepthDownsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=(2, 2, 2),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space_res":
|
||||
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||
block = SpaceToDepthDownsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=(1, 2, 2),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time_res":
|
||||
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||
block = SpaceToDepthDownsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=(2, 1, 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown encoder block: {block_name}")
|
||||
|
||||
return block, out_channels
|
||||
|
||||
|
||||
def _make_decoder_block(
|
||||
block_name: str,
|
||||
block_config: Dict[str, Any],
|
||||
in_channels: int,
|
||||
convolution_dimensions: int,
|
||||
norm_layer: NormLayerType,
|
||||
timestep_conditioning: bool,
|
||||
norm_num_groups: int,
|
||||
spatial_padding_mode: PaddingModeType,
|
||||
) -> Tuple[nn.Module, int]:
|
||||
"""Create a decoder block."""
|
||||
out_channels = in_channels
|
||||
|
||||
if block_name == "res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
num_layers=block_config["num_layers"],
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_config.get("inject_noise", False),
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
out_channels = in_channels // block_config.get("multiplier", 2)
|
||||
block = ResnetBlock3D(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_config.get("inject_noise", False),
|
||||
timestep_conditioning=False,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
stride=(2, 1, 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
stride=(1, 2, 2),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
out_channels = in_channels // block_config.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
stride=(2, 2, 2),
|
||||
residual=block_config.get("residual", False),
|
||||
out_channels_reduction_factor=block_config.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown decoder block: {block_name}")
|
||||
|
||||
return block, out_channels
|
||||
|
||||
|
||||
class VideoEncoder(nn.Module):
|
||||
|
||||
_DEFAULT_NORM_NUM_GROUPS = 32
|
||||
|
||||
def __init__(self, config: "VideoEncoderModelConfig"):
|
||||
"""Initialize VideoEncoder from config.
|
||||
|
||||
Args:
|
||||
config: VideoEncoderModelConfig with encoder parameters
|
||||
"""
|
||||
super().__init__()
|
||||
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
|
||||
|
||||
self.patch_size = config.patch_size
|
||||
self.norm_layer = config.norm_layer
|
||||
self.latent_channels = config.out_channels
|
||||
self.latent_log_var = config.latent_log_var
|
||||
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
|
||||
|
||||
encoder_blocks = config.encoder_blocks if config.encoder_blocks else []
|
||||
encoder_spatial_padding_mode = config.encoder_spatial_padding_mode
|
||||
|
||||
# Per-channel statistics for normalizing latents
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.out_channels)
|
||||
|
||||
# After patchify, channels increase by patch_size^2
|
||||
in_channels = config.in_channels * config.patch_size ** 2
|
||||
feature_channels = config.out_channels
|
||||
|
||||
# Initial convolution
|
||||
self.conv_in = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=feature_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||
)
|
||||
|
||||
# Build encoder blocks
|
||||
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||
self.down_blocks = {}
|
||||
for idx, (block_name, block_params) in enumerate(encoder_blocks):
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
|
||||
block, feature_channels = _make_encoder_block(
|
||||
block_name=block_name,
|
||||
block_config=block_config,
|
||||
in_channels=feature_channels,
|
||||
convolution_dimensions=config.convolution_dimensions,
|
||||
norm_layer=config.norm_layer,
|
||||
norm_num_groups=self._norm_num_groups,
|
||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||
)
|
||||
self.down_blocks[idx] = block
|
||||
|
||||
# Output normalization and convolution
|
||||
if config.norm_layer == NormLayerType.GROUP_NORM:
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_groups=self._norm_num_groups,
|
||||
dims=feature_channels,
|
||||
eps=1e-6,
|
||||
)
|
||||
elif config.norm_layer == NormLayerType.PIXEL_NORM:
|
||||
self.conv_norm_out = PixelNorm()
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
# Calculate output convolution channels
|
||||
conv_out_channels = config.out_channels
|
||||
if config.latent_log_var == LogVarianceType.PER_CHANNEL:
|
||||
conv_out_channels *= 2
|
||||
elif config.latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
|
||||
conv_out_channels += 1
|
||||
|
||||
self.conv_out = CausalConv3d(
|
||||
in_channels=feature_channels,
|
||||
out_channels=conv_out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||
)
|
||||
|
||||
def __call__(self, sample: mx.array) -> mx.array:
|
||||
"""Encode video to latent representation.
|
||||
|
||||
Args:
|
||||
sample: Input video of shape (B, C, F, H, W).
|
||||
F must be 1 + 8*k (e.g., 1, 9, 17, 25, 33...)
|
||||
|
||||
Returns:
|
||||
Normalized latent means of shape (B, 128, F', H', W')
|
||||
"""
|
||||
# Validate frame count
|
||||
frames_count = sample.shape[2]
|
||||
if ((frames_count - 1) % 8) != 0:
|
||||
raise ValueError(
|
||||
"Invalid number of frames: Encode input must have 1 + 8 * x frames "
|
||||
f"(e.g., 1, 9, 17, ...). Got {frames_count} frames."
|
||||
)
|
||||
|
||||
# Initial patchify
|
||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
sample = self.conv_in(sample, causal=True)
|
||||
|
||||
# Process through encoder blocks
|
||||
for i in range(len(self.down_blocks)):
|
||||
down_block = self.down_blocks[i]
|
||||
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
|
||||
sample = down_block(sample, causal=True)
|
||||
else:
|
||||
sample = down_block(sample, causal=True)
|
||||
|
||||
# Output processing
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, causal=True)
|
||||
|
||||
# Handle log variance modes
|
||||
if self.latent_log_var == LogVarianceType.UNIFORM:
|
||||
means = sample[:, :-1, ...]
|
||||
logvar = sample[:, -1:, ...]
|
||||
num_channels = means.shape[1]
|
||||
repeated_logvar = mx.tile(logvar, (1, num_channels, 1, 1, 1))
|
||||
sample = mx.concatenate([means, repeated_logvar], axis=1)
|
||||
elif self.latent_log_var == LogVarianceType.CONSTANT:
|
||||
sample = sample[:, :-1, ...]
|
||||
approx_ln_0 = -30
|
||||
sample = mx.concatenate([
|
||||
sample,
|
||||
mx.full_like(sample, approx_ln_0),
|
||||
], axis=1)
|
||||
|
||||
# Split into means and logvar, normalize means
|
||||
means = sample[:, :self.latent_channels, ...]
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize VAE encoder weights from PyTorch format to MLX format."""
|
||||
sanitized = {}
|
||||
if "per_channel_statistics.mean" in weights:
|
||||
return weights
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
if "position_ids" in key:
|
||||
continue
|
||||
|
||||
# Only process VAE encoder weights
|
||||
if not key.startswith("vae."):
|
||||
continue
|
||||
|
||||
# Handle per-channel statistics
|
||||
if "vae.per_channel_statistics" in key:
|
||||
if key == "vae.per_channel_statistics.mean-of-means":
|
||||
new_key = "per_channel_statistics.mean"
|
||||
elif key == "vae.per_channel_statistics.std-of-means":
|
||||
new_key = "per_channel_statistics.std"
|
||||
else:
|
||||
continue
|
||||
elif key.startswith("vae.encoder."):
|
||||
new_key = key.replace("vae.encoder.", "")
|
||||
else:
|
||||
continue
|
||||
|
||||
# Conv3d: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
# Conv2d: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "VideoEncoder":
|
||||
"""Load a pretrained VideoEncoder from a directory with weights and config.
|
||||
|
||||
Args:
|
||||
model_path: Path to directory containing safetensors weights and config.json
|
||||
|
||||
Returns:
|
||||
Loaded VideoEncoder instance
|
||||
"""
|
||||
import json
|
||||
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
|
||||
|
||||
# Load config
|
||||
config_path = model_path / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
config_dict = json.load(f)
|
||||
config = VideoEncoderModelConfig(**config_dict)
|
||||
else:
|
||||
config = VideoEncoderModelConfig()
|
||||
|
||||
# Load weights
|
||||
weight_files = sorted(model_path.glob("*.safetensors"))
|
||||
if not weight_files:
|
||||
if model_path.is_file():
|
||||
weights = mx.load(str(model_path))
|
||||
else:
|
||||
raise FileNotFoundError(f"No safetensors files found in {model_path}")
|
||||
else:
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(str(wf)))
|
||||
|
||||
# Create model, sanitize and load weights
|
||||
model = cls(config)
|
||||
weights = model.sanitize(weights)
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
return model
|
||||
|
||||
|
||||
class VideoDecoder(nn.Module):
|
||||
|
||||
_DEFAULT_NORM_NUM_GROUPS = 32
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
convolution_dimensions: int = 3,
|
||||
in_channels: int = 128,
|
||||
out_channels: int = 3,
|
||||
decoder_blocks: List[Tuple[str, Any]] = None,
|
||||
patch_size: int = 4,
|
||||
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||
causal: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||
):
|
||||
"""Initialize VideoDecoder.
|
||||
|
||||
Args:
|
||||
convolution_dimensions: Number of dimensions
|
||||
in_channels: Input latent channels
|
||||
out_channels: Output channels (3 for RGB)
|
||||
decoder_blocks: List of (block_name, config) tuples
|
||||
patch_size: Spatial patch size
|
||||
norm_layer: Normalization layer type
|
||||
causal: Whether to use causal convolutions
|
||||
timestep_conditioning: Whether to use timestep conditioning
|
||||
decoder_spatial_padding_mode: Padding mode
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if decoder_blocks is None:
|
||||
decoder_blocks = []
|
||||
|
||||
self.patch_size = patch_size
|
||||
out_channels = out_channels * patch_size ** 2
|
||||
self.causal = causal
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
|
||||
|
||||
# Per-channel statistics for denormalizing latents
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
|
||||
|
||||
# Noise and timestep parameters
|
||||
self.decode_noise_scale = 0.025
|
||||
self.decode_timestep = 0.05
|
||||
|
||||
# Compute initial feature channels
|
||||
feature_channels = in_channels
|
||||
for block_name, block_params in list(reversed(decoder_blocks)):
|
||||
block_config = block_params if isinstance(block_params, dict) else {}
|
||||
if block_name == "res_x_y":
|
||||
feature_channels = feature_channels * block_config.get("multiplier", 2)
|
||||
if block_name == "compress_all":
|
||||
feature_channels = feature_channels * block_config.get("multiplier", 1)
|
||||
|
||||
# Initial convolution
|
||||
self.conv_in = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=feature_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=decoder_spatial_padding_mode,
|
||||
)
|
||||
|
||||
# Build decoder blocks (reversed order)
|
||||
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||
self.up_blocks = {}
|
||||
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
|
||||
block, feature_channels = _make_decoder_block(
|
||||
block_name=block_name,
|
||||
block_config=block_config,
|
||||
in_channels=feature_channels,
|
||||
convolution_dimensions=convolution_dimensions,
|
||||
norm_layer=norm_layer,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
norm_num_groups=self._norm_num_groups,
|
||||
spatial_padding_mode=decoder_spatial_padding_mode,
|
||||
)
|
||||
self.up_blocks[idx] = block
|
||||
|
||||
# Output normalization
|
||||
if norm_layer == NormLayerType.GROUP_NORM:
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_groups=self._norm_num_groups,
|
||||
dims=feature_channels,
|
||||
eps=1e-6,
|
||||
)
|
||||
elif norm_layer == NormLayerType.PIXEL_NORM:
|
||||
self.conv_norm_out = PixelNorm()
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(
|
||||
in_channels=feature_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=decoder_spatial_padding_mode,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sample: mx.array,
|
||||
timestep: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
"""Decode latent to video.
|
||||
|
||||
Args:
|
||||
sample: Latent tensor of shape (B, 128, F', H', W')
|
||||
timestep: Optional timestep for conditioning
|
||||
|
||||
Returns:
|
||||
Decoded video of shape (B, 3, F, H, W)
|
||||
"""
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
# Add noise if timestep conditioning is enabled
|
||||
if self.timestep_conditioning:
|
||||
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
||||
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
||||
|
||||
# Denormalize latents
|
||||
sample = self.per_channel_statistics.un_normalize(sample)
|
||||
|
||||
# Use default timestep if not provided
|
||||
if timestep is None and self.timestep_conditioning:
|
||||
timestep = mx.full((batch_size,), self.decode_timestep)
|
||||
|
||||
# Initial convolution
|
||||
sample = self.conv_in(sample, causal=self.causal)
|
||||
|
||||
# Process through decoder blocks
|
||||
for i in range(len(self.up_blocks)):
|
||||
up_block = self.up_blocks[i]
|
||||
if isinstance(up_block, UNetMidBlock3D):
|
||||
sample = up_block(sample, causal=self.causal)
|
||||
elif isinstance(up_block, ResnetBlock3D):
|
||||
sample = up_block(sample, causal=self.causal)
|
||||
else:
|
||||
sample = up_block(sample, causal=self.causal)
|
||||
|
||||
# Output processing
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
|
||||
# Unpatchify to restore spatial resolution
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
return sample
|
||||
Reference in New Issue
Block a user