Add RoPE tests and warning for bfloat16 precision loss in RoPE calculations
This commit is contained in:
@@ -430,7 +430,19 @@ def _precompute_freqs_cis_double_precision(
|
|||||||
rope_type: LTXRopeType,
|
rope_type: LTXRopeType,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
|
||||||
|
# Warn if positions are bfloat16 - this causes quality degradation
|
||||||
|
if indices_grid.dtype == mx.bfloat16:
|
||||||
|
import warnings
|
||||||
|
warnings.warn(
|
||||||
|
"Position grid has dtype bfloat16, which causes precision loss in RoPE that causes quality degradation in generated videos/audio. "
|
||||||
|
"Use float32 for position grids to avoid quality degradation. "
|
||||||
|
"See tests/test_rope.py::test_bfloat16_positions_cause_precision_loss",
|
||||||
|
UserWarning,
|
||||||
|
stacklevel=2
|
||||||
|
)
|
||||||
|
|
||||||
# Convert to numpy float64 (first to float32 for numpy compatibility)
|
# Convert to numpy float64 (first to float32 for numpy compatibility)
|
||||||
|
# Note: If input is bfloat16, precision is already lost at this step
|
||||||
indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
|
indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
|
||||||
|
|
||||||
# Generate frequency indices in float64
|
# Generate frequency indices in float64
|
||||||
|
|||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
280
tests/test_rope.py
Normal file
280
tests/test_rope.py
Normal file
@@ -0,0 +1,280 @@
|
|||||||
|
import pytest
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mlx_video.models.ltx.rope import (
|
||||||
|
precompute_freqs_cis,
|
||||||
|
)
|
||||||
|
from mlx_video.models.ltx.config import LTXRopeType
|
||||||
|
|
||||||
|
|
||||||
|
def create_video_position_grid(
|
||||||
|
batch_size: int,
|
||||||
|
num_frames: int,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
dtype: mx.Dtype = mx.float32,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Create a simple video position grid for testing."""
|
||||||
|
t_coords = np.arange(0, num_frames)
|
||||||
|
h_coords = np.arange(0, height)
|
||||||
|
w_coords = np.arange(0, width)
|
||||||
|
|
||||||
|
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
|
||||||
|
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
|
||||||
|
patch_ends = patch_starts + 1
|
||||||
|
|
||||||
|
latent_coords = np.stack([patch_starts, patch_ends], axis=-1)
|
||||||
|
num_patches = num_frames * height * width
|
||||||
|
latent_coords = latent_coords.reshape(3, num_patches, 2)
|
||||||
|
latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1))
|
||||||
|
|
||||||
|
# Scale to pixel space
|
||||||
|
scale_factors = np.array([8, 32, 32]).reshape(1, 3, 1, 1)
|
||||||
|
pixel_coords = (latent_coords * scale_factors).astype(np.float32)
|
||||||
|
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / 24.0 # Convert to seconds
|
||||||
|
|
||||||
|
return mx.array(pixel_coords, dtype=dtype)
|
||||||
|
|
||||||
|
class TestRoPEPositionPrecision:
|
||||||
|
"""Test suite for RoPE position precision requirements."""
|
||||||
|
|
||||||
|
def test_float32_positions_produce_consistent_output(self):
|
||||||
|
"""Float32 position grids should produce stable RoPE frequencies."""
|
||||||
|
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||||
|
|
||||||
|
cos_freq, sin_freq = precompute_freqs_cis(
|
||||||
|
indices_grid=positions,
|
||||||
|
dim=128,
|
||||||
|
theta=10000.0,
|
||||||
|
max_pos=[20, 2048, 2048],
|
||||||
|
use_middle_indices_grid=True,
|
||||||
|
num_attention_heads=32,
|
||||||
|
rope_type=LTXRopeType.SPLIT,
|
||||||
|
double_precision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify output dtype is float32
|
||||||
|
assert cos_freq.dtype == mx.float32, f"Expected float32, got {cos_freq.dtype}"
|
||||||
|
assert sin_freq.dtype == mx.float32, f"Expected float32, got {sin_freq.dtype}"
|
||||||
|
|
||||||
|
# Verify no NaN or Inf values
|
||||||
|
assert not mx.any(mx.isnan(cos_freq)).item(), "cos_freq contains NaN"
|
||||||
|
assert not mx.any(mx.isnan(sin_freq)).item(), "sin_freq contains NaN"
|
||||||
|
assert not mx.any(mx.isinf(cos_freq)).item(), "cos_freq contains Inf"
|
||||||
|
assert not mx.any(mx.isinf(sin_freq)).item(), "sin_freq contains Inf"
|
||||||
|
|
||||||
|
# Verify cos/sin are in valid range [-1, 1]
|
||||||
|
assert mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item(), \
|
||||||
|
"cos_freq values out of [-1, 1] range"
|
||||||
|
assert mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item(), \
|
||||||
|
"sin_freq values out of [-1, 1] range"
|
||||||
|
|
||||||
|
def test_bfloat16_positions_cause_precision_loss(self):
|
||||||
|
"""bfloat16 positions should produce different (less precise) results than float32.
|
||||||
|
|
||||||
|
This test documents the known issue: bfloat16 has only 7 bits of mantissa
|
||||||
|
vs 23 bits for float32, causing quantization errors that get amplified
|
||||||
|
by sin/cos calculations in RoPE.
|
||||||
|
"""
|
||||||
|
# Create identical position grids in different dtypes
|
||||||
|
positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||||
|
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
|
||||||
|
|
||||||
|
# Compute RoPE frequencies
|
||||||
|
cos_f32, sin_f32 = precompute_freqs_cis(
|
||||||
|
indices_grid=positions_f32,
|
||||||
|
dim=128,
|
||||||
|
theta=10000.0,
|
||||||
|
max_pos=[20, 2048, 2048],
|
||||||
|
use_middle_indices_grid=True,
|
||||||
|
num_attention_heads=32,
|
||||||
|
rope_type=LTXRopeType.SPLIT,
|
||||||
|
double_precision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
cos_bf16, sin_bf16 = precompute_freqs_cis(
|
||||||
|
indices_grid=positions_bf16,
|
||||||
|
dim=128,
|
||||||
|
theta=10000.0,
|
||||||
|
max_pos=[20, 2048, 2048],
|
||||||
|
use_middle_indices_grid=True,
|
||||||
|
num_attention_heads=32,
|
||||||
|
rope_type=LTXRopeType.SPLIT,
|
||||||
|
double_precision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate the difference
|
||||||
|
cos_diff = mx.abs(cos_f32 - cos_bf16)
|
||||||
|
sin_diff = mx.abs(sin_f32 - sin_bf16)
|
||||||
|
|
||||||
|
max_cos_diff = mx.max(cos_diff).item()
|
||||||
|
max_sin_diff = mx.max(sin_diff).item()
|
||||||
|
|
||||||
|
# bfloat16 positions WILL cause measurable differences
|
||||||
|
# This test documents this known behavior
|
||||||
|
# The threshold here is intentionally low to catch the issue
|
||||||
|
precision_threshold = 1e-6
|
||||||
|
|
||||||
|
has_precision_loss = max_cos_diff > precision_threshold or max_sin_diff > precision_threshold
|
||||||
|
|
||||||
|
# Document the precision loss (this is expected behavior)
|
||||||
|
if has_precision_loss:
|
||||||
|
print(f"\nPrecision loss detected (expected):")
|
||||||
|
print(f" Max cos difference: {max_cos_diff:.6e}")
|
||||||
|
print(f" Max sin difference: {max_sin_diff:.6e}")
|
||||||
|
|
||||||
|
# This assertion documents the issue - bfloat16 positions cause precision loss
|
||||||
|
assert has_precision_loss, \
|
||||||
|
"Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed"
|
||||||
|
|
||||||
|
def test_double_precision_converts_to_float32_internally(self):
|
||||||
|
"""Verify that double_precision mode converts bfloat16 to float32 first."""
|
||||||
|
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
|
||||||
|
|
||||||
|
# The double precision path in rope.py line 434:
|
||||||
|
# indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
|
||||||
|
# This means bfloat16 -> float32 -> float64
|
||||||
|
# The precision is already lost at the bfloat16 -> float32 step
|
||||||
|
|
||||||
|
cos_freq, sin_freq = precompute_freqs_cis(
|
||||||
|
indices_grid=positions_bf16,
|
||||||
|
dim=128,
|
||||||
|
theta=10000.0,
|
||||||
|
max_pos=[20, 2048, 2048],
|
||||||
|
use_middle_indices_grid=True,
|
||||||
|
num_attention_heads=32,
|
||||||
|
rope_type=LTXRopeType.SPLIT,
|
||||||
|
double_precision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Output should still be float32
|
||||||
|
assert cos_freq.dtype == mx.float32
|
||||||
|
assert sin_freq.dtype == mx.float32
|
||||||
|
|
||||||
|
def test_position_grid_should_be_float32_recommendation(self):
|
||||||
|
"""Test that validates the recommended practice: positions should be float32.
|
||||||
|
|
||||||
|
This test serves as documentation that position grids MUST be float32
|
||||||
|
to avoid quality degradation in generated videos/audio.
|
||||||
|
"""
|
||||||
|
# Recommended: create positions in float32
|
||||||
|
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||||
|
|
||||||
|
assert positions.dtype == mx.float32, \
|
||||||
|
"Position grids should be created in float32 for RoPE precision"
|
||||||
|
|
||||||
|
# Verify the position values are reasonable
|
||||||
|
# Temporal positions should be small (seconds)
|
||||||
|
temporal_positions = positions[:, 0, :, :]
|
||||||
|
assert mx.max(temporal_positions).item() < 100, \
|
||||||
|
"Temporal positions should be in seconds (small values)"
|
||||||
|
|
||||||
|
# Spatial positions should be larger (pixels)
|
||||||
|
spatial_h = positions[:, 1, :, :]
|
||||||
|
spatial_w = positions[:, 2, :, :]
|
||||||
|
assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive"
|
||||||
|
assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRoPEInterleaved:
|
||||||
|
"""Tests for interleaved RoPE mode."""
|
||||||
|
|
||||||
|
def test_interleaved_rope_with_float32_positions(self):
|
||||||
|
"""Interleaved RoPE should work correctly with float32 positions."""
|
||||||
|
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||||
|
|
||||||
|
cos_freq, sin_freq = precompute_freqs_cis(
|
||||||
|
indices_grid=positions,
|
||||||
|
dim=128,
|
||||||
|
theta=10000.0,
|
||||||
|
max_pos=[20, 2048, 2048],
|
||||||
|
use_middle_indices_grid=True,
|
||||||
|
num_attention_heads=32,
|
||||||
|
rope_type=LTXRopeType.INTERLEAVED,
|
||||||
|
double_precision=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cos_freq.dtype == mx.float32
|
||||||
|
assert sin_freq.dtype == mx.float32
|
||||||
|
assert not mx.any(mx.isnan(cos_freq)).item()
|
||||||
|
assert not mx.any(mx.isnan(sin_freq)).item()
|
||||||
|
|
||||||
|
|
||||||
|
class TestRoPEWarnings:
|
||||||
|
"""Tests for RoPE warnings."""
|
||||||
|
|
||||||
|
def test_bfloat16_positions_trigger_warning(self):
|
||||||
|
"""Verify that bfloat16 positions trigger a UserWarning."""
|
||||||
|
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
|
||||||
|
|
||||||
|
with pytest.warns(UserWarning, match="Position grid has dtype bfloat16"):
|
||||||
|
precompute_freqs_cis(
|
||||||
|
indices_grid=positions_bf16,
|
||||||
|
dim=128,
|
||||||
|
theta=10000.0,
|
||||||
|
max_pos=[20, 2048, 2048],
|
||||||
|
use_middle_indices_grid=True,
|
||||||
|
num_attention_heads=32,
|
||||||
|
rope_type=LTXRopeType.SPLIT,
|
||||||
|
double_precision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_float32_positions_no_warning(self):
|
||||||
|
"""Verify that float32 positions do NOT trigger a warning."""
|
||||||
|
positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||||
|
|
||||||
|
# This should not raise any warnings
|
||||||
|
import warnings
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("error") # Turn warnings into errors
|
||||||
|
precompute_freqs_cis(
|
||||||
|
indices_grid=positions_f32,
|
||||||
|
dim=128,
|
||||||
|
theta=10000.0,
|
||||||
|
max_pos=[20, 2048, 2048],
|
||||||
|
use_middle_indices_grid=True,
|
||||||
|
num_attention_heads=32,
|
||||||
|
rope_type=LTXRopeType.SPLIT,
|
||||||
|
double_precision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRoPESplit:
|
||||||
|
"""Tests for split RoPE mode (used by LTX-2)."""
|
||||||
|
|
||||||
|
def test_split_rope_output_shape(self):
|
||||||
|
"""Verify split RoPE output has correct shape (B, H, T, dim_per_head//2)."""
|
||||||
|
batch_size = 1
|
||||||
|
num_frames = 4
|
||||||
|
height = 4
|
||||||
|
width = 4
|
||||||
|
num_heads = 32
|
||||||
|
dim = 128
|
||||||
|
|
||||||
|
positions = create_video_position_grid(batch_size, num_frames, height, width)
|
||||||
|
num_tokens = num_frames * height * width
|
||||||
|
|
||||||
|
cos_freq, sin_freq = precompute_freqs_cis(
|
||||||
|
indices_grid=positions,
|
||||||
|
dim=dim,
|
||||||
|
theta=10000.0,
|
||||||
|
max_pos=[20, 2048, 2048],
|
||||||
|
use_middle_indices_grid=True,
|
||||||
|
num_attention_heads=num_heads,
|
||||||
|
rope_type=LTXRopeType.SPLIT,
|
||||||
|
double_precision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shape should be (B, H, T, dim_per_head//2)
|
||||||
|
# dim=128, num_heads=32, so dim_per_head=4, and split uses half=2
|
||||||
|
dim_per_head = dim // num_heads
|
||||||
|
expected_shape = (batch_size, num_heads, num_tokens, dim_per_head // 2)
|
||||||
|
assert cos_freq.shape == expected_shape, \
|
||||||
|
f"Expected shape {expected_shape}, got {cos_freq.shape}"
|
||||||
|
assert sin_freq.shape == expected_shape, \
|
||||||
|
f"Expected shape {expected_shape}, got {sin_freq.shape}"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
Reference in New Issue
Block a user