Add RoPE tests and warning for bfloat16 precision loss in RoPE calculations

This commit is contained in:
Prince Canuma
2026-01-17 19:28:05 +01:00
parent 78244a2d66
commit 61c56cd989
3 changed files with 292 additions and 0 deletions

View File

@@ -430,7 +430,19 @@ def _precompute_freqs_cis_double_precision(
rope_type: LTXRopeType, rope_type: LTXRopeType,
) -> Tuple[mx.array, mx.array]: ) -> Tuple[mx.array, mx.array]:
# Warn if positions are bfloat16 - this causes quality degradation
if indices_grid.dtype == mx.bfloat16:
import warnings
warnings.warn(
"Position grid has dtype bfloat16, which causes precision loss in RoPE that causes quality degradation in generated videos/audio. "
"Use float32 for position grids to avoid quality degradation. "
"See tests/test_rope.py::test_bfloat16_positions_cause_precision_loss",
UserWarning,
stacklevel=2
)
# Convert to numpy float64 (first to float32 for numpy compatibility) # Convert to numpy float64 (first to float32 for numpy compatibility)
# Note: If input is bfloat16, precision is already lost at this step
indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64) indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
# Generate frequency indices in float64 # Generate frequency indices in float64

0
tests/__init__.py Normal file
View File

280
tests/test_rope.py Normal file
View File

@@ -0,0 +1,280 @@
import pytest
import mlx.core as mx
import numpy as np
from mlx_video.models.ltx.rope import (
precompute_freqs_cis,
)
from mlx_video.models.ltx.config import LTXRopeType
def create_video_position_grid(
batch_size: int,
num_frames: int,
height: int,
width: int,
dtype: mx.Dtype = mx.float32,
) -> mx.array:
"""Create a simple video position grid for testing."""
t_coords = np.arange(0, num_frames)
h_coords = np.arange(0, height)
w_coords = np.arange(0, width)
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
patch_ends = patch_starts + 1
latent_coords = np.stack([patch_starts, patch_ends], axis=-1)
num_patches = num_frames * height * width
latent_coords = latent_coords.reshape(3, num_patches, 2)
latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1))
# Scale to pixel space
scale_factors = np.array([8, 32, 32]).reshape(1, 3, 1, 1)
pixel_coords = (latent_coords * scale_factors).astype(np.float32)
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / 24.0 # Convert to seconds
return mx.array(pixel_coords, dtype=dtype)
class TestRoPEPositionPrecision:
"""Test suite for RoPE position precision requirements."""
def test_float32_positions_produce_consistent_output(self):
"""Float32 position grids should produce stable RoPE frequencies."""
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
cos_freq, sin_freq = precompute_freqs_cis(
indices_grid=positions,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
# Verify output dtype is float32
assert cos_freq.dtype == mx.float32, f"Expected float32, got {cos_freq.dtype}"
assert sin_freq.dtype == mx.float32, f"Expected float32, got {sin_freq.dtype}"
# Verify no NaN or Inf values
assert not mx.any(mx.isnan(cos_freq)).item(), "cos_freq contains NaN"
assert not mx.any(mx.isnan(sin_freq)).item(), "sin_freq contains NaN"
assert not mx.any(mx.isinf(cos_freq)).item(), "cos_freq contains Inf"
assert not mx.any(mx.isinf(sin_freq)).item(), "sin_freq contains Inf"
# Verify cos/sin are in valid range [-1, 1]
assert mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item(), \
"cos_freq values out of [-1, 1] range"
assert mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item(), \
"sin_freq values out of [-1, 1] range"
def test_bfloat16_positions_cause_precision_loss(self):
"""bfloat16 positions should produce different (less precise) results than float32.
This test documents the known issue: bfloat16 has only 7 bits of mantissa
vs 23 bits for float32, causing quantization errors that get amplified
by sin/cos calculations in RoPE.
"""
# Create identical position grids in different dtypes
positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
# Compute RoPE frequencies
cos_f32, sin_f32 = precompute_freqs_cis(
indices_grid=positions_f32,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
cos_bf16, sin_bf16 = precompute_freqs_cis(
indices_grid=positions_bf16,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
# Calculate the difference
cos_diff = mx.abs(cos_f32 - cos_bf16)
sin_diff = mx.abs(sin_f32 - sin_bf16)
max_cos_diff = mx.max(cos_diff).item()
max_sin_diff = mx.max(sin_diff).item()
# bfloat16 positions WILL cause measurable differences
# This test documents this known behavior
# The threshold here is intentionally low to catch the issue
precision_threshold = 1e-6
has_precision_loss = max_cos_diff > precision_threshold or max_sin_diff > precision_threshold
# Document the precision loss (this is expected behavior)
if has_precision_loss:
print(f"\nPrecision loss detected (expected):")
print(f" Max cos difference: {max_cos_diff:.6e}")
print(f" Max sin difference: {max_sin_diff:.6e}")
# This assertion documents the issue - bfloat16 positions cause precision loss
assert has_precision_loss, \
"Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed"
def test_double_precision_converts_to_float32_internally(self):
"""Verify that double_precision mode converts bfloat16 to float32 first."""
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
# The double precision path in rope.py line 434:
# indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
# This means bfloat16 -> float32 -> float64
# The precision is already lost at the bfloat16 -> float32 step
cos_freq, sin_freq = precompute_freqs_cis(
indices_grid=positions_bf16,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
# Output should still be float32
assert cos_freq.dtype == mx.float32
assert sin_freq.dtype == mx.float32
def test_position_grid_should_be_float32_recommendation(self):
"""Test that validates the recommended practice: positions should be float32.
This test serves as documentation that position grids MUST be float32
to avoid quality degradation in generated videos/audio.
"""
# Recommended: create positions in float32
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
assert positions.dtype == mx.float32, \
"Position grids should be created in float32 for RoPE precision"
# Verify the position values are reasonable
# Temporal positions should be small (seconds)
temporal_positions = positions[:, 0, :, :]
assert mx.max(temporal_positions).item() < 100, \
"Temporal positions should be in seconds (small values)"
# Spatial positions should be larger (pixels)
spatial_h = positions[:, 1, :, :]
spatial_w = positions[:, 2, :, :]
assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive"
assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive"
class TestRoPEInterleaved:
"""Tests for interleaved RoPE mode."""
def test_interleaved_rope_with_float32_positions(self):
"""Interleaved RoPE should work correctly with float32 positions."""
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
cos_freq, sin_freq = precompute_freqs_cis(
indices_grid=positions,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.INTERLEAVED,
double_precision=False,
)
assert cos_freq.dtype == mx.float32
assert sin_freq.dtype == mx.float32
assert not mx.any(mx.isnan(cos_freq)).item()
assert not mx.any(mx.isnan(sin_freq)).item()
class TestRoPEWarnings:
"""Tests for RoPE warnings."""
def test_bfloat16_positions_trigger_warning(self):
"""Verify that bfloat16 positions trigger a UserWarning."""
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
with pytest.warns(UserWarning, match="Position grid has dtype bfloat16"):
precompute_freqs_cis(
indices_grid=positions_bf16,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
def test_float32_positions_no_warning(self):
"""Verify that float32 positions do NOT trigger a warning."""
positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
# This should not raise any warnings
import warnings
with warnings.catch_warnings():
warnings.simplefilter("error") # Turn warnings into errors
precompute_freqs_cis(
indices_grid=positions_f32,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
class TestRoPESplit:
"""Tests for split RoPE mode (used by LTX-2)."""
def test_split_rope_output_shape(self):
"""Verify split RoPE output has correct shape (B, H, T, dim_per_head//2)."""
batch_size = 1
num_frames = 4
height = 4
width = 4
num_heads = 32
dim = 128
positions = create_video_position_grid(batch_size, num_frames, height, width)
num_tokens = num_frames * height * width
cos_freq, sin_freq = precompute_freqs_cis(
indices_grid=positions,
dim=dim,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=num_heads,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
# Shape should be (B, H, T, dim_per_head//2)
# dim=128, num_heads=32, so dim_per_head=4, and split uses half=2
dim_per_head = dim // num_heads
expected_shape = (batch_size, num_heads, num_tokens, dim_per_head // 2)
assert cos_freq.shape == expected_shape, \
f"Expected shape {expected_shape}, got {cos_freq.shape}"
assert sin_freq.shape == expected_shape, \
f"Expected shape {expected_shape}, got {sin_freq.shape}"
if __name__ == "__main__":
pytest.main([__file__, "-v"])