diff --git a/mlx_video/generate.py b/mlx_video/generate.py index f43d5e6..1d716ee 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -330,11 +330,16 @@ def denoise_distilled( mx.eval(audio_denoised) if sigma_next > 0: - sigma_next_arr = mx.array(sigma_next, dtype=dtype) - sigma_arr = mx.array(sigma, dtype=dtype) - latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr + # Compute Euler step in float32 for precision (matching PyTorch behavior) + latents_f32 = latents.astype(mx.float32) + denoised_f32 = denoised.astype(mx.float32) + sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32) + sigma_f32 = mx.array(sigma, dtype=mx.float32) + latents = (denoised_f32 + sigma_next_f32 * (latents_f32 - denoised_f32) / sigma_f32).astype(dtype) if enable_audio and audio_denoised is not None: - audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr + audio_latents_f32 = audio_latents.astype(mx.float32) + audio_denoised_f32 = audio_denoised.astype(mx.float32) + audio_latents = (audio_denoised_f32 + sigma_next_f32 * (audio_latents_f32 - audio_denoised_f32) / sigma_f32).astype(dtype) else: latents = denoised if enable_audio and audio_denoised is not None: @@ -452,9 +457,12 @@ def denoise_dev( denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask) if sigma_next > 0: - sigma_next_arr = mx.array(sigma_next, dtype=dtype) - sigma_arr = mx.array(sigma, dtype=dtype) - latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr + # Compute Euler step in float32 for precision (matching PyTorch behavior) + latents_f32 = latents.astype(mx.float32) + denoised_f32 = denoised.astype(mx.float32) + sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32) + sigma_f32 = mx.array(sigma, dtype=mx.float32) + latents = (denoised_f32 + sigma_next_f32 * (latents_f32 - denoised_f32) / sigma_f32).astype(dtype) else: latents = denoised @@ -599,10 +607,17 @@ def denoise_dev_av( # Euler step if sigma_next > 0: - sigma_next_arr = mx.array(sigma_next, dtype=dtype) - sigma_arr = mx.array(sigma, dtype=dtype) - video_latents = video_denoised + sigma_next_arr * (video_latents - video_denoised) / sigma_arr - audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr + # Compute Euler step in float32 for precision (matching PyTorch behavior) + sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32) + sigma_f32 = mx.array(sigma, dtype=mx.float32) + + video_latents_f32 = video_latents.astype(mx.float32) + video_denoised_f32 = video_denoised.astype(mx.float32) + video_latents = (video_denoised_f32 + sigma_next_f32 * (video_latents_f32 - video_denoised_f32) / sigma_f32).astype(dtype) + + audio_latents_f32 = audio_latents.astype(mx.float32) + audio_denoised_f32 = audio_denoised.astype(mx.float32) + audio_latents = (audio_denoised_f32 + sigma_next_f32 * (audio_latents_f32 - audio_denoised_f32) / sigma_f32).astype(dtype) else: video_latents = video_denoised audio_latents = audio_denoised diff --git a/mlx_video/utils.py b/mlx_video/utils.py index cebbed7..2a6eefe 100644 --- a/mlx_video/utils.py +++ b/mlx_video/utils.py @@ -1,5 +1,5 @@ import math -from typing import Optional, Tuple, Union +from typing import Optional, Union import mlx.core as mx import mlx.nn as nn @@ -61,6 +61,9 @@ def to_denoised( Given noisy input x_t and velocity prediction v, compute denoised x_0: x_0 = x_t - sigma * v + Uses float32 for computation precision (matching PyTorch behavior), + then converts back to input dtype. + Args: noisy: Noisy input tensor x_t velocity: Velocity prediction v @@ -69,16 +72,21 @@ def to_denoised( Returns: Denoised tensor x_0 """ + original_dtype = noisy.dtype + + # Cast to float32 for precision (PyTorch uses calc_dtype=torch.float32) + noisy_f32 = noisy.astype(mx.float32) + velocity_f32 = velocity.astype(mx.float32) + if isinstance(sigma, (int, float)): - # Convert to array with matching dtype to avoid float32 promotion - sigma_arr = mx.array(sigma, dtype=velocity.dtype) - return noisy - sigma_arr * velocity + sigma_f32 = mx.array(sigma, dtype=mx.float32) else: - # sigma is per-sample - ensure dtype matches - sigma = sigma.astype(velocity.dtype) - while sigma.ndim < velocity.ndim: - sigma = mx.expand_dims(sigma, axis=-1) - return noisy - sigma * velocity + sigma_f32 = sigma.astype(mx.float32) + while sigma_f32.ndim < velocity_f32.ndim: + sigma_f32 = mx.expand_dims(sigma_f32, axis=-1) + + result = noisy_f32 - sigma_f32 * velocity_f32 + return result.astype(original_dtype) def repeat_interleave(x: mx.array, repeats: int, axis: int = -1) -> mx.array: