Refactor denoising functions in generate.py and utils.py to use float32 for improved precision, aligning with PyTorch behavior. Update calculations for latents and denoised outputs to ensure consistent dtype handling across audio and video processing.

This commit is contained in:
Prince Canuma
2026-01-19 09:13:04 +01:00
parent e0ee934b99
commit 8a2ea38c88
2 changed files with 43 additions and 20 deletions

View File

@@ -330,11 +330,16 @@ def denoise_distilled(
mx.eval(audio_denoised) mx.eval(audio_denoised)
if sigma_next > 0: if sigma_next > 0:
sigma_next_arr = mx.array(sigma_next, dtype=dtype) # Compute Euler step in float32 for precision (matching PyTorch behavior)
sigma_arr = mx.array(sigma, dtype=dtype) latents_f32 = latents.astype(mx.float32)
latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr denoised_f32 = denoised.astype(mx.float32)
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
sigma_f32 = mx.array(sigma, dtype=mx.float32)
latents = (denoised_f32 + sigma_next_f32 * (latents_f32 - denoised_f32) / sigma_f32).astype(dtype)
if enable_audio and audio_denoised is not None: if enable_audio and audio_denoised is not None:
audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr audio_latents_f32 = audio_latents.astype(mx.float32)
audio_denoised_f32 = audio_denoised.astype(mx.float32)
audio_latents = (audio_denoised_f32 + sigma_next_f32 * (audio_latents_f32 - audio_denoised_f32) / sigma_f32).astype(dtype)
else: else:
latents = denoised latents = denoised
if enable_audio and audio_denoised is not None: if enable_audio and audio_denoised is not None:
@@ -452,9 +457,12 @@ def denoise_dev(
denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask) denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask)
if sigma_next > 0: if sigma_next > 0:
sigma_next_arr = mx.array(sigma_next, dtype=dtype) # Compute Euler step in float32 for precision (matching PyTorch behavior)
sigma_arr = mx.array(sigma, dtype=dtype) latents_f32 = latents.astype(mx.float32)
latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr denoised_f32 = denoised.astype(mx.float32)
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
sigma_f32 = mx.array(sigma, dtype=mx.float32)
latents = (denoised_f32 + sigma_next_f32 * (latents_f32 - denoised_f32) / sigma_f32).astype(dtype)
else: else:
latents = denoised latents = denoised
@@ -599,10 +607,17 @@ def denoise_dev_av(
# Euler step # Euler step
if sigma_next > 0: if sigma_next > 0:
sigma_next_arr = mx.array(sigma_next, dtype=dtype) # Compute Euler step in float32 for precision (matching PyTorch behavior)
sigma_arr = mx.array(sigma, dtype=dtype) sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
video_latents = video_denoised + sigma_next_arr * (video_latents - video_denoised) / sigma_arr sigma_f32 = mx.array(sigma, dtype=mx.float32)
audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr
video_latents_f32 = video_latents.astype(mx.float32)
video_denoised_f32 = video_denoised.astype(mx.float32)
video_latents = (video_denoised_f32 + sigma_next_f32 * (video_latents_f32 - video_denoised_f32) / sigma_f32).astype(dtype)
audio_latents_f32 = audio_latents.astype(mx.float32)
audio_denoised_f32 = audio_denoised.astype(mx.float32)
audio_latents = (audio_denoised_f32 + sigma_next_f32 * (audio_latents_f32 - audio_denoised_f32) / sigma_f32).astype(dtype)
else: else:
video_latents = video_denoised video_latents = video_denoised
audio_latents = audio_denoised audio_latents = audio_denoised

View File

@@ -1,5 +1,5 @@
import math import math
from typing import Optional, Tuple, Union from typing import Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@@ -61,6 +61,9 @@ def to_denoised(
Given noisy input x_t and velocity prediction v, compute denoised x_0: Given noisy input x_t and velocity prediction v, compute denoised x_0:
x_0 = x_t - sigma * v x_0 = x_t - sigma * v
Uses float32 for computation precision (matching PyTorch behavior),
then converts back to input dtype.
Args: Args:
noisy: Noisy input tensor x_t noisy: Noisy input tensor x_t
velocity: Velocity prediction v velocity: Velocity prediction v
@@ -69,16 +72,21 @@ def to_denoised(
Returns: Returns:
Denoised tensor x_0 Denoised tensor x_0
""" """
original_dtype = noisy.dtype
# Cast to float32 for precision (PyTorch uses calc_dtype=torch.float32)
noisy_f32 = noisy.astype(mx.float32)
velocity_f32 = velocity.astype(mx.float32)
if isinstance(sigma, (int, float)): if isinstance(sigma, (int, float)):
# Convert to array with matching dtype to avoid float32 promotion sigma_f32 = mx.array(sigma, dtype=mx.float32)
sigma_arr = mx.array(sigma, dtype=velocity.dtype)
return noisy - sigma_arr * velocity
else: else:
# sigma is per-sample - ensure dtype matches sigma_f32 = sigma.astype(mx.float32)
sigma = sigma.astype(velocity.dtype) while sigma_f32.ndim < velocity_f32.ndim:
while sigma.ndim < velocity.ndim: sigma_f32 = mx.expand_dims(sigma_f32, axis=-1)
sigma = mx.expand_dims(sigma, axis=-1)
return noisy - sigma * velocity result = noisy_f32 - sigma_f32 * velocity_f32
return result.astype(original_dtype)
def repeat_interleave(x: mx.array, repeats: int, axis: int = -1) -> mx.array: def repeat_interleave(x: mx.array, repeats: int, axis: int = -1) -> mx.array: