ensure dtype cast

This commit is contained in:
Prince Canuma
2026-01-17 13:03:48 +01:00
parent e4cdbb7eab
commit 883c6b0ad8
6 changed files with 52 additions and 32 deletions

View File

@@ -44,10 +44,9 @@ def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
class_predicate=get_class_predicate,
)
@partial(mx.compile, shapeless=True)
@partial(mx.compile, shapeless=True)
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],)), eps)
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],), dtype=x.dtype), eps)
@@ -71,9 +70,12 @@ def to_denoised(
Denoised tensor x_0
"""
if isinstance(sigma, (int, float)):
return noisy - sigma * velocity
# Convert to array with matching dtype to avoid float32 promotion
sigma_arr = mx.array(sigma, dtype=velocity.dtype)
return noisy - sigma_arr * velocity
else:
# sigma is per-sample
# sigma is per-sample - ensure dtype matches
sigma = sigma.astype(velocity.dtype)
while sigma.ndim < velocity.ndim:
sigma = mx.expand_dims(sigma, axis=-1)
return noisy - sigma * velocity
@@ -251,6 +253,7 @@ def prepare_image_for_encoding(
image: mx.array,
target_height: int,
target_width: int,
dtype: mx.Dtype = mx.float32,
) -> mx.array:
"""Prepare image for VAE encoding by resizing and normalizing.
@@ -281,4 +284,4 @@ def prepare_image_for_encoding(
image = mx.expand_dims(image, axis=0) # (1, 3, H, W)
image = mx.expand_dims(image, axis=2) # (1, 3, 1, H, W)
return image
return image.astype(dtype)