Refactor denoising functions in generate.py and utils.py to use float32 for improved precision, aligning with PyTorch behavior. Update calculations for latents and denoised outputs to ensure consistent dtype handling across audio and video processing.
This commit is contained in:
@@ -330,11 +330,16 @@ def denoise_distilled(
|
||||
mx.eval(audio_denoised)
|
||||
|
||||
if sigma_next > 0:
|
||||
sigma_next_arr = mx.array(sigma_next, dtype=dtype)
|
||||
sigma_arr = mx.array(sigma, dtype=dtype)
|
||||
latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr
|
||||
# Compute Euler step in float32 for precision (matching PyTorch behavior)
|
||||
latents_f32 = latents.astype(mx.float32)
|
||||
denoised_f32 = denoised.astype(mx.float32)
|
||||
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
|
||||
sigma_f32 = mx.array(sigma, dtype=mx.float32)
|
||||
latents = (denoised_f32 + sigma_next_f32 * (latents_f32 - denoised_f32) / sigma_f32).astype(dtype)
|
||||
if enable_audio and audio_denoised is not None:
|
||||
audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr
|
||||
audio_latents_f32 = audio_latents.astype(mx.float32)
|
||||
audio_denoised_f32 = audio_denoised.astype(mx.float32)
|
||||
audio_latents = (audio_denoised_f32 + sigma_next_f32 * (audio_latents_f32 - audio_denoised_f32) / sigma_f32).astype(dtype)
|
||||
else:
|
||||
latents = denoised
|
||||
if enable_audio and audio_denoised is not None:
|
||||
@@ -452,9 +457,12 @@ def denoise_dev(
|
||||
denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask)
|
||||
|
||||
if sigma_next > 0:
|
||||
sigma_next_arr = mx.array(sigma_next, dtype=dtype)
|
||||
sigma_arr = mx.array(sigma, dtype=dtype)
|
||||
latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr
|
||||
# Compute Euler step in float32 for precision (matching PyTorch behavior)
|
||||
latents_f32 = latents.astype(mx.float32)
|
||||
denoised_f32 = denoised.astype(mx.float32)
|
||||
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
|
||||
sigma_f32 = mx.array(sigma, dtype=mx.float32)
|
||||
latents = (denoised_f32 + sigma_next_f32 * (latents_f32 - denoised_f32) / sigma_f32).astype(dtype)
|
||||
else:
|
||||
latents = denoised
|
||||
|
||||
@@ -599,10 +607,17 @@ def denoise_dev_av(
|
||||
|
||||
# Euler step
|
||||
if sigma_next > 0:
|
||||
sigma_next_arr = mx.array(sigma_next, dtype=dtype)
|
||||
sigma_arr = mx.array(sigma, dtype=dtype)
|
||||
video_latents = video_denoised + sigma_next_arr * (video_latents - video_denoised) / sigma_arr
|
||||
audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr
|
||||
# Compute Euler step in float32 for precision (matching PyTorch behavior)
|
||||
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
|
||||
sigma_f32 = mx.array(sigma, dtype=mx.float32)
|
||||
|
||||
video_latents_f32 = video_latents.astype(mx.float32)
|
||||
video_denoised_f32 = video_denoised.astype(mx.float32)
|
||||
video_latents = (video_denoised_f32 + sigma_next_f32 * (video_latents_f32 - video_denoised_f32) / sigma_f32).astype(dtype)
|
||||
|
||||
audio_latents_f32 = audio_latents.astype(mx.float32)
|
||||
audio_denoised_f32 = audio_denoised.astype(mx.float32)
|
||||
audio_latents = (audio_denoised_f32 + sigma_next_f32 * (audio_latents_f32 - audio_denoised_f32) / sigma_f32).astype(dtype)
|
||||
else:
|
||||
video_latents = video_denoised
|
||||
audio_latents = audio_denoised
|
||||
|
||||
Reference in New Issue
Block a user