adjust gelu and precision
This commit is contained in:
@@ -49,6 +49,12 @@ def apply_interleaved_rotary_emb(
|
||||
Returns:
|
||||
Tensor with interleaved rotary embeddings applied
|
||||
"""
|
||||
# Compute in float32 for better precision
|
||||
input_dtype = input_tensor.dtype
|
||||
input_tensor = input_tensor.astype(mx.float32)
|
||||
cos_freqs = cos_freqs.astype(mx.float32)
|
||||
sin_freqs = sin_freqs.astype(mx.float32)
|
||||
|
||||
# Reshape to pair adjacent dimensions: (..., dim) -> (..., dim/2, 2)
|
||||
shape = input_tensor.shape
|
||||
input_tensor = mx.reshape(input_tensor, shape[:-1] + (shape[-1] // 2, 2))
|
||||
@@ -67,7 +73,7 @@ def apply_interleaved_rotary_emb(
|
||||
# Apply rotary embeddings
|
||||
out = input_tensor * cos_freqs + t_rot * sin_freqs
|
||||
|
||||
return out
|
||||
return out.astype(input_dtype)
|
||||
|
||||
|
||||
def rotate_half_interleaved(x: mx.array) -> mx.array:
|
||||
|
||||
Reference in New Issue
Block a user