perf(wan): Add mx.compile and fix first-frame artifacts
This commit is contained in:
@@ -173,6 +173,13 @@ def generate_video(
|
|||||||
# Validate frame count
|
# Validate frame count
|
||||||
assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}"
|
assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}"
|
||||||
|
|
||||||
|
# For T2V: generate 1 extra latent frame so the VAE's causal zero-padding
|
||||||
|
# artifacts land on throwaway frames. The reference Wan2.2 speech2video.py
|
||||||
|
# uses a similar "drop_first_motion" approach (drops 3 pixel frames).
|
||||||
|
# For I2V the reference image provides real first-frame content, so no extra needed.
|
||||||
|
extra_frames = config.vae_stride[0] if not is_i2v else 0
|
||||||
|
gen_frames = num_frames + extra_frames
|
||||||
|
|
||||||
version_str = f"Wan{config.model_version}"
|
version_str = f"Wan{config.model_version}"
|
||||||
mode_str = "dual-model" if is_dual else "single-model"
|
mode_str = "dual-model" if is_dual else "single-model"
|
||||||
pipeline_str = "Image-to-Video" if is_i2v else "Text-to-Video"
|
pipeline_str = "Image-to-Video" if is_i2v else "Text-to-Video"
|
||||||
@@ -223,7 +230,7 @@ def generate_video(
|
|||||||
|
|
||||||
# Compute target latent shape
|
# Compute target latent shape
|
||||||
z_dim = config.vae_z_dim
|
z_dim = config.vae_z_dim
|
||||||
t_latent = (num_frames - 1) // vae_stride[0] + 1
|
t_latent = (gen_frames - 1) // vae_stride[0] + 1
|
||||||
h_latent = height // vae_stride[1]
|
h_latent = height // vae_stride[1]
|
||||||
w_latent = width // vae_stride[2]
|
w_latent = width // vae_stride[2]
|
||||||
target_shape = (z_dim, t_latent, h_latent, w_latent)
|
target_shape = (z_dim, t_latent, h_latent, w_latent)
|
||||||
@@ -234,6 +241,8 @@ def generate_video(
|
|||||||
)
|
)
|
||||||
|
|
||||||
print(f"{Colors.DIM} Latent shape: {target_shape}")
|
print(f"{Colors.DIM} Latent shape: {target_shape}")
|
||||||
|
if extra_frames > 0:
|
||||||
|
print(f" Generating {extra_frames} extra pixel frames to absorb VAE boundary artifacts")
|
||||||
print(f" Sequence length: {seq_len}{Colors.RESET}")
|
print(f" Sequence length: {seq_len}{Colors.RESET}")
|
||||||
|
|
||||||
# Load T5 encoder
|
# Load T5 encoder
|
||||||
@@ -439,6 +448,15 @@ def generate_video(
|
|||||||
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
|
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
|
||||||
t3 = time.time()
|
t3 = time.time()
|
||||||
|
|
||||||
|
# Compile model forward for faster denoising
|
||||||
|
models_to_compile = (
|
||||||
|
[high_noise_model, low_noise_model] if is_dual else [single_model]
|
||||||
|
)
|
||||||
|
for m in models_to_compile:
|
||||||
|
m._compiled = mx.compile(m)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Pre-convert timesteps to Python list to avoid .item() sync each step
|
# Pre-convert timesteps to Python list to avoid .item() sync each step
|
||||||
timestep_list = sched.timesteps.tolist()
|
timestep_list = sched.timesteps.tolist()
|
||||||
|
|
||||||
@@ -460,6 +478,9 @@ def generate_video(
|
|||||||
kv = cross_kv
|
kv = cross_kv
|
||||||
rcs = rope_cos_sin
|
rcs = rope_cos_sin
|
||||||
|
|
||||||
|
# Use compiled forward when available (faster after first trace)
|
||||||
|
_call = getattr(model, '_compiled', model)
|
||||||
|
|
||||||
if cfg_disabled:
|
if cfg_disabled:
|
||||||
# No CFG: B=1 forward pass (2x faster than B=2 CFG batch)
|
# No CFG: B=1 forward pass (2x faster than B=2 CFG batch)
|
||||||
if is_i2v_mask_blend:
|
if is_i2v_mask_blend:
|
||||||
@@ -479,7 +500,7 @@ def generate_video(
|
|||||||
ctx = context_cond_high if timestep_val >= boundary else context_cond_low
|
ctx = context_cond_high if timestep_val >= boundary else context_cond_low
|
||||||
else:
|
else:
|
||||||
ctx = context_cond
|
ctx = context_cond
|
||||||
preds = model(
|
preds = _call(
|
||||||
[latents],
|
[latents],
|
||||||
t=t_batch,
|
t=t_batch,
|
||||||
context=ctx,
|
context=ctx,
|
||||||
@@ -513,7 +534,7 @@ def generate_video(
|
|||||||
ctx = context_cfg if not is_dual else (
|
ctx = context_cfg if not is_dual else (
|
||||||
context_cfg_high if timestep_val >= boundary else context_cfg_low
|
context_cfg_high if timestep_val >= boundary else context_cfg_low
|
||||||
)
|
)
|
||||||
preds = model(
|
preds = _call(
|
||||||
[latents, latents],
|
[latents, latents],
|
||||||
t=t_batch,
|
t=t_batch,
|
||||||
context=ctx,
|
context=ctx,
|
||||||
@@ -564,36 +585,28 @@ def generate_video(
|
|||||||
|
|
||||||
is_wan22_vae = config.vae_z_dim == 48
|
is_wan22_vae = config.vae_z_dim == 48
|
||||||
|
|
||||||
# Warm-up: prepend a copy of the first latent frame to provide temporal
|
|
||||||
# context for the real first frame. Causal convolutions in the VAE decoder
|
|
||||||
# pad with zeros on the left, so the first few output frames have degraded
|
|
||||||
# quality (no temporal context). By duplicating the first latent, the real
|
|
||||||
# first frame sees its own features as left context instead of zeros.
|
|
||||||
# We trim the extra output frames after decoding.
|
|
||||||
warmup_trim = vae_stride[0] # 4 frames per latent temporal position
|
|
||||||
latents_for_decode = mx.concatenate([latents[:, 0:1], latents], axis=1)
|
|
||||||
|
|
||||||
if is_wan22_vae:
|
if is_wan22_vae:
|
||||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||||
|
|
||||||
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
|
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
|
||||||
z = latents_for_decode.transpose(1, 2, 3, 0)[None] # [1, T+1, H, W, C]
|
z = latents.transpose(1, 2, 3, 0)[None]
|
||||||
z = denormalize_latents(z)
|
z = denormalize_latents(z)
|
||||||
video = vae(z) # [1, T', H', W', 3]
|
video = vae(z)
|
||||||
mx.eval(video)
|
mx.eval(video)
|
||||||
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
||||||
|
|
||||||
video = np.array(video[0]) # [T', H', W', 3]
|
video = np.array(video[0]) # [T', H', W', 3]
|
||||||
video = video[warmup_trim:] # Trim warm-up frames
|
# Trim extra frames generated for zero-padding warmup
|
||||||
|
if extra_frames > 0:
|
||||||
|
video = video[extra_frames:]
|
||||||
video = (video + 1.0) / 2.0
|
video = (video + 1.0) / 2.0
|
||||||
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
||||||
else:
|
else:
|
||||||
video = vae.decode(latents_for_decode[None]) # [1, 3, T+1*4, H, W]
|
video = vae.decode(latents[None])
|
||||||
mx.eval(video)
|
mx.eval(video)
|
||||||
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
||||||
|
|
||||||
video = np.array(video[0]) # [3, T', H, W]
|
video = np.array(video[0]) # [3, T', H, W]
|
||||||
video = video[:, warmup_trim:] # Trim warm-up frames (channels-first)
|
|
||||||
video = (video + 1.0) / 2.0
|
video = (video + 1.0) / 2.0
|
||||||
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
||||||
video = video.transpose(1, 2, 3, 0) # [T, H, W, 3]
|
video = video.transpose(1, 2, 3, 0) # [T, H, W, 3]
|
||||||
|
|||||||
@@ -48,13 +48,14 @@ class Head(nn.Module):
|
|||||||
"""
|
"""
|
||||||
if e.ndim == 2:
|
if e.ndim == 2:
|
||||||
e = e[:, None, :] # [B, 1, dim]
|
e = e[:, None, :] # [B, 1, dim]
|
||||||
# modulation already float32; e already float32 from model forward
|
# Compute modulation in float32 for precision, cast to working dtype
|
||||||
mod = self.modulation[:, None, :, :] + e[:, :, None, :] # [B, L_e, 2, dim]
|
w_dtype = _linear_dtype(self.head)
|
||||||
|
mod = (self.modulation[:, None, :, :] + e[:, :, None, :]).astype(w_dtype)
|
||||||
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
|
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
|
||||||
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
|
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
|
||||||
x_norm = self.norm(x)
|
x_norm = self.norm(x)
|
||||||
x_mod = x_norm * (1 + e1) + e0 # type promotion handles bf16→f32
|
x_mod = x_norm * (1 + e1) + e0
|
||||||
return self.head(x_mod.astype(_linear_dtype(self.head)))
|
return self.head(x_mod)
|
||||||
|
|
||||||
|
|
||||||
class WanModel(nn.Module):
|
class WanModel(nn.Module):
|
||||||
@@ -322,18 +323,14 @@ class WanModel(nn.Module):
|
|||||||
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||||
) # [B, dim]
|
) # [B, dim]
|
||||||
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
|
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
|
||||||
# Keep e and e0 in float32 — official asserts float32 for modulation
|
e0 = e0.reshape(batch_size, 1, 6, self.dim)
|
||||||
e0 = e0.reshape(batch_size, 1, 6, self.dim).astype(mx.float32)
|
|
||||||
e = e.astype(mx.float32)
|
|
||||||
else:
|
else:
|
||||||
# I2V: per-token timesteps [B, L]
|
# I2V: per-token timesteps [B, L]
|
||||||
e = self.time_embedding_1(
|
e = self.time_embedding_1(
|
||||||
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||||
) # [B, L, dim]
|
) # [B, L, dim]
|
||||||
e0 = self.time_projection(self.time_projection_act(e)) # [B, L, dim*6]
|
e0 = self.time_projection(self.time_projection_act(e)) # [B, L, dim*6]
|
||||||
# Keep e and e0 in float32 — official asserts float32 for modulation
|
e0 = e0.reshape(batch_size, -1, 6, self.dim)
|
||||||
e0 = e0.reshape(batch_size, -1, 6, self.dim).astype(mx.float32)
|
|
||||||
e = e.astype(mx.float32)
|
|
||||||
|
|
||||||
# Text embedding: skip MLP if context is already embedded (mx.array)
|
# Text embedding: skip MLP if context is already embedded (mx.array)
|
||||||
if isinstance(context, mx.array):
|
if isinstance(context, mx.array):
|
||||||
|
|||||||
@@ -51,17 +51,20 @@ class WanAttentionBlock(nn.Module):
|
|||||||
rope_cos_sin: tuple | None = None,
|
rope_cos_sin: tuple | None = None,
|
||||||
attn_mask: mx.array | None = None,
|
attn_mask: mx.array | None = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
# Modulation in float32 (e is already float32 from model forward)
|
# Modulation: compute in float32 for precision, cast to working dtype
|
||||||
mod = self.modulation + e
|
# to avoid promoting the full hidden state (seq_len × dim) to float32
|
||||||
e0 = mod[:, :, 0, :] # shift for self-attn
|
w_dtype = _linear_dtype(self.self_attn.q)
|
||||||
e1 = mod[:, :, 1, :] # scale for self-attn
|
mod = (self.modulation + e).astype(w_dtype)
|
||||||
e2 = mod[:, :, 2, :] # gate for self-attn
|
e0, e1, e2, e3, e4, e5 = (
|
||||||
e3 = mod[:, :, 3, :] # shift for ffn
|
mod[:, :, 0, :], # shift for self-attn
|
||||||
e4 = mod[:, :, 4, :] # scale for ffn
|
mod[:, :, 1, :], # scale for self-attn
|
||||||
e5 = mod[:, :, 5, :] # gate for ffn
|
mod[:, :, 2, :], # gate for self-attn
|
||||||
|
mod[:, :, 3, :], # shift for ffn
|
||||||
|
mod[:, :, 4, :], # scale for ffn
|
||||||
|
mod[:, :, 5, :], # gate for ffn
|
||||||
|
)
|
||||||
|
|
||||||
# Self-attention with modulation
|
# Self-attention with modulation (hidden state stays in w_dtype)
|
||||||
# Type promotion handles bf16→f32 automatically when multiplied with f32 modulation
|
|
||||||
x_mod = self.norm1(x) * (1 + e1) + e0
|
x_mod = self.norm1(x) * (1 + e1) + e0
|
||||||
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs, rope_cos_sin=rope_cos_sin, attn_mask=attn_mask)
|
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs, rope_cos_sin=rope_cos_sin, attn_mask=attn_mask)
|
||||||
x = x + y * e2
|
x = x + y * e2
|
||||||
|
|||||||
@@ -81,7 +81,8 @@ class CausalConv3d(nn.Module):
|
|||||||
y = mx.conv_general(x_flat, w2d) + self.bias
|
y = mx.conv_general(x_flat, w2d) + self.bias
|
||||||
return y.reshape(B, T, y.shape[1], y.shape[2], -1)
|
return y.reshape(B, T, y.shape[1], y.shape[2], -1)
|
||||||
|
|
||||||
# Causal temporal padding (left only)
|
# Causal temporal padding (left only) — zeros match the reference
|
||||||
|
# implementation and what the model was trained with.
|
||||||
if self._causal_pad_t > 0:
|
if self._causal_pad_t > 0:
|
||||||
pad_t = mx.zeros((B, self._causal_pad_t, H, W, C))
|
pad_t = mx.zeros((B, self._causal_pad_t, H, W, C))
|
||||||
x = mx.concatenate([pad_t, x], axis=1)
|
x = mx.concatenate([pad_t, x], axis=1)
|
||||||
|
|||||||
Reference in New Issue
Block a user