Refactor GroupNorm3d: Optimize data type handling by casting input, weight, and bias to float32 for consistency and performance
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Tuple, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
@@ -85,6 +85,10 @@ class GroupNorm3d(nn.Module):
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# x: (N, D, H, W, C)
|
||||
n, d, h, w, c = x.shape
|
||||
input_dtype = x.dtype
|
||||
|
||||
|
||||
x = x.astype(mx.float32)
|
||||
|
||||
# Reshape to (N, D*H*W, num_groups, C//num_groups)
|
||||
x = mx.reshape(x, (n, d * h * w, self.num_groups, c // self.num_groups))
|
||||
@@ -100,7 +104,12 @@ class GroupNorm3d(nn.Module):
|
||||
x = mx.reshape(x, (n, d, h, w, c))
|
||||
|
||||
# Apply weight and bias
|
||||
x = x * self.weight + self.bias
|
||||
weight = self.weight.astype(mx.float32)
|
||||
bias = self.bias.astype(mx.float32)
|
||||
x = x * weight + bias
|
||||
|
||||
# Convert back to input dtype
|
||||
x = x.astype(input_dtype)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
Reference in New Issue
Block a user