This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -55,18 +55,23 @@ class TestRoPEFrequencyConstruction:
d = 128 # head_dim for all Wan models
# Reference: three separate calls
correct = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
correct = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
# Wrong: single call
wrong = rope_params(1024, d)
mx.eval(correct, wrong)
assert correct.shape == wrong.shape
diff = np.abs(np.array(correct) - np.array(wrong)).max()
assert diff > 0.1, f"Three-call and single-call should differ significantly, got max diff {diff}"
assert (
diff > 0.1
), f"Three-call and single-call should differ significantly, got max diff {diff}"
def test_each_axis_starts_at_frequency_one(self):
"""Each axis (temporal/height/width) should have cos=1, sin=0 at position 0.
@@ -77,11 +82,14 @@ class TestRoPEFrequencyConstruction:
from mlx_video.models.wan.rope import rope_params
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
mx.eval(freqs)
f = np.array(freqs)
@@ -95,14 +103,17 @@ class TestRoPEFrequencyConstruction:
# At position 1, each axis should have its FIRST frequency near cos(1/theta^0)=cos(1)
# Temporal axis first freq
np.testing.assert_allclose(f[1, 0, 0], np.cos(1.0), atol=1e-5,
err_msg="temporal[0] cos at pos 1")
np.testing.assert_allclose(
f[1, 0, 0], np.cos(1.0), atol=1e-5, err_msg="temporal[0] cos at pos 1"
)
# Height axis first freq (starts at index d_t)
np.testing.assert_allclose(f[1, d_t, 0], np.cos(1.0), atol=1e-5,
err_msg="height[0] cos at pos 1")
np.testing.assert_allclose(
f[1, d_t, 0], np.cos(1.0), atol=1e-5, err_msg="height[0] cos at pos 1"
)
# Width axis first freq (starts at index d_t + d_h)
np.testing.assert_allclose(f[1, d_t + d_h, 0], np.cos(1.0), atol=1e-5,
err_msg="width[0] cos at pos 1")
np.testing.assert_allclose(
f[1, d_t + d_h, 0], np.cos(1.0), atol=1e-5, err_msg="width[0] cos at pos 1"
)
def test_height_width_frequencies_identical(self):
"""Height and width axes should have identical frequency tables.
@@ -113,11 +124,14 @@ class TestRoPEFrequencyConstruction:
d = 128
d_h_dim = 2 * (d // 6) # 42
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, d_h_dim),
rope_params(1024, d_h_dim),
], axis=1)
freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, d_h_dim),
rope_params(1024, d_h_dim),
],
axis=1,
)
mx.eval(freqs)
f = np.array(freqs)
@@ -125,8 +139,8 @@ class TestRoPEFrequencyConstruction:
d_t = half_d - 2 * (half_d // 3)
d_h = half_d // 3
height_freqs = f[:, d_t:d_t + d_h]
width_freqs = f[:, d_t + d_h:]
height_freqs = f[:, d_t : d_t + d_h]
width_freqs = f[:, d_t + d_h :]
np.testing.assert_array_equal(height_freqs, width_freqs)
def test_frequency_range_per_axis(self):
@@ -139,11 +153,14 @@ class TestRoPEFrequencyConstruction:
from mlx_video.models.wan.rope import rope_params
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
mx.eval(freqs)
f = np.array(freqs)
@@ -157,7 +174,9 @@ class TestRoPEFrequencyConstruction:
pos1_h = f[1, d_t, 0] # height first freq
pos1_w = f[1, d_t + d_h, 0] # width first freq
assert pos1_t > 0.5, f"Temporal first freq at pos 1 should be >0.5, got {pos1_t}"
assert (
pos1_t > 0.5
), f"Temporal first freq at pos 1 should be >0.5, got {pos1_t}"
assert pos1_h > 0.5, f"Height first freq at pos 1 should be >0.5, got {pos1_h}"
assert pos1_w > 0.5, f"Width first freq at pos 1 should be >0.5, got {pos1_w}"
@@ -167,15 +186,19 @@ class TestRoPEFrequencyConstruction:
freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4)
d = head_dim # 16
freqs_manual = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
freqs_manual = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
mx.eval(freqs_model, freqs_manual)
np.testing.assert_array_equal(
np.array(freqs_model), np.array(freqs_manual),
err_msg="WanModel.freqs should use three-call construction"
np.array(freqs_model),
np.array(freqs_manual),
err_msg="WanModel.freqs should use three-call construction",
)
def test_model_freqs_14b_dimensions(self):
@@ -183,11 +206,14 @@ class TestRoPEFrequencyConstruction:
from mlx_video.models.wan.rope import rope_params
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)), # dim=44 → 22 freq pairs
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
], axis=1)
freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)), # dim=44 → 22 freq pairs
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
],
axis=1,
)
mx.eval(freqs)
assert freqs.shape == (1024, 64, 2)
@@ -206,7 +232,8 @@ class TestRoPEFrequencyMatchesReference:
@pytest.fixture
def has_torch(self):
try:
import torch
pass
return True
except ImportError:
pytest.skip("PyTorch not installed")
@@ -214,6 +241,7 @@ class TestRoPEFrequencyMatchesReference:
def test_freqs_match_pytorch_reference(self, has_torch):
"""Numerically compare MLX and PyTorch frequency tables."""
import torch
from mlx_video.models.wan.rope import rope_params
d = 128
@@ -222,22 +250,30 @@ class TestRoPEFrequencyMatchesReference:
def pt_rope_params(max_seq_len, dim, theta=10000):
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)))
1.0
/ torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)),
)
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
ref = torch.cat([
pt_rope_params(1024, d - 4 * (d // 6)),
pt_rope_params(1024, 2 * (d // 6)),
pt_rope_params(1024, 2 * (d // 6)),
], dim=1)
ref = torch.cat(
[
pt_rope_params(1024, d - 4 * (d // 6)),
pt_rope_params(1024, 2 * (d // 6)),
pt_rope_params(1024, 2 * (d // 6)),
],
dim=1,
)
# MLX
ours = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
ours = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
mx.eval(ours)
our_cos = np.array(ours[:, :, 0])
@@ -245,10 +281,12 @@ class TestRoPEFrequencyMatchesReference:
ref_cos = ref.real.float().numpy()
ref_sin = ref.imag.float().numpy()
np.testing.assert_allclose(our_cos, ref_cos, atol=1e-6,
err_msg="cos mismatch vs PyTorch reference")
np.testing.assert_allclose(our_sin, ref_sin, atol=1e-6,
err_msg="sin mismatch vs PyTorch reference")
np.testing.assert_allclose(
our_cos, ref_cos, atol=1e-6, err_msg="cos mismatch vs PyTorch reference"
)
np.testing.assert_allclose(
our_sin, ref_sin, atol=1e-6, err_msg="sin mismatch vs PyTorch reference"
)
class TestRoPEApplyWithCorrectFreqs:
@@ -260,14 +298,17 @@ class TestRoPEApplyWithCorrectFreqs:
This is the key property that was broken by the single-call bug:
height/width frequencies were too low to distinguish nearby positions.
"""
from mlx_video.models.wan.rope import rope_params, rope_apply
from mlx_video.models.wan.rope import rope_apply, rope_params
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
B, N = 1, 4
F, H, W = 1, 4, 4
@@ -289,15 +330,19 @@ class TestRoPEApplyWithCorrectFreqs:
# Max diff should be >0.5 for both axes. With the bug, height was ~0.04
# and width was ~0.002. With correct freqs, both are ~1.3.
assert height_diff > 0.5, (
f"Adjacent height positions should differ significantly, got {height_diff:.4f}"
)
assert width_diff > 0.5, (
f"Adjacent width positions should differ significantly, got {width_diff:.4f}"
)
assert (
height_diff > 0.5
), f"Adjacent height positions should differ significantly, got {height_diff:.4f}"
assert (
width_diff > 0.5
), f"Adjacent width positions should differ significantly, got {width_diff:.4f}"
# Height and width should have identical frequency tables → same diffs
np.testing.assert_allclose(height_diff, width_diff, rtol=1e-5,
err_msg="Height and width should use identical frequency tables")
np.testing.assert_allclose(
height_diff,
width_diff,
rtol=1e-5,
err_msg="Height and width should use identical frequency tables",
)
def test_precomputed_matches_online(self):
"""rope_precompute_cos_sin + rope_apply should match non-precomputed path."""
@@ -308,11 +353,14 @@ class TestRoPEApplyWithCorrectFreqs:
)
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
B, N = 2, 4
F, H, W = 2, 3, 4
@@ -329,6 +377,8 @@ class TestRoPEApplyWithCorrectFreqs:
mx.eval(out_online, out_precomp)
np.testing.assert_allclose(
np.array(out_online), np.array(out_precomp), atol=1e-5,
err_msg="Precomputed and online RoPE should match"
np.array(out_online),
np.array(out_precomp),
atol=1e-5,
err_msg="Precomputed and online RoPE should match",
)