diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 8a6c5d5..55f0965 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -27,6 +27,7 @@ from mlx_video.convert import sanitize_transformer_weights, sanitize_vae_encoder from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder +from mlx_video.models.ltx.video_vae.tiling import TilingConfig from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning from mlx_video.conditioning.latent import LatentState, create_initial_state, apply_denoise_mask, add_noise_with_state @@ -207,6 +208,7 @@ def generate_video( image: Optional[str] = None, image_strength: float = 1.0, image_frame_idx: int = 0, + tiling: str = "auto", ): """Generate video from text prompt, optionally conditioned on an image. @@ -228,6 +230,14 @@ def generate_video( image: Path to conditioning image for I2V (Image-to-Video) image_strength: Conditioning strength (1.0 = full denoise, 0.0 = keep original) image_frame_idx: Frame index to condition (0 = first frame) + tiling: Tiling mode for VAE decoding. Options: + - "auto": Automatically determine based on video size (default) + - "none": Disable tiling + - "default": 512px spatial, 64 frame temporal + - "aggressive": 256px spatial, 32 frame temporal (lowest memory) + - "conservative": 768px spatial, 96 frame temporal (faster) + - "spatial": Spatial tiling only + - "temporal": Temporal tiling only """ start_time = time.time() @@ -435,9 +445,36 @@ def generate_video( del transformer mx.clear_cache() - # Decode to video + # Decode to video with tiling print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}") - video = vae_decoder(latents) + + # Select tiling configuration + if tiling == "none": + tiling_config = None + elif tiling == "auto": + tiling_config = TilingConfig.auto(height, width, num_frames) + elif tiling == "default": + tiling_config = TilingConfig.default() + elif tiling == "aggressive": + tiling_config = TilingConfig.aggressive() + elif tiling == "conservative": + tiling_config = TilingConfig.conservative() + elif tiling == "spatial": + tiling_config = TilingConfig.spatial_only() + elif tiling == "temporal": + tiling_config = TilingConfig.temporal_only() + else: + print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}") + tiling_config = TilingConfig.auto(height, width, num_frames) + + if tiling_config is not None: + spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none" + temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none" + print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}") + video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, debug=verbose) + else: + print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}") + video = vae_decoder(latents) mx.eval(video) mx.clear_cache() @@ -594,6 +631,15 @@ Examples: default=0, help="Frame index to condition for I2V (0 = first frame, default: 0)" ) + parser.add_argument( + "--tiling", + type=str, + default="auto", + choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"], + help="Tiling mode for VAE decoding (default: auto). " + "auto=based on video size, none=disabled, default=512px/64f, " + "aggressive=256px/32f (lowest memory), conservative=768px/96f, spatial=spatial only, temporal=temporal only" + ) args = parser.parse_args() generate_video( diff --git a/mlx_video/generate_av.py b/mlx_video/generate_av.py index a481c23..b7cba4a 100644 --- a/mlx_video/generate_av.py +++ b/mlx_video/generate_av.py @@ -30,6 +30,7 @@ from mlx_video.convert import sanitize_transformer_weights, sanitize_audio_vae_w from mlx_video.utils import to_denoised, get_model_path, load_image, prepare_image_for_encoding from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder +from mlx_video.models.ltx.video_vae.tiling import TilingConfig from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning from mlx_video.conditioning.latent import LatentState, apply_denoise_mask @@ -363,6 +364,7 @@ def generate_video_with_audio( image: Optional[str] = None, image_strength: float = 1.0, image_frame_idx: int = 0, + tiling: str = "auto", ): """Generate video with synchronized audio from text prompt, optionally conditioned on an image. @@ -384,6 +386,7 @@ def generate_video_with_audio( image: Path to conditioning image for I2V image_strength: Conditioning strength (1.0 = full denoise) image_frame_idx: Frame index to condition (0 = first frame) + tiling: Tiling mode for VAE decoding (auto/none/default/aggressive/conservative/spatial/temporal) """ start_time = time.time() @@ -623,9 +626,36 @@ def generate_video_with_audio( del transformer mx.clear_cache() - # Decode video + # Decode video with tiling print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}") - video = vae_decoder(video_latents) + + # Select tiling configuration + if tiling == "none": + tiling_config = None + elif tiling == "auto": + tiling_config = TilingConfig.auto(height, width, num_frames) + elif tiling == "default": + tiling_config = TilingConfig.default() + elif tiling == "aggressive": + tiling_config = TilingConfig.aggressive() + elif tiling == "conservative": + tiling_config = TilingConfig.conservative() + elif tiling == "spatial": + tiling_config = TilingConfig.spatial_only() + elif tiling == "temporal": + tiling_config = TilingConfig.temporal_only() + else: + print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}") + tiling_config = TilingConfig.auto(height, width, num_frames) + + if tiling_config is not None: + spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none" + temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none" + print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}") + video = vae_decoder.decode_tiled(video_latents, tiling_config=tiling_config, debug=verbose) + else: + print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}") + video = vae_decoder(video_latents) mx.eval(video) # Convert video to uint8 frames @@ -762,6 +792,11 @@ Examples: help="Conditioning strength for I2V (1.0 = full denoise, 0.0 = keep original, default: 1.0)") parser.add_argument("--image-frame-idx", type=int, default=0, help="Frame index to condition for I2V (0 = first frame, default: 0)") + parser.add_argument("--tiling", type=str, default="auto", + choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"], + help="Tiling mode for VAE decoding (default: auto). " + "auto=based on size, none=disabled, default=512px/64f, " + "aggressive=256px/32f (lowest memory), conservative=768px/96f") args = parser.parse_args() @@ -783,6 +818,7 @@ Examples: image=args.image, image_strength=args.image_strength, image_frame_idx=args.image_frame_idx, + tiling=args.tiling, ) diff --git a/mlx_video/models/ltx/text_encoder.py b/mlx_video/models/ltx/text_encoder.py index bcd7cf4..29993fb 100644 --- a/mlx_video/models/ltx/text_encoder.py +++ b/mlx_video/models/ltx/text_encoder.py @@ -918,7 +918,7 @@ class LTX2TextEncoder(nn.Module): if response.token == 1 or response.token == 107: # EOS tokens break - + mx.clear_cache() # Decode only the new tokens diff --git a/mlx_video/models/ltx/video_vae/__init__.py b/mlx_video/models/ltx/video_vae/__init__.py index d620f17..bac1644 100644 --- a/mlx_video/models/ltx/video_vae/__init__.py +++ b/mlx_video/models/ltx/video_vae/__init__.py @@ -1,3 +1,8 @@ from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder, encode_image from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder, LTX2VideoDecoder +from mlx_video.models.ltx.video_vae.tiling import ( + TilingConfig, + SpatialTilingConfig, + TemporalTilingConfig, +) diff --git a/mlx_video/models/ltx/video_vae/decoder.py b/mlx_video/models/ltx/video_vae/decoder.py index 1bc0983..390a92b 100644 --- a/mlx_video/models/ltx/video_vae/decoder.py +++ b/mlx_video/models/ltx/video_vae/decoder.py @@ -23,6 +23,7 @@ import mlx.nn as nn from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType from mlx_video.models.ltx.video_vae.ops import unpatchify from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample +from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling def get_timestep_embedding( @@ -444,6 +445,75 @@ class LTX2VideoDecoder(nn.Module): return x + def decode_tiled( + self, + sample: mx.array, + tiling_config: Optional[TilingConfig] = None, + causal: bool = False, + timestep: Optional[mx.array] = None, + debug: bool = False, + ) -> mx.array: + """Decode latents using tiling to reduce memory usage. + + This method is useful for decoding large videos that would otherwise + cause out-of-memory errors. It divides the latents into tiles, + decodes each tile separately, and blends them together. + + Args: + sample: Input latents of shape (B, C, F, H, W). + tiling_config: Tiling configuration. If None, uses TilingConfig.default(). + causal: Whether to use causal convolutions. + timestep: Optional timestep for conditioning. + debug: Whether to print debug info. + + Returns: + Decoded video of shape (B, 3, F*8, H*8, W*8). + """ + if tiling_config is None: + tiling_config = TilingConfig.default() + + # Check if tiling is actually needed + _, _, f, h, w = sample.shape + needs_spatial_tiling = False + needs_temporal_tiling = False + + # Spatial scale is 32 (8x VAE upsample + 4x unpatchify) + # Temporal scale is 8 + spatial_scale = 32 + temporal_scale = 8 + + if tiling_config.spatial_config is not None: + s_cfg = tiling_config.spatial_config + tile_size_latent = s_cfg.tile_size_in_pixels // spatial_scale + if h > tile_size_latent or w > tile_size_latent: + needs_spatial_tiling = True + + if tiling_config.temporal_config is not None: + t_cfg = tiling_config.temporal_config + tile_size_latent = t_cfg.tile_size_in_frames // temporal_scale + if f > tile_size_latent: + needs_temporal_tiling = True + + if not needs_spatial_tiling and not needs_temporal_tiling: + # No tiling needed, use regular decode + if debug: + print("[Tiling] Input fits within tile size, using regular decode") + return self(sample, causal=causal, timestep=timestep, debug=debug) + + if debug: + print(f"[Tiling] Using tiled decode (spatial={needs_spatial_tiling}, temporal={needs_temporal_tiling})") + + return decode_with_tiling( + decoder_fn=self, + latents=sample, + tiling_config=tiling_config, + spatial_scale=32, # VAE spatial: 8x upsampling + 4x unpatchify = 32x + temporal_scale=8, # VAE temporal upsampling factor + causal=causal, + timestep=timestep, + debug=debug, + ) + def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder: from pathlib import Path diff --git a/mlx_video/models/ltx/video_vae/tiling.py b/mlx_video/models/ltx/video_vae/tiling.py new file mode 100644 index 0000000..20950fc --- /dev/null +++ b/mlx_video/models/ltx/video_vae/tiling.py @@ -0,0 +1,470 @@ +"""VAE Tiling Configuration for decoding large videos. + +Implements spatial and temporal tiling with trapezoidal blending masks +to decode large videos without running out of memory. + +Default configuration (from PyTorch): +- Spatial: 512px tiles with 64px overlap +- Temporal: 64 frames with 24 frame overlap +""" + +from dataclasses import dataclass, replace +from typing import List, Optional, Tuple + +import mlx.core as mx + + +def compute_trapezoidal_mask_1d( + length: int, + ramp_left: int, + ramp_right: int, + left_starts_from_0: bool = False, +) -> mx.array: + """Generate a 1D trapezoidal blending mask with linear ramps. + + Args: + length: Output length of the mask. + ramp_left: Fade-in length on the left. + ramp_right: Fade-out length on the right. + left_starts_from_0: Whether the ramp starts from 0 or first non-zero value. + Useful for temporal tiles where the first tile is causal. + + Returns: + A 1D array of shape (length,) with values in [0, 1]. + """ + if length <= 0: + raise ValueError("Mask length must be positive.") + + ramp_left = max(0, min(ramp_left, length)) + ramp_right = max(0, min(ramp_right, length)) + + # Start with ones + mask = [1.0] * length + + # Apply left ramp (fade in) + if ramp_left > 0: + interval_length = ramp_left + 1 if left_starts_from_0 else ramp_left + 2 + # Create fade_in values using linspace logic + fade_in_full = [i / (interval_length - 1) for i in range(interval_length)] + fade_in = fade_in_full[:-1] # Remove last element + if not left_starts_from_0: + fade_in = fade_in[1:] # Remove first element too + for i in range(min(ramp_left, len(fade_in))): + mask[i] *= fade_in[i] + + # Apply right ramp (fade out) + if ramp_right > 0: + # Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1] + fade_out = [(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)] + for i in range(ramp_right): + mask[length - ramp_right + i] *= fade_out[i] + + return mx.clip(mx.array(mask), 0, 1) + + +@dataclass(frozen=True) +class SpatialTilingConfig: + """Configuration for dividing each frame into spatial tiles with optional overlap.""" + + tile_size_in_pixels: int + tile_overlap_in_pixels: int = 0 + + def __post_init__(self) -> None: + if self.tile_size_in_pixels < 64: + raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}") + if self.tile_size_in_pixels % 32 != 0: + raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}") + if self.tile_overlap_in_pixels % 32 != 0: + raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}") + if self.tile_overlap_in_pixels >= self.tile_size_in_pixels: + raise ValueError( + f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}" + ) + + +@dataclass(frozen=True) +class TemporalTilingConfig: + """Configuration for dividing a video into temporal tiles.""" + + tile_size_in_frames: int + tile_overlap_in_frames: int = 0 + + def __post_init__(self) -> None: + if self.tile_size_in_frames < 16: + raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}") + if self.tile_size_in_frames % 8 != 0: + raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}") + if self.tile_overlap_in_frames % 8 != 0: + raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}") + if self.tile_overlap_in_frames >= self.tile_size_in_frames: + raise ValueError( + f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}" + ) + + +@dataclass(frozen=True) +class TilingConfig: + """Configuration for splitting video into tiles with optional overlap.""" + + spatial_config: Optional[SpatialTilingConfig] = None + temporal_config: Optional[TemporalTilingConfig] = None + + @classmethod + def default(cls) -> "TilingConfig": + """Default tiling: 512px spatial, 64 frame temporal.""" + return cls( + spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64), + temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24), + ) + + @classmethod + def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig": + """Spatial tiling only (for short videos with large resolution).""" + return cls( + spatial_config=SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap), + temporal_config=None, + ) + + @classmethod + def temporal_only(cls, tile_size: int = 64, overlap: int = 24) -> "TilingConfig": + """Temporal tiling only (for long videos with small resolution).""" + return cls( + spatial_config=None, + temporal_config=TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap), + ) + + @classmethod + def aggressive(cls) -> "TilingConfig": + """Aggressive tiling for very large videos (smaller tiles, much lower memory).""" + return cls( + spatial_config=SpatialTilingConfig(tile_size_in_pixels=256, tile_overlap_in_pixels=64), + temporal_config=TemporalTilingConfig(tile_size_in_frames=32, tile_overlap_in_frames=8), + ) + + @classmethod + def conservative(cls) -> "TilingConfig": + """Conservative tiling (larger tiles, less memory savings but faster).""" + return cls( + spatial_config=SpatialTilingConfig(tile_size_in_pixels=768, tile_overlap_in_pixels=64), + temporal_config=TemporalTilingConfig(tile_size_in_frames=96, tile_overlap_in_frames=24), + ) + + @classmethod + def auto( + cls, + height: int, + width: int, + num_frames: int, + spatial_threshold: int = 512, + temporal_threshold: int = 65, + ) -> Optional["TilingConfig"]: + """Automatically determine tiling config based on video dimensions. + + Args: + height: Video height in pixels + width: Video width in pixels + num_frames: Number of video frames + spatial_threshold: Enable spatial tiling if either dimension exceeds this + temporal_threshold: Enable temporal tiling if frames exceed this + + Returns: + TilingConfig if tiling is needed, None otherwise + """ + needs_spatial = height > spatial_threshold or width > spatial_threshold + needs_temporal = num_frames > temporal_threshold + + if not needs_spatial and not needs_temporal: + return None + + # Estimate memory requirement (rough heuristic) + # Output size in bytes (float32): B * 3 * F * H * W * 4 + estimated_output_gb = (3 * num_frames * height * width * 4) / (1024**3) + + # For very large videos, use aggressive tiling + if estimated_output_gb > 2.0 or (height * width > 768 * 1024 and num_frames > 100): + return cls.aggressive() + + spatial_config = None + temporal_config = None + + if needs_spatial: + # Choose tile size based on resolution + max_dim = max(height, width) + if max_dim > 1024: + tile_size = 384 # Smaller tiles for very large resolutions + elif max_dim > 768: + tile_size = 512 + else: + tile_size = 384 + spatial_config = SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=64) + + if needs_temporal: + # Choose tile size based on frame count + if num_frames > 200: + tile_size, overlap = 32, 8 # Aggressive for very long videos + elif num_frames > 100: + tile_size, overlap = 48, 16 + else: + tile_size, overlap = 64, 24 + temporal_config = TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap) + + return cls(spatial_config=spatial_config, temporal_config=temporal_config) + + +@dataclass +class DimensionIntervals: + """Intervals for splitting a single dimension.""" + starts: List[int] + ends: List[int] + left_ramps: List[int] + right_ramps: List[int] + + +def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionIntervals: + """Split a spatial dimension into intervals.""" + if dimension_size <= size: + return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]) + + amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap) + starts = [i * (size - overlap) for i in range(amount)] + ends = [start + size for start in starts] + ends[-1] = dimension_size + left_ramps = [0] + [overlap] * (amount - 1) + right_ramps = [overlap] * (amount - 1) + [0] + + return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps) + + +def split_in_temporal(size: int, overlap: int, dimension_size: int) -> DimensionIntervals: + """Split a temporal dimension into intervals with causal adjustment.""" + if dimension_size <= size: + return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]) + + # Start with spatial split + intervals = split_in_spatial(size, overlap, dimension_size) + + # Adjust for temporal: starts[1:] -= 1, left_ramps[1:] += 1 + starts = intervals.starts.copy() + left_ramps = intervals.left_ramps.copy() + + for i in range(1, len(starts)): + starts[i] = starts[i] - 1 + left_ramps[i] = left_ramps[i] + 1 + + return DimensionIntervals(starts=starts, ends=intervals.ends, left_ramps=left_ramps, right_ramps=intervals.right_ramps) + + +def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]: + """Map temporal latent interval to output coordinates and mask.""" + start = begin * scale + stop = 1 + (end - 1) * scale + left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0 + right_ramp_scaled = right_ramp * scale + + mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, True) + return slice(start, stop), mask + + +def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]: + """Map spatial latent interval to output coordinates and mask.""" + start = begin * scale + stop = end * scale + left_ramp_scaled = left_ramp * scale + right_ramp_scaled = right_ramp * scale + + mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, False) + return slice(start, stop), mask + + +def decode_with_tiling( + decoder_fn, + latents: mx.array, + tiling_config: TilingConfig, + spatial_scale: int = 32, + temporal_scale: int = 8, + causal: bool = False, + timestep: Optional[mx.array] = None, + debug: bool = False, +) -> mx.array: + """Decode latents using tiling to reduce memory usage. + + Args: + decoder_fn: Decoder function to call for each tile. + latents: Input latents of shape (B, C, F, H, W). + tiling_config: Tiling configuration. + spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify). + temporal_scale: Temporal scale factor (8 for LTX VAE). + causal: Whether to use causal convolutions. + timestep: Optional timestep for conditioning. + debug: Whether to print debug info. + + Returns: + Decoded video. + """ + import gc + + b, c, f_latent, h_latent, w_latent = latents.shape + + # Compute output shape + out_f = 1 + (f_latent - 1) * temporal_scale + out_h = h_latent * spatial_scale + out_w = w_latent * spatial_scale + + # Get tile size and overlap in latent space + if tiling_config.spatial_config is not None: + s_cfg = tiling_config.spatial_config + spatial_tile_size = s_cfg.tile_size_in_pixels // spatial_scale + spatial_overlap = s_cfg.tile_overlap_in_pixels // spatial_scale + else: + spatial_tile_size = max(h_latent, w_latent) + spatial_overlap = 0 + + if tiling_config.temporal_config is not None: + t_cfg = tiling_config.temporal_config + temporal_tile_size = t_cfg.tile_size_in_frames // temporal_scale + temporal_overlap = t_cfg.tile_overlap_in_frames // temporal_scale + else: + temporal_tile_size = f_latent + temporal_overlap = 0 + + # Compute intervals for each dimension + temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent) + height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent) + width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent) + + num_t_tiles = len(temporal_intervals.starts) + num_h_tiles = len(height_intervals.starts) + num_w_tiles = len(width_intervals.starts) + total_tiles = num_t_tiles * num_h_tiles * num_w_tiles + + if debug: + print(f"[Tiling] Latent shape: {latents.shape}, Output shape: ({b}, 3, {out_f}, {out_h}, {out_w})") + print(f"[Tiling] Tiles: {num_t_tiles} temporal x {num_h_tiles} height x {num_w_tiles} width = {total_tiles}") + + # Initialize output and weight accumulator + # Use float32 for accumulation to avoid precision issues + output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32) + weights = mx.zeros((b, 1, out_f, out_h, out_w), dtype=mx.float32) + mx.eval(output, weights) + + tile_idx = 0 + for t_idx in range(num_t_tiles): + t_start = temporal_intervals.starts[t_idx] + t_end = temporal_intervals.ends[t_idx] + t_left = temporal_intervals.left_ramps[t_idx] + t_right = temporal_intervals.right_ramps[t_idx] + + # Map temporal coordinates + out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale) + + for h_idx in range(num_h_tiles): + h_start = height_intervals.starts[h_idx] + h_end = height_intervals.ends[h_idx] + h_left = height_intervals.left_ramps[h_idx] + h_right = height_intervals.right_ramps[h_idx] + + # Map height coordinates + out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale) + + for w_idx in range(num_w_tiles): + w_start = width_intervals.starts[w_idx] + w_end = width_intervals.ends[w_idx] + w_left = width_intervals.left_ramps[w_idx] + w_right = width_intervals.right_ramps[w_idx] + + # Map width coordinates + out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale) + + if debug: + print(f"[Tiling] Tile {tile_idx + 1}/{total_tiles}: " + f"latent t=[{t_start},{t_end}) h=[{h_start},{h_end}) w=[{w_start},{w_end})") + + # Extract tile latents (small slice) + tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end] + + # Decode tile + tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False) + mx.eval(tile_output) + + # Clear tile_latents reference + del tile_latents + + # Get actual decoded dimensions + _, _, decoded_t, decoded_h, decoded_w = tile_output.shape + expected_t = out_t_slice.stop - out_t_slice.start + expected_h = out_h_slice.stop - out_h_slice.start + expected_w = out_w_slice.stop - out_w_slice.start + + # Handle potential size mismatches (use minimum) + actual_t = min(decoded_t, expected_t) + actual_h = min(decoded_h, expected_h) + actual_w = min(decoded_w, expected_w) + + # Build blend mask + t_mask_slice = t_mask[:actual_t] if len(t_mask) > actual_t else t_mask + h_mask_slice = h_mask[:actual_h] if len(h_mask) > actual_h else h_mask + w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask + + blend_mask = ( + t_mask_slice.reshape(1, 1, -1, 1, 1) * + h_mask_slice.reshape(1, 1, 1, -1, 1) * + w_mask_slice.reshape(1, 1, 1, 1, -1) + ) + + # Slice tile output to match + tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32) + + # Clear full tile_output + del tile_output + + # Compute output coordinates + t_out_start = out_t_slice.start + t_out_end = t_out_start + actual_t + h_out_start = out_h_slice.start + h_out_end = h_out_start + actual_h + w_out_start = out_w_slice.start + w_out_end = w_out_start + actual_w + + # Use direct slice assignment (MLX supports this) + # Weighted accumulation + weighted_tile = tile_output_slice * blend_mask + + # Update output using slice assignment + output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = ( + output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile + ) + weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = ( + weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask + ) + + # Force evaluation to free memory + mx.eval(output, weights) + + # Clean up tile-specific arrays + del tile_output_slice, weighted_tile, blend_mask + del t_mask_slice, h_mask_slice, w_mask_slice + + tile_idx += 1 + + # Periodic garbage collection and cache clearing + if tile_idx % 4 == 0: + gc.collect() + try: + mx.clear_cache() + except Exception: + pass # May not be available on all platforms + + # Normalize by weights + weights = mx.maximum(weights, 1e-8) + output = output / weights + mx.eval(output) + + # Clean up weights + del weights + gc.collect() + + if debug: + print(f"[Tiling] Done. Final shape: {output.shape}") + + # Convert back to original dtype if needed + return output.astype(latents.dtype)