fix save tensors
This commit is contained in:
@@ -595,18 +595,15 @@ def convert(
|
||||
def save_weights(path: Path, weights: Dict[str, mx.array]) -> None:
|
||||
"""Save weights in safetensors format.
|
||||
|
||||
Uses mx.save_safetensors to preserve exact dtype (especially bfloat16).
|
||||
Converting through numpy loses bfloat16 fidelity since numpy lacks native
|
||||
bfloat16 support.
|
||||
|
||||
Args:
|
||||
path: Output directory
|
||||
weights: Dictionary of weights
|
||||
"""
|
||||
from safetensors.numpy import save_file
|
||||
import numpy as np
|
||||
|
||||
# Convert to numpy for safetensors
|
||||
np_weights = {k: np.array(v) for k, v in weights.items()}
|
||||
|
||||
# Save to file
|
||||
save_file(np_weights, path / "model.safetensors")
|
||||
mx.save_safetensors(str(path / "model.safetensors"), weights)
|
||||
|
||||
|
||||
def load_model(
|
||||
|
||||
Reference in New Issue
Block a user