feat(wan): Add tiled VAE decoding and fix TI2V quality
This commit is contained in:
@@ -46,8 +46,10 @@ def generate_video(
|
||||
loras: list | None = None,
|
||||
loras_high: list | None = None,
|
||||
loras_low: list | None = None,
|
||||
|
||||
tiling: str = "auto",
|
||||
no_compile: bool = False,
|
||||
):
|
||||
|
||||
"""Generate video using Wan pipeline (supports T2V and I2V).
|
||||
|
||||
Args:
|
||||
@@ -67,6 +69,13 @@ def generate_video(
|
||||
loras: Optional list of (path, strength) tuples applied to all models
|
||||
loras_high: Optional list of (path, strength) tuples for high-noise model only
|
||||
loras_low: Optional list of (path, strength) tuples for low-noise model only
|
||||
tiling: Tiling mode for VAE decoding. Options:
|
||||
- "auto": Automatically determine tiling based on video size (default)
|
||||
- "none": Disable tiling
|
||||
- "default", "aggressive", "conservative": Preset tiling configs
|
||||
- "spatial": Spatial tiling only
|
||||
- "temporal": Temporal tiling only
|
||||
no_compile: If True, skip mx.compile on models (useful for debugging)
|
||||
|
||||
"""
|
||||
import json
|
||||
@@ -173,12 +182,7 @@ def generate_video(
|
||||
# Validate frame count
|
||||
assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}"
|
||||
|
||||
# For T2V: generate 1 extra latent frame so the VAE's causal zero-padding
|
||||
# artifacts land on throwaway frames. The reference Wan2.2 speech2video.py
|
||||
# uses a similar "drop_first_motion" approach (drops 3 pixel frames).
|
||||
# For I2V the reference image provides real first-frame content, so no extra needed.
|
||||
extra_frames = config.vae_stride[0] if not is_i2v else 0
|
||||
gen_frames = num_frames + extra_frames
|
||||
gen_frames = num_frames
|
||||
|
||||
version_str = f"Wan{config.model_version}"
|
||||
mode_str = "dual-model" if is_dual else "single-model"
|
||||
@@ -241,8 +245,6 @@ def generate_video(
|
||||
)
|
||||
|
||||
print(f"{Colors.DIM} Latent shape: {target_shape}")
|
||||
if extra_frames > 0:
|
||||
print(f" Generating {extra_frames} extra pixel frames to absorb VAE boundary artifacts")
|
||||
print(f" Sequence length: {seq_len}{Colors.RESET}")
|
||||
|
||||
# Load T5 encoder
|
||||
@@ -419,7 +421,7 @@ def generate_video(
|
||||
rope_cos_sin_high = high_noise_model.prepare_rope(rope_grid_sizes)
|
||||
mx.eval(rope_cos_sin_low, rope_cos_sin_high)
|
||||
else:
|
||||
rope_cos_sin = ref_model.prepare_rope(rope_grid_sizes)
|
||||
rope_cos_sin = single_model.prepare_rope(rope_grid_sizes)
|
||||
mx.eval(rope_cos_sin)
|
||||
|
||||
# Setup scheduler
|
||||
@@ -448,12 +450,13 @@ def generate_video(
|
||||
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
|
||||
t3 = time.time()
|
||||
|
||||
# Compile model forward for faster denoising
|
||||
models_to_compile = (
|
||||
[high_noise_model, low_noise_model] if is_dual else [single_model]
|
||||
)
|
||||
for m in models_to_compile:
|
||||
m._compiled = mx.compile(m)
|
||||
if not no_compile:
|
||||
models_to_compile = (
|
||||
[high_noise_model, low_noise_model] if is_dual else [single_model]
|
||||
)
|
||||
for m in models_to_compile:
|
||||
m._compiled = mx.compile(m)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -585,24 +588,53 @@ def generate_video(
|
||||
|
||||
is_wan22_vae = config.vae_z_dim == 48
|
||||
|
||||
# Select tiling configuration
|
||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
|
||||
|
||||
if tiling == "none":
|
||||
tiling_config = None
|
||||
elif tiling == "auto":
|
||||
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||
elif tiling == "default":
|
||||
tiling_config = TilingConfig.default()
|
||||
elif tiling == "aggressive":
|
||||
tiling_config = TilingConfig.aggressive()
|
||||
elif tiling == "conservative":
|
||||
tiling_config = TilingConfig.conservative()
|
||||
elif tiling == "spatial":
|
||||
tiling_config = TilingConfig.spatial_only()
|
||||
elif tiling == "temporal":
|
||||
tiling_config = TilingConfig.temporal_only()
|
||||
else:
|
||||
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
|
||||
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||
|
||||
if tiling_config is not None:
|
||||
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
|
||||
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
|
||||
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
|
||||
|
||||
if is_wan22_vae:
|
||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||
|
||||
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
|
||||
z = latents.transpose(1, 2, 3, 0)[None]
|
||||
z = denormalize_latents(z)
|
||||
video = vae(z)
|
||||
if tiling_config is not None:
|
||||
video = vae.decode_tiled(z, tiling_config)
|
||||
else:
|
||||
video = vae(z)
|
||||
mx.eval(video)
|
||||
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
||||
|
||||
video = np.array(video[0]) # [T', H', W', 3]
|
||||
# Trim extra frames generated for zero-padding warmup
|
||||
if extra_frames > 0:
|
||||
video = video[extra_frames:]
|
||||
video = (video + 1.0) / 2.0
|
||||
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
||||
else:
|
||||
video = vae.decode(latents[None])
|
||||
if tiling_config is not None:
|
||||
video = vae.decode_tiled(latents[None], tiling_config)
|
||||
else:
|
||||
video = vae.decode(latents[None])
|
||||
mx.eval(video)
|
||||
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
||||
|
||||
@@ -651,6 +683,17 @@ def main():
|
||||
"--lora-low", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
|
||||
help="Apply a LoRA to low-noise model only (dual-model, repeatable)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tiling",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
|
||||
help="VAE tiling mode to reduce memory during decoding (default: auto)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-compile", action="store_true",
|
||||
help="Disable mx.compile on models (for debugging)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -688,6 +731,8 @@ def main():
|
||||
loras=_parse_lora_args(args.lora),
|
||||
loras_high=_parse_lora_args(args.lora_high),
|
||||
loras_low=_parse_lora_args(args.lora_low),
|
||||
tiling=args.tiling,
|
||||
no_compile=args.no_compile,
|
||||
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user