ensure dtype cast

This commit is contained in:
Prince Canuma
2026-01-17 13:03:48 +01:00
parent e4cdbb7eab
commit 883c6b0ad8
6 changed files with 52 additions and 32 deletions

View File

@@ -128,6 +128,7 @@ def apply_split_rotary_emb(
Returns:
Tensor with split rotary embeddings applied
"""
input_dtype = input_tensor.dtype
needs_reshape = False
original_shape = input_tensor.shape
@@ -139,6 +140,11 @@ def apply_split_rotary_emb(
input_tensor = mx.swapaxes(input_tensor, 1, 2)
needs_reshape = True
# Cast to float32 for computation precision
input_tensor = input_tensor.astype(mx.float32)
cos_freqs = cos_freqs.astype(mx.float32)
sin_freqs = sin_freqs.astype(mx.float32)
# Split into two halves: (..., dim) -> (..., 2, dim//2)
dim = input_tensor.shape[-1]
split_input = mx.reshape(input_tensor, input_tensor.shape[:-1] + (2, dim // 2))
@@ -167,7 +173,7 @@ def apply_split_rotary_emb(
output = mx.swapaxes(output, 1, 2)
output = mx.reshape(output, (b, t, h * d))
return output
return output.astype(input_dtype)
def generate_freq_grid(
@@ -424,8 +430,8 @@ def _precompute_freqs_cis_double_precision(
rope_type: LTXRopeType,
) -> Tuple[mx.array, mx.array]:
# Convert to numpy float64
indices_grid_np = np.array(indices_grid).astype(np.float64)
# Convert to numpy float64 (first to float32 for numpy compatibility)
indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
# Generate frequency indices in float64
n_pos_dims = indices_grid_np.shape[1]