feat(wan): Add tiled VAE decoding and fix TI2V quality

This commit is contained in:
Daniel
2026-03-04 14:32:45 +01:00
parent 9597b7c9c5
commit 9bdda9f22e
7 changed files with 407 additions and 34 deletions

View File

@@ -534,3 +534,56 @@ class WanVAE(nn.Module):
x = self.conv2(z)
out = self.decoder(x)
return mx.clip(out, -1, 1)
def decode_tiled(self, z: mx.array, tiling_config=None) -> mx.array:
"""Decode latent to video using tiling to reduce memory usage.
Splits the latent tensor into overlapping spatial/temporal tiles,
decodes each tile independently, and blends them with trapezoidal
masks. Reuses the LTX-2 tiling infrastructure.
Args:
z: Normalized latent [B, z_dim, T, H, W]
tiling_config: Optional TilingConfig. If None, uses default.
Returns:
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
"""
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
if tiling_config is None:
tiling_config = TilingConfig.default()
# Check if tiling is actually needed
_, _, f, h, w = z.shape
needs_tiling = False
if tiling_config.spatial_config is not None:
s_tile = tiling_config.spatial_config.tile_size_in_pixels // 8
if h > s_tile or w > s_tile:
needs_tiling = True
if tiling_config.temporal_config is not None:
t_tile = tiling_config.temporal_config.tile_size_in_frames // 4
if f > t_tile:
needs_tiling = True
if not needs_tiling:
return self.decode(z)
# Denormalize once (small tensor), then tile the denormalized latents
mean = self.mean.reshape(1, -1, 1, 1, 1)
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
z_denorm = z / inv_std + mean
def tile_decode(tile_latents, **kwargs):
x = self.conv2(tile_latents)
out = self.decoder(x)
return mx.clip(out, -1, 1)
return decode_with_tiling(
decoder_fn=tile_decode,
latents=z_denorm,
tiling_config=tiling_config,
spatial_scale=8, # 3× spatial 2× upsamples = 8×
temporal_scale=4, # 2× temporal upsamples × 2 = 4×
causal_temporal=False, # Wan2.1 uses non-causal temporal (T → 4T)
)