fix(wan): Fix RoPE frequency construction

This commit is contained in:
Daniel
2026-02-28 11:20:36 +01:00
parent f4195f0118
commit dbab95ec45
3 changed files with 386 additions and 30 deletions

View File

@@ -9,12 +9,13 @@ This document records the systematic diagnostic methodology used to debug the Wa
- [Diagnostic Methodology](#diagnostic-methodology)
- [Bug 1: Text Embedding Cross-Contamination](#bug-1-text-embedding-cross-contamination)
- [Bug 2: VAE Encoder Weights Excluded from Conversion](#bug-2-vae-encoder-weights-excluded-from-conversion)
- [Bug 3: RoPE Frequency Computation](#bug-3-rope-frequency-computation)
- [Bug 3: RoPE Frequency Computation (original)](#bug-3-rope-frequency-computation-original)
- [Bug 6: RoPE Frequency Distribution (Bug 3 Fix Was Wrong)](#bug-6-rope-frequency-distribution-bug-3-fix-was-wrong)
- [Bug 4: VAE Encoder Temporal Downsample Order](#bug-4-vae-encoder-temporal-downsample-order)
- [Bug 5: Non-Chunked VAE Encoding](#bug-5-non-chunked-vae-encoding)
- [Verified Correct Components](#verified-correct-components)
- [Performance Optimizations](#performance-optimizations)
- [Open Investigation: CFG Effectiveness](#open-investigation-cfg-effectiveness)
- [Resolved: CFG Effectiveness](#resolved-cfg-effectiveness-was-open-investigation)
- [Reference Implementation](#reference-implementation)
- [Useful Diagnostic Commands](#useful-diagnostic-commands)
@@ -166,34 +167,57 @@ include_encoder = config.model_type in ("ti2v", "i2v")
---
## Bug 3: RoPE Frequency Computation
## Bug 3: RoPE Frequency Computation (original)
**Symptom:** Progressive 2px checkerboard artifacts on generated frames, increasing with temporal distance from the conditioned frame.
**Root Cause:** The reference creates **one** frequency table via `rope_params(1024, head_dim=128)` producing 64 frequency exponents, which `rope_apply` then splits into temporal (22), height (21), and width (21) portions. This gives temporal axes LOW frequencies and spatial axes progressively HIGHER frequencies.
**Root Cause (original):** Our original code called `rope_params` three times but applied them incorrectly (per-axis in the model init, then rope_apply did NOT split). This was initially "fixed" by switching to a single `rope_params(1024, head_dim=128)` call, which reduced checkerboard but introduced Bug 6 (see below).
**File:** `mlx_video/models/wan/model.py`
**Commit:** `3da4a637`
---
## Bug 6: RoPE Frequency Distribution (Bug 3 Fix Was Wrong)
**Symptom:** I2V generates input image in frames 03, colorful checkerboard on frame 4, then grey frames. CFG cond/uncond predictions nearly identical. Model cannot produce coherent motion.
**Root Cause:** The Bug 3 "fix" replaced three separate `rope_params` calls with a single `rope_params(1024, 128)`. But the reference (`wan/modules/model.py` lines 400405) actually uses **three separate calls with different dimension normalizations**, concatenated:
Our code called `rope_params` **three times** with different normalizations:
```python
# WRONG: each axis gets full frequency range [0, 1)
freqs_t = rope_params(1024, d_t=44) # 22 exponents normalized by 44
freqs_h = rope_params(1024, d_h=42) # 21 exponents normalized by 42
freqs_w = rope_params(1024, d_w=42) # 21 exponents normalized by 42
# Reference (CORRECT):
d = dim // num_heads # 128
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)), # rope_params(1024, 44)
rope_params(1024, 2 * (d // 6)), # rope_params(1024, 42)
rope_params(1024, 2 * (d // 6)) # rope_params(1024, 42)
], dim=1)
```
The max frequency difference was ~1.0 (not a precision issue — a fundamental design bug). This affected **all** Wan models (T2V, I2V, TI2V).
Each axis gets its own full frequency range [θ^0, θ^(-~0.95)]. The single-call approach gave:
- Temporal: low frequencies only [1.0 → 0.049]
- Height: medium frequencies only [0.042 → 0.002] (should start at 1.0!)
- Width: high frequencies only [0.002 → 0.0001] (should start at 1.0!)
**How Found:** Line-by-line comparison of `rope_params` usage between reference `model.py` (single call) and our `model.py` (three calls). Printed the actual frequency exponents to confirm the numerical divergence.
The height/width position encoding was essentially destroyed — nearby spatial positions were indistinguishable (max diff 0.958 for height, 0.998 for width vs reference).
**How Found:** Direct line-by-line comparison of `WanModel.__init__` freq construction between reference `wan/modules/model.py` and our `models/wan/model.py`. Numerical verification confirmed the three-call approach gives each axis a full [0, ~1) exponent range, while the single-call monotonically assigns low→high across axes.
**Fix:**
```python
# Single unified frequency table, split by rope_apply
self.freqs = rope_params(1024, dim // config.num_heads)
d = dim // config.num_heads
self.freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
```
**Impact:** ~35% reduction in checkerboard metric, 55% reduction in FFT 8px-frequency power.
**Verification:** Max diff vs reference cos/sin: 0.00000000 (exact float32 match).
**File:** `mlx_video/models/wan/model.py` (lines 154156)
**Commit:** `3da4a637`
**Impact:** Affects ALL Wan models (T2V, I2V, TI2V). Resolves the "Open Investigation: CFG Effectiveness" issue — the model could not produce meaningful cond/uncond differences because it couldn't encode spatial positions.
**File:** `mlx_video/models/wan/model.py` (line 155)
---
@@ -298,19 +322,11 @@ Applied alongside bug fixes to improve inference speed:
---
## Open Investigation: CFG Effectiveness
## Resolved: CFG Effectiveness (was Open Investigation)
**Current symptom:** After all bug fixes, generated video shows the input image in frames 03 (latent frame 0), then grey/flat frames for the rest.
**Symptom:** Generated video shows the input image in frames 03 (latent frame 0), then grey/flat frames for the rest. Cond and uncond predictions were nearly identical.
**Finding:** A single forward pass at t=1000 shows cond and uncond predictions are nearly identical (|diff| mean = 0.010.035). With `guide_scale=3.5`, the CFG guidance term barely changes anything.
**Possible causes under investigation:**
1. Cross-attention context flow — both cond and uncond may be receiving equivalent context
2. The model may genuinely produce small cond/uncond differences for I2V (since both share the same y conditioning)
3. The `embed_text` method or `prepare_cross_kv` may not properly separate B=2 batch elements
4. There may be an issue with how cross-attention K/V caches index into batch elements
**Diagnostic approach:** Compare cross-attention K/V cache values between cond (index 0) and uncond (index 1) to confirm they contain different embeddings.
**Resolution:** This was caused by Bug 6 (incorrect RoPE frequency distribution). The single `rope_params(1024, 128)` call gave height frequencies starting at 0.042 and width at 0.002 (instead of 1.0 for both), making the model unable to encode spatial positions. This caused the transformer to produce nearly identical outputs regardless of text conditioning, explaining the tiny cond/uncond differences.
---