Files
mlx-video/mlx_video/models/ltx/video_vae/tiling.py
Prince Canuma e4cdbb7eab add vae tiling
2026-01-17 07:51:54 +01:00

471 lines
18 KiB
Python

"""VAE Tiling Configuration for decoding large videos.
Implements spatial and temporal tiling with trapezoidal blending masks
to decode large videos without running out of memory.
Default configuration (from PyTorch):
- Spatial: 512px tiles with 64px overlap
- Temporal: 64 frames with 24 frame overlap
"""
from dataclasses import dataclass, replace
from typing import List, Optional, Tuple
import mlx.core as mx
def compute_trapezoidal_mask_1d(
length: int,
ramp_left: int,
ramp_right: int,
left_starts_from_0: bool = False,
) -> mx.array:
"""Generate a 1D trapezoidal blending mask with linear ramps.
Args:
length: Output length of the mask.
ramp_left: Fade-in length on the left.
ramp_right: Fade-out length on the right.
left_starts_from_0: Whether the ramp starts from 0 or first non-zero value.
Useful for temporal tiles where the first tile is causal.
Returns:
A 1D array of shape (length,) with values in [0, 1].
"""
if length <= 0:
raise ValueError("Mask length must be positive.")
ramp_left = max(0, min(ramp_left, length))
ramp_right = max(0, min(ramp_right, length))
# Start with ones
mask = [1.0] * length
# Apply left ramp (fade in)
if ramp_left > 0:
interval_length = ramp_left + 1 if left_starts_from_0 else ramp_left + 2
# Create fade_in values using linspace logic
fade_in_full = [i / (interval_length - 1) for i in range(interval_length)]
fade_in = fade_in_full[:-1] # Remove last element
if not left_starts_from_0:
fade_in = fade_in[1:] # Remove first element too
for i in range(min(ramp_left, len(fade_in))):
mask[i] *= fade_in[i]
# Apply right ramp (fade out)
if ramp_right > 0:
# Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1]
fade_out = [(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)]
for i in range(ramp_right):
mask[length - ramp_right + i] *= fade_out[i]
return mx.clip(mx.array(mask), 0, 1)
@dataclass(frozen=True)
class SpatialTilingConfig:
"""Configuration for dividing each frame into spatial tiles with optional overlap."""
tile_size_in_pixels: int
tile_overlap_in_pixels: int = 0
def __post_init__(self) -> None:
if self.tile_size_in_pixels < 64:
raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}")
if self.tile_size_in_pixels % 32 != 0:
raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}")
if self.tile_overlap_in_pixels % 32 != 0:
raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}")
if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:
raise ValueError(
f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}"
)
@dataclass(frozen=True)
class TemporalTilingConfig:
"""Configuration for dividing a video into temporal tiles."""
tile_size_in_frames: int
tile_overlap_in_frames: int = 0
def __post_init__(self) -> None:
if self.tile_size_in_frames < 16:
raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}")
if self.tile_size_in_frames % 8 != 0:
raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}")
if self.tile_overlap_in_frames % 8 != 0:
raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}")
if self.tile_overlap_in_frames >= self.tile_size_in_frames:
raise ValueError(
f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}"
)
@dataclass(frozen=True)
class TilingConfig:
"""Configuration for splitting video into tiles with optional overlap."""
spatial_config: Optional[SpatialTilingConfig] = None
temporal_config: Optional[TemporalTilingConfig] = None
@classmethod
def default(cls) -> "TilingConfig":
"""Default tiling: 512px spatial, 64 frame temporal."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24),
)
@classmethod
def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig":
"""Spatial tiling only (for short videos with large resolution)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap),
temporal_config=None,
)
@classmethod
def temporal_only(cls, tile_size: int = 64, overlap: int = 24) -> "TilingConfig":
"""Temporal tiling only (for long videos with small resolution)."""
return cls(
spatial_config=None,
temporal_config=TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap),
)
@classmethod
def aggressive(cls) -> "TilingConfig":
"""Aggressive tiling for very large videos (smaller tiles, much lower memory)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=256, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=32, tile_overlap_in_frames=8),
)
@classmethod
def conservative(cls) -> "TilingConfig":
"""Conservative tiling (larger tiles, less memory savings but faster)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=768, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=96, tile_overlap_in_frames=24),
)
@classmethod
def auto(
cls,
height: int,
width: int,
num_frames: int,
spatial_threshold: int = 512,
temporal_threshold: int = 65,
) -> Optional["TilingConfig"]:
"""Automatically determine tiling config based on video dimensions.
Args:
height: Video height in pixels
width: Video width in pixels
num_frames: Number of video frames
spatial_threshold: Enable spatial tiling if either dimension exceeds this
temporal_threshold: Enable temporal tiling if frames exceed this
Returns:
TilingConfig if tiling is needed, None otherwise
"""
needs_spatial = height > spatial_threshold or width > spatial_threshold
needs_temporal = num_frames > temporal_threshold
if not needs_spatial and not needs_temporal:
return None
# Estimate memory requirement (rough heuristic)
# Output size in bytes (float32): B * 3 * F * H * W * 4
estimated_output_gb = (3 * num_frames * height * width * 4) / (1024**3)
# For very large videos, use aggressive tiling
if estimated_output_gb > 2.0 or (height * width > 768 * 1024 and num_frames > 100):
return cls.aggressive()
spatial_config = None
temporal_config = None
if needs_spatial:
# Choose tile size based on resolution
max_dim = max(height, width)
if max_dim > 1024:
tile_size = 384 # Smaller tiles for very large resolutions
elif max_dim > 768:
tile_size = 512
else:
tile_size = 384
spatial_config = SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=64)
if needs_temporal:
# Choose tile size based on frame count
if num_frames > 200:
tile_size, overlap = 32, 8 # Aggressive for very long videos
elif num_frames > 100:
tile_size, overlap = 48, 16
else:
tile_size, overlap = 64, 24
temporal_config = TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap)
return cls(spatial_config=spatial_config, temporal_config=temporal_config)
@dataclass
class DimensionIntervals:
"""Intervals for splitting a single dimension."""
starts: List[int]
ends: List[int]
left_ramps: List[int]
right_ramps: List[int]
def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
"""Split a spatial dimension into intervals."""
if dimension_size <= size:
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)
starts = [i * (size - overlap) for i in range(amount)]
ends = [start + size for start in starts]
ends[-1] = dimension_size
left_ramps = [0] + [overlap] * (amount - 1)
right_ramps = [overlap] * (amount - 1) + [0]
return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps)
def split_in_temporal(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
"""Split a temporal dimension into intervals with causal adjustment."""
if dimension_size <= size:
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
# Start with spatial split
intervals = split_in_spatial(size, overlap, dimension_size)
# Adjust for temporal: starts[1:] -= 1, left_ramps[1:] += 1
starts = intervals.starts.copy()
left_ramps = intervals.left_ramps.copy()
for i in range(1, len(starts)):
starts[i] = starts[i] - 1
left_ramps[i] = left_ramps[i] + 1
return DimensionIntervals(starts=starts, ends=intervals.ends, left_ramps=left_ramps, right_ramps=intervals.right_ramps)
def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
"""Map temporal latent interval to output coordinates and mask."""
start = begin * scale
stop = 1 + (end - 1) * scale
left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0
right_ramp_scaled = right_ramp * scale
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, True)
return slice(start, stop), mask
def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
"""Map spatial latent interval to output coordinates and mask."""
start = begin * scale
stop = end * scale
left_ramp_scaled = left_ramp * scale
right_ramp_scaled = right_ramp * scale
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, False)
return slice(start, stop), mask
def decode_with_tiling(
decoder_fn,
latents: mx.array,
tiling_config: TilingConfig,
spatial_scale: int = 32,
temporal_scale: int = 8,
causal: bool = False,
timestep: Optional[mx.array] = None,
debug: bool = False,
) -> mx.array:
"""Decode latents using tiling to reduce memory usage.
Args:
decoder_fn: Decoder function to call for each tile.
latents: Input latents of shape (B, C, F, H, W).
tiling_config: Tiling configuration.
spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify).
temporal_scale: Temporal scale factor (8 for LTX VAE).
causal: Whether to use causal convolutions.
timestep: Optional timestep for conditioning.
debug: Whether to print debug info.
Returns:
Decoded video.
"""
import gc
b, c, f_latent, h_latent, w_latent = latents.shape
# Compute output shape
out_f = 1 + (f_latent - 1) * temporal_scale
out_h = h_latent * spatial_scale
out_w = w_latent * spatial_scale
# Get tile size and overlap in latent space
if tiling_config.spatial_config is not None:
s_cfg = tiling_config.spatial_config
spatial_tile_size = s_cfg.tile_size_in_pixels // spatial_scale
spatial_overlap = s_cfg.tile_overlap_in_pixels // spatial_scale
else:
spatial_tile_size = max(h_latent, w_latent)
spatial_overlap = 0
if tiling_config.temporal_config is not None:
t_cfg = tiling_config.temporal_config
temporal_tile_size = t_cfg.tile_size_in_frames // temporal_scale
temporal_overlap = t_cfg.tile_overlap_in_frames // temporal_scale
else:
temporal_tile_size = f_latent
temporal_overlap = 0
# Compute intervals for each dimension
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
num_t_tiles = len(temporal_intervals.starts)
num_h_tiles = len(height_intervals.starts)
num_w_tiles = len(width_intervals.starts)
total_tiles = num_t_tiles * num_h_tiles * num_w_tiles
if debug:
print(f"[Tiling] Latent shape: {latents.shape}, Output shape: ({b}, 3, {out_f}, {out_h}, {out_w})")
print(f"[Tiling] Tiles: {num_t_tiles} temporal x {num_h_tiles} height x {num_w_tiles} width = {total_tiles}")
# Initialize output and weight accumulator
# Use float32 for accumulation to avoid precision issues
output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32)
weights = mx.zeros((b, 1, out_f, out_h, out_w), dtype=mx.float32)
mx.eval(output, weights)
tile_idx = 0
for t_idx in range(num_t_tiles):
t_start = temporal_intervals.starts[t_idx]
t_end = temporal_intervals.ends[t_idx]
t_left = temporal_intervals.left_ramps[t_idx]
t_right = temporal_intervals.right_ramps[t_idx]
# Map temporal coordinates
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
for h_idx in range(num_h_tiles):
h_start = height_intervals.starts[h_idx]
h_end = height_intervals.ends[h_idx]
h_left = height_intervals.left_ramps[h_idx]
h_right = height_intervals.right_ramps[h_idx]
# Map height coordinates
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
for w_idx in range(num_w_tiles):
w_start = width_intervals.starts[w_idx]
w_end = width_intervals.ends[w_idx]
w_left = width_intervals.left_ramps[w_idx]
w_right = width_intervals.right_ramps[w_idx]
# Map width coordinates
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
if debug:
print(f"[Tiling] Tile {tile_idx + 1}/{total_tiles}: "
f"latent t=[{t_start},{t_end}) h=[{h_start},{h_end}) w=[{w_start},{w_end})")
# Extract tile latents (small slice)
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
# Decode tile
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False)
mx.eval(tile_output)
# Clear tile_latents reference
del tile_latents
# Get actual decoded dimensions
_, _, decoded_t, decoded_h, decoded_w = tile_output.shape
expected_t = out_t_slice.stop - out_t_slice.start
expected_h = out_h_slice.stop - out_h_slice.start
expected_w = out_w_slice.stop - out_w_slice.start
# Handle potential size mismatches (use minimum)
actual_t = min(decoded_t, expected_t)
actual_h = min(decoded_h, expected_h)
actual_w = min(decoded_w, expected_w)
# Build blend mask
t_mask_slice = t_mask[:actual_t] if len(t_mask) > actual_t else t_mask
h_mask_slice = h_mask[:actual_h] if len(h_mask) > actual_h else h_mask
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
blend_mask = (
t_mask_slice.reshape(1, 1, -1, 1, 1) *
h_mask_slice.reshape(1, 1, 1, -1, 1) *
w_mask_slice.reshape(1, 1, 1, 1, -1)
)
# Slice tile output to match
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
# Clear full tile_output
del tile_output
# Compute output coordinates
t_out_start = out_t_slice.start
t_out_end = t_out_start + actual_t
h_out_start = out_h_slice.start
h_out_end = h_out_start + actual_h
w_out_start = out_w_slice.start
w_out_end = w_out_start + actual_w
# Use direct slice assignment (MLX supports this)
# Weighted accumulation
weighted_tile = tile_output_slice * blend_mask
# Update output using slice assignment
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
)
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
)
# Force evaluation to free memory
mx.eval(output, weights)
# Clean up tile-specific arrays
del tile_output_slice, weighted_tile, blend_mask
del t_mask_slice, h_mask_slice, w_mask_slice
tile_idx += 1
# Periodic garbage collection and cache clearing
if tile_idx % 4 == 0:
gc.collect()
try:
mx.clear_cache()
except Exception:
pass # May not be available on all platforms
# Normalize by weights
weights = mx.maximum(weights, 1e-8)
output = output / weights
mx.eval(output)
# Clean up weights
del weights
gc.collect()
if debug:
print(f"[Tiling] Done. Final shape: {output.shape}")
# Convert back to original dtype if needed
return output.astype(latents.dtype)