541 lines
18 KiB
Python
541 lines
18 KiB
Python
|
|
import math
|
|
from typing import List, Optional, Tuple
|
|
|
|
import mlx.core as mx
|
|
|
|
from mlx_video.models.ltx.config import LTXRopeType
|
|
|
|
|
|
def apply_rotary_emb(
|
|
input_tensor: mx.array,
|
|
freqs_cis: Tuple[mx.array, mx.array],
|
|
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
|
) -> mx.array:
|
|
"""Apply rotary position embeddings to input tensor.
|
|
|
|
Args:
|
|
input_tensor: Input tensor to apply RoPE to
|
|
freqs_cis: Tuple of (cos_freqs, sin_freqs)
|
|
rope_type: Type of RoPE to apply (INTERLEAVED or SPLIT)
|
|
|
|
Returns:
|
|
Tensor with rotary embeddings applied
|
|
"""
|
|
if rope_type == LTXRopeType.INTERLEAVED:
|
|
return apply_interleaved_rotary_emb(input_tensor, freqs_cis[0], freqs_cis[1])
|
|
elif rope_type == LTXRopeType.SPLIT:
|
|
return apply_split_rotary_emb(input_tensor, freqs_cis[0], freqs_cis[1])
|
|
else:
|
|
raise ValueError(f"Invalid rope type: {rope_type}")
|
|
|
|
|
|
def apply_interleaved_rotary_emb(
|
|
input_tensor: mx.array,
|
|
cos_freqs: mx.array,
|
|
sin_freqs: mx.array,
|
|
) -> mx.array:
|
|
"""Apply interleaved rotary embeddings.
|
|
|
|
Pairs adjacent dimensions and applies rotation.
|
|
Pattern: [x0, x1, x2, x3, ...] -> rotate pairs (x0,x1), (x2,x3), ...
|
|
|
|
Args:
|
|
input_tensor: Input tensor of shape (..., dim)
|
|
cos_freqs: Cosine frequencies
|
|
sin_freqs: Sine frequencies
|
|
|
|
Returns:
|
|
Tensor with interleaved rotary embeddings applied
|
|
"""
|
|
# Compute in float32 for better precision
|
|
input_dtype = input_tensor.dtype
|
|
input_tensor = input_tensor.astype(mx.float32)
|
|
cos_freqs = cos_freqs.astype(mx.float32)
|
|
sin_freqs = sin_freqs.astype(mx.float32)
|
|
|
|
# Reshape to pair adjacent dimensions: (..., dim) -> (..., dim/2, 2)
|
|
shape = input_tensor.shape
|
|
input_tensor = mx.reshape(input_tensor, shape[:-1] + (shape[-1] // 2, 2))
|
|
|
|
# Extract pairs
|
|
t1 = input_tensor[..., 0] # Even indices
|
|
t2 = input_tensor[..., 1] # Odd indices
|
|
|
|
# Apply rotation: (-t2, t1) pattern
|
|
t_rot = mx.stack([-t2, t1], axis=-1)
|
|
|
|
# Flatten back: (..., dim/2, 2) -> (..., dim)
|
|
input_tensor = mx.reshape(input_tensor, shape)
|
|
t_rot = mx.reshape(t_rot, shape)
|
|
|
|
# Apply rotary embeddings
|
|
out = input_tensor * cos_freqs + t_rot * sin_freqs
|
|
|
|
return out.astype(input_dtype)
|
|
|
|
|
|
def rotate_half_interleaved(x: mx.array) -> mx.array:
|
|
"""Rotate for interleaved RoPE: [x0, x1, x2, x3] -> [-x1, x0, -x3, x2].
|
|
|
|
PyTorch equivalent:
|
|
t_dup = rearrange(x, "... (d r) -> ... d r", r=2)
|
|
t1, t2 = t_dup.unbind(dim=-1)
|
|
t_dup = torch.stack((-t2, t1), dim=-1)
|
|
return rearrange(t_dup, "... d r -> ... (d r)")
|
|
"""
|
|
# x: (..., dim) where dim is even
|
|
x_even = x[..., 0::2] # [x0, x2, x4, ...]
|
|
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
|
|
# Stack: [[-x1, x0], [-x3, x2], ...] then flatten to [-x1, x0, -x3, x2, ...]
|
|
rotated = mx.stack([-x_odd, x_even], axis=-1)
|
|
return mx.reshape(rotated, x.shape)
|
|
|
|
def apply_rotary_emb_1d(
|
|
q: mx.array,
|
|
k: mx.array,
|
|
freqs_cis: mx.array,
|
|
) -> Tuple[mx.array, mx.array]:
|
|
"""Apply 1D rotary embeddings using precomputed frequencies (interleaved)."""
|
|
# freqs_cis: (1, seq_len, num_heads, head_dim, 2) where [..., 0] = cos, [..., 1] = sin
|
|
cos = freqs_cis[..., 0] # (1, seq_len, num_heads, head_dim)
|
|
sin = freqs_cis[..., 1]
|
|
|
|
# q, k: (batch, seq_len, num_heads, head_dim)
|
|
# Interleaved RoPE: pairs of adjacent dims rotate together
|
|
q_r = q * cos + rotate_half_interleaved(q) * sin
|
|
k_r = k * cos + rotate_half_interleaved(k) * sin
|
|
|
|
return q_r, k_r
|
|
|
|
|
|
def apply_split_rotary_emb(
|
|
input_tensor: mx.array,
|
|
cos_freqs: mx.array,
|
|
sin_freqs: mx.array,
|
|
) -> mx.array:
|
|
"""Apply split rotary embeddings.
|
|
|
|
Splits dimensions into two halves and applies rotation.
|
|
Pattern: split into first half and second half
|
|
|
|
Args:
|
|
input_tensor: Input tensor
|
|
cos_freqs: Cosine frequencies of shape (B, H, T, D//2)
|
|
sin_freqs: Sine frequencies of shape (B, H, T, D//2)
|
|
|
|
Returns:
|
|
Tensor with split rotary embeddings applied
|
|
"""
|
|
input_dtype = input_tensor.dtype
|
|
needs_reshape = False
|
|
original_shape = input_tensor.shape
|
|
|
|
# Handle dimension mismatch
|
|
if input_tensor.ndim != 4 and cos_freqs.ndim == 4:
|
|
b, h, t, _ = cos_freqs.shape
|
|
# Reshape from (B, T, H*D) to (B, H, T, D)
|
|
input_tensor = mx.reshape(input_tensor, (b, t, h, -1))
|
|
input_tensor = mx.swapaxes(input_tensor, 1, 2)
|
|
needs_reshape = True
|
|
|
|
# Cast to float32 for computation precision
|
|
input_tensor = input_tensor.astype(mx.float32)
|
|
cos_freqs = cos_freqs.astype(mx.float32)
|
|
sin_freqs = sin_freqs.astype(mx.float32)
|
|
|
|
# Split into two halves: (..., dim) -> (..., 2, dim//2)
|
|
dim = input_tensor.shape[-1]
|
|
split_input = mx.reshape(input_tensor, input_tensor.shape[:-1] + (2, dim // 2))
|
|
|
|
# Get first and second halves
|
|
first_half = split_input[..., 0, :] # (..., dim//2)
|
|
second_half = split_input[..., 1, :] # (..., dim//2)
|
|
|
|
# Apply cosine to both halves
|
|
output_first = first_half * cos_freqs
|
|
output_second = second_half * cos_freqs
|
|
|
|
# Apply sine cross-terms (addcmul pattern)
|
|
output_first = output_first - sin_freqs * second_half
|
|
output_second = output_second + sin_freqs * first_half
|
|
|
|
# Stack back together
|
|
output = mx.stack([output_first, output_second], axis=-2)
|
|
|
|
# Flatten: (..., 2, dim//2) -> (..., dim)
|
|
output = mx.reshape(output, input_tensor.shape)
|
|
|
|
if needs_reshape:
|
|
# Reshape back: (B, H, T, D) -> (B, T, H*D)
|
|
b, h, t, d = output.shape
|
|
output = mx.swapaxes(output, 1, 2)
|
|
output = mx.reshape(output, (b, t, h * d))
|
|
|
|
return output.astype(input_dtype)
|
|
|
|
|
|
def generate_freq_grid(
|
|
positional_embedding_theta: float,
|
|
positional_embedding_max_pos_count: int,
|
|
inner_dim: int,
|
|
) -> mx.array:
|
|
"""Generate frequency grid for RoPE.
|
|
|
|
Args:
|
|
positional_embedding_theta: Base theta value
|
|
positional_embedding_max_pos_count: Number of position dimensions
|
|
inner_dim: Inner dimension of the model
|
|
|
|
Returns:
|
|
Frequency indices tensor
|
|
"""
|
|
theta = positional_embedding_theta
|
|
start = 1.0
|
|
end = theta
|
|
|
|
n_elem = 2 * positional_embedding_max_pos_count
|
|
|
|
# Compute logarithmic spacing
|
|
log_start = math.log(start) / math.log(theta)
|
|
log_end = math.log(end) / math.log(theta)
|
|
|
|
num_indices = inner_dim // n_elem
|
|
if num_indices == 0:
|
|
num_indices = 1
|
|
|
|
# Create linearly spaced values in log space
|
|
lin_space = mx.linspace(log_start, log_end, num_indices)
|
|
|
|
# Compute power indices
|
|
pow_indices = mx.power(theta, lin_space)
|
|
|
|
# Scale by pi/2
|
|
return pow_indices * (math.pi / 2)
|
|
|
|
|
|
def get_fractional_positions(
|
|
indices_grid: mx.array,
|
|
max_pos: List[int],
|
|
) -> mx.array:
|
|
"""Convert indices to fractional positions.
|
|
|
|
Args:
|
|
indices_grid: Grid of position indices of shape (B, n_pos_dims, ...)
|
|
max_pos: Maximum position for each dimension
|
|
|
|
Returns:
|
|
Fractional positions in range [-1, 1] after scaling
|
|
"""
|
|
n_pos_dims = indices_grid.shape[1]
|
|
assert n_pos_dims == len(max_pos), (
|
|
f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
|
|
)
|
|
|
|
# Divide each dimension by its max position
|
|
fractional_positions = []
|
|
for i in range(n_pos_dims):
|
|
frac = indices_grid[:, i] / max_pos[i]
|
|
fractional_positions.append(frac)
|
|
|
|
return mx.stack(fractional_positions, axis=-1)
|
|
|
|
|
|
def generate_freqs(
|
|
indices: mx.array,
|
|
indices_grid: mx.array,
|
|
max_pos: List[int],
|
|
use_middle_indices_grid: bool,
|
|
) -> mx.array:
|
|
"""Generate frequencies from indices and position grid.
|
|
|
|
Args:
|
|
indices: Frequency indices
|
|
indices_grid: Position indices grid
|
|
max_pos: Maximum positions per dimension
|
|
use_middle_indices_grid: Whether to use middle of index ranges
|
|
|
|
Returns:
|
|
Frequency tensor
|
|
"""
|
|
# Handle middle indices grid
|
|
if use_middle_indices_grid:
|
|
# indices_grid shape: (B, n_dims, T, 2) where last dim is [start, end]
|
|
assert len(indices_grid.shape) == 4
|
|
assert indices_grid.shape[-1] == 2
|
|
indices_grid_start = indices_grid[..., 0]
|
|
indices_grid_end = indices_grid[..., 1]
|
|
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
|
|
elif len(indices_grid.shape) == 4:
|
|
indices_grid = indices_grid[..., 0]
|
|
|
|
# Get fractional positions
|
|
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
|
|
|
# Compute frequencies
|
|
# fractional_positions: (B, T, n_dims)
|
|
# indices: (inner_dim // n_elem,)
|
|
# Result: (B, T, inner_dim // n_elem * n_dims)
|
|
|
|
# Scale fractional positions to [-1, 1]
|
|
scaled_positions = fractional_positions * 2 - 1 # (B, T, n_dims)
|
|
|
|
# Outer product with indices
|
|
# (B, T, n_dims, 1) * (1, 1, 1, n_indices) -> (B, T, n_dims, n_indices)
|
|
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.expand_dims(
|
|
mx.expand_dims(mx.expand_dims(indices, axis=0), axis=0), axis=0
|
|
)
|
|
|
|
# Transpose and flatten: (B, T, n_dims, n_indices) -> (B, T, n_indices * n_dims)
|
|
freqs = mx.swapaxes(freqs, -1, -2) # (B, T, n_indices, n_dims)
|
|
freqs = mx.reshape(freqs, freqs.shape[:-2] + (-1,))
|
|
|
|
return freqs
|
|
|
|
|
|
def split_freqs_cis(
|
|
freqs: mx.array,
|
|
pad_size: int,
|
|
num_attention_heads: int,
|
|
) -> Tuple[mx.array, mx.array]:
|
|
"""Prepare cos/sin frequencies for split RoPE.
|
|
|
|
Args:
|
|
freqs: Frequency tensor
|
|
pad_size: Padding size for dimension alignment
|
|
num_attention_heads: Number of attention heads
|
|
|
|
Returns:
|
|
Tuple of (cos_freq, sin_freq) with shape (B, H, T, D//2)
|
|
"""
|
|
cos_freq = mx.cos(freqs)
|
|
sin_freq = mx.sin(freqs)
|
|
|
|
# Add padding if needed
|
|
if pad_size != 0:
|
|
cos_padding = mx.ones_like(cos_freq[:, :, :pad_size])
|
|
sin_padding = mx.zeros_like(sin_freq[:, :, :pad_size])
|
|
|
|
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
|
|
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
|
|
|
|
# Reshape for multi-head attention
|
|
b, t = cos_freq.shape[0], cos_freq.shape[1]
|
|
|
|
cos_freq = mx.reshape(cos_freq, (b, t, num_attention_heads, -1))
|
|
sin_freq = mx.reshape(sin_freq, (b, t, num_attention_heads, -1))
|
|
|
|
# Swap axes: (B, T, H, D//2) -> (B, H, T, D//2)
|
|
cos_freq = mx.swapaxes(cos_freq, 1, 2)
|
|
sin_freq = mx.swapaxes(sin_freq, 1, 2)
|
|
|
|
return cos_freq, sin_freq
|
|
|
|
|
|
def interleaved_freqs_cis(
|
|
freqs: mx.array,
|
|
pad_size: int,
|
|
) -> Tuple[mx.array, mx.array]:
|
|
"""Prepare cos/sin frequencies for interleaved RoPE.
|
|
|
|
Args:
|
|
freqs: Frequency tensor of shape (B, T, dim//2)
|
|
pad_size: Padding size for dimension alignment
|
|
|
|
Returns:
|
|
Tuple of (cos_freq, sin_freq) with shape (B, T, dim)
|
|
"""
|
|
# Compute cos and sin
|
|
cos_freq = mx.cos(freqs)
|
|
sin_freq = mx.sin(freqs)
|
|
|
|
# Repeat interleave: each element repeated twice
|
|
# (B, T, D) -> (B, T, 2*D) with pattern [c0, c0, c1, c1, ...]
|
|
cos_freq = mx.repeat(cos_freq, 2, axis=-1)
|
|
sin_freq = mx.repeat(sin_freq, 2, axis=-1)
|
|
|
|
# Add padding if needed
|
|
if pad_size != 0:
|
|
cos_padding = mx.ones_like(cos_freq[:, :, :pad_size])
|
|
sin_padding = mx.zeros_like(sin_freq[:, :, :pad_size])
|
|
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
|
|
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
|
|
|
|
return cos_freq, sin_freq
|
|
|
|
|
|
def precompute_freqs_cis(
|
|
indices_grid: mx.array,
|
|
dim: int,
|
|
theta: float = 10000.0,
|
|
max_pos: Optional[List[int]] = None,
|
|
use_middle_indices_grid: bool = False,
|
|
num_attention_heads: int = 32,
|
|
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
|
double_precision: bool = False,
|
|
) -> Tuple[mx.array, mx.array]:
|
|
"""Precompute RoPE frequencies.
|
|
|
|
Args:
|
|
indices_grid: Position indices grid
|
|
dim: Dimension for RoPE
|
|
theta: Base theta value for frequency computation
|
|
max_pos: Maximum position per dimension
|
|
use_middle_indices_grid: Whether to use middle indices
|
|
num_attention_heads: Number of attention heads
|
|
rope_type: Type of RoPE (INTERLEAVED or SPLIT)
|
|
double_precision: If True, compute frequencies in float64 for higher precision
|
|
|
|
Returns:
|
|
Tuple of (cos_freq, sin_freq) tensors
|
|
"""
|
|
if max_pos is None:
|
|
max_pos = [20, 2048, 2048]
|
|
|
|
|
|
if double_precision:
|
|
return _precompute_freqs_cis_double_precision(
|
|
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
|
|
num_attention_heads, rope_type
|
|
)
|
|
|
|
# Keep positions in float32 for RoPE computation.
|
|
# Even though PyTorch nominally casts positions to model dtype (bfloat16),
|
|
# empirical comparison shows float32 positions produce RoPE values matching
|
|
# PyTorch exactly (cosine=1.000). BFloat16 loses precision in fractional
|
|
# position computation that gets amplified by high-frequency indices
|
|
# (up to 15708), causing cos/sin sign flips and cosine sim of only 0.88.
|
|
indices_grid = indices_grid.astype(mx.float32)
|
|
|
|
# Generate frequency indices
|
|
indices = generate_freq_grid(theta, indices_grid.shape[1], dim)
|
|
|
|
# Generate frequencies
|
|
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
|
|
|
|
# Prepare cos/sin based on rope type
|
|
if rope_type == LTXRopeType.SPLIT:
|
|
expected_freqs = dim // 2
|
|
current_freqs = freqs.shape[-1]
|
|
pad_size = expected_freqs - current_freqs
|
|
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
|
|
else:
|
|
# Interleaved
|
|
n_elem = 2 * indices_grid.shape[1]
|
|
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
|
|
|
return cos_freq, sin_freq
|
|
|
|
|
|
def _precompute_freqs_cis_double_precision(
|
|
indices_grid: mx.array,
|
|
dim: int,
|
|
theta: float,
|
|
max_pos: List[int],
|
|
use_middle_indices_grid: bool,
|
|
num_attention_heads: int,
|
|
rope_type: LTXRopeType,
|
|
) -> Tuple[mx.array, mx.array]:
|
|
"""Compute RoPE frequencies with higher precision using float64 for frequency grid.
|
|
|
|
Matches PyTorch's generate_freq_grid_np: uses NumPy float64 for the critical
|
|
frequency grid computation (log-spaced values), then converts to float32.
|
|
Position grid stays in bfloat16 to match PyTorch behavior (positions are in
|
|
model dtype throughout generate_freqs).
|
|
"""
|
|
import numpy as np
|
|
|
|
# Keep positions in float32 — same reasoning as the non-double-precision path.
|
|
indices_grid_f32 = indices_grid.astype(mx.float32)
|
|
|
|
n_pos_dims = indices_grid_f32.shape[1]
|
|
n_elem = 2 * n_pos_dims
|
|
|
|
# Compute log-spaced frequencies in float64 (matching PyTorch's generate_freq_grid_np)
|
|
# This is the critical precision step - PyTorch uses np.float64 here
|
|
log_start = np.log(1.0) / np.log(theta)
|
|
log_end = np.log(theta) / np.log(theta) # = 1.0
|
|
num_indices = dim // n_elem
|
|
if num_indices == 0:
|
|
num_indices = 1
|
|
|
|
# Use numpy float64 for the linspace computation (matches PyTorch)
|
|
pow_indices = np.power(
|
|
theta,
|
|
np.linspace(log_start, log_end, num_indices, dtype=np.float64),
|
|
)
|
|
# Convert to float32 tensor (matches PyTorch: torch.tensor(..., dtype=torch.float32))
|
|
freq_indices = mx.array(pow_indices * (math.pi / 2), dtype=mx.float32)
|
|
|
|
# Handle middle indices grid
|
|
# Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise
|
|
if use_middle_indices_grid:
|
|
assert len(indices_grid_f32.shape) == 4
|
|
assert indices_grid_f32.shape[-1] == 2
|
|
indices_grid_start = indices_grid_f32[..., 0]
|
|
indices_grid_end = indices_grid_f32[..., 1]
|
|
indices_grid_f32 = (indices_grid_start + indices_grid_end) / 2.0
|
|
elif len(indices_grid_f32.shape) == 4:
|
|
indices_grid_f32 = indices_grid_f32[..., 0]
|
|
# After handling: indices_grid_f32 shape is (B, n_dims, T)
|
|
|
|
# Get fractional positions: (B, n_dims, T) -> (B, T, n_dims)
|
|
# Compute fractional positions for each dimension
|
|
fractional_list = []
|
|
for i in range(n_pos_dims):
|
|
frac = indices_grid_f32[:, i, :] / max_pos[i] # (B, T)
|
|
fractional_list.append(frac)
|
|
|
|
# Stack: (B, T, n_dims)
|
|
fractional_positions = mx.stack(fractional_list, axis=-1)
|
|
|
|
# Scale to [-1, 1]
|
|
scaled_positions = fractional_positions * 2 - 1
|
|
|
|
# Compute frequencies: outer product
|
|
# scaled_positions: (B, T, n_dims) -> (B, T, n_dims, 1)
|
|
# freq_indices: (num_indices,) -> (1, 1, 1, num_indices)
|
|
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(freq_indices, (1, 1, 1, -1))
|
|
# freqs: (B, T, n_dims, num_indices)
|
|
|
|
# Transpose and flatten: (B, T, n_dims, num_indices) -> (B, T, num_indices, n_dims) -> (B, T, num_indices * n_dims)
|
|
freqs = mx.swapaxes(freqs, -1, -2)
|
|
freqs = mx.reshape(freqs, (freqs.shape[0], freqs.shape[1], -1))
|
|
|
|
# Compute cos/sin
|
|
cos_freq = mx.cos(freqs)
|
|
sin_freq = mx.sin(freqs)
|
|
|
|
# Prepare based on rope type
|
|
if rope_type == LTXRopeType.SPLIT:
|
|
expected_freqs = dim // 2
|
|
current_freqs = cos_freq.shape[-1]
|
|
pad_size = expected_freqs - current_freqs
|
|
|
|
# Add padding
|
|
if pad_size > 0:
|
|
cos_padding = mx.ones((*cos_freq.shape[:-1], pad_size), dtype=mx.float32)
|
|
sin_padding = mx.zeros((*sin_freq.shape[:-1], pad_size), dtype=mx.float32)
|
|
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
|
|
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
|
|
|
|
# Reshape for multi-head attention: (B, T, dim//2) -> (B, H, T, dim//2//H)
|
|
b, t = cos_freq.shape[0], cos_freq.shape[1]
|
|
cos_freq = mx.reshape(cos_freq, (b, t, num_attention_heads, -1))
|
|
sin_freq = mx.reshape(sin_freq, (b, t, num_attention_heads, -1))
|
|
cos_freq = mx.swapaxes(cos_freq, 1, 2)
|
|
sin_freq = mx.swapaxes(sin_freq, 1, 2)
|
|
else:
|
|
# Interleaved
|
|
cos_freq = mx.repeat(cos_freq, 2, axis=-1)
|
|
sin_freq = mx.repeat(sin_freq, 2, axis=-1)
|
|
|
|
pad_size = dim % n_elem
|
|
if pad_size > 0:
|
|
cos_padding = mx.ones((*cos_freq.shape[:-1], pad_size), dtype=mx.float32)
|
|
sin_padding = mx.zeros((*sin_freq.shape[:-1], pad_size), dtype=mx.float32)
|
|
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
|
|
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
|
|
|
|
return cos_freq, sin_freq
|