From 38d46a6eda091527804abbaaf3b105f1481b2e0c Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 15 Mar 2026 23:00:38 +0100 Subject: [PATCH] Implement regression tests for RoPE position precision using NumPy float64 reference. Add a new function to compute reference values and validate that float32 results closely match expected outputs, addressing high-frequency amplification issues. Update imports to include LTXModelConfig for enhanced configuration management. --- tests/test_rope.py | 304 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 265 insertions(+), 39 deletions(-) diff --git a/tests/test_rope.py b/tests/test_rope.py index cef8d6f..f64a0c2 100644 --- a/tests/test_rope.py +++ b/tests/test_rope.py @@ -5,7 +5,7 @@ import numpy as np from mlx_video.models.ltx.rope import ( precompute_freqs_cis, ) -from mlx_video.models.ltx.config import LTXRopeType +from mlx_video.models.ltx.config import LTXModelConfig, LTXRopeType def create_video_position_grid( @@ -36,6 +36,65 @@ def create_video_position_grid( return mx.array(pixel_coords, dtype=dtype) + +def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads): + """Compute RoPE cos/sin using NumPy float64 as ground truth reference. + + This mirrors the regular (non-double-precision) path in rope.py exactly, + but uses float64 throughout, so we can verify that the float32 MLX path + stays close to the true values. + """ + # positions_np: (B, 3, T, 2) in float64 + n_pos_dims = positions_np.shape[1] + n_elem = 2 * n_pos_dims + + # Middle-of-interval positions + mid = (positions_np[..., 0] + positions_np[..., 1]) / 2.0 # (B, 3, T) + + # Frequency grid — matches generate_freq_grid() in rope.py: + # log_start = log(1)/log(theta) = 0 + # log_end = log(theta)/log(theta) = 1 + # pow_indices = theta^linspace(0, 1, num_indices) * pi/2 + num_indices = dim // n_elem + if num_indices == 0: + num_indices = 1 + lin_space = np.linspace(0.0, 1.0, num_indices, dtype=np.float64) + freq_indices = np.power(theta, lin_space) * (np.pi / 2) # (num_indices,) + + # Fractional positions and scaling — matches generate_freqs() + # frac = pos / max_pos for each dim, then scale to [-1, 1] + frac_list = [] + for d in range(n_pos_dims): + frac = mid[:, d, :] / max_pos[d] # (B, T) + frac_list.append(frac) + fractional = np.stack(frac_list, axis=-1) # (B, T, n_dims) + scaled = fractional * 2 - 1 # [-1, 1] + + # Outer product: (B, T, n_dims, 1) * (1, 1, 1, num_indices) + freqs = scaled[..., np.newaxis] * freq_indices[np.newaxis, np.newaxis, np.newaxis, :] + # (B, T, n_dims, num_indices) -> swap last two -> (B, T, num_indices, n_dims) -> flatten + freqs = np.swapaxes(freqs, -1, -2) + freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # (B, T, num_indices * n_dims) + + cos_ref = np.cos(freqs) + sin_ref = np.sin(freqs) + + # Split RoPE: pad to dim//2, reshape to (B, H, T, dim_per_head//2) + expected = dim // 2 + pad_size = expected - cos_ref.shape[-1] + if pad_size > 0: + # Padding is prepended (ones for cos, zeros for sin) — matches split_freqs_cis() + cos_ref = np.concatenate([np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1) + sin_ref = np.concatenate([np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1) + + B, T, F = cos_ref.shape + dim_per_head = dim // num_heads + cos_ref = cos_ref.reshape(B, T, num_heads, dim_per_head // 2).transpose(0, 2, 1, 3) + sin_ref = sin_ref.reshape(B, T, num_heads, dim_per_head // 2).transpose(0, 2, 1, 3) + + return cos_ref, sin_ref + + class TestRoPEPositionPrecision: """Test suite for RoPE position precision requirements.""" @@ -132,11 +191,6 @@ class TestRoPEPositionPrecision: """Verify that double_precision mode converts bfloat16 to float32 first.""" positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16) - # The double precision path in rope.py line 434: - # indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64) - # This means bfloat16 -> float32 -> float64 - # The precision is already lost at the bfloat16 -> float32 step - cos_freq, sin_freq = precompute_freqs_cis( indices_grid=positions_bf16, dim=128, @@ -176,6 +230,96 @@ class TestRoPEPositionPrecision: assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive" assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive" + def test_float32_positions_match_numpy_float64_reference(self): + """Regression test: float32 RoPE must closely match a NumPy float64 reference. + + This is the key correctness test. We compute RoPE with NumPy in float64 + (ground truth) and verify that the MLX float32 path produces nearly + identical results. The max allowed diff (1e-5) is well below the error + we saw with bfloat16 positions (~2.0 max diff, cosine sim 0.88). + """ + positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32) + positions_np = np.array(positions).astype(np.float64) + + dim = 128 + theta = 10000.0 + max_pos = [20, 2048, 2048] + num_heads = 32 + + # MLX result (float32 path, non-double-precision) + cos_mlx, sin_mlx = precompute_freqs_cis( + indices_grid=positions, + dim=dim, + theta=theta, + max_pos=max_pos, + use_middle_indices_grid=True, + num_attention_heads=num_heads, + rope_type=LTXRopeType.SPLIT, + double_precision=False, + ) + + # NumPy float64 reference + cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads) + + cos_mlx_np = np.array(cos_mlx) + sin_mlx_np = np.array(sin_mlx) + + max_cos_diff = np.max(np.abs(cos_mlx_np - cos_ref)) + max_sin_diff = np.max(np.abs(sin_mlx_np - sin_ref)) + + # Cosine similarity (flatten for single scalar) + cos_flat = cos_mlx_np.flatten() + ref_flat = cos_ref.flatten() + cosine_sim = np.dot(cos_flat, ref_flat) / (np.linalg.norm(cos_flat) * np.linalg.norm(ref_flat)) + + # float32 vs float64: expect small diffs from 23-bit vs 52-bit mantissa. + # Threshold 0.01 is well below the bfloat16 failure mode (~2.0 max diff). + assert max_cos_diff < 0.01, \ + f"cos max diff {max_cos_diff:.2e} exceeds 0.01 — float32 positions may not be preserved" + assert max_sin_diff < 0.01, \ + f"sin max diff {max_sin_diff:.2e} exceeds 0.01 — float32 positions may not be preserved" + assert cosine_sim > 0.9999, \ + f"cos cosine similarity {cosine_sim:.6f} too low — expected >0.9999" + + def test_high_frequency_amplification_regression(self): + """Regression test for the specific failure mode: high-frequency index amplification. + + With production-sized grids (5x16x16 = 1280 tokens), fractional positions + like 0.000391 get multiplied by frequency indices up to ~15708. In bfloat16 + the fractional part is quantized, producing raw freq errors of ~6.14 and + cos/sin sign flips (max_diff ~2.0). Float32 must keep max_diff < 0.01. + """ + # Use a production-like grid size + positions = create_video_position_grid(1, 5, 16, 16, dtype=mx.float32) + positions_np = np.array(positions).astype(np.float64) + + dim = 128 + theta = 10000.0 + max_pos = [20, 2048, 2048] + num_heads = 32 + + cos_mlx, sin_mlx = precompute_freqs_cis( + indices_grid=positions, + dim=dim, + theta=theta, + max_pos=max_pos, + use_middle_indices_grid=True, + num_attention_heads=num_heads, + rope_type=LTXRopeType.SPLIT, + double_precision=False, + ) + + cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads) + + max_cos_diff = np.max(np.abs(np.array(cos_mlx) - cos_ref)) + max_sin_diff = np.max(np.abs(np.array(sin_mlx) - sin_ref)) + + # Float32 should keep errors well below the bfloat16 failure threshold of ~2.0 + assert max_cos_diff < 0.01, \ + f"Production grid cos max diff {max_cos_diff:.4f} — high-freq amplification detected" + assert max_sin_diff < 0.01, \ + f"Production grid sin max diff {max_sin_diff:.4f} — high-freq amplification detected" + class TestRoPEInterleaved: """Tests for interleaved RoPE mode.""" @@ -201,43 +345,125 @@ class TestRoPEInterleaved: assert not mx.any(mx.isnan(sin_freq)).item() -class TestRoPEWarnings: - """Tests for RoPE warnings.""" +class TestRoPEInputCasting: + """Tests that precompute_freqs_cis casts positions to float32 internally. - def test_bfloat16_positions_trigger_warning(self): - """Verify that bfloat16 positions trigger a UserWarning.""" - positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16) + The fix in rope.py ensures that regardless of the input dtype, positions are + cast to float32 before any computation. This class verifies that behavior + for both the regular and double-precision paths. + """ - with pytest.warns(UserWarning, match="Position grid has dtype bfloat16"): - precompute_freqs_cis( - indices_grid=positions_bf16, - dim=128, - theta=10000.0, - max_pos=[20, 2048, 2048], - use_middle_indices_grid=True, - num_attention_heads=32, - rope_type=LTXRopeType.SPLIT, - double_precision=True, - ) - - def test_float32_positions_no_warning(self): - """Verify that float32 positions do NOT trigger a warning.""" + def test_regular_path_outputs_float32(self): + """Regular path: both float32 and bfloat16 inputs produce float32 output.""" positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32) + positions_bf16 = positions_f32.astype(mx.bfloat16) - # This should not raise any warnings - import warnings - with warnings.catch_warnings(): - warnings.simplefilter("error") # Turn warnings into errors - precompute_freqs_cis( - indices_grid=positions_f32, - dim=128, - theta=10000.0, - max_pos=[20, 2048, 2048], - use_middle_indices_grid=True, - num_attention_heads=32, - rope_type=LTXRopeType.SPLIT, - double_precision=True, - ) + kwargs = dict( + dim=128, theta=10000.0, max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, num_attention_heads=32, + rope_type=LTXRopeType.SPLIT, double_precision=False, + ) + + cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs) + cos_bf16, sin_bf16 = precompute_freqs_cis(indices_grid=positions_bf16, **kwargs) + + # Both produce float32 output regardless of input dtype + assert cos_f32.dtype == mx.float32 + assert cos_bf16.dtype == mx.float32 + assert sin_f32.dtype == mx.float32 + assert sin_bf16.dtype == mx.float32 + + # No NaN/Inf in either + assert not mx.any(mx.isnan(cos_bf16)).item() + assert not mx.any(mx.isinf(cos_bf16)).item() + + def test_double_precision_path_outputs_float32(self): + """Double-precision path: both float32 and bfloat16 inputs produce float32 output.""" + positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32) + positions_bf16 = positions_f32.astype(mx.bfloat16) + + kwargs = dict( + dim=128, theta=10000.0, max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, num_attention_heads=32, + rope_type=LTXRopeType.SPLIT, double_precision=True, + ) + + cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs) + cos_bf16, sin_bf16 = precompute_freqs_cis(indices_grid=positions_bf16, **kwargs) + + assert cos_f32.dtype == mx.float32 + assert cos_bf16.dtype == mx.float32 + assert sin_f32.dtype == mx.float32 + assert sin_bf16.dtype == mx.float32 + + assert not mx.any(mx.isnan(cos_bf16)).item() + assert not mx.any(mx.isinf(cos_bf16)).item() + + def test_float16_input_also_cast_to_float32(self): + """Float16 input should also be handled correctly.""" + positions_f16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float16) + + cos_freq, sin_freq = precompute_freqs_cis( + indices_grid=positions_f16, + dim=128, theta=10000.0, max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, num_attention_heads=32, + rope_type=LTXRopeType.SPLIT, double_precision=False, + ) + + assert cos_freq.dtype == mx.float32 + assert sin_freq.dtype == mx.float32 + assert not mx.any(mx.isnan(cos_freq)).item() + + +class TestDoublePrecisionRopeConfig: + """Tests for the conditional double_precision_rope logic in LTXModelConfig.""" + + def test_ltx2_forces_double_precision_rope_false(self): + """LTX-2 (no prompt adaln) must have double_precision_rope=False.""" + config = LTXModelConfig(has_prompt_adaln=False, double_precision_rope=True) + assert config.double_precision_rope is False, \ + "LTX-2 should force double_precision_rope=False regardless of input" + + def test_ltx23_preserves_double_precision_rope_true(self): + """LTX-2.3 (has_prompt_adaln=True) should keep double_precision_rope=True.""" + config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=True) + assert config.double_precision_rope is True, \ + "LTX-2.3 should preserve double_precision_rope=True" + + def test_ltx23_preserves_double_precision_rope_false(self): + """LTX-2.3 with double_precision_rope=False should stay False.""" + config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=False) + assert config.double_precision_rope is False, \ + "LTX-2.3 should respect double_precision_rope=False when explicitly set" + + def test_ltx2_default_double_precision_rope(self): + """LTX-2 default (double_precision_rope not set) should be False.""" + config = LTXModelConfig(has_prompt_adaln=False) + assert config.double_precision_rope is False + + def test_ltx23_default_double_precision_rope(self): + """LTX-2.3 default (double_precision_rope not set) should be False (field default).""" + config = LTXModelConfig(has_prompt_adaln=True) + # The field default is False and __post_init__ doesn't override for LTX-2.3 + assert config.double_precision_rope is False + + def test_config_from_dict_ltx2(self): + """Config created from dict for LTX-2 should force double_precision_rope=False.""" + config = LTXModelConfig.from_dict({ + "has_prompt_adaln": False, + "double_precision_rope": True, + "rope_type": "split", + }) + assert config.double_precision_rope is False + + def test_config_from_dict_ltx23(self): + """Config created from dict for LTX-2.3 should preserve double_precision_rope.""" + config = LTXModelConfig.from_dict({ + "has_prompt_adaln": True, + "double_precision_rope": True, + "rope_type": "split", + }) + assert config.double_precision_rope is True class TestRoPESplit: