Refactor weight loading and sanitization processes for audio models
This commit is contained in:
@@ -428,11 +428,14 @@ def _precompute_freqs_cis_double_precision(
|
||||
num_attention_heads: int,
|
||||
rope_type: LTXRopeType,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute RoPE frequencies with higher precision using float32.
|
||||
"""Compute RoPE frequencies with higher precision using float64 for frequency grid.
|
||||
|
||||
This version stays entirely in MLX/GPU, avoiding expensive NumPy round-trips.
|
||||
Uses float32 for computation precision (sufficient for RoPE).
|
||||
Matches PyTorch's approach: uses NumPy float64 for the critical frequency grid
|
||||
computation (log-spaced values), then converts to float32 for the final tensor.
|
||||
This provides better numerical precision in the frequency generation phase.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Warn if positions are bfloat16 - this causes quality degradation
|
||||
if indices_grid.dtype == mx.bfloat16:
|
||||
import warnings
|
||||
@@ -443,21 +446,27 @@ def _precompute_freqs_cis_double_precision(
|
||||
stacklevel=2
|
||||
)
|
||||
|
||||
# Cast to float32 for computation (stay on GPU, no NumPy/CPU conversion)
|
||||
# Cast to float32 for position computation
|
||||
indices_grid_f32 = indices_grid.astype(mx.float32)
|
||||
|
||||
n_pos_dims = indices_grid_f32.shape[1]
|
||||
n_elem = 2 * n_pos_dims
|
||||
|
||||
# Compute log-spaced frequencies in float32
|
||||
log_start = math.log(1.0) / math.log(theta)
|
||||
log_end = math.log(theta) / math.log(theta)
|
||||
# Compute log-spaced frequencies in float64 (matching PyTorch's generate_freq_grid_np)
|
||||
# This is the critical precision step - PyTorch uses np.float64 here
|
||||
log_start = np.log(1.0) / np.log(theta)
|
||||
log_end = np.log(theta) / np.log(theta) # = 1.0
|
||||
num_indices = dim // n_elem
|
||||
if num_indices == 0:
|
||||
num_indices = 1
|
||||
|
||||
lin_space = mx.linspace(log_start, log_end, num_indices)
|
||||
freq_indices = mx.power(mx.array(theta, dtype=mx.float32), lin_space) * (math.pi / 2)
|
||||
# Use numpy float64 for the linspace computation (matches PyTorch)
|
||||
pow_indices = np.power(
|
||||
theta,
|
||||
np.linspace(log_start, log_end, num_indices, dtype=np.float64),
|
||||
)
|
||||
# Convert to float32 tensor (matches PyTorch: torch.tensor(..., dtype=torch.float32))
|
||||
freq_indices = mx.array(pow_indices * (math.pi / 2), dtype=mx.float32)
|
||||
|
||||
# Handle middle indices grid
|
||||
# Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise
|
||||
|
||||
Reference in New Issue
Block a user