adjust gelu and precision
This commit is contained in:
@@ -49,6 +49,12 @@ def apply_interleaved_rotary_emb(
|
|||||||
Returns:
|
Returns:
|
||||||
Tensor with interleaved rotary embeddings applied
|
Tensor with interleaved rotary embeddings applied
|
||||||
"""
|
"""
|
||||||
|
# Compute in float32 for better precision
|
||||||
|
input_dtype = input_tensor.dtype
|
||||||
|
input_tensor = input_tensor.astype(mx.float32)
|
||||||
|
cos_freqs = cos_freqs.astype(mx.float32)
|
||||||
|
sin_freqs = sin_freqs.astype(mx.float32)
|
||||||
|
|
||||||
# Reshape to pair adjacent dimensions: (..., dim) -> (..., dim/2, 2)
|
# Reshape to pair adjacent dimensions: (..., dim) -> (..., dim/2, 2)
|
||||||
shape = input_tensor.shape
|
shape = input_tensor.shape
|
||||||
input_tensor = mx.reshape(input_tensor, shape[:-1] + (shape[-1] // 2, 2))
|
input_tensor = mx.reshape(input_tensor, shape[:-1] + (shape[-1] // 2, 2))
|
||||||
@@ -67,7 +73,7 @@ def apply_interleaved_rotary_emb(
|
|||||||
# Apply rotary embeddings
|
# Apply rotary embeddings
|
||||||
out = input_tensor * cos_freqs + t_rot * sin_freqs
|
out = input_tensor * cos_freqs + t_rot * sin_freqs
|
||||||
|
|
||||||
return out
|
return out.astype(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
def rotate_half_interleaved(x: mx.array) -> mx.array:
|
def rotate_half_interleaved(x: mx.array) -> mx.array:
|
||||||
|
|||||||
@@ -227,7 +227,7 @@ class GEGLU(nn.Module):
|
|||||||
self.proj = nn.Linear(in_dim, out_dim, bias=True)
|
self.proj = nn.Linear(in_dim, out_dim, bias=True)
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
return nn.gelu_approx(self.proj(x))
|
return nn.gelu(self.proj(x))
|
||||||
|
|
||||||
|
|
||||||
class ConnectorFeedForward(nn.Module):
|
class ConnectorFeedForward(nn.Module):
|
||||||
@@ -308,7 +308,6 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
|
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
|
||||||
"""Compute RoPE frequencies for connector (INTERLEAVED type).
|
"""Compute RoPE frequencies for connector (INTERLEAVED type).
|
||||||
|
|
||||||
Matches PyTorch: generate_freq_grid_pytorch + generate_freqs + interleaved_freqs_cis
|
|
||||||
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
|
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -327,7 +326,7 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
log_start = np.log(start) / np.log(theta) # = 0
|
log_start = np.log(start) / np.log(theta) # = 0
|
||||||
log_end = np.log(end) / np.log(theta) # = 1
|
log_end = np.log(end) / np.log(theta) # = 1
|
||||||
lin_space = np.linspace(log_start, log_end, num_indices, dtype=np.float64)
|
lin_space = np.linspace(log_start, log_end, num_indices, dtype=np.float64)
|
||||||
indices = np.power(theta, lin_space) * (np.pi / 2)
|
indices = (np.power(theta, lin_space) * (np.pi / 2)).astype(np.float32)
|
||||||
|
|
||||||
# Generate positions and compute freqs (matches generate_freqs)
|
# Generate positions and compute freqs (matches generate_freqs)
|
||||||
positions = np.arange(seq_len, dtype=np.float64)
|
positions = np.arange(seq_len, dtype=np.float64)
|
||||||
|
|||||||
Reference in New Issue
Block a user