From f5134fa172b93e2fba5f8126d7b4980da845391c Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Thu, 15 Jan 2026 12:49:21 +0100 Subject: [PATCH] adjust gelu and precision --- mlx_video/models/ltx/rope.py | 8 +++++++- mlx_video/models/ltx/text_encoder.py | 5 ++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mlx_video/models/ltx/rope.py b/mlx_video/models/ltx/rope.py index 65fc82d..54b721a 100644 --- a/mlx_video/models/ltx/rope.py +++ b/mlx_video/models/ltx/rope.py @@ -49,6 +49,12 @@ def apply_interleaved_rotary_emb( Returns: Tensor with interleaved rotary embeddings applied """ + # Compute in float32 for better precision + input_dtype = input_tensor.dtype + input_tensor = input_tensor.astype(mx.float32) + cos_freqs = cos_freqs.astype(mx.float32) + sin_freqs = sin_freqs.astype(mx.float32) + # Reshape to pair adjacent dimensions: (..., dim) -> (..., dim/2, 2) shape = input_tensor.shape input_tensor = mx.reshape(input_tensor, shape[:-1] + (shape[-1] // 2, 2)) @@ -67,7 +73,7 @@ def apply_interleaved_rotary_emb( # Apply rotary embeddings out = input_tensor * cos_freqs + t_rot * sin_freqs - return out + return out.astype(input_dtype) def rotate_half_interleaved(x: mx.array) -> mx.array: diff --git a/mlx_video/models/ltx/text_encoder.py b/mlx_video/models/ltx/text_encoder.py index 8837f8c..cbcbcb8 100644 --- a/mlx_video/models/ltx/text_encoder.py +++ b/mlx_video/models/ltx/text_encoder.py @@ -227,7 +227,7 @@ class GEGLU(nn.Module): self.proj = nn.Linear(in_dim, out_dim, bias=True) def __call__(self, x: mx.array) -> mx.array: - return nn.gelu_approx(self.proj(x)) + return nn.gelu(self.proj(x)) class ConnectorFeedForward(nn.Module): @@ -308,7 +308,6 @@ class Embeddings1DConnector(nn.Module): def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]: """Compute RoPE frequencies for connector (INTERLEAVED type). - Matches PyTorch: generate_freq_grid_pytorch + generate_freqs + interleaved_freqs_cis Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim). """ @@ -327,7 +326,7 @@ class Embeddings1DConnector(nn.Module): log_start = np.log(start) / np.log(theta) # = 0 log_end = np.log(end) / np.log(theta) # = 1 lin_space = np.linspace(log_start, log_end, num_indices, dtype=np.float64) - indices = np.power(theta, lin_space) * (np.pi / 2) + indices = (np.power(theta, lin_space) * (np.pi / 2)).astype(np.float32) # Generate positions and compute freqs (matches generate_freqs) positions = np.arange(seq_len, dtype=np.float64)