adjust gelu and precision

This commit is contained in:
Prince Canuma
2026-01-15 12:49:21 +01:00
parent 349a82f763
commit f5134fa172
2 changed files with 9 additions and 4 deletions

View File

@@ -227,7 +227,7 @@ class GEGLU(nn.Module):
self.proj = nn.Linear(in_dim, out_dim, bias=True)
def __call__(self, x: mx.array) -> mx.array:
return nn.gelu_approx(self.proj(x))
return nn.gelu(self.proj(x))
class ConnectorFeedForward(nn.Module):
@@ -308,7 +308,6 @@ class Embeddings1DConnector(nn.Module):
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies for connector (INTERLEAVED type).
Matches PyTorch: generate_freq_grid_pytorch + generate_freqs + interleaved_freqs_cis
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
"""
@@ -327,7 +326,7 @@ class Embeddings1DConnector(nn.Module):
log_start = np.log(start) / np.log(theta) # = 0
log_end = np.log(end) / np.log(theta) # = 1
lin_space = np.linspace(log_start, log_end, num_indices, dtype=np.float64)
indices = np.power(theta, lin_space) * (np.pi / 2)
indices = (np.power(theta, lin_space) * (np.pi / 2)).astype(np.float32)
# Generate positions and compute freqs (matches generate_freqs)
positions = np.arange(seq_len, dtype=np.float64)