Optimize positional embedding handling in TransformerArgsPreprocessor and improve RoPE frequency computation in _precompute_freqs_cis_double_precision for enhanced performance and precision.
This commit is contained in:
@@ -121,13 +121,18 @@ class TransformerArgsPreprocessor:
|
|||||||
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype)
|
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype)
|
||||||
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
|
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
|
||||||
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
|
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
|
||||||
pe = self._prepare_positional_embeddings(
|
|
||||||
positions=modality.positions,
|
# Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation)
|
||||||
inner_dim=self.inner_dim,
|
if modality.positional_embeddings is not None:
|
||||||
max_pos=self.max_pos,
|
pe = modality.positional_embeddings
|
||||||
use_middle_indices_grid=self.use_middle_indices_grid,
|
else:
|
||||||
num_attention_heads=self.num_attention_heads,
|
pe = self._prepare_positional_embeddings(
|
||||||
)
|
positions=modality.positions,
|
||||||
|
inner_dim=self.inner_dim,
|
||||||
|
max_pos=self.max_pos,
|
||||||
|
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
)
|
||||||
|
|
||||||
return TransformerArgs(
|
return TransformerArgs(
|
||||||
x=x,
|
x=x,
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Callable, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from mlx_video.models.ltx.config import LTXRopeType
|
from mlx_video.models.ltx.config import LTXRopeType
|
||||||
|
|
||||||
@@ -429,66 +428,75 @@ def _precompute_freqs_cis_double_precision(
|
|||||||
num_attention_heads: int,
|
num_attention_heads: int,
|
||||||
rope_type: LTXRopeType,
|
rope_type: LTXRopeType,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
"""Compute RoPE frequencies with higher precision using float32.
|
||||||
|
|
||||||
|
This version stays entirely in MLX/GPU, avoiding expensive NumPy round-trips.
|
||||||
|
Uses float32 for computation precision (sufficient for RoPE).
|
||||||
|
"""
|
||||||
# Warn if positions are bfloat16 - this causes quality degradation
|
# Warn if positions are bfloat16 - this causes quality degradation
|
||||||
if indices_grid.dtype == mx.bfloat16:
|
if indices_grid.dtype == mx.bfloat16:
|
||||||
import warnings
|
import warnings
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Position grid has dtype bfloat16, which causes precision loss in RoPE that causes quality degradation in generated videos/audio. "
|
"Position grid has dtype bfloat16, which causes precision loss in RoPE. "
|
||||||
"Use float32 for position grids to avoid quality degradation. "
|
"Use float32 for position grids to avoid quality degradation.",
|
||||||
"See tests/test_rope.py::test_bfloat16_positions_cause_precision_loss",
|
|
||||||
UserWarning,
|
UserWarning,
|
||||||
stacklevel=2
|
stacklevel=2
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert to numpy float64 (first to float32 for numpy compatibility)
|
# Cast to float32 for computation (stay on GPU, no NumPy/CPU conversion)
|
||||||
# Note: If input is bfloat16, precision is already lost at this step
|
indices_grid_f32 = indices_grid.astype(mx.float32)
|
||||||
indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
|
|
||||||
|
|
||||||
# Generate frequency indices in float64
|
n_pos_dims = indices_grid_f32.shape[1]
|
||||||
n_pos_dims = indices_grid_np.shape[1]
|
|
||||||
n_elem = 2 * n_pos_dims
|
n_elem = 2 * n_pos_dims
|
||||||
|
|
||||||
# Compute log-spaced frequencies
|
# Compute log-spaced frequencies in float32
|
||||||
log_start = math.log(1.0) / math.log(theta)
|
log_start = math.log(1.0) / math.log(theta)
|
||||||
log_end = math.log(theta) / math.log(theta)
|
log_end = math.log(theta) / math.log(theta)
|
||||||
num_indices = dim // n_elem
|
num_indices = dim // n_elem
|
||||||
if num_indices == 0:
|
if num_indices == 0:
|
||||||
num_indices = 1
|
num_indices = 1
|
||||||
lin_space = np.linspace(log_start, log_end, num_indices)
|
|
||||||
indices_np = np.power(theta, lin_space) * (math.pi / 2)
|
lin_space = mx.linspace(log_start, log_end, num_indices)
|
||||||
|
freq_indices = mx.power(mx.array(theta, dtype=mx.float32), lin_space) * (math.pi / 2)
|
||||||
|
|
||||||
# Handle middle indices grid
|
# Handle middle indices grid
|
||||||
# Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise
|
# Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise
|
||||||
if use_middle_indices_grid:
|
if use_middle_indices_grid:
|
||||||
assert len(indices_grid_np.shape) == 4
|
assert len(indices_grid_f32.shape) == 4
|
||||||
assert indices_grid_np.shape[-1] == 2
|
assert indices_grid_f32.shape[-1] == 2
|
||||||
indices_grid_start = indices_grid_np[..., 0]
|
indices_grid_start = indices_grid_f32[..., 0]
|
||||||
indices_grid_end = indices_grid_np[..., 1]
|
indices_grid_end = indices_grid_f32[..., 1]
|
||||||
indices_grid_np = (indices_grid_start + indices_grid_end) / 2.0
|
indices_grid_f32 = (indices_grid_start + indices_grid_end) / 2.0
|
||||||
elif len(indices_grid_np.shape) == 4:
|
elif len(indices_grid_f32.shape) == 4:
|
||||||
indices_grid_np = indices_grid_np[..., 0]
|
indices_grid_f32 = indices_grid_f32[..., 0]
|
||||||
# After handling: indices_grid_np shape is (B, n_dims, T)
|
# After handling: indices_grid_f32 shape is (B, n_dims, T)
|
||||||
|
|
||||||
# Get fractional positions: (B, n_dims, T) -> (B, T, n_dims)
|
# Get fractional positions: (B, n_dims, T) -> (B, T, n_dims)
|
||||||
batch_size = indices_grid_np.shape[0]
|
# Compute fractional positions for each dimension
|
||||||
seq_len = indices_grid_np.shape[2]
|
fractional_list = []
|
||||||
fractional_positions = np.zeros((batch_size, seq_len, n_pos_dims), dtype=np.float64)
|
|
||||||
for i in range(n_pos_dims):
|
for i in range(n_pos_dims):
|
||||||
# indices_grid_np[:, i, :] has shape (B, T)
|
frac = indices_grid_f32[:, i, :] / max_pos[i] # (B, T)
|
||||||
fractional_positions[:, :, i] = indices_grid_np[:, i, :] / max_pos[i]
|
fractional_list.append(frac)
|
||||||
|
|
||||||
|
# Stack: (B, T, n_dims)
|
||||||
|
fractional_positions = mx.stack(fractional_list, axis=-1)
|
||||||
|
|
||||||
# Scale to [-1, 1]
|
# Scale to [-1, 1]
|
||||||
scaled_positions = fractional_positions * 2 - 1
|
scaled_positions = fractional_positions * 2 - 1
|
||||||
|
|
||||||
# Compute frequencies: outer product
|
# Compute frequencies: outer product
|
||||||
freqs = np.expand_dims(scaled_positions, axis=-1) * indices_np.reshape(1, 1, 1, -1)
|
# scaled_positions: (B, T, n_dims) -> (B, T, n_dims, 1)
|
||||||
freqs = np.swapaxes(freqs, -1, -2)
|
# freq_indices: (num_indices,) -> (1, 1, 1, num_indices)
|
||||||
freqs = freqs.reshape(freqs.shape[:-2] + (-1,))
|
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(freq_indices, (1, 1, 1, -1))
|
||||||
|
# freqs: (B, T, n_dims, num_indices)
|
||||||
|
|
||||||
# Compute cos/sin in float64
|
# Transpose and flatten: (B, T, n_dims, num_indices) -> (B, T, num_indices, n_dims) -> (B, T, num_indices * n_dims)
|
||||||
cos_freq = np.cos(freqs)
|
freqs = mx.swapaxes(freqs, -1, -2)
|
||||||
sin_freq = np.sin(freqs)
|
freqs = mx.reshape(freqs, (freqs.shape[0], freqs.shape[1], -1))
|
||||||
|
|
||||||
|
# Compute cos/sin
|
||||||
|
cos_freq = mx.cos(freqs)
|
||||||
|
sin_freq = mx.sin(freqs)
|
||||||
|
|
||||||
# Prepare based on rope type
|
# Prepare based on rope type
|
||||||
if rope_type == LTXRopeType.SPLIT:
|
if rope_type == LTXRopeType.SPLIT:
|
||||||
@@ -498,31 +506,27 @@ def _precompute_freqs_cis_double_precision(
|
|||||||
|
|
||||||
# Add padding
|
# Add padding
|
||||||
if pad_size > 0:
|
if pad_size > 0:
|
||||||
cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64)
|
cos_padding = mx.ones((*cos_freq.shape[:-1], pad_size), dtype=mx.float32)
|
||||||
sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64)
|
sin_padding = mx.zeros((*sin_freq.shape[:-1], pad_size), dtype=mx.float32)
|
||||||
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1)
|
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
|
||||||
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1)
|
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
|
||||||
|
|
||||||
# Reshape for multi-head attention: (B, T, dim//2) -> (B, H, T, dim//2//H)
|
# Reshape for multi-head attention: (B, T, dim//2) -> (B, H, T, dim//2//H)
|
||||||
b, t = cos_freq.shape[0], cos_freq.shape[1]
|
b, t = cos_freq.shape[0], cos_freq.shape[1]
|
||||||
cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1)
|
cos_freq = mx.reshape(cos_freq, (b, t, num_attention_heads, -1))
|
||||||
sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1)
|
sin_freq = mx.reshape(sin_freq, (b, t, num_attention_heads, -1))
|
||||||
cos_freq = np.swapaxes(cos_freq, 1, 2)
|
cos_freq = mx.swapaxes(cos_freq, 1, 2)
|
||||||
sin_freq = np.swapaxes(sin_freq, 1, 2)
|
sin_freq = mx.swapaxes(sin_freq, 1, 2)
|
||||||
else:
|
else:
|
||||||
# Interleaved
|
# Interleaved
|
||||||
cos_freq = np.repeat(cos_freq, 2, axis=-1)
|
cos_freq = mx.repeat(cos_freq, 2, axis=-1)
|
||||||
sin_freq = np.repeat(sin_freq, 2, axis=-1)
|
sin_freq = mx.repeat(sin_freq, 2, axis=-1)
|
||||||
|
|
||||||
pad_size = dim % n_elem
|
pad_size = dim % n_elem
|
||||||
if pad_size > 0:
|
if pad_size > 0:
|
||||||
cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64)
|
cos_padding = mx.ones((*cos_freq.shape[:-1], pad_size), dtype=mx.float32)
|
||||||
sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64)
|
sin_padding = mx.zeros((*sin_freq.shape[:-1], pad_size), dtype=mx.float32)
|
||||||
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1)
|
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
|
||||||
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1)
|
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
|
||||||
|
|
||||||
# Convert back to MLX (float32 for GPU compatibility)
|
|
||||||
cos_freq = mx.array(cos_freq.astype(np.float32))
|
|
||||||
sin_freq = mx.array(sin_freq.astype(np.float32))
|
|
||||||
|
|
||||||
return cos_freq, sin_freq
|
return cos_freq, sin_freq
|
||||||
|
|||||||
@@ -12,12 +12,14 @@ from mlx_video.utils import rms_norm
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Modality:
|
class Modality:
|
||||||
latent: mx.array
|
latent: mx.array
|
||||||
timesteps: mx.array
|
timesteps: mx.array
|
||||||
positions: mx.array
|
positions: mx.array
|
||||||
context: mx.array
|
context: mx.array
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
context_mask: Optional[mx.array] = None
|
context_mask: Optional[mx.array] = None
|
||||||
|
# Optional precomputed positional embeddings (RoPE) to avoid recomputation
|
||||||
|
positional_embeddings: Optional[Tuple[mx.array, mx.array]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user