From 957093c29b0a5f6ccee8e05d4b5211a30b5a188b Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 14 Jan 2026 00:03:00 +0100 Subject: [PATCH] use numpy for improved float64 precision and performance --- mlx_video/models/ltx/text_encoder.py | 29 ++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/mlx_video/models/ltx/text_encoder.py b/mlx_video/models/ltx/text_encoder.py index 168b7b9..8837f8c 100644 --- a/mlx_video/models/ltx/text_encoder.py +++ b/mlx_video/models/ltx/text_encoder.py @@ -311,25 +311,26 @@ class Embeddings1DConnector(nn.Module): Matches PyTorch: generate_freq_grid_pytorch + generate_freqs + interleaved_freqs_cis Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim). """ - import math + + import numpy as np dim = self.num_heads * self.head_dim # inner_dim = 3840 theta = self.positional_embedding_theta max_pos = [1] # Default for connector n_elem = 2 * len(max_pos) # = 2 - # Generate frequency indices (matches generate_freq_grid_pytorch) start = 1.0 end = theta num_indices = dim // n_elem # 1920 - log_start = math.log(start) / math.log(theta) # = 0 - log_end = math.log(end) / math.log(theta) # = 1 - lin_space = mx.linspace(log_start, log_end, num_indices) - indices = (theta ** lin_space) * (math.pi / 2) + # Use numpy float64 for precision + log_start = np.log(start) / np.log(theta) # = 0 + log_end = np.log(end) / np.log(theta) # = 1 + lin_space = np.linspace(log_start, log_end, num_indices, dtype=np.float64) + indices = np.power(theta, lin_space) * (np.pi / 2) # Generate positions and compute freqs (matches generate_freqs) - positions = mx.arange(seq_len).astype(mx.float32) + positions = np.arange(seq_len, dtype=np.float64) # fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1) # scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1 scaled_positions = positions * 2 - 1 # Shape: (seq_len,) @@ -339,17 +340,17 @@ class Embeddings1DConnector(nn.Module): freqs = scaled_positions[:, None] * indices[None, :] # Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis) - cos_freq = mx.cos(freqs) - sin_freq = mx.sin(freqs) + cos_freq = np.cos(freqs) + sin_freq = np.sin(freqs) # repeat_interleave: (seq_len, num_indices) -> (seq_len, dim) # Pattern: [c0, c0, c1, c1, c2, c2, ...] - cos_full = mx.repeat(cos_freq, 2, axis=-1) - sin_full = mx.repeat(sin_freq, 2, axis=-1) + cos_full = np.repeat(cos_freq, 2, axis=-1) + sin_full = np.repeat(sin_freq, 2, axis=-1) - # Add batch dimension: (1, seq_len, dim) - cos_full = cos_full[None, :, :] - sin_full = sin_full[None, :, :] + # Add batch dimension and convert to MLX: (1, seq_len, dim) + cos_full = mx.array(cos_full[None, :, :].astype(np.float32)) + sin_full = mx.array(sin_full[None, :, :].astype(np.float32)) return cos_full.astype(dtype), sin_full.astype(dtype)