Enhance precision in denormalization and normalization processes

- Updated `denormalize` and `pixel_norm` methods in `LTX2VideoDecoder` and `PerChannelStatistics` classes to cast mean and standard deviation to float32 for improved precision.
- Ensured that the output of normalization operations retains the original data type of the input tensor.
This commit is contained in:
Prince Canuma
2026-01-17 01:14:29 +01:00
parent 146f5d2981
commit d52e567c56
2 changed files with 13 additions and 8 deletions

View File

@@ -347,8 +347,9 @@ class LTX2VideoDecoder(nn.Module):
def denormalize(self, x: mx.array) -> mx.array:
"""Denormalize latents using per-channel statistics."""
mean = self.latents_mean.reshape(1, -1, 1, 1, 1)
std = self.latents_std.reshape(1, -1, 1, 1, 1)
# Cast to float32 for precision (statistics may be in bfloat16)
mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
return x * std + mean
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:

View File

@@ -101,10 +101,12 @@ class PerChannelStatistics(nn.Module):
Normalized tensor
"""
# Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1)
mean = self.mean.reshape(1, -1, 1, 1, 1)
std = self.std.reshape(1, -1, 1, 1, 1)
dtype = x.dtype
# Cast to float32 for precision
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
return (x - mean) / std
return ((x - mean) / std).astype(dtype)
def un_normalize(self, x: mx.array) -> mx.array:
"""Denormalize latents using per-channel statistics.
@@ -115,7 +117,9 @@ class PerChannelStatistics(nn.Module):
Returns:
Denormalized tensor
"""
mean = self.mean.reshape(1, -1, 1, 1, 1)
std = self.std.reshape(1, -1, 1, 1, 1)
dtype = x.dtype
# Cast to float32 for precision
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
return x * std + mean
return (x * std + mean).astype(dtype)