initial commit (LTX-2)
This commit is contained in:
508
mlx_video/models/ltx/rope.py
Normal file
508
mlx_video/models/ltx/rope.py
Normal file
@@ -0,0 +1,508 @@
|
||||
|
||||
import math
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
from mlx_video.models.ltx.config import LTXRopeType
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
input_tensor: mx.array,
|
||||
freqs_cis: Tuple[mx.array, mx.array],
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||
) -> mx.array:
|
||||
"""Apply rotary position embeddings to input tensor.
|
||||
|
||||
Args:
|
||||
input_tensor: Input tensor to apply RoPE to
|
||||
freqs_cis: Tuple of (cos_freqs, sin_freqs)
|
||||
rope_type: Type of RoPE to apply (INTERLEAVED or SPLIT)
|
||||
|
||||
Returns:
|
||||
Tensor with rotary embeddings applied
|
||||
"""
|
||||
if rope_type == LTXRopeType.INTERLEAVED:
|
||||
return apply_interleaved_rotary_emb(input_tensor, freqs_cis[0], freqs_cis[1])
|
||||
elif rope_type == LTXRopeType.SPLIT:
|
||||
return apply_split_rotary_emb(input_tensor, freqs_cis[0], freqs_cis[1])
|
||||
else:
|
||||
raise ValueError(f"Invalid rope type: {rope_type}")
|
||||
|
||||
|
||||
def apply_interleaved_rotary_emb(
|
||||
input_tensor: mx.array,
|
||||
cos_freqs: mx.array,
|
||||
sin_freqs: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply interleaved rotary embeddings.
|
||||
|
||||
Pairs adjacent dimensions and applies rotation.
|
||||
Pattern: [x0, x1, x2, x3, ...] -> rotate pairs (x0,x1), (x2,x3), ...
|
||||
|
||||
Args:
|
||||
input_tensor: Input tensor of shape (..., dim)
|
||||
cos_freqs: Cosine frequencies
|
||||
sin_freqs: Sine frequencies
|
||||
|
||||
Returns:
|
||||
Tensor with interleaved rotary embeddings applied
|
||||
"""
|
||||
# Reshape to pair adjacent dimensions: (..., dim) -> (..., dim/2, 2)
|
||||
shape = input_tensor.shape
|
||||
input_tensor = mx.reshape(input_tensor, shape[:-1] + (shape[-1] // 2, 2))
|
||||
|
||||
# Extract pairs
|
||||
t1 = input_tensor[..., 0] # Even indices
|
||||
t2 = input_tensor[..., 1] # Odd indices
|
||||
|
||||
# Apply rotation: (-t2, t1) pattern
|
||||
t_rot = mx.stack([-t2, t1], axis=-1)
|
||||
|
||||
# Flatten back: (..., dim/2, 2) -> (..., dim)
|
||||
input_tensor = mx.reshape(input_tensor, shape)
|
||||
t_rot = mx.reshape(t_rot, shape)
|
||||
|
||||
# Apply rotary embeddings
|
||||
out = input_tensor * cos_freqs + t_rot * sin_freqs
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def rotate_half_interleaved(x: mx.array) -> mx.array:
|
||||
"""Rotate for interleaved RoPE: [x0, x1, x2, x3] -> [-x1, x0, -x3, x2].
|
||||
|
||||
PyTorch equivalent:
|
||||
t_dup = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||
t1, t2 = t_dup.unbind(dim=-1)
|
||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||
return rearrange(t_dup, "... d r -> ... (d r)")
|
||||
"""
|
||||
# x: (..., dim) where dim is even
|
||||
x_even = x[..., 0::2] # [x0, x2, x4, ...]
|
||||
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
|
||||
# Stack: [[-x1, x0], [-x3, x2], ...] then flatten to [-x1, x0, -x3, x2, ...]
|
||||
rotated = mx.stack([-x_odd, x_even], axis=-1)
|
||||
return mx.reshape(rotated, x.shape)
|
||||
|
||||
def apply_rotary_emb_1d(
|
||||
q: mx.array,
|
||||
k: mx.array,
|
||||
freqs_cis: mx.array,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Apply 1D rotary embeddings using precomputed frequencies (interleaved)."""
|
||||
# freqs_cis: (1, seq_len, num_heads, head_dim, 2) where [..., 0] = cos, [..., 1] = sin
|
||||
cos = freqs_cis[..., 0] # (1, seq_len, num_heads, head_dim)
|
||||
sin = freqs_cis[..., 1]
|
||||
|
||||
# q, k: (batch, seq_len, num_heads, head_dim)
|
||||
# Interleaved RoPE: pairs of adjacent dims rotate together
|
||||
q_r = q * cos + rotate_half_interleaved(q) * sin
|
||||
k_r = k * cos + rotate_half_interleaved(k) * sin
|
||||
|
||||
return q_r, k_r
|
||||
|
||||
|
||||
def apply_split_rotary_emb(
|
||||
input_tensor: mx.array,
|
||||
cos_freqs: mx.array,
|
||||
sin_freqs: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply split rotary embeddings.
|
||||
|
||||
Splits dimensions into two halves and applies rotation.
|
||||
Pattern: split into first half and second half
|
||||
|
||||
Args:
|
||||
input_tensor: Input tensor
|
||||
cos_freqs: Cosine frequencies of shape (B, H, T, D//2)
|
||||
sin_freqs: Sine frequencies of shape (B, H, T, D//2)
|
||||
|
||||
Returns:
|
||||
Tensor with split rotary embeddings applied
|
||||
"""
|
||||
needs_reshape = False
|
||||
original_shape = input_tensor.shape
|
||||
|
||||
# Handle dimension mismatch
|
||||
if input_tensor.ndim != 4 and cos_freqs.ndim == 4:
|
||||
b, h, t, _ = cos_freqs.shape
|
||||
# Reshape from (B, T, H*D) to (B, H, T, D)
|
||||
input_tensor = mx.reshape(input_tensor, (b, t, h, -1))
|
||||
input_tensor = mx.swapaxes(input_tensor, 1, 2)
|
||||
needs_reshape = True
|
||||
|
||||
# Split into two halves: (..., dim) -> (..., 2, dim//2)
|
||||
dim = input_tensor.shape[-1]
|
||||
split_input = mx.reshape(input_tensor, input_tensor.shape[:-1] + (2, dim // 2))
|
||||
|
||||
# Get first and second halves
|
||||
first_half = split_input[..., 0, :] # (..., dim//2)
|
||||
second_half = split_input[..., 1, :] # (..., dim//2)
|
||||
|
||||
# Apply cosine to both halves
|
||||
output_first = first_half * cos_freqs
|
||||
output_second = second_half * cos_freqs
|
||||
|
||||
# Apply sine cross-terms (addcmul pattern)
|
||||
output_first = output_first - sin_freqs * second_half
|
||||
output_second = output_second + sin_freqs * first_half
|
||||
|
||||
# Stack back together
|
||||
output = mx.stack([output_first, output_second], axis=-2)
|
||||
|
||||
# Flatten: (..., 2, dim//2) -> (..., dim)
|
||||
output = mx.reshape(output, input_tensor.shape)
|
||||
|
||||
if needs_reshape:
|
||||
# Reshape back: (B, H, T, D) -> (B, T, H*D)
|
||||
b, h, t, d = output.shape
|
||||
output = mx.swapaxes(output, 1, 2)
|
||||
output = mx.reshape(output, (b, t, h * d))
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def generate_freq_grid(
|
||||
positional_embedding_theta: float,
|
||||
positional_embedding_max_pos_count: int,
|
||||
inner_dim: int,
|
||||
) -> mx.array:
|
||||
"""Generate frequency grid for RoPE.
|
||||
|
||||
Args:
|
||||
positional_embedding_theta: Base theta value
|
||||
positional_embedding_max_pos_count: Number of position dimensions
|
||||
inner_dim: Inner dimension of the model
|
||||
|
||||
Returns:
|
||||
Frequency indices tensor
|
||||
"""
|
||||
theta = positional_embedding_theta
|
||||
start = 1.0
|
||||
end = theta
|
||||
|
||||
n_elem = 2 * positional_embedding_max_pos_count
|
||||
|
||||
# Compute logarithmic spacing
|
||||
log_start = math.log(start) / math.log(theta)
|
||||
log_end = math.log(end) / math.log(theta)
|
||||
|
||||
num_indices = inner_dim // n_elem
|
||||
if num_indices == 0:
|
||||
num_indices = 1
|
||||
|
||||
# Create linearly spaced values in log space
|
||||
lin_space = mx.linspace(log_start, log_end, num_indices)
|
||||
|
||||
# Compute power indices
|
||||
pow_indices = mx.power(theta, lin_space)
|
||||
|
||||
# Scale by pi/2
|
||||
return pow_indices * (math.pi / 2)
|
||||
|
||||
|
||||
def get_fractional_positions(
|
||||
indices_grid: mx.array,
|
||||
max_pos: List[int],
|
||||
) -> mx.array:
|
||||
"""Convert indices to fractional positions.
|
||||
|
||||
Args:
|
||||
indices_grid: Grid of position indices of shape (B, n_pos_dims, ...)
|
||||
max_pos: Maximum position for each dimension
|
||||
|
||||
Returns:
|
||||
Fractional positions in range [-1, 1] after scaling
|
||||
"""
|
||||
n_pos_dims = indices_grid.shape[1]
|
||||
assert n_pos_dims == len(max_pos), (
|
||||
f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
|
||||
)
|
||||
|
||||
# Divide each dimension by its max position
|
||||
fractional_positions = []
|
||||
for i in range(n_pos_dims):
|
||||
frac = indices_grid[:, i] / max_pos[i]
|
||||
fractional_positions.append(frac)
|
||||
|
||||
return mx.stack(fractional_positions, axis=-1)
|
||||
|
||||
|
||||
def generate_freqs(
|
||||
indices: mx.array,
|
||||
indices_grid: mx.array,
|
||||
max_pos: List[int],
|
||||
use_middle_indices_grid: bool,
|
||||
) -> mx.array:
|
||||
"""Generate frequencies from indices and position grid.
|
||||
|
||||
Args:
|
||||
indices: Frequency indices
|
||||
indices_grid: Position indices grid
|
||||
max_pos: Maximum positions per dimension
|
||||
use_middle_indices_grid: Whether to use middle of index ranges
|
||||
|
||||
Returns:
|
||||
Frequency tensor
|
||||
"""
|
||||
# Handle middle indices grid
|
||||
if use_middle_indices_grid:
|
||||
# indices_grid shape: (B, n_dims, T, 2) where last dim is [start, end]
|
||||
assert len(indices_grid.shape) == 4
|
||||
assert indices_grid.shape[-1] == 2
|
||||
indices_grid_start = indices_grid[..., 0]
|
||||
indices_grid_end = indices_grid[..., 1]
|
||||
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
|
||||
elif len(indices_grid.shape) == 4:
|
||||
indices_grid = indices_grid[..., 0]
|
||||
|
||||
# Get fractional positions
|
||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||
|
||||
# Compute frequencies
|
||||
# fractional_positions: (B, T, n_dims)
|
||||
# indices: (inner_dim // n_elem,)
|
||||
# Result: (B, T, inner_dim // n_elem * n_dims)
|
||||
|
||||
# Scale fractional positions to [-1, 1]
|
||||
scaled_positions = fractional_positions * 2 - 1 # (B, T, n_dims)
|
||||
|
||||
# Outer product with indices
|
||||
# (B, T, n_dims, 1) * (1, 1, 1, n_indices) -> (B, T, n_dims, n_indices)
|
||||
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.expand_dims(
|
||||
mx.expand_dims(mx.expand_dims(indices, axis=0), axis=0), axis=0
|
||||
)
|
||||
|
||||
# Transpose and flatten: (B, T, n_dims, n_indices) -> (B, T, n_indices * n_dims)
|
||||
freqs = mx.swapaxes(freqs, -1, -2) # (B, T, n_indices, n_dims)
|
||||
freqs = mx.reshape(freqs, freqs.shape[:-2] + (-1,))
|
||||
|
||||
return freqs
|
||||
|
||||
|
||||
def split_freqs_cis(
|
||||
freqs: mx.array,
|
||||
pad_size: int,
|
||||
num_attention_heads: int,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Prepare cos/sin frequencies for split RoPE.
|
||||
|
||||
Args:
|
||||
freqs: Frequency tensor
|
||||
pad_size: Padding size for dimension alignment
|
||||
num_attention_heads: Number of attention heads
|
||||
|
||||
Returns:
|
||||
Tuple of (cos_freq, sin_freq) with shape (B, H, T, D//2)
|
||||
"""
|
||||
cos_freq = mx.cos(freqs)
|
||||
sin_freq = mx.sin(freqs)
|
||||
|
||||
# Add padding if needed
|
||||
if pad_size != 0:
|
||||
cos_padding = mx.ones_like(cos_freq[:, :, :pad_size])
|
||||
sin_padding = mx.zeros_like(sin_freq[:, :, :pad_size])
|
||||
|
||||
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
|
||||
# Reshape for multi-head attention
|
||||
b, t = cos_freq.shape[0], cos_freq.shape[1]
|
||||
|
||||
cos_freq = mx.reshape(cos_freq, (b, t, num_attention_heads, -1))
|
||||
sin_freq = mx.reshape(sin_freq, (b, t, num_attention_heads, -1))
|
||||
|
||||
# Swap axes: (B, T, H, D//2) -> (B, H, T, D//2)
|
||||
cos_freq = mx.swapaxes(cos_freq, 1, 2)
|
||||
sin_freq = mx.swapaxes(sin_freq, 1, 2)
|
||||
|
||||
return cos_freq, sin_freq
|
||||
|
||||
|
||||
def interleaved_freqs_cis(
|
||||
freqs: mx.array,
|
||||
pad_size: int,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Prepare cos/sin frequencies for interleaved RoPE.
|
||||
|
||||
Args:
|
||||
freqs: Frequency tensor of shape (B, T, dim//2)
|
||||
pad_size: Padding size for dimension alignment
|
||||
|
||||
Returns:
|
||||
Tuple of (cos_freq, sin_freq) with shape (B, T, dim)
|
||||
"""
|
||||
# Compute cos and sin
|
||||
cos_freq = mx.cos(freqs)
|
||||
sin_freq = mx.sin(freqs)
|
||||
|
||||
# Repeat interleave: each element repeated twice
|
||||
# (B, T, D) -> (B, T, 2*D) with pattern [c0, c0, c1, c1, ...]
|
||||
cos_freq = mx.repeat(cos_freq, 2, axis=-1)
|
||||
sin_freq = mx.repeat(sin_freq, 2, axis=-1)
|
||||
|
||||
# Add padding if needed
|
||||
if pad_size != 0:
|
||||
cos_padding = mx.ones_like(cos_freq[:, :, :pad_size])
|
||||
sin_padding = mx.zeros_like(sin_freq[:, :, :pad_size])
|
||||
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
|
||||
return cos_freq, sin_freq
|
||||
|
||||
|
||||
def precompute_freqs_cis(
|
||||
indices_grid: mx.array,
|
||||
dim: int,
|
||||
theta: float = 10000.0,
|
||||
max_pos: Optional[List[int]] = None,
|
||||
use_middle_indices_grid: bool = False,
|
||||
num_attention_heads: int = 32,
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||
double_precision: bool = False,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Precompute RoPE frequencies.
|
||||
|
||||
Args:
|
||||
indices_grid: Position indices grid
|
||||
dim: Dimension for RoPE
|
||||
theta: Base theta value for frequency computation
|
||||
max_pos: Maximum position per dimension
|
||||
use_middle_indices_grid: Whether to use middle indices
|
||||
num_attention_heads: Number of attention heads
|
||||
rope_type: Type of RoPE (INTERLEAVED or SPLIT)
|
||||
double_precision: If True, compute frequencies in float64 for higher precision
|
||||
|
||||
Returns:
|
||||
Tuple of (cos_freq, sin_freq) tensors
|
||||
"""
|
||||
if max_pos is None:
|
||||
max_pos = [20, 2048, 2048]
|
||||
|
||||
# For double precision, compute in numpy (float64) then convert back to MLX
|
||||
# MLX GPU doesn't support float64, so we use numpy for high precision computation
|
||||
if double_precision:
|
||||
return _precompute_freqs_cis_double_precision(
|
||||
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
|
||||
num_attention_heads, rope_type
|
||||
)
|
||||
|
||||
# Generate frequency indices
|
||||
indices = generate_freq_grid(theta, indices_grid.shape[1], dim)
|
||||
|
||||
# Generate frequencies
|
||||
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
|
||||
|
||||
# Prepare cos/sin based on rope type
|
||||
if rope_type == LTXRopeType.SPLIT:
|
||||
expected_freqs = dim // 2
|
||||
current_freqs = freqs.shape[-1]
|
||||
pad_size = expected_freqs - current_freqs
|
||||
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
|
||||
else:
|
||||
# Interleaved
|
||||
n_elem = 2 * indices_grid.shape[1]
|
||||
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||
|
||||
return cos_freq, sin_freq
|
||||
|
||||
|
||||
def _precompute_freqs_cis_double_precision(
|
||||
indices_grid: mx.array,
|
||||
dim: int,
|
||||
theta: float,
|
||||
max_pos: List[int],
|
||||
use_middle_indices_grid: bool,
|
||||
num_attention_heads: int,
|
||||
rope_type: LTXRopeType,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute RoPE frequencies in double precision using numpy.
|
||||
|
||||
MLX GPU doesn't support float64, so we use numpy for computation then convert back.
|
||||
"""
|
||||
# Convert to numpy float64
|
||||
indices_grid_np = np.array(indices_grid).astype(np.float64)
|
||||
|
||||
# Generate frequency indices in float64
|
||||
n_pos_dims = indices_grid_np.shape[1]
|
||||
n_elem = 2 * n_pos_dims
|
||||
|
||||
# Compute log-spaced frequencies
|
||||
log_start = math.log(1.0) / math.log(theta)
|
||||
log_end = math.log(theta) / math.log(theta)
|
||||
num_indices = dim // n_elem
|
||||
if num_indices == 0:
|
||||
num_indices = 1
|
||||
lin_space = np.linspace(log_start, log_end, num_indices)
|
||||
indices_np = np.power(theta, lin_space) * (math.pi / 2)
|
||||
|
||||
# Handle middle indices grid
|
||||
# Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise
|
||||
if use_middle_indices_grid:
|
||||
assert len(indices_grid_np.shape) == 4
|
||||
assert indices_grid_np.shape[-1] == 2
|
||||
indices_grid_start = indices_grid_np[..., 0]
|
||||
indices_grid_end = indices_grid_np[..., 1]
|
||||
indices_grid_np = (indices_grid_start + indices_grid_end) / 2.0
|
||||
elif len(indices_grid_np.shape) == 4:
|
||||
indices_grid_np = indices_grid_np[..., 0]
|
||||
# After handling: indices_grid_np shape is (B, n_dims, T)
|
||||
|
||||
# Get fractional positions: (B, n_dims, T) -> (B, T, n_dims)
|
||||
batch_size = indices_grid_np.shape[0]
|
||||
seq_len = indices_grid_np.shape[2]
|
||||
fractional_positions = np.zeros((batch_size, seq_len, n_pos_dims), dtype=np.float64)
|
||||
for i in range(n_pos_dims):
|
||||
# indices_grid_np[:, i, :] has shape (B, T)
|
||||
fractional_positions[:, :, i] = indices_grid_np[:, i, :] / max_pos[i]
|
||||
|
||||
# Scale to [-1, 1]
|
||||
scaled_positions = fractional_positions * 2 - 1
|
||||
|
||||
# Compute frequencies: outer product
|
||||
freqs = np.expand_dims(scaled_positions, axis=-1) * indices_np.reshape(1, 1, 1, -1)
|
||||
freqs = np.swapaxes(freqs, -1, -2)
|
||||
freqs = freqs.reshape(freqs.shape[:-2] + (-1,))
|
||||
|
||||
# Compute cos/sin in float64
|
||||
cos_freq = np.cos(freqs)
|
||||
sin_freq = np.sin(freqs)
|
||||
|
||||
# Prepare based on rope type
|
||||
if rope_type == LTXRopeType.SPLIT:
|
||||
expected_freqs = dim // 2
|
||||
current_freqs = cos_freq.shape[-1]
|
||||
pad_size = expected_freqs - current_freqs
|
||||
|
||||
# Add padding
|
||||
if pad_size > 0:
|
||||
cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64)
|
||||
sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64)
|
||||
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
|
||||
# Reshape for multi-head attention: (B, T, dim//2) -> (B, H, T, dim//2//H)
|
||||
b, t = cos_freq.shape[0], cos_freq.shape[1]
|
||||
cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1)
|
||||
sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1)
|
||||
cos_freq = np.swapaxes(cos_freq, 1, 2)
|
||||
sin_freq = np.swapaxes(sin_freq, 1, 2)
|
||||
else:
|
||||
# Interleaved
|
||||
cos_freq = np.repeat(cos_freq, 2, axis=-1)
|
||||
sin_freq = np.repeat(sin_freq, 2, axis=-1)
|
||||
|
||||
pad_size = dim % n_elem
|
||||
if pad_size > 0:
|
||||
cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64)
|
||||
sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64)
|
||||
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
|
||||
# Convert back to MLX (float32 for GPU compatibility)
|
||||
cos_freq = mx.array(cos_freq.astype(np.float32))
|
||||
sin_freq = mx.array(sin_freq.astype(np.float32))
|
||||
|
||||
return cos_freq, sin_freq
|
||||
Reference in New Issue
Block a user