feat(wan): Add tiled VAE decoding and fix TI2V quality
This commit is contained in:
@@ -48,9 +48,8 @@ class Head(nn.Module):
|
||||
"""
|
||||
if e.ndim == 2:
|
||||
e = e[:, None, :] # [B, 1, dim]
|
||||
# Compute modulation in float32 for precision, cast to working dtype
|
||||
w_dtype = _linear_dtype(self.head)
|
||||
mod = (self.modulation[:, None, :, :] + e[:, :, None, :]).astype(w_dtype)
|
||||
# Compute modulation in float32 (matching reference's autocast(float32))
|
||||
mod = self.modulation[:, None, :, :] + e[:, :, None, :] # float32
|
||||
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
|
||||
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
|
||||
x_norm = self.norm(x)
|
||||
@@ -120,10 +119,13 @@ class WanModel(nn.Module):
|
||||
], axis=1)
|
||||
|
||||
# Precompute sinusoidal inv_freq for time embedding
|
||||
# Use numpy float64 for precision (matches reference torch.float64),
|
||||
# then store as float32 since MLX GPU doesn't support float64.
|
||||
half = config.freq_dim // 2
|
||||
self._inv_freq = mx.power(
|
||||
10000.0, -mx.arange(half).astype(mx.float32) / half
|
||||
inv_freq_np = np.power(
|
||||
10000.0, -np.arange(half, dtype=np.float64) / half
|
||||
)
|
||||
self._inv_freq = mx.array(inv_freq_np.astype(np.float32))
|
||||
|
||||
|
||||
def _patchify(self, x: mx.array) -> tuple:
|
||||
|
||||
@@ -51,10 +51,11 @@ class WanAttentionBlock(nn.Module):
|
||||
rope_cos_sin: tuple | None = None,
|
||||
attn_mask: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
# Modulation: compute in float32 for precision, cast to working dtype
|
||||
# to avoid promoting the full hidden state (seq_len × dim) to float32
|
||||
w_dtype = _linear_dtype(self.self_attn.q)
|
||||
mod = (self.modulation + e).astype(w_dtype)
|
||||
# Modulation: compute in float32 for precision, matching the reference
|
||||
# which keeps residual x in float32 via torch.amp.autocast(dtype=float32).
|
||||
# By keeping modulation in float32, type promotion ensures the residual
|
||||
# stream stays float32 throughout all 30 layers (gate * output + x → float32).
|
||||
mod = self.modulation + e # float32
|
||||
e0, e1, e2, e3, e4, e5 = (
|
||||
mod[:, :, 0, :], # shift for self-attn
|
||||
mod[:, :, 1, :], # scale for self-attn
|
||||
|
||||
@@ -534,3 +534,56 @@ class WanVAE(nn.Module):
|
||||
x = self.conv2(z)
|
||||
out = self.decoder(x)
|
||||
return mx.clip(out, -1, 1)
|
||||
|
||||
def decode_tiled(self, z: mx.array, tiling_config=None) -> mx.array:
|
||||
"""Decode latent to video using tiling to reduce memory usage.
|
||||
|
||||
Splits the latent tensor into overlapping spatial/temporal tiles,
|
||||
decodes each tile independently, and blends them with trapezoidal
|
||||
masks. Reuses the LTX-2 tiling infrastructure.
|
||||
|
||||
Args:
|
||||
z: Normalized latent [B, z_dim, T, H, W]
|
||||
tiling_config: Optional TilingConfig. If None, uses default.
|
||||
|
||||
Returns:
|
||||
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
|
||||
"""
|
||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
if tiling_config is None:
|
||||
tiling_config = TilingConfig.default()
|
||||
|
||||
# Check if tiling is actually needed
|
||||
_, _, f, h, w = z.shape
|
||||
needs_tiling = False
|
||||
if tiling_config.spatial_config is not None:
|
||||
s_tile = tiling_config.spatial_config.tile_size_in_pixels // 8
|
||||
if h > s_tile or w > s_tile:
|
||||
needs_tiling = True
|
||||
if tiling_config.temporal_config is not None:
|
||||
t_tile = tiling_config.temporal_config.tile_size_in_frames // 4
|
||||
if f > t_tile:
|
||||
needs_tiling = True
|
||||
|
||||
if not needs_tiling:
|
||||
return self.decode(z)
|
||||
|
||||
# Denormalize once (small tensor), then tile the denormalized latents
|
||||
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
|
||||
z_denorm = z / inv_std + mean
|
||||
|
||||
def tile_decode(tile_latents, **kwargs):
|
||||
x = self.conv2(tile_latents)
|
||||
out = self.decoder(x)
|
||||
return mx.clip(out, -1, 1)
|
||||
|
||||
return decode_with_tiling(
|
||||
decoder_fn=tile_decode,
|
||||
latents=z_denorm,
|
||||
tiling_config=tiling_config,
|
||||
spatial_scale=8, # 3× spatial 2× upsamples = 8×
|
||||
temporal_scale=4, # 2× temporal upsamples × 2 = 4×
|
||||
causal_temporal=False, # Wan2.1 uses non-causal temporal (T → 4T)
|
||||
)
|
||||
|
||||
@@ -709,6 +709,67 @@ class Wan22VAEDecoder(nn.Module):
|
||||
|
||||
return mx.clip(out, -1.0, 1.0)
|
||||
|
||||
def decode_tiled(self, z, tiling_config=None):
|
||||
"""Decode latents using tiling to reduce memory usage.
|
||||
|
||||
Splits the latent tensor into overlapping spatial/temporal tiles,
|
||||
decodes each tile independently, and blends them with trapezoidal
|
||||
masks. Reuses the LTX-2 tiling infrastructure with channels-first
|
||||
adapter (future: refactor tiling.py to be layout-agnostic).
|
||||
|
||||
Args:
|
||||
z: [B, T, H, W, C=48] latent tensor (already denormalized)
|
||||
tiling_config: Optional TilingConfig. If None, uses default.
|
||||
|
||||
Returns:
|
||||
video: [B, T', H', W', 3] decoded RGB in [-1, 1]
|
||||
"""
|
||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
if tiling_config is None:
|
||||
tiling_config = TilingConfig.default()
|
||||
|
||||
# Check if tiling is actually needed
|
||||
b, t, h_px, w_px, c = z.shape
|
||||
# Latent dimensions (before conv2/decoder upsampling)
|
||||
h_lat, w_lat = h_px, w_px
|
||||
needs_tiling = False
|
||||
if tiling_config.spatial_config is not None:
|
||||
s_tile = tiling_config.spatial_config.tile_size_in_pixels // 16
|
||||
if h_lat > s_tile or w_lat > s_tile:
|
||||
needs_tiling = True
|
||||
if tiling_config.temporal_config is not None:
|
||||
t_tile = tiling_config.temporal_config.tile_size_in_frames // 4
|
||||
if t > t_tile:
|
||||
needs_tiling = True
|
||||
|
||||
if not needs_tiling:
|
||||
return self(z)
|
||||
|
||||
# Transpose to channels-first for decode_with_tiling: [B,T,H,W,C] → [B,C,T,H,W]
|
||||
z_cf = z.transpose(0, 4, 1, 2, 3)
|
||||
|
||||
# Tile decoder: receives (B,C,T,H,W) channels-first, returns (B,3,T',H',W')
|
||||
def tile_decode(tile_latents, **kwargs):
|
||||
tile_cl = tile_latents.transpose(0, 2, 3, 4, 1) # → [B,T,H,W,C]
|
||||
x = self.conv2(tile_cl)
|
||||
out = self.decoder(x, first_chunk=True)
|
||||
out = _unpatchify(out, patch_size=2)
|
||||
out = mx.clip(out, -1.0, 1.0)
|
||||
return out.transpose(0, 4, 1, 2, 3) # → [B,3,T',H',W']
|
||||
|
||||
result_cf = decode_with_tiling(
|
||||
decoder_fn=tile_decode,
|
||||
latents=z_cf,
|
||||
tiling_config=tiling_config,
|
||||
spatial_scale=16, # 8× conv upsample + 2× unpatchify
|
||||
temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal)
|
||||
causal_temporal=True,
|
||||
)
|
||||
|
||||
# Back to channels-last: [B,3,T',H',W'] → [B,T',H',W',3]
|
||||
return result_cf.transpose(0, 2, 3, 4, 1)
|
||||
|
||||
|
||||
def denormalize_latents(z, mean=None, std=None):
|
||||
"""Denormalize latents: z = z / (1/std) + mean."""
|
||||
|
||||
Reference in New Issue
Block a user