fix(wan): Fix RoPE frequency construction

This commit is contained in:
Daniel
2026-02-28 11:20:36 +01:00
parent f4195f0118
commit dbab95ec45
3 changed files with 386 additions and 30 deletions

View File

@@ -9,12 +9,13 @@ This document records the systematic diagnostic methodology used to debug the Wa
- [Diagnostic Methodology](#diagnostic-methodology)
- [Bug 1: Text Embedding Cross-Contamination](#bug-1-text-embedding-cross-contamination)
- [Bug 2: VAE Encoder Weights Excluded from Conversion](#bug-2-vae-encoder-weights-excluded-from-conversion)
- [Bug 3: RoPE Frequency Computation](#bug-3-rope-frequency-computation)
- [Bug 3: RoPE Frequency Computation (original)](#bug-3-rope-frequency-computation-original)
- [Bug 6: RoPE Frequency Distribution (Bug 3 Fix Was Wrong)](#bug-6-rope-frequency-distribution-bug-3-fix-was-wrong)
- [Bug 4: VAE Encoder Temporal Downsample Order](#bug-4-vae-encoder-temporal-downsample-order)
- [Bug 5: Non-Chunked VAE Encoding](#bug-5-non-chunked-vae-encoding)
- [Verified Correct Components](#verified-correct-components)
- [Performance Optimizations](#performance-optimizations)
- [Open Investigation: CFG Effectiveness](#open-investigation-cfg-effectiveness)
- [Resolved: CFG Effectiveness](#resolved-cfg-effectiveness-was-open-investigation)
- [Reference Implementation](#reference-implementation)
- [Useful Diagnostic Commands](#useful-diagnostic-commands)
@@ -166,34 +167,57 @@ include_encoder = config.model_type in ("ti2v", "i2v")
---
## Bug 3: RoPE Frequency Computation
## Bug 3: RoPE Frequency Computation (original)
**Symptom:** Progressive 2px checkerboard artifacts on generated frames, increasing with temporal distance from the conditioned frame.
**Root Cause:** The reference creates **one** frequency table via `rope_params(1024, head_dim=128)` producing 64 frequency exponents, which `rope_apply` then splits into temporal (22), height (21), and width (21) portions. This gives temporal axes LOW frequencies and spatial axes progressively HIGHER frequencies.
**Root Cause (original):** Our original code called `rope_params` three times but applied them incorrectly (per-axis in the model init, then rope_apply did NOT split). This was initially "fixed" by switching to a single `rope_params(1024, head_dim=128)` call, which reduced checkerboard but introduced Bug 6 (see below).
**File:** `mlx_video/models/wan/model.py`
**Commit:** `3da4a637`
---
## Bug 6: RoPE Frequency Distribution (Bug 3 Fix Was Wrong)
**Symptom:** I2V generates input image in frames 03, colorful checkerboard on frame 4, then grey frames. CFG cond/uncond predictions nearly identical. Model cannot produce coherent motion.
**Root Cause:** The Bug 3 "fix" replaced three separate `rope_params` calls with a single `rope_params(1024, 128)`. But the reference (`wan/modules/model.py` lines 400405) actually uses **three separate calls with different dimension normalizations**, concatenated:
Our code called `rope_params` **three times** with different normalizations:
```python
# WRONG: each axis gets full frequency range [0, 1)
freqs_t = rope_params(1024, d_t=44) # 22 exponents normalized by 44
freqs_h = rope_params(1024, d_h=42) # 21 exponents normalized by 42
freqs_w = rope_params(1024, d_w=42) # 21 exponents normalized by 42
# Reference (CORRECT):
d = dim // num_heads # 128
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)), # rope_params(1024, 44)
rope_params(1024, 2 * (d // 6)), # rope_params(1024, 42)
rope_params(1024, 2 * (d // 6)) # rope_params(1024, 42)
], dim=1)
```
The max frequency difference was ~1.0 (not a precision issue — a fundamental design bug). This affected **all** Wan models (T2V, I2V, TI2V).
Each axis gets its own full frequency range [θ^0, θ^(-~0.95)]. The single-call approach gave:
- Temporal: low frequencies only [1.0 → 0.049]
- Height: medium frequencies only [0.042 → 0.002] (should start at 1.0!)
- Width: high frequencies only [0.002 → 0.0001] (should start at 1.0!)
**How Found:** Line-by-line comparison of `rope_params` usage between reference `model.py` (single call) and our `model.py` (three calls). Printed the actual frequency exponents to confirm the numerical divergence.
The height/width position encoding was essentially destroyed — nearby spatial positions were indistinguishable (max diff 0.958 for height, 0.998 for width vs reference).
**How Found:** Direct line-by-line comparison of `WanModel.__init__` freq construction between reference `wan/modules/model.py` and our `models/wan/model.py`. Numerical verification confirmed the three-call approach gives each axis a full [0, ~1) exponent range, while the single-call monotonically assigns low→high across axes.
**Fix:**
```python
# Single unified frequency table, split by rope_apply
self.freqs = rope_params(1024, dim // config.num_heads)
d = dim // config.num_heads
self.freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
```
**Impact:** ~35% reduction in checkerboard metric, 55% reduction in FFT 8px-frequency power.
**Verification:** Max diff vs reference cos/sin: 0.00000000 (exact float32 match).
**File:** `mlx_video/models/wan/model.py` (lines 154156)
**Commit:** `3da4a637`
**Impact:** Affects ALL Wan models (T2V, I2V, TI2V). Resolves the "Open Investigation: CFG Effectiveness" issue — the model could not produce meaningful cond/uncond differences because it couldn't encode spatial positions.
**File:** `mlx_video/models/wan/model.py` (line 155)
---
@@ -298,19 +322,11 @@ Applied alongside bug fixes to improve inference speed:
---
## Open Investigation: CFG Effectiveness
## Resolved: CFG Effectiveness (was Open Investigation)
**Current symptom:** After all bug fixes, generated video shows the input image in frames 03 (latent frame 0), then grey/flat frames for the rest.
**Symptom:** Generated video shows the input image in frames 03 (latent frame 0), then grey/flat frames for the rest. Cond and uncond predictions were nearly identical.
**Finding:** A single forward pass at t=1000 shows cond and uncond predictions are nearly identical (|diff| mean = 0.010.035). With `guide_scale=3.5`, the CFG guidance term barely changes anything.
**Possible causes under investigation:**
1. Cross-attention context flow — both cond and uncond may be receiving equivalent context
2. The model may genuinely produce small cond/uncond differences for I2V (since both share the same y conditioning)
3. The `embed_text` method or `prepare_cross_kv` may not properly separate B=2 batch elements
4. There may be an issue with how cross-attention K/V caches index into batch elements
**Diagnostic approach:** Compare cross-attention K/V cache values between cond (index 0) and uncond (index 1) to confirm they contain different embeddings.
**Resolution:** This was caused by Bug 6 (incorrect RoPE frequency distribution). The single `rope_params(1024, 128)` call gave height frequencies starting at 0.042 and width at 0.002 (instead of 1.0 for both), making the model unable to encode spatial positions. This caused the transformer to produce nearly identical outputs regardless of text conditioning, explaining the tiny cond/uncond differences.
---

View File

@@ -108,9 +108,15 @@ class WanModel(nn.Module):
# Output head
self.head = Head(dim, config.out_dim, config.patch_size, config.eps)
# Precompute RoPE frequencies — single table, split by rope_apply
# Reference computes one rope_params(head_dim) and splits into t/h/w.
self.freqs = rope_params(1024, dim // config.num_heads)
# Precompute RoPE frequencies — three separate tables concatenated.
# Reference computes three rope_params with different dim normalizations
# so each axis (temporal/height/width) gets its own full frequency range.
d = dim // config.num_heads
self.freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
# Precompute sinusoidal inv_freq for time embedding
half = config.freq_dim // 2

View File

@@ -0,0 +1,334 @@
"""Tests for Wan RoPE frequency construction (Bug 6 regression tests).
These tests verify that the RoPE frequency table is built correctly by
concatenating three separate rope_params calls with different dimension
normalizations, matching the reference implementation.
Background: The reference Wan model constructs RoPE frequencies as:
d = dim // num_heads (128 for all Wan models)
freqs = cat([
rope_params(1024, d - 4*(d//6)), # temporal (dim=44, 22 freqs)
rope_params(1024, 2*(d//6)), # height (dim=42, 21 freqs)
rope_params(1024, 2*(d//6)), # width (dim=42, 21 freqs)
])
A previous incorrect fix used a single rope_params(1024, 128) call, which
gave height/width axes only medium/high frequencies instead of full-range.
This destroyed spatial position encoding and caused grey/artifact output.
"""
import mlx.core as mx
import numpy as np
import pytest
class TestRoPEFrequencyConstruction:
"""Verify WanModel builds RoPE frequencies matching the reference."""
def _get_model_freqs(self, dim=64, num_heads=4):
"""Instantiate a tiny WanModel and return its .freqs tensor."""
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan.model import WanModel
config = WanModelConfig()
config.dim = dim
config.ffn_dim = dim * 2
config.num_heads = num_heads
config.num_layers = 1
config.in_dim = 4
config.out_dim = 4
config.freq_dim = 32
config.text_dim = 32
config.text_len = 8
model = WanModel(config)
mx.eval(model.freqs)
return model.freqs, dim // num_heads
def test_freqs_shape(self):
"""Freqs should be [1024, head_dim//2, 2] regardless of construction."""
freqs, head_dim = self._get_model_freqs(dim=64, num_heads=4)
assert freqs.shape == (1024, head_dim // 2, 2)
def test_three_call_vs_single_call_differ(self):
"""Three separate rope_params calls must differ from single call."""
from mlx_video.models.wan.rope import rope_params
d = 128 # head_dim for all Wan models
# Reference: three separate calls
correct = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
# Wrong: single call
wrong = rope_params(1024, d)
mx.eval(correct, wrong)
assert correct.shape == wrong.shape
diff = np.abs(np.array(correct) - np.array(wrong)).max()
assert diff > 0.1, f"Three-call and single-call should differ significantly, got max diff {diff}"
def test_each_axis_starts_at_frequency_one(self):
"""Each axis (temporal/height/width) should have cos=1, sin=0 at position 0.
This verifies each axis gets its own independent frequency range
starting from theta^0 = 1.0 (i.e., exponent 0/dim).
"""
from mlx_video.models.wan.rope import rope_params
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
mx.eval(freqs)
f = np.array(freqs)
half_d = d // 2 # 64
d_t = half_d - 2 * (half_d // 3) # 22
d_h = half_d // 3 # 21
# At position 0, cos=1 and sin=0 for ALL frequency components
np.testing.assert_allclose(f[0, :, 0], 1.0, atol=1e-6, err_msg="cos at pos 0")
np.testing.assert_allclose(f[0, :, 1], 0.0, atol=1e-6, err_msg="sin at pos 0")
# At position 1, each axis should have its FIRST frequency near cos(1/theta^0)=cos(1)
# Temporal axis first freq
np.testing.assert_allclose(f[1, 0, 0], np.cos(1.0), atol=1e-5,
err_msg="temporal[0] cos at pos 1")
# Height axis first freq (starts at index d_t)
np.testing.assert_allclose(f[1, d_t, 0], np.cos(1.0), atol=1e-5,
err_msg="height[0] cos at pos 1")
# Width axis first freq (starts at index d_t + d_h)
np.testing.assert_allclose(f[1, d_t + d_h, 0], np.cos(1.0), atol=1e-5,
err_msg="width[0] cos at pos 1")
def test_height_width_frequencies_identical(self):
"""Height and width axes should have identical frequency tables.
Both use rope_params(1024, 2*(d//6)) = rope_params(1024, 42).
"""
from mlx_video.models.wan.rope import rope_params
d = 128
d_h_dim = 2 * (d // 6) # 42
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, d_h_dim),
rope_params(1024, d_h_dim),
], axis=1)
mx.eval(freqs)
f = np.array(freqs)
half_d = d // 2
d_t = half_d - 2 * (half_d // 3)
d_h = half_d // 3
height_freqs = f[:, d_t:d_t + d_h]
width_freqs = f[:, d_t + d_h:]
np.testing.assert_array_equal(height_freqs, width_freqs)
def test_frequency_range_per_axis(self):
"""Each axis should span a full frequency range, not a slice of one range.
With three-call construction, the inverse frequency at index 0 of each
axis should be 1.0 (theta^0). A single-call approach would give height
starting at ~0.04 and width at ~0.002 instead of 1.0.
"""
from mlx_video.models.wan.rope import rope_params
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
mx.eval(freqs)
f = np.array(freqs)
half_d = d // 2
d_t = half_d - 2 * (half_d // 3)
d_h = half_d // 3
# At position 1, the first frequency component of each axis should
# have significant magnitude (cos ≈ 0.54), not near-zero
pos1_t = f[1, 0, 0] # temporal first freq
pos1_h = f[1, d_t, 0] # height first freq
pos1_w = f[1, d_t + d_h, 0] # width first freq
assert pos1_t > 0.5, f"Temporal first freq at pos 1 should be >0.5, got {pos1_t}"
assert pos1_h > 0.5, f"Height first freq at pos 1 should be >0.5, got {pos1_h}"
assert pos1_w > 0.5, f"Width first freq at pos 1 should be >0.5, got {pos1_w}"
def test_model_freqs_match_manual_construction(self):
"""WanModel.freqs should match manually constructed three-call freqs."""
from mlx_video.models.wan.rope import rope_params
freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4)
d = head_dim # 16
freqs_manual = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
mx.eval(freqs_model, freqs_manual)
np.testing.assert_array_equal(
np.array(freqs_model), np.array(freqs_manual),
err_msg="WanModel.freqs should use three-call construction"
)
def test_model_freqs_14b_dimensions(self):
"""Verify freq dimensions for 14B-scale head_dim=128."""
from mlx_video.models.wan.rope import rope_params
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)), # dim=44 → 22 freq pairs
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
], axis=1)
mx.eval(freqs)
assert freqs.shape == (1024, 64, 2)
# Verify the split dimensions used by rope_apply
half_d = 64
d_t = half_d - 2 * (half_d // 3)
d_h = half_d // 3
d_w = half_d // 3
assert (d_t, d_h, d_w) == (22, 21, 21)
assert d_t + d_h + d_w == half_d
class TestRoPEFrequencyMatchesReference:
"""Cross-validate MLX RoPE against PyTorch reference implementation."""
@pytest.fixture
def has_torch(self):
try:
import torch
return True
except ImportError:
pytest.skip("PyTorch not installed")
def test_freqs_match_pytorch_reference(self, has_torch):
"""Numerically compare MLX and PyTorch frequency tables."""
import torch
from mlx_video.models.wan.rope import rope_params
d = 128
# PyTorch reference (from wan/modules/model.py)
def pt_rope_params(max_seq_len, dim, theta=10000):
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
ref = torch.cat([
pt_rope_params(1024, d - 4 * (d // 6)),
pt_rope_params(1024, 2 * (d // 6)),
pt_rope_params(1024, 2 * (d // 6)),
], dim=1)
# MLX
ours = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
mx.eval(ours)
our_cos = np.array(ours[:, :, 0])
our_sin = np.array(ours[:, :, 1])
ref_cos = ref.real.float().numpy()
ref_sin = ref.imag.float().numpy()
np.testing.assert_allclose(our_cos, ref_cos, atol=1e-6,
err_msg="cos mismatch vs PyTorch reference")
np.testing.assert_allclose(our_sin, ref_sin, atol=1e-6,
err_msg="sin mismatch vs PyTorch reference")
class TestRoPEApplyWithCorrectFreqs:
"""Test that rope_apply produces correct rotations with three-call freqs."""
def test_different_spatial_positions_get_different_rotations(self):
"""Adjacent height/width positions must produce different RoPE rotations.
This is the key property that was broken by the single-call bug:
height/width frequencies were too low to distinguish nearby positions.
"""
from mlx_video.models.wan.rope import rope_params, rope_apply
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
B, N = 1, 4
F, H, W = 1, 4, 4
L = F * H * W
# Use a constant input so differences come purely from RoPE
x = mx.ones((B, L, N, d))
out = rope_apply(x, [(F, H, W)], freqs)
mx.eval(out)
out_np = np.array(out[0])
# Position (0,0,0) vs (0,1,0) — different height
pos_00 = out_np[0 * H * W + 0 * W + 0] # (f=0, h=0, w=0)
pos_10 = out_np[0 * H * W + 1 * W + 0] # (f=0, h=1, w=0)
height_diff = np.abs(pos_00 - pos_10).max()
# Position (0,0,0) vs (0,0,1) — different width
pos_01 = out_np[0 * H * W + 0 * W + 1] # (f=0, h=0, w=1)
width_diff = np.abs(pos_00 - pos_01).max()
# Max diff should be >0.5 for both axes. With the bug, height was ~0.04
# and width was ~0.002. With correct freqs, both are ~1.3.
assert height_diff > 0.5, (
f"Adjacent height positions should differ significantly, got {height_diff:.4f}"
)
assert width_diff > 0.5, (
f"Adjacent width positions should differ significantly, got {width_diff:.4f}"
)
# Height and width should have identical frequency tables → same diffs
np.testing.assert_allclose(height_diff, width_diff, rtol=1e-5,
err_msg="Height and width should use identical frequency tables")
def test_precomputed_matches_online(self):
"""rope_precompute_cos_sin + rope_apply should match non-precomputed path."""
from mlx_video.models.wan.rope import (
rope_apply,
rope_params,
rope_precompute_cos_sin,
)
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
B, N = 2, 4
F, H, W = 2, 3, 4
L = F * H * W
grids = [(F, H, W), (F, H, W)]
x = mx.random.normal((B, L, N, d))
# Online (no precomputed)
out_online = rope_apply(x, grids, freqs)
# Precomputed
cos_sin = rope_precompute_cos_sin(grids, freqs)
out_precomp = rope_apply(x, grids, freqs, precomputed_cos_sin=cos_sin)
mx.eval(out_online, out_precomp)
np.testing.assert_allclose(
np.array(out_online), np.array(out_precomp), atol=1e-5,
err_msg="Precomputed and online RoPE should match"
)