From 8a2ea38c886bf94d4fa054a966642bce73eb553c Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 19 Jan 2026 09:13:04 +0100 Subject: [PATCH] Refactor denoising functions in generate.py and utils.py to use float32 for improved precision, aligning with PyTorch behavior. Update calculations for latents and denoised outputs to ensure consistent dtype handling across audio and video processing. --- mlx_video/generate.py | 37 ++++++++++++++++++++++++++----------- mlx_video/utils.py | 26 +++++++++++++++++--------- 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/mlx_video/generate.py b/mlx_video/generate.py index f43d5e6..1d716ee 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -330,11 +330,16 @@ def denoise_distilled( mx.eval(audio_denoised) if sigma_next > 0: - sigma_next_arr = mx.array(sigma_next, dtype=dtype) - sigma_arr = mx.array(sigma, dtype=dtype) - latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr + # Compute Euler step in float32 for precision (matching PyTorch behavior) + latents_f32 = latents.astype(mx.float32) + denoised_f32 = denoised.astype(mx.float32) + sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32) + sigma_f32 = mx.array(sigma, dtype=mx.float32) + latents = (denoised_f32 + sigma_next_f32 * (latents_f32 - denoised_f32) / sigma_f32).astype(dtype) if enable_audio and audio_denoised is not None: - audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr + audio_latents_f32 = audio_latents.astype(mx.float32) + audio_denoised_f32 = audio_denoised.astype(mx.float32) + audio_latents = (audio_denoised_f32 + sigma_next_f32 * (audio_latents_f32 - audio_denoised_f32) / sigma_f32).astype(dtype) else: latents = denoised if enable_audio and audio_denoised is not None: @@ -452,9 +457,12 @@ def denoise_dev( denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask) if sigma_next > 0: - sigma_next_arr = mx.array(sigma_next, dtype=dtype) - sigma_arr = mx.array(sigma, dtype=dtype) - latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr + # Compute Euler step in float32 for precision (matching PyTorch behavior) + latents_f32 = latents.astype(mx.float32) + denoised_f32 = denoised.astype(mx.float32) + sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32) + sigma_f32 = mx.array(sigma, dtype=mx.float32) + latents = (denoised_f32 + sigma_next_f32 * (latents_f32 - denoised_f32) / sigma_f32).astype(dtype) else: latents = denoised @@ -599,10 +607,17 @@ def denoise_dev_av( # Euler step if sigma_next > 0: - sigma_next_arr = mx.array(sigma_next, dtype=dtype) - sigma_arr = mx.array(sigma, dtype=dtype) - video_latents = video_denoised + sigma_next_arr * (video_latents - video_denoised) / sigma_arr - audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr + # Compute Euler step in float32 for precision (matching PyTorch behavior) + sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32) + sigma_f32 = mx.array(sigma, dtype=mx.float32) + + video_latents_f32 = video_latents.astype(mx.float32) + video_denoised_f32 = video_denoised.astype(mx.float32) + video_latents = (video_denoised_f32 + sigma_next_f32 * (video_latents_f32 - video_denoised_f32) / sigma_f32).astype(dtype) + + audio_latents_f32 = audio_latents.astype(mx.float32) + audio_denoised_f32 = audio_denoised.astype(mx.float32) + audio_latents = (audio_denoised_f32 + sigma_next_f32 * (audio_latents_f32 - audio_denoised_f32) / sigma_f32).astype(dtype) else: video_latents = video_denoised audio_latents = audio_denoised diff --git a/mlx_video/utils.py b/mlx_video/utils.py index cebbed7..2a6eefe 100644 --- a/mlx_video/utils.py +++ b/mlx_video/utils.py @@ -1,5 +1,5 @@ import math -from typing import Optional, Tuple, Union +from typing import Optional, Union import mlx.core as mx import mlx.nn as nn @@ -61,6 +61,9 @@ def to_denoised( Given noisy input x_t and velocity prediction v, compute denoised x_0: x_0 = x_t - sigma * v + Uses float32 for computation precision (matching PyTorch behavior), + then converts back to input dtype. + Args: noisy: Noisy input tensor x_t velocity: Velocity prediction v @@ -69,16 +72,21 @@ def to_denoised( Returns: Denoised tensor x_0 """ + original_dtype = noisy.dtype + + # Cast to float32 for precision (PyTorch uses calc_dtype=torch.float32) + noisy_f32 = noisy.astype(mx.float32) + velocity_f32 = velocity.astype(mx.float32) + if isinstance(sigma, (int, float)): - # Convert to array with matching dtype to avoid float32 promotion - sigma_arr = mx.array(sigma, dtype=velocity.dtype) - return noisy - sigma_arr * velocity + sigma_f32 = mx.array(sigma, dtype=mx.float32) else: - # sigma is per-sample - ensure dtype matches - sigma = sigma.astype(velocity.dtype) - while sigma.ndim < velocity.ndim: - sigma = mx.expand_dims(sigma, axis=-1) - return noisy - sigma * velocity + sigma_f32 = sigma.astype(mx.float32) + while sigma_f32.ndim < velocity_f32.ndim: + sigma_f32 = mx.expand_dims(sigma_f32, axis=-1) + + result = noisy_f32 - sigma_f32 * velocity_f32 + return result.astype(original_dtype) def repeat_interleave(x: mx.array, repeats: int, axis: int = -1) -> mx.array: