Enhance precision in denormalization and normalization processes
- Updated `denormalize` and `pixel_norm` methods in `LTX2VideoDecoder` and `PerChannelStatistics` classes to cast mean and standard deviation to float32 for improved precision. - Ensured that the output of normalization operations retains the original data type of the input tensor.
This commit is contained in:
@@ -347,8 +347,9 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
|
|
||||||
def denormalize(self, x: mx.array) -> mx.array:
|
def denormalize(self, x: mx.array) -> mx.array:
|
||||||
"""Denormalize latents using per-channel statistics."""
|
"""Denormalize latents using per-channel statistics."""
|
||||||
mean = self.latents_mean.reshape(1, -1, 1, 1, 1)
|
# Cast to float32 for precision (statistics may be in bfloat16)
|
||||||
std = self.latents_std.reshape(1, -1, 1, 1, 1)
|
mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||||
|
std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||||
return x * std + mean
|
return x * std + mean
|
||||||
|
|
||||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||||
|
|||||||
@@ -101,10 +101,12 @@ class PerChannelStatistics(nn.Module):
|
|||||||
Normalized tensor
|
Normalized tensor
|
||||||
"""
|
"""
|
||||||
# Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1)
|
# Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1)
|
||||||
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
dtype = x.dtype
|
||||||
std = self.std.reshape(1, -1, 1, 1, 1)
|
# Cast to float32 for precision
|
||||||
|
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||||
|
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||||
|
|
||||||
return (x - mean) / std
|
return ((x - mean) / std).astype(dtype)
|
||||||
|
|
||||||
def un_normalize(self, x: mx.array) -> mx.array:
|
def un_normalize(self, x: mx.array) -> mx.array:
|
||||||
"""Denormalize latents using per-channel statistics.
|
"""Denormalize latents using per-channel statistics.
|
||||||
@@ -115,7 +117,9 @@ class PerChannelStatistics(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Denormalized tensor
|
Denormalized tensor
|
||||||
"""
|
"""
|
||||||
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
dtype = x.dtype
|
||||||
std = self.std.reshape(1, -1, 1, 1, 1)
|
# Cast to float32 for precision
|
||||||
|
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||||
|
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||||
|
|
||||||
return x * std + mean
|
return (x * std + mean).astype(dtype)
|
||||||
|
|||||||
Reference in New Issue
Block a user