This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -1,4 +1,3 @@
import math
from typing import List, Optional, Tuple
@@ -86,11 +85,12 @@ def rotate_half_interleaved(x: mx.array) -> mx.array:
"""
# x: (..., dim) where dim is even
x_even = x[..., 0::2] # [x0, x2, x4, ...]
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
# Stack: [[-x1, x0], [-x3, x2], ...] then flatten to [-x1, x0, -x3, x2, ...]
rotated = mx.stack([-x_odd, x_even], axis=-1)
return mx.reshape(rotated, x.shape)
def apply_rotary_emb_1d(
q: mx.array,
k: mx.array,
@@ -228,9 +228,9 @@ def get_fractional_positions(
Fractional positions in range [-1, 1] after scaling
"""
n_pos_dims = indices_grid.shape[1]
assert n_pos_dims == len(max_pos), (
f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
)
assert n_pos_dims == len(
max_pos
), f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
# Divide each dimension by its max position
fractional_positions = []
@@ -392,11 +392,15 @@ def precompute_freqs_cis(
if max_pos is None:
max_pos = [20, 2048, 2048]
if double_precision:
return _precompute_freqs_cis_double_precision(
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
num_attention_heads, rope_type
indices_grid,
dim,
theta,
max_pos,
use_middle_indices_grid,
num_attention_heads,
rope_type,
)
# Keep positions in float32 for RoPE computation.
@@ -495,7 +499,9 @@ def _precompute_freqs_cis_double_precision(
# Compute frequencies: outer product
# scaled_positions: (B, T, n_dims) -> (B, T, n_dims, 1)
# freq_indices: (num_indices,) -> (1, 1, 1, num_indices)
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(freq_indices, (1, 1, 1, -1))
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(
freq_indices, (1, 1, 1, -1)
)
# freqs: (B, T, n_dims, num_indices)
# Transpose and flatten: (B, T, n_dims, num_indices) -> (B, T, num_indices, n_dims) -> (B, T, num_indices * n_dims)