diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 2486ef0..a146031 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -249,6 +249,11 @@ def denoise_distilled( if state is not None: latents = state.latent + # Keep latents in float32 throughout to avoid quantization noise accumulation. + latents = latents.astype(mx.float32) + if enable_audio: + audio_latents = audio_latents.astype(mx.float32) + desc = "[cyan]Denoising A/V[/]" if enable_audio else "[cyan]Denoising[/]" num_steps = len(sigmas) - 1 @@ -268,7 +273,8 @@ def denoise_distilled( b, c, f, h, w = latents.shape num_tokens = f * h * w - latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) + # Cast to model dtype for transformer input + latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)).astype(dtype) if state is not None: denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1)) @@ -291,7 +297,7 @@ def denoise_distilled( if enable_audio: ab, ac, at, af = audio_latents.shape audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) - audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)) + audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype) audio_modality = Modality( latent=audio_flat, @@ -307,34 +313,36 @@ def denoise_distilled( if audio_velocity is not None: mx.eval(audio_velocity) - velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w)) - denoised = to_denoised(latents, velocity, sigma) + # Compute denoised (x0) using per-token timesteps in float32 + # x0 = latent - timestep * velocity + # For conditioned tokens (timestep=0): x0 = latent + # For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity + sigma_f32 = mx.array(sigma, dtype=mx.float32) + latents_flat_f32 = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) + timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1) + x0_f32 = latents_flat_f32 - timesteps_f32 * velocity.astype(mx.float32) + denoised = mx.reshape(mx.transpose(x0_f32, (0, 2, 1)), (b, c, f, h, w)) audio_denoised = None if enable_audio and audio_velocity is not None: ab, ac, at, af = audio_latents.shape audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af)) audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) - audio_denoised = to_denoised(audio_latents, audio_velocity, sigma) + audio_denoised = audio_latents - sigma_f32 * audio_velocity.astype(mx.float32) if state is not None: - denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask) + denoised = apply_denoise_mask(denoised, state.clean_latent.astype(mx.float32), state.denoise_mask) mx.eval(denoised) if audio_denoised is not None: mx.eval(audio_denoised) + # Euler step in float32 (latents stay in float32) if sigma_next > 0: - # Compute Euler step in float32 for precision (matching PyTorch behavior) - latents_f32 = latents.astype(mx.float32) - denoised_f32 = denoised.astype(mx.float32) sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32) - sigma_f32 = mx.array(sigma, dtype=mx.float32) - latents = (denoised_f32 + sigma_next_f32 * (latents_f32 - denoised_f32) / sigma_f32).astype(dtype) + latents = denoised + sigma_next_f32 * (latents - denoised) / sigma_f32 if enable_audio and audio_denoised is not None: - audio_latents_f32 = audio_latents.astype(mx.float32) - audio_denoised_f32 = audio_denoised.astype(mx.float32) - audio_latents = (audio_denoised_f32 + sigma_next_f32 * (audio_latents_f32 - audio_denoised_f32) / sigma_f32).astype(dtype) + audio_latents = audio_denoised + sigma_next_f32 * (audio_latents - audio_denoised) / sigma_f32 else: latents = denoised if enable_audio and audio_denoised is not None: @@ -346,7 +354,7 @@ def denoise_distilled( progress.advance(task) - return latents, audio_latents if enable_audio else None + return latents.astype(dtype), audio_latents.astype(dtype) if enable_audio else None # ============================================================================= @@ -371,6 +379,11 @@ def denoise_dev( if state is not None: latents = state.latent + # Keep latents in float32 throughout the denoising loop to avoid + # quantization noise accumulation over many steps. + # Model input is cast to model dtype; all denoising math stays in float32. + latents = latents.astype(mx.float32) + sigmas_list = sigmas.tolist() use_cfg = cfg_scale != 1.0 num_steps = len(sigmas_list) - 1 @@ -405,7 +418,8 @@ def denoise_dev( b, c, f, h, w = latents.shape num_tokens = f * h * w - latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) + # Cast to model dtype for transformer input + latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)).astype(dtype) if state is not None: denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1)) @@ -427,6 +441,14 @@ def denoise_dev( ) velocity_pos, _ = transformer(video=video_modality_pos, audio=None) + # Convert velocity to x0 (denoised) using per-token timesteps + # Matches PyTorch's X0Model: x0 = latent - timestep * velocity + # For conditioned tokens (timestep=0): x0 = latent (correct regardless of velocity) + # For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity + latents_flat_f32 = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) + timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1) + x0_pos_f32 = latents_flat_f32 - timesteps_f32 * velocity_pos.astype(mx.float32) + if use_cfg: # Negative conditioning pass video_modality_neg = Modality( @@ -440,31 +462,34 @@ def denoise_dev( ) velocity_neg, _ = transformer(video=video_modality_neg, audio=None) - # Apply CFG - velocity_flat = velocity_pos + (cfg_scale - 1.0) * (velocity_pos - velocity_neg) - else: - velocity_flat = velocity_pos + # Convert negative velocity to x0 using per-token timesteps + x0_neg_f32 = latents_flat_f32 - timesteps_f32 * velocity_neg.astype(mx.float32) - velocity = mx.reshape(mx.transpose(velocity_flat, (0, 2, 1)), (b, c, f, h, w)) - denoised = to_denoised(latents, velocity, sigma) + # Apply CFG to x0 predictions (matches PyTorch CFGGuider) + # For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0 + x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32) + else: + x0_guided_f32 = x0_pos_f32 + + # Reshape x0 from token space (b, tokens, c) to spatial (b, c, f, h, w) + denoised = mx.reshape(mx.transpose(x0_guided_f32, (0, 2, 1)), (b, c, f, h, w)) + + sigma_f32 = mx.array(sigma, dtype=mx.float32) if state is not None: - denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask) + denoised = apply_denoise_mask(denoised, state.clean_latent.astype(mx.float32), state.denoise_mask) + # Euler step in float32 (latents stay in float32) if sigma_next > 0: - # Compute Euler step in float32 for precision (matching PyTorch behavior) - latents_f32 = latents.astype(mx.float32) - denoised_f32 = denoised.astype(mx.float32) sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32) - sigma_f32 = mx.array(sigma, dtype=mx.float32) - latents = (denoised_f32 + sigma_next_f32 * (latents_f32 - denoised_f32) / sigma_f32).astype(dtype) + latents = denoised + sigma_next_f32 * (latents - denoised) / sigma_f32 else: latents = denoised mx.eval(latents) progress.advance(task) - return latents + return latents.astype(dtype) def denoise_dev_av( @@ -1055,9 +1080,8 @@ def generate_video( mx.clear_cache() console.print("[green]✓[/] VAE encoder loaded and image encoded") - # Generate sigma schedule - # PyTorch LTX-2 does NOT pass the latent to the scheduler, so it uses - # the default MAX_SHIFT_ANCHOR (4096) for the shift calculation + # Generate sigma schedule (uses MAX_SHIFT_ANCHOR=4096 like the reference implementation) + num_tokens = latent_frames * latent_h * latent_w sigmas = ltx2_scheduler(steps=num_inference_steps) mx.eval(sigmas) console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]") @@ -1117,7 +1141,7 @@ def generate_video( latents, video_positions, video_embeddings_pos, video_embeddings_neg, transformer, sigmas, cfg_scale=cfg_scale, - cfg_rescale=cfg_rescale, verbose=verbose, state=video_state + verbose=verbose, state=video_state ) # Load VAE decoder (for dev pipeline, loaded here instead of during upsampling) @@ -1347,7 +1371,6 @@ Examples: parser.add_argument("--stream", action="store_true", help="Stream frames to output as they're decoded") parser.add_argument("--audio", "-a", action="store_true", help="Enable synchronized audio generation") parser.add_argument("--output-audio", type=str, default=None, help="Output audio path") - args = parser.parse_args() pipeline = PipelineType.DEV if args.pipeline == "dev" else PipelineType.DISTILLED