From d52e567c5692470149fbd05fd76e346d43dff354 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 17 Jan 2026 01:14:29 +0100 Subject: [PATCH] Enhance precision in denormalization and normalization processes - Updated `denormalize` and `pixel_norm` methods in `LTX2VideoDecoder` and `PerChannelStatistics` classes to cast mean and standard deviation to float32 for improved precision. - Ensured that the output of normalization operations retains the original data type of the input tensor. --- mlx_video/models/ltx/video_vae/decoder.py | 5 +++-- mlx_video/models/ltx/video_vae/ops.py | 16 ++++++++++------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/mlx_video/models/ltx/video_vae/decoder.py b/mlx_video/models/ltx/video_vae/decoder.py index ab22374..1bc0983 100644 --- a/mlx_video/models/ltx/video_vae/decoder.py +++ b/mlx_video/models/ltx/video_vae/decoder.py @@ -347,8 +347,9 @@ class LTX2VideoDecoder(nn.Module): def denormalize(self, x: mx.array) -> mx.array: """Denormalize latents using per-channel statistics.""" - mean = self.latents_mean.reshape(1, -1, 1, 1, 1) - std = self.latents_std.reshape(1, -1, 1, 1, 1) + # Cast to float32 for precision (statistics may be in bfloat16) + mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1) + std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1) return x * std + mean def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array: diff --git a/mlx_video/models/ltx/video_vae/ops.py b/mlx_video/models/ltx/video_vae/ops.py index ca0457e..d730d2f 100644 --- a/mlx_video/models/ltx/video_vae/ops.py +++ b/mlx_video/models/ltx/video_vae/ops.py @@ -101,10 +101,12 @@ class PerChannelStatistics(nn.Module): Normalized tensor """ # Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1) - mean = self.mean.reshape(1, -1, 1, 1, 1) - std = self.std.reshape(1, -1, 1, 1, 1) + dtype = x.dtype + # Cast to float32 for precision + mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1) + std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1) - return (x - mean) / std + return ((x - mean) / std).astype(dtype) def un_normalize(self, x: mx.array) -> mx.array: """Denormalize latents using per-channel statistics. @@ -115,7 +117,9 @@ class PerChannelStatistics(nn.Module): Returns: Denormalized tensor """ - mean = self.mean.reshape(1, -1, 1, 1, 1) - std = self.std.reshape(1, -1, 1, 1, 1) + dtype = x.dtype + # Cast to float32 for precision + mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1) + std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1) - return x * std + mean + return (x * std + mean).astype(dtype)