Implement regression tests for RoPE position precision using NumPy float64 reference. Add a new function to compute reference values and validate that float32 results closely match expected outputs, addressing high-frequency amplification issues. Update imports to include LTXModelConfig for enhanced configuration management.

This commit is contained in:
Prince Canuma
2026-03-15 23:00:38 +01:00
parent cecd68197c
commit 38d46a6eda

View File

@@ -5,7 +5,7 @@ import numpy as np
from mlx_video.models.ltx.rope import ( from mlx_video.models.ltx.rope import (
precompute_freqs_cis, precompute_freqs_cis,
) )
from mlx_video.models.ltx.config import LTXRopeType from mlx_video.models.ltx.config import LTXModelConfig, LTXRopeType
def create_video_position_grid( def create_video_position_grid(
@@ -36,6 +36,65 @@ def create_video_position_grid(
return mx.array(pixel_coords, dtype=dtype) return mx.array(pixel_coords, dtype=dtype)
def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads):
"""Compute RoPE cos/sin using NumPy float64 as ground truth reference.
This mirrors the regular (non-double-precision) path in rope.py exactly,
but uses float64 throughout, so we can verify that the float32 MLX path
stays close to the true values.
"""
# positions_np: (B, 3, T, 2) in float64
n_pos_dims = positions_np.shape[1]
n_elem = 2 * n_pos_dims
# Middle-of-interval positions
mid = (positions_np[..., 0] + positions_np[..., 1]) / 2.0 # (B, 3, T)
# Frequency grid — matches generate_freq_grid() in rope.py:
# log_start = log(1)/log(theta) = 0
# log_end = log(theta)/log(theta) = 1
# pow_indices = theta^linspace(0, 1, num_indices) * pi/2
num_indices = dim // n_elem
if num_indices == 0:
num_indices = 1
lin_space = np.linspace(0.0, 1.0, num_indices, dtype=np.float64)
freq_indices = np.power(theta, lin_space) * (np.pi / 2) # (num_indices,)
# Fractional positions and scaling — matches generate_freqs()
# frac = pos / max_pos for each dim, then scale to [-1, 1]
frac_list = []
for d in range(n_pos_dims):
frac = mid[:, d, :] / max_pos[d] # (B, T)
frac_list.append(frac)
fractional = np.stack(frac_list, axis=-1) # (B, T, n_dims)
scaled = fractional * 2 - 1 # [-1, 1]
# Outer product: (B, T, n_dims, 1) * (1, 1, 1, num_indices)
freqs = scaled[..., np.newaxis] * freq_indices[np.newaxis, np.newaxis, np.newaxis, :]
# (B, T, n_dims, num_indices) -> swap last two -> (B, T, num_indices, n_dims) -> flatten
freqs = np.swapaxes(freqs, -1, -2)
freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # (B, T, num_indices * n_dims)
cos_ref = np.cos(freqs)
sin_ref = np.sin(freqs)
# Split RoPE: pad to dim//2, reshape to (B, H, T, dim_per_head//2)
expected = dim // 2
pad_size = expected - cos_ref.shape[-1]
if pad_size > 0:
# Padding is prepended (ones for cos, zeros for sin) — matches split_freqs_cis()
cos_ref = np.concatenate([np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1)
sin_ref = np.concatenate([np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1)
B, T, F = cos_ref.shape
dim_per_head = dim // num_heads
cos_ref = cos_ref.reshape(B, T, num_heads, dim_per_head // 2).transpose(0, 2, 1, 3)
sin_ref = sin_ref.reshape(B, T, num_heads, dim_per_head // 2).transpose(0, 2, 1, 3)
return cos_ref, sin_ref
class TestRoPEPositionPrecision: class TestRoPEPositionPrecision:
"""Test suite for RoPE position precision requirements.""" """Test suite for RoPE position precision requirements."""
@@ -132,11 +191,6 @@ class TestRoPEPositionPrecision:
"""Verify that double_precision mode converts bfloat16 to float32 first.""" """Verify that double_precision mode converts bfloat16 to float32 first."""
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16) positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
# The double precision path in rope.py line 434:
# indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
# This means bfloat16 -> float32 -> float64
# The precision is already lost at the bfloat16 -> float32 step
cos_freq, sin_freq = precompute_freqs_cis( cos_freq, sin_freq = precompute_freqs_cis(
indices_grid=positions_bf16, indices_grid=positions_bf16,
dim=128, dim=128,
@@ -176,6 +230,96 @@ class TestRoPEPositionPrecision:
assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive" assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive"
assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive" assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive"
def test_float32_positions_match_numpy_float64_reference(self):
"""Regression test: float32 RoPE must closely match a NumPy float64 reference.
This is the key correctness test. We compute RoPE with NumPy in float64
(ground truth) and verify that the MLX float32 path produces nearly
identical results. The max allowed diff (1e-5) is well below the error
we saw with bfloat16 positions (~2.0 max diff, cosine sim 0.88).
"""
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
positions_np = np.array(positions).astype(np.float64)
dim = 128
theta = 10000.0
max_pos = [20, 2048, 2048]
num_heads = 32
# MLX result (float32 path, non-double-precision)
cos_mlx, sin_mlx = precompute_freqs_cis(
indices_grid=positions,
dim=dim,
theta=theta,
max_pos=max_pos,
use_middle_indices_grid=True,
num_attention_heads=num_heads,
rope_type=LTXRopeType.SPLIT,
double_precision=False,
)
# NumPy float64 reference
cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads)
cos_mlx_np = np.array(cos_mlx)
sin_mlx_np = np.array(sin_mlx)
max_cos_diff = np.max(np.abs(cos_mlx_np - cos_ref))
max_sin_diff = np.max(np.abs(sin_mlx_np - sin_ref))
# Cosine similarity (flatten for single scalar)
cos_flat = cos_mlx_np.flatten()
ref_flat = cos_ref.flatten()
cosine_sim = np.dot(cos_flat, ref_flat) / (np.linalg.norm(cos_flat) * np.linalg.norm(ref_flat))
# float32 vs float64: expect small diffs from 23-bit vs 52-bit mantissa.
# Threshold 0.01 is well below the bfloat16 failure mode (~2.0 max diff).
assert max_cos_diff < 0.01, \
f"cos max diff {max_cos_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
assert max_sin_diff < 0.01, \
f"sin max diff {max_sin_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
assert cosine_sim > 0.9999, \
f"cos cosine similarity {cosine_sim:.6f} too low — expected >0.9999"
def test_high_frequency_amplification_regression(self):
"""Regression test for the specific failure mode: high-frequency index amplification.
With production-sized grids (5x16x16 = 1280 tokens), fractional positions
like 0.000391 get multiplied by frequency indices up to ~15708. In bfloat16
the fractional part is quantized, producing raw freq errors of ~6.14 and
cos/sin sign flips (max_diff ~2.0). Float32 must keep max_diff < 0.01.
"""
# Use a production-like grid size
positions = create_video_position_grid(1, 5, 16, 16, dtype=mx.float32)
positions_np = np.array(positions).astype(np.float64)
dim = 128
theta = 10000.0
max_pos = [20, 2048, 2048]
num_heads = 32
cos_mlx, sin_mlx = precompute_freqs_cis(
indices_grid=positions,
dim=dim,
theta=theta,
max_pos=max_pos,
use_middle_indices_grid=True,
num_attention_heads=num_heads,
rope_type=LTXRopeType.SPLIT,
double_precision=False,
)
cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads)
max_cos_diff = np.max(np.abs(np.array(cos_mlx) - cos_ref))
max_sin_diff = np.max(np.abs(np.array(sin_mlx) - sin_ref))
# Float32 should keep errors well below the bfloat16 failure threshold of ~2.0
assert max_cos_diff < 0.01, \
f"Production grid cos max diff {max_cos_diff:.4f} — high-freq amplification detected"
assert max_sin_diff < 0.01, \
f"Production grid sin max diff {max_sin_diff:.4f} — high-freq amplification detected"
class TestRoPEInterleaved: class TestRoPEInterleaved:
"""Tests for interleaved RoPE mode.""" """Tests for interleaved RoPE mode."""
@@ -201,44 +345,126 @@ class TestRoPEInterleaved:
assert not mx.any(mx.isnan(sin_freq)).item() assert not mx.any(mx.isnan(sin_freq)).item()
class TestRoPEWarnings: class TestRoPEInputCasting:
"""Tests for RoPE warnings.""" """Tests that precompute_freqs_cis casts positions to float32 internally.
def test_bfloat16_positions_trigger_warning(self): The fix in rope.py ensures that regardless of the input dtype, positions are
"""Verify that bfloat16 positions trigger a UserWarning.""" cast to float32 before any computation. This class verifies that behavior
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16) for both the regular and double-precision paths.
"""
with pytest.warns(UserWarning, match="Position grid has dtype bfloat16"): def test_regular_path_outputs_float32(self):
precompute_freqs_cis( """Regular path: both float32 and bfloat16 inputs produce float32 output."""
indices_grid=positions_bf16,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
def test_float32_positions_no_warning(self):
"""Verify that float32 positions do NOT trigger a warning."""
positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32) positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
positions_bf16 = positions_f32.astype(mx.bfloat16)
# This should not raise any warnings kwargs = dict(
import warnings dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
with warnings.catch_warnings(): use_middle_indices_grid=True, num_attention_heads=32,
warnings.simplefilter("error") # Turn warnings into errors rope_type=LTXRopeType.SPLIT, double_precision=False,
precompute_freqs_cis(
indices_grid=positions_f32,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
) )
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
cos_bf16, sin_bf16 = precompute_freqs_cis(indices_grid=positions_bf16, **kwargs)
# Both produce float32 output regardless of input dtype
assert cos_f32.dtype == mx.float32
assert cos_bf16.dtype == mx.float32
assert sin_f32.dtype == mx.float32
assert sin_bf16.dtype == mx.float32
# No NaN/Inf in either
assert not mx.any(mx.isnan(cos_bf16)).item()
assert not mx.any(mx.isinf(cos_bf16)).item()
def test_double_precision_path_outputs_float32(self):
"""Double-precision path: both float32 and bfloat16 inputs produce float32 output."""
positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
positions_bf16 = positions_f32.astype(mx.bfloat16)
kwargs = dict(
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
use_middle_indices_grid=True, num_attention_heads=32,
rope_type=LTXRopeType.SPLIT, double_precision=True,
)
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
cos_bf16, sin_bf16 = precompute_freqs_cis(indices_grid=positions_bf16, **kwargs)
assert cos_f32.dtype == mx.float32
assert cos_bf16.dtype == mx.float32
assert sin_f32.dtype == mx.float32
assert sin_bf16.dtype == mx.float32
assert not mx.any(mx.isnan(cos_bf16)).item()
assert not mx.any(mx.isinf(cos_bf16)).item()
def test_float16_input_also_cast_to_float32(self):
"""Float16 input should also be handled correctly."""
positions_f16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float16)
cos_freq, sin_freq = precompute_freqs_cis(
indices_grid=positions_f16,
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
use_middle_indices_grid=True, num_attention_heads=32,
rope_type=LTXRopeType.SPLIT, double_precision=False,
)
assert cos_freq.dtype == mx.float32
assert sin_freq.dtype == mx.float32
assert not mx.any(mx.isnan(cos_freq)).item()
class TestDoublePrecisionRopeConfig:
"""Tests for the conditional double_precision_rope logic in LTXModelConfig."""
def test_ltx2_forces_double_precision_rope_false(self):
"""LTX-2 (no prompt adaln) must have double_precision_rope=False."""
config = LTXModelConfig(has_prompt_adaln=False, double_precision_rope=True)
assert config.double_precision_rope is False, \
"LTX-2 should force double_precision_rope=False regardless of input"
def test_ltx23_preserves_double_precision_rope_true(self):
"""LTX-2.3 (has_prompt_adaln=True) should keep double_precision_rope=True."""
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=True)
assert config.double_precision_rope is True, \
"LTX-2.3 should preserve double_precision_rope=True"
def test_ltx23_preserves_double_precision_rope_false(self):
"""LTX-2.3 with double_precision_rope=False should stay False."""
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=False)
assert config.double_precision_rope is False, \
"LTX-2.3 should respect double_precision_rope=False when explicitly set"
def test_ltx2_default_double_precision_rope(self):
"""LTX-2 default (double_precision_rope not set) should be False."""
config = LTXModelConfig(has_prompt_adaln=False)
assert config.double_precision_rope is False
def test_ltx23_default_double_precision_rope(self):
"""LTX-2.3 default (double_precision_rope not set) should be False (field default)."""
config = LTXModelConfig(has_prompt_adaln=True)
# The field default is False and __post_init__ doesn't override for LTX-2.3
assert config.double_precision_rope is False
def test_config_from_dict_ltx2(self):
"""Config created from dict for LTX-2 should force double_precision_rope=False."""
config = LTXModelConfig.from_dict({
"has_prompt_adaln": False,
"double_precision_rope": True,
"rope_type": "split",
})
assert config.double_precision_rope is False
def test_config_from_dict_ltx23(self):
"""Config created from dict for LTX-2.3 should preserve double_precision_rope."""
config = LTXModelConfig.from_dict({
"has_prompt_adaln": True,
"double_precision_rope": True,
"rope_type": "split",
})
assert config.double_precision_rope is True
class TestRoPESplit: class TestRoPESplit:
"""Tests for split RoPE mode (used by LTX-2).""" """Tests for split RoPE mode (used by LTX-2)."""