Optimize positional embedding handling in TransformerArgsPreprocessor and improve RoPE frequency computation in _precompute_freqs_cis_double_precision for enhanced performance and precision.
This commit is contained in:
@@ -121,6 +121,11 @@ class TransformerArgsPreprocessor:
|
||||
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype)
|
||||
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
|
||||
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
|
||||
|
||||
# Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation)
|
||||
if modality.positional_embeddings is not None:
|
||||
pe = modality.positional_embeddings
|
||||
else:
|
||||
pe = self._prepare_positional_embeddings(
|
||||
positions=modality.positions,
|
||||
inner_dim=self.inner_dim,
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
|
||||
import math
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
from mlx_video.models.ltx.config import LTXRopeType
|
||||
|
||||
@@ -429,66 +428,75 @@ def _precompute_freqs_cis_double_precision(
|
||||
num_attention_heads: int,
|
||||
rope_type: LTXRopeType,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute RoPE frequencies with higher precision using float32.
|
||||
|
||||
This version stays entirely in MLX/GPU, avoiding expensive NumPy round-trips.
|
||||
Uses float32 for computation precision (sufficient for RoPE).
|
||||
"""
|
||||
# Warn if positions are bfloat16 - this causes quality degradation
|
||||
if indices_grid.dtype == mx.bfloat16:
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Position grid has dtype bfloat16, which causes precision loss in RoPE that causes quality degradation in generated videos/audio. "
|
||||
"Use float32 for position grids to avoid quality degradation. "
|
||||
"See tests/test_rope.py::test_bfloat16_positions_cause_precision_loss",
|
||||
"Position grid has dtype bfloat16, which causes precision loss in RoPE. "
|
||||
"Use float32 for position grids to avoid quality degradation.",
|
||||
UserWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
|
||||
# Convert to numpy float64 (first to float32 for numpy compatibility)
|
||||
# Note: If input is bfloat16, precision is already lost at this step
|
||||
indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
|
||||
# Cast to float32 for computation (stay on GPU, no NumPy/CPU conversion)
|
||||
indices_grid_f32 = indices_grid.astype(mx.float32)
|
||||
|
||||
# Generate frequency indices in float64
|
||||
n_pos_dims = indices_grid_np.shape[1]
|
||||
n_pos_dims = indices_grid_f32.shape[1]
|
||||
n_elem = 2 * n_pos_dims
|
||||
|
||||
# Compute log-spaced frequencies
|
||||
# Compute log-spaced frequencies in float32
|
||||
log_start = math.log(1.0) / math.log(theta)
|
||||
log_end = math.log(theta) / math.log(theta)
|
||||
num_indices = dim // n_elem
|
||||
if num_indices == 0:
|
||||
num_indices = 1
|
||||
lin_space = np.linspace(log_start, log_end, num_indices)
|
||||
indices_np = np.power(theta, lin_space) * (math.pi / 2)
|
||||
|
||||
lin_space = mx.linspace(log_start, log_end, num_indices)
|
||||
freq_indices = mx.power(mx.array(theta, dtype=mx.float32), lin_space) * (math.pi / 2)
|
||||
|
||||
# Handle middle indices grid
|
||||
# Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise
|
||||
if use_middle_indices_grid:
|
||||
assert len(indices_grid_np.shape) == 4
|
||||
assert indices_grid_np.shape[-1] == 2
|
||||
indices_grid_start = indices_grid_np[..., 0]
|
||||
indices_grid_end = indices_grid_np[..., 1]
|
||||
indices_grid_np = (indices_grid_start + indices_grid_end) / 2.0
|
||||
elif len(indices_grid_np.shape) == 4:
|
||||
indices_grid_np = indices_grid_np[..., 0]
|
||||
# After handling: indices_grid_np shape is (B, n_dims, T)
|
||||
assert len(indices_grid_f32.shape) == 4
|
||||
assert indices_grid_f32.shape[-1] == 2
|
||||
indices_grid_start = indices_grid_f32[..., 0]
|
||||
indices_grid_end = indices_grid_f32[..., 1]
|
||||
indices_grid_f32 = (indices_grid_start + indices_grid_end) / 2.0
|
||||
elif len(indices_grid_f32.shape) == 4:
|
||||
indices_grid_f32 = indices_grid_f32[..., 0]
|
||||
# After handling: indices_grid_f32 shape is (B, n_dims, T)
|
||||
|
||||
# Get fractional positions: (B, n_dims, T) -> (B, T, n_dims)
|
||||
batch_size = indices_grid_np.shape[0]
|
||||
seq_len = indices_grid_np.shape[2]
|
||||
fractional_positions = np.zeros((batch_size, seq_len, n_pos_dims), dtype=np.float64)
|
||||
# Compute fractional positions for each dimension
|
||||
fractional_list = []
|
||||
for i in range(n_pos_dims):
|
||||
# indices_grid_np[:, i, :] has shape (B, T)
|
||||
fractional_positions[:, :, i] = indices_grid_np[:, i, :] / max_pos[i]
|
||||
frac = indices_grid_f32[:, i, :] / max_pos[i] # (B, T)
|
||||
fractional_list.append(frac)
|
||||
|
||||
# Stack: (B, T, n_dims)
|
||||
fractional_positions = mx.stack(fractional_list, axis=-1)
|
||||
|
||||
# Scale to [-1, 1]
|
||||
scaled_positions = fractional_positions * 2 - 1
|
||||
|
||||
# Compute frequencies: outer product
|
||||
freqs = np.expand_dims(scaled_positions, axis=-1) * indices_np.reshape(1, 1, 1, -1)
|
||||
freqs = np.swapaxes(freqs, -1, -2)
|
||||
freqs = freqs.reshape(freqs.shape[:-2] + (-1,))
|
||||
# scaled_positions: (B, T, n_dims) -> (B, T, n_dims, 1)
|
||||
# freq_indices: (num_indices,) -> (1, 1, 1, num_indices)
|
||||
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(freq_indices, (1, 1, 1, -1))
|
||||
# freqs: (B, T, n_dims, num_indices)
|
||||
|
||||
# Compute cos/sin in float64
|
||||
cos_freq = np.cos(freqs)
|
||||
sin_freq = np.sin(freqs)
|
||||
# Transpose and flatten: (B, T, n_dims, num_indices) -> (B, T, num_indices, n_dims) -> (B, T, num_indices * n_dims)
|
||||
freqs = mx.swapaxes(freqs, -1, -2)
|
||||
freqs = mx.reshape(freqs, (freqs.shape[0], freqs.shape[1], -1))
|
||||
|
||||
# Compute cos/sin
|
||||
cos_freq = mx.cos(freqs)
|
||||
sin_freq = mx.sin(freqs)
|
||||
|
||||
# Prepare based on rope type
|
||||
if rope_type == LTXRopeType.SPLIT:
|
||||
@@ -498,31 +506,27 @@ def _precompute_freqs_cis_double_precision(
|
||||
|
||||
# Add padding
|
||||
if pad_size > 0:
|
||||
cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64)
|
||||
sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64)
|
||||
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
cos_padding = mx.ones((*cos_freq.shape[:-1], pad_size), dtype=mx.float32)
|
||||
sin_padding = mx.zeros((*sin_freq.shape[:-1], pad_size), dtype=mx.float32)
|
||||
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
|
||||
# Reshape for multi-head attention: (B, T, dim//2) -> (B, H, T, dim//2//H)
|
||||
b, t = cos_freq.shape[0], cos_freq.shape[1]
|
||||
cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1)
|
||||
sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1)
|
||||
cos_freq = np.swapaxes(cos_freq, 1, 2)
|
||||
sin_freq = np.swapaxes(sin_freq, 1, 2)
|
||||
cos_freq = mx.reshape(cos_freq, (b, t, num_attention_heads, -1))
|
||||
sin_freq = mx.reshape(sin_freq, (b, t, num_attention_heads, -1))
|
||||
cos_freq = mx.swapaxes(cos_freq, 1, 2)
|
||||
sin_freq = mx.swapaxes(sin_freq, 1, 2)
|
||||
else:
|
||||
# Interleaved
|
||||
cos_freq = np.repeat(cos_freq, 2, axis=-1)
|
||||
sin_freq = np.repeat(sin_freq, 2, axis=-1)
|
||||
cos_freq = mx.repeat(cos_freq, 2, axis=-1)
|
||||
sin_freq = mx.repeat(sin_freq, 2, axis=-1)
|
||||
|
||||
pad_size = dim % n_elem
|
||||
if pad_size > 0:
|
||||
cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64)
|
||||
sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64)
|
||||
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
|
||||
# Convert back to MLX (float32 for GPU compatibility)
|
||||
cos_freq = mx.array(cos_freq.astype(np.float32))
|
||||
sin_freq = mx.array(sin_freq.astype(np.float32))
|
||||
cos_padding = mx.ones((*cos_freq.shape[:-1], pad_size), dtype=mx.float32)
|
||||
sin_padding = mx.zeros((*sin_freq.shape[:-1], pad_size), dtype=mx.float32)
|
||||
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
|
||||
return cos_freq, sin_freq
|
||||
|
||||
@@ -18,6 +18,8 @@ class Modality:
|
||||
context: mx.array
|
||||
enabled: bool = True
|
||||
context_mask: Optional[mx.array] = None
|
||||
# Optional precomputed positional embeddings (RoPE) to avoid recomputation
|
||||
positional_embeddings: Optional[Tuple[mx.array, mx.array]] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
Reference in New Issue
Block a user