Add RoPE tests and warning for bfloat16 precision loss in RoPE calculations

This commit is contained in:
Prince Canuma
2026-01-17 19:28:05 +01:00
parent 78244a2d66
commit 61c56cd989
3 changed files with 292 additions and 0 deletions

View File

@@ -430,7 +430,19 @@ def _precompute_freqs_cis_double_precision(
rope_type: LTXRopeType,
) -> Tuple[mx.array, mx.array]:
# Warn if positions are bfloat16 - this causes quality degradation
if indices_grid.dtype == mx.bfloat16:
import warnings
warnings.warn(
"Position grid has dtype bfloat16, which causes precision loss in RoPE that causes quality degradation in generated videos/audio. "
"Use float32 for position grids to avoid quality degradation. "
"See tests/test_rope.py::test_bfloat16_positions_cause_precision_loss",
UserWarning,
stacklevel=2
)
# Convert to numpy float64 (first to float32 for numpy compatibility)
# Note: If input is bfloat16, precision is already lost at this step
indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
# Generate frequency indices in float64