diff --git a/mlx_video/models/ltx/upsampler.py b/mlx_video/models/ltx/upsampler.py index 05cb4e8..7f43536 100644 --- a/mlx_video/models/ltx/upsampler.py +++ b/mlx_video/models/ltx/upsampler.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -85,6 +85,10 @@ class GroupNorm3d(nn.Module): def __call__(self, x: mx.array) -> mx.array: # x: (N, D, H, W, C) n, d, h, w, c = x.shape + input_dtype = x.dtype + + + x = x.astype(mx.float32) # Reshape to (N, D*H*W, num_groups, C//num_groups) x = mx.reshape(x, (n, d * h * w, self.num_groups, c // self.num_groups)) @@ -100,7 +104,12 @@ class GroupNorm3d(nn.Module): x = mx.reshape(x, (n, d, h, w, c)) # Apply weight and bias - x = x * self.weight + self.bias + weight = self.weight.astype(mx.float32) + bias = self.bias.astype(mx.float32) + x = x * weight + bias + + # Convert back to input dtype + x = x.astype(input_dtype) return x