From e483eab0393f46e65115834fd5f33eb16849cb10 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 18 Jan 2026 11:13:32 +0100 Subject: [PATCH] Optimize positional embedding handling in TransformerArgsPreprocessor and improve RoPE frequency computation in _precompute_freqs_cis_double_precision for enhanced performance and precision. --- mlx_video/models/ltx/ltx.py | 19 +++-- mlx_video/models/ltx/rope.py | 104 +++++++++++++++------------- mlx_video/models/ltx/transformer.py | 10 +-- 3 files changed, 72 insertions(+), 61 deletions(-) diff --git a/mlx_video/models/ltx/ltx.py b/mlx_video/models/ltx/ltx.py index a3eef42..c7c51a2 100644 --- a/mlx_video/models/ltx/ltx.py +++ b/mlx_video/models/ltx/ltx.py @@ -121,13 +121,18 @@ class TransformerArgsPreprocessor: timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype) context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask) attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype) - pe = self._prepare_positional_embeddings( - positions=modality.positions, - inner_dim=self.inner_dim, - max_pos=self.max_pos, - use_middle_indices_grid=self.use_middle_indices_grid, - num_attention_heads=self.num_attention_heads, - ) + + # Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation) + if modality.positional_embeddings is not None: + pe = modality.positional_embeddings + else: + pe = self._prepare_positional_embeddings( + positions=modality.positions, + inner_dim=self.inner_dim, + max_pos=self.max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + num_attention_heads=self.num_attention_heads, + ) return TransformerArgs( x=x, diff --git a/mlx_video/models/ltx/rope.py b/mlx_video/models/ltx/rope.py index 4852942..9e2db5f 100644 --- a/mlx_video/models/ltx/rope.py +++ b/mlx_video/models/ltx/rope.py @@ -1,9 +1,8 @@ import math -from typing import Callable, List, Optional, Tuple +from typing import List, Optional, Tuple import mlx.core as mx -import numpy as np from mlx_video.models.ltx.config import LTXRopeType @@ -429,66 +428,75 @@ def _precompute_freqs_cis_double_precision( num_attention_heads: int, rope_type: LTXRopeType, ) -> Tuple[mx.array, mx.array]: + """Compute RoPE frequencies with higher precision using float32. + This version stays entirely in MLX/GPU, avoiding expensive NumPy round-trips. + Uses float32 for computation precision (sufficient for RoPE). + """ # Warn if positions are bfloat16 - this causes quality degradation if indices_grid.dtype == mx.bfloat16: import warnings warnings.warn( - "Position grid has dtype bfloat16, which causes precision loss in RoPE that causes quality degradation in generated videos/audio. " - "Use float32 for position grids to avoid quality degradation. " - "See tests/test_rope.py::test_bfloat16_positions_cause_precision_loss", + "Position grid has dtype bfloat16, which causes precision loss in RoPE. " + "Use float32 for position grids to avoid quality degradation.", UserWarning, stacklevel=2 ) - # Convert to numpy float64 (first to float32 for numpy compatibility) - # Note: If input is bfloat16, precision is already lost at this step - indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64) + # Cast to float32 for computation (stay on GPU, no NumPy/CPU conversion) + indices_grid_f32 = indices_grid.astype(mx.float32) - # Generate frequency indices in float64 - n_pos_dims = indices_grid_np.shape[1] + n_pos_dims = indices_grid_f32.shape[1] n_elem = 2 * n_pos_dims - # Compute log-spaced frequencies + # Compute log-spaced frequencies in float32 log_start = math.log(1.0) / math.log(theta) log_end = math.log(theta) / math.log(theta) num_indices = dim // n_elem if num_indices == 0: num_indices = 1 - lin_space = np.linspace(log_start, log_end, num_indices) - indices_np = np.power(theta, lin_space) * (math.pi / 2) + + lin_space = mx.linspace(log_start, log_end, num_indices) + freq_indices = mx.power(mx.array(theta, dtype=mx.float32), lin_space) * (math.pi / 2) # Handle middle indices grid # Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise if use_middle_indices_grid: - assert len(indices_grid_np.shape) == 4 - assert indices_grid_np.shape[-1] == 2 - indices_grid_start = indices_grid_np[..., 0] - indices_grid_end = indices_grid_np[..., 1] - indices_grid_np = (indices_grid_start + indices_grid_end) / 2.0 - elif len(indices_grid_np.shape) == 4: - indices_grid_np = indices_grid_np[..., 0] - # After handling: indices_grid_np shape is (B, n_dims, T) + assert len(indices_grid_f32.shape) == 4 + assert indices_grid_f32.shape[-1] == 2 + indices_grid_start = indices_grid_f32[..., 0] + indices_grid_end = indices_grid_f32[..., 1] + indices_grid_f32 = (indices_grid_start + indices_grid_end) / 2.0 + elif len(indices_grid_f32.shape) == 4: + indices_grid_f32 = indices_grid_f32[..., 0] + # After handling: indices_grid_f32 shape is (B, n_dims, T) # Get fractional positions: (B, n_dims, T) -> (B, T, n_dims) - batch_size = indices_grid_np.shape[0] - seq_len = indices_grid_np.shape[2] - fractional_positions = np.zeros((batch_size, seq_len, n_pos_dims), dtype=np.float64) + # Compute fractional positions for each dimension + fractional_list = [] for i in range(n_pos_dims): - # indices_grid_np[:, i, :] has shape (B, T) - fractional_positions[:, :, i] = indices_grid_np[:, i, :] / max_pos[i] + frac = indices_grid_f32[:, i, :] / max_pos[i] # (B, T) + fractional_list.append(frac) + + # Stack: (B, T, n_dims) + fractional_positions = mx.stack(fractional_list, axis=-1) # Scale to [-1, 1] scaled_positions = fractional_positions * 2 - 1 # Compute frequencies: outer product - freqs = np.expand_dims(scaled_positions, axis=-1) * indices_np.reshape(1, 1, 1, -1) - freqs = np.swapaxes(freqs, -1, -2) - freqs = freqs.reshape(freqs.shape[:-2] + (-1,)) + # scaled_positions: (B, T, n_dims) -> (B, T, n_dims, 1) + # freq_indices: (num_indices,) -> (1, 1, 1, num_indices) + freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(freq_indices, (1, 1, 1, -1)) + # freqs: (B, T, n_dims, num_indices) - # Compute cos/sin in float64 - cos_freq = np.cos(freqs) - sin_freq = np.sin(freqs) + # Transpose and flatten: (B, T, n_dims, num_indices) -> (B, T, num_indices, n_dims) -> (B, T, num_indices * n_dims) + freqs = mx.swapaxes(freqs, -1, -2) + freqs = mx.reshape(freqs, (freqs.shape[0], freqs.shape[1], -1)) + + # Compute cos/sin + cos_freq = mx.cos(freqs) + sin_freq = mx.sin(freqs) # Prepare based on rope type if rope_type == LTXRopeType.SPLIT: @@ -498,31 +506,27 @@ def _precompute_freqs_cis_double_precision( # Add padding if pad_size > 0: - cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64) - sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64) - cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1) - sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1) + cos_padding = mx.ones((*cos_freq.shape[:-1], pad_size), dtype=mx.float32) + sin_padding = mx.zeros((*sin_freq.shape[:-1], pad_size), dtype=mx.float32) + cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1) # Reshape for multi-head attention: (B, T, dim//2) -> (B, H, T, dim//2//H) b, t = cos_freq.shape[0], cos_freq.shape[1] - cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1) - sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1) - cos_freq = np.swapaxes(cos_freq, 1, 2) - sin_freq = np.swapaxes(sin_freq, 1, 2) + cos_freq = mx.reshape(cos_freq, (b, t, num_attention_heads, -1)) + sin_freq = mx.reshape(sin_freq, (b, t, num_attention_heads, -1)) + cos_freq = mx.swapaxes(cos_freq, 1, 2) + sin_freq = mx.swapaxes(sin_freq, 1, 2) else: # Interleaved - cos_freq = np.repeat(cos_freq, 2, axis=-1) - sin_freq = np.repeat(sin_freq, 2, axis=-1) + cos_freq = mx.repeat(cos_freq, 2, axis=-1) + sin_freq = mx.repeat(sin_freq, 2, axis=-1) pad_size = dim % n_elem if pad_size > 0: - cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64) - sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64) - cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1) - sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1) - - # Convert back to MLX (float32 for GPU compatibility) - cos_freq = mx.array(cos_freq.astype(np.float32)) - sin_freq = mx.array(sin_freq.astype(np.float32)) + cos_padding = mx.ones((*cos_freq.shape[:-1], pad_size), dtype=mx.float32) + sin_padding = mx.zeros((*sin_freq.shape[:-1], pad_size), dtype=mx.float32) + cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1) return cos_freq, sin_freq diff --git a/mlx_video/models/ltx/transformer.py b/mlx_video/models/ltx/transformer.py index 60ee7ec..5a60989 100644 --- a/mlx_video/models/ltx/transformer.py +++ b/mlx_video/models/ltx/transformer.py @@ -12,12 +12,14 @@ from mlx_video.utils import rms_norm @dataclass(frozen=True) class Modality: - latent: mx.array - timesteps: mx.array - positions: mx.array - context: mx.array + latent: mx.array + timesteps: mx.array + positions: mx.array + context: mx.array enabled: bool = True context_mask: Optional[mx.array] = None + # Optional precomputed positional embeddings (RoPE) to avoid recomputation + positional_embeddings: Optional[Tuple[mx.array, mx.array]] = None @dataclass(frozen=True)