This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -1,8 +1,8 @@
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
from mlx_video.models.ltx_2.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
from mlx_video.models.ltx_2.video_vae.tiling import (
TilingConfig,
SpatialTilingConfig,
TemporalTilingConfig,
TilingConfig,
)
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -27,14 +27,18 @@ def reflect_pad_2d(x: mx.array, pad_h: int, pad_w: int) -> mx.array:
# Height padding (axis 2)
if pad_h > 0:
# Get reflection indices - exclude boundary
top_pad = x[:, :, 1:pad_h+1, :, :][:, :, ::-1, :, :] # Flip top portion
bottom_pad = x[:, :, -pad_h-1:-1, :, :][:, :, ::-1, :, :] # Flip bottom portion
top_pad = x[:, :, 1 : pad_h + 1, :, :][:, :, ::-1, :, :] # Flip top portion
bottom_pad = x[:, :, -pad_h - 1 : -1, :, :][
:, :, ::-1, :, :
] # Flip bottom portion
x = mx.concatenate([top_pad, x, bottom_pad], axis=2)
# Width padding (axis 3)
if pad_w > 0:
left_pad = x[:, :, :, 1:pad_w+1, :][:, :, :, ::-1, :] # Flip left portion
right_pad = x[:, :, :, -pad_w-1:-1, :][:, :, :, ::-1, :] # Flip right portion
left_pad = x[:, :, :, 1 : pad_w + 1, :][:, :, :, ::-1, :] # Flip left portion
right_pad = x[:, :, :, -pad_w - 1 : -1, :][
:, :, :, ::-1, :
] # Flip right portion
x = mx.concatenate([left_pad, x, right_pad], axis=3)
return x
@@ -50,7 +54,7 @@ def make_conv_nd(
causal: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
) -> nn.Module:
if dims == 2:
return CausalConv2d(
in_channels=in_channels,
@@ -118,15 +122,17 @@ class CausalConv3d(nn.Module):
)
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
use_causal = causal if causal is not None else self.causal
# Apply temporal padding via frame replication
# Apply temporal padding via frame replication
# Only apply if kernel_size > 1
if self.time_kernel_size > 1:
if use_causal:
# Causal: replicate first frame kernel_size-1 times at the beginning
first_frame_pad = mx.repeat(x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2)
first_frame_pad = mx.repeat(
x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2
)
x = mx.concatenate([first_frame_pad, x], axis=2)
else:
# Non-causal: replicate first frame at start, last frame at end
@@ -176,7 +182,6 @@ class CausalConv3d(nn.Module):
"""
b, d, h, w, c = x.shape
total_elements = d * h * w * c
max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk
@@ -191,11 +196,10 @@ class CausalConv3d(nn.Module):
overlap = kernel_t - 1
expected_output_frames = d - overlap
outputs = []
out_idx = 0
out_idx = 0
# Process chunks
in_start = 0

View File

@@ -15,14 +15,14 @@ Architecture (from PyTorch weights):
"""
import math
from typing import Optional, Dict
from pathlib import Path
from typing import Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx_2.video_vae.ops import unpatchify, PerChannelStatistics
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, unpatchify
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling
@@ -77,16 +77,14 @@ class PixArtAlphaTimestepEmbedder(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.timestep_embedder = TimestepEmbedding(
in_channels=256,
time_embed_dim=embedding_dim
in_channels=256, time_embed_dim=embedding_dim
)
def __call__(self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32) -> mx.array:
def __call__(
self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32
) -> mx.array:
timesteps_proj = get_timestep_embedding(
timestep,
embedding_dim=256,
flip_sin_to_cos=True,
downscale_freq_shift=0
timestep, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0
)
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
return timesteps_emb
@@ -119,6 +117,7 @@ class ResnetBlock3DSimple(nn.Module):
def _make_conv_wrapper(self, in_ch, out_ch, padding_mode):
"""Create a wrapper object with a 'conv' attribute to match PyTorch naming."""
class ConvWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
@@ -130,13 +129,15 @@ class ResnetBlock3DSimple(nn.Module):
padding=1,
spatial_padding_mode=padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
return ConvWrapper()
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization."""
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps)
def __call__(
self,
@@ -153,7 +154,9 @@ class ResnetBlock3DSimple(nn.Module):
if self.timestep_conditioning and timestep_embed is not None:
# scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1)
# Combine table with timestep embedding
ada_values = self.scale_shift_table[None, :, :, None, None, None] # (1, 4, C, 1, 1, 1)
ada_values = self.scale_shift_table[
None, :, :, None, None, None
] # (1, 4, C, 1, 1, 1)
# Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1)
channels = self.scale_shift_table.shape[1]
ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1)
@@ -199,16 +202,14 @@ class ResBlockGroup(nn.Module):
# Time embedder for this block group: embed_dim = 4 * channels
if timestep_conditioning:
self.time_embedder = PixArtAlphaTimestepEmbedder(
embedding_dim=channels * 4
)
self.time_embedder = PixArtAlphaTimestepEmbedder(embedding_dim=channels * 4)
# Use dict with int keys for MLX to track parameters properly
self.res_blocks = {
i: ResnetBlock3DSimple(
channels,
spatial_padding_mode,
timestep_conditioning=timestep_conditioning
timestep_conditioning=timestep_conditioning,
)
for i in range(num_layers)
}
@@ -224,8 +225,7 @@ class ResBlockGroup(nn.Module):
if self.timestep_conditioning and timestep is not None:
batch_size = x.shape[0]
timestep_embed = self.time_embedder(
timestep.flatten(),
hidden_dtype=x.dtype
timestep.flatten(), hidden_dtype=x.dtype
)
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
@@ -301,8 +301,10 @@ class LTX2VideoDecoder(nn.Module):
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
self.conv_in = ConvInWrapper()
# Build up blocks from config
@@ -311,8 +313,12 @@ class LTX2VideoDecoder(nn.Module):
block_type = block_def[0]
ch = block_def[1]
if block_type == "res":
num_layers = block_def[2] if len(block_def) > 2 else num_layers_per_block
self.up_blocks[idx] = ResBlockGroup(ch, num_layers, spatial_padding_mode, timestep_conditioning)
num_layers = (
block_def[2] if len(block_def) > 2 else num_layers_per_block
)
self.up_blocks[idx] = ResBlockGroup(
ch, num_layers, spatial_padding_mode, timestep_conditioning
)
elif block_type == "d2s":
reduction = block_def[2] if len(block_def) > 2 else 2
stride = block_def[3] if len(block_def) > 3 else (2, 2, 2)
@@ -327,6 +333,7 @@ class LTX2VideoDecoder(nn.Module):
)
final_out_channels = out_channels * patch_size * patch_size
class ConvOutWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
@@ -338,8 +345,10 @@ class LTX2VideoDecoder(nn.Module):
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
self.conv_out = ConvOutWrapper()
self.act = nn.SiLU()
@@ -358,7 +367,7 @@ class LTX2VideoDecoder(nn.Module):
return weights
for key, value in weights.items():
new_key = key
if not key.startswith("vae.") or key.startswith("vae.encoder."):
continue
@@ -374,7 +383,6 @@ class LTX2VideoDecoder(nn.Module):
if key.startswith("vae.decoder."):
new_key = key.replace("vae.decoder.", "")
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".conv.weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
@@ -384,7 +392,10 @@ class LTX2VideoDecoder(nn.Module):
if ".conv.weight" in new_key or ".conv.bias" in new_key:
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
if (
".conv.conv.weight" not in new_key
and ".conv.conv.bias" not in new_key
):
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
@@ -392,7 +403,9 @@ class LTX2VideoDecoder(nn.Module):
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTX2VideoDecoder":
def from_pretrained(
cls, model_path: Path, strict: bool = True
) -> "LTX2VideoDecoder":
"""Load a pretrained decoder from a directory with config.json and weights.
Args:
@@ -422,7 +435,6 @@ class LTX2VideoDecoder(nn.Module):
for wf in weight_files:
weights.update(mx.load(str(wf)))
# Infer block structure from weights
decoder_blocks = cls._infer_blocks(weights)
@@ -537,11 +549,9 @@ class LTX2VideoDecoder(nn.Module):
return final_blocks
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization."""
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps)
def __call__(
self,
@@ -551,20 +561,15 @@ class LTX2VideoDecoder(nn.Module):
debug: bool = False,
chunked_conv: bool = False,
) -> mx.array:
batch_size = sample.shape[0]
# Add noise if timestep conditioning is enabled
if self.timestep_conditioning:
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
sample = noise + (1.0 - self.decode_noise_scale) * sample
sample = self.per_channel_statistics.un_normalize(sample)
if timestep is None and self.timestep_conditioning:
timestep = mx.full((batch_size,), self.decode_timestep)
@@ -574,7 +579,6 @@ class LTX2VideoDecoder(nn.Module):
scaled_timestep = timestep * self.timestep_scale_multiplier
x = self.conv_in(sample, causal=causal)
for i, block in self.up_blocks.items():
if isinstance(block, ResBlockGroup):
@@ -583,19 +587,18 @@ class LTX2VideoDecoder(nn.Module):
x = block(x, causal=causal, chunked_conv=chunked_conv)
else:
x = block(x, causal=causal)
x = self.pixel_norm(x)
if self.timestep_conditioning and scaled_timestep is not None:
embedded_timestep = self.last_time_embedder(
scaled_timestep.flatten(),
hidden_dtype=x.dtype
scaled_timestep.flatten(), hidden_dtype=x.dtype
)
embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1)
ada_values = self.last_scale_shift_table[None, :, :, None, None, None] # (1, 2, 128, 1, 1, 1)
ada_values = self.last_scale_shift_table[
None, :, :, None, None, None
] # (1, 2, 128, 1, 1, 1)
ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1)
ada_values = ada_values + ts_reshaped
@@ -603,16 +606,13 @@ class LTX2VideoDecoder(nn.Module):
scale = ada_values[:, 1]
x = x * (1 + scale) + shift
x = self.act(x)
x = self.conv_out(x, causal=causal)
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
return x
@@ -669,11 +669,23 @@ class LTX2VideoDecoder(nn.Module):
# Auto-enable chunked conv for modes where it helps (larger tiles)
# Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks
use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial")
use_chunked_conv = tiling_mode in (
"conservative",
"none",
"auto",
"default",
"spatial",
)
if not needs_spatial_tiling and not needs_temporal_tiling:
# No tiling needed, use regular decode
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
return self(
sample,
causal=causal,
timestep=timestep,
debug=debug,
chunked_conv=use_chunked_conv,
)
return decode_with_tiling(
decoder_fn=self,

View File

@@ -6,8 +6,8 @@ to latent space, which can then be used to condition video generation.
"""
import mlx.core as mx
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
def encode_image(

View File

@@ -1,6 +1,5 @@
"""Operations for Video VAE."""
from typing import List, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -32,7 +31,9 @@ def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.a
new_c = c * patch_size_hw * patch_size_hw * patch_size_t
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw))
x = mx.reshape(
x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw)
)
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, pw, ph, F', H', W')
# PyTorch einops uses (c, p, r, q) = (c, temporal, width, height), so we need pw before ph
@@ -101,7 +102,7 @@ class PerChannelStatistics(nn.Module):
Normalized tensor
"""
# Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1)
dtype = x.dtype
dtype = x.dtype
# Cast to float32 for precision
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
@@ -117,7 +118,7 @@ class PerChannelStatistics(nn.Module):
Returns:
Denormalized tensor
"""
dtype = x.dtype
dtype = x.dtype
# Cast to float32 for precision
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)

View File

@@ -44,7 +44,7 @@ class ResnetBlock3D(nn.Module):
timestep_conditioning: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
out_channels = out_channels or in_channels
@@ -96,7 +96,7 @@ class ResnetBlock3D(nn.Module):
causal: bool = True,
generator: Optional[int] = None,
) -> mx.array:
residual = x
# First block
@@ -136,7 +136,7 @@ class UNetMidBlock3D(nn.Module):
attention_head_dim: Optional[int] = None,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
self.num_layers = num_layers

View File

@@ -104,7 +104,7 @@ class SpaceToDepthDownsample(nn.Module):
class DepthToSpaceUpsample(nn.Module):
def __init__(
self,
dims: int,
@@ -114,7 +114,7 @@ class DepthToSpaceUpsample(nn.Module):
out_channels_reduction_factor: int = 1,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
if isinstance(stride, int):
@@ -156,7 +156,9 @@ class DepthToSpaceUpsample(nn.Module):
return x
def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array:
def __call__(
self, x: mx.array, causal: bool = True, chunked_conv: bool = False
) -> mx.array:
b, c, d, h, w = x.shape
st, sh, sw = self.stride
@@ -196,7 +198,9 @@ class DepthToSpaceUpsample(nn.Module):
return x
def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array:
def _chunked_conv_depth_to_space(
self, x: mx.array, causal: bool = True
) -> mx.array:
"""Chunked conv + depth_to_space that processes in temporal chunks.
This reduces peak memory by avoiding the full high-channel intermediate tensor.

View File

@@ -55,7 +55,9 @@ def compute_trapezoidal_mask_1d(
# Apply right ramp (fade out)
if ramp_right > 0:
# Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1]
fade_out = [(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)]
fade_out = [
(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)
]
for i in range(ramp_right):
mask[length - ramp_right + i] *= fade_out[i]
@@ -71,11 +73,17 @@ class SpatialTilingConfig:
def __post_init__(self) -> None:
if self.tile_size_in_pixels < 64:
raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}")
raise ValueError(
f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}"
)
if self.tile_size_in_pixels % 32 != 0:
raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}")
raise ValueError(
f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}"
)
if self.tile_overlap_in_pixels % 32 != 0:
raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}")
raise ValueError(
f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}"
)
if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:
raise ValueError(
f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}"
@@ -91,11 +99,17 @@ class TemporalTilingConfig:
def __post_init__(self) -> None:
if self.tile_size_in_frames < 16:
raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}")
raise ValueError(
f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}"
)
if self.tile_size_in_frames % 8 != 0:
raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}")
raise ValueError(
f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}"
)
if self.tile_overlap_in_frames % 8 != 0:
raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}")
raise ValueError(
f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}"
)
if self.tile_overlap_in_frames >= self.tile_size_in_frames:
raise ValueError(
f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}"
@@ -113,15 +127,21 @@ class TilingConfig:
def default(cls) -> "TilingConfig":
"""Default tiling: 512px spatial, 64 frame temporal."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24),
spatial_config=SpatialTilingConfig(
tile_size_in_pixels=512, tile_overlap_in_pixels=64
),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=64, tile_overlap_in_frames=24
),
)
@classmethod
def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig":
"""Spatial tiling only (for short videos with large resolution)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap),
spatial_config=SpatialTilingConfig(
tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap
),
temporal_config=None,
)
@@ -130,23 +150,33 @@ class TilingConfig:
"""Temporal tiling only (for long videos with small resolution)."""
return cls(
spatial_config=None,
temporal_config=TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap
),
)
@classmethod
def aggressive(cls) -> "TilingConfig":
"""Aggressive tiling for very large videos (smaller tiles, much lower memory)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=256, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=32, tile_overlap_in_frames=8),
spatial_config=SpatialTilingConfig(
tile_size_in_pixels=256, tile_overlap_in_pixels=64
),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=32, tile_overlap_in_frames=8
),
)
@classmethod
def conservative(cls) -> "TilingConfig":
"""Conservative tiling (larger tiles, less memory savings but faster)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=768, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=96, tile_overlap_in_frames=24),
spatial_config=SpatialTilingConfig(
tile_size_in_pixels=768, tile_overlap_in_pixels=64
),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=96, tile_overlap_in_frames=24
),
)
@classmethod
@@ -186,10 +216,14 @@ class TilingConfig:
temporal_config = None
if needs_spatial:
spatial_config = SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64)
spatial_config = SpatialTilingConfig(
tile_size_in_pixels=512, tile_overlap_in_pixels=64
)
if needs_temporal:
temporal_config = TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24)
temporal_config = TemporalTilingConfig(
tile_size_in_frames=64, tile_overlap_in_frames=24
)
return cls(spatial_config=spatial_config, temporal_config=temporal_config)
@@ -197,16 +231,21 @@ class TilingConfig:
@dataclass
class DimensionIntervals:
"""Intervals for splitting a single dimension."""
starts: List[int]
ends: List[int]
left_ramps: List[int]
right_ramps: List[int]
def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
def split_in_spatial(
size: int, overlap: int, dimension_size: int
) -> DimensionIntervals:
"""Split a spatial dimension into intervals."""
if dimension_size <= size:
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
return DimensionIntervals(
starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]
)
amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)
starts = [i * (size - overlap) for i in range(amount)]
@@ -215,13 +254,19 @@ def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionI
left_ramps = [0] + [overlap] * (amount - 1)
right_ramps = [overlap] * (amount - 1) + [0]
return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps)
return DimensionIntervals(
starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps
)
def split_in_temporal(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
def split_in_temporal(
size: int, overlap: int, dimension_size: int
) -> DimensionIntervals:
"""Split a temporal dimension into intervals with causal adjustment."""
if dimension_size <= size:
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
return DimensionIntervals(
starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]
)
# Start with spatial split
intervals = split_in_spatial(size, overlap, dimension_size)
@@ -234,28 +279,41 @@ def split_in_temporal(size: int, overlap: int, dimension_size: int) -> Dimension
starts[i] = starts[i] - 1
left_ramps[i] = left_ramps[i] + 1
return DimensionIntervals(starts=starts, ends=intervals.ends, left_ramps=left_ramps, right_ramps=intervals.right_ramps)
return DimensionIntervals(
starts=starts,
ends=intervals.ends,
left_ramps=left_ramps,
right_ramps=intervals.right_ramps,
)
def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
def map_temporal_slice(
begin: int, end: int, left_ramp: int, right_ramp: int, scale: int
) -> Tuple[slice, mx.array]:
"""Map temporal latent interval to output coordinates and mask."""
start = begin * scale
stop = 1 + (end - 1) * scale
left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0
right_ramp_scaled = right_ramp * scale
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, True)
mask = compute_trapezoidal_mask_1d(
stop - start, left_ramp_scaled, right_ramp_scaled, True
)
return slice(start, stop), mask
def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
def map_spatial_slice(
begin: int, end: int, left_ramp: int, right_ramp: int, scale: int
) -> Tuple[slice, mx.array]:
"""Map spatial latent interval to output coordinates and mask."""
start = begin * scale
stop = end * scale
left_ramp_scaled = left_ramp * scale
right_ramp_scaled = right_ramp * scale
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, False)
mask = compute_trapezoidal_mask_1d(
stop - start, left_ramp_scaled, right_ramp_scaled, False
)
return slice(start, stop), mask
@@ -315,7 +373,9 @@ def decode_with_tiling(
temporal_overlap = 0
# Compute intervals for each dimension
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
temporal_intervals = split_in_temporal(
temporal_tile_size, temporal_overlap, f_latent
)
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
@@ -338,7 +398,9 @@ def decode_with_tiling(
t_right = temporal_intervals.right_ramps[t_idx]
# Map temporal coordinates
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
out_t_slice, t_mask = map_temporal_slice(
t_start, t_end, t_left, t_right, temporal_scale
)
for h_idx in range(num_h_tiles):
h_start = height_intervals.starts[h_idx]
@@ -347,7 +409,9 @@ def decode_with_tiling(
h_right = height_intervals.right_ramps[h_idx]
# Map height coordinates
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
out_h_slice, h_mask = map_spatial_slice(
h_start, h_end, h_left, h_right, spatial_scale
)
for w_idx in range(num_w_tiles):
w_start = width_intervals.starts[w_idx]
@@ -356,13 +420,23 @@ def decode_with_tiling(
w_right = width_intervals.right_ramps[w_idx]
# Map width coordinates
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
out_w_slice, w_mask = map_spatial_slice(
w_start, w_end, w_left, w_right, spatial_scale
)
# Extract tile latents (small slice)
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
tile_latents = latents[
:, :, t_start:t_end, h_start:h_end, w_start:w_end
]
# Decode tile
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
tile_output = decoder_fn(
tile_latents,
causal=causal,
timestep=timestep,
debug=False,
chunked_conv=chunked_conv,
)
mx.eval(tile_output)
# Clear tile_latents reference
@@ -385,13 +459,15 @@ def decode_with_tiling(
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
blend_mask = (
t_mask_slice.reshape(1, 1, -1, 1, 1) *
h_mask_slice.reshape(1, 1, 1, -1, 1) *
w_mask_slice.reshape(1, 1, 1, 1, -1)
t_mask_slice.reshape(1, 1, -1, 1, 1)
* h_mask_slice.reshape(1, 1, 1, -1, 1)
* w_mask_slice.reshape(1, 1, 1, 1, -1)
)
# Slice tile output to match
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
tile_output_slice = tile_output[
:, :, :actual_t, :actual_h, :actual_w
].astype(mx.float32)
# Clear full tile_output
del tile_output
@@ -409,11 +485,37 @@ def decode_with_tiling(
weighted_tile = tile_output_slice * blend_mask
# Update output using slice assignment
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
output[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
] = (
output[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
]
+ weighted_tile
)
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
weights[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
] = (
weights[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
]
+ blend_mask
)
# Force evaluation to free memory
@@ -445,10 +547,12 @@ def decode_with_tiling(
if next_tile_start_latent == 0:
next_tile_start_out = 0
else:
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
next_tile_start_out = (
1 + (next_tile_start_latent - 1) * temporal_scale
)
# We need to track how many frames we've already emitted
if not hasattr(decode_with_tiling, '_emitted_frames'):
if not hasattr(decode_with_tiling, "_emitted_frames"):
decode_with_tiling._emitted_frames = 0
emitted = decode_with_tiling._emitted_frames
@@ -456,7 +560,10 @@ def decode_with_tiling(
# Normalize and emit frames [emitted, next_tile_start_out)
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
finalized_weights = mx.maximum(finalized_weights, 1e-8)
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
finalized_output = (
output[:, :, emitted:next_tile_start_out, :, :]
/ finalized_weights
)
finalized_output = finalized_output.astype(latents.dtype)
mx.eval(finalized_output)
@@ -473,7 +580,7 @@ def decode_with_tiling(
# Emit remaining frames if callback provided
if on_frames_ready is not None:
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
emitted = getattr(decode_with_tiling, "_emitted_frames", 0)
if emitted < out_f:
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
mx.eval(remaining_output)
@@ -481,7 +588,7 @@ def decode_with_tiling(
del remaining_output
# Reset emitted frames counter for next call
if hasattr(decode_with_tiling, '_emitted_frames'):
if hasattr(decode_with_tiling, "_emitted_frames"):
del decode_with_tiling._emitted_frames
# Clean up weights

View File

@@ -8,12 +8,15 @@ import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, patchify, unpatchify
from mlx_video.models.ltx_2.video_vae.ops import (
PerChannelStatistics,
patchify,
unpatchify,
)
from mlx_video.models.ltx_2.video_vae.resnet import (
NormLayerType,
ResnetBlock3D,
UNetMidBlock3D,
get_norm_layer,
)
from mlx_video.models.ltx_2.video_vae.sampling import (
DepthToSpaceUpsample,
@@ -24,6 +27,7 @@ from mlx_video.utils import PixelNorm
class LogVarianceType(Enum):
"""Log variance mode for VAE."""
PER_CHANNEL = "per_channel"
UNIFORM = "uniform"
CONSTANT = "constant"
@@ -229,7 +233,6 @@ class VideoEncoder(nn.Module):
config: VideoEncoderModelConfig with encoder parameters
"""
super().__init__()
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
self.patch_size = config.patch_size
self.norm_layer = config.norm_layer
@@ -241,10 +244,12 @@ class VideoEncoder(nn.Module):
encoder_spatial_padding_mode = config.encoder_spatial_padding_mode
# Per-channel statistics for normalizing latents
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.out_channels)
self.per_channel_statistics = PerChannelStatistics(
latent_channels=config.out_channels
)
# After patchify, channels increase by patch_size^2
in_channels = config.in_channels * config.patch_size ** 2
in_channels = config.in_channels * config.patch_size**2
feature_channels = config.out_channels
# Initial convolution
@@ -262,7 +267,11 @@ class VideoEncoder(nn.Module):
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
self.down_blocks = {}
for idx, (block_name, block_params) in enumerate(encoder_blocks):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block_config = (
{"num_layers": block_params}
if isinstance(block_params, int)
else block_params
)
block, feature_channels = _make_encoder_block(
block_name=block_name,
@@ -291,7 +300,10 @@ class VideoEncoder(nn.Module):
conv_out_channels = config.out_channels
if config.latent_log_var == LogVarianceType.PER_CHANNEL:
conv_out_channels *= 2
elif config.latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
elif config.latent_log_var in {
LogVarianceType.UNIFORM,
LogVarianceType.CONSTANT,
}:
conv_out_channels += 1
self.conv_out = CausalConv3d(
@@ -349,13 +361,16 @@ class VideoEncoder(nn.Module):
elif self.latent_log_var == LogVarianceType.CONSTANT:
sample = sample[:, :-1, ...]
approx_ln_0 = -30
sample = mx.concatenate([
sample,
mx.full_like(sample, approx_ln_0),
], axis=1)
sample = mx.concatenate(
[
sample,
mx.full_like(sample, approx_ln_0),
],
axis=1,
)
# Split into means and logvar, normalize means
means = sample[:, :self.latent_channels, ...]
means = sample[:, : self.latent_channels, ...]
return self.per_channel_statistics.normalize(means)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
@@ -409,6 +424,7 @@ class VideoEncoder(nn.Module):
Loaded VideoEncoder instance
"""
import json
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
# Load config
@@ -474,7 +490,7 @@ class VideoDecoder(nn.Module):
decoder_blocks = []
self.patch_size = patch_size
out_channels = out_channels * patch_size ** 2
out_channels = out_channels * patch_size**2
self.causal = causal
self.timestep_conditioning = timestep_conditioning
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
@@ -510,7 +526,11 @@ class VideoDecoder(nn.Module):
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
self.up_blocks = {}
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block_config = (
{"num_layers": block_params}
if isinstance(block_params, int)
else block_params
)
block, feature_channels = _make_decoder_block(
block_name=block_name,