fix tiling, rope precision and weights

This commit is contained in:
Prince Canuma
2026-03-15 22:58:55 +01:00
parent ebcd5dd4e4
commit cecd68197c
5 changed files with 86 additions and 149 deletions

View File

@@ -399,13 +399,13 @@ def precompute_freqs_cis(
num_attention_heads, rope_type
)
# Cast positions to bfloat16 to match PyTorch's behavior.
# In PyTorch, positions are in bfloat16 (model dtype) during the entire
# generate_freqs computation — fractional positions, scaling, etc. are all
# computed in bfloat16. The multiplication with float32 freq_indices then
# upcasts to float32. This precision behavior is what the model was trained
# with, so we must replicate it.
indices_grid = indices_grid.astype(mx.bfloat16)
# Keep positions in float32 for RoPE computation.
# Even though PyTorch nominally casts positions to model dtype (bfloat16),
# empirical comparison shows float32 positions produce RoPE values matching
# PyTorch exactly (cosine=1.000). BFloat16 loses precision in fractional
# position computation that gets amplified by high-frequency indices
# (up to 15708), causing cos/sin sign flips and cosine sim of only 0.88.
indices_grid = indices_grid.astype(mx.float32)
# Generate frequency indices
indices = generate_freq_grid(theta, indices_grid.shape[1], dim)
@@ -438,23 +438,14 @@ def _precompute_freqs_cis_double_precision(
) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies with higher precision using float64 for frequency grid.
Matches PyTorch's approach: uses NumPy float64 for the critical frequency grid
computation (log-spaced values), then converts to float32 for the final tensor.
This provides better numerical precision in the frequency generation phase.
Matches PyTorch's generate_freq_grid_np: uses NumPy float64 for the critical
frequency grid computation (log-spaced values), then converts to float32.
Position grid stays in bfloat16 to match PyTorch behavior (positions are in
model dtype throughout generate_freqs).
"""
import numpy as np
# Warn if positions are bfloat16 - this causes quality degradation
if indices_grid.dtype == mx.bfloat16:
import warnings
warnings.warn(
"Position grid has dtype bfloat16, which causes precision loss in RoPE. "
"Use float32 for position grids to avoid quality degradation.",
UserWarning,
stacklevel=2
)
# Cast to float32 for position computation
# Keep positions in float32 — same reasoning as the non-double-precision path.
indices_grid_f32 = indices_grid.astype(mx.float32)
n_pos_dims = indices_grid_f32.shape[1]