feat(wan): Add tiled VAE decoding and fix TI2V quality

This commit is contained in:
Daniel
2026-03-04 14:32:45 +01:00
parent 9597b7c9c5
commit 9bdda9f22e
7 changed files with 407 additions and 34 deletions

View File

@@ -709,6 +709,67 @@ class Wan22VAEDecoder(nn.Module):
return mx.clip(out, -1.0, 1.0)
def decode_tiled(self, z, tiling_config=None):
"""Decode latents using tiling to reduce memory usage.
Splits the latent tensor into overlapping spatial/temporal tiles,
decodes each tile independently, and blends them with trapezoidal
masks. Reuses the LTX-2 tiling infrastructure with channels-first
adapter (future: refactor tiling.py to be layout-agnostic).
Args:
z: [B, T, H, W, C=48] latent tensor (already denormalized)
tiling_config: Optional TilingConfig. If None, uses default.
Returns:
video: [B, T', H', W', 3] decoded RGB in [-1, 1]
"""
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
if tiling_config is None:
tiling_config = TilingConfig.default()
# Check if tiling is actually needed
b, t, h_px, w_px, c = z.shape
# Latent dimensions (before conv2/decoder upsampling)
h_lat, w_lat = h_px, w_px
needs_tiling = False
if tiling_config.spatial_config is not None:
s_tile = tiling_config.spatial_config.tile_size_in_pixels // 16
if h_lat > s_tile or w_lat > s_tile:
needs_tiling = True
if tiling_config.temporal_config is not None:
t_tile = tiling_config.temporal_config.tile_size_in_frames // 4
if t > t_tile:
needs_tiling = True
if not needs_tiling:
return self(z)
# Transpose to channels-first for decode_with_tiling: [B,T,H,W,C] → [B,C,T,H,W]
z_cf = z.transpose(0, 4, 1, 2, 3)
# Tile decoder: receives (B,C,T,H,W) channels-first, returns (B,3,T',H',W')
def tile_decode(tile_latents, **kwargs):
tile_cl = tile_latents.transpose(0, 2, 3, 4, 1) # → [B,T,H,W,C]
x = self.conv2(tile_cl)
out = self.decoder(x, first_chunk=True)
out = _unpatchify(out, patch_size=2)
out = mx.clip(out, -1.0, 1.0)
return out.transpose(0, 4, 1, 2, 3) # → [B,3,T',H',W']
result_cf = decode_with_tiling(
decoder_fn=tile_decode,
latents=z_cf,
tiling_config=tiling_config,
spatial_scale=16, # 8× conv upsample + 2× unpatchify
temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal)
causal_temporal=True,
)
# Back to channels-last: [B,3,T',H',W'] → [B,T',H',W',3]
return result_cf.transpose(0, 2, 3, 4, 1)
def denormalize_latents(z, mean=None, std=None):
"""Denormalize latents: z = z / (1/std) + mean."""