refactor(wan): move causal_temporal tiling to wan/tiling.py
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -283,7 +283,6 @@ def decode_with_tiling(
|
||||
spatial_scale: int = 32,
|
||||
temporal_scale: int = 8,
|
||||
causal: bool = False,
|
||||
causal_temporal: bool = True,
|
||||
timestep: Optional[mx.array] = None,
|
||||
chunked_conv: bool = False,
|
||||
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
|
||||
@@ -297,10 +296,6 @@ def decode_with_tiling(
|
||||
spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify).
|
||||
temporal_scale: Temporal scale factor (8 for LTX VAE).
|
||||
causal: Whether to use causal convolutions.
|
||||
causal_temporal: Whether the decoder uses causal temporal mapping where
|
||||
T input frames produce 1+(T-1)*scale output frames. When False, uses
|
||||
simple scaling where T frames produce T*scale output frames.
|
||||
Default True (LTX behavior). Set False for non-causal decoders (e.g. Wan2.1).
|
||||
timestep: Optional timestep for conditioning.
|
||||
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
|
||||
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
|
||||
@@ -315,7 +310,7 @@ def decode_with_tiling(
|
||||
b, c, f_latent, h_latent, w_latent = latents.shape
|
||||
|
||||
# Compute output shape
|
||||
out_f = (1 + (f_latent - 1) * temporal_scale) if causal_temporal else (f_latent * temporal_scale)
|
||||
out_f = 1 + (f_latent - 1) * temporal_scale
|
||||
out_h = h_latent * spatial_scale
|
||||
out_w = w_latent * spatial_scale
|
||||
|
||||
@@ -337,10 +332,7 @@ def decode_with_tiling(
|
||||
temporal_overlap = 0
|
||||
|
||||
# Compute intervals for each dimension
|
||||
if causal_temporal:
|
||||
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
|
||||
else:
|
||||
temporal_intervals = split_in_spatial(temporal_tile_size, temporal_overlap, f_latent)
|
||||
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
|
||||
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
|
||||
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
|
||||
|
||||
@@ -363,10 +355,7 @@ def decode_with_tiling(
|
||||
t_right = temporal_intervals.right_ramps[t_idx]
|
||||
|
||||
# Map temporal coordinates
|
||||
if causal_temporal:
|
||||
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
||||
else:
|
||||
out_t_slice, t_mask = map_spatial_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
||||
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
||||
|
||||
for h_idx in range(num_h_tiles):
|
||||
h_start = height_intervals.starts[h_idx]
|
||||
@@ -472,10 +461,8 @@ def decode_with_tiling(
|
||||
# Map to output frame index (first frame of next tile's contribution)
|
||||
if next_tile_start_latent == 0:
|
||||
next_tile_start_out = 0
|
||||
elif causal_temporal:
|
||||
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
|
||||
else:
|
||||
next_tile_start_out = next_tile_start_latent * temporal_scale
|
||||
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
|
||||
|
||||
# We need to track how many frames we've already emitted
|
||||
if not hasattr(decode_with_tiling, '_emitted_frames'):
|
||||
|
||||
Reference in New Issue
Block a user