initial commit (LTX-2)
This commit is contained in:
294
mlx_video/models/ltx/video_vae/convolution.py
Normal file
294
mlx_video/models/ltx/video_vae/convolution.py
Normal file
@@ -0,0 +1,294 @@
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class PaddingModeType(Enum):
|
||||
ZEROS = "zeros"
|
||||
REFLECT = "reflect"
|
||||
|
||||
|
||||
def reflect_pad_2d(x: mx.array, pad_h: int, pad_w: int) -> mx.array:
|
||||
"""Apply reflect padding to spatial dimensions of a 5D tensor.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, D, H, W, C) - channels last
|
||||
pad_h: Padding for height dimension
|
||||
pad_w: Padding for width dimension
|
||||
|
||||
Returns:
|
||||
Padded tensor
|
||||
"""
|
||||
if pad_h == 0 and pad_w == 0:
|
||||
return x
|
||||
|
||||
# Height padding (axis 2)
|
||||
if pad_h > 0:
|
||||
# Get reflection indices - exclude boundary
|
||||
top_pad = x[:, :, 1:pad_h+1, :, :][:, :, ::-1, :, :] # Flip top portion
|
||||
bottom_pad = x[:, :, -pad_h-1:-1, :, :][:, :, ::-1, :, :] # Flip bottom portion
|
||||
x = mx.concatenate([top_pad, x, bottom_pad], axis=2)
|
||||
|
||||
# Width padding (axis 3)
|
||||
if pad_w > 0:
|
||||
left_pad = x[:, :, :, 1:pad_w+1, :][:, :, :, ::-1, :] # Flip left portion
|
||||
right_pad = x[:, :, :, -pad_w-1:-1, :][:, :, :, ::-1, :] # Flip right portion
|
||||
x = mx.concatenate([left_pad, x, right_pad], axis=3)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def make_conv_nd(
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, ...]],
|
||||
stride: Union[int, Tuple[int, ...]] = 1,
|
||||
padding: Union[int, Tuple[int, ...], str] = 0,
|
||||
causal: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
) -> nn.Module:
|
||||
|
||||
if dims == 2:
|
||||
return CausalConv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
causal=causal,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif dims == 3:
|
||||
return CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
causal=causal,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported number of dimensions: {dims}")
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int], str] = 0,
|
||||
causal: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.causal = causal
|
||||
self.spatial_padding_mode = spatial_padding_mode
|
||||
|
||||
# Normalize kernel_size and stride to tuples
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.time_kernel_size = kernel_size[0]
|
||||
|
||||
# Calculate spatial padding (temporal is handled separately via frame replication)
|
||||
height_pad = kernel_size[1] // 2
|
||||
width_pad = kernel_size[2] // 2
|
||||
self.spatial_padding = (height_pad, width_pad)
|
||||
|
||||
# Create the base convolution (without padding, we'll handle it manually)
|
||||
self.conv = nn.Conv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=0, # We handle padding manually
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
|
||||
|
||||
use_causal = causal if causal is not None else self.causal
|
||||
|
||||
# Apply temporal padding via frame replication
|
||||
# Only apply if kernel_size > 1
|
||||
if self.time_kernel_size > 1:
|
||||
if use_causal:
|
||||
# Causal: replicate first frame kernel_size-1 times at the beginning
|
||||
first_frame_pad = mx.repeat(x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2)
|
||||
x = mx.concatenate([first_frame_pad, x], axis=2)
|
||||
else:
|
||||
# Non-causal: replicate first frame at start, last frame at end
|
||||
pad_size = (self.time_kernel_size - 1) // 2
|
||||
if pad_size > 0:
|
||||
first_frame_pad = mx.repeat(x[:, :, :1, :, :], pad_size, axis=2)
|
||||
last_frame_pad = mx.repeat(x[:, :, -1:, :, :], pad_size, axis=2)
|
||||
x = mx.concatenate([first_frame_pad, x, last_frame_pad], axis=2)
|
||||
|
||||
# Transpose to channels last: (B, C, D, H, W) -> (B, D, H, W, C)
|
||||
x = mx.transpose(x, (0, 2, 3, 4, 1))
|
||||
|
||||
# Apply spatial padding
|
||||
pad_h, pad_w = self.spatial_padding
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
if self.spatial_padding_mode == PaddingModeType.REFLECT:
|
||||
# Use reflect padding for spatial dimensions
|
||||
x = reflect_pad_2d(x, pad_h, pad_w)
|
||||
else:
|
||||
# Use zero padding for spatial dimensions
|
||||
pad_width = [
|
||||
(0, 0), # Batch
|
||||
(0, 0), # D (temporal - already padded)
|
||||
(pad_h, pad_h), # H
|
||||
(pad_w, pad_w), # W
|
||||
(0, 0), # C
|
||||
]
|
||||
x = mx.pad(x, pad_width)
|
||||
|
||||
# Apply convolution with chunking for large tensors
|
||||
# Note: We choose to use chunking because MLX conv3d fails around 33 frames with 192x192 spatial
|
||||
x = self._chunked_conv3d(x)
|
||||
|
||||
# Transpose back to channels first: (B, D, H, W, C) -> (B, C, D, H, W)
|
||||
x = mx.transpose(x, (0, 4, 1, 2, 3))
|
||||
|
||||
return x
|
||||
|
||||
def _chunked_conv3d(self, x: mx.array) -> mx.array:
|
||||
"""Apply conv3d in temporal chunks to work around MLX bug with large tensors.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, D, H, W, C) in channels-last format
|
||||
|
||||
Returns:
|
||||
Output tensor after conv3d
|
||||
"""
|
||||
b, d, h, w, c = x.shape
|
||||
|
||||
|
||||
total_elements = d * h * w * c
|
||||
max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk
|
||||
|
||||
if total_elements <= max_safe_elements:
|
||||
return self.conv(x)
|
||||
|
||||
elements_per_frame = h * w * c
|
||||
max_frames_per_chunk = max(1, max_safe_elements // elements_per_frame)
|
||||
chunk_size = min(max_frames_per_chunk, 24) # Cap at 24 frames per chunk
|
||||
|
||||
kernel_t = self.time_kernel_size
|
||||
|
||||
overlap = kernel_t - 1
|
||||
|
||||
|
||||
expected_output_frames = d - overlap
|
||||
|
||||
outputs = []
|
||||
out_idx = 0
|
||||
|
||||
# Process chunks
|
||||
in_start = 0
|
||||
while out_idx < expected_output_frames:
|
||||
remaining = expected_output_frames - out_idx
|
||||
out_frames_this_chunk = min(chunk_size, remaining)
|
||||
|
||||
in_frames_needed = out_frames_this_chunk + overlap
|
||||
in_end = min(in_start + in_frames_needed, d)
|
||||
|
||||
chunk = x[:, in_start:in_end, :, :, :]
|
||||
|
||||
chunk_out = self.conv(chunk)
|
||||
mx.eval(chunk_out)
|
||||
|
||||
outputs.append(chunk_out)
|
||||
|
||||
out_idx += chunk_out.shape[1]
|
||||
in_start += chunk_out.shape[1]
|
||||
|
||||
# Concatenate all chunks
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return mx.concatenate(outputs, axis=1)
|
||||
|
||||
|
||||
class CausalConv2d(nn.Module):
|
||||
"""2D convolution with optional causal padding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Union[int, Tuple[int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int], str] = 0,
|
||||
causal: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
"""Initialize CausalConv2d."""
|
||||
super().__init__()
|
||||
|
||||
self.causal = causal
|
||||
self.spatial_padding_mode = spatial_padding_mode
|
||||
|
||||
# Normalize kernel_size and stride
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride)
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
|
||||
# Calculate padding
|
||||
if isinstance(padding, str) and padding == "same":
|
||||
self.padding = (
|
||||
(kernel_size[0] - 1) // 2,
|
||||
(kernel_size[1] - 1) // 2,
|
||||
)
|
||||
elif isinstance(padding, int):
|
||||
self.padding = (padding, padding)
|
||||
else:
|
||||
self.padding = padding
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
|
||||
"""Forward pass."""
|
||||
# Transpose to channels last: (B, C, H, W) -> (B, H, W, C)
|
||||
x = mx.transpose(x, (0, 2, 3, 1))
|
||||
|
||||
# Apply padding
|
||||
pad_h, pad_w = self.padding
|
||||
if pad_h != 0 or pad_w != 0:
|
||||
pad_width = [
|
||||
(0, 0), # Batch
|
||||
(pad_h, pad_h), # H
|
||||
(pad_w, pad_w), # W
|
||||
(0, 0), # C
|
||||
]
|
||||
x = mx.pad(x, pad_width)
|
||||
|
||||
x = self.conv(x)
|
||||
|
||||
# Transpose back: (B, H, W, C) -> (B, C, H, W)
|
||||
x = mx.transpose(x, (0, 3, 1, 2))
|
||||
|
||||
return x
|
||||
Reference in New Issue
Block a user