Enhance precision in denormalization and normalization processes
- Updated `denormalize` and `pixel_norm` methods in `LTX2VideoDecoder` and `PerChannelStatistics` classes to cast mean and standard deviation to float32 for improved precision. - Ensured that the output of normalization operations retains the original data type of the input tensor.
This commit is contained in:
@@ -101,10 +101,12 @@ class PerChannelStatistics(nn.Module):
|
||||
Normalized tensor
|
||||
"""
|
||||
# Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1)
|
||||
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||
std = self.std.reshape(1, -1, 1, 1, 1)
|
||||
dtype = x.dtype
|
||||
# Cast to float32 for precision
|
||||
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
|
||||
return (x - mean) / std
|
||||
return ((x - mean) / std).astype(dtype)
|
||||
|
||||
def un_normalize(self, x: mx.array) -> mx.array:
|
||||
"""Denormalize latents using per-channel statistics.
|
||||
@@ -115,7 +117,9 @@ class PerChannelStatistics(nn.Module):
|
||||
Returns:
|
||||
Denormalized tensor
|
||||
"""
|
||||
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||
std = self.std.reshape(1, -1, 1, 1, 1)
|
||||
dtype = x.dtype
|
||||
# Cast to float32 for precision
|
||||
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
|
||||
return x * std + mean
|
||||
return (x * std + mean).astype(dtype)
|
||||
|
||||
Reference in New Issue
Block a user