import math from typing import Callable, List, Optional, Tuple import mlx.core as mx import numpy as np from mlx_video.models.ltx.config import LTXRopeType def apply_rotary_emb( input_tensor: mx.array, freqs_cis: Tuple[mx.array, mx.array], rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, ) -> mx.array: """Apply rotary position embeddings to input tensor. Args: input_tensor: Input tensor to apply RoPE to freqs_cis: Tuple of (cos_freqs, sin_freqs) rope_type: Type of RoPE to apply (INTERLEAVED or SPLIT) Returns: Tensor with rotary embeddings applied """ if rope_type == LTXRopeType.INTERLEAVED: return apply_interleaved_rotary_emb(input_tensor, freqs_cis[0], freqs_cis[1]) elif rope_type == LTXRopeType.SPLIT: return apply_split_rotary_emb(input_tensor, freqs_cis[0], freqs_cis[1]) else: raise ValueError(f"Invalid rope type: {rope_type}") def apply_interleaved_rotary_emb( input_tensor: mx.array, cos_freqs: mx.array, sin_freqs: mx.array, ) -> mx.array: """Apply interleaved rotary embeddings. Pairs adjacent dimensions and applies rotation. Pattern: [x0, x1, x2, x3, ...] -> rotate pairs (x0,x1), (x2,x3), ... Args: input_tensor: Input tensor of shape (..., dim) cos_freqs: Cosine frequencies sin_freqs: Sine frequencies Returns: Tensor with interleaved rotary embeddings applied """ # Compute in float32 for better precision input_dtype = input_tensor.dtype input_tensor = input_tensor.astype(mx.float32) cos_freqs = cos_freqs.astype(mx.float32) sin_freqs = sin_freqs.astype(mx.float32) # Reshape to pair adjacent dimensions: (..., dim) -> (..., dim/2, 2) shape = input_tensor.shape input_tensor = mx.reshape(input_tensor, shape[:-1] + (shape[-1] // 2, 2)) # Extract pairs t1 = input_tensor[..., 0] # Even indices t2 = input_tensor[..., 1] # Odd indices # Apply rotation: (-t2, t1) pattern t_rot = mx.stack([-t2, t1], axis=-1) # Flatten back: (..., dim/2, 2) -> (..., dim) input_tensor = mx.reshape(input_tensor, shape) t_rot = mx.reshape(t_rot, shape) # Apply rotary embeddings out = input_tensor * cos_freqs + t_rot * sin_freqs return out.astype(input_dtype) def rotate_half_interleaved(x: mx.array) -> mx.array: """Rotate for interleaved RoPE: [x0, x1, x2, x3] -> [-x1, x0, -x3, x2]. PyTorch equivalent: t_dup = rearrange(x, "... (d r) -> ... d r", r=2) t1, t2 = t_dup.unbind(dim=-1) t_dup = torch.stack((-t2, t1), dim=-1) return rearrange(t_dup, "... d r -> ... (d r)") """ # x: (..., dim) where dim is even x_even = x[..., 0::2] # [x0, x2, x4, ...] x_odd = x[..., 1::2] # [x1, x3, x5, ...] # Stack: [[-x1, x0], [-x3, x2], ...] then flatten to [-x1, x0, -x3, x2, ...] rotated = mx.stack([-x_odd, x_even], axis=-1) return mx.reshape(rotated, x.shape) def apply_rotary_emb_1d( q: mx.array, k: mx.array, freqs_cis: mx.array, ) -> Tuple[mx.array, mx.array]: """Apply 1D rotary embeddings using precomputed frequencies (interleaved).""" # freqs_cis: (1, seq_len, num_heads, head_dim, 2) where [..., 0] = cos, [..., 1] = sin cos = freqs_cis[..., 0] # (1, seq_len, num_heads, head_dim) sin = freqs_cis[..., 1] # q, k: (batch, seq_len, num_heads, head_dim) # Interleaved RoPE: pairs of adjacent dims rotate together q_r = q * cos + rotate_half_interleaved(q) * sin k_r = k * cos + rotate_half_interleaved(k) * sin return q_r, k_r def apply_split_rotary_emb( input_tensor: mx.array, cos_freqs: mx.array, sin_freqs: mx.array, ) -> mx.array: """Apply split rotary embeddings. Splits dimensions into two halves and applies rotation. Pattern: split into first half and second half Args: input_tensor: Input tensor cos_freqs: Cosine frequencies of shape (B, H, T, D//2) sin_freqs: Sine frequencies of shape (B, H, T, D//2) Returns: Tensor with split rotary embeddings applied """ needs_reshape = False original_shape = input_tensor.shape # Handle dimension mismatch if input_tensor.ndim != 4 and cos_freqs.ndim == 4: b, h, t, _ = cos_freqs.shape # Reshape from (B, T, H*D) to (B, H, T, D) input_tensor = mx.reshape(input_tensor, (b, t, h, -1)) input_tensor = mx.swapaxes(input_tensor, 1, 2) needs_reshape = True # Split into two halves: (..., dim) -> (..., 2, dim//2) dim = input_tensor.shape[-1] split_input = mx.reshape(input_tensor, input_tensor.shape[:-1] + (2, dim // 2)) # Get first and second halves first_half = split_input[..., 0, :] # (..., dim//2) second_half = split_input[..., 1, :] # (..., dim//2) # Apply cosine to both halves output_first = first_half * cos_freqs output_second = second_half * cos_freqs # Apply sine cross-terms (addcmul pattern) output_first = output_first - sin_freqs * second_half output_second = output_second + sin_freqs * first_half # Stack back together output = mx.stack([output_first, output_second], axis=-2) # Flatten: (..., 2, dim//2) -> (..., dim) output = mx.reshape(output, input_tensor.shape) if needs_reshape: # Reshape back: (B, H, T, D) -> (B, T, H*D) b, h, t, d = output.shape output = mx.swapaxes(output, 1, 2) output = mx.reshape(output, (b, t, h * d)) return output def generate_freq_grid( positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int, ) -> mx.array: """Generate frequency grid for RoPE. Args: positional_embedding_theta: Base theta value positional_embedding_max_pos_count: Number of position dimensions inner_dim: Inner dimension of the model Returns: Frequency indices tensor """ theta = positional_embedding_theta start = 1.0 end = theta n_elem = 2 * positional_embedding_max_pos_count # Compute logarithmic spacing log_start = math.log(start) / math.log(theta) log_end = math.log(end) / math.log(theta) num_indices = inner_dim // n_elem if num_indices == 0: num_indices = 1 # Create linearly spaced values in log space lin_space = mx.linspace(log_start, log_end, num_indices) # Compute power indices pow_indices = mx.power(theta, lin_space) # Scale by pi/2 return pow_indices * (math.pi / 2) def get_fractional_positions( indices_grid: mx.array, max_pos: List[int], ) -> mx.array: """Convert indices to fractional positions. Args: indices_grid: Grid of position indices of shape (B, n_pos_dims, ...) max_pos: Maximum position for each dimension Returns: Fractional positions in range [-1, 1] after scaling """ n_pos_dims = indices_grid.shape[1] assert n_pos_dims == len(max_pos), ( f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})" ) # Divide each dimension by its max position fractional_positions = [] for i in range(n_pos_dims): frac = indices_grid[:, i] / max_pos[i] fractional_positions.append(frac) return mx.stack(fractional_positions, axis=-1) def generate_freqs( indices: mx.array, indices_grid: mx.array, max_pos: List[int], use_middle_indices_grid: bool, ) -> mx.array: """Generate frequencies from indices and position grid. Args: indices: Frequency indices indices_grid: Position indices grid max_pos: Maximum positions per dimension use_middle_indices_grid: Whether to use middle of index ranges Returns: Frequency tensor """ # Handle middle indices grid if use_middle_indices_grid: # indices_grid shape: (B, n_dims, T, 2) where last dim is [start, end] assert len(indices_grid.shape) == 4 assert indices_grid.shape[-1] == 2 indices_grid_start = indices_grid[..., 0] indices_grid_end = indices_grid[..., 1] indices_grid = (indices_grid_start + indices_grid_end) / 2.0 elif len(indices_grid.shape) == 4: indices_grid = indices_grid[..., 0] # Get fractional positions fractional_positions = get_fractional_positions(indices_grid, max_pos) # Compute frequencies # fractional_positions: (B, T, n_dims) # indices: (inner_dim // n_elem,) # Result: (B, T, inner_dim // n_elem * n_dims) # Scale fractional positions to [-1, 1] scaled_positions = fractional_positions * 2 - 1 # (B, T, n_dims) # Outer product with indices # (B, T, n_dims, 1) * (1, 1, 1, n_indices) -> (B, T, n_dims, n_indices) freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.expand_dims( mx.expand_dims(mx.expand_dims(indices, axis=0), axis=0), axis=0 ) # Transpose and flatten: (B, T, n_dims, n_indices) -> (B, T, n_indices * n_dims) freqs = mx.swapaxes(freqs, -1, -2) # (B, T, n_indices, n_dims) freqs = mx.reshape(freqs, freqs.shape[:-2] + (-1,)) return freqs def split_freqs_cis( freqs: mx.array, pad_size: int, num_attention_heads: int, ) -> Tuple[mx.array, mx.array]: """Prepare cos/sin frequencies for split RoPE. Args: freqs: Frequency tensor pad_size: Padding size for dimension alignment num_attention_heads: Number of attention heads Returns: Tuple of (cos_freq, sin_freq) with shape (B, H, T, D//2) """ cos_freq = mx.cos(freqs) sin_freq = mx.sin(freqs) # Add padding if needed if pad_size != 0: cos_padding = mx.ones_like(cos_freq[:, :, :pad_size]) sin_padding = mx.zeros_like(sin_freq[:, :, :pad_size]) cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1) sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1) # Reshape for multi-head attention b, t = cos_freq.shape[0], cos_freq.shape[1] cos_freq = mx.reshape(cos_freq, (b, t, num_attention_heads, -1)) sin_freq = mx.reshape(sin_freq, (b, t, num_attention_heads, -1)) # Swap axes: (B, T, H, D//2) -> (B, H, T, D//2) cos_freq = mx.swapaxes(cos_freq, 1, 2) sin_freq = mx.swapaxes(sin_freq, 1, 2) return cos_freq, sin_freq def interleaved_freqs_cis( freqs: mx.array, pad_size: int, ) -> Tuple[mx.array, mx.array]: """Prepare cos/sin frequencies for interleaved RoPE. Args: freqs: Frequency tensor of shape (B, T, dim//2) pad_size: Padding size for dimension alignment Returns: Tuple of (cos_freq, sin_freq) with shape (B, T, dim) """ # Compute cos and sin cos_freq = mx.cos(freqs) sin_freq = mx.sin(freqs) # Repeat interleave: each element repeated twice # (B, T, D) -> (B, T, 2*D) with pattern [c0, c0, c1, c1, ...] cos_freq = mx.repeat(cos_freq, 2, axis=-1) sin_freq = mx.repeat(sin_freq, 2, axis=-1) # Add padding if needed if pad_size != 0: cos_padding = mx.ones_like(cos_freq[:, :, :pad_size]) sin_padding = mx.zeros_like(sin_freq[:, :, :pad_size]) cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1) sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1) return cos_freq, sin_freq def precompute_freqs_cis( indices_grid: mx.array, dim: int, theta: float = 10000.0, max_pos: Optional[List[int]] = None, use_middle_indices_grid: bool = False, num_attention_heads: int = 32, rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, double_precision: bool = False, ) -> Tuple[mx.array, mx.array]: """Precompute RoPE frequencies. Args: indices_grid: Position indices grid dim: Dimension for RoPE theta: Base theta value for frequency computation max_pos: Maximum position per dimension use_middle_indices_grid: Whether to use middle indices num_attention_heads: Number of attention heads rope_type: Type of RoPE (INTERLEAVED or SPLIT) double_precision: If True, compute frequencies in float64 for higher precision Returns: Tuple of (cos_freq, sin_freq) tensors """ if max_pos is None: max_pos = [20, 2048, 2048] if double_precision: return _precompute_freqs_cis_double_precision( indices_grid, dim, theta, max_pos, use_middle_indices_grid, num_attention_heads, rope_type ) # Generate frequency indices indices = generate_freq_grid(theta, indices_grid.shape[1], dim) # Generate frequencies freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid) # Prepare cos/sin based on rope type if rope_type == LTXRopeType.SPLIT: expected_freqs = dim // 2 current_freqs = freqs.shape[-1] pad_size = expected_freqs - current_freqs cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads) else: # Interleaved n_elem = 2 * indices_grid.shape[1] cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) return cos_freq, sin_freq def _precompute_freqs_cis_double_precision( indices_grid: mx.array, dim: int, theta: float, max_pos: List[int], use_middle_indices_grid: bool, num_attention_heads: int, rope_type: LTXRopeType, ) -> Tuple[mx.array, mx.array]: # Convert to numpy float64 indices_grid_np = np.array(indices_grid).astype(np.float64) # Generate frequency indices in float64 n_pos_dims = indices_grid_np.shape[1] n_elem = 2 * n_pos_dims # Compute log-spaced frequencies log_start = math.log(1.0) / math.log(theta) log_end = math.log(theta) / math.log(theta) num_indices = dim // n_elem if num_indices == 0: num_indices = 1 lin_space = np.linspace(log_start, log_end, num_indices) indices_np = np.power(theta, lin_space) * (math.pi / 2) # Handle middle indices grid # Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise if use_middle_indices_grid: assert len(indices_grid_np.shape) == 4 assert indices_grid_np.shape[-1] == 2 indices_grid_start = indices_grid_np[..., 0] indices_grid_end = indices_grid_np[..., 1] indices_grid_np = (indices_grid_start + indices_grid_end) / 2.0 elif len(indices_grid_np.shape) == 4: indices_grid_np = indices_grid_np[..., 0] # After handling: indices_grid_np shape is (B, n_dims, T) # Get fractional positions: (B, n_dims, T) -> (B, T, n_dims) batch_size = indices_grid_np.shape[0] seq_len = indices_grid_np.shape[2] fractional_positions = np.zeros((batch_size, seq_len, n_pos_dims), dtype=np.float64) for i in range(n_pos_dims): # indices_grid_np[:, i, :] has shape (B, T) fractional_positions[:, :, i] = indices_grid_np[:, i, :] / max_pos[i] # Scale to [-1, 1] scaled_positions = fractional_positions * 2 - 1 # Compute frequencies: outer product freqs = np.expand_dims(scaled_positions, axis=-1) * indices_np.reshape(1, 1, 1, -1) freqs = np.swapaxes(freqs, -1, -2) freqs = freqs.reshape(freqs.shape[:-2] + (-1,)) # Compute cos/sin in float64 cos_freq = np.cos(freqs) sin_freq = np.sin(freqs) # Prepare based on rope type if rope_type == LTXRopeType.SPLIT: expected_freqs = dim // 2 current_freqs = cos_freq.shape[-1] pad_size = expected_freqs - current_freqs # Add padding if pad_size > 0: cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64) sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64) cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1) sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1) # Reshape for multi-head attention: (B, T, dim//2) -> (B, H, T, dim//2//H) b, t = cos_freq.shape[0], cos_freq.shape[1] cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1) sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1) cos_freq = np.swapaxes(cos_freq, 1, 2) sin_freq = np.swapaxes(sin_freq, 1, 2) else: # Interleaved cos_freq = np.repeat(cos_freq, 2, axis=-1) sin_freq = np.repeat(sin_freq, 2, axis=-1) pad_size = dim % n_elem if pad_size > 0: cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64) sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64) cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1) sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1) # Convert back to MLX (float32 for GPU compatibility) cos_freq = mx.array(cos_freq.astype(np.float32)) sin_freq = mx.array(sin_freq.astype(np.float32)) return cos_freq, sin_freq