diff --git a/mlx_video/models/ltx/rope.py b/mlx_video/models/ltx/rope.py index a00d019..4852942 100644 --- a/mlx_video/models/ltx/rope.py +++ b/mlx_video/models/ltx/rope.py @@ -430,7 +430,19 @@ def _precompute_freqs_cis_double_precision( rope_type: LTXRopeType, ) -> Tuple[mx.array, mx.array]: + # Warn if positions are bfloat16 - this causes quality degradation + if indices_grid.dtype == mx.bfloat16: + import warnings + warnings.warn( + "Position grid has dtype bfloat16, which causes precision loss in RoPE that causes quality degradation in generated videos/audio. " + "Use float32 for position grids to avoid quality degradation. " + "See tests/test_rope.py::test_bfloat16_positions_cause_precision_loss", + UserWarning, + stacklevel=2 + ) + # Convert to numpy float64 (first to float32 for numpy compatibility) + # Note: If input is bfloat16, precision is already lost at this step indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64) # Generate frequency indices in float64 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_rope.py b/tests/test_rope.py new file mode 100644 index 0000000..cef8d6f --- /dev/null +++ b/tests/test_rope.py @@ -0,0 +1,280 @@ +import pytest +import mlx.core as mx +import numpy as np + +from mlx_video.models.ltx.rope import ( + precompute_freqs_cis, +) +from mlx_video.models.ltx.config import LTXRopeType + + +def create_video_position_grid( + batch_size: int, + num_frames: int, + height: int, + width: int, + dtype: mx.Dtype = mx.float32, +) -> mx.array: + """Create a simple video position grid for testing.""" + t_coords = np.arange(0, num_frames) + h_coords = np.arange(0, height) + w_coords = np.arange(0, width) + + t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij') + patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0) + patch_ends = patch_starts + 1 + + latent_coords = np.stack([patch_starts, patch_ends], axis=-1) + num_patches = num_frames * height * width + latent_coords = latent_coords.reshape(3, num_patches, 2) + latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1)) + + # Scale to pixel space + scale_factors = np.array([8, 32, 32]).reshape(1, 3, 1, 1) + pixel_coords = (latent_coords * scale_factors).astype(np.float32) + pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / 24.0 # Convert to seconds + + return mx.array(pixel_coords, dtype=dtype) + +class TestRoPEPositionPrecision: + """Test suite for RoPE position precision requirements.""" + + def test_float32_positions_produce_consistent_output(self): + """Float32 position grids should produce stable RoPE frequencies.""" + positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32) + + cos_freq, sin_freq = precompute_freqs_cis( + indices_grid=positions, + dim=128, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + num_attention_heads=32, + rope_type=LTXRopeType.SPLIT, + double_precision=True, + ) + + # Verify output dtype is float32 + assert cos_freq.dtype == mx.float32, f"Expected float32, got {cos_freq.dtype}" + assert sin_freq.dtype == mx.float32, f"Expected float32, got {sin_freq.dtype}" + + # Verify no NaN or Inf values + assert not mx.any(mx.isnan(cos_freq)).item(), "cos_freq contains NaN" + assert not mx.any(mx.isnan(sin_freq)).item(), "sin_freq contains NaN" + assert not mx.any(mx.isinf(cos_freq)).item(), "cos_freq contains Inf" + assert not mx.any(mx.isinf(sin_freq)).item(), "sin_freq contains Inf" + + # Verify cos/sin are in valid range [-1, 1] + assert mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item(), \ + "cos_freq values out of [-1, 1] range" + assert mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item(), \ + "sin_freq values out of [-1, 1] range" + + def test_bfloat16_positions_cause_precision_loss(self): + """bfloat16 positions should produce different (less precise) results than float32. + + This test documents the known issue: bfloat16 has only 7 bits of mantissa + vs 23 bits for float32, causing quantization errors that get amplified + by sin/cos calculations in RoPE. + """ + # Create identical position grids in different dtypes + positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32) + positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16) + + # Compute RoPE frequencies + cos_f32, sin_f32 = precompute_freqs_cis( + indices_grid=positions_f32, + dim=128, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + num_attention_heads=32, + rope_type=LTXRopeType.SPLIT, + double_precision=True, + ) + + cos_bf16, sin_bf16 = precompute_freqs_cis( + indices_grid=positions_bf16, + dim=128, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + num_attention_heads=32, + rope_type=LTXRopeType.SPLIT, + double_precision=True, + ) + + # Calculate the difference + cos_diff = mx.abs(cos_f32 - cos_bf16) + sin_diff = mx.abs(sin_f32 - sin_bf16) + + max_cos_diff = mx.max(cos_diff).item() + max_sin_diff = mx.max(sin_diff).item() + + # bfloat16 positions WILL cause measurable differences + # This test documents this known behavior + # The threshold here is intentionally low to catch the issue + precision_threshold = 1e-6 + + has_precision_loss = max_cos_diff > precision_threshold or max_sin_diff > precision_threshold + + # Document the precision loss (this is expected behavior) + if has_precision_loss: + print(f"\nPrecision loss detected (expected):") + print(f" Max cos difference: {max_cos_diff:.6e}") + print(f" Max sin difference: {max_sin_diff:.6e}") + + # This assertion documents the issue - bfloat16 positions cause precision loss + assert has_precision_loss, \ + "Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed" + + def test_double_precision_converts_to_float32_internally(self): + """Verify that double_precision mode converts bfloat16 to float32 first.""" + positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16) + + # The double precision path in rope.py line 434: + # indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64) + # This means bfloat16 -> float32 -> float64 + # The precision is already lost at the bfloat16 -> float32 step + + cos_freq, sin_freq = precompute_freqs_cis( + indices_grid=positions_bf16, + dim=128, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + num_attention_heads=32, + rope_type=LTXRopeType.SPLIT, + double_precision=True, + ) + + # Output should still be float32 + assert cos_freq.dtype == mx.float32 + assert sin_freq.dtype == mx.float32 + + def test_position_grid_should_be_float32_recommendation(self): + """Test that validates the recommended practice: positions should be float32. + + This test serves as documentation that position grids MUST be float32 + to avoid quality degradation in generated videos/audio. + """ + # Recommended: create positions in float32 + positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32) + + assert positions.dtype == mx.float32, \ + "Position grids should be created in float32 for RoPE precision" + + # Verify the position values are reasonable + # Temporal positions should be small (seconds) + temporal_positions = positions[:, 0, :, :] + assert mx.max(temporal_positions).item() < 100, \ + "Temporal positions should be in seconds (small values)" + + # Spatial positions should be larger (pixels) + spatial_h = positions[:, 1, :, :] + spatial_w = positions[:, 2, :, :] + assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive" + assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive" + + +class TestRoPEInterleaved: + """Tests for interleaved RoPE mode.""" + + def test_interleaved_rope_with_float32_positions(self): + """Interleaved RoPE should work correctly with float32 positions.""" + positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32) + + cos_freq, sin_freq = precompute_freqs_cis( + indices_grid=positions, + dim=128, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + num_attention_heads=32, + rope_type=LTXRopeType.INTERLEAVED, + double_precision=False, + ) + + assert cos_freq.dtype == mx.float32 + assert sin_freq.dtype == mx.float32 + assert not mx.any(mx.isnan(cos_freq)).item() + assert not mx.any(mx.isnan(sin_freq)).item() + + +class TestRoPEWarnings: + """Tests for RoPE warnings.""" + + def test_bfloat16_positions_trigger_warning(self): + """Verify that bfloat16 positions trigger a UserWarning.""" + positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16) + + with pytest.warns(UserWarning, match="Position grid has dtype bfloat16"): + precompute_freqs_cis( + indices_grid=positions_bf16, + dim=128, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + num_attention_heads=32, + rope_type=LTXRopeType.SPLIT, + double_precision=True, + ) + + def test_float32_positions_no_warning(self): + """Verify that float32 positions do NOT trigger a warning.""" + positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32) + + # This should not raise any warnings + import warnings + with warnings.catch_warnings(): + warnings.simplefilter("error") # Turn warnings into errors + precompute_freqs_cis( + indices_grid=positions_f32, + dim=128, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + num_attention_heads=32, + rope_type=LTXRopeType.SPLIT, + double_precision=True, + ) + + +class TestRoPESplit: + """Tests for split RoPE mode (used by LTX-2).""" + + def test_split_rope_output_shape(self): + """Verify split RoPE output has correct shape (B, H, T, dim_per_head//2).""" + batch_size = 1 + num_frames = 4 + height = 4 + width = 4 + num_heads = 32 + dim = 128 + + positions = create_video_position_grid(batch_size, num_frames, height, width) + num_tokens = num_frames * height * width + + cos_freq, sin_freq = precompute_freqs_cis( + indices_grid=positions, + dim=dim, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + num_attention_heads=num_heads, + rope_type=LTXRopeType.SPLIT, + double_precision=True, + ) + + # Shape should be (B, H, T, dim_per_head//2) + # dim=128, num_heads=32, so dim_per_head=4, and split uses half=2 + dim_per_head = dim // num_heads + expected_shape = (batch_size, num_heads, num_tokens, dim_per_head // 2) + assert cos_freq.shape == expected_shape, \ + f"Expected shape {expected_shape}, got {cos_freq.shape}" + assert sin_freq.shape == expected_shape, \ + f"Expected shape {expected_shape}, got {sin_freq.shape}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])