adjust gelu and precision
This commit is contained in:
@@ -227,7 +227,7 @@ class GEGLU(nn.Module):
|
||||
self.proj = nn.Linear(in_dim, out_dim, bias=True)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return nn.gelu_approx(self.proj(x))
|
||||
return nn.gelu(self.proj(x))
|
||||
|
||||
|
||||
class ConnectorFeedForward(nn.Module):
|
||||
@@ -308,7 +308,6 @@ class Embeddings1DConnector(nn.Module):
|
||||
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute RoPE frequencies for connector (INTERLEAVED type).
|
||||
|
||||
Matches PyTorch: generate_freq_grid_pytorch + generate_freqs + interleaved_freqs_cis
|
||||
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
|
||||
"""
|
||||
|
||||
@@ -327,7 +326,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
log_start = np.log(start) / np.log(theta) # = 0
|
||||
log_end = np.log(end) / np.log(theta) # = 1
|
||||
lin_space = np.linspace(log_start, log_end, num_indices, dtype=np.float64)
|
||||
indices = np.power(theta, lin_space) * (np.pi / 2)
|
||||
indices = (np.power(theta, lin_space) * (np.pi / 2)).astype(np.float32)
|
||||
|
||||
# Generate positions and compute freqs (matches generate_freqs)
|
||||
positions = np.arange(seq_len, dtype=np.float64)
|
||||
|
||||
Reference in New Issue
Block a user