feat(wan): Add tiled VAE decoding and fix TI2V quality

This commit is contained in:
Daniel
2026-03-04 14:32:45 +01:00
parent 9597b7c9c5
commit 9bdda9f22e
7 changed files with 407 additions and 34 deletions

View File

@@ -46,8 +46,10 @@ def generate_video(
loras: list | None = None,
loras_high: list | None = None,
loras_low: list | None = None,
tiling: str = "auto",
no_compile: bool = False,
):
"""Generate video using Wan pipeline (supports T2V and I2V).
Args:
@@ -67,6 +69,13 @@ def generate_video(
loras: Optional list of (path, strength) tuples applied to all models
loras_high: Optional list of (path, strength) tuples for high-noise model only
loras_low: Optional list of (path, strength) tuples for low-noise model only
tiling: Tiling mode for VAE decoding. Options:
- "auto": Automatically determine tiling based on video size (default)
- "none": Disable tiling
- "default", "aggressive", "conservative": Preset tiling configs
- "spatial": Spatial tiling only
- "temporal": Temporal tiling only
no_compile: If True, skip mx.compile on models (useful for debugging)
"""
import json
@@ -173,12 +182,7 @@ def generate_video(
# Validate frame count
assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}"
# For T2V: generate 1 extra latent frame so the VAE's causal zero-padding
# artifacts land on throwaway frames. The reference Wan2.2 speech2video.py
# uses a similar "drop_first_motion" approach (drops 3 pixel frames).
# For I2V the reference image provides real first-frame content, so no extra needed.
extra_frames = config.vae_stride[0] if not is_i2v else 0
gen_frames = num_frames + extra_frames
gen_frames = num_frames
version_str = f"Wan{config.model_version}"
mode_str = "dual-model" if is_dual else "single-model"
@@ -241,8 +245,6 @@ def generate_video(
)
print(f"{Colors.DIM} Latent shape: {target_shape}")
if extra_frames > 0:
print(f" Generating {extra_frames} extra pixel frames to absorb VAE boundary artifacts")
print(f" Sequence length: {seq_len}{Colors.RESET}")
# Load T5 encoder
@@ -419,7 +421,7 @@ def generate_video(
rope_cos_sin_high = high_noise_model.prepare_rope(rope_grid_sizes)
mx.eval(rope_cos_sin_low, rope_cos_sin_high)
else:
rope_cos_sin = ref_model.prepare_rope(rope_grid_sizes)
rope_cos_sin = single_model.prepare_rope(rope_grid_sizes)
mx.eval(rope_cos_sin)
# Setup scheduler
@@ -448,12 +450,13 @@ def generate_video(
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
t3 = time.time()
# Compile model forward for faster denoising
models_to_compile = (
[high_noise_model, low_noise_model] if is_dual else [single_model]
)
for m in models_to_compile:
m._compiled = mx.compile(m)
if not no_compile:
models_to_compile = (
[high_noise_model, low_noise_model] if is_dual else [single_model]
)
for m in models_to_compile:
m._compiled = mx.compile(m)
@@ -585,24 +588,53 @@ def generate_video(
is_wan22_vae = config.vae_z_dim == 48
# Select tiling configuration
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
if tiling == "none":
tiling_config = None
elif tiling == "auto":
tiling_config = TilingConfig.auto(height, width, num_frames)
elif tiling == "default":
tiling_config = TilingConfig.default()
elif tiling == "aggressive":
tiling_config = TilingConfig.aggressive()
elif tiling == "conservative":
tiling_config = TilingConfig.conservative()
elif tiling == "spatial":
tiling_config = TilingConfig.spatial_only()
elif tiling == "temporal":
tiling_config = TilingConfig.temporal_only()
else:
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
tiling_config = TilingConfig.auto(height, width, num_frames)
if tiling_config is not None:
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
if is_wan22_vae:
from mlx_video.models.wan.vae22 import denormalize_latents
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
z = latents.transpose(1, 2, 3, 0)[None]
z = denormalize_latents(z)
video = vae(z)
if tiling_config is not None:
video = vae.decode_tiled(z, tiling_config)
else:
video = vae(z)
mx.eval(video)
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
video = np.array(video[0]) # [T', H', W', 3]
# Trim extra frames generated for zero-padding warmup
if extra_frames > 0:
video = video[extra_frames:]
video = (video + 1.0) / 2.0
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
else:
video = vae.decode(latents[None])
if tiling_config is not None:
video = vae.decode_tiled(latents[None], tiling_config)
else:
video = vae.decode(latents[None])
mx.eval(video)
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
@@ -651,6 +683,17 @@ def main():
"--lora-low", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to low-noise model only (dual-model, repeatable)",
)
parser.add_argument(
"--tiling",
type=str,
default="auto",
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
help="VAE tiling mode to reduce memory during decoding (default: auto)",
)
parser.add_argument(
"--no-compile", action="store_true",
help="Disable mx.compile on models (for debugging)",
)
args = parser.parse_args()
@@ -688,6 +731,8 @@ def main():
loras=_parse_lora_args(args.lora),
loras_high=_parse_lora_args(args.lora_high),
loras_low=_parse_lora_args(args.lora_low),
tiling=args.tiling,
no_compile=args.no_compile,
)

View File

@@ -283,6 +283,7 @@ def decode_with_tiling(
spatial_scale: int = 32,
temporal_scale: int = 8,
causal: bool = False,
causal_temporal: bool = True,
timestep: Optional[mx.array] = None,
chunked_conv: bool = False,
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
@@ -296,6 +297,10 @@ def decode_with_tiling(
spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify).
temporal_scale: Temporal scale factor (8 for LTX VAE).
causal: Whether to use causal convolutions.
causal_temporal: Whether the decoder uses causal temporal mapping where
T input frames produce 1+(T-1)*scale output frames. When False, uses
simple scaling where T frames produce T*scale output frames.
Default True (LTX behavior). Set False for non-causal decoders (e.g. Wan2.1).
timestep: Optional timestep for conditioning.
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
@@ -310,7 +315,7 @@ def decode_with_tiling(
b, c, f_latent, h_latent, w_latent = latents.shape
# Compute output shape
out_f = 1 + (f_latent - 1) * temporal_scale
out_f = (1 + (f_latent - 1) * temporal_scale) if causal_temporal else (f_latent * temporal_scale)
out_h = h_latent * spatial_scale
out_w = w_latent * spatial_scale
@@ -332,7 +337,10 @@ def decode_with_tiling(
temporal_overlap = 0
# Compute intervals for each dimension
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
if causal_temporal:
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
else:
temporal_intervals = split_in_spatial(temporal_tile_size, temporal_overlap, f_latent)
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
@@ -355,7 +363,10 @@ def decode_with_tiling(
t_right = temporal_intervals.right_ramps[t_idx]
# Map temporal coordinates
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
if causal_temporal:
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
else:
out_t_slice, t_mask = map_spatial_slice(t_start, t_end, t_left, t_right, temporal_scale)
for h_idx in range(num_h_tiles):
h_start = height_intervals.starts[h_idx]
@@ -461,8 +472,10 @@ def decode_with_tiling(
# Map to output frame index (first frame of next tile's contribution)
if next_tile_start_latent == 0:
next_tile_start_out = 0
else:
elif causal_temporal:
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
else:
next_tile_start_out = next_tile_start_latent * temporal_scale
# We need to track how many frames we've already emitted
if not hasattr(decode_with_tiling, '_emitted_frames'):

View File

@@ -48,9 +48,8 @@ class Head(nn.Module):
"""
if e.ndim == 2:
e = e[:, None, :] # [B, 1, dim]
# Compute modulation in float32 for precision, cast to working dtype
w_dtype = _linear_dtype(self.head)
mod = (self.modulation[:, None, :, :] + e[:, :, None, :]).astype(w_dtype)
# Compute modulation in float32 (matching reference's autocast(float32))
mod = self.modulation[:, None, :, :] + e[:, :, None, :] # float32
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
x_norm = self.norm(x)
@@ -120,10 +119,13 @@ class WanModel(nn.Module):
], axis=1)
# Precompute sinusoidal inv_freq for time embedding
# Use numpy float64 for precision (matches reference torch.float64),
# then store as float32 since MLX GPU doesn't support float64.
half = config.freq_dim // 2
self._inv_freq = mx.power(
10000.0, -mx.arange(half).astype(mx.float32) / half
inv_freq_np = np.power(
10000.0, -np.arange(half, dtype=np.float64) / half
)
self._inv_freq = mx.array(inv_freq_np.astype(np.float32))
def _patchify(self, x: mx.array) -> tuple:

View File

@@ -51,10 +51,11 @@ class WanAttentionBlock(nn.Module):
rope_cos_sin: tuple | None = None,
attn_mask: mx.array | None = None,
) -> mx.array:
# Modulation: compute in float32 for precision, cast to working dtype
# to avoid promoting the full hidden state (seq_len × dim) to float32
w_dtype = _linear_dtype(self.self_attn.q)
mod = (self.modulation + e).astype(w_dtype)
# Modulation: compute in float32 for precision, matching the reference
# which keeps residual x in float32 via torch.amp.autocast(dtype=float32).
# By keeping modulation in float32, type promotion ensures the residual
# stream stays float32 throughout all 30 layers (gate * output + x → float32).
mod = self.modulation + e # float32
e0, e1, e2, e3, e4, e5 = (
mod[:, :, 0, :], # shift for self-attn
mod[:, :, 1, :], # scale for self-attn

View File

@@ -534,3 +534,56 @@ class WanVAE(nn.Module):
x = self.conv2(z)
out = self.decoder(x)
return mx.clip(out, -1, 1)
def decode_tiled(self, z: mx.array, tiling_config=None) -> mx.array:
"""Decode latent to video using tiling to reduce memory usage.
Splits the latent tensor into overlapping spatial/temporal tiles,
decodes each tile independently, and blends them with trapezoidal
masks. Reuses the LTX-2 tiling infrastructure.
Args:
z: Normalized latent [B, z_dim, T, H, W]
tiling_config: Optional TilingConfig. If None, uses default.
Returns:
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
"""
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
if tiling_config is None:
tiling_config = TilingConfig.default()
# Check if tiling is actually needed
_, _, f, h, w = z.shape
needs_tiling = False
if tiling_config.spatial_config is not None:
s_tile = tiling_config.spatial_config.tile_size_in_pixels // 8
if h > s_tile or w > s_tile:
needs_tiling = True
if tiling_config.temporal_config is not None:
t_tile = tiling_config.temporal_config.tile_size_in_frames // 4
if f > t_tile:
needs_tiling = True
if not needs_tiling:
return self.decode(z)
# Denormalize once (small tensor), then tile the denormalized latents
mean = self.mean.reshape(1, -1, 1, 1, 1)
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
z_denorm = z / inv_std + mean
def tile_decode(tile_latents, **kwargs):
x = self.conv2(tile_latents)
out = self.decoder(x)
return mx.clip(out, -1, 1)
return decode_with_tiling(
decoder_fn=tile_decode,
latents=z_denorm,
tiling_config=tiling_config,
spatial_scale=8, # 3× spatial 2× upsamples = 8×
temporal_scale=4, # 2× temporal upsamples × 2 = 4×
causal_temporal=False, # Wan2.1 uses non-causal temporal (T → 4T)
)

View File

@@ -709,6 +709,67 @@ class Wan22VAEDecoder(nn.Module):
return mx.clip(out, -1.0, 1.0)
def decode_tiled(self, z, tiling_config=None):
"""Decode latents using tiling to reduce memory usage.
Splits the latent tensor into overlapping spatial/temporal tiles,
decodes each tile independently, and blends them with trapezoidal
masks. Reuses the LTX-2 tiling infrastructure with channels-first
adapter (future: refactor tiling.py to be layout-agnostic).
Args:
z: [B, T, H, W, C=48] latent tensor (already denormalized)
tiling_config: Optional TilingConfig. If None, uses default.
Returns:
video: [B, T', H', W', 3] decoded RGB in [-1, 1]
"""
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
if tiling_config is None:
tiling_config = TilingConfig.default()
# Check if tiling is actually needed
b, t, h_px, w_px, c = z.shape
# Latent dimensions (before conv2/decoder upsampling)
h_lat, w_lat = h_px, w_px
needs_tiling = False
if tiling_config.spatial_config is not None:
s_tile = tiling_config.spatial_config.tile_size_in_pixels // 16
if h_lat > s_tile or w_lat > s_tile:
needs_tiling = True
if tiling_config.temporal_config is not None:
t_tile = tiling_config.temporal_config.tile_size_in_frames // 4
if t > t_tile:
needs_tiling = True
if not needs_tiling:
return self(z)
# Transpose to channels-first for decode_with_tiling: [B,T,H,W,C] → [B,C,T,H,W]
z_cf = z.transpose(0, 4, 1, 2, 3)
# Tile decoder: receives (B,C,T,H,W) channels-first, returns (B,3,T',H',W')
def tile_decode(tile_latents, **kwargs):
tile_cl = tile_latents.transpose(0, 2, 3, 4, 1) # → [B,T,H,W,C]
x = self.conv2(tile_cl)
out = self.decoder(x, first_chunk=True)
out = _unpatchify(out, patch_size=2)
out = mx.clip(out, -1.0, 1.0)
return out.transpose(0, 4, 1, 2, 3) # → [B,3,T',H',W']
result_cf = decode_with_tiling(
decoder_fn=tile_decode,
latents=z_cf,
tiling_config=tiling_config,
spatial_scale=16, # 8× conv upsample + 2× unpatchify
temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal)
causal_temporal=True,
)
# Back to channels-last: [B,3,T',H',W'] → [B,T',H',W',3]
return result_cf.transpose(0, 2, 3, 4, 1)
def denormalize_latents(z, mean=None, std=None):
"""Denormalize latents: z = z / (1/std) + mean."""