Refactor audio handling in generate_video function to preserve stage 1 audio latents during stage 2 processing. Remove redundant audio re-denoising steps, ensuring audio integrity while refining video output. Update comments for clarity on audio processing logic.

This commit is contained in:
Prince Canuma
2026-03-13 16:09:07 +01:00
parent 387d4fc301
commit f346e09de4

View File

@@ -1268,6 +1268,10 @@ def generate_video(
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions) mx.eval(positions)
# Save stage 1 audio latents — stage 2 only refines video (spatial upsampling).
# Audio is already fully denoised from stage 1; re-noising would destroy the signal.
stage1_audio_latents = audio_latents
state2 = None state2 = None
if is_i2v and stage2_image_latent is not None: if is_i2v and stage2_image_latent is not None:
state2 = LatentState( state2 = LatentState(
@@ -1288,12 +1292,6 @@ def generate_video(
) )
latents = state2.latent latents = state2.latent
mx.eval(latents) mx.eval(latents)
if audio and audio_latents is not None:
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
mx.eval(audio_latents)
else: else:
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype) one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
@@ -1301,16 +1299,13 @@ def generate_video(
latents = noise * noise_scale + latents * one_minus_scale latents = noise * noise_scale + latents * one_minus_scale
mx.eval(latents) mx.eval(latents)
if audio and audio_latents is not None: # Stage 2 refines video only (no audio re-denoising)
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) latents, _ = denoise_distilled(
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
mx.eval(audio_latents)
latents, audio_latents = denoise_distilled(
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
verbose=verbose, state=state2, verbose=verbose, state=state2,
audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings,
) )
# Restore audio latents from stage 1
audio_latents = stage1_audio_latents
elif pipeline == PipelineType.DEV: elif pipeline == PipelineType.DEV:
# ====================================================================== # ======================================================================
@@ -1531,6 +1526,10 @@ def generate_video(
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions) mx.eval(positions)
# Save stage 1 audio latents — stage 2 only refines video (spatial upsampling).
# Audio is already fully denoised from stage 1; re-noising would destroy the signal.
stage1_audio_latents = audio_latents
state2 = None state2 = None
if is_i2v and stage2_image_latent is not None: if is_i2v and stage2_image_latent is not None:
state2 = LatentState( state2 = LatentState(
@@ -1551,12 +1550,6 @@ def generate_video(
) )
latents = state2.latent latents = state2.latent
mx.eval(latents) mx.eval(latents)
if audio and audio_latents is not None:
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
mx.eval(audio_latents)
else: else:
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype) one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
@@ -1564,18 +1557,13 @@ def generate_video(
latents = noise * noise_scale + latents * one_minus_scale latents = noise * noise_scale + latents * one_minus_scale
mx.eval(latents) mx.eval(latents)
if audio and audio_latents is not None: # Stage 2 refines video only (no audio re-denoising)
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) latents, _ = denoise_distilled(
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
mx.eval(audio_latents)
# Stage 2 uses distilled denoising (no CFG)
latents, audio_latents = denoise_distilled(
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
verbose=verbose, state=state2, verbose=verbose, state=state2,
audio_latents=audio_latents, audio_positions=audio_positions,
audio_embeddings=audio_embeddings_pos if audio else None,
) )
# Restore audio latents from stage 1
audio_latents = stage1_audio_latents
del transformer del transformer
mx.clear_cache() mx.clear_cache()