import pytest import mlx.core as mx import numpy as np from mlx_video.models.ltx.rope import ( precompute_freqs_cis, ) from mlx_video.models.ltx.config import LTXRopeType def create_video_position_grid( batch_size: int, num_frames: int, height: int, width: int, dtype: mx.Dtype = mx.float32, ) -> mx.array: """Create a simple video position grid for testing.""" t_coords = np.arange(0, num_frames) h_coords = np.arange(0, height) w_coords = np.arange(0, width) t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij') patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0) patch_ends = patch_starts + 1 latent_coords = np.stack([patch_starts, patch_ends], axis=-1) num_patches = num_frames * height * width latent_coords = latent_coords.reshape(3, num_patches, 2) latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1)) # Scale to pixel space scale_factors = np.array([8, 32, 32]).reshape(1, 3, 1, 1) pixel_coords = (latent_coords * scale_factors).astype(np.float32) pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / 24.0 # Convert to seconds return mx.array(pixel_coords, dtype=dtype) class TestRoPEPositionPrecision: """Test suite for RoPE position precision requirements.""" def test_float32_positions_produce_consistent_output(self): """Float32 position grids should produce stable RoPE frequencies.""" positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32) cos_freq, sin_freq = precompute_freqs_cis( indices_grid=positions, dim=128, theta=10000.0, max_pos=[20, 2048, 2048], use_middle_indices_grid=True, num_attention_heads=32, rope_type=LTXRopeType.SPLIT, double_precision=True, ) # Verify output dtype is float32 assert cos_freq.dtype == mx.float32, f"Expected float32, got {cos_freq.dtype}" assert sin_freq.dtype == mx.float32, f"Expected float32, got {sin_freq.dtype}" # Verify no NaN or Inf values assert not mx.any(mx.isnan(cos_freq)).item(), "cos_freq contains NaN" assert not mx.any(mx.isnan(sin_freq)).item(), "sin_freq contains NaN" assert not mx.any(mx.isinf(cos_freq)).item(), "cos_freq contains Inf" assert not mx.any(mx.isinf(sin_freq)).item(), "sin_freq contains Inf" # Verify cos/sin are in valid range [-1, 1] assert mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item(), \ "cos_freq values out of [-1, 1] range" assert mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item(), \ "sin_freq values out of [-1, 1] range" def test_bfloat16_positions_cause_precision_loss(self): """bfloat16 positions should produce different (less precise) results than float32. This test documents the known issue: bfloat16 has only 7 bits of mantissa vs 23 bits for float32, causing quantization errors that get amplified by sin/cos calculations in RoPE. """ # Create identical position grids in different dtypes positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32) positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16) # Compute RoPE frequencies cos_f32, sin_f32 = precompute_freqs_cis( indices_grid=positions_f32, dim=128, theta=10000.0, max_pos=[20, 2048, 2048], use_middle_indices_grid=True, num_attention_heads=32, rope_type=LTXRopeType.SPLIT, double_precision=True, ) cos_bf16, sin_bf16 = precompute_freqs_cis( indices_grid=positions_bf16, dim=128, theta=10000.0, max_pos=[20, 2048, 2048], use_middle_indices_grid=True, num_attention_heads=32, rope_type=LTXRopeType.SPLIT, double_precision=True, ) # Calculate the difference cos_diff = mx.abs(cos_f32 - cos_bf16) sin_diff = mx.abs(sin_f32 - sin_bf16) max_cos_diff = mx.max(cos_diff).item() max_sin_diff = mx.max(sin_diff).item() # bfloat16 positions WILL cause measurable differences # This test documents this known behavior # The threshold here is intentionally low to catch the issue precision_threshold = 1e-6 has_precision_loss = max_cos_diff > precision_threshold or max_sin_diff > precision_threshold # Document the precision loss (this is expected behavior) if has_precision_loss: print(f"\nPrecision loss detected (expected):") print(f" Max cos difference: {max_cos_diff:.6e}") print(f" Max sin difference: {max_sin_diff:.6e}") # This assertion documents the issue - bfloat16 positions cause precision loss assert has_precision_loss, \ "Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed" def test_double_precision_converts_to_float32_internally(self): """Verify that double_precision mode converts bfloat16 to float32 first.""" positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16) # The double precision path in rope.py line 434: # indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64) # This means bfloat16 -> float32 -> float64 # The precision is already lost at the bfloat16 -> float32 step cos_freq, sin_freq = precompute_freqs_cis( indices_grid=positions_bf16, dim=128, theta=10000.0, max_pos=[20, 2048, 2048], use_middle_indices_grid=True, num_attention_heads=32, rope_type=LTXRopeType.SPLIT, double_precision=True, ) # Output should still be float32 assert cos_freq.dtype == mx.float32 assert sin_freq.dtype == mx.float32 def test_position_grid_should_be_float32_recommendation(self): """Test that validates the recommended practice: positions should be float32. This test serves as documentation that position grids MUST be float32 to avoid quality degradation in generated videos/audio. """ # Recommended: create positions in float32 positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32) assert positions.dtype == mx.float32, \ "Position grids should be created in float32 for RoPE precision" # Verify the position values are reasonable # Temporal positions should be small (seconds) temporal_positions = positions[:, 0, :, :] assert mx.max(temporal_positions).item() < 100, \ "Temporal positions should be in seconds (small values)" # Spatial positions should be larger (pixels) spatial_h = positions[:, 1, :, :] spatial_w = positions[:, 2, :, :] assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive" assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive" class TestRoPEInterleaved: """Tests for interleaved RoPE mode.""" def test_interleaved_rope_with_float32_positions(self): """Interleaved RoPE should work correctly with float32 positions.""" positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32) cos_freq, sin_freq = precompute_freqs_cis( indices_grid=positions, dim=128, theta=10000.0, max_pos=[20, 2048, 2048], use_middle_indices_grid=True, num_attention_heads=32, rope_type=LTXRopeType.INTERLEAVED, double_precision=False, ) assert cos_freq.dtype == mx.float32 assert sin_freq.dtype == mx.float32 assert not mx.any(mx.isnan(cos_freq)).item() assert not mx.any(mx.isnan(sin_freq)).item() class TestRoPEWarnings: """Tests for RoPE warnings.""" def test_bfloat16_positions_trigger_warning(self): """Verify that bfloat16 positions trigger a UserWarning.""" positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16) with pytest.warns(UserWarning, match="Position grid has dtype bfloat16"): precompute_freqs_cis( indices_grid=positions_bf16, dim=128, theta=10000.0, max_pos=[20, 2048, 2048], use_middle_indices_grid=True, num_attention_heads=32, rope_type=LTXRopeType.SPLIT, double_precision=True, ) def test_float32_positions_no_warning(self): """Verify that float32 positions do NOT trigger a warning.""" positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32) # This should not raise any warnings import warnings with warnings.catch_warnings(): warnings.simplefilter("error") # Turn warnings into errors precompute_freqs_cis( indices_grid=positions_f32, dim=128, theta=10000.0, max_pos=[20, 2048, 2048], use_middle_indices_grid=True, num_attention_heads=32, rope_type=LTXRopeType.SPLIT, double_precision=True, ) class TestRoPESplit: """Tests for split RoPE mode (used by LTX-2).""" def test_split_rope_output_shape(self): """Verify split RoPE output has correct shape (B, H, T, dim_per_head//2).""" batch_size = 1 num_frames = 4 height = 4 width = 4 num_heads = 32 dim = 128 positions = create_video_position_grid(batch_size, num_frames, height, width) num_tokens = num_frames * height * width cos_freq, sin_freq = precompute_freqs_cis( indices_grid=positions, dim=dim, theta=10000.0, max_pos=[20, 2048, 2048], use_middle_indices_grid=True, num_attention_heads=num_heads, rope_type=LTXRopeType.SPLIT, double_precision=True, ) # Shape should be (B, H, T, dim_per_head//2) # dim=128, num_heads=32, so dim_per_head=4, and split uses half=2 dim_per_head = dim // num_heads expected_shape = (batch_size, num_heads, num_tokens, dim_per_head // 2) assert cos_freq.shape == expected_shape, \ f"Expected shape {expected_shape}, got {cos_freq.shape}" assert sin_freq.shape == expected_shape, \ f"Expected shape {expected_shape}, got {sin_freq.shape}" if __name__ == "__main__": pytest.main([__file__, "-v"])