Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.
This commit is contained in:
338
mlx_video/models/wan_2/tiling.py
Normal file
338
mlx_video/models/wan_2/tiling.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""Wan-specific tiled VAE decoding.
|
||||
|
||||
Re-exports all tiling utilities from the LTX VAE tiling module and provides
|
||||
a Wan-specific ``decode_with_tiling`` that adds ``causal_temporal`` support
|
||||
for non-causal temporal decoders (e.g. Wan2.1 where T latent frames → T*scale
|
||||
output frames rather than LTX's 1+(T-1)*scale mapping).
|
||||
|
||||
# TODO: This function can be refactored to consolidate with
|
||||
# mlx_video.models.ltx_2.video_vae.tiling.decode_with_tiling once the
|
||||
# causal_temporal generalisation is accepted upstream.
|
||||
"""
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import (
|
||||
SpatialTilingConfig,
|
||||
TemporalTilingConfig,
|
||||
TilingConfig,
|
||||
map_spatial_slice,
|
||||
map_temporal_slice,
|
||||
split_in_spatial,
|
||||
split_in_temporal,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SpatialTilingConfig",
|
||||
"TemporalTilingConfig",
|
||||
"TilingConfig",
|
||||
"decode_with_tiling",
|
||||
"map_spatial_slice",
|
||||
"map_temporal_slice",
|
||||
"split_in_spatial",
|
||||
"split_in_temporal",
|
||||
]
|
||||
|
||||
|
||||
def decode_with_tiling(
|
||||
decoder_fn,
|
||||
latents: mx.array,
|
||||
tiling_config: TilingConfig,
|
||||
spatial_scale: int = 32,
|
||||
temporal_scale: int = 8,
|
||||
causal: bool = False,
|
||||
causal_temporal: bool = True,
|
||||
timestep: Optional[mx.array] = None,
|
||||
chunked_conv: bool = False,
|
||||
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
|
||||
) -> mx.array:
|
||||
"""Decode latents using tiling to reduce memory usage.
|
||||
|
||||
Args:
|
||||
decoder_fn: Decoder function to call for each tile.
|
||||
latents: Input latents of shape (B, C, F, H, W).
|
||||
tiling_config: Tiling configuration.
|
||||
spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify).
|
||||
temporal_scale: Temporal scale factor (8 for LTX VAE).
|
||||
causal: Whether to use causal convolutions.
|
||||
causal_temporal: Whether the decoder uses causal temporal mapping where
|
||||
T input frames produce 1+(T-1)*scale output frames. When False, uses
|
||||
simple scaling where T frames produce T*scale output frames.
|
||||
Default True (LTX behavior). Set False for non-causal decoders (e.g. Wan2.1).
|
||||
timestep: Optional timestep for conditioning.
|
||||
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
|
||||
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
|
||||
frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames.
|
||||
start_idx: Starting frame index in the full video.
|
||||
|
||||
Returns:
|
||||
Decoded video.
|
||||
"""
|
||||
import gc
|
||||
|
||||
b, c, f_latent, h_latent, w_latent = latents.shape
|
||||
|
||||
# Compute output shape
|
||||
out_f = (
|
||||
(1 + (f_latent - 1) * temporal_scale)
|
||||
if causal_temporal
|
||||
else (f_latent * temporal_scale)
|
||||
)
|
||||
out_h = h_latent * spatial_scale
|
||||
out_w = w_latent * spatial_scale
|
||||
|
||||
# Get tile size and overlap in latent space
|
||||
if tiling_config.spatial_config is not None:
|
||||
s_cfg = tiling_config.spatial_config
|
||||
spatial_tile_size = s_cfg.tile_size_in_pixels // spatial_scale
|
||||
spatial_overlap = s_cfg.tile_overlap_in_pixels // spatial_scale
|
||||
else:
|
||||
spatial_tile_size = max(h_latent, w_latent)
|
||||
spatial_overlap = 0
|
||||
|
||||
if tiling_config.temporal_config is not None:
|
||||
t_cfg = tiling_config.temporal_config
|
||||
temporal_tile_size = t_cfg.tile_size_in_frames // temporal_scale
|
||||
temporal_overlap = t_cfg.tile_overlap_in_frames // temporal_scale
|
||||
else:
|
||||
temporal_tile_size = f_latent
|
||||
temporal_overlap = 0
|
||||
|
||||
# Compute intervals for each dimension
|
||||
if causal_temporal:
|
||||
temporal_intervals = split_in_temporal(
|
||||
temporal_tile_size, temporal_overlap, f_latent
|
||||
)
|
||||
else:
|
||||
temporal_intervals = split_in_spatial(
|
||||
temporal_tile_size, temporal_overlap, f_latent
|
||||
)
|
||||
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
|
||||
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
|
||||
|
||||
num_t_tiles = len(temporal_intervals.starts)
|
||||
num_h_tiles = len(height_intervals.starts)
|
||||
num_w_tiles = len(width_intervals.starts)
|
||||
total_tiles = num_t_tiles * num_h_tiles * num_w_tiles # noqa: F841
|
||||
|
||||
# Initialize output and weight accumulator
|
||||
# Use float32 for accumulation to avoid precision issues
|
||||
output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32)
|
||||
weights = mx.zeros((b, 1, out_f, out_h, out_w), dtype=mx.float32)
|
||||
mx.eval(output, weights)
|
||||
|
||||
tile_idx = 0
|
||||
for t_idx in range(num_t_tiles):
|
||||
t_start = temporal_intervals.starts[t_idx]
|
||||
t_end = temporal_intervals.ends[t_idx]
|
||||
t_left = temporal_intervals.left_ramps[t_idx]
|
||||
t_right = temporal_intervals.right_ramps[t_idx]
|
||||
|
||||
# Map temporal coordinates
|
||||
if causal_temporal:
|
||||
out_t_slice, t_mask = map_temporal_slice(
|
||||
t_start, t_end, t_left, t_right, temporal_scale
|
||||
)
|
||||
else:
|
||||
out_t_slice, t_mask = map_spatial_slice(
|
||||
t_start, t_end, t_left, t_right, temporal_scale
|
||||
)
|
||||
|
||||
for h_idx in range(num_h_tiles):
|
||||
h_start = height_intervals.starts[h_idx]
|
||||
h_end = height_intervals.ends[h_idx]
|
||||
h_left = height_intervals.left_ramps[h_idx]
|
||||
h_right = height_intervals.right_ramps[h_idx]
|
||||
|
||||
# Map height coordinates
|
||||
out_h_slice, h_mask = map_spatial_slice(
|
||||
h_start, h_end, h_left, h_right, spatial_scale
|
||||
)
|
||||
|
||||
for w_idx in range(num_w_tiles):
|
||||
w_start = width_intervals.starts[w_idx]
|
||||
w_end = width_intervals.ends[w_idx]
|
||||
w_left = width_intervals.left_ramps[w_idx]
|
||||
w_right = width_intervals.right_ramps[w_idx]
|
||||
|
||||
# Map width coordinates
|
||||
out_w_slice, w_mask = map_spatial_slice(
|
||||
w_start, w_end, w_left, w_right, spatial_scale
|
||||
)
|
||||
|
||||
# Extract tile latents (small slice)
|
||||
tile_latents = latents[
|
||||
:, :, t_start:t_end, h_start:h_end, w_start:w_end
|
||||
]
|
||||
|
||||
# Decode tile
|
||||
tile_output = decoder_fn(
|
||||
tile_latents,
|
||||
causal=causal,
|
||||
timestep=timestep,
|
||||
debug=False,
|
||||
chunked_conv=chunked_conv,
|
||||
)
|
||||
mx.eval(tile_output)
|
||||
|
||||
# Clear tile_latents reference
|
||||
del tile_latents
|
||||
|
||||
# Get actual decoded dimensions
|
||||
_, _, decoded_t, decoded_h, decoded_w = tile_output.shape
|
||||
expected_t = out_t_slice.stop - out_t_slice.start
|
||||
expected_h = out_h_slice.stop - out_h_slice.start
|
||||
expected_w = out_w_slice.stop - out_w_slice.start
|
||||
|
||||
# Handle potential size mismatches (use minimum)
|
||||
actual_t = min(decoded_t, expected_t)
|
||||
actual_h = min(decoded_h, expected_h)
|
||||
actual_w = min(decoded_w, expected_w)
|
||||
|
||||
# Build blend mask
|
||||
t_mask_slice = t_mask[:actual_t] if len(t_mask) > actual_t else t_mask
|
||||
h_mask_slice = h_mask[:actual_h] if len(h_mask) > actual_h else h_mask
|
||||
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
|
||||
|
||||
blend_mask = (
|
||||
t_mask_slice.reshape(1, 1, -1, 1, 1)
|
||||
* h_mask_slice.reshape(1, 1, 1, -1, 1)
|
||||
* w_mask_slice.reshape(1, 1, 1, 1, -1)
|
||||
)
|
||||
|
||||
# Slice tile output to match
|
||||
tile_output_slice = tile_output[
|
||||
:, :, :actual_t, :actual_h, :actual_w
|
||||
].astype(mx.float32)
|
||||
|
||||
# Clear full tile_output
|
||||
del tile_output
|
||||
|
||||
# Compute output coordinates
|
||||
t_out_start = out_t_slice.start
|
||||
t_out_end = t_out_start + actual_t
|
||||
h_out_start = out_h_slice.start
|
||||
h_out_end = h_out_start + actual_h
|
||||
w_out_start = out_w_slice.start
|
||||
w_out_end = w_out_start + actual_w
|
||||
|
||||
# Weighted accumulation
|
||||
weighted_tile = tile_output_slice * blend_mask
|
||||
|
||||
# Update output using slice assignment
|
||||
output[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
] = (
|
||||
output[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
]
|
||||
+ weighted_tile
|
||||
)
|
||||
weights[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
] = (
|
||||
weights[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
]
|
||||
+ blend_mask
|
||||
)
|
||||
|
||||
# Force evaluation to free memory
|
||||
mx.eval(output, weights)
|
||||
|
||||
# Clean up tile-specific arrays
|
||||
del tile_output_slice, weighted_tile, blend_mask
|
||||
del t_mask_slice, h_mask_slice, w_mask_slice
|
||||
|
||||
tile_idx += 1
|
||||
|
||||
# Periodic garbage collection and cache clearing
|
||||
if tile_idx % 4 == 0:
|
||||
gc.collect()
|
||||
try:
|
||||
mx.clear_cache()
|
||||
except Exception:
|
||||
pass # May not be available on all platforms
|
||||
|
||||
# After completing all spatial tiles for this temporal tile,
|
||||
# check if any frames are now finalized (no future tiles will contribute)
|
||||
if on_frames_ready is not None and num_t_tiles > 1:
|
||||
# Determine the finalized frame boundary
|
||||
# Frames before the start of the next tile's output region are finalized
|
||||
if t_idx < num_t_tiles - 1:
|
||||
# Next tile starts at temporal_intervals.starts[t_idx + 1]
|
||||
next_tile_start_latent = temporal_intervals.starts[t_idx + 1]
|
||||
# Map to output frame index (first frame of next tile's contribution)
|
||||
if next_tile_start_latent == 0:
|
||||
next_tile_start_out = 0
|
||||
elif causal_temporal:
|
||||
next_tile_start_out = (
|
||||
1 + (next_tile_start_latent - 1) * temporal_scale
|
||||
)
|
||||
else:
|
||||
next_tile_start_out = next_tile_start_latent * temporal_scale
|
||||
|
||||
# We need to track how many frames we've already emitted
|
||||
if not hasattr(decode_with_tiling, "_emitted_frames"):
|
||||
decode_with_tiling._emitted_frames = 0
|
||||
emitted = decode_with_tiling._emitted_frames
|
||||
|
||||
if next_tile_start_out > emitted:
|
||||
# Normalize and emit frames [emitted, next_tile_start_out)
|
||||
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
|
||||
finalized_weights = mx.maximum(finalized_weights, 1e-8)
|
||||
finalized_output = (
|
||||
output[:, :, emitted:next_tile_start_out, :, :]
|
||||
/ finalized_weights
|
||||
)
|
||||
finalized_output = finalized_output.astype(latents.dtype)
|
||||
mx.eval(finalized_output)
|
||||
|
||||
on_frames_ready(finalized_output, emitted)
|
||||
decode_with_tiling._emitted_frames = next_tile_start_out
|
||||
|
||||
del finalized_output, finalized_weights
|
||||
gc.collect()
|
||||
|
||||
# Normalize by weights
|
||||
weights = mx.maximum(weights, 1e-8)
|
||||
output = output / weights
|
||||
mx.eval(output)
|
||||
|
||||
# Emit remaining frames if callback provided
|
||||
if on_frames_ready is not None:
|
||||
emitted = getattr(decode_with_tiling, "_emitted_frames", 0)
|
||||
if emitted < out_f:
|
||||
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
||||
mx.eval(remaining_output)
|
||||
on_frames_ready(remaining_output, emitted)
|
||||
del remaining_output
|
||||
|
||||
# Reset emitted frames counter for next call
|
||||
if hasattr(decode_with_tiling, "_emitted_frames"):
|
||||
del decode_with_tiling._emitted_frames
|
||||
|
||||
# Clean up weights
|
||||
del weights
|
||||
gc.collect()
|
||||
|
||||
# Convert back to original dtype if needed
|
||||
return output.astype(latents.dtype)
|
||||
Reference in New Issue
Block a user