Refactor weight loading and sanitization processes for audio models

This commit is contained in:
Prince Canuma
2026-01-23 17:31:25 +01:00
parent 2681f75d2f
commit 02bfa228d9
18 changed files with 510 additions and 498 deletions

View File

@@ -428,11 +428,14 @@ def _precompute_freqs_cis_double_precision(
num_attention_heads: int,
rope_type: LTXRopeType,
) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies with higher precision using float32.
"""Compute RoPE frequencies with higher precision using float64 for frequency grid.
This version stays entirely in MLX/GPU, avoiding expensive NumPy round-trips.
Uses float32 for computation precision (sufficient for RoPE).
Matches PyTorch's approach: uses NumPy float64 for the critical frequency grid
computation (log-spaced values), then converts to float32 for the final tensor.
This provides better numerical precision in the frequency generation phase.
"""
import numpy as np
# Warn if positions are bfloat16 - this causes quality degradation
if indices_grid.dtype == mx.bfloat16:
import warnings
@@ -443,21 +446,27 @@ def _precompute_freqs_cis_double_precision(
stacklevel=2
)
# Cast to float32 for computation (stay on GPU, no NumPy/CPU conversion)
# Cast to float32 for position computation
indices_grid_f32 = indices_grid.astype(mx.float32)
n_pos_dims = indices_grid_f32.shape[1]
n_elem = 2 * n_pos_dims
# Compute log-spaced frequencies in float32
log_start = math.log(1.0) / math.log(theta)
log_end = math.log(theta) / math.log(theta)
# Compute log-spaced frequencies in float64 (matching PyTorch's generate_freq_grid_np)
# This is the critical precision step - PyTorch uses np.float64 here
log_start = np.log(1.0) / np.log(theta)
log_end = np.log(theta) / np.log(theta) # = 1.0
num_indices = dim // n_elem
if num_indices == 0:
num_indices = 1
lin_space = mx.linspace(log_start, log_end, num_indices)
freq_indices = mx.power(mx.array(theta, dtype=mx.float32), lin_space) * (math.pi / 2)
# Use numpy float64 for the linspace computation (matches PyTorch)
pow_indices = np.power(
theta,
np.linspace(log_start, log_end, num_indices, dtype=np.float64),
)
# Convert to float32 tensor (matches PyTorch: torch.tensor(..., dtype=torch.float32))
freq_indices = mx.array(pow_indices * (math.pi / 2), dtype=mx.float32)
# Handle middle indices grid
# Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise