diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 7383464..1f0d2e1 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -1268,6 +1268,10 @@ def generate_video( positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) mx.eval(positions) + # Save stage 1 audio latents — stage 2 only refines video (spatial upsampling). + # Audio is already fully denoised from stage 1; re-noising would destroy the signal. + stage1_audio_latents = audio_latents + state2 = None if is_i2v and stage2_image_latent is not None: state2 = LatentState( @@ -1288,12 +1292,6 @@ def generate_video( ) latents = state2.latent mx.eval(latents) - - if audio and audio_latents is not None: - audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) - one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale - audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale - mx.eval(audio_latents) else: noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype) @@ -1301,16 +1299,13 @@ def generate_video( latents = noise * noise_scale + latents * one_minus_scale mx.eval(latents) - if audio and audio_latents is not None: - audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) - audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale - mx.eval(audio_latents) - - latents, audio_latents = denoise_distilled( + # Stage 2 refines video only (no audio re-denoising) + latents, _ = denoise_distilled( latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2, - audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings, ) + # Restore audio latents from stage 1 + audio_latents = stage1_audio_latents elif pipeline == PipelineType.DEV: # ====================================================================== @@ -1531,6 +1526,10 @@ def generate_video( positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) mx.eval(positions) + # Save stage 1 audio latents — stage 2 only refines video (spatial upsampling). + # Audio is already fully denoised from stage 1; re-noising would destroy the signal. + stage1_audio_latents = audio_latents + state2 = None if is_i2v and stage2_image_latent is not None: state2 = LatentState( @@ -1551,12 +1550,6 @@ def generate_video( ) latents = state2.latent mx.eval(latents) - - if audio and audio_latents is not None: - audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) - one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale - audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale - mx.eval(audio_latents) else: noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype) @@ -1564,18 +1557,13 @@ def generate_video( latents = noise * noise_scale + latents * one_minus_scale mx.eval(latents) - if audio and audio_latents is not None: - audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) - audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale - mx.eval(audio_latents) - - # Stage 2 uses distilled denoising (no CFG) - latents, audio_latents = denoise_distilled( + # Stage 2 refines video only (no audio re-denoising) + latents, _ = denoise_distilled( latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2, - audio_latents=audio_latents, audio_positions=audio_positions, - audio_embeddings=audio_embeddings_pos if audio else None, ) + # Restore audio latents from stage 1 + audio_latents = stage1_audio_latents del transformer mx.clear_cache()