adjust gelu and precision

This commit is contained in:
Prince Canuma
2026-01-15 12:49:21 +01:00
parent 349a82f763
commit f5134fa172
2 changed files with 9 additions and 4 deletions

View File

@@ -49,6 +49,12 @@ def apply_interleaved_rotary_emb(
Returns:
Tensor with interleaved rotary embeddings applied
"""
# Compute in float32 for better precision
input_dtype = input_tensor.dtype
input_tensor = input_tensor.astype(mx.float32)
cos_freqs = cos_freqs.astype(mx.float32)
sin_freqs = sin_freqs.astype(mx.float32)
# Reshape to pair adjacent dimensions: (..., dim) -> (..., dim/2, 2)
shape = input_tensor.shape
input_tensor = mx.reshape(input_tensor, shape[:-1] + (shape[-1] // 2, 2))
@@ -67,7 +73,7 @@ def apply_interleaved_rotary_emb(
# Apply rotary embeddings
out = input_tensor * cos_freqs + t_rot * sin_freqs
return out
return out.astype(input_dtype)
def rotate_half_interleaved(x: mx.array) -> mx.array: