add vae tiling

This commit is contained in:
Prince Canuma
2026-01-17 07:51:54 +01:00
parent f607112407
commit e4cdbb7eab
6 changed files with 632 additions and 5 deletions

View File

@@ -23,6 +23,7 @@ import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx.video_vae.ops import unpatchify
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
def get_timestep_embedding(
@@ -444,6 +445,75 @@ class LTX2VideoDecoder(nn.Module):
return x
def decode_tiled(
self,
sample: mx.array,
tiling_config: Optional[TilingConfig] = None,
causal: bool = False,
timestep: Optional[mx.array] = None,
debug: bool = False,
) -> mx.array:
"""Decode latents using tiling to reduce memory usage.
This method is useful for decoding large videos that would otherwise
cause out-of-memory errors. It divides the latents into tiles,
decodes each tile separately, and blends them together.
Args:
sample: Input latents of shape (B, C, F, H, W).
tiling_config: Tiling configuration. If None, uses TilingConfig.default().
causal: Whether to use causal convolutions.
timestep: Optional timestep for conditioning.
debug: Whether to print debug info.
Returns:
Decoded video of shape (B, 3, F*8, H*8, W*8).
"""
if tiling_config is None:
tiling_config = TilingConfig.default()
# Check if tiling is actually needed
_, _, f, h, w = sample.shape
needs_spatial_tiling = False
needs_temporal_tiling = False
# Spatial scale is 32 (8x VAE upsample + 4x unpatchify)
# Temporal scale is 8
spatial_scale = 32
temporal_scale = 8
if tiling_config.spatial_config is not None:
s_cfg = tiling_config.spatial_config
tile_size_latent = s_cfg.tile_size_in_pixels // spatial_scale
if h > tile_size_latent or w > tile_size_latent:
needs_spatial_tiling = True
if tiling_config.temporal_config is not None:
t_cfg = tiling_config.temporal_config
tile_size_latent = t_cfg.tile_size_in_frames // temporal_scale
if f > tile_size_latent:
needs_temporal_tiling = True
if not needs_spatial_tiling and not needs_temporal_tiling:
# No tiling needed, use regular decode
if debug:
print("[Tiling] Input fits within tile size, using regular decode")
return self(sample, causal=causal, timestep=timestep, debug=debug)
if debug:
print(f"[Tiling] Using tiled decode (spatial={needs_spatial_tiling}, temporal={needs_temporal_tiling})")
return decode_with_tiling(
decoder_fn=self,
latents=sample,
tiling_config=tiling_config,
spatial_scale=32, # VAE spatial: 8x upsampling + 4x unpatchify = 32x
temporal_scale=8, # VAE temporal upsampling factor
causal=causal,
timestep=timestep,
debug=debug,
)
def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder:
from pathlib import Path