format
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
@@ -86,11 +85,12 @@ def rotate_half_interleaved(x: mx.array) -> mx.array:
|
||||
"""
|
||||
# x: (..., dim) where dim is even
|
||||
x_even = x[..., 0::2] # [x0, x2, x4, ...]
|
||||
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
|
||||
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
|
||||
# Stack: [[-x1, x0], [-x3, x2], ...] then flatten to [-x1, x0, -x3, x2, ...]
|
||||
rotated = mx.stack([-x_odd, x_even], axis=-1)
|
||||
return mx.reshape(rotated, x.shape)
|
||||
|
||||
|
||||
def apply_rotary_emb_1d(
|
||||
q: mx.array,
|
||||
k: mx.array,
|
||||
@@ -228,9 +228,9 @@ def get_fractional_positions(
|
||||
Fractional positions in range [-1, 1] after scaling
|
||||
"""
|
||||
n_pos_dims = indices_grid.shape[1]
|
||||
assert n_pos_dims == len(max_pos), (
|
||||
f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
|
||||
)
|
||||
assert n_pos_dims == len(
|
||||
max_pos
|
||||
), f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
|
||||
|
||||
# Divide each dimension by its max position
|
||||
fractional_positions = []
|
||||
@@ -392,11 +392,15 @@ def precompute_freqs_cis(
|
||||
if max_pos is None:
|
||||
max_pos = [20, 2048, 2048]
|
||||
|
||||
|
||||
if double_precision:
|
||||
return _precompute_freqs_cis_double_precision(
|
||||
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
|
||||
num_attention_heads, rope_type
|
||||
indices_grid,
|
||||
dim,
|
||||
theta,
|
||||
max_pos,
|
||||
use_middle_indices_grid,
|
||||
num_attention_heads,
|
||||
rope_type,
|
||||
)
|
||||
|
||||
# Keep positions in float32 for RoPE computation.
|
||||
@@ -495,7 +499,9 @@ def _precompute_freqs_cis_double_precision(
|
||||
# Compute frequencies: outer product
|
||||
# scaled_positions: (B, T, n_dims) -> (B, T, n_dims, 1)
|
||||
# freq_indices: (num_indices,) -> (1, 1, 1, num_indices)
|
||||
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(freq_indices, (1, 1, 1, -1))
|
||||
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(
|
||||
freq_indices, (1, 1, 1, -1)
|
||||
)
|
||||
# freqs: (B, T, n_dims, num_indices)
|
||||
|
||||
# Transpose and flatten: (B, T, n_dims, num_indices) -> (B, T, num_indices, n_dims) -> (B, T, num_indices * n_dims)
|
||||
|
||||
Reference in New Issue
Block a user