add vae tiling
This commit is contained in:
@@ -23,6 +23,7 @@ import mlx.nn as nn
|
||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx.video_vae.ops import unpatchify
|
||||
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
|
||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
@@ -444,6 +445,75 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def decode_tiled(
|
||||
self,
|
||||
sample: mx.array,
|
||||
tiling_config: Optional[TilingConfig] = None,
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
debug: bool = False,
|
||||
) -> mx.array:
|
||||
"""Decode latents using tiling to reduce memory usage.
|
||||
|
||||
This method is useful for decoding large videos that would otherwise
|
||||
cause out-of-memory errors. It divides the latents into tiles,
|
||||
decodes each tile separately, and blends them together.
|
||||
|
||||
Args:
|
||||
sample: Input latents of shape (B, C, F, H, W).
|
||||
tiling_config: Tiling configuration. If None, uses TilingConfig.default().
|
||||
causal: Whether to use causal convolutions.
|
||||
timestep: Optional timestep for conditioning.
|
||||
debug: Whether to print debug info.
|
||||
|
||||
Returns:
|
||||
Decoded video of shape (B, 3, F*8, H*8, W*8).
|
||||
"""
|
||||
if tiling_config is None:
|
||||
tiling_config = TilingConfig.default()
|
||||
|
||||
# Check if tiling is actually needed
|
||||
_, _, f, h, w = sample.shape
|
||||
needs_spatial_tiling = False
|
||||
needs_temporal_tiling = False
|
||||
|
||||
# Spatial scale is 32 (8x VAE upsample + 4x unpatchify)
|
||||
# Temporal scale is 8
|
||||
spatial_scale = 32
|
||||
temporal_scale = 8
|
||||
|
||||
if tiling_config.spatial_config is not None:
|
||||
s_cfg = tiling_config.spatial_config
|
||||
tile_size_latent = s_cfg.tile_size_in_pixels // spatial_scale
|
||||
if h > tile_size_latent or w > tile_size_latent:
|
||||
needs_spatial_tiling = True
|
||||
|
||||
if tiling_config.temporal_config is not None:
|
||||
t_cfg = tiling_config.temporal_config
|
||||
tile_size_latent = t_cfg.tile_size_in_frames // temporal_scale
|
||||
if f > tile_size_latent:
|
||||
needs_temporal_tiling = True
|
||||
|
||||
if not needs_spatial_tiling and not needs_temporal_tiling:
|
||||
# No tiling needed, use regular decode
|
||||
if debug:
|
||||
print("[Tiling] Input fits within tile size, using regular decode")
|
||||
return self(sample, causal=causal, timestep=timestep, debug=debug)
|
||||
|
||||
if debug:
|
||||
print(f"[Tiling] Using tiled decode (spatial={needs_spatial_tiling}, temporal={needs_temporal_tiling})")
|
||||
|
||||
return decode_with_tiling(
|
||||
decoder_fn=self,
|
||||
latents=sample,
|
||||
tiling_config=tiling_config,
|
||||
spatial_scale=32, # VAE spatial: 8x upsampling + 4x unpatchify = 32x
|
||||
temporal_scale=8, # VAE temporal upsampling factor
|
||||
causal=causal,
|
||||
timestep=timestep,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
|
||||
def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder:
|
||||
from pathlib import Path
|
||||
|
||||
Reference in New Issue
Block a user