format
This commit is contained in:
@@ -55,18 +55,23 @@ class TestRoPEFrequencyConstruction:
|
||||
|
||||
d = 128 # head_dim for all Wan models
|
||||
# Reference: three separate calls
|
||||
correct = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
correct = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
# Wrong: single call
|
||||
wrong = rope_params(1024, d)
|
||||
mx.eval(correct, wrong)
|
||||
|
||||
assert correct.shape == wrong.shape
|
||||
diff = np.abs(np.array(correct) - np.array(wrong)).max()
|
||||
assert diff > 0.1, f"Three-call and single-call should differ significantly, got max diff {diff}"
|
||||
assert (
|
||||
diff > 0.1
|
||||
), f"Three-call and single-call should differ significantly, got max diff {diff}"
|
||||
|
||||
def test_each_axis_starts_at_frequency_one(self):
|
||||
"""Each axis (temporal/height/width) should have cos=1, sin=0 at position 0.
|
||||
@@ -77,11 +82,14 @@ class TestRoPEFrequencyConstruction:
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
mx.eval(freqs)
|
||||
f = np.array(freqs)
|
||||
|
||||
@@ -95,14 +103,17 @@ class TestRoPEFrequencyConstruction:
|
||||
|
||||
# At position 1, each axis should have its FIRST frequency near cos(1/theta^0)=cos(1)
|
||||
# Temporal axis first freq
|
||||
np.testing.assert_allclose(f[1, 0, 0], np.cos(1.0), atol=1e-5,
|
||||
err_msg="temporal[0] cos at pos 1")
|
||||
np.testing.assert_allclose(
|
||||
f[1, 0, 0], np.cos(1.0), atol=1e-5, err_msg="temporal[0] cos at pos 1"
|
||||
)
|
||||
# Height axis first freq (starts at index d_t)
|
||||
np.testing.assert_allclose(f[1, d_t, 0], np.cos(1.0), atol=1e-5,
|
||||
err_msg="height[0] cos at pos 1")
|
||||
np.testing.assert_allclose(
|
||||
f[1, d_t, 0], np.cos(1.0), atol=1e-5, err_msg="height[0] cos at pos 1"
|
||||
)
|
||||
# Width axis first freq (starts at index d_t + d_h)
|
||||
np.testing.assert_allclose(f[1, d_t + d_h, 0], np.cos(1.0), atol=1e-5,
|
||||
err_msg="width[0] cos at pos 1")
|
||||
np.testing.assert_allclose(
|
||||
f[1, d_t + d_h, 0], np.cos(1.0), atol=1e-5, err_msg="width[0] cos at pos 1"
|
||||
)
|
||||
|
||||
def test_height_width_frequencies_identical(self):
|
||||
"""Height and width axes should have identical frequency tables.
|
||||
@@ -113,11 +124,14 @@ class TestRoPEFrequencyConstruction:
|
||||
|
||||
d = 128
|
||||
d_h_dim = 2 * (d // 6) # 42
|
||||
freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, d_h_dim),
|
||||
rope_params(1024, d_h_dim),
|
||||
], axis=1)
|
||||
freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, d_h_dim),
|
||||
rope_params(1024, d_h_dim),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
mx.eval(freqs)
|
||||
f = np.array(freqs)
|
||||
|
||||
@@ -125,8 +139,8 @@ class TestRoPEFrequencyConstruction:
|
||||
d_t = half_d - 2 * (half_d // 3)
|
||||
d_h = half_d // 3
|
||||
|
||||
height_freqs = f[:, d_t:d_t + d_h]
|
||||
width_freqs = f[:, d_t + d_h:]
|
||||
height_freqs = f[:, d_t : d_t + d_h]
|
||||
width_freqs = f[:, d_t + d_h :]
|
||||
np.testing.assert_array_equal(height_freqs, width_freqs)
|
||||
|
||||
def test_frequency_range_per_axis(self):
|
||||
@@ -139,11 +153,14 @@ class TestRoPEFrequencyConstruction:
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
mx.eval(freqs)
|
||||
f = np.array(freqs)
|
||||
|
||||
@@ -157,7 +174,9 @@ class TestRoPEFrequencyConstruction:
|
||||
pos1_h = f[1, d_t, 0] # height first freq
|
||||
pos1_w = f[1, d_t + d_h, 0] # width first freq
|
||||
|
||||
assert pos1_t > 0.5, f"Temporal first freq at pos 1 should be >0.5, got {pos1_t}"
|
||||
assert (
|
||||
pos1_t > 0.5
|
||||
), f"Temporal first freq at pos 1 should be >0.5, got {pos1_t}"
|
||||
assert pos1_h > 0.5, f"Height first freq at pos 1 should be >0.5, got {pos1_h}"
|
||||
assert pos1_w > 0.5, f"Width first freq at pos 1 should be >0.5, got {pos1_w}"
|
||||
|
||||
@@ -167,15 +186,19 @@ class TestRoPEFrequencyConstruction:
|
||||
|
||||
freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4)
|
||||
d = head_dim # 16
|
||||
freqs_manual = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
freqs_manual = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
mx.eval(freqs_model, freqs_manual)
|
||||
np.testing.assert_array_equal(
|
||||
np.array(freqs_model), np.array(freqs_manual),
|
||||
err_msg="WanModel.freqs should use three-call construction"
|
||||
np.array(freqs_model),
|
||||
np.array(freqs_manual),
|
||||
err_msg="WanModel.freqs should use three-call construction",
|
||||
)
|
||||
|
||||
def test_model_freqs_14b_dimensions(self):
|
||||
@@ -183,11 +206,14 @@ class TestRoPEFrequencyConstruction:
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)), # dim=44 → 22 freq pairs
|
||||
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
|
||||
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
|
||||
], axis=1)
|
||||
freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)), # dim=44 → 22 freq pairs
|
||||
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
|
||||
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
mx.eval(freqs)
|
||||
|
||||
assert freqs.shape == (1024, 64, 2)
|
||||
@@ -206,7 +232,8 @@ class TestRoPEFrequencyMatchesReference:
|
||||
@pytest.fixture
|
||||
def has_torch(self):
|
||||
try:
|
||||
import torch
|
||||
pass
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
pytest.skip("PyTorch not installed")
|
||||
@@ -214,6 +241,7 @@ class TestRoPEFrequencyMatchesReference:
|
||||
def test_freqs_match_pytorch_reference(self, has_torch):
|
||||
"""Numerically compare MLX and PyTorch frequency tables."""
|
||||
import torch
|
||||
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
d = 128
|
||||
@@ -222,22 +250,30 @@ class TestRoPEFrequencyMatchesReference:
|
||||
def pt_rope_params(max_seq_len, dim, theta=10000):
|
||||
freqs = torch.outer(
|
||||
torch.arange(max_seq_len),
|
||||
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)))
|
||||
1.0
|
||||
/ torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)),
|
||||
)
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
ref = torch.cat([
|
||||
pt_rope_params(1024, d - 4 * (d // 6)),
|
||||
pt_rope_params(1024, 2 * (d // 6)),
|
||||
pt_rope_params(1024, 2 * (d // 6)),
|
||||
], dim=1)
|
||||
ref = torch.cat(
|
||||
[
|
||||
pt_rope_params(1024, d - 4 * (d // 6)),
|
||||
pt_rope_params(1024, 2 * (d // 6)),
|
||||
pt_rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# MLX
|
||||
ours = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
ours = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
mx.eval(ours)
|
||||
|
||||
our_cos = np.array(ours[:, :, 0])
|
||||
@@ -245,10 +281,12 @@ class TestRoPEFrequencyMatchesReference:
|
||||
ref_cos = ref.real.float().numpy()
|
||||
ref_sin = ref.imag.float().numpy()
|
||||
|
||||
np.testing.assert_allclose(our_cos, ref_cos, atol=1e-6,
|
||||
err_msg="cos mismatch vs PyTorch reference")
|
||||
np.testing.assert_allclose(our_sin, ref_sin, atol=1e-6,
|
||||
err_msg="sin mismatch vs PyTorch reference")
|
||||
np.testing.assert_allclose(
|
||||
our_cos, ref_cos, atol=1e-6, err_msg="cos mismatch vs PyTorch reference"
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
our_sin, ref_sin, atol=1e-6, err_msg="sin mismatch vs PyTorch reference"
|
||||
)
|
||||
|
||||
|
||||
class TestRoPEApplyWithCorrectFreqs:
|
||||
@@ -260,14 +298,17 @@ class TestRoPEApplyWithCorrectFreqs:
|
||||
This is the key property that was broken by the single-call bug:
|
||||
height/width frequencies were too low to distinguish nearby positions.
|
||||
"""
|
||||
from mlx_video.models.wan.rope import rope_params, rope_apply
|
||||
from mlx_video.models.wan.rope import rope_apply, rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
B, N = 1, 4
|
||||
F, H, W = 1, 4, 4
|
||||
@@ -289,15 +330,19 @@ class TestRoPEApplyWithCorrectFreqs:
|
||||
|
||||
# Max diff should be >0.5 for both axes. With the bug, height was ~0.04
|
||||
# and width was ~0.002. With correct freqs, both are ~1.3.
|
||||
assert height_diff > 0.5, (
|
||||
f"Adjacent height positions should differ significantly, got {height_diff:.4f}"
|
||||
)
|
||||
assert width_diff > 0.5, (
|
||||
f"Adjacent width positions should differ significantly, got {width_diff:.4f}"
|
||||
)
|
||||
assert (
|
||||
height_diff > 0.5
|
||||
), f"Adjacent height positions should differ significantly, got {height_diff:.4f}"
|
||||
assert (
|
||||
width_diff > 0.5
|
||||
), f"Adjacent width positions should differ significantly, got {width_diff:.4f}"
|
||||
# Height and width should have identical frequency tables → same diffs
|
||||
np.testing.assert_allclose(height_diff, width_diff, rtol=1e-5,
|
||||
err_msg="Height and width should use identical frequency tables")
|
||||
np.testing.assert_allclose(
|
||||
height_diff,
|
||||
width_diff,
|
||||
rtol=1e-5,
|
||||
err_msg="Height and width should use identical frequency tables",
|
||||
)
|
||||
|
||||
def test_precomputed_matches_online(self):
|
||||
"""rope_precompute_cos_sin + rope_apply should match non-precomputed path."""
|
||||
@@ -308,11 +353,14 @@ class TestRoPEApplyWithCorrectFreqs:
|
||||
)
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
B, N = 2, 4
|
||||
F, H, W = 2, 3, 4
|
||||
@@ -329,6 +377,8 @@ class TestRoPEApplyWithCorrectFreqs:
|
||||
mx.eval(out_online, out_precomp)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
np.array(out_online), np.array(out_precomp), atol=1e-5,
|
||||
err_msg="Precomputed and online RoPE should match"
|
||||
np.array(out_online),
|
||||
np.array(out_precomp),
|
||||
atol=1e-5,
|
||||
err_msg="Precomputed and online RoPE should match",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user