Refactor audio handling in generate_video function to preserve stage 1 audio latents during stage 2 processing. Remove redundant audio re-denoising steps, ensuring audio integrity while refining video output. Update comments for clarity on audio processing logic.
This commit is contained in:
@@ -1268,6 +1268,10 @@ def generate_video(
|
||||
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||
mx.eval(positions)
|
||||
|
||||
# Save stage 1 audio latents — stage 2 only refines video (spatial upsampling).
|
||||
# Audio is already fully denoised from stage 1; re-noising would destroy the signal.
|
||||
stage1_audio_latents = audio_latents
|
||||
|
||||
state2 = None
|
||||
if is_i2v and stage2_image_latent is not None:
|
||||
state2 = LatentState(
|
||||
@@ -1288,12 +1292,6 @@ def generate_video(
|
||||
)
|
||||
latents = state2.latent
|
||||
mx.eval(latents)
|
||||
|
||||
if audio and audio_latents is not None:
|
||||
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
|
||||
one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale
|
||||
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
|
||||
mx.eval(audio_latents)
|
||||
else:
|
||||
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
@@ -1301,16 +1299,13 @@ def generate_video(
|
||||
latents = noise * noise_scale + latents * one_minus_scale
|
||||
mx.eval(latents)
|
||||
|
||||
if audio and audio_latents is not None:
|
||||
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
|
||||
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
|
||||
mx.eval(audio_latents)
|
||||
|
||||
latents, audio_latents = denoise_distilled(
|
||||
# Stage 2 refines video only (no audio re-denoising)
|
||||
latents, _ = denoise_distilled(
|
||||
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
|
||||
verbose=verbose, state=state2,
|
||||
audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings,
|
||||
)
|
||||
# Restore audio latents from stage 1
|
||||
audio_latents = stage1_audio_latents
|
||||
|
||||
elif pipeline == PipelineType.DEV:
|
||||
# ======================================================================
|
||||
@@ -1531,6 +1526,10 @@ def generate_video(
|
||||
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||
mx.eval(positions)
|
||||
|
||||
# Save stage 1 audio latents — stage 2 only refines video (spatial upsampling).
|
||||
# Audio is already fully denoised from stage 1; re-noising would destroy the signal.
|
||||
stage1_audio_latents = audio_latents
|
||||
|
||||
state2 = None
|
||||
if is_i2v and stage2_image_latent is not None:
|
||||
state2 = LatentState(
|
||||
@@ -1551,12 +1550,6 @@ def generate_video(
|
||||
)
|
||||
latents = state2.latent
|
||||
mx.eval(latents)
|
||||
|
||||
if audio and audio_latents is not None:
|
||||
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
|
||||
one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale
|
||||
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
|
||||
mx.eval(audio_latents)
|
||||
else:
|
||||
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
@@ -1564,18 +1557,13 @@ def generate_video(
|
||||
latents = noise * noise_scale + latents * one_minus_scale
|
||||
mx.eval(latents)
|
||||
|
||||
if audio and audio_latents is not None:
|
||||
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
|
||||
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
|
||||
mx.eval(audio_latents)
|
||||
|
||||
# Stage 2 uses distilled denoising (no CFG)
|
||||
latents, audio_latents = denoise_distilled(
|
||||
# Stage 2 refines video only (no audio re-denoising)
|
||||
latents, _ = denoise_distilled(
|
||||
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
|
||||
verbose=verbose, state=state2,
|
||||
audio_latents=audio_latents, audio_positions=audio_positions,
|
||||
audio_embeddings=audio_embeddings_pos if audio else None,
|
||||
)
|
||||
# Restore audio latents from stage 1
|
||||
audio_latents = stage1_audio_latents
|
||||
|
||||
del transformer
|
||||
mx.clear_cache()
|
||||
|
||||
Reference in New Issue
Block a user