diff --git a/mlx_video/convert.py b/mlx_video/convert.py index cbefd68..de9f01d 100644 --- a/mlx_video/convert.py +++ b/mlx_video/convert.py @@ -595,18 +595,15 @@ def convert( def save_weights(path: Path, weights: Dict[str, mx.array]) -> None: """Save weights in safetensors format. + Uses mx.save_safetensors to preserve exact dtype (especially bfloat16). + Converting through numpy loses bfloat16 fidelity since numpy lacks native + bfloat16 support. + Args: path: Output directory weights: Dictionary of weights """ - from safetensors.numpy import save_file - import numpy as np - - # Convert to numpy for safetensors - np_weights = {k: np.array(v) for k, v in weights.items()} - - # Save to file - save_file(np_weights, path / "model.safetensors") + mx.save_safetensors(str(path / "model.safetensors"), weights) def load_model( diff --git a/tests/test_rope.py b/tests/test_rope.py index f64a0c2..7406cf2 100644 --- a/tests/test_rope.py +++ b/tests/test_rope.py @@ -87,7 +87,7 @@ def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads): cos_ref = np.concatenate([np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1) sin_ref = np.concatenate([np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1) - B, T, F = cos_ref.shape + B, T, _ = cos_ref.shape dim_per_head = dim // num_heads cos_ref = cos_ref.reshape(B, T, num_heads, dim_per_head // 2).transpose(0, 2, 1, 3) sin_ref = sin_ref.reshape(B, T, num_heads, dim_per_head // 2).transpose(0, 2, 1, 3)