ensure dtype cast
This commit is contained in:
@@ -44,10 +44,9 @@ def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
|
||||
class_predicate=get_class_predicate,
|
||||
)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
|
||||
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],)), eps)
|
||||
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],), dtype=x.dtype), eps)
|
||||
|
||||
|
||||
|
||||
@@ -71,9 +70,12 @@ def to_denoised(
|
||||
Denoised tensor x_0
|
||||
"""
|
||||
if isinstance(sigma, (int, float)):
|
||||
return noisy - sigma * velocity
|
||||
# Convert to array with matching dtype to avoid float32 promotion
|
||||
sigma_arr = mx.array(sigma, dtype=velocity.dtype)
|
||||
return noisy - sigma_arr * velocity
|
||||
else:
|
||||
# sigma is per-sample
|
||||
# sigma is per-sample - ensure dtype matches
|
||||
sigma = sigma.astype(velocity.dtype)
|
||||
while sigma.ndim < velocity.ndim:
|
||||
sigma = mx.expand_dims(sigma, axis=-1)
|
||||
return noisy - sigma * velocity
|
||||
@@ -251,6 +253,7 @@ def prepare_image_for_encoding(
|
||||
image: mx.array,
|
||||
target_height: int,
|
||||
target_width: int,
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
) -> mx.array:
|
||||
"""Prepare image for VAE encoding by resizing and normalizing.
|
||||
|
||||
@@ -281,4 +284,4 @@ def prepare_image_for_encoding(
|
||||
image = mx.expand_dims(image, axis=0) # (1, 3, H, W)
|
||||
image = mx.expand_dims(image, axis=2) # (1, 3, 1, H, W)
|
||||
|
||||
return image
|
||||
return image.astype(dtype)
|
||||
|
||||
Reference in New Issue
Block a user