perf(wan): Add mx.compile and fix first-frame artifacts

This commit is contained in:
Daniel
2026-03-01 18:15:25 +01:00
parent 849cc45d84
commit 9597b7c9c5
4 changed files with 52 additions and 38 deletions

View File

@@ -173,6 +173,13 @@ def generate_video(
# Validate frame count
assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}"
# For T2V: generate 1 extra latent frame so the VAE's causal zero-padding
# artifacts land on throwaway frames. The reference Wan2.2 speech2video.py
# uses a similar "drop_first_motion" approach (drops 3 pixel frames).
# For I2V the reference image provides real first-frame content, so no extra needed.
extra_frames = config.vae_stride[0] if not is_i2v else 0
gen_frames = num_frames + extra_frames
version_str = f"Wan{config.model_version}"
mode_str = "dual-model" if is_dual else "single-model"
pipeline_str = "Image-to-Video" if is_i2v else "Text-to-Video"
@@ -223,7 +230,7 @@ def generate_video(
# Compute target latent shape
z_dim = config.vae_z_dim
t_latent = (num_frames - 1) // vae_stride[0] + 1
t_latent = (gen_frames - 1) // vae_stride[0] + 1
h_latent = height // vae_stride[1]
w_latent = width // vae_stride[2]
target_shape = (z_dim, t_latent, h_latent, w_latent)
@@ -234,6 +241,8 @@ def generate_video(
)
print(f"{Colors.DIM} Latent shape: {target_shape}")
if extra_frames > 0:
print(f" Generating {extra_frames} extra pixel frames to absorb VAE boundary artifacts")
print(f" Sequence length: {seq_len}{Colors.RESET}")
# Load T5 encoder
@@ -439,6 +448,15 @@ def generate_video(
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
t3 = time.time()
# Compile model forward for faster denoising
models_to_compile = (
[high_noise_model, low_noise_model] if is_dual else [single_model]
)
for m in models_to_compile:
m._compiled = mx.compile(m)
# Pre-convert timesteps to Python list to avoid .item() sync each step
timestep_list = sched.timesteps.tolist()
@@ -460,6 +478,9 @@ def generate_video(
kv = cross_kv
rcs = rope_cos_sin
# Use compiled forward when available (faster after first trace)
_call = getattr(model, '_compiled', model)
if cfg_disabled:
# No CFG: B=1 forward pass (2x faster than B=2 CFG batch)
if is_i2v_mask_blend:
@@ -479,7 +500,7 @@ def generate_video(
ctx = context_cond_high if timestep_val >= boundary else context_cond_low
else:
ctx = context_cond
preds = model(
preds = _call(
[latents],
t=t_batch,
context=ctx,
@@ -513,7 +534,7 @@ def generate_video(
ctx = context_cfg if not is_dual else (
context_cfg_high if timestep_val >= boundary else context_cfg_low
)
preds = model(
preds = _call(
[latents, latents],
t=t_batch,
context=ctx,
@@ -564,36 +585,28 @@ def generate_video(
is_wan22_vae = config.vae_z_dim == 48
# Warm-up: prepend a copy of the first latent frame to provide temporal
# context for the real first frame. Causal convolutions in the VAE decoder
# pad with zeros on the left, so the first few output frames have degraded
# quality (no temporal context). By duplicating the first latent, the real
# first frame sees its own features as left context instead of zeros.
# We trim the extra output frames after decoding.
warmup_trim = vae_stride[0] # 4 frames per latent temporal position
latents_for_decode = mx.concatenate([latents[:, 0:1], latents], axis=1)
if is_wan22_vae:
from mlx_video.models.wan.vae22 import denormalize_latents
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
z = latents_for_decode.transpose(1, 2, 3, 0)[None] # [1, T+1, H, W, C]
z = latents.transpose(1, 2, 3, 0)[None]
z = denormalize_latents(z)
video = vae(z) # [1, T', H', W', 3]
video = vae(z)
mx.eval(video)
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
video = np.array(video[0]) # [T', H', W', 3]
video = video[warmup_trim:] # Trim warm-up frames
# Trim extra frames generated for zero-padding warmup
if extra_frames > 0:
video = video[extra_frames:]
video = (video + 1.0) / 2.0
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
else:
video = vae.decode(latents_for_decode[None]) # [1, 3, T+1*4, H, W]
video = vae.decode(latents[None])
mx.eval(video)
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
video = np.array(video[0]) # [3, T', H, W]
video = video[:, warmup_trim:] # Trim warm-up frames (channels-first)
video = (video + 1.0) / 2.0
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
video = video.transpose(1, 2, 3, 0) # [T, H, W, 3]