ensure dtype cast
This commit is contained in:
@@ -128,6 +128,7 @@ def apply_split_rotary_emb(
|
||||
Returns:
|
||||
Tensor with split rotary embeddings applied
|
||||
"""
|
||||
input_dtype = input_tensor.dtype
|
||||
needs_reshape = False
|
||||
original_shape = input_tensor.shape
|
||||
|
||||
@@ -139,6 +140,11 @@ def apply_split_rotary_emb(
|
||||
input_tensor = mx.swapaxes(input_tensor, 1, 2)
|
||||
needs_reshape = True
|
||||
|
||||
# Cast to float32 for computation precision
|
||||
input_tensor = input_tensor.astype(mx.float32)
|
||||
cos_freqs = cos_freqs.astype(mx.float32)
|
||||
sin_freqs = sin_freqs.astype(mx.float32)
|
||||
|
||||
# Split into two halves: (..., dim) -> (..., 2, dim//2)
|
||||
dim = input_tensor.shape[-1]
|
||||
split_input = mx.reshape(input_tensor, input_tensor.shape[:-1] + (2, dim // 2))
|
||||
@@ -167,7 +173,7 @@ def apply_split_rotary_emb(
|
||||
output = mx.swapaxes(output, 1, 2)
|
||||
output = mx.reshape(output, (b, t, h * d))
|
||||
|
||||
return output
|
||||
return output.astype(input_dtype)
|
||||
|
||||
|
||||
def generate_freq_grid(
|
||||
@@ -424,8 +430,8 @@ def _precompute_freqs_cis_double_precision(
|
||||
rope_type: LTXRopeType,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
|
||||
# Convert to numpy float64
|
||||
indices_grid_np = np.array(indices_grid).astype(np.float64)
|
||||
# Convert to numpy float64 (first to float32 for numpy compatibility)
|
||||
indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
|
||||
|
||||
# Generate frequency indices in float64
|
||||
n_pos_dims = indices_grid_np.shape[1]
|
||||
|
||||
Reference in New Issue
Block a user