385 lines
13 KiB
Python
385 lines
13 KiB
Python
"""Tests for Wan RoPE frequency construction (Bug 6 regression tests).
|
|
|
|
These tests verify that the RoPE frequency table is built correctly by
|
|
concatenating three separate rope_params calls with different dimension
|
|
normalizations, matching the reference implementation.
|
|
|
|
Background: The reference Wan model constructs RoPE frequencies as:
|
|
d = dim // num_heads (128 for all Wan models)
|
|
freqs = cat([
|
|
rope_params(1024, d - 4*(d//6)), # temporal (dim=44, 22 freqs)
|
|
rope_params(1024, 2*(d//6)), # height (dim=42, 21 freqs)
|
|
rope_params(1024, 2*(d//6)), # width (dim=42, 21 freqs)
|
|
])
|
|
|
|
A previous incorrect fix used a single rope_params(1024, 128) call, which
|
|
gave height/width axes only medium/high frequencies instead of full-range.
|
|
This destroyed spatial position encoding and caused grey/artifact output.
|
|
"""
|
|
|
|
import mlx.core as mx
|
|
import numpy as np
|
|
import pytest
|
|
|
|
|
|
class TestRoPEFrequencyConstruction:
|
|
"""Verify WanModel builds RoPE frequencies matching the reference."""
|
|
|
|
def _get_model_freqs(self, dim=64, num_heads=4):
|
|
"""Instantiate a tiny WanModel and return its .freqs tensor."""
|
|
from mlx_video.models.wan2.config import WanModelConfig
|
|
from mlx_video.models.wan2.wan2 import WanModel
|
|
|
|
config = WanModelConfig()
|
|
config.dim = dim
|
|
config.ffn_dim = dim * 2
|
|
config.num_heads = num_heads
|
|
config.num_layers = 1
|
|
config.in_dim = 4
|
|
config.out_dim = 4
|
|
config.freq_dim = 32
|
|
config.text_dim = 32
|
|
config.text_len = 8
|
|
model = WanModel(config)
|
|
mx.eval(model.freqs)
|
|
return model.freqs, dim // num_heads
|
|
|
|
def test_freqs_shape(self):
|
|
"""Freqs should be [1024, head_dim//2, 2] regardless of construction."""
|
|
freqs, head_dim = self._get_model_freqs(dim=64, num_heads=4)
|
|
assert freqs.shape == (1024, head_dim // 2, 2)
|
|
|
|
def test_three_call_vs_single_call_differ(self):
|
|
"""Three separate rope_params calls must differ from single call."""
|
|
from mlx_video.models.wan2.rope import rope_params
|
|
|
|
d = 128 # head_dim for all Wan models
|
|
# Reference: three separate calls
|
|
correct = mx.concatenate(
|
|
[
|
|
rope_params(1024, d - 4 * (d // 6)),
|
|
rope_params(1024, 2 * (d // 6)),
|
|
rope_params(1024, 2 * (d // 6)),
|
|
],
|
|
axis=1,
|
|
)
|
|
# Wrong: single call
|
|
wrong = rope_params(1024, d)
|
|
mx.eval(correct, wrong)
|
|
|
|
assert correct.shape == wrong.shape
|
|
diff = np.abs(np.array(correct) - np.array(wrong)).max()
|
|
assert (
|
|
diff > 0.1
|
|
), f"Three-call and single-call should differ significantly, got max diff {diff}"
|
|
|
|
def test_each_axis_starts_at_frequency_one(self):
|
|
"""Each axis (temporal/height/width) should have cos=1, sin=0 at position 0.
|
|
|
|
This verifies each axis gets its own independent frequency range
|
|
starting from theta^0 = 1.0 (i.e., exponent 0/dim).
|
|
"""
|
|
from mlx_video.models.wan2.rope import rope_params
|
|
|
|
d = 128
|
|
freqs = mx.concatenate(
|
|
[
|
|
rope_params(1024, d - 4 * (d // 6)),
|
|
rope_params(1024, 2 * (d // 6)),
|
|
rope_params(1024, 2 * (d // 6)),
|
|
],
|
|
axis=1,
|
|
)
|
|
mx.eval(freqs)
|
|
f = np.array(freqs)
|
|
|
|
half_d = d // 2 # 64
|
|
d_t = half_d - 2 * (half_d // 3) # 22
|
|
d_h = half_d // 3 # 21
|
|
|
|
# At position 0, cos=1 and sin=0 for ALL frequency components
|
|
np.testing.assert_allclose(f[0, :, 0], 1.0, atol=1e-6, err_msg="cos at pos 0")
|
|
np.testing.assert_allclose(f[0, :, 1], 0.0, atol=1e-6, err_msg="sin at pos 0")
|
|
|
|
# At position 1, each axis should have its FIRST frequency near cos(1/theta^0)=cos(1)
|
|
# Temporal axis first freq
|
|
np.testing.assert_allclose(
|
|
f[1, 0, 0], np.cos(1.0), atol=1e-5, err_msg="temporal[0] cos at pos 1"
|
|
)
|
|
# Height axis first freq (starts at index d_t)
|
|
np.testing.assert_allclose(
|
|
f[1, d_t, 0], np.cos(1.0), atol=1e-5, err_msg="height[0] cos at pos 1"
|
|
)
|
|
# Width axis first freq (starts at index d_t + d_h)
|
|
np.testing.assert_allclose(
|
|
f[1, d_t + d_h, 0], np.cos(1.0), atol=1e-5, err_msg="width[0] cos at pos 1"
|
|
)
|
|
|
|
def test_height_width_frequencies_identical(self):
|
|
"""Height and width axes should have identical frequency tables.
|
|
|
|
Both use rope_params(1024, 2*(d//6)) = rope_params(1024, 42).
|
|
"""
|
|
from mlx_video.models.wan2.rope import rope_params
|
|
|
|
d = 128
|
|
d_h_dim = 2 * (d // 6) # 42
|
|
freqs = mx.concatenate(
|
|
[
|
|
rope_params(1024, d - 4 * (d // 6)),
|
|
rope_params(1024, d_h_dim),
|
|
rope_params(1024, d_h_dim),
|
|
],
|
|
axis=1,
|
|
)
|
|
mx.eval(freqs)
|
|
f = np.array(freqs)
|
|
|
|
half_d = d // 2
|
|
d_t = half_d - 2 * (half_d // 3)
|
|
d_h = half_d // 3
|
|
|
|
height_freqs = f[:, d_t : d_t + d_h]
|
|
width_freqs = f[:, d_t + d_h :]
|
|
np.testing.assert_array_equal(height_freqs, width_freqs)
|
|
|
|
def test_frequency_range_per_axis(self):
|
|
"""Each axis should span a full frequency range, not a slice of one range.
|
|
|
|
With three-call construction, the inverse frequency at index 0 of each
|
|
axis should be 1.0 (theta^0). A single-call approach would give height
|
|
starting at ~0.04 and width at ~0.002 instead of 1.0.
|
|
"""
|
|
from mlx_video.models.wan2.rope import rope_params
|
|
|
|
d = 128
|
|
freqs = mx.concatenate(
|
|
[
|
|
rope_params(1024, d - 4 * (d // 6)),
|
|
rope_params(1024, 2 * (d // 6)),
|
|
rope_params(1024, 2 * (d // 6)),
|
|
],
|
|
axis=1,
|
|
)
|
|
mx.eval(freqs)
|
|
f = np.array(freqs)
|
|
|
|
half_d = d // 2
|
|
d_t = half_d - 2 * (half_d // 3)
|
|
d_h = half_d // 3
|
|
|
|
# At position 1, the first frequency component of each axis should
|
|
# have significant magnitude (cos ≈ 0.54), not near-zero
|
|
pos1_t = f[1, 0, 0] # temporal first freq
|
|
pos1_h = f[1, d_t, 0] # height first freq
|
|
pos1_w = f[1, d_t + d_h, 0] # width first freq
|
|
|
|
assert (
|
|
pos1_t > 0.5
|
|
), f"Temporal first freq at pos 1 should be >0.5, got {pos1_t}"
|
|
assert pos1_h > 0.5, f"Height first freq at pos 1 should be >0.5, got {pos1_h}"
|
|
assert pos1_w > 0.5, f"Width first freq at pos 1 should be >0.5, got {pos1_w}"
|
|
|
|
def test_model_freqs_match_manual_construction(self):
|
|
"""WanModel.freqs should match manually constructed three-call freqs."""
|
|
from mlx_video.models.wan2.rope import rope_params
|
|
|
|
freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4)
|
|
d = head_dim # 16
|
|
freqs_manual = mx.concatenate(
|
|
[
|
|
rope_params(1024, d - 4 * (d // 6)),
|
|
rope_params(1024, 2 * (d // 6)),
|
|
rope_params(1024, 2 * (d // 6)),
|
|
],
|
|
axis=1,
|
|
)
|
|
mx.eval(freqs_model, freqs_manual)
|
|
np.testing.assert_array_equal(
|
|
np.array(freqs_model),
|
|
np.array(freqs_manual),
|
|
err_msg="WanModel.freqs should use three-call construction",
|
|
)
|
|
|
|
def test_model_freqs_14b_dimensions(self):
|
|
"""Verify freq dimensions for 14B-scale head_dim=128."""
|
|
from mlx_video.models.wan2.rope import rope_params
|
|
|
|
d = 128
|
|
freqs = mx.concatenate(
|
|
[
|
|
rope_params(1024, d - 4 * (d // 6)), # dim=44 → 22 freq pairs
|
|
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
|
|
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
|
|
],
|
|
axis=1,
|
|
)
|
|
mx.eval(freqs)
|
|
|
|
assert freqs.shape == (1024, 64, 2)
|
|
# Verify the split dimensions used by rope_apply
|
|
half_d = 64
|
|
d_t = half_d - 2 * (half_d // 3)
|
|
d_h = half_d // 3
|
|
d_w = half_d // 3
|
|
assert (d_t, d_h, d_w) == (22, 21, 21)
|
|
assert d_t + d_h + d_w == half_d
|
|
|
|
|
|
class TestRoPEFrequencyMatchesReference:
|
|
"""Cross-validate MLX RoPE against PyTorch reference implementation."""
|
|
|
|
@pytest.fixture
|
|
def has_torch(self):
|
|
try:
|
|
pass
|
|
|
|
return True
|
|
except ImportError:
|
|
pytest.skip("PyTorch not installed")
|
|
|
|
def test_freqs_match_pytorch_reference(self, has_torch):
|
|
"""Numerically compare MLX and PyTorch frequency tables."""
|
|
import torch
|
|
|
|
from mlx_video.models.wan2.rope import rope_params
|
|
|
|
d = 128
|
|
|
|
# PyTorch reference (from wan/modules/model.py)
|
|
def pt_rope_params(max_seq_len, dim, theta=10000):
|
|
freqs = torch.outer(
|
|
torch.arange(max_seq_len),
|
|
1.0
|
|
/ torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)),
|
|
)
|
|
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
|
return freqs
|
|
|
|
ref = torch.cat(
|
|
[
|
|
pt_rope_params(1024, d - 4 * (d // 6)),
|
|
pt_rope_params(1024, 2 * (d // 6)),
|
|
pt_rope_params(1024, 2 * (d // 6)),
|
|
],
|
|
dim=1,
|
|
)
|
|
|
|
# MLX
|
|
ours = mx.concatenate(
|
|
[
|
|
rope_params(1024, d - 4 * (d // 6)),
|
|
rope_params(1024, 2 * (d // 6)),
|
|
rope_params(1024, 2 * (d // 6)),
|
|
],
|
|
axis=1,
|
|
)
|
|
mx.eval(ours)
|
|
|
|
our_cos = np.array(ours[:, :, 0])
|
|
our_sin = np.array(ours[:, :, 1])
|
|
ref_cos = ref.real.float().numpy()
|
|
ref_sin = ref.imag.float().numpy()
|
|
|
|
np.testing.assert_allclose(
|
|
our_cos, ref_cos, atol=1e-6, err_msg="cos mismatch vs PyTorch reference"
|
|
)
|
|
np.testing.assert_allclose(
|
|
our_sin, ref_sin, atol=1e-6, err_msg="sin mismatch vs PyTorch reference"
|
|
)
|
|
|
|
|
|
class TestRoPEApplyWithCorrectFreqs:
|
|
"""Test that rope_apply produces correct rotations with three-call freqs."""
|
|
|
|
def test_different_spatial_positions_get_different_rotations(self):
|
|
"""Adjacent height/width positions must produce different RoPE rotations.
|
|
|
|
This is the key property that was broken by the single-call bug:
|
|
height/width frequencies were too low to distinguish nearby positions.
|
|
"""
|
|
from mlx_video.models.wan2.rope import rope_apply, rope_params
|
|
|
|
d = 128
|
|
freqs = mx.concatenate(
|
|
[
|
|
rope_params(1024, d - 4 * (d // 6)),
|
|
rope_params(1024, 2 * (d // 6)),
|
|
rope_params(1024, 2 * (d // 6)),
|
|
],
|
|
axis=1,
|
|
)
|
|
|
|
B, N = 1, 4
|
|
F, H, W = 1, 4, 4
|
|
L = F * H * W
|
|
# Use a constant input so differences come purely from RoPE
|
|
x = mx.ones((B, L, N, d))
|
|
out = rope_apply(x, [(F, H, W)], freqs)
|
|
mx.eval(out)
|
|
out_np = np.array(out[0])
|
|
|
|
# Position (0,0,0) vs (0,1,0) — different height
|
|
pos_00 = out_np[0 * H * W + 0 * W + 0] # (f=0, h=0, w=0)
|
|
pos_10 = out_np[0 * H * W + 1 * W + 0] # (f=0, h=1, w=0)
|
|
height_diff = np.abs(pos_00 - pos_10).max()
|
|
|
|
# Position (0,0,0) vs (0,0,1) — different width
|
|
pos_01 = out_np[0 * H * W + 0 * W + 1] # (f=0, h=0, w=1)
|
|
width_diff = np.abs(pos_00 - pos_01).max()
|
|
|
|
# Max diff should be >0.5 for both axes. With the bug, height was ~0.04
|
|
# and width was ~0.002. With correct freqs, both are ~1.3.
|
|
assert (
|
|
height_diff > 0.5
|
|
), f"Adjacent height positions should differ significantly, got {height_diff:.4f}"
|
|
assert (
|
|
width_diff > 0.5
|
|
), f"Adjacent width positions should differ significantly, got {width_diff:.4f}"
|
|
# Height and width should have identical frequency tables → same diffs
|
|
np.testing.assert_allclose(
|
|
height_diff,
|
|
width_diff,
|
|
rtol=1e-5,
|
|
err_msg="Height and width should use identical frequency tables",
|
|
)
|
|
|
|
def test_precomputed_matches_online(self):
|
|
"""rope_precompute_cos_sin + rope_apply should match non-precomputed path."""
|
|
from mlx_video.models.wan2.rope import (
|
|
rope_apply,
|
|
rope_params,
|
|
rope_precompute_cos_sin,
|
|
)
|
|
|
|
d = 128
|
|
freqs = mx.concatenate(
|
|
[
|
|
rope_params(1024, d - 4 * (d // 6)),
|
|
rope_params(1024, 2 * (d // 6)),
|
|
rope_params(1024, 2 * (d // 6)),
|
|
],
|
|
axis=1,
|
|
)
|
|
|
|
B, N = 2, 4
|
|
F, H, W = 2, 3, 4
|
|
L = F * H * W
|
|
grids = [(F, H, W), (F, H, W)]
|
|
|
|
x = mx.random.normal((B, L, N, d))
|
|
|
|
# Online (no precomputed)
|
|
out_online = rope_apply(x, grids, freqs)
|
|
# Precomputed
|
|
cos_sin = rope_precompute_cos_sin(grids, freqs)
|
|
out_precomp = rope_apply(x, grids, freqs, precomputed_cos_sin=cos_sin)
|
|
mx.eval(out_online, out_precomp)
|
|
|
|
np.testing.assert_allclose(
|
|
np.array(out_online),
|
|
np.array(out_precomp),
|
|
atol=1e-5,
|
|
err_msg="Precomputed and online RoPE should match",
|
|
)
|