use numpy for improved float64 precision and performance

This commit is contained in:
Prince Canuma
2026-01-14 00:03:00 +01:00
parent 74af04718d
commit 957093c29b

View File

@@ -311,25 +311,26 @@ class Embeddings1DConnector(nn.Module):
Matches PyTorch: generate_freq_grid_pytorch + generate_freqs + interleaved_freqs_cis
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
"""
import math
import numpy as np
dim = self.num_heads * self.head_dim # inner_dim = 3840
theta = self.positional_embedding_theta
max_pos = [1] # Default for connector
n_elem = 2 * len(max_pos) # = 2
# Generate frequency indices (matches generate_freq_grid_pytorch)
start = 1.0
end = theta
num_indices = dim // n_elem # 1920
log_start = math.log(start) / math.log(theta) # = 0
log_end = math.log(end) / math.log(theta) # = 1
lin_space = mx.linspace(log_start, log_end, num_indices)
indices = (theta ** lin_space) * (math.pi / 2)
# Use numpy float64 for precision
log_start = np.log(start) / np.log(theta) # = 0
log_end = np.log(end) / np.log(theta) # = 1
lin_space = np.linspace(log_start, log_end, num_indices, dtype=np.float64)
indices = np.power(theta, lin_space) * (np.pi / 2)
# Generate positions and compute freqs (matches generate_freqs)
positions = mx.arange(seq_len).astype(mx.float32)
positions = np.arange(seq_len, dtype=np.float64)
# fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1)
# scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1
scaled_positions = positions * 2 - 1 # Shape: (seq_len,)
@@ -339,17 +340,17 @@ class Embeddings1DConnector(nn.Module):
freqs = scaled_positions[:, None] * indices[None, :]
# Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis)
cos_freq = mx.cos(freqs)
sin_freq = mx.sin(freqs)
cos_freq = np.cos(freqs)
sin_freq = np.sin(freqs)
# repeat_interleave: (seq_len, num_indices) -> (seq_len, dim)
# Pattern: [c0, c0, c1, c1, c2, c2, ...]
cos_full = mx.repeat(cos_freq, 2, axis=-1)
sin_full = mx.repeat(sin_freq, 2, axis=-1)
cos_full = np.repeat(cos_freq, 2, axis=-1)
sin_full = np.repeat(sin_freq, 2, axis=-1)
# Add batch dimension: (1, seq_len, dim)
cos_full = cos_full[None, :, :]
sin_full = sin_full[None, :, :]
# Add batch dimension and convert to MLX: (1, seq_len, dim)
cos_full = mx.array(cos_full[None, :, :].astype(np.float32))
sin_full = mx.array(sin_full[None, :, :].astype(np.float32))
return cos_full.astype(dtype), sin_full.astype(dtype)