format
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
@@ -11,13 +10,16 @@ def rope_params(max_seq_len: int, dim: int, theta: float = 10000.0) -> mx.array:
|
||||
Complex frequency tensor of shape [max_seq_len, dim // 2].
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
freqs = np.arange(max_seq_len, dtype=np.float64)[:, None] * (
|
||||
1.0
|
||||
/ np.power(
|
||||
theta,
|
||||
np.arange(0, dim, 2, dtype=np.float64) / dim,
|
||||
)
|
||||
)[None, :]
|
||||
freqs = (
|
||||
np.arange(max_seq_len, dtype=np.float64)[:, None]
|
||||
* (
|
||||
1.0
|
||||
/ np.power(
|
||||
theta,
|
||||
np.arange(0, dim, 2, dtype=np.float64) / dim,
|
||||
)
|
||||
)[None, :]
|
||||
)
|
||||
# Store as (cos, sin) pairs: shape [max_seq_len, dim // 2, 2]
|
||||
cos_freqs = np.cos(freqs).astype(np.float32)
|
||||
sin_freqs = np.sin(freqs).astype(np.float32)
|
||||
@@ -46,9 +48,9 @@ def rope_apply(
|
||||
# Check if all batch elements have the same grid (common for CFG B=2)
|
||||
f0, h0, w0 = grid_sizes[0]
|
||||
seq_len = f0 * h0 * w0
|
||||
all_same_grid = all(
|
||||
grid_sizes[i] == grid_sizes[0] for i in range(1, b)
|
||||
) if b > 1 else True
|
||||
all_same_grid = (
|
||||
all(grid_sizes[i] == grid_sizes[0] for i in range(1, b)) if b > 1 else True
|
||||
)
|
||||
|
||||
if all_same_grid:
|
||||
# Vectorized path: apply RoPE to all batch elements at once
|
||||
@@ -57,7 +59,9 @@ def rope_apply(
|
||||
x_imag = x_seq[..., 1]
|
||||
out_real = x_real * cos_f - x_imag * sin_f
|
||||
out_imag = x_real * sin_f + x_imag * cos_f
|
||||
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(b, seq_len, n, d)
|
||||
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(
|
||||
b, seq_len, n, d
|
||||
)
|
||||
if seq_len < s:
|
||||
x_rotated = mx.concatenate([x_rotated, x[:, seq_len:]], axis=1)
|
||||
return x_rotated
|
||||
@@ -102,17 +106,11 @@ def rope_apply(
|
||||
|
||||
# Build per-position frequencies by expanding along grid dims
|
||||
# temporal: [f,1,1,d_t,2] -> [f,h,w,d_t,2]
|
||||
ft = mx.broadcast_to(
|
||||
freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2)
|
||||
)
|
||||
ft = mx.broadcast_to(freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2))
|
||||
# height: [1,h,1,d_h,2] -> [f,h,w,d_h,2]
|
||||
fh = mx.broadcast_to(
|
||||
freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2)
|
||||
)
|
||||
fh = mx.broadcast_to(freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2))
|
||||
# width: [1,1,w,d_w,2] -> [f,h,w,d_w,2]
|
||||
fw = mx.broadcast_to(
|
||||
freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2)
|
||||
)
|
||||
fw = mx.broadcast_to(freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2))
|
||||
|
||||
# Concatenate: [f*h*w, half_d, 2]
|
||||
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)
|
||||
|
||||
Reference in New Issue
Block a user