Optimize positional embedding handling in TransformerArgsPreprocessor and improve RoPE frequency computation in _precompute_freqs_cis_double_precision for enhanced performance and precision.

This commit is contained in:
Prince Canuma
2026-01-18 11:13:32 +01:00
parent 62fc4805a0
commit e483eab039
3 changed files with 72 additions and 61 deletions

View File

@@ -121,13 +121,18 @@ class TransformerArgsPreprocessor:
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype)
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
pe = self._prepare_positional_embeddings(
positions=modality.positions,
inner_dim=self.inner_dim,
max_pos=self.max_pos,
use_middle_indices_grid=self.use_middle_indices_grid,
num_attention_heads=self.num_attention_heads,
)
# Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation)
if modality.positional_embeddings is not None:
pe = modality.positional_embeddings
else:
pe = self._prepare_positional_embeddings(
positions=modality.positions,
inner_dim=self.inner_dim,
max_pos=self.max_pos,
use_middle_indices_grid=self.use_middle_indices_grid,
num_attention_heads=self.num_attention_heads,
)
return TransformerArgs(
x=x,

View File

@@ -1,9 +1,8 @@
import math
from typing import Callable, List, Optional, Tuple
from typing import List, Optional, Tuple
import mlx.core as mx
import numpy as np
from mlx_video.models.ltx.config import LTXRopeType
@@ -429,66 +428,75 @@ def _precompute_freqs_cis_double_precision(
num_attention_heads: int,
rope_type: LTXRopeType,
) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies with higher precision using float32.
This version stays entirely in MLX/GPU, avoiding expensive NumPy round-trips.
Uses float32 for computation precision (sufficient for RoPE).
"""
# Warn if positions are bfloat16 - this causes quality degradation
if indices_grid.dtype == mx.bfloat16:
import warnings
warnings.warn(
"Position grid has dtype bfloat16, which causes precision loss in RoPE that causes quality degradation in generated videos/audio. "
"Use float32 for position grids to avoid quality degradation. "
"See tests/test_rope.py::test_bfloat16_positions_cause_precision_loss",
"Position grid has dtype bfloat16, which causes precision loss in RoPE. "
"Use float32 for position grids to avoid quality degradation.",
UserWarning,
stacklevel=2
)
# Convert to numpy float64 (first to float32 for numpy compatibility)
# Note: If input is bfloat16, precision is already lost at this step
indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
# Cast to float32 for computation (stay on GPU, no NumPy/CPU conversion)
indices_grid_f32 = indices_grid.astype(mx.float32)
# Generate frequency indices in float64
n_pos_dims = indices_grid_np.shape[1]
n_pos_dims = indices_grid_f32.shape[1]
n_elem = 2 * n_pos_dims
# Compute log-spaced frequencies
# Compute log-spaced frequencies in float32
log_start = math.log(1.0) / math.log(theta)
log_end = math.log(theta) / math.log(theta)
num_indices = dim // n_elem
if num_indices == 0:
num_indices = 1
lin_space = np.linspace(log_start, log_end, num_indices)
indices_np = np.power(theta, lin_space) * (math.pi / 2)
lin_space = mx.linspace(log_start, log_end, num_indices)
freq_indices = mx.power(mx.array(theta, dtype=mx.float32), lin_space) * (math.pi / 2)
# Handle middle indices grid
# Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise
if use_middle_indices_grid:
assert len(indices_grid_np.shape) == 4
assert indices_grid_np.shape[-1] == 2
indices_grid_start = indices_grid_np[..., 0]
indices_grid_end = indices_grid_np[..., 1]
indices_grid_np = (indices_grid_start + indices_grid_end) / 2.0
elif len(indices_grid_np.shape) == 4:
indices_grid_np = indices_grid_np[..., 0]
# After handling: indices_grid_np shape is (B, n_dims, T)
assert len(indices_grid_f32.shape) == 4
assert indices_grid_f32.shape[-1] == 2
indices_grid_start = indices_grid_f32[..., 0]
indices_grid_end = indices_grid_f32[..., 1]
indices_grid_f32 = (indices_grid_start + indices_grid_end) / 2.0
elif len(indices_grid_f32.shape) == 4:
indices_grid_f32 = indices_grid_f32[..., 0]
# After handling: indices_grid_f32 shape is (B, n_dims, T)
# Get fractional positions: (B, n_dims, T) -> (B, T, n_dims)
batch_size = indices_grid_np.shape[0]
seq_len = indices_grid_np.shape[2]
fractional_positions = np.zeros((batch_size, seq_len, n_pos_dims), dtype=np.float64)
# Compute fractional positions for each dimension
fractional_list = []
for i in range(n_pos_dims):
# indices_grid_np[:, i, :] has shape (B, T)
fractional_positions[:, :, i] = indices_grid_np[:, i, :] / max_pos[i]
frac = indices_grid_f32[:, i, :] / max_pos[i] # (B, T)
fractional_list.append(frac)
# Stack: (B, T, n_dims)
fractional_positions = mx.stack(fractional_list, axis=-1)
# Scale to [-1, 1]
scaled_positions = fractional_positions * 2 - 1
# Compute frequencies: outer product
freqs = np.expand_dims(scaled_positions, axis=-1) * indices_np.reshape(1, 1, 1, -1)
freqs = np.swapaxes(freqs, -1, -2)
freqs = freqs.reshape(freqs.shape[:-2] + (-1,))
# scaled_positions: (B, T, n_dims) -> (B, T, n_dims, 1)
# freq_indices: (num_indices,) -> (1, 1, 1, num_indices)
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(freq_indices, (1, 1, 1, -1))
# freqs: (B, T, n_dims, num_indices)
# Compute cos/sin in float64
cos_freq = np.cos(freqs)
sin_freq = np.sin(freqs)
# Transpose and flatten: (B, T, n_dims, num_indices) -> (B, T, num_indices, n_dims) -> (B, T, num_indices * n_dims)
freqs = mx.swapaxes(freqs, -1, -2)
freqs = mx.reshape(freqs, (freqs.shape[0], freqs.shape[1], -1))
# Compute cos/sin
cos_freq = mx.cos(freqs)
sin_freq = mx.sin(freqs)
# Prepare based on rope type
if rope_type == LTXRopeType.SPLIT:
@@ -498,31 +506,27 @@ def _precompute_freqs_cis_double_precision(
# Add padding
if pad_size > 0:
cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64)
sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64)
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1)
cos_padding = mx.ones((*cos_freq.shape[:-1], pad_size), dtype=mx.float32)
sin_padding = mx.zeros((*sin_freq.shape[:-1], pad_size), dtype=mx.float32)
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
# Reshape for multi-head attention: (B, T, dim//2) -> (B, H, T, dim//2//H)
b, t = cos_freq.shape[0], cos_freq.shape[1]
cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1)
sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1)
cos_freq = np.swapaxes(cos_freq, 1, 2)
sin_freq = np.swapaxes(sin_freq, 1, 2)
cos_freq = mx.reshape(cos_freq, (b, t, num_attention_heads, -1))
sin_freq = mx.reshape(sin_freq, (b, t, num_attention_heads, -1))
cos_freq = mx.swapaxes(cos_freq, 1, 2)
sin_freq = mx.swapaxes(sin_freq, 1, 2)
else:
# Interleaved
cos_freq = np.repeat(cos_freq, 2, axis=-1)
sin_freq = np.repeat(sin_freq, 2, axis=-1)
cos_freq = mx.repeat(cos_freq, 2, axis=-1)
sin_freq = mx.repeat(sin_freq, 2, axis=-1)
pad_size = dim % n_elem
if pad_size > 0:
cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64)
sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64)
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1)
# Convert back to MLX (float32 for GPU compatibility)
cos_freq = mx.array(cos_freq.astype(np.float32))
sin_freq = mx.array(sin_freq.astype(np.float32))
cos_padding = mx.ones((*cos_freq.shape[:-1], pad_size), dtype=mx.float32)
sin_padding = mx.zeros((*sin_freq.shape[:-1], pad_size), dtype=mx.float32)
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
return cos_freq, sin_freq

View File

@@ -12,12 +12,14 @@ from mlx_video.utils import rms_norm
@dataclass(frozen=True)
class Modality:
latent: mx.array
timesteps: mx.array
positions: mx.array
context: mx.array
latent: mx.array
timesteps: mx.array
positions: mx.array
context: mx.array
enabled: bool = True
context_mask: Optional[mx.array] = None
# Optional precomputed positional embeddings (RoPE) to avoid recomputation
positional_embeddings: Optional[Tuple[mx.array, mx.array]] = None
@dataclass(frozen=True)