initial commit (LTX-2)
This commit is contained in:
1
mlx_video/models/ltx/video_vae/__init__.py
Normal file
1
mlx_video/models/ltx/video_vae/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder
|
||||
294
mlx_video/models/ltx/video_vae/convolution.py
Normal file
294
mlx_video/models/ltx/video_vae/convolution.py
Normal file
@@ -0,0 +1,294 @@
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class PaddingModeType(Enum):
|
||||
ZEROS = "zeros"
|
||||
REFLECT = "reflect"
|
||||
|
||||
|
||||
def reflect_pad_2d(x: mx.array, pad_h: int, pad_w: int) -> mx.array:
|
||||
"""Apply reflect padding to spatial dimensions of a 5D tensor.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, D, H, W, C) - channels last
|
||||
pad_h: Padding for height dimension
|
||||
pad_w: Padding for width dimension
|
||||
|
||||
Returns:
|
||||
Padded tensor
|
||||
"""
|
||||
if pad_h == 0 and pad_w == 0:
|
||||
return x
|
||||
|
||||
# Height padding (axis 2)
|
||||
if pad_h > 0:
|
||||
# Get reflection indices - exclude boundary
|
||||
top_pad = x[:, :, 1:pad_h+1, :, :][:, :, ::-1, :, :] # Flip top portion
|
||||
bottom_pad = x[:, :, -pad_h-1:-1, :, :][:, :, ::-1, :, :] # Flip bottom portion
|
||||
x = mx.concatenate([top_pad, x, bottom_pad], axis=2)
|
||||
|
||||
# Width padding (axis 3)
|
||||
if pad_w > 0:
|
||||
left_pad = x[:, :, :, 1:pad_w+1, :][:, :, :, ::-1, :] # Flip left portion
|
||||
right_pad = x[:, :, :, -pad_w-1:-1, :][:, :, :, ::-1, :] # Flip right portion
|
||||
x = mx.concatenate([left_pad, x, right_pad], axis=3)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def make_conv_nd(
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, ...]],
|
||||
stride: Union[int, Tuple[int, ...]] = 1,
|
||||
padding: Union[int, Tuple[int, ...], str] = 0,
|
||||
causal: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
) -> nn.Module:
|
||||
|
||||
if dims == 2:
|
||||
return CausalConv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
causal=causal,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif dims == 3:
|
||||
return CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
causal=causal,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported number of dimensions: {dims}")
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int], str] = 0,
|
||||
causal: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.causal = causal
|
||||
self.spatial_padding_mode = spatial_padding_mode
|
||||
|
||||
# Normalize kernel_size and stride to tuples
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.time_kernel_size = kernel_size[0]
|
||||
|
||||
# Calculate spatial padding (temporal is handled separately via frame replication)
|
||||
height_pad = kernel_size[1] // 2
|
||||
width_pad = kernel_size[2] // 2
|
||||
self.spatial_padding = (height_pad, width_pad)
|
||||
|
||||
# Create the base convolution (without padding, we'll handle it manually)
|
||||
self.conv = nn.Conv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=0, # We handle padding manually
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
|
||||
|
||||
use_causal = causal if causal is not None else self.causal
|
||||
|
||||
# Apply temporal padding via frame replication
|
||||
# Only apply if kernel_size > 1
|
||||
if self.time_kernel_size > 1:
|
||||
if use_causal:
|
||||
# Causal: replicate first frame kernel_size-1 times at the beginning
|
||||
first_frame_pad = mx.repeat(x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2)
|
||||
x = mx.concatenate([first_frame_pad, x], axis=2)
|
||||
else:
|
||||
# Non-causal: replicate first frame at start, last frame at end
|
||||
pad_size = (self.time_kernel_size - 1) // 2
|
||||
if pad_size > 0:
|
||||
first_frame_pad = mx.repeat(x[:, :, :1, :, :], pad_size, axis=2)
|
||||
last_frame_pad = mx.repeat(x[:, :, -1:, :, :], pad_size, axis=2)
|
||||
x = mx.concatenate([first_frame_pad, x, last_frame_pad], axis=2)
|
||||
|
||||
# Transpose to channels last: (B, C, D, H, W) -> (B, D, H, W, C)
|
||||
x = mx.transpose(x, (0, 2, 3, 4, 1))
|
||||
|
||||
# Apply spatial padding
|
||||
pad_h, pad_w = self.spatial_padding
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
if self.spatial_padding_mode == PaddingModeType.REFLECT:
|
||||
# Use reflect padding for spatial dimensions
|
||||
x = reflect_pad_2d(x, pad_h, pad_w)
|
||||
else:
|
||||
# Use zero padding for spatial dimensions
|
||||
pad_width = [
|
||||
(0, 0), # Batch
|
||||
(0, 0), # D (temporal - already padded)
|
||||
(pad_h, pad_h), # H
|
||||
(pad_w, pad_w), # W
|
||||
(0, 0), # C
|
||||
]
|
||||
x = mx.pad(x, pad_width)
|
||||
|
||||
# Apply convolution with chunking for large tensors
|
||||
# Note: We choose to use chunking because MLX conv3d fails around 33 frames with 192x192 spatial
|
||||
x = self._chunked_conv3d(x)
|
||||
|
||||
# Transpose back to channels first: (B, D, H, W, C) -> (B, C, D, H, W)
|
||||
x = mx.transpose(x, (0, 4, 1, 2, 3))
|
||||
|
||||
return x
|
||||
|
||||
def _chunked_conv3d(self, x: mx.array) -> mx.array:
|
||||
"""Apply conv3d in temporal chunks to work around MLX bug with large tensors.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, D, H, W, C) in channels-last format
|
||||
|
||||
Returns:
|
||||
Output tensor after conv3d
|
||||
"""
|
||||
b, d, h, w, c = x.shape
|
||||
|
||||
|
||||
total_elements = d * h * w * c
|
||||
max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk
|
||||
|
||||
if total_elements <= max_safe_elements:
|
||||
return self.conv(x)
|
||||
|
||||
elements_per_frame = h * w * c
|
||||
max_frames_per_chunk = max(1, max_safe_elements // elements_per_frame)
|
||||
chunk_size = min(max_frames_per_chunk, 24) # Cap at 24 frames per chunk
|
||||
|
||||
kernel_t = self.time_kernel_size
|
||||
|
||||
overlap = kernel_t - 1
|
||||
|
||||
|
||||
expected_output_frames = d - overlap
|
||||
|
||||
outputs = []
|
||||
out_idx = 0
|
||||
|
||||
# Process chunks
|
||||
in_start = 0
|
||||
while out_idx < expected_output_frames:
|
||||
remaining = expected_output_frames - out_idx
|
||||
out_frames_this_chunk = min(chunk_size, remaining)
|
||||
|
||||
in_frames_needed = out_frames_this_chunk + overlap
|
||||
in_end = min(in_start + in_frames_needed, d)
|
||||
|
||||
chunk = x[:, in_start:in_end, :, :, :]
|
||||
|
||||
chunk_out = self.conv(chunk)
|
||||
mx.eval(chunk_out)
|
||||
|
||||
outputs.append(chunk_out)
|
||||
|
||||
out_idx += chunk_out.shape[1]
|
||||
in_start += chunk_out.shape[1]
|
||||
|
||||
# Concatenate all chunks
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return mx.concatenate(outputs, axis=1)
|
||||
|
||||
|
||||
class CausalConv2d(nn.Module):
|
||||
"""2D convolution with optional causal padding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Union[int, Tuple[int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int], str] = 0,
|
||||
causal: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
"""Initialize CausalConv2d."""
|
||||
super().__init__()
|
||||
|
||||
self.causal = causal
|
||||
self.spatial_padding_mode = spatial_padding_mode
|
||||
|
||||
# Normalize kernel_size and stride
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride)
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
|
||||
# Calculate padding
|
||||
if isinstance(padding, str) and padding == "same":
|
||||
self.padding = (
|
||||
(kernel_size[0] - 1) // 2,
|
||||
(kernel_size[1] - 1) // 2,
|
||||
)
|
||||
elif isinstance(padding, int):
|
||||
self.padding = (padding, padding)
|
||||
else:
|
||||
self.padding = padding
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
|
||||
"""Forward pass."""
|
||||
# Transpose to channels last: (B, C, H, W) -> (B, H, W, C)
|
||||
x = mx.transpose(x, (0, 2, 3, 1))
|
||||
|
||||
# Apply padding
|
||||
pad_h, pad_w = self.padding
|
||||
if pad_h != 0 or pad_w != 0:
|
||||
pad_width = [
|
||||
(0, 0), # Batch
|
||||
(pad_h, pad_h), # H
|
||||
(pad_w, pad_w), # W
|
||||
(0, 0), # C
|
||||
]
|
||||
x = mx.pad(x, pad_width)
|
||||
|
||||
x = self.conv(x)
|
||||
|
||||
# Transpose back: (B, H, W, C) -> (B, C, H, W)
|
||||
x = mx.transpose(x, (0, 3, 1, 2))
|
||||
|
||||
return x
|
||||
524
mlx_video/models/ltx/video_vae/decoder.py
Normal file
524
mlx_video/models/ltx/video_vae/decoder.py
Normal file
@@ -0,0 +1,524 @@
|
||||
"""Video VAE Decoder for LTX-2 with timestep conditioning.
|
||||
|
||||
Architecture (from PyTorch weights):
|
||||
- conv_in: 128 -> 1024
|
||||
- up_blocks.0: 5 ResBlocks at 1024 (with timestep)
|
||||
- up_blocks.1: Conv 1024 -> 4096, depth2space -> 512, upscale 2x
|
||||
- up_blocks.2: 5 ResBlocks at 512 (with timestep)
|
||||
- up_blocks.3: Conv 512 -> 2048, depth2space -> 256, upscale 2x
|
||||
- up_blocks.4: 5 ResBlocks at 256 (with timestep)
|
||||
- up_blocks.5: Conv 256 -> 1024, depth2space -> 128, upscale 2x
|
||||
- up_blocks.6: 5 ResBlocks at 128 (with timestep)
|
||||
- pixel_norm + timestep modulation (last_scale_shift_table)
|
||||
- conv_out: 128 -> 48
|
||||
- unpatchify: 48 -> 3 with patch_size=4
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx.video_vae.ops import unpatchify
|
||||
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: mx.array,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = True,
|
||||
downscale_freq_shift: float = 0,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
) -> mx.array:
|
||||
"""Create sinusoidal timestep embeddings."""
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * mx.arange(0, half_dim, dtype=mx.float32)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = mx.exp(exponent)
|
||||
emb = timesteps[:, None].astype(mx.float32) * emb[None, :]
|
||||
emb = scale * emb
|
||||
|
||||
emb = mx.concatenate([mx.sin(emb), mx.cos(emb)], axis=-1)
|
||||
|
||||
if flip_sin_to_cos:
|
||||
emb = mx.concatenate([emb[:, half_dim:], emb[:, :half_dim]], axis=-1)
|
||||
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = mx.pad(emb, [(0, 0), (0, 1)])
|
||||
|
||||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
"""MLP for timestep embedding."""
|
||||
|
||||
def __init__(self, in_channels: int, time_embed_dim: int):
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def __call__(self, sample: mx.array) -> mx.array:
|
||||
sample = self.linear_1(sample)
|
||||
sample = self.act(sample)
|
||||
sample = self.linear_2(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class PixArtAlphaTimestepEmbedder(nn.Module):
|
||||
"""Combined timestep embedding (sinusoidal + MLP)."""
|
||||
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=256,
|
||||
time_embed_dim=embedding_dim
|
||||
)
|
||||
|
||||
def __call__(self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32) -> mx.array:
|
||||
timesteps_proj = get_timestep_embedding(
|
||||
timestep,
|
||||
embedding_dim=256,
|
||||
flip_sin_to_cos=True,
|
||||
downscale_freq_shift=0
|
||||
)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class ResnetBlock3DSimple(nn.Module):
|
||||
"""ResNet block with optional timestep conditioning.
|
||||
|
||||
Weight keys: conv1.conv, conv2.conv, scale_shift_table
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||
timestep_conditioning: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
# Nested conv structure to match PyTorch naming: conv1.conv.weight
|
||||
self.conv1 = self._make_conv_wrapper(channels, channels, spatial_padding_mode)
|
||||
self.conv2 = self._make_conv_wrapper(channels, channels, spatial_padding_mode)
|
||||
|
||||
self.act = nn.SiLU()
|
||||
|
||||
# Scale-shift table for timestep conditioning: [shift1, scale1, shift2, scale2]
|
||||
if timestep_conditioning:
|
||||
self.scale_shift_table = mx.zeros((4, channels))
|
||||
|
||||
def _make_conv_wrapper(self, in_ch, out_ch, padding_mode):
|
||||
"""Create a wrapper object with a 'conv' attribute to match PyTorch naming."""
|
||||
class ConvWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
self_inner.conv = CausalConv3d(
|
||||
in_channels=in_ch,
|
||||
out_channels=out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=padding_mode,
|
||||
)
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
return ConvWrapper()
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
"""Apply pixel normalization."""
|
||||
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
causal: bool = False,
|
||||
timestep_embed: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
residual = x
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# Block 1 with optional timestep conditioning
|
||||
x = self.pixel_norm(x)
|
||||
|
||||
if self.timestep_conditioning and timestep_embed is not None:
|
||||
# scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1)
|
||||
# Combine table with timestep embedding
|
||||
ada_values = self.scale_shift_table[None, :, :, None, None, None] # (1, 4, C, 1, 1, 1)
|
||||
# Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1)
|
||||
channels = self.scale_shift_table.shape[1]
|
||||
ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1)
|
||||
ada_values = ada_values + ts_reshaped
|
||||
|
||||
shift1 = ada_values[:, 0] # (B, C, 1, 1, 1)
|
||||
scale1 = ada_values[:, 1]
|
||||
shift2 = ada_values[:, 2]
|
||||
scale2 = ada_values[:, 3]
|
||||
|
||||
x = x * (1 + scale1) + shift1
|
||||
|
||||
x = self.act(x)
|
||||
x = self.conv1(x, causal=causal)
|
||||
|
||||
# Block 2 with optional timestep conditioning
|
||||
x = self.pixel_norm(x)
|
||||
|
||||
if self.timestep_conditioning and timestep_embed is not None:
|
||||
x = x * (1 + scale2) + shift2
|
||||
|
||||
x = self.act(x)
|
||||
x = self.conv2(x, causal=causal)
|
||||
|
||||
return x + residual
|
||||
|
||||
|
||||
class ResBlockGroup(nn.Module):
|
||||
"""Group of ResNet blocks with shared timestep embedding.
|
||||
|
||||
PyTorch naming: res_blocks.0, res_blocks.1, etc.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
num_layers: int = 5,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||
timestep_conditioning: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
# Time embedder for this block group: embed_dim = 4 * channels
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaTimestepEmbedder(
|
||||
embedding_dim=channels * 4
|
||||
)
|
||||
|
||||
self.res_blocks = [
|
||||
ResnetBlock3DSimple(
|
||||
channels,
|
||||
spatial_padding_mode,
|
||||
timestep_conditioning=timestep_conditioning
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
timestep_embed = None
|
||||
|
||||
if self.timestep_conditioning and timestep is not None:
|
||||
batch_size = x.shape[0]
|
||||
timestep_embed = self.time_embedder(
|
||||
timestep.flatten(),
|
||||
hidden_dtype=x.dtype
|
||||
)
|
||||
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
|
||||
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
|
||||
|
||||
for res_block in self.res_blocks:
|
||||
x = res_block(x, causal=causal, timestep_embed=timestep_embed)
|
||||
return x
|
||||
|
||||
|
||||
class LTX2VideoDecoder(nn.Module):
|
||||
"""LTX-2 Video VAE Decoder with timestep conditioning.
|
||||
|
||||
Architecture:
|
||||
- conv_in: 128 -> 1024
|
||||
- up_blocks.0: 5 ResBlocks at 1024 (with timestep)
|
||||
- up_blocks.1: Upsampler 1024 -> 512
|
||||
- up_blocks.2: 5 ResBlocks at 512 (with timestep)
|
||||
- up_blocks.3: Upsampler 512 -> 256
|
||||
- up_blocks.4: 5 ResBlocks at 256 (with timestep)
|
||||
- up_blocks.5: Upsampler 256 -> 128
|
||||
- up_blocks.6: 5 ResBlocks at 128 (with timestep)
|
||||
- conv_out: 128 -> 48 (3 * 4^2 for patch_size=4)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
out_channels: int = 3,
|
||||
patch_size: int = 4,
|
||||
num_layers_per_block: int = 5,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||
timestep_conditioning: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
# Decode parameters (configurable via constructor)
|
||||
self.decode_noise_scale = 0.025 # Set to 0.0 to disable noise
|
||||
self.decode_timestep = 0.05
|
||||
|
||||
# Per-channel statistics for denormalization (loaded from weights)
|
||||
self.latents_mean = mx.zeros((in_channels,))
|
||||
self.latents_std = mx.ones((in_channels,))
|
||||
|
||||
# Initial conv: 128 -> 1024
|
||||
class ConvInWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
self_inner.conv = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=1024,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
self.conv_in = ConvInWrapper()
|
||||
|
||||
# Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample
|
||||
|
||||
self.up_blocks = [
|
||||
ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||
DepthToSpaceUpsample(
|
||||
dims=3,
|
||||
in_channels=1024,
|
||||
stride=(2, 2, 2),
|
||||
residual=True, # CRITICAL: Must match PyTorch config!
|
||||
out_channels_reduction_factor=2,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
),
|
||||
ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||
DepthToSpaceUpsample(
|
||||
dims=3,
|
||||
in_channels=512,
|
||||
stride=(2, 2, 2),
|
||||
residual=True, # CRITICAL: Must match PyTorch config!
|
||||
out_channels_reduction_factor=2,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
),
|
||||
ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||
DepthToSpaceUpsample(
|
||||
dims=3,
|
||||
in_channels=256,
|
||||
stride=(2, 2, 2),
|
||||
residual=True, # CRITICAL: Must match PyTorch config!
|
||||
out_channels_reduction_factor=2,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
),
|
||||
ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||
]
|
||||
|
||||
final_out_channels = out_channels * patch_size * patch_size
|
||||
class ConvOutWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
self_inner.conv = CausalConv3d(
|
||||
in_channels=128,
|
||||
out_channels=final_out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
self.conv_out = ConvOutWrapper()
|
||||
|
||||
self.act = nn.SiLU()
|
||||
|
||||
if timestep_conditioning:
|
||||
self.timestep_scale_multiplier = mx.array(1000.0)
|
||||
self.last_time_embedder = PixArtAlphaTimestepEmbedder(
|
||||
embedding_dim=128 * 2 # 256, matches (2, 128) table
|
||||
)
|
||||
self.last_scale_shift_table = mx.zeros((2, 128))
|
||||
|
||||
def denormalize(self, x: mx.array) -> mx.array:
|
||||
"""Denormalize latents using per-channel statistics."""
|
||||
mean = self.latents_mean.reshape(1, -1, 1, 1, 1)
|
||||
std = self.latents_std.reshape(1, -1, 1, 1, 1)
|
||||
return x * std + mean
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
"""Apply pixel normalization."""
|
||||
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sample: mx.array,
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
debug: bool = False,
|
||||
) -> mx.array:
|
||||
|
||||
def debug_stats(name, t):
|
||||
if debug:
|
||||
mx.eval(t)
|
||||
print(f" [VAE] {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
|
||||
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
if debug:
|
||||
debug_stats("Input", sample)
|
||||
|
||||
# Add noise if timestep conditioning is enabled
|
||||
if self.timestep_conditioning:
|
||||
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
||||
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
||||
if debug:
|
||||
debug_stats("After noise", sample)
|
||||
|
||||
if debug:
|
||||
print(f" [VAE] Denorm stats - mean: [{self.latents_mean.min().item():.4f}, {self.latents_mean.max().item():.4f}], std: [{self.latents_std.min().item():.4f}, {self.latents_std.max().item():.4f}]")
|
||||
sample = self.denormalize(sample)
|
||||
if debug:
|
||||
debug_stats("After denormalize", sample)
|
||||
|
||||
if timestep is None and self.timestep_conditioning:
|
||||
timestep = mx.full((batch_size,), self.decode_timestep)
|
||||
|
||||
scaled_timestep = None
|
||||
if self.timestep_conditioning and timestep is not None:
|
||||
scaled_timestep = timestep * self.timestep_scale_multiplier
|
||||
|
||||
x = self.conv_in(sample, causal=causal)
|
||||
if debug:
|
||||
debug_stats("After conv_in", x)
|
||||
|
||||
for i, block in enumerate(self.up_blocks):
|
||||
if isinstance(block, ResBlockGroup):
|
||||
x = block(x, causal=causal, timestep=scaled_timestep)
|
||||
else:
|
||||
x = block(x, causal=causal)
|
||||
if debug:
|
||||
block_type = type(block).__name__
|
||||
debug_stats(f"After up_blocks[{i}] ({block_type})", x)
|
||||
|
||||
x = self.pixel_norm(x)
|
||||
if debug:
|
||||
debug_stats("After pixel_norm", x)
|
||||
|
||||
if self.timestep_conditioning and scaled_timestep is not None:
|
||||
embedded_timestep = self.last_time_embedder(
|
||||
scaled_timestep.flatten(),
|
||||
hidden_dtype=x.dtype
|
||||
)
|
||||
embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1)
|
||||
|
||||
ada_values = self.last_scale_shift_table[None, :, :, None, None, None] # (1, 2, 128, 1, 1, 1)
|
||||
ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1)
|
||||
ada_values = ada_values + ts_reshaped
|
||||
|
||||
shift = ada_values[:, 0] # (B, 128, 1, 1, 1)
|
||||
scale = ada_values[:, 1]
|
||||
|
||||
x = x * (1 + scale) + shift
|
||||
if debug:
|
||||
debug_stats("After timestep modulation", x)
|
||||
|
||||
x = self.act(x)
|
||||
if debug:
|
||||
debug_stats("After activation", x)
|
||||
|
||||
x = self.conv_out(x, causal=causal)
|
||||
if debug:
|
||||
debug_stats("After conv_out", x)
|
||||
|
||||
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
|
||||
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
if debug:
|
||||
debug_stats("After unpatchify", x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def load_vae_decoder(model_path: str, timestep_conditioning: bool = True) -> LTX2VideoDecoder:
|
||||
from pathlib import Path
|
||||
|
||||
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
|
||||
|
||||
model_path = Path(model_path)
|
||||
|
||||
# Try to find the weights file
|
||||
if model_path.is_file() and model_path.suffix == ".safetensors":
|
||||
weights_path = model_path
|
||||
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
|
||||
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
|
||||
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
|
||||
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
|
||||
else:
|
||||
raise FileNotFoundError(f"VAE weights not found at {model_path}")
|
||||
|
||||
print(f"Loading VAE decoder from {weights_path}...")
|
||||
weights = mx.load(str(weights_path))
|
||||
|
||||
# Determine prefix based on weight keys
|
||||
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
|
||||
has_decoder_prefix = any(k.startswith("decoder.") for k in weights.keys())
|
||||
|
||||
if has_vae_prefix:
|
||||
prefix = "vae.decoder."
|
||||
stats_prefix = "vae.per_channel_statistics."
|
||||
elif has_decoder_prefix:
|
||||
prefix = "decoder."
|
||||
stats_prefix = ""
|
||||
else:
|
||||
prefix = ""
|
||||
stats_prefix = ""
|
||||
|
||||
# Load per-channel statistics for denormalization
|
||||
# Note: use std-of-means (not mean-of-stds) for proper denormalization
|
||||
mean_key = f"{stats_prefix}mean-of-means" if stats_prefix else "latents_mean"
|
||||
std_key = f"{stats_prefix}std-of-means" if stats_prefix else "latents_std"
|
||||
|
||||
if mean_key in weights:
|
||||
decoder.latents_mean = weights[mean_key]
|
||||
print(f" Loaded latent mean: shape {decoder.latents_mean.shape}")
|
||||
if std_key in weights:
|
||||
decoder.latents_std = weights[std_key]
|
||||
print(f" Loaded latent std: shape {decoder.latents_std.shape}")
|
||||
|
||||
# Build decoder weights dict with key remapping
|
||||
decoder_weights = {}
|
||||
for key, value in weights.items():
|
||||
if not key.startswith(prefix):
|
||||
continue
|
||||
|
||||
# Remove prefix
|
||||
new_key = key[len(prefix):]
|
||||
|
||||
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||
if ".conv.weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
if ".conv.bias" in key:
|
||||
pass # bias doesn't need transpose
|
||||
|
||||
|
||||
if ".conv.weight" in new_key or ".conv.bias" in new_key:
|
||||
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
|
||||
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
|
||||
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
|
||||
|
||||
decoder_weights[new_key] = value
|
||||
|
||||
print(f" Found {len(decoder_weights)} decoder weights")
|
||||
|
||||
ts_keys = [k for k in decoder_weights.keys() if "scale_shift" in k or "time_embedder" in k or "timestep_scale" in k]
|
||||
print(f" Found {len(ts_keys)} timestep conditioning weights")
|
||||
|
||||
# Load weights
|
||||
decoder.load_weights(list(decoder_weights.items()), strict=False)
|
||||
|
||||
print("VAE decoder loaded successfully")
|
||||
return decoder
|
||||
120
mlx_video/models/ltx/video_vae/ops.py
Normal file
120
mlx_video/models/ltx/video_vae/ops.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Operations for Video VAE."""
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
|
||||
"""Convert video to patches.
|
||||
|
||||
Moves spatial pixels from H, W dimensions to channel dimension.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, C, F, H, W)
|
||||
patch_size_hw: Spatial patch size
|
||||
patch_size_t: Temporal patch size
|
||||
|
||||
Returns:
|
||||
Patched tensor of shape (B, C * patch_size_hw^2, F, H/patch_size_hw, W/patch_size_hw)
|
||||
"""
|
||||
b, c, f, h, w = x.shape
|
||||
|
||||
# Check dimensions are divisible
|
||||
assert h % patch_size_hw == 0 and w % patch_size_hw == 0
|
||||
assert f % patch_size_t == 0
|
||||
|
||||
# New dimensions
|
||||
new_h = h // patch_size_hw
|
||||
new_w = w // patch_size_hw
|
||||
new_f = f // patch_size_t
|
||||
new_c = c * patch_size_hw * patch_size_hw * patch_size_t
|
||||
|
||||
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
|
||||
x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw))
|
||||
|
||||
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, ph, pw, F', H', W')
|
||||
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
|
||||
|
||||
# Reshape: (B, C, pt, ph, pw, F', H', W') -> (B, C*pt*ph*pw, F', H', W')
|
||||
x = mx.reshape(x, (b, new_c, new_f, new_h, new_w))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def unpatchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
|
||||
"""Convert patches back to video.
|
||||
|
||||
Inverse of patchify - moves pixels from channel dimension back to spatial.
|
||||
Matches PyTorch einops: "b (c p r q) f h w -> b c (f p) (h q) (w r)"
|
||||
where p=patch_size_t, r=patch_size_hw (width), q=patch_size_hw (height)
|
||||
|
||||
Args:
|
||||
x: Patched tensor of shape (B, C * patch_size_hw^2, F, H, W)
|
||||
patch_size_hw: Spatial patch size
|
||||
patch_size_t: Temporal patch size
|
||||
|
||||
Returns:
|
||||
Video tensor of shape (B, C, F * patch_size_t, H * patch_size_hw, W * patch_size_hw)
|
||||
"""
|
||||
b, c_packed, f, h, w = x.shape
|
||||
|
||||
# Calculate original channel count
|
||||
c = c_packed // (patch_size_hw * patch_size_hw * patch_size_t)
|
||||
|
||||
# Reshape: (B, C*pt*pr*pq, F, H, W) -> (B, C, pt, pr, pq, F, H, W)
|
||||
# where pt=temporal, pr=width_patch (r), pq=height_patch (q)
|
||||
# Channel layout from PyTorch is (c, p, r, q) = (c, temporal, width, height)
|
||||
x = mx.reshape(x, (b, c, patch_size_t, patch_size_hw, patch_size_hw, f, h, w))
|
||||
|
||||
# Permute to interleave patches with spatial dims:
|
||||
# (B, C, pt, pr, pq, F, H, W) -> (B, C, F, pt, H, pq, W, pr)
|
||||
|
||||
x = mx.transpose(x, (0, 1, 5, 2, 6, 4, 7, 3))
|
||||
|
||||
# Reshape: (B, C, F, pt, H, pq, W, pr) -> (B, C, F*pt, H*pq, W*pr)
|
||||
x = mx.reshape(x, (b, c, f * patch_size_t, h * patch_size_hw, w * patch_size_hw))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PerChannelStatistics(nn.Module):
|
||||
|
||||
def __init__(self, latent_channels: int = 128):
|
||||
|
||||
super().__init__()
|
||||
self.latent_channels = latent_channels
|
||||
|
||||
# Learnable per-channel mean and std
|
||||
self.mean = mx.zeros((latent_channels,))
|
||||
self.std = mx.ones((latent_channels,))
|
||||
|
||||
def normalize(self, x: mx.array) -> mx.array:
|
||||
"""Normalize latents using per-channel statistics.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, C, ...)
|
||||
|
||||
Returns:
|
||||
Normalized tensor
|
||||
"""
|
||||
# Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1)
|
||||
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||
std = self.std.reshape(1, -1, 1, 1, 1)
|
||||
|
||||
return (x - mean) / std
|
||||
|
||||
def un_normalize(self, x: mx.array) -> mx.array:
|
||||
"""Denormalize latents using per-channel statistics.
|
||||
|
||||
Args:
|
||||
x: Normalized tensor of shape (B, C, ...)
|
||||
|
||||
Returns:
|
||||
Denormalized tensor
|
||||
"""
|
||||
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||
std = self.std.reshape(1, -1, 1, 1, 1)
|
||||
|
||||
return x * std + mean
|
||||
171
mlx_video/models/ltx/video_vae/resnet.py
Normal file
171
mlx_video/models/ltx/video_vae/resnet.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""ResNet blocks for Video VAE."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.utils import PixelNorm
|
||||
|
||||
|
||||
class NormLayerType(Enum):
|
||||
GROUP_NORM = "group_norm"
|
||||
PIXEL_NORM = "pixel_norm"
|
||||
|
||||
|
||||
def get_norm_layer(
|
||||
norm_type: NormLayerType,
|
||||
num_channels: int,
|
||||
num_groups: int = 32,
|
||||
eps: float = 1e-6,
|
||||
) -> nn.Module:
|
||||
|
||||
if norm_type == NormLayerType.GROUP_NORM:
|
||||
return nn.GroupNorm(num_groups=num_groups, dims=num_channels, eps=eps)
|
||||
elif norm_type == NormLayerType.PIXEL_NORM:
|
||||
return PixelNorm(eps=eps)
|
||||
else:
|
||||
raise ValueError(f"Unknown norm type: {norm_type}")
|
||||
|
||||
|
||||
class ResnetBlock3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
eps: float = 1e-6,
|
||||
groups: int = 32,
|
||||
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.inject_noise = inject_noise
|
||||
|
||||
# First normalization and convolution
|
||||
self.norm1 = get_norm_layer(norm_layer, in_channels, groups, eps)
|
||||
self.conv1 = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
# Second normalization and convolution
|
||||
self.norm2 = get_norm_layer(norm_layer, out_channels, groups, eps)
|
||||
self.conv2 = CausalConv3d(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
# Shortcut connection if channels change
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
self.shortcut = None
|
||||
|
||||
# Activation
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
causal: bool = True,
|
||||
generator: Optional[int] = None,
|
||||
) -> mx.array:
|
||||
|
||||
residual = x
|
||||
|
||||
# First block
|
||||
x = self.norm1(x)
|
||||
x = self.act(x)
|
||||
x = self.conv1(x, causal=causal)
|
||||
|
||||
# Inject noise if enabled
|
||||
if self.inject_noise and generator is not None:
|
||||
noise = mx.random.normal(x.shape)
|
||||
x = x + noise * 0.01
|
||||
|
||||
# Second block
|
||||
x = self.norm2(x)
|
||||
x = self.act(x)
|
||||
x = self.conv2(x, causal=causal)
|
||||
|
||||
# Shortcut
|
||||
if self.shortcut is not None:
|
||||
residual = self.shortcut(residual, causal=causal)
|
||||
|
||||
return x + residual
|
||||
|
||||
|
||||
class UNetMidBlock3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_groups: int = 32,
|
||||
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.num_layers = num_layers
|
||||
|
||||
# Create ResNet blocks
|
||||
self.resnets = [
|
||||
ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=inject_noise,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
causal: bool = True,
|
||||
timestep: Optional[mx.array] = None,
|
||||
generator: Optional[int] = None,
|
||||
) -> mx.array:
|
||||
|
||||
for resnet in self.resnets:
|
||||
x = resnet(x, causal=causal, generator=generator)
|
||||
|
||||
return x
|
||||
173
mlx_video/models/ltx/video_vae/sampling.py
Normal file
173
mlx_video/models/ltx/video_vae/sampling.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Sampling operations for Video VAE (upsampling/downsampling)."""
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
|
||||
|
||||
class SpaceToDepthDownsample(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: Union[int, Tuple[int, int, int]],
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
|
||||
self.stride = stride
|
||||
self.dims = dims
|
||||
|
||||
# Calculate the multiplier for channels
|
||||
multiplier = stride[0] * stride[1] * stride[2]
|
||||
intermediate_channels = in_channels * multiplier
|
||||
|
||||
# 1x1x1 convolution to adjust channels
|
||||
self.conv = CausalConv3d(
|
||||
in_channels=intermediate_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
|
||||
|
||||
b, c, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
|
||||
# Pad if necessary to make dimensions divisible by stride
|
||||
pad_d = (st - d % st) % st
|
||||
pad_h = (sh - h % sh) % sh
|
||||
pad_w = (sw - w % sw) % sw
|
||||
|
||||
if pad_d > 0 or pad_h > 0 or pad_w > 0:
|
||||
# For causal, pad at the end of temporal dimension
|
||||
if causal:
|
||||
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)])
|
||||
else:
|
||||
x = mx.pad(x, [(0, 0), (0, 0), (pad_d // 2, pad_d - pad_d // 2),
|
||||
(pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2)])
|
||||
|
||||
b, c, d, h, w = x.shape
|
||||
|
||||
# Reshape to group spatial elements
|
||||
# (B, C, D, H, W) -> (B, C, D/st, st, H/sh, sh, W/sw, sw)
|
||||
x = mx.reshape(x, (b, c, d // st, st, h // sh, sh, w // sw, sw))
|
||||
|
||||
# Permute to move stride elements to channel dim
|
||||
# (B, C, D', st, H', sh, W', sw) -> (B, C, st, sh, sw, D', H', W')
|
||||
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
|
||||
|
||||
# Reshape to combine channels
|
||||
# (B, C, st, sh, sw, D', H', W') -> (B, C*st*sh*sw, D', H', W')
|
||||
new_c = c * st * sh * sw
|
||||
new_d = d // st
|
||||
new_h = h // sh
|
||||
new_w = w // sw
|
||||
x = mx.reshape(x, (b, new_c, new_d, new_h, new_w))
|
||||
|
||||
# Apply 1x1 conv to adjust channels
|
||||
x = self.conv(x, causal=causal)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DepthToSpaceUpsample(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
stride: Union[int, Tuple[int, int, int]],
|
||||
residual: bool = False,
|
||||
out_channels_reduction_factor: int = 1,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
|
||||
self.stride = stride
|
||||
self.dims = dims
|
||||
self.residual = residual
|
||||
self.out_channels_reduction_factor = out_channels_reduction_factor
|
||||
|
||||
# Calculate output channels
|
||||
multiplier = stride[0] * stride[1] * stride[2]
|
||||
out_channels = in_channels // out_channels_reduction_factor
|
||||
self.out_channels = out_channels
|
||||
|
||||
# 3x3x3 convolution to prepare channels for unpacking (matches PyTorch)
|
||||
self.conv = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels * multiplier,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
def _depth_to_space(self, x: mx.array) -> mx.array:
|
||||
b, c_packed, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
c = c_packed // (st * sh * sw)
|
||||
|
||||
# (B, C*st*sh*sw, D, H, W) -> (B, C, st, sh, sw, D, H, W)
|
||||
x = mx.reshape(x, (b, c, st, sh, sw, d, h, w))
|
||||
|
||||
# (B, C, st, sh, sw, D, H, W) -> (B, C, D, st, H, sh, W, sw)
|
||||
x = mx.transpose(x, (0, 1, 5, 2, 6, 3, 7, 4))
|
||||
|
||||
# (B, C, D, st, H, sh, W, sw) -> (B, C, D*st, H*sh, W*sw)
|
||||
x = mx.reshape(x, (b, c, d * st, h * sh, w * sw))
|
||||
|
||||
return x
|
||||
|
||||
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
|
||||
|
||||
b, c, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
|
||||
# Compute residual path if enabled
|
||||
x_residual = None
|
||||
if self.residual:
|
||||
# Reshape input: treat channels as spatial factors
|
||||
# "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)"
|
||||
x_residual = self._depth_to_space(x)
|
||||
|
||||
# Tile channels to match output (PyTorch .repeat() tiles, not element-repeat!)
|
||||
# num_repeat = prod(stride) / out_channels_reduction_factor
|
||||
num_repeat = (st * sh * sw) // self.out_channels_reduction_factor
|
||||
x_residual = mx.tile(x_residual, (1, num_repeat, 1, 1, 1))
|
||||
|
||||
# Remove first temporal frame if temporal upsampling
|
||||
if st > 1:
|
||||
x_residual = x_residual[:, :, 1:, :, :]
|
||||
|
||||
# Apply conv
|
||||
x = self.conv(x, causal=causal)
|
||||
|
||||
# Depth to space rearrangement
|
||||
x = self._depth_to_space(x)
|
||||
|
||||
# Remove first frame for causal temporal upsampling
|
||||
if st > 1:
|
||||
x = x[:, :, 1:, :, :]
|
||||
|
||||
# Add residual
|
||||
if self.residual and x_residual is not None:
|
||||
x = x + x_residual
|
||||
|
||||
return x
|
||||
528
mlx_video/models/ltx/video_vae/video_vae.py
Normal file
528
mlx_video/models/ltx/video_vae/video_vae.py
Normal file
@@ -0,0 +1,528 @@
|
||||
"""Video VAE Encoder and Decoder for LTX-2."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx.video_vae.ops import PerChannelStatistics, patchify, unpatchify
|
||||
from mlx_video.models.ltx.video_vae.resnet import (
|
||||
NormLayerType,
|
||||
ResnetBlock3D,
|
||||
UNetMidBlock3D,
|
||||
get_norm_layer,
|
||||
)
|
||||
from mlx_video.models.ltx.video_vae.sampling import (
|
||||
DepthToSpaceUpsample,
|
||||
SpaceToDepthDownsample,
|
||||
)
|
||||
from mlx_video.utils import PixelNorm
|
||||
|
||||
|
||||
class LogVarianceType(Enum):
|
||||
"""Log variance mode for VAE."""
|
||||
PER_CHANNEL = "per_channel"
|
||||
UNIFORM = "uniform"
|
||||
CONSTANT = "constant"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
def _make_encoder_block(
|
||||
block_name: str,
|
||||
block_config: Dict[str, Any],
|
||||
in_channels: int,
|
||||
convolution_dimensions: int,
|
||||
norm_layer: NormLayerType,
|
||||
norm_num_groups: int,
|
||||
spatial_padding_mode: PaddingModeType,
|
||||
) -> Tuple[nn.Module, int]:
|
||||
"""Create an encoder block.
|
||||
|
||||
Args:
|
||||
block_name: Type of block
|
||||
block_config: Block configuration
|
||||
in_channels: Input channels
|
||||
convolution_dimensions: Number of dimensions
|
||||
norm_layer: Normalization layer type
|
||||
norm_num_groups: Number of groups for group norm
|
||||
spatial_padding_mode: Padding mode
|
||||
|
||||
Returns:
|
||||
Tuple of (block, output_channels)
|
||||
"""
|
||||
out_channels = in_channels
|
||||
|
||||
if block_name == "res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
num_layers=block_config["num_layers"],
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||
block = ResnetBlock3D(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=(2, 1, 1),
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=(1, 2, 2),
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
block = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all_x_y":
|
||||
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||
block = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all_res":
|
||||
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||
block = SpaceToDepthDownsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=(2, 2, 2),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space_res":
|
||||
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||
block = SpaceToDepthDownsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=(1, 2, 2),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time_res":
|
||||
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||
block = SpaceToDepthDownsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=(2, 1, 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown encoder block: {block_name}")
|
||||
|
||||
return block, out_channels
|
||||
|
||||
|
||||
def _make_decoder_block(
|
||||
block_name: str,
|
||||
block_config: Dict[str, Any],
|
||||
in_channels: int,
|
||||
convolution_dimensions: int,
|
||||
norm_layer: NormLayerType,
|
||||
timestep_conditioning: bool,
|
||||
norm_num_groups: int,
|
||||
spatial_padding_mode: PaddingModeType,
|
||||
) -> Tuple[nn.Module, int]:
|
||||
"""Create a decoder block."""
|
||||
out_channels = in_channels
|
||||
|
||||
if block_name == "res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
num_layers=block_config["num_layers"],
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_config.get("inject_noise", False),
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
out_channels = in_channels // block_config.get("multiplier", 2)
|
||||
block = ResnetBlock3D(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_config.get("inject_noise", False),
|
||||
timestep_conditioning=False,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
stride=(2, 1, 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
stride=(1, 2, 2),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
out_channels = in_channels // block_config.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
stride=(2, 2, 2),
|
||||
residual=block_config.get("residual", False),
|
||||
out_channels_reduction_factor=block_config.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown decoder block: {block_name}")
|
||||
|
||||
return block, out_channels
|
||||
|
||||
|
||||
class VideoEncoder(nn.Module):
|
||||
|
||||
_DEFAULT_NORM_NUM_GROUPS = 32
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
convolution_dimensions: int = 3,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 128,
|
||||
encoder_blocks: List[Tuple[str, Any]] = None,
|
||||
patch_size: int = 4,
|
||||
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||
latent_log_var: LogVarianceType = LogVarianceType.UNIFORM,
|
||||
encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
"""Initialize VideoEncoder.
|
||||
|
||||
Args:
|
||||
convolution_dimensions: Number of dimensions (3 for video)
|
||||
in_channels: Input channels (3 for RGB)
|
||||
out_channels: Output latent channels
|
||||
encoder_blocks: List of (block_name, config) tuples
|
||||
patch_size: Spatial patch size
|
||||
norm_layer: Normalization layer type
|
||||
latent_log_var: Log variance mode
|
||||
encoder_spatial_padding_mode: Padding mode
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if encoder_blocks is None:
|
||||
encoder_blocks = []
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.norm_layer = norm_layer
|
||||
self.latent_channels = out_channels
|
||||
self.latent_log_var = latent_log_var
|
||||
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
|
||||
|
||||
# Per-channel statistics for normalizing latents
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=out_channels)
|
||||
|
||||
# After patchify, channels increase by patch_size^2
|
||||
in_channels = in_channels * patch_size ** 2
|
||||
feature_channels = out_channels
|
||||
|
||||
# Initial convolution
|
||||
self.conv_in = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=feature_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||
)
|
||||
|
||||
# Build encoder blocks
|
||||
self.down_blocks = []
|
||||
for block_name, block_params in encoder_blocks:
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
|
||||
block, feature_channels = _make_encoder_block(
|
||||
block_name=block_name,
|
||||
block_config=block_config,
|
||||
in_channels=feature_channels,
|
||||
convolution_dimensions=convolution_dimensions,
|
||||
norm_layer=norm_layer,
|
||||
norm_num_groups=self._norm_num_groups,
|
||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||
)
|
||||
self.down_blocks.append(block)
|
||||
|
||||
# Output normalization and convolution
|
||||
if norm_layer == NormLayerType.GROUP_NORM:
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_groups=self._norm_num_groups,
|
||||
dims=feature_channels,
|
||||
eps=1e-6,
|
||||
)
|
||||
elif norm_layer == NormLayerType.PIXEL_NORM:
|
||||
self.conv_norm_out = PixelNorm()
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
# Calculate output convolution channels
|
||||
conv_out_channels = out_channels
|
||||
if latent_log_var == LogVarianceType.PER_CHANNEL:
|
||||
conv_out_channels *= 2
|
||||
elif latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
|
||||
conv_out_channels += 1
|
||||
|
||||
self.conv_out = CausalConv3d(
|
||||
in_channels=feature_channels,
|
||||
out_channels=conv_out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||
)
|
||||
|
||||
def __call__(self, sample: mx.array) -> mx.array:
|
||||
"""Encode video to latent representation.
|
||||
|
||||
Args:
|
||||
sample: Input video of shape (B, C, F, H, W).
|
||||
F must be 1 + 8*k (e.g., 1, 9, 17, 25, 33...)
|
||||
|
||||
Returns:
|
||||
Normalized latent means of shape (B, 128, F', H', W')
|
||||
"""
|
||||
# Validate frame count
|
||||
frames_count = sample.shape[2]
|
||||
if ((frames_count - 1) % 8) != 0:
|
||||
raise ValueError(
|
||||
"Invalid number of frames: Encode input must have 1 + 8 * x frames "
|
||||
f"(e.g., 1, 9, 17, ...). Got {frames_count} frames."
|
||||
)
|
||||
|
||||
# Initial patchify
|
||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
sample = self.conv_in(sample, causal=True)
|
||||
|
||||
# Process through encoder blocks
|
||||
for down_block in self.down_blocks:
|
||||
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
|
||||
sample = down_block(sample, causal=True)
|
||||
else:
|
||||
sample = down_block(sample, causal=True)
|
||||
|
||||
# Output processing
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, causal=True)
|
||||
|
||||
# Handle log variance modes
|
||||
if self.latent_log_var == LogVarianceType.UNIFORM:
|
||||
means = sample[:, :-1, ...]
|
||||
logvar = sample[:, -1:, ...]
|
||||
num_channels = means.shape[1]
|
||||
repeated_logvar = mx.tile(logvar, (1, num_channels, 1, 1, 1))
|
||||
sample = mx.concatenate([means, repeated_logvar], axis=1)
|
||||
elif self.latent_log_var == LogVarianceType.CONSTANT:
|
||||
sample = sample[:, :-1, ...]
|
||||
approx_ln_0 = -30
|
||||
sample = mx.concatenate([
|
||||
sample,
|
||||
mx.full_like(sample, approx_ln_0),
|
||||
], axis=1)
|
||||
|
||||
# Split into means and logvar, normalize means
|
||||
means = sample[:, :self.latent_channels, ...]
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
|
||||
class VideoDecoder(nn.Module):
|
||||
|
||||
_DEFAULT_NORM_NUM_GROUPS = 32
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
convolution_dimensions: int = 3,
|
||||
in_channels: int = 128,
|
||||
out_channels: int = 3,
|
||||
decoder_blocks: List[Tuple[str, Any]] = None,
|
||||
patch_size: int = 4,
|
||||
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||
causal: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||
):
|
||||
"""Initialize VideoDecoder.
|
||||
|
||||
Args:
|
||||
convolution_dimensions: Number of dimensions
|
||||
in_channels: Input latent channels
|
||||
out_channels: Output channels (3 for RGB)
|
||||
decoder_blocks: List of (block_name, config) tuples
|
||||
patch_size: Spatial patch size
|
||||
norm_layer: Normalization layer type
|
||||
causal: Whether to use causal convolutions
|
||||
timestep_conditioning: Whether to use timestep conditioning
|
||||
decoder_spatial_padding_mode: Padding mode
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if decoder_blocks is None:
|
||||
decoder_blocks = []
|
||||
|
||||
self.patch_size = patch_size
|
||||
out_channels = out_channels * patch_size ** 2
|
||||
self.causal = causal
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
|
||||
|
||||
# Per-channel statistics for denormalizing latents
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
|
||||
|
||||
# Noise and timestep parameters
|
||||
self.decode_noise_scale = 0.025
|
||||
self.decode_timestep = 0.05
|
||||
|
||||
# Compute initial feature channels
|
||||
feature_channels = in_channels
|
||||
for block_name, block_params in list(reversed(decoder_blocks)):
|
||||
block_config = block_params if isinstance(block_params, dict) else {}
|
||||
if block_name == "res_x_y":
|
||||
feature_channels = feature_channels * block_config.get("multiplier", 2)
|
||||
if block_name == "compress_all":
|
||||
feature_channels = feature_channels * block_config.get("multiplier", 1)
|
||||
|
||||
# Initial convolution
|
||||
self.conv_in = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=feature_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=decoder_spatial_padding_mode,
|
||||
)
|
||||
|
||||
# Build decoder blocks (reversed order)
|
||||
self.up_blocks = []
|
||||
for block_name, block_params in list(reversed(decoder_blocks)):
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
|
||||
block, feature_channels = _make_decoder_block(
|
||||
block_name=block_name,
|
||||
block_config=block_config,
|
||||
in_channels=feature_channels,
|
||||
convolution_dimensions=convolution_dimensions,
|
||||
norm_layer=norm_layer,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
norm_num_groups=self._norm_num_groups,
|
||||
spatial_padding_mode=decoder_spatial_padding_mode,
|
||||
)
|
||||
self.up_blocks.append(block)
|
||||
|
||||
# Output normalization
|
||||
if norm_layer == NormLayerType.GROUP_NORM:
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_groups=self._norm_num_groups,
|
||||
dims=feature_channels,
|
||||
eps=1e-6,
|
||||
)
|
||||
elif norm_layer == NormLayerType.PIXEL_NORM:
|
||||
self.conv_norm_out = PixelNorm()
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(
|
||||
in_channels=feature_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=decoder_spatial_padding_mode,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sample: mx.array,
|
||||
timestep: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
"""Decode latent to video.
|
||||
|
||||
Args:
|
||||
sample: Latent tensor of shape (B, 128, F', H', W')
|
||||
timestep: Optional timestep for conditioning
|
||||
|
||||
Returns:
|
||||
Decoded video of shape (B, 3, F, H, W)
|
||||
"""
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
# Add noise if timestep conditioning is enabled
|
||||
if self.timestep_conditioning:
|
||||
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
||||
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
||||
|
||||
# Denormalize latents
|
||||
sample = self.per_channel_statistics.un_normalize(sample)
|
||||
|
||||
# Use default timestep if not provided
|
||||
if timestep is None and self.timestep_conditioning:
|
||||
timestep = mx.full((batch_size,), self.decode_timestep)
|
||||
|
||||
# Initial convolution
|
||||
sample = self.conv_in(sample, causal=self.causal)
|
||||
|
||||
# Process through decoder blocks
|
||||
for up_block in self.up_blocks:
|
||||
if isinstance(up_block, UNetMidBlock3D):
|
||||
sample = up_block(sample, causal=self.causal)
|
||||
elif isinstance(up_block, ResnetBlock3D):
|
||||
sample = up_block(sample, causal=self.causal)
|
||||
else:
|
||||
sample = up_block(sample, causal=self.causal)
|
||||
|
||||
# Output processing
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
|
||||
# Unpatchify to restore spatial resolution
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
return sample
|
||||
Reference in New Issue
Block a user