Add RoPE tests and warning for bfloat16 precision loss in RoPE calculations
This commit is contained in:
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
280
tests/test_rope.py
Normal file
280
tests/test_rope.py
Normal file
@@ -0,0 +1,280 @@
|
||||
import pytest
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
from mlx_video.models.ltx.rope import (
|
||||
precompute_freqs_cis,
|
||||
)
|
||||
from mlx_video.models.ltx.config import LTXRopeType
|
||||
|
||||
|
||||
def create_video_position_grid(
|
||||
batch_size: int,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
) -> mx.array:
|
||||
"""Create a simple video position grid for testing."""
|
||||
t_coords = np.arange(0, num_frames)
|
||||
h_coords = np.arange(0, height)
|
||||
w_coords = np.arange(0, width)
|
||||
|
||||
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
|
||||
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
|
||||
patch_ends = patch_starts + 1
|
||||
|
||||
latent_coords = np.stack([patch_starts, patch_ends], axis=-1)
|
||||
num_patches = num_frames * height * width
|
||||
latent_coords = latent_coords.reshape(3, num_patches, 2)
|
||||
latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1))
|
||||
|
||||
# Scale to pixel space
|
||||
scale_factors = np.array([8, 32, 32]).reshape(1, 3, 1, 1)
|
||||
pixel_coords = (latent_coords * scale_factors).astype(np.float32)
|
||||
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / 24.0 # Convert to seconds
|
||||
|
||||
return mx.array(pixel_coords, dtype=dtype)
|
||||
|
||||
class TestRoPEPositionPrecision:
|
||||
"""Test suite for RoPE position precision requirements."""
|
||||
|
||||
def test_float32_positions_produce_consistent_output(self):
|
||||
"""Float32 position grids should produce stable RoPE frequencies."""
|
||||
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||
|
||||
cos_freq, sin_freq = precompute_freqs_cis(
|
||||
indices_grid=positions,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=True,
|
||||
)
|
||||
|
||||
# Verify output dtype is float32
|
||||
assert cos_freq.dtype == mx.float32, f"Expected float32, got {cos_freq.dtype}"
|
||||
assert sin_freq.dtype == mx.float32, f"Expected float32, got {sin_freq.dtype}"
|
||||
|
||||
# Verify no NaN or Inf values
|
||||
assert not mx.any(mx.isnan(cos_freq)).item(), "cos_freq contains NaN"
|
||||
assert not mx.any(mx.isnan(sin_freq)).item(), "sin_freq contains NaN"
|
||||
assert not mx.any(mx.isinf(cos_freq)).item(), "cos_freq contains Inf"
|
||||
assert not mx.any(mx.isinf(sin_freq)).item(), "sin_freq contains Inf"
|
||||
|
||||
# Verify cos/sin are in valid range [-1, 1]
|
||||
assert mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item(), \
|
||||
"cos_freq values out of [-1, 1] range"
|
||||
assert mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item(), \
|
||||
"sin_freq values out of [-1, 1] range"
|
||||
|
||||
def test_bfloat16_positions_cause_precision_loss(self):
|
||||
"""bfloat16 positions should produce different (less precise) results than float32.
|
||||
|
||||
This test documents the known issue: bfloat16 has only 7 bits of mantissa
|
||||
vs 23 bits for float32, causing quantization errors that get amplified
|
||||
by sin/cos calculations in RoPE.
|
||||
"""
|
||||
# Create identical position grids in different dtypes
|
||||
positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
|
||||
|
||||
# Compute RoPE frequencies
|
||||
cos_f32, sin_f32 = precompute_freqs_cis(
|
||||
indices_grid=positions_f32,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=True,
|
||||
)
|
||||
|
||||
cos_bf16, sin_bf16 = precompute_freqs_cis(
|
||||
indices_grid=positions_bf16,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=True,
|
||||
)
|
||||
|
||||
# Calculate the difference
|
||||
cos_diff = mx.abs(cos_f32 - cos_bf16)
|
||||
sin_diff = mx.abs(sin_f32 - sin_bf16)
|
||||
|
||||
max_cos_diff = mx.max(cos_diff).item()
|
||||
max_sin_diff = mx.max(sin_diff).item()
|
||||
|
||||
# bfloat16 positions WILL cause measurable differences
|
||||
# This test documents this known behavior
|
||||
# The threshold here is intentionally low to catch the issue
|
||||
precision_threshold = 1e-6
|
||||
|
||||
has_precision_loss = max_cos_diff > precision_threshold or max_sin_diff > precision_threshold
|
||||
|
||||
# Document the precision loss (this is expected behavior)
|
||||
if has_precision_loss:
|
||||
print(f"\nPrecision loss detected (expected):")
|
||||
print(f" Max cos difference: {max_cos_diff:.6e}")
|
||||
print(f" Max sin difference: {max_sin_diff:.6e}")
|
||||
|
||||
# This assertion documents the issue - bfloat16 positions cause precision loss
|
||||
assert has_precision_loss, \
|
||||
"Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed"
|
||||
|
||||
def test_double_precision_converts_to_float32_internally(self):
|
||||
"""Verify that double_precision mode converts bfloat16 to float32 first."""
|
||||
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
|
||||
|
||||
# The double precision path in rope.py line 434:
|
||||
# indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
|
||||
# This means bfloat16 -> float32 -> float64
|
||||
# The precision is already lost at the bfloat16 -> float32 step
|
||||
|
||||
cos_freq, sin_freq = precompute_freqs_cis(
|
||||
indices_grid=positions_bf16,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=True,
|
||||
)
|
||||
|
||||
# Output should still be float32
|
||||
assert cos_freq.dtype == mx.float32
|
||||
assert sin_freq.dtype == mx.float32
|
||||
|
||||
def test_position_grid_should_be_float32_recommendation(self):
|
||||
"""Test that validates the recommended practice: positions should be float32.
|
||||
|
||||
This test serves as documentation that position grids MUST be float32
|
||||
to avoid quality degradation in generated videos/audio.
|
||||
"""
|
||||
# Recommended: create positions in float32
|
||||
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||
|
||||
assert positions.dtype == mx.float32, \
|
||||
"Position grids should be created in float32 for RoPE precision"
|
||||
|
||||
# Verify the position values are reasonable
|
||||
# Temporal positions should be small (seconds)
|
||||
temporal_positions = positions[:, 0, :, :]
|
||||
assert mx.max(temporal_positions).item() < 100, \
|
||||
"Temporal positions should be in seconds (small values)"
|
||||
|
||||
# Spatial positions should be larger (pixels)
|
||||
spatial_h = positions[:, 1, :, :]
|
||||
spatial_w = positions[:, 2, :, :]
|
||||
assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive"
|
||||
assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive"
|
||||
|
||||
|
||||
class TestRoPEInterleaved:
|
||||
"""Tests for interleaved RoPE mode."""
|
||||
|
||||
def test_interleaved_rope_with_float32_positions(self):
|
||||
"""Interleaved RoPE should work correctly with float32 positions."""
|
||||
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||
|
||||
cos_freq, sin_freq = precompute_freqs_cis(
|
||||
indices_grid=positions,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.INTERLEAVED,
|
||||
double_precision=False,
|
||||
)
|
||||
|
||||
assert cos_freq.dtype == mx.float32
|
||||
assert sin_freq.dtype == mx.float32
|
||||
assert not mx.any(mx.isnan(cos_freq)).item()
|
||||
assert not mx.any(mx.isnan(sin_freq)).item()
|
||||
|
||||
|
||||
class TestRoPEWarnings:
|
||||
"""Tests for RoPE warnings."""
|
||||
|
||||
def test_bfloat16_positions_trigger_warning(self):
|
||||
"""Verify that bfloat16 positions trigger a UserWarning."""
|
||||
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
|
||||
|
||||
with pytest.warns(UserWarning, match="Position grid has dtype bfloat16"):
|
||||
precompute_freqs_cis(
|
||||
indices_grid=positions_bf16,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=True,
|
||||
)
|
||||
|
||||
def test_float32_positions_no_warning(self):
|
||||
"""Verify that float32 positions do NOT trigger a warning."""
|
||||
positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||
|
||||
# This should not raise any warnings
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error") # Turn warnings into errors
|
||||
precompute_freqs_cis(
|
||||
indices_grid=positions_f32,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=True,
|
||||
)
|
||||
|
||||
|
||||
class TestRoPESplit:
|
||||
"""Tests for split RoPE mode (used by LTX-2)."""
|
||||
|
||||
def test_split_rope_output_shape(self):
|
||||
"""Verify split RoPE output has correct shape (B, H, T, dim_per_head//2)."""
|
||||
batch_size = 1
|
||||
num_frames = 4
|
||||
height = 4
|
||||
width = 4
|
||||
num_heads = 32
|
||||
dim = 128
|
||||
|
||||
positions = create_video_position_grid(batch_size, num_frames, height, width)
|
||||
num_tokens = num_frames * height * width
|
||||
|
||||
cos_freq, sin_freq = precompute_freqs_cis(
|
||||
indices_grid=positions,
|
||||
dim=dim,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=num_heads,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=True,
|
||||
)
|
||||
|
||||
# Shape should be (B, H, T, dim_per_head//2)
|
||||
# dim=128, num_heads=32, so dim_per_head=4, and split uses half=2
|
||||
dim_per_head = dim // num_heads
|
||||
expected_shape = (batch_size, num_heads, num_tokens, dim_per_head // 2)
|
||||
assert cos_freq.shape == expected_shape, \
|
||||
f"Expected shape {expected_shape}, got {cos_freq.shape}"
|
||||
assert sin_freq.shape == expected_shape, \
|
||||
f"Expected shape {expected_shape}, got {sin_freq.shape}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user