From 87962c7f831b71ebbe0154cccdec7522b97116e3 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 24 Jan 2026 15:40:42 +0100 Subject: [PATCH] Enhance precision in denoising functions by ensuring all latents and calculations are consistently handled in float32. Update model input casting and return types to maintain dtype integrity across audio and video processing. Add precision parameter to video generation for improved memory management. --- mlx_video/generate.py | 91 +++++++++++++++++++++++++++---------------- 1 file changed, 57 insertions(+), 34 deletions(-) diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 2486ef0..a146031 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -249,6 +249,11 @@ def denoise_distilled( if state is not None: latents = state.latent + # Keep latents in float32 throughout to avoid quantization noise accumulation. + latents = latents.astype(mx.float32) + if enable_audio: + audio_latents = audio_latents.astype(mx.float32) + desc = "[cyan]Denoising A/V[/]" if enable_audio else "[cyan]Denoising[/]" num_steps = len(sigmas) - 1 @@ -268,7 +273,8 @@ def denoise_distilled( b, c, f, h, w = latents.shape num_tokens = f * h * w - latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) + # Cast to model dtype for transformer input + latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)).astype(dtype) if state is not None: denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1)) @@ -291,7 +297,7 @@ def denoise_distilled( if enable_audio: ab, ac, at, af = audio_latents.shape audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) - audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)) + audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype) audio_modality = Modality( latent=audio_flat, @@ -307,34 +313,36 @@ def denoise_distilled( if audio_velocity is not None: mx.eval(audio_velocity) - velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w)) - denoised = to_denoised(latents, velocity, sigma) + # Compute denoised (x0) using per-token timesteps in float32 + # x0 = latent - timestep * velocity + # For conditioned tokens (timestep=0): x0 = latent + # For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity + sigma_f32 = mx.array(sigma, dtype=mx.float32) + latents_flat_f32 = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) + timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1) + x0_f32 = latents_flat_f32 - timesteps_f32 * velocity.astype(mx.float32) + denoised = mx.reshape(mx.transpose(x0_f32, (0, 2, 1)), (b, c, f, h, w)) audio_denoised = None if enable_audio and audio_velocity is not None: ab, ac, at, af = audio_latents.shape audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af)) audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) - audio_denoised = to_denoised(audio_latents, audio_velocity, sigma) + audio_denoised = audio_latents - sigma_f32 * audio_velocity.astype(mx.float32) if state is not None: - denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask) + denoised = apply_denoise_mask(denoised, state.clean_latent.astype(mx.float32), state.denoise_mask) mx.eval(denoised) if audio_denoised is not None: mx.eval(audio_denoised) + # Euler step in float32 (latents stay in float32) if sigma_next > 0: - # Compute Euler step in float32 for precision (matching PyTorch behavior) - latents_f32 = latents.astype(mx.float32) - denoised_f32 = denoised.astype(mx.float32) sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32) - sigma_f32 = mx.array(sigma, dtype=mx.float32) - latents = (denoised_f32 + sigma_next_f32 * (latents_f32 - denoised_f32) / sigma_f32).astype(dtype) + latents = denoised + sigma_next_f32 * (latents - denoised) / sigma_f32 if enable_audio and audio_denoised is not None: - audio_latents_f32 = audio_latents.astype(mx.float32) - audio_denoised_f32 = audio_denoised.astype(mx.float32) - audio_latents = (audio_denoised_f32 + sigma_next_f32 * (audio_latents_f32 - audio_denoised_f32) / sigma_f32).astype(dtype) + audio_latents = audio_denoised + sigma_next_f32 * (audio_latents - audio_denoised) / sigma_f32 else: latents = denoised if enable_audio and audio_denoised is not None: @@ -346,7 +354,7 @@ def denoise_distilled( progress.advance(task) - return latents, audio_latents if enable_audio else None + return latents.astype(dtype), audio_latents.astype(dtype) if enable_audio else None # ============================================================================= @@ -371,6 +379,11 @@ def denoise_dev( if state is not None: latents = state.latent + # Keep latents in float32 throughout the denoising loop to avoid + # quantization noise accumulation over many steps. + # Model input is cast to model dtype; all denoising math stays in float32. + latents = latents.astype(mx.float32) + sigmas_list = sigmas.tolist() use_cfg = cfg_scale != 1.0 num_steps = len(sigmas_list) - 1 @@ -405,7 +418,8 @@ def denoise_dev( b, c, f, h, w = latents.shape num_tokens = f * h * w - latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) + # Cast to model dtype for transformer input + latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)).astype(dtype) if state is not None: denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1)) @@ -427,6 +441,14 @@ def denoise_dev( ) velocity_pos, _ = transformer(video=video_modality_pos, audio=None) + # Convert velocity to x0 (denoised) using per-token timesteps + # Matches PyTorch's X0Model: x0 = latent - timestep * velocity + # For conditioned tokens (timestep=0): x0 = latent (correct regardless of velocity) + # For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity + latents_flat_f32 = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) + timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1) + x0_pos_f32 = latents_flat_f32 - timesteps_f32 * velocity_pos.astype(mx.float32) + if use_cfg: # Negative conditioning pass video_modality_neg = Modality( @@ -440,31 +462,34 @@ def denoise_dev( ) velocity_neg, _ = transformer(video=video_modality_neg, audio=None) - # Apply CFG - velocity_flat = velocity_pos + (cfg_scale - 1.0) * (velocity_pos - velocity_neg) - else: - velocity_flat = velocity_pos + # Convert negative velocity to x0 using per-token timesteps + x0_neg_f32 = latents_flat_f32 - timesteps_f32 * velocity_neg.astype(mx.float32) - velocity = mx.reshape(mx.transpose(velocity_flat, (0, 2, 1)), (b, c, f, h, w)) - denoised = to_denoised(latents, velocity, sigma) + # Apply CFG to x0 predictions (matches PyTorch CFGGuider) + # For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 + x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32) + else: + x0_guided_f32 = x0_pos_f32 + + # Reshape x0 from token space (b, tokens, c) to spatial (b, c, f, h, w) + denoised = mx.reshape(mx.transpose(x0_guided_f32, (0, 2, 1)), (b, c, f, h, w)) + + sigma_f32 = mx.array(sigma, dtype=mx.float32) if state is not None: - denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask) + denoised = apply_denoise_mask(denoised, state.clean_latent.astype(mx.float32), state.denoise_mask) + # Euler step in float32 (latents stay in float32) if sigma_next > 0: - # Compute Euler step in float32 for precision (matching PyTorch behavior) - latents_f32 = latents.astype(mx.float32) - denoised_f32 = denoised.astype(mx.float32) sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32) - sigma_f32 = mx.array(sigma, dtype=mx.float32) - latents = (denoised_f32 + sigma_next_f32 * (latents_f32 - denoised_f32) / sigma_f32).astype(dtype) + latents = denoised + sigma_next_f32 * (latents - denoised) / sigma_f32 else: latents = denoised mx.eval(latents) progress.advance(task) - return latents + return latents.astype(dtype) def denoise_dev_av( @@ -1055,9 +1080,8 @@ def generate_video( mx.clear_cache() console.print("[green]✓[/] VAE encoder loaded and image encoded") - # Generate sigma schedule - # PyTorch LTX-2 does NOT pass the latent to the scheduler, so it uses - # the default MAX_SHIFT_ANCHOR (4096) for the shift calculation + # Generate sigma schedule (uses MAX_SHIFT_ANCHOR=4096 like the reference implementation) + num_tokens = latent_frames * latent_h * latent_w sigmas = ltx2_scheduler(steps=num_inference_steps) mx.eval(sigmas) console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]") @@ -1117,7 +1141,7 @@ def generate_video( latents, video_positions, video_embeddings_pos, video_embeddings_neg, transformer, sigmas, cfg_scale=cfg_scale, - cfg_rescale=cfg_rescale, verbose=verbose, state=video_state + verbose=verbose, state=video_state ) # Load VAE decoder (for dev pipeline, loaded here instead of during upsampling) @@ -1347,7 +1371,6 @@ Examples: parser.add_argument("--stream", action="store_true", help="Stream frames to output as they're decoded") parser.add_argument("--audio", "-a", action="store_true", help="Enable synchronized audio generation") parser.add_argument("--output-audio", type=str, default=None, help="Output audio path") - args = parser.parse_args() pipeline = PipelineType.DEV if args.pipeline == "dev" else PipelineType.DISTILLED