fix save tensors

This commit is contained in:
Prince Canuma
2026-03-15 23:08:12 +01:00
parent 38d46a6eda
commit df81bc852f
2 changed files with 6 additions and 9 deletions

View File

@@ -595,18 +595,15 @@ def convert(
def save_weights(path: Path, weights: Dict[str, mx.array]) -> None:
"""Save weights in safetensors format.
Uses mx.save_safetensors to preserve exact dtype (especially bfloat16).
Converting through numpy loses bfloat16 fidelity since numpy lacks native
bfloat16 support.
Args:
path: Output directory
weights: Dictionary of weights
"""
from safetensors.numpy import save_file
import numpy as np
# Convert to numpy for safetensors
np_weights = {k: np.array(v) for k, v in weights.items()}
# Save to file
save_file(np_weights, path / "model.safetensors")
mx.save_safetensors(str(path / "model.safetensors"), weights)
def load_model(