diff --git a/mlx_video/models/ltx/video_vae/decoder.py b/mlx_video/models/ltx/video_vae/decoder.py index ab22374..1bc0983 100644 --- a/mlx_video/models/ltx/video_vae/decoder.py +++ b/mlx_video/models/ltx/video_vae/decoder.py @@ -347,8 +347,9 @@ class LTX2VideoDecoder(nn.Module): def denormalize(self, x: mx.array) -> mx.array: """Denormalize latents using per-channel statistics.""" - mean = self.latents_mean.reshape(1, -1, 1, 1, 1) - std = self.latents_std.reshape(1, -1, 1, 1, 1) + # Cast to float32 for precision (statistics may be in bfloat16) + mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1) + std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1) return x * std + mean def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array: diff --git a/mlx_video/models/ltx/video_vae/ops.py b/mlx_video/models/ltx/video_vae/ops.py index ca0457e..d730d2f 100644 --- a/mlx_video/models/ltx/video_vae/ops.py +++ b/mlx_video/models/ltx/video_vae/ops.py @@ -101,10 +101,12 @@ class PerChannelStatistics(nn.Module): Normalized tensor """ # Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1) - mean = self.mean.reshape(1, -1, 1, 1, 1) - std = self.std.reshape(1, -1, 1, 1, 1) + dtype = x.dtype + # Cast to float32 for precision + mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1) + std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1) - return (x - mean) / std + return ((x - mean) / std).astype(dtype) def un_normalize(self, x: mx.array) -> mx.array: """Denormalize latents using per-channel statistics. @@ -115,7 +117,9 @@ class PerChannelStatistics(nn.Module): Returns: Denormalized tensor """ - mean = self.mean.reshape(1, -1, 1, 1, 1) - std = self.std.reshape(1, -1, 1, 1, 1) + dtype = x.dtype + # Cast to float32 for precision + mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1) + std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1) - return x * std + mean + return (x * std + mean).astype(dtype)