fix tiling, rope precision and weights
This commit is contained in:
@@ -399,13 +399,13 @@ def precompute_freqs_cis(
|
||||
num_attention_heads, rope_type
|
||||
)
|
||||
|
||||
# Cast positions to bfloat16 to match PyTorch's behavior.
|
||||
# In PyTorch, positions are in bfloat16 (model dtype) during the entire
|
||||
# generate_freqs computation — fractional positions, scaling, etc. are all
|
||||
# computed in bfloat16. The multiplication with float32 freq_indices then
|
||||
# upcasts to float32. This precision behavior is what the model was trained
|
||||
# with, so we must replicate it.
|
||||
indices_grid = indices_grid.astype(mx.bfloat16)
|
||||
# Keep positions in float32 for RoPE computation.
|
||||
# Even though PyTorch nominally casts positions to model dtype (bfloat16),
|
||||
# empirical comparison shows float32 positions produce RoPE values matching
|
||||
# PyTorch exactly (cosine=1.000). BFloat16 loses precision in fractional
|
||||
# position computation that gets amplified by high-frequency indices
|
||||
# (up to 15708), causing cos/sin sign flips and cosine sim of only 0.88.
|
||||
indices_grid = indices_grid.astype(mx.float32)
|
||||
|
||||
# Generate frequency indices
|
||||
indices = generate_freq_grid(theta, indices_grid.shape[1], dim)
|
||||
@@ -438,23 +438,14 @@ def _precompute_freqs_cis_double_precision(
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute RoPE frequencies with higher precision using float64 for frequency grid.
|
||||
|
||||
Matches PyTorch's approach: uses NumPy float64 for the critical frequency grid
|
||||
computation (log-spaced values), then converts to float32 for the final tensor.
|
||||
This provides better numerical precision in the frequency generation phase.
|
||||
Matches PyTorch's generate_freq_grid_np: uses NumPy float64 for the critical
|
||||
frequency grid computation (log-spaced values), then converts to float32.
|
||||
Position grid stays in bfloat16 to match PyTorch behavior (positions are in
|
||||
model dtype throughout generate_freqs).
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Warn if positions are bfloat16 - this causes quality degradation
|
||||
if indices_grid.dtype == mx.bfloat16:
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Position grid has dtype bfloat16, which causes precision loss in RoPE. "
|
||||
"Use float32 for position grids to avoid quality degradation.",
|
||||
UserWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
|
||||
# Cast to float32 for position computation
|
||||
# Keep positions in float32 — same reasoning as the non-double-precision path.
|
||||
indices_grid_f32 = indices_grid.astype(mx.float32)
|
||||
|
||||
n_pos_dims = indices_grid_f32.shape[1]
|
||||
|
||||
Reference in New Issue
Block a user