From 349a82f763851a690f870326a11e979347613f70 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Thu, 15 Jan 2026 04:46:56 +0100 Subject: [PATCH] Refactor GroupNorm3d: Optimize data type handling by casting input, weight, and bias to float32 for consistency and performance --- mlx_video/models/ltx/upsampler.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/mlx_video/models/ltx/upsampler.py b/mlx_video/models/ltx/upsampler.py index 05cb4e8..7f43536 100644 --- a/mlx_video/models/ltx/upsampler.py +++ b/mlx_video/models/ltx/upsampler.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -85,6 +85,10 @@ class GroupNorm3d(nn.Module): def __call__(self, x: mx.array) -> mx.array: # x: (N, D, H, W, C) n, d, h, w, c = x.shape + input_dtype = x.dtype + + + x = x.astype(mx.float32) # Reshape to (N, D*H*W, num_groups, C//num_groups) x = mx.reshape(x, (n, d * h * w, self.num_groups, c // self.num_groups)) @@ -100,7 +104,12 @@ class GroupNorm3d(nn.Module): x = mx.reshape(x, (n, d, h, w, c)) # Apply weight and bias - x = x * self.weight + self.bias + weight = self.weight.astype(mx.float32) + bias = self.bias.astype(mx.float32) + x = x * weight + bias + + # Convert back to input dtype + x = x.astype(input_dtype) return x