feat(wan): Add tiled VAE decoding and fix TI2V quality
This commit is contained in:
@@ -709,6 +709,67 @@ class Wan22VAEDecoder(nn.Module):
|
||||
|
||||
return mx.clip(out, -1.0, 1.0)
|
||||
|
||||
def decode_tiled(self, z, tiling_config=None):
|
||||
"""Decode latents using tiling to reduce memory usage.
|
||||
|
||||
Splits the latent tensor into overlapping spatial/temporal tiles,
|
||||
decodes each tile independently, and blends them with trapezoidal
|
||||
masks. Reuses the LTX-2 tiling infrastructure with channels-first
|
||||
adapter (future: refactor tiling.py to be layout-agnostic).
|
||||
|
||||
Args:
|
||||
z: [B, T, H, W, C=48] latent tensor (already denormalized)
|
||||
tiling_config: Optional TilingConfig. If None, uses default.
|
||||
|
||||
Returns:
|
||||
video: [B, T', H', W', 3] decoded RGB in [-1, 1]
|
||||
"""
|
||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
if tiling_config is None:
|
||||
tiling_config = TilingConfig.default()
|
||||
|
||||
# Check if tiling is actually needed
|
||||
b, t, h_px, w_px, c = z.shape
|
||||
# Latent dimensions (before conv2/decoder upsampling)
|
||||
h_lat, w_lat = h_px, w_px
|
||||
needs_tiling = False
|
||||
if tiling_config.spatial_config is not None:
|
||||
s_tile = tiling_config.spatial_config.tile_size_in_pixels // 16
|
||||
if h_lat > s_tile or w_lat > s_tile:
|
||||
needs_tiling = True
|
||||
if tiling_config.temporal_config is not None:
|
||||
t_tile = tiling_config.temporal_config.tile_size_in_frames // 4
|
||||
if t > t_tile:
|
||||
needs_tiling = True
|
||||
|
||||
if not needs_tiling:
|
||||
return self(z)
|
||||
|
||||
# Transpose to channels-first for decode_with_tiling: [B,T,H,W,C] → [B,C,T,H,W]
|
||||
z_cf = z.transpose(0, 4, 1, 2, 3)
|
||||
|
||||
# Tile decoder: receives (B,C,T,H,W) channels-first, returns (B,3,T',H',W')
|
||||
def tile_decode(tile_latents, **kwargs):
|
||||
tile_cl = tile_latents.transpose(0, 2, 3, 4, 1) # → [B,T,H,W,C]
|
||||
x = self.conv2(tile_cl)
|
||||
out = self.decoder(x, first_chunk=True)
|
||||
out = _unpatchify(out, patch_size=2)
|
||||
out = mx.clip(out, -1.0, 1.0)
|
||||
return out.transpose(0, 4, 1, 2, 3) # → [B,3,T',H',W']
|
||||
|
||||
result_cf = decode_with_tiling(
|
||||
decoder_fn=tile_decode,
|
||||
latents=z_cf,
|
||||
tiling_config=tiling_config,
|
||||
spatial_scale=16, # 8× conv upsample + 2× unpatchify
|
||||
temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal)
|
||||
causal_temporal=True,
|
||||
)
|
||||
|
||||
# Back to channels-last: [B,3,T',H',W'] → [B,T',H',W',3]
|
||||
return result_cf.transpose(0, 2, 3, 4, 1)
|
||||
|
||||
|
||||
def denormalize_latents(z, mean=None, std=None):
|
||||
"""Denormalize latents: z = z / (1/std) + mean."""
|
||||
|
||||
Reference in New Issue
Block a user