refactor(wan): move causal_temporal tiling to wan/tiling.py

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Daniel
2026-03-11 12:02:54 +01:00
parent 1cf878f5e0
commit c144c8817c
4 changed files with 287 additions and 19 deletions

View File

@@ -283,7 +283,6 @@ def decode_with_tiling(
spatial_scale: int = 32,
temporal_scale: int = 8,
causal: bool = False,
causal_temporal: bool = True,
timestep: Optional[mx.array] = None,
chunked_conv: bool = False,
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
@@ -297,10 +296,6 @@ def decode_with_tiling(
spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify).
temporal_scale: Temporal scale factor (8 for LTX VAE).
causal: Whether to use causal convolutions.
causal_temporal: Whether the decoder uses causal temporal mapping where
T input frames produce 1+(T-1)*scale output frames. When False, uses
simple scaling where T frames produce T*scale output frames.
Default True (LTX behavior). Set False for non-causal decoders (e.g. Wan2.1).
timestep: Optional timestep for conditioning.
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
@@ -315,7 +310,7 @@ def decode_with_tiling(
b, c, f_latent, h_latent, w_latent = latents.shape
# Compute output shape
out_f = (1 + (f_latent - 1) * temporal_scale) if causal_temporal else (f_latent * temporal_scale)
out_f = 1 + (f_latent - 1) * temporal_scale
out_h = h_latent * spatial_scale
out_w = w_latent * spatial_scale
@@ -337,10 +332,7 @@ def decode_with_tiling(
temporal_overlap = 0
# Compute intervals for each dimension
if causal_temporal:
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
else:
temporal_intervals = split_in_spatial(temporal_tile_size, temporal_overlap, f_latent)
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
@@ -363,10 +355,7 @@ def decode_with_tiling(
t_right = temporal_intervals.right_ramps[t_idx]
# Map temporal coordinates
if causal_temporal:
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
else:
out_t_slice, t_mask = map_spatial_slice(t_start, t_end, t_left, t_right, temporal_scale)
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
for h_idx in range(num_h_tiles):
h_start = height_intervals.starts[h_idx]
@@ -472,10 +461,8 @@ def decode_with_tiling(
# Map to output frame index (first frame of next tile's contribution)
if next_tile_start_latent == 0:
next_tile_start_out = 0
elif causal_temporal:
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
else:
next_tile_start_out = next_tile_start_latent * temporal_scale
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
# We need to track how many frames we've already emitted
if not hasattr(decode_with_tiling, '_emitted_frames'):