diff --git a/mlx_video/models/ltx/video_vae/tiling.py b/mlx_video/models/ltx/video_vae/tiling.py index 512edff..72d32e4 100644 --- a/mlx_video/models/ltx/video_vae/tiling.py +++ b/mlx_video/models/ltx/video_vae/tiling.py @@ -283,7 +283,6 @@ def decode_with_tiling( spatial_scale: int = 32, temporal_scale: int = 8, causal: bool = False, - causal_temporal: bool = True, timestep: Optional[mx.array] = None, chunked_conv: bool = False, on_frames_ready: Optional[Callable[[mx.array, int], None]] = None, @@ -297,10 +296,6 @@ def decode_with_tiling( spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify). temporal_scale: Temporal scale factor (8 for LTX VAE). causal: Whether to use causal convolutions. - causal_temporal: Whether the decoder uses causal temporal mapping where - T input frames produce 1+(T-1)*scale output frames. When False, uses - simple scaling where T frames produce T*scale output frames. - Default True (LTX behavior). Set False for non-causal decoders (e.g. Wan2.1). timestep: Optional timestep for conditioning. chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory). on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized. @@ -315,7 +310,7 @@ def decode_with_tiling( b, c, f_latent, h_latent, w_latent = latents.shape # Compute output shape - out_f = (1 + (f_latent - 1) * temporal_scale) if causal_temporal else (f_latent * temporal_scale) + out_f = 1 + (f_latent - 1) * temporal_scale out_h = h_latent * spatial_scale out_w = w_latent * spatial_scale @@ -337,10 +332,7 @@ def decode_with_tiling( temporal_overlap = 0 # Compute intervals for each dimension - if causal_temporal: - temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent) - else: - temporal_intervals = split_in_spatial(temporal_tile_size, temporal_overlap, f_latent) + temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent) height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent) width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent) @@ -363,10 +355,7 @@ def decode_with_tiling( t_right = temporal_intervals.right_ramps[t_idx] # Map temporal coordinates - if causal_temporal: - out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale) - else: - out_t_slice, t_mask = map_spatial_slice(t_start, t_end, t_left, t_right, temporal_scale) + out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale) for h_idx in range(num_h_tiles): h_start = height_intervals.starts[h_idx] @@ -472,10 +461,8 @@ def decode_with_tiling( # Map to output frame index (first frame of next tile's contribution) if next_tile_start_latent == 0: next_tile_start_out = 0 - elif causal_temporal: - next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale else: - next_tile_start_out = next_tile_start_latent * temporal_scale + next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale # We need to track how many frames we've already emitted if not hasattr(decode_with_tiling, '_emitted_frames'): diff --git a/mlx_video/models/wan/tiling.py b/mlx_video/models/wan/tiling.py new file mode 100644 index 0000000..73f2624 --- /dev/null +++ b/mlx_video/models/wan/tiling.py @@ -0,0 +1,281 @@ +"""Wan-specific tiled VAE decoding. + +Re-exports all tiling utilities from the LTX VAE tiling module and provides +a Wan-specific ``decode_with_tiling`` that adds ``causal_temporal`` support +for non-causal temporal decoders (e.g. Wan2.1 where T latent frames → T*scale +output frames rather than LTX's 1+(T-1)*scale mapping). + +# TODO: This function can be refactored to consolidate with +# mlx_video.models.ltx.video_vae.tiling.decode_with_tiling once the +# causal_temporal generalisation is accepted upstream. +""" + +from typing import Callable, Optional + +import mlx.core as mx + +from mlx_video.models.ltx.video_vae.tiling import ( + SpatialTilingConfig, + TemporalTilingConfig, + TilingConfig, + map_spatial_slice, + map_temporal_slice, + split_in_spatial, + split_in_temporal, +) + +__all__ = [ + "SpatialTilingConfig", + "TemporalTilingConfig", + "TilingConfig", + "decode_with_tiling", + "map_spatial_slice", + "map_temporal_slice", + "split_in_spatial", + "split_in_temporal", +] + + +def decode_with_tiling( + decoder_fn, + latents: mx.array, + tiling_config: TilingConfig, + spatial_scale: int = 32, + temporal_scale: int = 8, + causal: bool = False, + causal_temporal: bool = True, + timestep: Optional[mx.array] = None, + chunked_conv: bool = False, + on_frames_ready: Optional[Callable[[mx.array, int], None]] = None, +) -> mx.array: + """Decode latents using tiling to reduce memory usage. + + Args: + decoder_fn: Decoder function to call for each tile. + latents: Input latents of shape (B, C, F, H, W). + tiling_config: Tiling configuration. + spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify). + temporal_scale: Temporal scale factor (8 for LTX VAE). + causal: Whether to use causal convolutions. + causal_temporal: Whether the decoder uses causal temporal mapping where + T input frames produce 1+(T-1)*scale output frames. When False, uses + simple scaling where T frames produce T*scale output frames. + Default True (LTX behavior). Set False for non-causal decoders (e.g. Wan2.1). + timestep: Optional timestep for conditioning. + chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory). + on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized. + frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames. + start_idx: Starting frame index in the full video. + + Returns: + Decoded video. + """ + import gc + + b, c, f_latent, h_latent, w_latent = latents.shape + + # Compute output shape + out_f = (1 + (f_latent - 1) * temporal_scale) if causal_temporal else (f_latent * temporal_scale) + out_h = h_latent * spatial_scale + out_w = w_latent * spatial_scale + + # Get tile size and overlap in latent space + if tiling_config.spatial_config is not None: + s_cfg = tiling_config.spatial_config + spatial_tile_size = s_cfg.tile_size_in_pixels // spatial_scale + spatial_overlap = s_cfg.tile_overlap_in_pixels // spatial_scale + else: + spatial_tile_size = max(h_latent, w_latent) + spatial_overlap = 0 + + if tiling_config.temporal_config is not None: + t_cfg = tiling_config.temporal_config + temporal_tile_size = t_cfg.tile_size_in_frames // temporal_scale + temporal_overlap = t_cfg.tile_overlap_in_frames // temporal_scale + else: + temporal_tile_size = f_latent + temporal_overlap = 0 + + # Compute intervals for each dimension + if causal_temporal: + temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent) + else: + temporal_intervals = split_in_spatial(temporal_tile_size, temporal_overlap, f_latent) + height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent) + width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent) + + num_t_tiles = len(temporal_intervals.starts) + num_h_tiles = len(height_intervals.starts) + num_w_tiles = len(width_intervals.starts) + total_tiles = num_t_tiles * num_h_tiles * num_w_tiles # noqa: F841 + + # Initialize output and weight accumulator + # Use float32 for accumulation to avoid precision issues + output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32) + weights = mx.zeros((b, 1, out_f, out_h, out_w), dtype=mx.float32) + mx.eval(output, weights) + + tile_idx = 0 + for t_idx in range(num_t_tiles): + t_start = temporal_intervals.starts[t_idx] + t_end = temporal_intervals.ends[t_idx] + t_left = temporal_intervals.left_ramps[t_idx] + t_right = temporal_intervals.right_ramps[t_idx] + + # Map temporal coordinates + if causal_temporal: + out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale) + else: + out_t_slice, t_mask = map_spatial_slice(t_start, t_end, t_left, t_right, temporal_scale) + + for h_idx in range(num_h_tiles): + h_start = height_intervals.starts[h_idx] + h_end = height_intervals.ends[h_idx] + h_left = height_intervals.left_ramps[h_idx] + h_right = height_intervals.right_ramps[h_idx] + + # Map height coordinates + out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale) + + for w_idx in range(num_w_tiles): + w_start = width_intervals.starts[w_idx] + w_end = width_intervals.ends[w_idx] + w_left = width_intervals.left_ramps[w_idx] + w_right = width_intervals.right_ramps[w_idx] + + # Map width coordinates + out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale) + + # Extract tile latents (small slice) + tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end] + + # Decode tile + tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv) + mx.eval(tile_output) + + # Clear tile_latents reference + del tile_latents + + # Get actual decoded dimensions + _, _, decoded_t, decoded_h, decoded_w = tile_output.shape + expected_t = out_t_slice.stop - out_t_slice.start + expected_h = out_h_slice.stop - out_h_slice.start + expected_w = out_w_slice.stop - out_w_slice.start + + # Handle potential size mismatches (use minimum) + actual_t = min(decoded_t, expected_t) + actual_h = min(decoded_h, expected_h) + actual_w = min(decoded_w, expected_w) + + # Build blend mask + t_mask_slice = t_mask[:actual_t] if len(t_mask) > actual_t else t_mask + h_mask_slice = h_mask[:actual_h] if len(h_mask) > actual_h else h_mask + w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask + + blend_mask = ( + t_mask_slice.reshape(1, 1, -1, 1, 1) * + h_mask_slice.reshape(1, 1, 1, -1, 1) * + w_mask_slice.reshape(1, 1, 1, 1, -1) + ) + + # Slice tile output to match + tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32) + + # Clear full tile_output + del tile_output + + # Compute output coordinates + t_out_start = out_t_slice.start + t_out_end = t_out_start + actual_t + h_out_start = out_h_slice.start + h_out_end = h_out_start + actual_h + w_out_start = out_w_slice.start + w_out_end = w_out_start + actual_w + + # Weighted accumulation + weighted_tile = tile_output_slice * blend_mask + + # Update output using slice assignment + output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = ( + output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile + ) + weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = ( + weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask + ) + + # Force evaluation to free memory + mx.eval(output, weights) + + # Clean up tile-specific arrays + del tile_output_slice, weighted_tile, blend_mask + del t_mask_slice, h_mask_slice, w_mask_slice + + tile_idx += 1 + + # Periodic garbage collection and cache clearing + if tile_idx % 4 == 0: + gc.collect() + try: + mx.clear_cache() + except Exception: + pass # May not be available on all platforms + + # After completing all spatial tiles for this temporal tile, + # check if any frames are now finalized (no future tiles will contribute) + if on_frames_ready is not None and num_t_tiles > 1: + # Determine the finalized frame boundary + # Frames before the start of the next tile's output region are finalized + if t_idx < num_t_tiles - 1: + # Next tile starts at temporal_intervals.starts[t_idx + 1] + next_tile_start_latent = temporal_intervals.starts[t_idx + 1] + # Map to output frame index (first frame of next tile's contribution) + if next_tile_start_latent == 0: + next_tile_start_out = 0 + elif causal_temporal: + next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale + else: + next_tile_start_out = next_tile_start_latent * temporal_scale + + # We need to track how many frames we've already emitted + if not hasattr(decode_with_tiling, '_emitted_frames'): + decode_with_tiling._emitted_frames = 0 + emitted = decode_with_tiling._emitted_frames + + if next_tile_start_out > emitted: + # Normalize and emit frames [emitted, next_tile_start_out) + finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :] + finalized_weights = mx.maximum(finalized_weights, 1e-8) + finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights + finalized_output = finalized_output.astype(latents.dtype) + mx.eval(finalized_output) + + on_frames_ready(finalized_output, emitted) + decode_with_tiling._emitted_frames = next_tile_start_out + + del finalized_output, finalized_weights + gc.collect() + + # Normalize by weights + weights = mx.maximum(weights, 1e-8) + output = output / weights + mx.eval(output) + + # Emit remaining frames if callback provided + if on_frames_ready is not None: + emitted = getattr(decode_with_tiling, '_emitted_frames', 0) + if emitted < out_f: + remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype) + mx.eval(remaining_output) + on_frames_ready(remaining_output, emitted) + del remaining_output + + # Reset emitted frames counter for next call + if hasattr(decode_with_tiling, '_emitted_frames'): + del decode_with_tiling._emitted_frames + + # Clean up weights + del weights + gc.collect() + + # Convert back to original dtype if needed + return output.astype(latents.dtype) diff --git a/mlx_video/models/wan/vae.py b/mlx_video/models/wan/vae.py index f4d83e7..faa2372 100644 --- a/mlx_video/models/wan/vae.py +++ b/mlx_video/models/wan/vae.py @@ -549,7 +549,7 @@ class WanVAE(nn.Module): Returns: Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1] """ - from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling + from mlx_video.models.wan.tiling import TilingConfig, decode_with_tiling if tiling_config is None: tiling_config = TilingConfig.default() diff --git a/mlx_video/models/wan/vae22.py b/mlx_video/models/wan/vae22.py index 72a4f04..a1b233f 100644 --- a/mlx_video/models/wan/vae22.py +++ b/mlx_video/models/wan/vae22.py @@ -847,7 +847,7 @@ class Wan22VAEDecoder(nn.Module): Returns: video: [B, T', H', W', 3] decoded RGB in [-1, 1] """ - from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling + from mlx_video.models.wan.tiling import TilingConfig, decode_with_tiling if tiling_config is None: tiling_config = TilingConfig.default()