Refactor GroupNorm3d: Optimize data type handling by casting input, weight, and bias to float32 for consistency and performance

This commit is contained in:
Prince Canuma
2026-01-15 04:46:56 +01:00
parent 09c2b460a7
commit 349a82f763

View File

@@ -1,4 +1,4 @@
from typing import Optional, Tuple, Union from typing import Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@@ -85,6 +85,10 @@ class GroupNorm3d(nn.Module):
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
# x: (N, D, H, W, C) # x: (N, D, H, W, C)
n, d, h, w, c = x.shape n, d, h, w, c = x.shape
input_dtype = x.dtype
x = x.astype(mx.float32)
# Reshape to (N, D*H*W, num_groups, C//num_groups) # Reshape to (N, D*H*W, num_groups, C//num_groups)
x = mx.reshape(x, (n, d * h * w, self.num_groups, c // self.num_groups)) x = mx.reshape(x, (n, d * h * w, self.num_groups, c // self.num_groups))
@@ -100,7 +104,12 @@ class GroupNorm3d(nn.Module):
x = mx.reshape(x, (n, d, h, w, c)) x = mx.reshape(x, (n, d, h, w, c))
# Apply weight and bias # Apply weight and bias
x = x * self.weight + self.bias weight = self.weight.astype(mx.float32)
bias = self.bias.astype(mx.float32)
x = x * weight + bias
# Convert back to input dtype
x = x.astype(input_dtype)
return x return x