Optimize positional embedding handling in TransformerArgsPreprocessor and improve RoPE frequency computation in _precompute_freqs_cis_double_precision for enhanced performance and precision.

This commit is contained in:
Prince Canuma
2026-01-18 11:13:32 +01:00
parent 62fc4805a0
commit e483eab039
3 changed files with 72 additions and 61 deletions

View File

@@ -121,13 +121,18 @@ class TransformerArgsPreprocessor:
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype) timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype)
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask) context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype) attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
pe = self._prepare_positional_embeddings(
positions=modality.positions, # Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation)
inner_dim=self.inner_dim, if modality.positional_embeddings is not None:
max_pos=self.max_pos, pe = modality.positional_embeddings
use_middle_indices_grid=self.use_middle_indices_grid, else:
num_attention_heads=self.num_attention_heads, pe = self._prepare_positional_embeddings(
) positions=modality.positions,
inner_dim=self.inner_dim,
max_pos=self.max_pos,
use_middle_indices_grid=self.use_middle_indices_grid,
num_attention_heads=self.num_attention_heads,
)
return TransformerArgs( return TransformerArgs(
x=x, x=x,

View File

@@ -1,9 +1,8 @@
import math import math
from typing import Callable, List, Optional, Tuple from typing import List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import numpy as np
from mlx_video.models.ltx.config import LTXRopeType from mlx_video.models.ltx.config import LTXRopeType
@@ -429,66 +428,75 @@ def _precompute_freqs_cis_double_precision(
num_attention_heads: int, num_attention_heads: int,
rope_type: LTXRopeType, rope_type: LTXRopeType,
) -> Tuple[mx.array, mx.array]: ) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies with higher precision using float32.
This version stays entirely in MLX/GPU, avoiding expensive NumPy round-trips.
Uses float32 for computation precision (sufficient for RoPE).
"""
# Warn if positions are bfloat16 - this causes quality degradation # Warn if positions are bfloat16 - this causes quality degradation
if indices_grid.dtype == mx.bfloat16: if indices_grid.dtype == mx.bfloat16:
import warnings import warnings
warnings.warn( warnings.warn(
"Position grid has dtype bfloat16, which causes precision loss in RoPE that causes quality degradation in generated videos/audio. " "Position grid has dtype bfloat16, which causes precision loss in RoPE. "
"Use float32 for position grids to avoid quality degradation. " "Use float32 for position grids to avoid quality degradation.",
"See tests/test_rope.py::test_bfloat16_positions_cause_precision_loss",
UserWarning, UserWarning,
stacklevel=2 stacklevel=2
) )
# Convert to numpy float64 (first to float32 for numpy compatibility) # Cast to float32 for computation (stay on GPU, no NumPy/CPU conversion)
# Note: If input is bfloat16, precision is already lost at this step indices_grid_f32 = indices_grid.astype(mx.float32)
indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
# Generate frequency indices in float64 n_pos_dims = indices_grid_f32.shape[1]
n_pos_dims = indices_grid_np.shape[1]
n_elem = 2 * n_pos_dims n_elem = 2 * n_pos_dims
# Compute log-spaced frequencies # Compute log-spaced frequencies in float32
log_start = math.log(1.0) / math.log(theta) log_start = math.log(1.0) / math.log(theta)
log_end = math.log(theta) / math.log(theta) log_end = math.log(theta) / math.log(theta)
num_indices = dim // n_elem num_indices = dim // n_elem
if num_indices == 0: if num_indices == 0:
num_indices = 1 num_indices = 1
lin_space = np.linspace(log_start, log_end, num_indices)
indices_np = np.power(theta, lin_space) * (math.pi / 2) lin_space = mx.linspace(log_start, log_end, num_indices)
freq_indices = mx.power(mx.array(theta, dtype=mx.float32), lin_space) * (math.pi / 2)
# Handle middle indices grid # Handle middle indices grid
# Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise # Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise
if use_middle_indices_grid: if use_middle_indices_grid:
assert len(indices_grid_np.shape) == 4 assert len(indices_grid_f32.shape) == 4
assert indices_grid_np.shape[-1] == 2 assert indices_grid_f32.shape[-1] == 2
indices_grid_start = indices_grid_np[..., 0] indices_grid_start = indices_grid_f32[..., 0]
indices_grid_end = indices_grid_np[..., 1] indices_grid_end = indices_grid_f32[..., 1]
indices_grid_np = (indices_grid_start + indices_grid_end) / 2.0 indices_grid_f32 = (indices_grid_start + indices_grid_end) / 2.0
elif len(indices_grid_np.shape) == 4: elif len(indices_grid_f32.shape) == 4:
indices_grid_np = indices_grid_np[..., 0] indices_grid_f32 = indices_grid_f32[..., 0]
# After handling: indices_grid_np shape is (B, n_dims, T) # After handling: indices_grid_f32 shape is (B, n_dims, T)
# Get fractional positions: (B, n_dims, T) -> (B, T, n_dims) # Get fractional positions: (B, n_dims, T) -> (B, T, n_dims)
batch_size = indices_grid_np.shape[0] # Compute fractional positions for each dimension
seq_len = indices_grid_np.shape[2] fractional_list = []
fractional_positions = np.zeros((batch_size, seq_len, n_pos_dims), dtype=np.float64)
for i in range(n_pos_dims): for i in range(n_pos_dims):
# indices_grid_np[:, i, :] has shape (B, T) frac = indices_grid_f32[:, i, :] / max_pos[i] # (B, T)
fractional_positions[:, :, i] = indices_grid_np[:, i, :] / max_pos[i] fractional_list.append(frac)
# Stack: (B, T, n_dims)
fractional_positions = mx.stack(fractional_list, axis=-1)
# Scale to [-1, 1] # Scale to [-1, 1]
scaled_positions = fractional_positions * 2 - 1 scaled_positions = fractional_positions * 2 - 1
# Compute frequencies: outer product # Compute frequencies: outer product
freqs = np.expand_dims(scaled_positions, axis=-1) * indices_np.reshape(1, 1, 1, -1) # scaled_positions: (B, T, n_dims) -> (B, T, n_dims, 1)
freqs = np.swapaxes(freqs, -1, -2) # freq_indices: (num_indices,) -> (1, 1, 1, num_indices)
freqs = freqs.reshape(freqs.shape[:-2] + (-1,)) freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(freq_indices, (1, 1, 1, -1))
# freqs: (B, T, n_dims, num_indices)
# Compute cos/sin in float64 # Transpose and flatten: (B, T, n_dims, num_indices) -> (B, T, num_indices, n_dims) -> (B, T, num_indices * n_dims)
cos_freq = np.cos(freqs) freqs = mx.swapaxes(freqs, -1, -2)
sin_freq = np.sin(freqs) freqs = mx.reshape(freqs, (freqs.shape[0], freqs.shape[1], -1))
# Compute cos/sin
cos_freq = mx.cos(freqs)
sin_freq = mx.sin(freqs)
# Prepare based on rope type # Prepare based on rope type
if rope_type == LTXRopeType.SPLIT: if rope_type == LTXRopeType.SPLIT:
@@ -498,31 +506,27 @@ def _precompute_freqs_cis_double_precision(
# Add padding # Add padding
if pad_size > 0: if pad_size > 0:
cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64) cos_padding = mx.ones((*cos_freq.shape[:-1], pad_size), dtype=mx.float32)
sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64) sin_padding = mx.zeros((*sin_freq.shape[:-1], pad_size), dtype=mx.float32)
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1) cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1) sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
# Reshape for multi-head attention: (B, T, dim//2) -> (B, H, T, dim//2//H) # Reshape for multi-head attention: (B, T, dim//2) -> (B, H, T, dim//2//H)
b, t = cos_freq.shape[0], cos_freq.shape[1] b, t = cos_freq.shape[0], cos_freq.shape[1]
cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1) cos_freq = mx.reshape(cos_freq, (b, t, num_attention_heads, -1))
sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1) sin_freq = mx.reshape(sin_freq, (b, t, num_attention_heads, -1))
cos_freq = np.swapaxes(cos_freq, 1, 2) cos_freq = mx.swapaxes(cos_freq, 1, 2)
sin_freq = np.swapaxes(sin_freq, 1, 2) sin_freq = mx.swapaxes(sin_freq, 1, 2)
else: else:
# Interleaved # Interleaved
cos_freq = np.repeat(cos_freq, 2, axis=-1) cos_freq = mx.repeat(cos_freq, 2, axis=-1)
sin_freq = np.repeat(sin_freq, 2, axis=-1) sin_freq = mx.repeat(sin_freq, 2, axis=-1)
pad_size = dim % n_elem pad_size = dim % n_elem
if pad_size > 0: if pad_size > 0:
cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64) cos_padding = mx.ones((*cos_freq.shape[:-1], pad_size), dtype=mx.float32)
sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64) sin_padding = mx.zeros((*sin_freq.shape[:-1], pad_size), dtype=mx.float32)
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1) cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1) sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
# Convert back to MLX (float32 for GPU compatibility)
cos_freq = mx.array(cos_freq.astype(np.float32))
sin_freq = mx.array(sin_freq.astype(np.float32))
return cos_freq, sin_freq return cos_freq, sin_freq

View File

@@ -18,6 +18,8 @@ class Modality:
context: mx.array context: mx.array
enabled: bool = True enabled: bool = True
context_mask: Optional[mx.array] = None context_mask: Optional[mx.array] = None
# Optional precomputed positional embeddings (RoPE) to avoid recomputation
positional_embeddings: Optional[Tuple[mx.array, mx.array]] = None
@dataclass(frozen=True) @dataclass(frozen=True)