Enhance precision in denormalization and normalization processes
- Updated `denormalize` and `pixel_norm` methods in `LTX2VideoDecoder` and `PerChannelStatistics` classes to cast mean and standard deviation to float32 for improved precision. - Ensured that the output of normalization operations retains the original data type of the input tensor.
This commit is contained in:
@@ -347,8 +347,9 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
def denormalize(self, x: mx.array) -> mx.array:
|
||||
"""Denormalize latents using per-channel statistics."""
|
||||
mean = self.latents_mean.reshape(1, -1, 1, 1, 1)
|
||||
std = self.latents_std.reshape(1, -1, 1, 1, 1)
|
||||
# Cast to float32 for precision (statistics may be in bfloat16)
|
||||
mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
return x * std + mean
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
|
||||
Reference in New Issue
Block a user