diff --git a/mlx_video/generate_wan.py b/mlx_video/generate_wan.py index a1875d5..f1f0275 100644 --- a/mlx_video/generate_wan.py +++ b/mlx_video/generate_wan.py @@ -173,6 +173,13 @@ def generate_video( # Validate frame count assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}" + # For T2V: generate 1 extra latent frame so the VAE's causal zero-padding + # artifacts land on throwaway frames. The reference Wan2.2 speech2video.py + # uses a similar "drop_first_motion" approach (drops 3 pixel frames). + # For I2V the reference image provides real first-frame content, so no extra needed. + extra_frames = config.vae_stride[0] if not is_i2v else 0 + gen_frames = num_frames + extra_frames + version_str = f"Wan{config.model_version}" mode_str = "dual-model" if is_dual else "single-model" pipeline_str = "Image-to-Video" if is_i2v else "Text-to-Video" @@ -223,7 +230,7 @@ def generate_video( # Compute target latent shape z_dim = config.vae_z_dim - t_latent = (num_frames - 1) // vae_stride[0] + 1 + t_latent = (gen_frames - 1) // vae_stride[0] + 1 h_latent = height // vae_stride[1] w_latent = width // vae_stride[2] target_shape = (z_dim, t_latent, h_latent, w_latent) @@ -234,6 +241,8 @@ def generate_video( ) print(f"{Colors.DIM} Latent shape: {target_shape}") + if extra_frames > 0: + print(f" Generating {extra_frames} extra pixel frames to absorb VAE boundary artifacts") print(f" Sequence length: {seq_len}{Colors.RESET}") # Load T5 encoder @@ -439,6 +448,15 @@ def generate_video( print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}") t3 = time.time() + # Compile model forward for faster denoising + models_to_compile = ( + [high_noise_model, low_noise_model] if is_dual else [single_model] + ) + for m in models_to_compile: + m._compiled = mx.compile(m) + + + # Pre-convert timesteps to Python list to avoid .item() sync each step timestep_list = sched.timesteps.tolist() @@ -460,6 +478,9 @@ def generate_video( kv = cross_kv rcs = rope_cos_sin + # Use compiled forward when available (faster after first trace) + _call = getattr(model, '_compiled', model) + if cfg_disabled: # No CFG: B=1 forward pass (2x faster than B=2 CFG batch) if is_i2v_mask_blend: @@ -479,7 +500,7 @@ def generate_video( ctx = context_cond_high if timestep_val >= boundary else context_cond_low else: ctx = context_cond - preds = model( + preds = _call( [latents], t=t_batch, context=ctx, @@ -513,7 +534,7 @@ def generate_video( ctx = context_cfg if not is_dual else ( context_cfg_high if timestep_val >= boundary else context_cfg_low ) - preds = model( + preds = _call( [latents, latents], t=t_batch, context=ctx, @@ -564,36 +585,28 @@ def generate_video( is_wan22_vae = config.vae_z_dim == 48 - # Warm-up: prepend a copy of the first latent frame to provide temporal - # context for the real first frame. Causal convolutions in the VAE decoder - # pad with zeros on the left, so the first few output frames have degraded - # quality (no temporal context). By duplicating the first latent, the real - # first frame sees its own features as left context instead of zeros. - # We trim the extra output frames after decoding. - warmup_trim = vae_stride[0] # 4 frames per latent temporal position - latents_for_decode = mx.concatenate([latents[:, 0:1], latents], axis=1) - if is_wan22_vae: from mlx_video.models.wan.vae22 import denormalize_latents # latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE) - z = latents_for_decode.transpose(1, 2, 3, 0)[None] # [1, T+1, H, W, C] + z = latents.transpose(1, 2, 3, 0)[None] z = denormalize_latents(z) - video = vae(z) # [1, T', H', W', 3] + video = vae(z) mx.eval(video) print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}") video = np.array(video[0]) # [T', H', W', 3] - video = video[warmup_trim:] # Trim warm-up frames + # Trim extra frames generated for zero-padding warmup + if extra_frames > 0: + video = video[extra_frames:] video = (video + 1.0) / 2.0 video = np.clip(video * 255.0, 0, 255).astype(np.uint8) else: - video = vae.decode(latents_for_decode[None]) # [1, 3, T+1*4, H, W] + video = vae.decode(latents[None]) mx.eval(video) print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}") video = np.array(video[0]) # [3, T', H, W] - video = video[:, warmup_trim:] # Trim warm-up frames (channels-first) video = (video + 1.0) / 2.0 video = np.clip(video * 255.0, 0, 255).astype(np.uint8) video = video.transpose(1, 2, 3, 0) # [T, H, W, 3] diff --git a/mlx_video/models/wan/model.py b/mlx_video/models/wan/model.py index e6a3a40..5f2e391 100644 --- a/mlx_video/models/wan/model.py +++ b/mlx_video/models/wan/model.py @@ -48,13 +48,14 @@ class Head(nn.Module): """ if e.ndim == 2: e = e[:, None, :] # [B, 1, dim] - # modulation already float32; e already float32 from model forward - mod = self.modulation[:, None, :, :] + e[:, :, None, :] # [B, L_e, 2, dim] + # Compute modulation in float32 for precision, cast to working dtype + w_dtype = _linear_dtype(self.head) + mod = (self.modulation[:, None, :, :] + e[:, :, None, :]).astype(w_dtype) e0 = mod[:, :, 0, :] # [B, L_e, dim] shift e1 = mod[:, :, 1, :] # [B, L_e, dim] scale x_norm = self.norm(x) - x_mod = x_norm * (1 + e1) + e0 # type promotion handles bf16→f32 - return self.head(x_mod.astype(_linear_dtype(self.head))) + x_mod = x_norm * (1 + e1) + e0 + return self.head(x_mod) class WanModel(nn.Module): @@ -322,18 +323,14 @@ class WanModel(nn.Module): self.time_embedding_act(self.time_embedding_0(sin_emb)) ) # [B, dim] e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6] - # Keep e and e0 in float32 — official asserts float32 for modulation - e0 = e0.reshape(batch_size, 1, 6, self.dim).astype(mx.float32) - e = e.astype(mx.float32) + e0 = e0.reshape(batch_size, 1, 6, self.dim) else: # I2V: per-token timesteps [B, L] e = self.time_embedding_1( self.time_embedding_act(self.time_embedding_0(sin_emb)) ) # [B, L, dim] e0 = self.time_projection(self.time_projection_act(e)) # [B, L, dim*6] - # Keep e and e0 in float32 — official asserts float32 for modulation - e0 = e0.reshape(batch_size, -1, 6, self.dim).astype(mx.float32) - e = e.astype(mx.float32) + e0 = e0.reshape(batch_size, -1, 6, self.dim) # Text embedding: skip MLP if context is already embedded (mx.array) if isinstance(context, mx.array): diff --git a/mlx_video/models/wan/transformer.py b/mlx_video/models/wan/transformer.py index 857bcae..59aa651 100644 --- a/mlx_video/models/wan/transformer.py +++ b/mlx_video/models/wan/transformer.py @@ -51,17 +51,20 @@ class WanAttentionBlock(nn.Module): rope_cos_sin: tuple | None = None, attn_mask: mx.array | None = None, ) -> mx.array: - # Modulation in float32 (e is already float32 from model forward) - mod = self.modulation + e - e0 = mod[:, :, 0, :] # shift for self-attn - e1 = mod[:, :, 1, :] # scale for self-attn - e2 = mod[:, :, 2, :] # gate for self-attn - e3 = mod[:, :, 3, :] # shift for ffn - e4 = mod[:, :, 4, :] # scale for ffn - e5 = mod[:, :, 5, :] # gate for ffn + # Modulation: compute in float32 for precision, cast to working dtype + # to avoid promoting the full hidden state (seq_len × dim) to float32 + w_dtype = _linear_dtype(self.self_attn.q) + mod = (self.modulation + e).astype(w_dtype) + e0, e1, e2, e3, e4, e5 = ( + mod[:, :, 0, :], # shift for self-attn + mod[:, :, 1, :], # scale for self-attn + mod[:, :, 2, :], # gate for self-attn + mod[:, :, 3, :], # shift for ffn + mod[:, :, 4, :], # scale for ffn + mod[:, :, 5, :], # gate for ffn + ) - # Self-attention with modulation - # Type promotion handles bf16→f32 automatically when multiplied with f32 modulation + # Self-attention with modulation (hidden state stays in w_dtype) x_mod = self.norm1(x) * (1 + e1) + e0 y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs, rope_cos_sin=rope_cos_sin, attn_mask=attn_mask) x = x + y * e2 diff --git a/mlx_video/models/wan/vae22.py b/mlx_video/models/wan/vae22.py index 8c31d3b..a0f7234 100644 --- a/mlx_video/models/wan/vae22.py +++ b/mlx_video/models/wan/vae22.py @@ -81,7 +81,8 @@ class CausalConv3d(nn.Module): y = mx.conv_general(x_flat, w2d) + self.bias return y.reshape(B, T, y.shape[1], y.shape[2], -1) - # Causal temporal padding (left only) + # Causal temporal padding (left only) — zeros match the reference + # implementation and what the model was trained with. if self._causal_pad_t > 0: pad_t = mx.zeros((B, self._causal_pad_t, H, W, C)) x = mx.concatenate([pad_t, x], axis=1)