Refactor GroupNorm3d: Optimize data type handling by casting input, weight, and bias to float32 for consistency and performance
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Tuple, Union
|
from typing import Tuple, Union
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
@@ -85,6 +85,10 @@ class GroupNorm3d(nn.Module):
|
|||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
# x: (N, D, H, W, C)
|
# x: (N, D, H, W, C)
|
||||||
n, d, h, w, c = x.shape
|
n, d, h, w, c = x.shape
|
||||||
|
input_dtype = x.dtype
|
||||||
|
|
||||||
|
|
||||||
|
x = x.astype(mx.float32)
|
||||||
|
|
||||||
# Reshape to (N, D*H*W, num_groups, C//num_groups)
|
# Reshape to (N, D*H*W, num_groups, C//num_groups)
|
||||||
x = mx.reshape(x, (n, d * h * w, self.num_groups, c // self.num_groups))
|
x = mx.reshape(x, (n, d * h * w, self.num_groups, c // self.num_groups))
|
||||||
@@ -100,7 +104,12 @@ class GroupNorm3d(nn.Module):
|
|||||||
x = mx.reshape(x, (n, d, h, w, c))
|
x = mx.reshape(x, (n, d, h, w, c))
|
||||||
|
|
||||||
# Apply weight and bias
|
# Apply weight and bias
|
||||||
x = x * self.weight + self.bias
|
weight = self.weight.astype(mx.float32)
|
||||||
|
bias = self.bias.astype(mx.float32)
|
||||||
|
x = x * weight + bias
|
||||||
|
|
||||||
|
# Convert back to input dtype
|
||||||
|
x = x.astype(input_dtype)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user