Files
mlx-video/docs/wan22-implementation-notes.md
2026-03-11 09:08:10 +01:00

12 KiB
Raw Blame History

Wan2.2 MLX Implementation Notes

Learnings and key decisions from porting Wan2.2 (TI2V-5B / T2V-14B / T2V-1.3B) to Apple MLX.

Architecture Overview

Wan2.2 is a Diffusion Transformer (DiT) for video generation. Despite early reports, the T2V/TI2V models do not use Mixture-of-Experts — they are dense DiT models with a dual-model architecture for the 14B variant (separate high-noise and low-noise denoisers with a boundary timestep).

Key Parameters

Model dim heads layers FFN mult VAE z_dim VAE stride
T2V-14B 5120 40 40 4×(5120×4/3) 16 (4, 8, 8)
TI2V-5B 3072 24 32 4×(3072×4/3) 48 (4, 16, 16)
T2V-1.3B 1536 12 30 4×(1536×4/3) 16 (4, 8, 8)

Codebase Structure (~3900 lines of Wan2.2 code)

mlx_video/
├── generate_wan.py           # 483L - Generation pipeline (T2V + I2V)
├── convert_wan.py            # 564L - Weight conversion from HuggingFace
└── models/wan/
    ├── config.py             # 113L - Model configs (dataclass presets)
    ├── model.py              # 320L - DiT model (time embed, patchify, unpatchify)
    ├── transformer.py        #  91L - Attention block + FFN
    ├── attention.py          # 211L - Self-attention + cross-attention
    ├── rope.py               # 100L - 3D Rotary Position Embeddings
    ├── text_encoder.py       # 240L - T5 encoder (UMT5-XXL)
    ├── scheduler.py          # 428L - Euler, DPM++ 2M, UniPC schedulers
    ├── vae.py                # 315L - Wan2.1 VAE decoder (4×8×8)
    ├── vae22.py              # 836L - Wan2.2 VAE encoder + decoder (4×16×16)
    ├── loading.py            # 154L - Model loading utilities
    └── i2v_utils.py          #  58L - I2V mask/preprocessing

Critical Bugs & Fixes

1. MLX Underscore Attribute Gotcha

Problem: MLX's nn.Module silently ignores underscore-prefixed attributes (_layer_0, _layer_1, etc.) in parameters() and load_weights(). The Wan2.2 VAE had layers named _layer_N, causing 87 out of 110 weights to be silently dropped during loading.

Fix: Rename all _layer_N attributes to layer_N. MLX treats underscore-prefixed attributes as "private" and excludes them from the parameter tree.

Lesson: Never use underscore-prefixed names for nn.Module sub-modules in MLX.

2. Patchify Channel Ordering

Problem: The patchify/unpatchify operations transposed channels incorrectly — producing [C fastest] layout instead of [C slowest], causing completely garbled video output.

Fix: Changed reshape to produce correct [B, T', H', W', pt*ph*pw*C] ordering matching PyTorch's contiguous memory layout.

Lesson: When porting PyTorch reshape/view operations to MLX, pay close attention to memory layout — PyTorch is row-major by default, and reshape semantics differ when dimensions are reordered.

3. VAE AttentionBlock Reshape

Problem: Attention block merged batch (B) with channels (C) instead of batch with temporal (T), producing a green checker pattern in output.

Fix: Correct reshape from [B*C, T, H, W] to [B*T, C, H, W] for spatial attention.

4. RMS Norm vs L2 Norm

Problem: The Wan2.2 VAE uses a class named RMS_norm in PyTorch, but it actually computes L2 normalization (divide by L2 norm), not RMS normalization (divide by RMS). Using actual RMS norm caused exponential value explosion.

Fix: Implement as x / ||x||₂ instead of x / sqrt(mean(x²)).

Lesson: Don't trust class names in reference code — read the actual computation.

5. Video Codec Green Output

Problem: OpenCV's mp4v codec on macOS produces green-tinted video.

Fix: Switch to imageio with libx264 codec. Fallback chain: imageio → cv2 (avc1) → PNG frames.


Precision & Dtype Flow

The bfloat16 Autocast Pattern

The official PyTorch implementation uses torch.autocast("cuda", dtype=torch.bfloat16) which automatically casts matmul inputs. In MLX, we replicate this manually:

Operation Official (PyTorch) MLX Implementation
Modulation/gates float32 (explicit autocast(enabled=False)) x.astype(mx.float32) before modulation
QKV projections bfloat16 (outer autocast) Cast input to self.q.weight.dtype
RoPE computation float64 → float32 float32 (MLX lacks float64 on GPU)
Q/K after RoPE bfloat16 (q.to(v.dtype)) Cast back to weight dtype after RoPE
FFN matmuls bfloat16 (outer autocast) Cast input to self.fc1.weight.dtype
Residual stream float32 float32 (no cast)

Result: ~16% speedup (47s vs 56s for 20 steps at 480p) with no quality regression.

Key insight: Modulation parameters (scale, shift, gate) must stay in float32 — they are small values (~0.010.1) that lose significant precision in bfloat16. The official code explicitly disables autocast for these computations.

T5 Encoder Precision

The T5 text encoder must run in float32. Bfloat16 weights cause the attention softmax to produce degenerate distributions, which corrupts text conditioning and manifests as blurry patches in generated video. Since T5 only runs once per generation, the performance cost is negligible.

VAE Decoder Precision

VAE weights must be float32. Bfloat16 VAE decode introduces visible quality loss in the decoded video frames.


Scheduler Implementation Details

Three Schedulers: Euler, DPM++ 2M, UniPC

All operate in the flow-matching formulation where sigma represents the noise level (1.0 = pure noise, 0.0 = clean).

Euler: Simple first-order ODE solver. Most stable, recommended for debugging.

DPM++ 2M: Second-order multistep solver. Uses previous step's model output for higher-order correction. Requires special handling at boundaries (return ±inf from _lambda() when sigma is 0 or 1).

UniPC (default, matches official): Second-order predictor-corrector. The "C" (corrector) part is critical — it refines each step using the already-computed model output at zero additional model evaluation cost.

UniPC Corrector: Must Be Enabled

Discovery: Our implementation had use_corrector=False by default, but the official Wan2.2 code always enables it (there's no flag — the corrector runs whenever step_index > 0).

Impact: Without the corrector, UniPC degrades to a simple predictor, losing its second-order accuracy advantage.

UniPC Corrector Coefficients

The corrector coefficients (rhos_c) must be computed by solving a linear system, not hardcoded. For order ≥ 2, hardcoding rhos_c[-1] = 0.5 introduces ~613% error in the correction term across 47+ steps. The fix uses np.linalg.solve() to compute exact coefficients.

Sigma Schedule

# Flow-matching sigma schedule with shift
sigmas = np.linspace(1.0, 1.0 / num_steps, num_steps)
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)

Default shifts: T2V-14B uses 5.0, TI2V-5B uses 3.0, T2V-1.3B uses 3.0.


Image-to-Video (I2V) Pipeline

Per-Token Timesteps

I2V conditions on a reference first frame by giving first-frame latent patches a timestep of 0 (clean) while other patches get the current diffusion timestep:

# mask_tokens: [1, L] — 0 for first-frame patches, 1 for rest
t_tokens = mask_tokens * current_timestep  # first-frame → t=0

The model receives 2D timestep input [B, L] instead of scalar, enabling per-token noise levels.

Mask Re-application

After each scheduler step, the first-frame latent is re-injected to prevent drift:

latents = (1.0 - mask) * z_img + mask * latents

VAE Encoder Temporal Downsample Order

The Wan2.2 VAE encoder has temporal_downsample = (False, True, True):

  • Stage 0: Spatial-only downsampling
  • Stages 12: Spatial + temporal downsampling

This was incorrectly set to (True, True, False) initially, causing wrong spatial processing paths.


Dimension Constraints

Patchify Alignment

Video dimensions must be divisible by patch_size × vae_stride:

  • TI2V-5B: patch=(1,2,2), stride=(4,16,16) → alignment = 32 pixels
  • T2V-14B: patch=(1,2,2), stride=(4,8,8) → alignment = 16 pixels

Example: 720p (1280×720) → 720 % 32 ≠ 0, auto-aligns to 704.

Frame Count

Frames must satisfy num_frames = 4n + 1 (e.g., 5, 9, 13, ..., 81) due to temporal VAE stride of 4.


Performance Optimizations

Batched CFG

Instead of two separate forward passes for conditional and unconditional predictions, batch them into a single B=2 forward pass:

preds = model([latents, latents], t=t_batch, context=context_cfg, ...)
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]

Result: ~40% speedup by amortizing attention overhead.

Precomputed Text Embeddings & Cross-Attention KV Cache

Text embeddings and cross-attention K/V projections are constant across all diffusion steps. Computing them once and passing as caches eliminates redundant computation.

Memory Management in Diffusion Loop

# Release temporaries before eval to free memory for graph execution
del noise_pred_cond, noise_pred_uncond, noise_pred, preds
mx.eval(latents)

MLX's lazy evaluation means mx.eval() triggers the full computation graph. Deleting intermediate arrays before eval allows MLX to reuse their memory during execution.


Weight Conversion

Key Mapping Patterns

The PyTorch → MLX conversion (convert_wan.py) handles several systematic transforms:

  1. Conv3d weight transposition: PyTorch (out, in, D, H, W) → MLX (out, D, H, W, in)
  2. Linear weight transposition: PyTorch (out, in) → MLX (out, in) (same convention for nn.Linear)
  3. Nested module paths: blocks.0.self_attn.q.weight → same paths, MLX loads by dotted key

Dual-Model Splitting

The T2V-14B uses dual models (high-noise and low-noise). The conversion script splits a single checkpoint into separate files or handles pre-split checkpoints from HuggingFace.


Testing Strategy

260 tests across 9 files, all running in ~4 seconds:

File Focus
test_wan_config.py Config presets, field validation
test_wan_attention.py Self/cross attention, RMSNorm, bf16 autocast
test_wan_transformer.py FFN, attention block, float32 modulation
test_wan_model.py Full DiT forward pass, per-token timesteps
test_wan_t5.py T5 encoder layers and full encoding
test_wan_vae.py VAE 2.1 decoder, VAE 2.2 encoder + decoder
test_wan_scheduler.py All 3 schedulers, cross-scheduler coherence
test_wan_convert.py Weight sanitization and conversion
test_wan_generate.py End-to-end pipeline, I2V masks, dimension alignment

Tests use a tiny config (dim=64, heads=2, layers=2) for fast execution. Cross-scheduler coherence tests verify that all three schedulers produce similar outputs from the same noise.


Known Issues

I2V Quality Degradation

Frames 213 gradually degrade, and frame 14 often has a "flash" artifact. All implementation details have been verified against the official PyTorch code with no discrepancies found. Possible causes:

  • Subtle numerical differences from float32 vs float64 RoPE (MLX lacks float64 on GPU)
  • MLX-specific attention precision behavior
  • Better prompts and 720p resolution (the model's native resolution) help reduce artifacts

Chinese Negative Prompt

The official Wan2.2 uses a Chinese negative prompt that prevents oversaturation and comic-style artifacts. Correct tokenization requires ftfy.fix_text() to normalize fullwidth characters and double HTML unescaping. Without proper text cleaning, the negative prompt tokens don't match the training distribution, causing blurry patches.