feat(wan): Add tiled VAE decoding and fix TI2V quality
This commit is contained in:
@@ -534,3 +534,56 @@ class WanVAE(nn.Module):
|
||||
x = self.conv2(z)
|
||||
out = self.decoder(x)
|
||||
return mx.clip(out, -1, 1)
|
||||
|
||||
def decode_tiled(self, z: mx.array, tiling_config=None) -> mx.array:
|
||||
"""Decode latent to video using tiling to reduce memory usage.
|
||||
|
||||
Splits the latent tensor into overlapping spatial/temporal tiles,
|
||||
decodes each tile independently, and blends them with trapezoidal
|
||||
masks. Reuses the LTX-2 tiling infrastructure.
|
||||
|
||||
Args:
|
||||
z: Normalized latent [B, z_dim, T, H, W]
|
||||
tiling_config: Optional TilingConfig. If None, uses default.
|
||||
|
||||
Returns:
|
||||
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
|
||||
"""
|
||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
if tiling_config is None:
|
||||
tiling_config = TilingConfig.default()
|
||||
|
||||
# Check if tiling is actually needed
|
||||
_, _, f, h, w = z.shape
|
||||
needs_tiling = False
|
||||
if tiling_config.spatial_config is not None:
|
||||
s_tile = tiling_config.spatial_config.tile_size_in_pixels // 8
|
||||
if h > s_tile or w > s_tile:
|
||||
needs_tiling = True
|
||||
if tiling_config.temporal_config is not None:
|
||||
t_tile = tiling_config.temporal_config.tile_size_in_frames // 4
|
||||
if f > t_tile:
|
||||
needs_tiling = True
|
||||
|
||||
if not needs_tiling:
|
||||
return self.decode(z)
|
||||
|
||||
# Denormalize once (small tensor), then tile the denormalized latents
|
||||
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
|
||||
z_denorm = z / inv_std + mean
|
||||
|
||||
def tile_decode(tile_latents, **kwargs):
|
||||
x = self.conv2(tile_latents)
|
||||
out = self.decoder(x)
|
||||
return mx.clip(out, -1, 1)
|
||||
|
||||
return decode_with_tiling(
|
||||
decoder_fn=tile_decode,
|
||||
latents=z_denorm,
|
||||
tiling_config=tiling_config,
|
||||
spatial_scale=8, # 3× spatial 2× upsamples = 8×
|
||||
temporal_scale=4, # 2× temporal upsamples × 2 = 4×
|
||||
causal_temporal=False, # Wan2.1 uses non-causal temporal (T → 4T)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user