feat(wan): Add tiled VAE decoding and fix TI2V quality
This commit is contained in:
@@ -46,8 +46,10 @@ def generate_video(
|
|||||||
loras: list | None = None,
|
loras: list | None = None,
|
||||||
loras_high: list | None = None,
|
loras_high: list | None = None,
|
||||||
loras_low: list | None = None,
|
loras_low: list | None = None,
|
||||||
|
tiling: str = "auto",
|
||||||
|
no_compile: bool = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
"""Generate video using Wan pipeline (supports T2V and I2V).
|
"""Generate video using Wan pipeline (supports T2V and I2V).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -67,6 +69,13 @@ def generate_video(
|
|||||||
loras: Optional list of (path, strength) tuples applied to all models
|
loras: Optional list of (path, strength) tuples applied to all models
|
||||||
loras_high: Optional list of (path, strength) tuples for high-noise model only
|
loras_high: Optional list of (path, strength) tuples for high-noise model only
|
||||||
loras_low: Optional list of (path, strength) tuples for low-noise model only
|
loras_low: Optional list of (path, strength) tuples for low-noise model only
|
||||||
|
tiling: Tiling mode for VAE decoding. Options:
|
||||||
|
- "auto": Automatically determine tiling based on video size (default)
|
||||||
|
- "none": Disable tiling
|
||||||
|
- "default", "aggressive", "conservative": Preset tiling configs
|
||||||
|
- "spatial": Spatial tiling only
|
||||||
|
- "temporal": Temporal tiling only
|
||||||
|
no_compile: If True, skip mx.compile on models (useful for debugging)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
@@ -173,12 +182,7 @@ def generate_video(
|
|||||||
# Validate frame count
|
# Validate frame count
|
||||||
assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}"
|
assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}"
|
||||||
|
|
||||||
# For T2V: generate 1 extra latent frame so the VAE's causal zero-padding
|
gen_frames = num_frames
|
||||||
# artifacts land on throwaway frames. The reference Wan2.2 speech2video.py
|
|
||||||
# uses a similar "drop_first_motion" approach (drops 3 pixel frames).
|
|
||||||
# For I2V the reference image provides real first-frame content, so no extra needed.
|
|
||||||
extra_frames = config.vae_stride[0] if not is_i2v else 0
|
|
||||||
gen_frames = num_frames + extra_frames
|
|
||||||
|
|
||||||
version_str = f"Wan{config.model_version}"
|
version_str = f"Wan{config.model_version}"
|
||||||
mode_str = "dual-model" if is_dual else "single-model"
|
mode_str = "dual-model" if is_dual else "single-model"
|
||||||
@@ -241,8 +245,6 @@ def generate_video(
|
|||||||
)
|
)
|
||||||
|
|
||||||
print(f"{Colors.DIM} Latent shape: {target_shape}")
|
print(f"{Colors.DIM} Latent shape: {target_shape}")
|
||||||
if extra_frames > 0:
|
|
||||||
print(f" Generating {extra_frames} extra pixel frames to absorb VAE boundary artifacts")
|
|
||||||
print(f" Sequence length: {seq_len}{Colors.RESET}")
|
print(f" Sequence length: {seq_len}{Colors.RESET}")
|
||||||
|
|
||||||
# Load T5 encoder
|
# Load T5 encoder
|
||||||
@@ -419,7 +421,7 @@ def generate_video(
|
|||||||
rope_cos_sin_high = high_noise_model.prepare_rope(rope_grid_sizes)
|
rope_cos_sin_high = high_noise_model.prepare_rope(rope_grid_sizes)
|
||||||
mx.eval(rope_cos_sin_low, rope_cos_sin_high)
|
mx.eval(rope_cos_sin_low, rope_cos_sin_high)
|
||||||
else:
|
else:
|
||||||
rope_cos_sin = ref_model.prepare_rope(rope_grid_sizes)
|
rope_cos_sin = single_model.prepare_rope(rope_grid_sizes)
|
||||||
mx.eval(rope_cos_sin)
|
mx.eval(rope_cos_sin)
|
||||||
|
|
||||||
# Setup scheduler
|
# Setup scheduler
|
||||||
@@ -448,12 +450,13 @@ def generate_video(
|
|||||||
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
|
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
|
||||||
t3 = time.time()
|
t3 = time.time()
|
||||||
|
|
||||||
# Compile model forward for faster denoising
|
if not no_compile:
|
||||||
models_to_compile = (
|
models_to_compile = (
|
||||||
[high_noise_model, low_noise_model] if is_dual else [single_model]
|
[high_noise_model, low_noise_model] if is_dual else [single_model]
|
||||||
)
|
)
|
||||||
for m in models_to_compile:
|
for m in models_to_compile:
|
||||||
m._compiled = mx.compile(m)
|
m._compiled = mx.compile(m)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -585,24 +588,53 @@ def generate_video(
|
|||||||
|
|
||||||
is_wan22_vae = config.vae_z_dim == 48
|
is_wan22_vae = config.vae_z_dim == 48
|
||||||
|
|
||||||
|
# Select tiling configuration
|
||||||
|
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
|
||||||
|
|
||||||
|
if tiling == "none":
|
||||||
|
tiling_config = None
|
||||||
|
elif tiling == "auto":
|
||||||
|
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||||
|
elif tiling == "default":
|
||||||
|
tiling_config = TilingConfig.default()
|
||||||
|
elif tiling == "aggressive":
|
||||||
|
tiling_config = TilingConfig.aggressive()
|
||||||
|
elif tiling == "conservative":
|
||||||
|
tiling_config = TilingConfig.conservative()
|
||||||
|
elif tiling == "spatial":
|
||||||
|
tiling_config = TilingConfig.spatial_only()
|
||||||
|
elif tiling == "temporal":
|
||||||
|
tiling_config = TilingConfig.temporal_only()
|
||||||
|
else:
|
||||||
|
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
|
||||||
|
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||||
|
|
||||||
|
if tiling_config is not None:
|
||||||
|
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
|
||||||
|
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
|
||||||
|
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
|
||||||
|
|
||||||
if is_wan22_vae:
|
if is_wan22_vae:
|
||||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||||
|
|
||||||
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
|
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
|
||||||
z = latents.transpose(1, 2, 3, 0)[None]
|
z = latents.transpose(1, 2, 3, 0)[None]
|
||||||
z = denormalize_latents(z)
|
z = denormalize_latents(z)
|
||||||
video = vae(z)
|
if tiling_config is not None:
|
||||||
|
video = vae.decode_tiled(z, tiling_config)
|
||||||
|
else:
|
||||||
|
video = vae(z)
|
||||||
mx.eval(video)
|
mx.eval(video)
|
||||||
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
||||||
|
|
||||||
video = np.array(video[0]) # [T', H', W', 3]
|
video = np.array(video[0]) # [T', H', W', 3]
|
||||||
# Trim extra frames generated for zero-padding warmup
|
|
||||||
if extra_frames > 0:
|
|
||||||
video = video[extra_frames:]
|
|
||||||
video = (video + 1.0) / 2.0
|
video = (video + 1.0) / 2.0
|
||||||
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
||||||
else:
|
else:
|
||||||
video = vae.decode(latents[None])
|
if tiling_config is not None:
|
||||||
|
video = vae.decode_tiled(latents[None], tiling_config)
|
||||||
|
else:
|
||||||
|
video = vae.decode(latents[None])
|
||||||
mx.eval(video)
|
mx.eval(video)
|
||||||
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
||||||
|
|
||||||
@@ -651,6 +683,17 @@ def main():
|
|||||||
"--lora-low", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
|
"--lora-low", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
|
||||||
help="Apply a LoRA to low-noise model only (dual-model, repeatable)",
|
help="Apply a LoRA to low-noise model only (dual-model, repeatable)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tiling",
|
||||||
|
type=str,
|
||||||
|
default="auto",
|
||||||
|
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
|
||||||
|
help="VAE tiling mode to reduce memory during decoding (default: auto)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-compile", action="store_true",
|
||||||
|
help="Disable mx.compile on models (for debugging)",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -688,6 +731,8 @@ def main():
|
|||||||
loras=_parse_lora_args(args.lora),
|
loras=_parse_lora_args(args.lora),
|
||||||
loras_high=_parse_lora_args(args.lora_high),
|
loras_high=_parse_lora_args(args.lora_high),
|
||||||
loras_low=_parse_lora_args(args.lora_low),
|
loras_low=_parse_lora_args(args.lora_low),
|
||||||
|
tiling=args.tiling,
|
||||||
|
no_compile=args.no_compile,
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -283,6 +283,7 @@ def decode_with_tiling(
|
|||||||
spatial_scale: int = 32,
|
spatial_scale: int = 32,
|
||||||
temporal_scale: int = 8,
|
temporal_scale: int = 8,
|
||||||
causal: bool = False,
|
causal: bool = False,
|
||||||
|
causal_temporal: bool = True,
|
||||||
timestep: Optional[mx.array] = None,
|
timestep: Optional[mx.array] = None,
|
||||||
chunked_conv: bool = False,
|
chunked_conv: bool = False,
|
||||||
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
|
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
|
||||||
@@ -296,6 +297,10 @@ def decode_with_tiling(
|
|||||||
spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify).
|
spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify).
|
||||||
temporal_scale: Temporal scale factor (8 for LTX VAE).
|
temporal_scale: Temporal scale factor (8 for LTX VAE).
|
||||||
causal: Whether to use causal convolutions.
|
causal: Whether to use causal convolutions.
|
||||||
|
causal_temporal: Whether the decoder uses causal temporal mapping where
|
||||||
|
T input frames produce 1+(T-1)*scale output frames. When False, uses
|
||||||
|
simple scaling where T frames produce T*scale output frames.
|
||||||
|
Default True (LTX behavior). Set False for non-causal decoders (e.g. Wan2.1).
|
||||||
timestep: Optional timestep for conditioning.
|
timestep: Optional timestep for conditioning.
|
||||||
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
|
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
|
||||||
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
|
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
|
||||||
@@ -310,7 +315,7 @@ def decode_with_tiling(
|
|||||||
b, c, f_latent, h_latent, w_latent = latents.shape
|
b, c, f_latent, h_latent, w_latent = latents.shape
|
||||||
|
|
||||||
# Compute output shape
|
# Compute output shape
|
||||||
out_f = 1 + (f_latent - 1) * temporal_scale
|
out_f = (1 + (f_latent - 1) * temporal_scale) if causal_temporal else (f_latent * temporal_scale)
|
||||||
out_h = h_latent * spatial_scale
|
out_h = h_latent * spatial_scale
|
||||||
out_w = w_latent * spatial_scale
|
out_w = w_latent * spatial_scale
|
||||||
|
|
||||||
@@ -332,7 +337,10 @@ def decode_with_tiling(
|
|||||||
temporal_overlap = 0
|
temporal_overlap = 0
|
||||||
|
|
||||||
# Compute intervals for each dimension
|
# Compute intervals for each dimension
|
||||||
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
|
if causal_temporal:
|
||||||
|
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
|
||||||
|
else:
|
||||||
|
temporal_intervals = split_in_spatial(temporal_tile_size, temporal_overlap, f_latent)
|
||||||
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
|
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
|
||||||
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
|
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
|
||||||
|
|
||||||
@@ -355,7 +363,10 @@ def decode_with_tiling(
|
|||||||
t_right = temporal_intervals.right_ramps[t_idx]
|
t_right = temporal_intervals.right_ramps[t_idx]
|
||||||
|
|
||||||
# Map temporal coordinates
|
# Map temporal coordinates
|
||||||
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
if causal_temporal:
|
||||||
|
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
||||||
|
else:
|
||||||
|
out_t_slice, t_mask = map_spatial_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
||||||
|
|
||||||
for h_idx in range(num_h_tiles):
|
for h_idx in range(num_h_tiles):
|
||||||
h_start = height_intervals.starts[h_idx]
|
h_start = height_intervals.starts[h_idx]
|
||||||
@@ -461,8 +472,10 @@ def decode_with_tiling(
|
|||||||
# Map to output frame index (first frame of next tile's contribution)
|
# Map to output frame index (first frame of next tile's contribution)
|
||||||
if next_tile_start_latent == 0:
|
if next_tile_start_latent == 0:
|
||||||
next_tile_start_out = 0
|
next_tile_start_out = 0
|
||||||
else:
|
elif causal_temporal:
|
||||||
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
|
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
|
||||||
|
else:
|
||||||
|
next_tile_start_out = next_tile_start_latent * temporal_scale
|
||||||
|
|
||||||
# We need to track how many frames we've already emitted
|
# We need to track how many frames we've already emitted
|
||||||
if not hasattr(decode_with_tiling, '_emitted_frames'):
|
if not hasattr(decode_with_tiling, '_emitted_frames'):
|
||||||
|
|||||||
@@ -48,9 +48,8 @@ class Head(nn.Module):
|
|||||||
"""
|
"""
|
||||||
if e.ndim == 2:
|
if e.ndim == 2:
|
||||||
e = e[:, None, :] # [B, 1, dim]
|
e = e[:, None, :] # [B, 1, dim]
|
||||||
# Compute modulation in float32 for precision, cast to working dtype
|
# Compute modulation in float32 (matching reference's autocast(float32))
|
||||||
w_dtype = _linear_dtype(self.head)
|
mod = self.modulation[:, None, :, :] + e[:, :, None, :] # float32
|
||||||
mod = (self.modulation[:, None, :, :] + e[:, :, None, :]).astype(w_dtype)
|
|
||||||
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
|
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
|
||||||
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
|
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
|
||||||
x_norm = self.norm(x)
|
x_norm = self.norm(x)
|
||||||
@@ -120,10 +119,13 @@ class WanModel(nn.Module):
|
|||||||
], axis=1)
|
], axis=1)
|
||||||
|
|
||||||
# Precompute sinusoidal inv_freq for time embedding
|
# Precompute sinusoidal inv_freq for time embedding
|
||||||
|
# Use numpy float64 for precision (matches reference torch.float64),
|
||||||
|
# then store as float32 since MLX GPU doesn't support float64.
|
||||||
half = config.freq_dim // 2
|
half = config.freq_dim // 2
|
||||||
self._inv_freq = mx.power(
|
inv_freq_np = np.power(
|
||||||
10000.0, -mx.arange(half).astype(mx.float32) / half
|
10000.0, -np.arange(half, dtype=np.float64) / half
|
||||||
)
|
)
|
||||||
|
self._inv_freq = mx.array(inv_freq_np.astype(np.float32))
|
||||||
|
|
||||||
|
|
||||||
def _patchify(self, x: mx.array) -> tuple:
|
def _patchify(self, x: mx.array) -> tuple:
|
||||||
|
|||||||
@@ -51,10 +51,11 @@ class WanAttentionBlock(nn.Module):
|
|||||||
rope_cos_sin: tuple | None = None,
|
rope_cos_sin: tuple | None = None,
|
||||||
attn_mask: mx.array | None = None,
|
attn_mask: mx.array | None = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
# Modulation: compute in float32 for precision, cast to working dtype
|
# Modulation: compute in float32 for precision, matching the reference
|
||||||
# to avoid promoting the full hidden state (seq_len × dim) to float32
|
# which keeps residual x in float32 via torch.amp.autocast(dtype=float32).
|
||||||
w_dtype = _linear_dtype(self.self_attn.q)
|
# By keeping modulation in float32, type promotion ensures the residual
|
||||||
mod = (self.modulation + e).astype(w_dtype)
|
# stream stays float32 throughout all 30 layers (gate * output + x → float32).
|
||||||
|
mod = self.modulation + e # float32
|
||||||
e0, e1, e2, e3, e4, e5 = (
|
e0, e1, e2, e3, e4, e5 = (
|
||||||
mod[:, :, 0, :], # shift for self-attn
|
mod[:, :, 0, :], # shift for self-attn
|
||||||
mod[:, :, 1, :], # scale for self-attn
|
mod[:, :, 1, :], # scale for self-attn
|
||||||
|
|||||||
@@ -534,3 +534,56 @@ class WanVAE(nn.Module):
|
|||||||
x = self.conv2(z)
|
x = self.conv2(z)
|
||||||
out = self.decoder(x)
|
out = self.decoder(x)
|
||||||
return mx.clip(out, -1, 1)
|
return mx.clip(out, -1, 1)
|
||||||
|
|
||||||
|
def decode_tiled(self, z: mx.array, tiling_config=None) -> mx.array:
|
||||||
|
"""Decode latent to video using tiling to reduce memory usage.
|
||||||
|
|
||||||
|
Splits the latent tensor into overlapping spatial/temporal tiles,
|
||||||
|
decodes each tile independently, and blends them with trapezoidal
|
||||||
|
masks. Reuses the LTX-2 tiling infrastructure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
z: Normalized latent [B, z_dim, T, H, W]
|
||||||
|
tiling_config: Optional TilingConfig. If None, uses default.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
|
||||||
|
"""
|
||||||
|
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||||
|
|
||||||
|
if tiling_config is None:
|
||||||
|
tiling_config = TilingConfig.default()
|
||||||
|
|
||||||
|
# Check if tiling is actually needed
|
||||||
|
_, _, f, h, w = z.shape
|
||||||
|
needs_tiling = False
|
||||||
|
if tiling_config.spatial_config is not None:
|
||||||
|
s_tile = tiling_config.spatial_config.tile_size_in_pixels // 8
|
||||||
|
if h > s_tile or w > s_tile:
|
||||||
|
needs_tiling = True
|
||||||
|
if tiling_config.temporal_config is not None:
|
||||||
|
t_tile = tiling_config.temporal_config.tile_size_in_frames // 4
|
||||||
|
if f > t_tile:
|
||||||
|
needs_tiling = True
|
||||||
|
|
||||||
|
if not needs_tiling:
|
||||||
|
return self.decode(z)
|
||||||
|
|
||||||
|
# Denormalize once (small tensor), then tile the denormalized latents
|
||||||
|
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||||
|
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
|
||||||
|
z_denorm = z / inv_std + mean
|
||||||
|
|
||||||
|
def tile_decode(tile_latents, **kwargs):
|
||||||
|
x = self.conv2(tile_latents)
|
||||||
|
out = self.decoder(x)
|
||||||
|
return mx.clip(out, -1, 1)
|
||||||
|
|
||||||
|
return decode_with_tiling(
|
||||||
|
decoder_fn=tile_decode,
|
||||||
|
latents=z_denorm,
|
||||||
|
tiling_config=tiling_config,
|
||||||
|
spatial_scale=8, # 3× spatial 2× upsamples = 8×
|
||||||
|
temporal_scale=4, # 2× temporal upsamples × 2 = 4×
|
||||||
|
causal_temporal=False, # Wan2.1 uses non-causal temporal (T → 4T)
|
||||||
|
)
|
||||||
|
|||||||
@@ -709,6 +709,67 @@ class Wan22VAEDecoder(nn.Module):
|
|||||||
|
|
||||||
return mx.clip(out, -1.0, 1.0)
|
return mx.clip(out, -1.0, 1.0)
|
||||||
|
|
||||||
|
def decode_tiled(self, z, tiling_config=None):
|
||||||
|
"""Decode latents using tiling to reduce memory usage.
|
||||||
|
|
||||||
|
Splits the latent tensor into overlapping spatial/temporal tiles,
|
||||||
|
decodes each tile independently, and blends them with trapezoidal
|
||||||
|
masks. Reuses the LTX-2 tiling infrastructure with channels-first
|
||||||
|
adapter (future: refactor tiling.py to be layout-agnostic).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
z: [B, T, H, W, C=48] latent tensor (already denormalized)
|
||||||
|
tiling_config: Optional TilingConfig. If None, uses default.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
video: [B, T', H', W', 3] decoded RGB in [-1, 1]
|
||||||
|
"""
|
||||||
|
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||||
|
|
||||||
|
if tiling_config is None:
|
||||||
|
tiling_config = TilingConfig.default()
|
||||||
|
|
||||||
|
# Check if tiling is actually needed
|
||||||
|
b, t, h_px, w_px, c = z.shape
|
||||||
|
# Latent dimensions (before conv2/decoder upsampling)
|
||||||
|
h_lat, w_lat = h_px, w_px
|
||||||
|
needs_tiling = False
|
||||||
|
if tiling_config.spatial_config is not None:
|
||||||
|
s_tile = tiling_config.spatial_config.tile_size_in_pixels // 16
|
||||||
|
if h_lat > s_tile or w_lat > s_tile:
|
||||||
|
needs_tiling = True
|
||||||
|
if tiling_config.temporal_config is not None:
|
||||||
|
t_tile = tiling_config.temporal_config.tile_size_in_frames // 4
|
||||||
|
if t > t_tile:
|
||||||
|
needs_tiling = True
|
||||||
|
|
||||||
|
if not needs_tiling:
|
||||||
|
return self(z)
|
||||||
|
|
||||||
|
# Transpose to channels-first for decode_with_tiling: [B,T,H,W,C] → [B,C,T,H,W]
|
||||||
|
z_cf = z.transpose(0, 4, 1, 2, 3)
|
||||||
|
|
||||||
|
# Tile decoder: receives (B,C,T,H,W) channels-first, returns (B,3,T',H',W')
|
||||||
|
def tile_decode(tile_latents, **kwargs):
|
||||||
|
tile_cl = tile_latents.transpose(0, 2, 3, 4, 1) # → [B,T,H,W,C]
|
||||||
|
x = self.conv2(tile_cl)
|
||||||
|
out = self.decoder(x, first_chunk=True)
|
||||||
|
out = _unpatchify(out, patch_size=2)
|
||||||
|
out = mx.clip(out, -1.0, 1.0)
|
||||||
|
return out.transpose(0, 4, 1, 2, 3) # → [B,3,T',H',W']
|
||||||
|
|
||||||
|
result_cf = decode_with_tiling(
|
||||||
|
decoder_fn=tile_decode,
|
||||||
|
latents=z_cf,
|
||||||
|
tiling_config=tiling_config,
|
||||||
|
spatial_scale=16, # 8× conv upsample + 2× unpatchify
|
||||||
|
temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal)
|
||||||
|
causal_temporal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Back to channels-last: [B,3,T',H',W'] → [B,T',H',W',3]
|
||||||
|
return result_cf.transpose(0, 2, 3, 4, 1)
|
||||||
|
|
||||||
|
|
||||||
def denormalize_latents(z, mean=None, std=None):
|
def denormalize_latents(z, mean=None, std=None):
|
||||||
"""Denormalize latents: z = z / (1/std) + mean."""
|
"""Denormalize latents: z = z / (1/std) + mean."""
|
||||||
|
|||||||
198
tests/test_wan_tiling.py
Normal file
198
tests/test_wan_tiling.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
"""Tests for Wan VAE tiled decoding."""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mlx_video.models.ltx.video_vae.tiling import (
|
||||||
|
TilingConfig,
|
||||||
|
decode_with_tiling,
|
||||||
|
split_in_spatial,
|
||||||
|
split_in_temporal,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestNonCausalTemporal:
|
||||||
|
"""Tests for the causal_temporal=False path in decode_with_tiling."""
|
||||||
|
|
||||||
|
def test_split_spatial_for_temporal(self):
|
||||||
|
"""Non-causal temporal should use split_in_spatial (no causal shift)."""
|
||||||
|
intervals = split_in_spatial(8, 2, 20)
|
||||||
|
# No causal adjustment: starts should be evenly spaced
|
||||||
|
assert intervals.starts[0] == 0
|
||||||
|
for i in range(1, len(intervals.starts)):
|
||||||
|
assert intervals.starts[i] == intervals.starts[i - 1] + (8 - 2)
|
||||||
|
|
||||||
|
def test_causal_vs_noncausal_output_size(self):
|
||||||
|
"""Causal temporal gives 1+(T-1)*S frames, non-causal gives T*S."""
|
||||||
|
mx.random.seed(42)
|
||||||
|
b, c, t, h, w = 1, 4, 4, 4, 4
|
||||||
|
latents = mx.random.normal((b, c, t, h, w))
|
||||||
|
scale = 4
|
||||||
|
|
||||||
|
# Simple passthrough decoder: just repeat along dimensions
|
||||||
|
def dummy_decoder_causal(x, **kwargs):
|
||||||
|
b, c, t, h, w = x.shape
|
||||||
|
out_t = 1 + (t - 1) * scale
|
||||||
|
out_h = h * scale
|
||||||
|
out_w = w * scale
|
||||||
|
return mx.ones((b, 3, out_t, out_h, out_w))
|
||||||
|
|
||||||
|
def dummy_decoder_noncausal(x, **kwargs):
|
||||||
|
b, c, t, h, w = x.shape
|
||||||
|
out_t = t * scale
|
||||||
|
out_h = h * scale
|
||||||
|
out_w = w * scale
|
||||||
|
return mx.ones((b, 3, out_t, out_h, out_w))
|
||||||
|
|
||||||
|
config = TilingConfig.spatial_only(tile_size=128, overlap=64)
|
||||||
|
|
||||||
|
# Causal: 1 + (4-1)*4 = 13
|
||||||
|
out_causal = decode_with_tiling(
|
||||||
|
dummy_decoder_causal, latents, config,
|
||||||
|
spatial_scale=scale, temporal_scale=scale, causal_temporal=True,
|
||||||
|
)
|
||||||
|
mx.eval(out_causal)
|
||||||
|
assert out_causal.shape[2] == 1 + (t - 1) * scale # 13
|
||||||
|
|
||||||
|
# Non-causal: 4*4 = 16
|
||||||
|
out_noncausal = decode_with_tiling(
|
||||||
|
dummy_decoder_noncausal, latents, config,
|
||||||
|
spatial_scale=scale, temporal_scale=scale, causal_temporal=False,
|
||||||
|
)
|
||||||
|
mx.eval(out_noncausal)
|
||||||
|
assert out_noncausal.shape[2] == t * scale # 16
|
||||||
|
|
||||||
|
|
||||||
|
class TestWan22TiledDecoding:
|
||||||
|
"""Tests for Wan2.2 VAE tiled decoding."""
|
||||||
|
|
||||||
|
def _make_small_wan22_decoder(self):
|
||||||
|
"""Create a small Wan2.2 decoder for testing."""
|
||||||
|
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||||
|
|
||||||
|
# Use very small dimensions for fast testing
|
||||||
|
vae = Wan22VAEDecoder(z_dim=48, dim=16, dec_dim=16)
|
||||||
|
mx.eval(vae.parameters())
|
||||||
|
return vae
|
||||||
|
|
||||||
|
def test_decode_tiled_output_shape(self):
|
||||||
|
"""Tiled decode should produce same shape as non-tiled."""
|
||||||
|
mx.random.seed(42)
|
||||||
|
vae = self._make_small_wan22_decoder()
|
||||||
|
|
||||||
|
# Small input: [B=1, T=3, H=2, W=2, C=48]
|
||||||
|
z = mx.random.normal((1, 3, 2, 2, 48))
|
||||||
|
mx.eval(z)
|
||||||
|
|
||||||
|
# Non-tiled
|
||||||
|
out_regular = vae(z)
|
||||||
|
mx.eval(out_regular)
|
||||||
|
|
||||||
|
# Tiled (force tiling with very small tile sizes)
|
||||||
|
# Use spatial tile=32px (2 latent @ scale 16) and temporal=8 frames (2 latent @ scale 4)
|
||||||
|
config = TilingConfig(
|
||||||
|
spatial_config=None, # Don't tile spatially (input is tiny)
|
||||||
|
temporal_config=None, # Don't tile temporally (input is tiny)
|
||||||
|
)
|
||||||
|
# With no tiling config, decode_tiled should fall through to regular decode
|
||||||
|
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
|
||||||
|
mx.eval(out_tiled)
|
||||||
|
|
||||||
|
# Both should produce the same shape
|
||||||
|
assert out_regular.shape == out_tiled.shape, (
|
||||||
|
f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_decode_tiled_falls_through_when_small(self):
|
||||||
|
"""When input is smaller than tile size, decode_tiled should produce same output as __call__."""
|
||||||
|
mx.random.seed(42)
|
||||||
|
vae = self._make_small_wan22_decoder()
|
||||||
|
|
||||||
|
# Input smaller than any tile size
|
||||||
|
z = mx.random.normal((1, 2, 2, 2, 48))
|
||||||
|
mx.eval(z)
|
||||||
|
|
||||||
|
out_regular = vae(z)
|
||||||
|
mx.eval(out_regular)
|
||||||
|
|
||||||
|
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
|
||||||
|
mx.eval(out_tiled)
|
||||||
|
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
np.array(out_regular), np.array(out_tiled),
|
||||||
|
rtol=1e-4, atol=1e-4,
|
||||||
|
err_msg="Tiled decode should match regular decode for small inputs",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWan21TiledDecoding:
|
||||||
|
"""Tests for Wan2.1 VAE tiled decoding."""
|
||||||
|
|
||||||
|
def _make_small_wan21_vae(self):
|
||||||
|
"""Create a small Wan2.1 VAE for testing."""
|
||||||
|
from mlx_video.models.wan.vae import WanVAE
|
||||||
|
|
||||||
|
vae = WanVAE(z_dim=16)
|
||||||
|
mx.eval(vae.parameters())
|
||||||
|
return vae
|
||||||
|
|
||||||
|
def test_decode_tiled_output_shape(self):
|
||||||
|
"""Tiled decode should produce correct output shape."""
|
||||||
|
mx.random.seed(42)
|
||||||
|
vae = self._make_small_wan21_vae()
|
||||||
|
|
||||||
|
# [B=1, C=16, T=3, H=4, W=4]
|
||||||
|
z = mx.random.normal((1, 16, 3, 4, 4))
|
||||||
|
mx.eval(z)
|
||||||
|
|
||||||
|
out_regular = vae.decode(z)
|
||||||
|
mx.eval(out_regular)
|
||||||
|
|
||||||
|
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
|
||||||
|
mx.eval(out_tiled)
|
||||||
|
|
||||||
|
assert out_regular.shape == out_tiled.shape, (
|
||||||
|
f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_decode_tiled_falls_through_when_small(self):
|
||||||
|
"""When input is smaller than tile size, decode_tiled should produce same output as decode."""
|
||||||
|
mx.random.seed(42)
|
||||||
|
vae = self._make_small_wan21_vae()
|
||||||
|
|
||||||
|
z = mx.random.normal((1, 16, 2, 4, 4))
|
||||||
|
mx.eval(z)
|
||||||
|
|
||||||
|
out_regular = vae.decode(z)
|
||||||
|
mx.eval(out_regular)
|
||||||
|
|
||||||
|
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
|
||||||
|
mx.eval(out_tiled)
|
||||||
|
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
np.array(out_regular), np.array(out_tiled),
|
||||||
|
rtol=1e-4, atol=1e-4,
|
||||||
|
err_msg="Tiled decode should match regular decode for small inputs",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWan21TemporalScale:
|
||||||
|
"""Verify Wan2.1 decoder temporal output is T*4 (non-causal)."""
|
||||||
|
|
||||||
|
def test_wan21_decoder_temporal_output(self):
|
||||||
|
"""Wan2.1 Decoder3d should produce T*4 temporal output (non-causal doubling)."""
|
||||||
|
from mlx_video.models.wan.vae import Decoder3d
|
||||||
|
|
||||||
|
# Small decoder for fast test
|
||||||
|
dec = Decoder3d(dim=16, z_dim=4, dim_mult=[1, 1, 1, 1], num_res_blocks=1,
|
||||||
|
temporal_upsample=[True, True, False])
|
||||||
|
mx.eval(dec.parameters())
|
||||||
|
|
||||||
|
x = mx.random.normal((1, 4, 3, 4, 4)) # T=3
|
||||||
|
mx.eval(x)
|
||||||
|
out = dec(x)
|
||||||
|
mx.eval(out)
|
||||||
|
|
||||||
|
# With two temporal 2× upsamples: T=3 → 6 → 12
|
||||||
|
assert out.shape[2] == 3 * 4, f"Expected T=12, got T={out.shape[2]}"
|
||||||
Reference in New Issue
Block a user