fix save tensors

This commit is contained in:
Prince Canuma
2026-03-15 23:08:12 +01:00
parent 38d46a6eda
commit df81bc852f
2 changed files with 6 additions and 9 deletions

View File

@@ -87,7 +87,7 @@ def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads):
cos_ref = np.concatenate([np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1)
sin_ref = np.concatenate([np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1)
B, T, F = cos_ref.shape
B, T, _ = cos_ref.shape
dim_per_head = dim // num_heads
cos_ref = cos_ref.reshape(B, T, num_heads, dim_per_head // 2).transpose(0, 2, 1, 3)
sin_ref = sin_ref.reshape(B, T, num_heads, dim_per_head // 2).transpose(0, 2, 1, 3)