import math from typing import List, Optional, Tuple import mlx.core as mx from mlx_video.models.ltx_2.config import LTXRopeType def apply_rotary_emb( input_tensor: mx.array, freqs_cis: Tuple[mx.array, mx.array], rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, ) -> mx.array: """Apply rotary position embeddings to input tensor. Args: input_tensor: Input tensor to apply RoPE to freqs_cis: Tuple of (cos_freqs, sin_freqs) rope_type: Type of RoPE to apply (INTERLEAVED or SPLIT) Returns: Tensor with rotary embeddings applied """ if rope_type == LTXRopeType.INTERLEAVED: return apply_interleaved_rotary_emb(input_tensor, freqs_cis[0], freqs_cis[1]) elif rope_type == LTXRopeType.SPLIT: return apply_split_rotary_emb(input_tensor, freqs_cis[0], freqs_cis[1]) else: raise ValueError(f"Invalid rope type: {rope_type}") def apply_interleaved_rotary_emb( input_tensor: mx.array, cos_freqs: mx.array, sin_freqs: mx.array, ) -> mx.array: """Apply interleaved rotary embeddings. Pairs adjacent dimensions and applies rotation. Pattern: [x0, x1, x2, x3, ...] -> rotate pairs (x0,x1), (x2,x3), ... Args: input_tensor: Input tensor of shape (..., dim) cos_freqs: Cosine frequencies sin_freqs: Sine frequencies Returns: Tensor with interleaved rotary embeddings applied """ # Compute in float32 for better precision input_dtype = input_tensor.dtype input_tensor = input_tensor.astype(mx.float32) cos_freqs = cos_freqs.astype(mx.float32) sin_freqs = sin_freqs.astype(mx.float32) # Reshape to pair adjacent dimensions: (..., dim) -> (..., dim/2, 2) shape = input_tensor.shape input_tensor = mx.reshape(input_tensor, shape[:-1] + (shape[-1] // 2, 2)) # Extract pairs t1 = input_tensor[..., 0] # Even indices t2 = input_tensor[..., 1] # Odd indices # Apply rotation: (-t2, t1) pattern t_rot = mx.stack([-t2, t1], axis=-1) # Flatten back: (..., dim/2, 2) -> (..., dim) input_tensor = mx.reshape(input_tensor, shape) t_rot = mx.reshape(t_rot, shape) # Apply rotary embeddings out = input_tensor * cos_freqs + t_rot * sin_freqs return out.astype(input_dtype) def rotate_half_interleaved(x: mx.array) -> mx.array: """Rotate for interleaved RoPE: [x0, x1, x2, x3] -> [-x1, x0, -x3, x2]. PyTorch equivalent: t_dup = rearrange(x, "... (d r) -> ... d r", r=2) t1, t2 = t_dup.unbind(dim=-1) t_dup = torch.stack((-t2, t1), dim=-1) return rearrange(t_dup, "... d r -> ... (d r)") """ # x: (..., dim) where dim is even x_even = x[..., 0::2] # [x0, x2, x4, ...] x_odd = x[..., 1::2] # [x1, x3, x5, ...] # Stack: [[-x1, x0], [-x3, x2], ...] then flatten to [-x1, x0, -x3, x2, ...] rotated = mx.stack([-x_odd, x_even], axis=-1) return mx.reshape(rotated, x.shape) def apply_rotary_emb_1d( q: mx.array, k: mx.array, freqs_cis: mx.array, ) -> Tuple[mx.array, mx.array]: """Apply 1D rotary embeddings using precomputed frequencies (interleaved).""" # freqs_cis: (1, seq_len, num_heads, head_dim, 2) where [..., 0] = cos, [..., 1] = sin cos = freqs_cis[..., 0] # (1, seq_len, num_heads, head_dim) sin = freqs_cis[..., 1] # q, k: (batch, seq_len, num_heads, head_dim) # Interleaved RoPE: pairs of adjacent dims rotate together q_r = q * cos + rotate_half_interleaved(q) * sin k_r = k * cos + rotate_half_interleaved(k) * sin return q_r, k_r def apply_split_rotary_emb( input_tensor: mx.array, cos_freqs: mx.array, sin_freqs: mx.array, ) -> mx.array: """Apply split rotary embeddings. Splits dimensions into two halves and applies rotation. Pattern: split into first half and second half Args: input_tensor: Input tensor cos_freqs: Cosine frequencies of shape (B, H, T, D//2) sin_freqs: Sine frequencies of shape (B, H, T, D//2) Returns: Tensor with split rotary embeddings applied """ input_dtype = input_tensor.dtype needs_reshape = False original_shape = input_tensor.shape # Handle dimension mismatch if input_tensor.ndim != 4 and cos_freqs.ndim == 4: b, h, t, _ = cos_freqs.shape # Reshape from (B, T, H*D) to (B, H, T, D) input_tensor = mx.reshape(input_tensor, (b, t, h, -1)) input_tensor = mx.swapaxes(input_tensor, 1, 2) needs_reshape = True # Cast to float32 for computation precision input_tensor = input_tensor.astype(mx.float32) cos_freqs = cos_freqs.astype(mx.float32) sin_freqs = sin_freqs.astype(mx.float32) # Split into two halves: (..., dim) -> (..., 2, dim//2) dim = input_tensor.shape[-1] split_input = mx.reshape(input_tensor, input_tensor.shape[:-1] + (2, dim // 2)) # Get first and second halves first_half = split_input[..., 0, :] # (..., dim//2) second_half = split_input[..., 1, :] # (..., dim//2) # Apply cosine to both halves output_first = first_half * cos_freqs output_second = second_half * cos_freqs # Apply sine cross-terms (addcmul pattern) output_first = output_first - sin_freqs * second_half output_second = output_second + sin_freqs * first_half # Stack back together output = mx.stack([output_first, output_second], axis=-2) # Flatten: (..., 2, dim//2) -> (..., dim) output = mx.reshape(output, input_tensor.shape) if needs_reshape: # Reshape back: (B, H, T, D) -> (B, T, H*D) b, h, t, d = output.shape output = mx.swapaxes(output, 1, 2) output = mx.reshape(output, (b, t, h * d)) return output.astype(input_dtype) def generate_freq_grid( positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int, ) -> mx.array: """Generate frequency grid for RoPE. Args: positional_embedding_theta: Base theta value positional_embedding_max_pos_count: Number of position dimensions inner_dim: Inner dimension of the model Returns: Frequency indices tensor """ theta = positional_embedding_theta start = 1.0 end = theta n_elem = 2 * positional_embedding_max_pos_count # Compute logarithmic spacing log_start = math.log(start) / math.log(theta) log_end = math.log(end) / math.log(theta) num_indices = inner_dim // n_elem if num_indices == 0: num_indices = 1 # Create linearly spaced values in log space lin_space = mx.linspace(log_start, log_end, num_indices) # Compute power indices pow_indices = mx.power(theta, lin_space) # Scale by pi/2 return pow_indices * (math.pi / 2) def get_fractional_positions( indices_grid: mx.array, max_pos: List[int], ) -> mx.array: """Convert indices to fractional positions. Args: indices_grid: Grid of position indices of shape (B, n_pos_dims, ...) max_pos: Maximum position for each dimension Returns: Fractional positions in range [-1, 1] after scaling """ n_pos_dims = indices_grid.shape[1] assert n_pos_dims == len(max_pos), ( f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})" ) # Divide each dimension by its max position fractional_positions = [] for i in range(n_pos_dims): frac = indices_grid[:, i] / max_pos[i] fractional_positions.append(frac) return mx.stack(fractional_positions, axis=-1) def generate_freqs( indices: mx.array, indices_grid: mx.array, max_pos: List[int], use_middle_indices_grid: bool, ) -> mx.array: """Generate frequencies from indices and position grid. Args: indices: Frequency indices indices_grid: Position indices grid max_pos: Maximum positions per dimension use_middle_indices_grid: Whether to use middle of index ranges Returns: Frequency tensor """ # Handle middle indices grid if use_middle_indices_grid: # indices_grid shape: (B, n_dims, T, 2) where last dim is [start, end] assert len(indices_grid.shape) == 4 assert indices_grid.shape[-1] == 2 indices_grid_start = indices_grid[..., 0] indices_grid_end = indices_grid[..., 1] indices_grid = (indices_grid_start + indices_grid_end) / 2.0 elif len(indices_grid.shape) == 4: indices_grid = indices_grid[..., 0] # Get fractional positions fractional_positions = get_fractional_positions(indices_grid, max_pos) # Compute frequencies # fractional_positions: (B, T, n_dims) # indices: (inner_dim // n_elem,) # Result: (B, T, inner_dim // n_elem * n_dims) # Scale fractional positions to [-1, 1] scaled_positions = fractional_positions * 2 - 1 # (B, T, n_dims) # Outer product with indices # (B, T, n_dims, 1) * (1, 1, 1, n_indices) -> (B, T, n_dims, n_indices) freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.expand_dims( mx.expand_dims(mx.expand_dims(indices, axis=0), axis=0), axis=0 ) # Transpose and flatten: (B, T, n_dims, n_indices) -> (B, T, n_indices * n_dims) freqs = mx.swapaxes(freqs, -1, -2) # (B, T, n_indices, n_dims) freqs = mx.reshape(freqs, freqs.shape[:-2] + (-1,)) return freqs def split_freqs_cis( freqs: mx.array, pad_size: int, num_attention_heads: int, ) -> Tuple[mx.array, mx.array]: """Prepare cos/sin frequencies for split RoPE. Args: freqs: Frequency tensor pad_size: Padding size for dimension alignment num_attention_heads: Number of attention heads Returns: Tuple of (cos_freq, sin_freq) with shape (B, H, T, D//2) """ cos_freq = mx.cos(freqs) sin_freq = mx.sin(freqs) # Add padding if needed if pad_size != 0: cos_padding = mx.ones_like(cos_freq[:, :, :pad_size]) sin_padding = mx.zeros_like(sin_freq[:, :, :pad_size]) cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1) sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1) # Reshape for multi-head attention b, t = cos_freq.shape[0], cos_freq.shape[1] cos_freq = mx.reshape(cos_freq, (b, t, num_attention_heads, -1)) sin_freq = mx.reshape(sin_freq, (b, t, num_attention_heads, -1)) # Swap axes: (B, T, H, D//2) -> (B, H, T, D//2) cos_freq = mx.swapaxes(cos_freq, 1, 2) sin_freq = mx.swapaxes(sin_freq, 1, 2) return cos_freq, sin_freq def interleaved_freqs_cis( freqs: mx.array, pad_size: int, ) -> Tuple[mx.array, mx.array]: """Prepare cos/sin frequencies for interleaved RoPE. Args: freqs: Frequency tensor of shape (B, T, dim//2) pad_size: Padding size for dimension alignment Returns: Tuple of (cos_freq, sin_freq) with shape (B, T, dim) """ # Compute cos and sin cos_freq = mx.cos(freqs) sin_freq = mx.sin(freqs) # Repeat interleave: each element repeated twice # (B, T, D) -> (B, T, 2*D) with pattern [c0, c0, c1, c1, ...] cos_freq = mx.repeat(cos_freq, 2, axis=-1) sin_freq = mx.repeat(sin_freq, 2, axis=-1) # Add padding if needed if pad_size != 0: cos_padding = mx.ones_like(cos_freq[:, :, :pad_size]) sin_padding = mx.zeros_like(sin_freq[:, :, :pad_size]) cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1) sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1) return cos_freq, sin_freq def precompute_freqs_cis( indices_grid: mx.array, dim: int, theta: float = 10000.0, max_pos: Optional[List[int]] = None, use_middle_indices_grid: bool = False, num_attention_heads: int = 32, rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, double_precision: bool = False, ) -> Tuple[mx.array, mx.array]: """Precompute RoPE frequencies. Args: indices_grid: Position indices grid dim: Dimension for RoPE theta: Base theta value for frequency computation max_pos: Maximum position per dimension use_middle_indices_grid: Whether to use middle indices num_attention_heads: Number of attention heads rope_type: Type of RoPE (INTERLEAVED or SPLIT) double_precision: If True, compute frequencies in float64 for higher precision Returns: Tuple of (cos_freq, sin_freq) tensors """ if max_pos is None: max_pos = [20, 2048, 2048] if double_precision: return _precompute_freqs_cis_double_precision( indices_grid, dim, theta, max_pos, use_middle_indices_grid, num_attention_heads, rope_type ) # Keep positions in float32 for RoPE computation. # Even though PyTorch nominally casts positions to model dtype (bfloat16), # empirical comparison shows float32 positions produce RoPE values matching # PyTorch exactly (cosine=1.000). BFloat16 loses precision in fractional # position computation that gets amplified by high-frequency indices # (up to 15708), causing cos/sin sign flips and cosine sim of only 0.88. indices_grid = indices_grid.astype(mx.float32) # Generate frequency indices indices = generate_freq_grid(theta, indices_grid.shape[1], dim) # Generate frequencies freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid) # Prepare cos/sin based on rope type if rope_type == LTXRopeType.SPLIT: expected_freqs = dim // 2 current_freqs = freqs.shape[-1] pad_size = expected_freqs - current_freqs cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads) else: # Interleaved n_elem = 2 * indices_grid.shape[1] cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) return cos_freq, sin_freq def _precompute_freqs_cis_double_precision( indices_grid: mx.array, dim: int, theta: float, max_pos: List[int], use_middle_indices_grid: bool, num_attention_heads: int, rope_type: LTXRopeType, ) -> Tuple[mx.array, mx.array]: """Compute RoPE frequencies with higher precision using float64 for frequency grid. Matches PyTorch's generate_freq_grid_np: uses NumPy float64 for the critical frequency grid computation (log-spaced values), then converts to float32. Position grid stays in bfloat16 to match PyTorch behavior (positions are in model dtype throughout generate_freqs). """ import numpy as np # Keep positions in float32 — same reasoning as the non-double-precision path. indices_grid_f32 = indices_grid.astype(mx.float32) n_pos_dims = indices_grid_f32.shape[1] n_elem = 2 * n_pos_dims # Compute log-spaced frequencies in float64 (matching PyTorch's generate_freq_grid_np) # This is the critical precision step - PyTorch uses np.float64 here log_start = np.log(1.0) / np.log(theta) log_end = np.log(theta) / np.log(theta) # = 1.0 num_indices = dim // n_elem if num_indices == 0: num_indices = 1 # Use numpy float64 for the linspace computation (matches PyTorch) pow_indices = np.power( theta, np.linspace(log_start, log_end, num_indices, dtype=np.float64), ) # Convert to float32 tensor (matches PyTorch: torch.tensor(..., dtype=torch.float32)) freq_indices = mx.array(pow_indices * (math.pi / 2), dtype=mx.float32) # Handle middle indices grid # Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise if use_middle_indices_grid: assert len(indices_grid_f32.shape) == 4 assert indices_grid_f32.shape[-1] == 2 indices_grid_start = indices_grid_f32[..., 0] indices_grid_end = indices_grid_f32[..., 1] indices_grid_f32 = (indices_grid_start + indices_grid_end) / 2.0 elif len(indices_grid_f32.shape) == 4: indices_grid_f32 = indices_grid_f32[..., 0] # After handling: indices_grid_f32 shape is (B, n_dims, T) # Get fractional positions: (B, n_dims, T) -> (B, T, n_dims) # Compute fractional positions for each dimension fractional_list = [] for i in range(n_pos_dims): frac = indices_grid_f32[:, i, :] / max_pos[i] # (B, T) fractional_list.append(frac) # Stack: (B, T, n_dims) fractional_positions = mx.stack(fractional_list, axis=-1) # Scale to [-1, 1] scaled_positions = fractional_positions * 2 - 1 # Compute frequencies: outer product # scaled_positions: (B, T, n_dims) -> (B, T, n_dims, 1) # freq_indices: (num_indices,) -> (1, 1, 1, num_indices) freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(freq_indices, (1, 1, 1, -1)) # freqs: (B, T, n_dims, num_indices) # Transpose and flatten: (B, T, n_dims, num_indices) -> (B, T, num_indices, n_dims) -> (B, T, num_indices * n_dims) freqs = mx.swapaxes(freqs, -1, -2) freqs = mx.reshape(freqs, (freqs.shape[0], freqs.shape[1], -1)) # Compute cos/sin cos_freq = mx.cos(freqs) sin_freq = mx.sin(freqs) # Prepare based on rope type if rope_type == LTXRopeType.SPLIT: expected_freqs = dim // 2 current_freqs = cos_freq.shape[-1] pad_size = expected_freqs - current_freqs # Add padding if pad_size > 0: cos_padding = mx.ones((*cos_freq.shape[:-1], pad_size), dtype=mx.float32) sin_padding = mx.zeros((*sin_freq.shape[:-1], pad_size), dtype=mx.float32) cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1) sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1) # Reshape for multi-head attention: (B, T, dim//2) -> (B, H, T, dim//2//H) b, t = cos_freq.shape[0], cos_freq.shape[1] cos_freq = mx.reshape(cos_freq, (b, t, num_attention_heads, -1)) sin_freq = mx.reshape(sin_freq, (b, t, num_attention_heads, -1)) cos_freq = mx.swapaxes(cos_freq, 1, 2) sin_freq = mx.swapaxes(sin_freq, 1, 2) else: # Interleaved cos_freq = mx.repeat(cos_freq, 2, axis=-1) sin_freq = mx.repeat(sin_freq, 2, axis=-1) pad_size = dim % n_elem if pad_size > 0: cos_padding = mx.ones((*cos_freq.shape[:-1], pad_size), dtype=mx.float32) sin_padding = mx.zeros((*sin_freq.shape[:-1], pad_size), dtype=mx.float32) cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1) sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1) return cos_freq, sin_freq