Implement regression tests for RoPE position precision using NumPy float64 reference. Add a new function to compute reference values and validate that float32 results closely match expected outputs, addressing high-frequency amplification issues. Update imports to include LTXModelConfig for enhanced configuration management.
This commit is contained in:
@@ -5,7 +5,7 @@ import numpy as np
|
||||
from mlx_video.models.ltx.rope import (
|
||||
precompute_freqs_cis,
|
||||
)
|
||||
from mlx_video.models.ltx.config import LTXRopeType
|
||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXRopeType
|
||||
|
||||
|
||||
def create_video_position_grid(
|
||||
@@ -36,6 +36,65 @@ def create_video_position_grid(
|
||||
|
||||
return mx.array(pixel_coords, dtype=dtype)
|
||||
|
||||
|
||||
def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads):
|
||||
"""Compute RoPE cos/sin using NumPy float64 as ground truth reference.
|
||||
|
||||
This mirrors the regular (non-double-precision) path in rope.py exactly,
|
||||
but uses float64 throughout, so we can verify that the float32 MLX path
|
||||
stays close to the true values.
|
||||
"""
|
||||
# positions_np: (B, 3, T, 2) in float64
|
||||
n_pos_dims = positions_np.shape[1]
|
||||
n_elem = 2 * n_pos_dims
|
||||
|
||||
# Middle-of-interval positions
|
||||
mid = (positions_np[..., 0] + positions_np[..., 1]) / 2.0 # (B, 3, T)
|
||||
|
||||
# Frequency grid — matches generate_freq_grid() in rope.py:
|
||||
# log_start = log(1)/log(theta) = 0
|
||||
# log_end = log(theta)/log(theta) = 1
|
||||
# pow_indices = theta^linspace(0, 1, num_indices) * pi/2
|
||||
num_indices = dim // n_elem
|
||||
if num_indices == 0:
|
||||
num_indices = 1
|
||||
lin_space = np.linspace(0.0, 1.0, num_indices, dtype=np.float64)
|
||||
freq_indices = np.power(theta, lin_space) * (np.pi / 2) # (num_indices,)
|
||||
|
||||
# Fractional positions and scaling — matches generate_freqs()
|
||||
# frac = pos / max_pos for each dim, then scale to [-1, 1]
|
||||
frac_list = []
|
||||
for d in range(n_pos_dims):
|
||||
frac = mid[:, d, :] / max_pos[d] # (B, T)
|
||||
frac_list.append(frac)
|
||||
fractional = np.stack(frac_list, axis=-1) # (B, T, n_dims)
|
||||
scaled = fractional * 2 - 1 # [-1, 1]
|
||||
|
||||
# Outer product: (B, T, n_dims, 1) * (1, 1, 1, num_indices)
|
||||
freqs = scaled[..., np.newaxis] * freq_indices[np.newaxis, np.newaxis, np.newaxis, :]
|
||||
# (B, T, n_dims, num_indices) -> swap last two -> (B, T, num_indices, n_dims) -> flatten
|
||||
freqs = np.swapaxes(freqs, -1, -2)
|
||||
freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # (B, T, num_indices * n_dims)
|
||||
|
||||
cos_ref = np.cos(freqs)
|
||||
sin_ref = np.sin(freqs)
|
||||
|
||||
# Split RoPE: pad to dim//2, reshape to (B, H, T, dim_per_head//2)
|
||||
expected = dim // 2
|
||||
pad_size = expected - cos_ref.shape[-1]
|
||||
if pad_size > 0:
|
||||
# Padding is prepended (ones for cos, zeros for sin) — matches split_freqs_cis()
|
||||
cos_ref = np.concatenate([np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1)
|
||||
sin_ref = np.concatenate([np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1)
|
||||
|
||||
B, T, F = cos_ref.shape
|
||||
dim_per_head = dim // num_heads
|
||||
cos_ref = cos_ref.reshape(B, T, num_heads, dim_per_head // 2).transpose(0, 2, 1, 3)
|
||||
sin_ref = sin_ref.reshape(B, T, num_heads, dim_per_head // 2).transpose(0, 2, 1, 3)
|
||||
|
||||
return cos_ref, sin_ref
|
||||
|
||||
|
||||
class TestRoPEPositionPrecision:
|
||||
"""Test suite for RoPE position precision requirements."""
|
||||
|
||||
@@ -132,11 +191,6 @@ class TestRoPEPositionPrecision:
|
||||
"""Verify that double_precision mode converts bfloat16 to float32 first."""
|
||||
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
|
||||
|
||||
# The double precision path in rope.py line 434:
|
||||
# indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
|
||||
# This means bfloat16 -> float32 -> float64
|
||||
# The precision is already lost at the bfloat16 -> float32 step
|
||||
|
||||
cos_freq, sin_freq = precompute_freqs_cis(
|
||||
indices_grid=positions_bf16,
|
||||
dim=128,
|
||||
@@ -176,6 +230,96 @@ class TestRoPEPositionPrecision:
|
||||
assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive"
|
||||
assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive"
|
||||
|
||||
def test_float32_positions_match_numpy_float64_reference(self):
|
||||
"""Regression test: float32 RoPE must closely match a NumPy float64 reference.
|
||||
|
||||
This is the key correctness test. We compute RoPE with NumPy in float64
|
||||
(ground truth) and verify that the MLX float32 path produces nearly
|
||||
identical results. The max allowed diff (1e-5) is well below the error
|
||||
we saw with bfloat16 positions (~2.0 max diff, cosine sim 0.88).
|
||||
"""
|
||||
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||
positions_np = np.array(positions).astype(np.float64)
|
||||
|
||||
dim = 128
|
||||
theta = 10000.0
|
||||
max_pos = [20, 2048, 2048]
|
||||
num_heads = 32
|
||||
|
||||
# MLX result (float32 path, non-double-precision)
|
||||
cos_mlx, sin_mlx = precompute_freqs_cis(
|
||||
indices_grid=positions,
|
||||
dim=dim,
|
||||
theta=theta,
|
||||
max_pos=max_pos,
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=num_heads,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=False,
|
||||
)
|
||||
|
||||
# NumPy float64 reference
|
||||
cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads)
|
||||
|
||||
cos_mlx_np = np.array(cos_mlx)
|
||||
sin_mlx_np = np.array(sin_mlx)
|
||||
|
||||
max_cos_diff = np.max(np.abs(cos_mlx_np - cos_ref))
|
||||
max_sin_diff = np.max(np.abs(sin_mlx_np - sin_ref))
|
||||
|
||||
# Cosine similarity (flatten for single scalar)
|
||||
cos_flat = cos_mlx_np.flatten()
|
||||
ref_flat = cos_ref.flatten()
|
||||
cosine_sim = np.dot(cos_flat, ref_flat) / (np.linalg.norm(cos_flat) * np.linalg.norm(ref_flat))
|
||||
|
||||
# float32 vs float64: expect small diffs from 23-bit vs 52-bit mantissa.
|
||||
# Threshold 0.01 is well below the bfloat16 failure mode (~2.0 max diff).
|
||||
assert max_cos_diff < 0.01, \
|
||||
f"cos max diff {max_cos_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
|
||||
assert max_sin_diff < 0.01, \
|
||||
f"sin max diff {max_sin_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
|
||||
assert cosine_sim > 0.9999, \
|
||||
f"cos cosine similarity {cosine_sim:.6f} too low — expected >0.9999"
|
||||
|
||||
def test_high_frequency_amplification_regression(self):
|
||||
"""Regression test for the specific failure mode: high-frequency index amplification.
|
||||
|
||||
With production-sized grids (5x16x16 = 1280 tokens), fractional positions
|
||||
like 0.000391 get multiplied by frequency indices up to ~15708. In bfloat16
|
||||
the fractional part is quantized, producing raw freq errors of ~6.14 and
|
||||
cos/sin sign flips (max_diff ~2.0). Float32 must keep max_diff < 0.01.
|
||||
"""
|
||||
# Use a production-like grid size
|
||||
positions = create_video_position_grid(1, 5, 16, 16, dtype=mx.float32)
|
||||
positions_np = np.array(positions).astype(np.float64)
|
||||
|
||||
dim = 128
|
||||
theta = 10000.0
|
||||
max_pos = [20, 2048, 2048]
|
||||
num_heads = 32
|
||||
|
||||
cos_mlx, sin_mlx = precompute_freqs_cis(
|
||||
indices_grid=positions,
|
||||
dim=dim,
|
||||
theta=theta,
|
||||
max_pos=max_pos,
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=num_heads,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=False,
|
||||
)
|
||||
|
||||
cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads)
|
||||
|
||||
max_cos_diff = np.max(np.abs(np.array(cos_mlx) - cos_ref))
|
||||
max_sin_diff = np.max(np.abs(np.array(sin_mlx) - sin_ref))
|
||||
|
||||
# Float32 should keep errors well below the bfloat16 failure threshold of ~2.0
|
||||
assert max_cos_diff < 0.01, \
|
||||
f"Production grid cos max diff {max_cos_diff:.4f} — high-freq amplification detected"
|
||||
assert max_sin_diff < 0.01, \
|
||||
f"Production grid sin max diff {max_sin_diff:.4f} — high-freq amplification detected"
|
||||
|
||||
|
||||
class TestRoPEInterleaved:
|
||||
"""Tests for interleaved RoPE mode."""
|
||||
@@ -201,43 +345,125 @@ class TestRoPEInterleaved:
|
||||
assert not mx.any(mx.isnan(sin_freq)).item()
|
||||
|
||||
|
||||
class TestRoPEWarnings:
|
||||
"""Tests for RoPE warnings."""
|
||||
class TestRoPEInputCasting:
|
||||
"""Tests that precompute_freqs_cis casts positions to float32 internally.
|
||||
|
||||
def test_bfloat16_positions_trigger_warning(self):
|
||||
"""Verify that bfloat16 positions trigger a UserWarning."""
|
||||
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
|
||||
The fix in rope.py ensures that regardless of the input dtype, positions are
|
||||
cast to float32 before any computation. This class verifies that behavior
|
||||
for both the regular and double-precision paths.
|
||||
"""
|
||||
|
||||
with pytest.warns(UserWarning, match="Position grid has dtype bfloat16"):
|
||||
precompute_freqs_cis(
|
||||
indices_grid=positions_bf16,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=True,
|
||||
)
|
||||
|
||||
def test_float32_positions_no_warning(self):
|
||||
"""Verify that float32 positions do NOT trigger a warning."""
|
||||
def test_regular_path_outputs_float32(self):
|
||||
"""Regular path: both float32 and bfloat16 inputs produce float32 output."""
|
||||
positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||
positions_bf16 = positions_f32.astype(mx.bfloat16)
|
||||
|
||||
# This should not raise any warnings
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error") # Turn warnings into errors
|
||||
precompute_freqs_cis(
|
||||
indices_grid=positions_f32,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=True,
|
||||
)
|
||||
kwargs = dict(
|
||||
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True, num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT, double_precision=False,
|
||||
)
|
||||
|
||||
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
|
||||
cos_bf16, sin_bf16 = precompute_freqs_cis(indices_grid=positions_bf16, **kwargs)
|
||||
|
||||
# Both produce float32 output regardless of input dtype
|
||||
assert cos_f32.dtype == mx.float32
|
||||
assert cos_bf16.dtype == mx.float32
|
||||
assert sin_f32.dtype == mx.float32
|
||||
assert sin_bf16.dtype == mx.float32
|
||||
|
||||
# No NaN/Inf in either
|
||||
assert not mx.any(mx.isnan(cos_bf16)).item()
|
||||
assert not mx.any(mx.isinf(cos_bf16)).item()
|
||||
|
||||
def test_double_precision_path_outputs_float32(self):
|
||||
"""Double-precision path: both float32 and bfloat16 inputs produce float32 output."""
|
||||
positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||
positions_bf16 = positions_f32.astype(mx.bfloat16)
|
||||
|
||||
kwargs = dict(
|
||||
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True, num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT, double_precision=True,
|
||||
)
|
||||
|
||||
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
|
||||
cos_bf16, sin_bf16 = precompute_freqs_cis(indices_grid=positions_bf16, **kwargs)
|
||||
|
||||
assert cos_f32.dtype == mx.float32
|
||||
assert cos_bf16.dtype == mx.float32
|
||||
assert sin_f32.dtype == mx.float32
|
||||
assert sin_bf16.dtype == mx.float32
|
||||
|
||||
assert not mx.any(mx.isnan(cos_bf16)).item()
|
||||
assert not mx.any(mx.isinf(cos_bf16)).item()
|
||||
|
||||
def test_float16_input_also_cast_to_float32(self):
|
||||
"""Float16 input should also be handled correctly."""
|
||||
positions_f16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float16)
|
||||
|
||||
cos_freq, sin_freq = precompute_freqs_cis(
|
||||
indices_grid=positions_f16,
|
||||
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True, num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT, double_precision=False,
|
||||
)
|
||||
|
||||
assert cos_freq.dtype == mx.float32
|
||||
assert sin_freq.dtype == mx.float32
|
||||
assert not mx.any(mx.isnan(cos_freq)).item()
|
||||
|
||||
|
||||
class TestDoublePrecisionRopeConfig:
|
||||
"""Tests for the conditional double_precision_rope logic in LTXModelConfig."""
|
||||
|
||||
def test_ltx2_forces_double_precision_rope_false(self):
|
||||
"""LTX-2 (no prompt adaln) must have double_precision_rope=False."""
|
||||
config = LTXModelConfig(has_prompt_adaln=False, double_precision_rope=True)
|
||||
assert config.double_precision_rope is False, \
|
||||
"LTX-2 should force double_precision_rope=False regardless of input"
|
||||
|
||||
def test_ltx23_preserves_double_precision_rope_true(self):
|
||||
"""LTX-2.3 (has_prompt_adaln=True) should keep double_precision_rope=True."""
|
||||
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=True)
|
||||
assert config.double_precision_rope is True, \
|
||||
"LTX-2.3 should preserve double_precision_rope=True"
|
||||
|
||||
def test_ltx23_preserves_double_precision_rope_false(self):
|
||||
"""LTX-2.3 with double_precision_rope=False should stay False."""
|
||||
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=False)
|
||||
assert config.double_precision_rope is False, \
|
||||
"LTX-2.3 should respect double_precision_rope=False when explicitly set"
|
||||
|
||||
def test_ltx2_default_double_precision_rope(self):
|
||||
"""LTX-2 default (double_precision_rope not set) should be False."""
|
||||
config = LTXModelConfig(has_prompt_adaln=False)
|
||||
assert config.double_precision_rope is False
|
||||
|
||||
def test_ltx23_default_double_precision_rope(self):
|
||||
"""LTX-2.3 default (double_precision_rope not set) should be False (field default)."""
|
||||
config = LTXModelConfig(has_prompt_adaln=True)
|
||||
# The field default is False and __post_init__ doesn't override for LTX-2.3
|
||||
assert config.double_precision_rope is False
|
||||
|
||||
def test_config_from_dict_ltx2(self):
|
||||
"""Config created from dict for LTX-2 should force double_precision_rope=False."""
|
||||
config = LTXModelConfig.from_dict({
|
||||
"has_prompt_adaln": False,
|
||||
"double_precision_rope": True,
|
||||
"rope_type": "split",
|
||||
})
|
||||
assert config.double_precision_rope is False
|
||||
|
||||
def test_config_from_dict_ltx23(self):
|
||||
"""Config created from dict for LTX-2.3 should preserve double_precision_rope."""
|
||||
config = LTXModelConfig.from_dict({
|
||||
"has_prompt_adaln": True,
|
||||
"double_precision_rope": True,
|
||||
"rope_type": "split",
|
||||
})
|
||||
assert config.double_precision_rope is True
|
||||
|
||||
|
||||
class TestRoPESplit:
|
||||
|
||||
Reference in New Issue
Block a user