Add RoPE tests and warning for bfloat16 precision loss in RoPE calculations
This commit is contained in:
@@ -430,7 +430,19 @@ def _precompute_freqs_cis_double_precision(
|
||||
rope_type: LTXRopeType,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
|
||||
# Warn if positions are bfloat16 - this causes quality degradation
|
||||
if indices_grid.dtype == mx.bfloat16:
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Position grid has dtype bfloat16, which causes precision loss in RoPE that causes quality degradation in generated videos/audio. "
|
||||
"Use float32 for position grids to avoid quality degradation. "
|
||||
"See tests/test_rope.py::test_bfloat16_positions_cause_precision_loss",
|
||||
UserWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
|
||||
# Convert to numpy float64 (first to float32 for numpy compatibility)
|
||||
# Note: If input is bfloat16, precision is already lost at this step
|
||||
indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
|
||||
|
||||
# Generate frequency indices in float64
|
||||
|
||||
Reference in New Issue
Block a user