Enhance precision in denormalization and normalization processes
- Updated `denormalize` and `pixel_norm` methods in `LTX2VideoDecoder` and `PerChannelStatistics` classes to cast mean and standard deviation to float32 for improved precision. - Ensured that the output of normalization operations retains the original data type of the input tensor.
This commit is contained in:
@@ -347,8 +347,9 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
def denormalize(self, x: mx.array) -> mx.array:
|
||||
"""Denormalize latents using per-channel statistics."""
|
||||
mean = self.latents_mean.reshape(1, -1, 1, 1, 1)
|
||||
std = self.latents_std.reshape(1, -1, 1, 1, 1)
|
||||
# Cast to float32 for precision (statistics may be in bfloat16)
|
||||
mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
return x * std + mean
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
|
||||
@@ -101,10 +101,12 @@ class PerChannelStatistics(nn.Module):
|
||||
Normalized tensor
|
||||
"""
|
||||
# Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1)
|
||||
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||
std = self.std.reshape(1, -1, 1, 1, 1)
|
||||
dtype = x.dtype
|
||||
# Cast to float32 for precision
|
||||
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
|
||||
return (x - mean) / std
|
||||
return ((x - mean) / std).astype(dtype)
|
||||
|
||||
def un_normalize(self, x: mx.array) -> mx.array:
|
||||
"""Denormalize latents using per-channel statistics.
|
||||
@@ -115,7 +117,9 @@ class PerChannelStatistics(nn.Module):
|
||||
Returns:
|
||||
Denormalized tensor
|
||||
"""
|
||||
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||
std = self.std.reshape(1, -1, 1, 1, 1)
|
||||
dtype = x.dtype
|
||||
# Cast to float32 for precision
|
||||
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
|
||||
return x * std + mean
|
||||
return (x * std + mean).astype(dtype)
|
||||
|
||||
Reference in New Issue
Block a user