use numpy for improved float64 precision and performance
This commit is contained in:
@@ -311,25 +311,26 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
Matches PyTorch: generate_freq_grid_pytorch + generate_freqs + interleaved_freqs_cis
|
Matches PyTorch: generate_freq_grid_pytorch + generate_freqs + interleaved_freqs_cis
|
||||||
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
|
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
|
||||||
"""
|
"""
|
||||||
import math
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
dim = self.num_heads * self.head_dim # inner_dim = 3840
|
dim = self.num_heads * self.head_dim # inner_dim = 3840
|
||||||
theta = self.positional_embedding_theta
|
theta = self.positional_embedding_theta
|
||||||
max_pos = [1] # Default for connector
|
max_pos = [1] # Default for connector
|
||||||
n_elem = 2 * len(max_pos) # = 2
|
n_elem = 2 * len(max_pos) # = 2
|
||||||
|
|
||||||
# Generate frequency indices (matches generate_freq_grid_pytorch)
|
|
||||||
start = 1.0
|
start = 1.0
|
||||||
end = theta
|
end = theta
|
||||||
num_indices = dim // n_elem # 1920
|
num_indices = dim // n_elem # 1920
|
||||||
|
|
||||||
log_start = math.log(start) / math.log(theta) # = 0
|
# Use numpy float64 for precision
|
||||||
log_end = math.log(end) / math.log(theta) # = 1
|
log_start = np.log(start) / np.log(theta) # = 0
|
||||||
lin_space = mx.linspace(log_start, log_end, num_indices)
|
log_end = np.log(end) / np.log(theta) # = 1
|
||||||
indices = (theta ** lin_space) * (math.pi / 2)
|
lin_space = np.linspace(log_start, log_end, num_indices, dtype=np.float64)
|
||||||
|
indices = np.power(theta, lin_space) * (np.pi / 2)
|
||||||
|
|
||||||
# Generate positions and compute freqs (matches generate_freqs)
|
# Generate positions and compute freqs (matches generate_freqs)
|
||||||
positions = mx.arange(seq_len).astype(mx.float32)
|
positions = np.arange(seq_len, dtype=np.float64)
|
||||||
# fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1)
|
# fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1)
|
||||||
# scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1
|
# scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1
|
||||||
scaled_positions = positions * 2 - 1 # Shape: (seq_len,)
|
scaled_positions = positions * 2 - 1 # Shape: (seq_len,)
|
||||||
@@ -339,17 +340,17 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
freqs = scaled_positions[:, None] * indices[None, :]
|
freqs = scaled_positions[:, None] * indices[None, :]
|
||||||
|
|
||||||
# Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis)
|
# Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis)
|
||||||
cos_freq = mx.cos(freqs)
|
cos_freq = np.cos(freqs)
|
||||||
sin_freq = mx.sin(freqs)
|
sin_freq = np.sin(freqs)
|
||||||
|
|
||||||
# repeat_interleave: (seq_len, num_indices) -> (seq_len, dim)
|
# repeat_interleave: (seq_len, num_indices) -> (seq_len, dim)
|
||||||
# Pattern: [c0, c0, c1, c1, c2, c2, ...]
|
# Pattern: [c0, c0, c1, c1, c2, c2, ...]
|
||||||
cos_full = mx.repeat(cos_freq, 2, axis=-1)
|
cos_full = np.repeat(cos_freq, 2, axis=-1)
|
||||||
sin_full = mx.repeat(sin_freq, 2, axis=-1)
|
sin_full = np.repeat(sin_freq, 2, axis=-1)
|
||||||
|
|
||||||
# Add batch dimension: (1, seq_len, dim)
|
# Add batch dimension and convert to MLX: (1, seq_len, dim)
|
||||||
cos_full = cos_full[None, :, :]
|
cos_full = mx.array(cos_full[None, :, :].astype(np.float32))
|
||||||
sin_full = sin_full[None, :, :]
|
sin_full = mx.array(sin_full[None, :, :].astype(np.float32))
|
||||||
|
|
||||||
return cos_full.astype(dtype), sin_full.astype(dtype)
|
return cos_full.astype(dtype), sin_full.astype(dtype)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user