use numpy for improved float64 precision and performance

This commit is contained in:
Prince Canuma
2026-01-14 00:03:00 +01:00
parent 74af04718d
commit 957093c29b

View File

@@ -311,25 +311,26 @@ class Embeddings1DConnector(nn.Module):
Matches PyTorch: generate_freq_grid_pytorch + generate_freqs + interleaved_freqs_cis Matches PyTorch: generate_freq_grid_pytorch + generate_freqs + interleaved_freqs_cis
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim). Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
""" """
import math
import numpy as np
dim = self.num_heads * self.head_dim # inner_dim = 3840 dim = self.num_heads * self.head_dim # inner_dim = 3840
theta = self.positional_embedding_theta theta = self.positional_embedding_theta
max_pos = [1] # Default for connector max_pos = [1] # Default for connector
n_elem = 2 * len(max_pos) # = 2 n_elem = 2 * len(max_pos) # = 2
# Generate frequency indices (matches generate_freq_grid_pytorch)
start = 1.0 start = 1.0
end = theta end = theta
num_indices = dim // n_elem # 1920 num_indices = dim // n_elem # 1920
log_start = math.log(start) / math.log(theta) # = 0 # Use numpy float64 for precision
log_end = math.log(end) / math.log(theta) # = 1 log_start = np.log(start) / np.log(theta) # = 0
lin_space = mx.linspace(log_start, log_end, num_indices) log_end = np.log(end) / np.log(theta) # = 1
indices = (theta ** lin_space) * (math.pi / 2) lin_space = np.linspace(log_start, log_end, num_indices, dtype=np.float64)
indices = np.power(theta, lin_space) * (np.pi / 2)
# Generate positions and compute freqs (matches generate_freqs) # Generate positions and compute freqs (matches generate_freqs)
positions = mx.arange(seq_len).astype(mx.float32) positions = np.arange(seq_len, dtype=np.float64)
# fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1) # fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1)
# scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1 # scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1
scaled_positions = positions * 2 - 1 # Shape: (seq_len,) scaled_positions = positions * 2 - 1 # Shape: (seq_len,)
@@ -339,17 +340,17 @@ class Embeddings1DConnector(nn.Module):
freqs = scaled_positions[:, None] * indices[None, :] freqs = scaled_positions[:, None] * indices[None, :]
# Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis) # Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis)
cos_freq = mx.cos(freqs) cos_freq = np.cos(freqs)
sin_freq = mx.sin(freqs) sin_freq = np.sin(freqs)
# repeat_interleave: (seq_len, num_indices) -> (seq_len, dim) # repeat_interleave: (seq_len, num_indices) -> (seq_len, dim)
# Pattern: [c0, c0, c1, c1, c2, c2, ...] # Pattern: [c0, c0, c1, c1, c2, c2, ...]
cos_full = mx.repeat(cos_freq, 2, axis=-1) cos_full = np.repeat(cos_freq, 2, axis=-1)
sin_full = mx.repeat(sin_freq, 2, axis=-1) sin_full = np.repeat(sin_freq, 2, axis=-1)
# Add batch dimension: (1, seq_len, dim) # Add batch dimension and convert to MLX: (1, seq_len, dim)
cos_full = cos_full[None, :, :] cos_full = mx.array(cos_full[None, :, :].astype(np.float32))
sin_full = sin_full[None, :, :] sin_full = mx.array(sin_full[None, :, :].astype(np.float32))
return cos_full.astype(dtype), sin_full.astype(dtype) return cos_full.astype(dtype), sin_full.astype(dtype)