Refactor audio handling in generate_video function to preserve stage 1 audio latents during stage 2 processing. Remove redundant audio re-denoising steps, ensuring audio integrity while refining video output. Update comments for clarity on audio processing logic.
This commit is contained in:
@@ -1268,6 +1268,10 @@ def generate_video(
|
|||||||
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||||
mx.eval(positions)
|
mx.eval(positions)
|
||||||
|
|
||||||
|
# Save stage 1 audio latents — stage 2 only refines video (spatial upsampling).
|
||||||
|
# Audio is already fully denoised from stage 1; re-noising would destroy the signal.
|
||||||
|
stage1_audio_latents = audio_latents
|
||||||
|
|
||||||
state2 = None
|
state2 = None
|
||||||
if is_i2v and stage2_image_latent is not None:
|
if is_i2v and stage2_image_latent is not None:
|
||||||
state2 = LatentState(
|
state2 = LatentState(
|
||||||
@@ -1288,12 +1292,6 @@ def generate_video(
|
|||||||
)
|
)
|
||||||
latents = state2.latent
|
latents = state2.latent
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
|
|
||||||
if audio and audio_latents is not None:
|
|
||||||
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
|
|
||||||
one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale
|
|
||||||
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
|
|
||||||
mx.eval(audio_latents)
|
|
||||||
else:
|
else:
|
||||||
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||||
one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
|
one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||||
@@ -1301,16 +1299,13 @@ def generate_video(
|
|||||||
latents = noise * noise_scale + latents * one_minus_scale
|
latents = noise * noise_scale + latents * one_minus_scale
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
|
|
||||||
if audio and audio_latents is not None:
|
# Stage 2 refines video only (no audio re-denoising)
|
||||||
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
|
latents, _ = denoise_distilled(
|
||||||
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
|
|
||||||
mx.eval(audio_latents)
|
|
||||||
|
|
||||||
latents, audio_latents = denoise_distilled(
|
|
||||||
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
|
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
|
||||||
verbose=verbose, state=state2,
|
verbose=verbose, state=state2,
|
||||||
audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings,
|
|
||||||
)
|
)
|
||||||
|
# Restore audio latents from stage 1
|
||||||
|
audio_latents = stage1_audio_latents
|
||||||
|
|
||||||
elif pipeline == PipelineType.DEV:
|
elif pipeline == PipelineType.DEV:
|
||||||
# ======================================================================
|
# ======================================================================
|
||||||
@@ -1531,6 +1526,10 @@ def generate_video(
|
|||||||
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||||
mx.eval(positions)
|
mx.eval(positions)
|
||||||
|
|
||||||
|
# Save stage 1 audio latents — stage 2 only refines video (spatial upsampling).
|
||||||
|
# Audio is already fully denoised from stage 1; re-noising would destroy the signal.
|
||||||
|
stage1_audio_latents = audio_latents
|
||||||
|
|
||||||
state2 = None
|
state2 = None
|
||||||
if is_i2v and stage2_image_latent is not None:
|
if is_i2v and stage2_image_latent is not None:
|
||||||
state2 = LatentState(
|
state2 = LatentState(
|
||||||
@@ -1551,12 +1550,6 @@ def generate_video(
|
|||||||
)
|
)
|
||||||
latents = state2.latent
|
latents = state2.latent
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
|
|
||||||
if audio and audio_latents is not None:
|
|
||||||
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
|
|
||||||
one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale
|
|
||||||
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
|
|
||||||
mx.eval(audio_latents)
|
|
||||||
else:
|
else:
|
||||||
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||||
one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
|
one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||||
@@ -1564,18 +1557,13 @@ def generate_video(
|
|||||||
latents = noise * noise_scale + latents * one_minus_scale
|
latents = noise * noise_scale + latents * one_minus_scale
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
|
|
||||||
if audio and audio_latents is not None:
|
# Stage 2 refines video only (no audio re-denoising)
|
||||||
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
|
latents, _ = denoise_distilled(
|
||||||
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
|
|
||||||
mx.eval(audio_latents)
|
|
||||||
|
|
||||||
# Stage 2 uses distilled denoising (no CFG)
|
|
||||||
latents, audio_latents = denoise_distilled(
|
|
||||||
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
|
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
|
||||||
verbose=verbose, state=state2,
|
verbose=verbose, state=state2,
|
||||||
audio_latents=audio_latents, audio_positions=audio_positions,
|
|
||||||
audio_embeddings=audio_embeddings_pos if audio else None,
|
|
||||||
)
|
)
|
||||||
|
# Restore audio latents from stage 1
|
||||||
|
audio_latents = stage1_audio_latents
|
||||||
|
|
||||||
del transformer
|
del transformer
|
||||||
mx.clear_cache()
|
mx.clear_cache()
|
||||||
|
|||||||
Reference in New Issue
Block a user