fix save tensors
This commit is contained in:
@@ -595,18 +595,15 @@ def convert(
|
|||||||
def save_weights(path: Path, weights: Dict[str, mx.array]) -> None:
|
def save_weights(path: Path, weights: Dict[str, mx.array]) -> None:
|
||||||
"""Save weights in safetensors format.
|
"""Save weights in safetensors format.
|
||||||
|
|
||||||
|
Uses mx.save_safetensors to preserve exact dtype (especially bfloat16).
|
||||||
|
Converting through numpy loses bfloat16 fidelity since numpy lacks native
|
||||||
|
bfloat16 support.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path: Output directory
|
path: Output directory
|
||||||
weights: Dictionary of weights
|
weights: Dictionary of weights
|
||||||
"""
|
"""
|
||||||
from safetensors.numpy import save_file
|
mx.save_safetensors(str(path / "model.safetensors"), weights)
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# Convert to numpy for safetensors
|
|
||||||
np_weights = {k: np.array(v) for k, v in weights.items()}
|
|
||||||
|
|
||||||
# Save to file
|
|
||||||
save_file(np_weights, path / "model.safetensors")
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads):
|
|||||||
cos_ref = np.concatenate([np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1)
|
cos_ref = np.concatenate([np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1)
|
||||||
sin_ref = np.concatenate([np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1)
|
sin_ref = np.concatenate([np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1)
|
||||||
|
|
||||||
B, T, F = cos_ref.shape
|
B, T, _ = cos_ref.shape
|
||||||
dim_per_head = dim // num_heads
|
dim_per_head = dim // num_heads
|
||||||
cos_ref = cos_ref.reshape(B, T, num_heads, dim_per_head // 2).transpose(0, 2, 1, 3)
|
cos_ref = cos_ref.reshape(B, T, num_heads, dim_per_head // 2).transpose(0, 2, 1, 3)
|
||||||
sin_ref = sin_ref.reshape(B, T, num_heads, dim_per_head // 2).transpose(0, 2, 1, 3)
|
sin_ref = sin_ref.reshape(B, T, num_heads, dim_per_head // 2).transpose(0, 2, 1, 3)
|
||||||
|
|||||||
Reference in New Issue
Block a user