refactor(wan): move causal_temporal tiling to wan/tiling.py
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -283,7 +283,6 @@ def decode_with_tiling(
|
|||||||
spatial_scale: int = 32,
|
spatial_scale: int = 32,
|
||||||
temporal_scale: int = 8,
|
temporal_scale: int = 8,
|
||||||
causal: bool = False,
|
causal: bool = False,
|
||||||
causal_temporal: bool = True,
|
|
||||||
timestep: Optional[mx.array] = None,
|
timestep: Optional[mx.array] = None,
|
||||||
chunked_conv: bool = False,
|
chunked_conv: bool = False,
|
||||||
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
|
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
|
||||||
@@ -297,10 +296,6 @@ def decode_with_tiling(
|
|||||||
spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify).
|
spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify).
|
||||||
temporal_scale: Temporal scale factor (8 for LTX VAE).
|
temporal_scale: Temporal scale factor (8 for LTX VAE).
|
||||||
causal: Whether to use causal convolutions.
|
causal: Whether to use causal convolutions.
|
||||||
causal_temporal: Whether the decoder uses causal temporal mapping where
|
|
||||||
T input frames produce 1+(T-1)*scale output frames. When False, uses
|
|
||||||
simple scaling where T frames produce T*scale output frames.
|
|
||||||
Default True (LTX behavior). Set False for non-causal decoders (e.g. Wan2.1).
|
|
||||||
timestep: Optional timestep for conditioning.
|
timestep: Optional timestep for conditioning.
|
||||||
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
|
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
|
||||||
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
|
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
|
||||||
@@ -315,7 +310,7 @@ def decode_with_tiling(
|
|||||||
b, c, f_latent, h_latent, w_latent = latents.shape
|
b, c, f_latent, h_latent, w_latent = latents.shape
|
||||||
|
|
||||||
# Compute output shape
|
# Compute output shape
|
||||||
out_f = (1 + (f_latent - 1) * temporal_scale) if causal_temporal else (f_latent * temporal_scale)
|
out_f = 1 + (f_latent - 1) * temporal_scale
|
||||||
out_h = h_latent * spatial_scale
|
out_h = h_latent * spatial_scale
|
||||||
out_w = w_latent * spatial_scale
|
out_w = w_latent * spatial_scale
|
||||||
|
|
||||||
@@ -337,10 +332,7 @@ def decode_with_tiling(
|
|||||||
temporal_overlap = 0
|
temporal_overlap = 0
|
||||||
|
|
||||||
# Compute intervals for each dimension
|
# Compute intervals for each dimension
|
||||||
if causal_temporal:
|
|
||||||
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
|
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
|
||||||
else:
|
|
||||||
temporal_intervals = split_in_spatial(temporal_tile_size, temporal_overlap, f_latent)
|
|
||||||
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
|
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
|
||||||
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
|
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
|
||||||
|
|
||||||
@@ -363,10 +355,7 @@ def decode_with_tiling(
|
|||||||
t_right = temporal_intervals.right_ramps[t_idx]
|
t_right = temporal_intervals.right_ramps[t_idx]
|
||||||
|
|
||||||
# Map temporal coordinates
|
# Map temporal coordinates
|
||||||
if causal_temporal:
|
|
||||||
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
||||||
else:
|
|
||||||
out_t_slice, t_mask = map_spatial_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
|
||||||
|
|
||||||
for h_idx in range(num_h_tiles):
|
for h_idx in range(num_h_tiles):
|
||||||
h_start = height_intervals.starts[h_idx]
|
h_start = height_intervals.starts[h_idx]
|
||||||
@@ -472,10 +461,8 @@ def decode_with_tiling(
|
|||||||
# Map to output frame index (first frame of next tile's contribution)
|
# Map to output frame index (first frame of next tile's contribution)
|
||||||
if next_tile_start_latent == 0:
|
if next_tile_start_latent == 0:
|
||||||
next_tile_start_out = 0
|
next_tile_start_out = 0
|
||||||
elif causal_temporal:
|
|
||||||
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
|
|
||||||
else:
|
else:
|
||||||
next_tile_start_out = next_tile_start_latent * temporal_scale
|
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
|
||||||
|
|
||||||
# We need to track how many frames we've already emitted
|
# We need to track how many frames we've already emitted
|
||||||
if not hasattr(decode_with_tiling, '_emitted_frames'):
|
if not hasattr(decode_with_tiling, '_emitted_frames'):
|
||||||
|
|||||||
281
mlx_video/models/wan/tiling.py
Normal file
281
mlx_video/models/wan/tiling.py
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
"""Wan-specific tiled VAE decoding.
|
||||||
|
|
||||||
|
Re-exports all tiling utilities from the LTX VAE tiling module and provides
|
||||||
|
a Wan-specific ``decode_with_tiling`` that adds ``causal_temporal`` support
|
||||||
|
for non-causal temporal decoders (e.g. Wan2.1 where T latent frames → T*scale
|
||||||
|
output frames rather than LTX's 1+(T-1)*scale mapping).
|
||||||
|
|
||||||
|
# TODO: This function can be refactored to consolidate with
|
||||||
|
# mlx_video.models.ltx.video_vae.tiling.decode_with_tiling once the
|
||||||
|
# causal_temporal generalisation is accepted upstream.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
from mlx_video.models.ltx.video_vae.tiling import (
|
||||||
|
SpatialTilingConfig,
|
||||||
|
TemporalTilingConfig,
|
||||||
|
TilingConfig,
|
||||||
|
map_spatial_slice,
|
||||||
|
map_temporal_slice,
|
||||||
|
split_in_spatial,
|
||||||
|
split_in_temporal,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SpatialTilingConfig",
|
||||||
|
"TemporalTilingConfig",
|
||||||
|
"TilingConfig",
|
||||||
|
"decode_with_tiling",
|
||||||
|
"map_spatial_slice",
|
||||||
|
"map_temporal_slice",
|
||||||
|
"split_in_spatial",
|
||||||
|
"split_in_temporal",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def decode_with_tiling(
|
||||||
|
decoder_fn,
|
||||||
|
latents: mx.array,
|
||||||
|
tiling_config: TilingConfig,
|
||||||
|
spatial_scale: int = 32,
|
||||||
|
temporal_scale: int = 8,
|
||||||
|
causal: bool = False,
|
||||||
|
causal_temporal: bool = True,
|
||||||
|
timestep: Optional[mx.array] = None,
|
||||||
|
chunked_conv: bool = False,
|
||||||
|
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Decode latents using tiling to reduce memory usage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_fn: Decoder function to call for each tile.
|
||||||
|
latents: Input latents of shape (B, C, F, H, W).
|
||||||
|
tiling_config: Tiling configuration.
|
||||||
|
spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify).
|
||||||
|
temporal_scale: Temporal scale factor (8 for LTX VAE).
|
||||||
|
causal: Whether to use causal convolutions.
|
||||||
|
causal_temporal: Whether the decoder uses causal temporal mapping where
|
||||||
|
T input frames produce 1+(T-1)*scale output frames. When False, uses
|
||||||
|
simple scaling where T frames produce T*scale output frames.
|
||||||
|
Default True (LTX behavior). Set False for non-causal decoders (e.g. Wan2.1).
|
||||||
|
timestep: Optional timestep for conditioning.
|
||||||
|
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
|
||||||
|
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
|
||||||
|
frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames.
|
||||||
|
start_idx: Starting frame index in the full video.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decoded video.
|
||||||
|
"""
|
||||||
|
import gc
|
||||||
|
|
||||||
|
b, c, f_latent, h_latent, w_latent = latents.shape
|
||||||
|
|
||||||
|
# Compute output shape
|
||||||
|
out_f = (1 + (f_latent - 1) * temporal_scale) if causal_temporal else (f_latent * temporal_scale)
|
||||||
|
out_h = h_latent * spatial_scale
|
||||||
|
out_w = w_latent * spatial_scale
|
||||||
|
|
||||||
|
# Get tile size and overlap in latent space
|
||||||
|
if tiling_config.spatial_config is not None:
|
||||||
|
s_cfg = tiling_config.spatial_config
|
||||||
|
spatial_tile_size = s_cfg.tile_size_in_pixels // spatial_scale
|
||||||
|
spatial_overlap = s_cfg.tile_overlap_in_pixels // spatial_scale
|
||||||
|
else:
|
||||||
|
spatial_tile_size = max(h_latent, w_latent)
|
||||||
|
spatial_overlap = 0
|
||||||
|
|
||||||
|
if tiling_config.temporal_config is not None:
|
||||||
|
t_cfg = tiling_config.temporal_config
|
||||||
|
temporal_tile_size = t_cfg.tile_size_in_frames // temporal_scale
|
||||||
|
temporal_overlap = t_cfg.tile_overlap_in_frames // temporal_scale
|
||||||
|
else:
|
||||||
|
temporal_tile_size = f_latent
|
||||||
|
temporal_overlap = 0
|
||||||
|
|
||||||
|
# Compute intervals for each dimension
|
||||||
|
if causal_temporal:
|
||||||
|
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
|
||||||
|
else:
|
||||||
|
temporal_intervals = split_in_spatial(temporal_tile_size, temporal_overlap, f_latent)
|
||||||
|
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
|
||||||
|
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
|
||||||
|
|
||||||
|
num_t_tiles = len(temporal_intervals.starts)
|
||||||
|
num_h_tiles = len(height_intervals.starts)
|
||||||
|
num_w_tiles = len(width_intervals.starts)
|
||||||
|
total_tiles = num_t_tiles * num_h_tiles * num_w_tiles # noqa: F841
|
||||||
|
|
||||||
|
# Initialize output and weight accumulator
|
||||||
|
# Use float32 for accumulation to avoid precision issues
|
||||||
|
output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32)
|
||||||
|
weights = mx.zeros((b, 1, out_f, out_h, out_w), dtype=mx.float32)
|
||||||
|
mx.eval(output, weights)
|
||||||
|
|
||||||
|
tile_idx = 0
|
||||||
|
for t_idx in range(num_t_tiles):
|
||||||
|
t_start = temporal_intervals.starts[t_idx]
|
||||||
|
t_end = temporal_intervals.ends[t_idx]
|
||||||
|
t_left = temporal_intervals.left_ramps[t_idx]
|
||||||
|
t_right = temporal_intervals.right_ramps[t_idx]
|
||||||
|
|
||||||
|
# Map temporal coordinates
|
||||||
|
if causal_temporal:
|
||||||
|
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
||||||
|
else:
|
||||||
|
out_t_slice, t_mask = map_spatial_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
||||||
|
|
||||||
|
for h_idx in range(num_h_tiles):
|
||||||
|
h_start = height_intervals.starts[h_idx]
|
||||||
|
h_end = height_intervals.ends[h_idx]
|
||||||
|
h_left = height_intervals.left_ramps[h_idx]
|
||||||
|
h_right = height_intervals.right_ramps[h_idx]
|
||||||
|
|
||||||
|
# Map height coordinates
|
||||||
|
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
|
||||||
|
|
||||||
|
for w_idx in range(num_w_tiles):
|
||||||
|
w_start = width_intervals.starts[w_idx]
|
||||||
|
w_end = width_intervals.ends[w_idx]
|
||||||
|
w_left = width_intervals.left_ramps[w_idx]
|
||||||
|
w_right = width_intervals.right_ramps[w_idx]
|
||||||
|
|
||||||
|
# Map width coordinates
|
||||||
|
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
|
||||||
|
|
||||||
|
# Extract tile latents (small slice)
|
||||||
|
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
|
||||||
|
|
||||||
|
# Decode tile
|
||||||
|
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
|
||||||
|
mx.eval(tile_output)
|
||||||
|
|
||||||
|
# Clear tile_latents reference
|
||||||
|
del tile_latents
|
||||||
|
|
||||||
|
# Get actual decoded dimensions
|
||||||
|
_, _, decoded_t, decoded_h, decoded_w = tile_output.shape
|
||||||
|
expected_t = out_t_slice.stop - out_t_slice.start
|
||||||
|
expected_h = out_h_slice.stop - out_h_slice.start
|
||||||
|
expected_w = out_w_slice.stop - out_w_slice.start
|
||||||
|
|
||||||
|
# Handle potential size mismatches (use minimum)
|
||||||
|
actual_t = min(decoded_t, expected_t)
|
||||||
|
actual_h = min(decoded_h, expected_h)
|
||||||
|
actual_w = min(decoded_w, expected_w)
|
||||||
|
|
||||||
|
# Build blend mask
|
||||||
|
t_mask_slice = t_mask[:actual_t] if len(t_mask) > actual_t else t_mask
|
||||||
|
h_mask_slice = h_mask[:actual_h] if len(h_mask) > actual_h else h_mask
|
||||||
|
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
|
||||||
|
|
||||||
|
blend_mask = (
|
||||||
|
t_mask_slice.reshape(1, 1, -1, 1, 1) *
|
||||||
|
h_mask_slice.reshape(1, 1, 1, -1, 1) *
|
||||||
|
w_mask_slice.reshape(1, 1, 1, 1, -1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Slice tile output to match
|
||||||
|
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
|
||||||
|
|
||||||
|
# Clear full tile_output
|
||||||
|
del tile_output
|
||||||
|
|
||||||
|
# Compute output coordinates
|
||||||
|
t_out_start = out_t_slice.start
|
||||||
|
t_out_end = t_out_start + actual_t
|
||||||
|
h_out_start = out_h_slice.start
|
||||||
|
h_out_end = h_out_start + actual_h
|
||||||
|
w_out_start = out_w_slice.start
|
||||||
|
w_out_end = w_out_start + actual_w
|
||||||
|
|
||||||
|
# Weighted accumulation
|
||||||
|
weighted_tile = tile_output_slice * blend_mask
|
||||||
|
|
||||||
|
# Update output using slice assignment
|
||||||
|
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
||||||
|
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
|
||||||
|
)
|
||||||
|
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
||||||
|
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
# Force evaluation to free memory
|
||||||
|
mx.eval(output, weights)
|
||||||
|
|
||||||
|
# Clean up tile-specific arrays
|
||||||
|
del tile_output_slice, weighted_tile, blend_mask
|
||||||
|
del t_mask_slice, h_mask_slice, w_mask_slice
|
||||||
|
|
||||||
|
tile_idx += 1
|
||||||
|
|
||||||
|
# Periodic garbage collection and cache clearing
|
||||||
|
if tile_idx % 4 == 0:
|
||||||
|
gc.collect()
|
||||||
|
try:
|
||||||
|
mx.clear_cache()
|
||||||
|
except Exception:
|
||||||
|
pass # May not be available on all platforms
|
||||||
|
|
||||||
|
# After completing all spatial tiles for this temporal tile,
|
||||||
|
# check if any frames are now finalized (no future tiles will contribute)
|
||||||
|
if on_frames_ready is not None and num_t_tiles > 1:
|
||||||
|
# Determine the finalized frame boundary
|
||||||
|
# Frames before the start of the next tile's output region are finalized
|
||||||
|
if t_idx < num_t_tiles - 1:
|
||||||
|
# Next tile starts at temporal_intervals.starts[t_idx + 1]
|
||||||
|
next_tile_start_latent = temporal_intervals.starts[t_idx + 1]
|
||||||
|
# Map to output frame index (first frame of next tile's contribution)
|
||||||
|
if next_tile_start_latent == 0:
|
||||||
|
next_tile_start_out = 0
|
||||||
|
elif causal_temporal:
|
||||||
|
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
|
||||||
|
else:
|
||||||
|
next_tile_start_out = next_tile_start_latent * temporal_scale
|
||||||
|
|
||||||
|
# We need to track how many frames we've already emitted
|
||||||
|
if not hasattr(decode_with_tiling, '_emitted_frames'):
|
||||||
|
decode_with_tiling._emitted_frames = 0
|
||||||
|
emitted = decode_with_tiling._emitted_frames
|
||||||
|
|
||||||
|
if next_tile_start_out > emitted:
|
||||||
|
# Normalize and emit frames [emitted, next_tile_start_out)
|
||||||
|
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
|
||||||
|
finalized_weights = mx.maximum(finalized_weights, 1e-8)
|
||||||
|
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
|
||||||
|
finalized_output = finalized_output.astype(latents.dtype)
|
||||||
|
mx.eval(finalized_output)
|
||||||
|
|
||||||
|
on_frames_ready(finalized_output, emitted)
|
||||||
|
decode_with_tiling._emitted_frames = next_tile_start_out
|
||||||
|
|
||||||
|
del finalized_output, finalized_weights
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Normalize by weights
|
||||||
|
weights = mx.maximum(weights, 1e-8)
|
||||||
|
output = output / weights
|
||||||
|
mx.eval(output)
|
||||||
|
|
||||||
|
# Emit remaining frames if callback provided
|
||||||
|
if on_frames_ready is not None:
|
||||||
|
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
|
||||||
|
if emitted < out_f:
|
||||||
|
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
||||||
|
mx.eval(remaining_output)
|
||||||
|
on_frames_ready(remaining_output, emitted)
|
||||||
|
del remaining_output
|
||||||
|
|
||||||
|
# Reset emitted frames counter for next call
|
||||||
|
if hasattr(decode_with_tiling, '_emitted_frames'):
|
||||||
|
del decode_with_tiling._emitted_frames
|
||||||
|
|
||||||
|
# Clean up weights
|
||||||
|
del weights
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Convert back to original dtype if needed
|
||||||
|
return output.astype(latents.dtype)
|
||||||
@@ -549,7 +549,7 @@ class WanVAE(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
|
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
|
||||||
"""
|
"""
|
||||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
|
from mlx_video.models.wan.tiling import TilingConfig, decode_with_tiling
|
||||||
|
|
||||||
if tiling_config is None:
|
if tiling_config is None:
|
||||||
tiling_config = TilingConfig.default()
|
tiling_config = TilingConfig.default()
|
||||||
|
|||||||
@@ -847,7 +847,7 @@ class Wan22VAEDecoder(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
video: [B, T', H', W', 3] decoded RGB in [-1, 1]
|
video: [B, T', H', W', 3] decoded RGB in [-1, 1]
|
||||||
"""
|
"""
|
||||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
|
from mlx_video.models.wan.tiling import TilingConfig, decode_with_tiling
|
||||||
|
|
||||||
if tiling_config is None:
|
if tiling_config is None:
|
||||||
tiling_config = TilingConfig.default()
|
tiling_config = TilingConfig.default()
|
||||||
|
|||||||
Reference in New Issue
Block a user