This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -49,20 +49,19 @@ class T5RelativeEmbedding(nn.Module):
is_small = rel_pos < max_exact
rel_pos_f = rel_pos.astype(mx.float32)
rel_pos_large = (
max_exact
+ (
mx.log(rel_pos_f / max_exact)
/ math.log(self.max_dist / max_exact)
* (num_buckets - max_exact)
).astype(mx.int32)
)
rel_pos_large = max_exact + (
mx.log(rel_pos_f / max_exact)
/ math.log(self.max_dist / max_exact)
* (num_buckets - max_exact)
).astype(mx.int32)
rel_pos_large = mx.minimum(
rel_pos_large,
mx.full(rel_pos_large.shape, num_buckets - 1, dtype=mx.int32),
)
rel_buckets = rel_buckets + mx.where(is_small, rel_pos.astype(mx.int32), rel_pos_large)
rel_buckets = rel_buckets + mx.where(
is_small, rel_pos.astype(mx.int32), rel_pos_large
)
return rel_buckets
def __call__(self, lq: int, lk: int) -> mx.array:
@@ -115,7 +114,7 @@ class T5Attention(nn.Module):
v = v.transpose(0, 2, 1, 3)
# QK^T (no scaling) — compute in float32 for precision
attn = (q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2))
attn = q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2)
# Add position bias
if pos_bias is not None: