format
This commit is contained in:
@@ -49,20 +49,19 @@ class T5RelativeEmbedding(nn.Module):
|
||||
is_small = rel_pos < max_exact
|
||||
|
||||
rel_pos_f = rel_pos.astype(mx.float32)
|
||||
rel_pos_large = (
|
||||
max_exact
|
||||
+ (
|
||||
mx.log(rel_pos_f / max_exact)
|
||||
/ math.log(self.max_dist / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).astype(mx.int32)
|
||||
)
|
||||
rel_pos_large = max_exact + (
|
||||
mx.log(rel_pos_f / max_exact)
|
||||
/ math.log(self.max_dist / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).astype(mx.int32)
|
||||
rel_pos_large = mx.minimum(
|
||||
rel_pos_large,
|
||||
mx.full(rel_pos_large.shape, num_buckets - 1, dtype=mx.int32),
|
||||
)
|
||||
|
||||
rel_buckets = rel_buckets + mx.where(is_small, rel_pos.astype(mx.int32), rel_pos_large)
|
||||
rel_buckets = rel_buckets + mx.where(
|
||||
is_small, rel_pos.astype(mx.int32), rel_pos_large
|
||||
)
|
||||
return rel_buckets
|
||||
|
||||
def __call__(self, lq: int, lk: int) -> mx.array:
|
||||
@@ -115,7 +114,7 @@ class T5Attention(nn.Module):
|
||||
v = v.transpose(0, 2, 1, 3)
|
||||
|
||||
# QK^T (no scaling) — compute in float32 for precision
|
||||
attn = (q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2))
|
||||
attn = q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2)
|
||||
|
||||
# Add position bias
|
||||
if pos_bias is not None:
|
||||
|
||||
Reference in New Issue
Block a user