feat(wan): Add tiled VAE decoding and fix TI2V quality

This commit is contained in:
Daniel
2026-03-04 14:32:45 +01:00
parent 9597b7c9c5
commit 9bdda9f22e
7 changed files with 407 additions and 34 deletions

View File

@@ -283,6 +283,7 @@ def decode_with_tiling(
spatial_scale: int = 32,
temporal_scale: int = 8,
causal: bool = False,
causal_temporal: bool = True,
timestep: Optional[mx.array] = None,
chunked_conv: bool = False,
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
@@ -296,6 +297,10 @@ def decode_with_tiling(
spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify).
temporal_scale: Temporal scale factor (8 for LTX VAE).
causal: Whether to use causal convolutions.
causal_temporal: Whether the decoder uses causal temporal mapping where
T input frames produce 1+(T-1)*scale output frames. When False, uses
simple scaling where T frames produce T*scale output frames.
Default True (LTX behavior). Set False for non-causal decoders (e.g. Wan2.1).
timestep: Optional timestep for conditioning.
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
@@ -310,7 +315,7 @@ def decode_with_tiling(
b, c, f_latent, h_latent, w_latent = latents.shape
# Compute output shape
out_f = 1 + (f_latent - 1) * temporal_scale
out_f = (1 + (f_latent - 1) * temporal_scale) if causal_temporal else (f_latent * temporal_scale)
out_h = h_latent * spatial_scale
out_w = w_latent * spatial_scale
@@ -332,7 +337,10 @@ def decode_with_tiling(
temporal_overlap = 0
# Compute intervals for each dimension
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
if causal_temporal:
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
else:
temporal_intervals = split_in_spatial(temporal_tile_size, temporal_overlap, f_latent)
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
@@ -355,7 +363,10 @@ def decode_with_tiling(
t_right = temporal_intervals.right_ramps[t_idx]
# Map temporal coordinates
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
if causal_temporal:
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
else:
out_t_slice, t_mask = map_spatial_slice(t_start, t_end, t_left, t_right, temporal_scale)
for h_idx in range(num_h_tiles):
h_start = height_intervals.starts[h_idx]
@@ -461,8 +472,10 @@ def decode_with_tiling(
# Map to output frame index (first frame of next tile's contribution)
if next_tile_start_latent == 0:
next_tile_start_out = 0
else:
elif causal_temporal:
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
else:
next_tile_start_out = next_tile_start_latent * temporal_scale
# We need to track how many frames we've already emitted
if not hasattr(decode_with_tiling, '_emitted_frames'):