Refactor video and audio latent generation in generate_video and generate_video_with_audio

- Removed direct initialization of latents with random noise, replacing it with a conditional approach based on I2V (Image-to-Video) conditioning.
- Introduced a structured flow for applying noise during the latent state creation, enhancing the conditioning process for both video and audio.
- Updated the noise application logic to ensure proper handling of conditioned and unconditioned frames in both stages of video generation.
- Improved code clarity and maintainability by consolidating latent shape definitions and restructuring noise application logic.
This commit is contained in:
Prince Canuma
2026-01-17 01:38:53 +01:00
parent d52e567c56
commit f607112407
2 changed files with 87 additions and 26 deletions

View File

@@ -498,9 +498,6 @@ def generate_video_with_audio(
# Initialize latents
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
mx.random.seed(seed)
video_latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS))
mx.eval(video_latents, audio_latents)
# Create position grids
video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w)
@@ -509,10 +506,12 @@ def generate_video_with_audio(
# Apply I2V conditioning for stage 1 if provided
video_state1 = None
video_latent_shape = (1, 128, latent_frames, stage1_h, stage1_w)
if is_i2v and stage1_image_latent is not None:
# PyTorch flow: create zeros -> apply conditioning -> apply noiser
video_state1 = LatentState(
latent=video_latents,
clean_latent=mx.zeros_like(video_latents),
latent=mx.zeros(video_latent_shape),
clean_latent=mx.zeros(video_latent_shape),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
)
conditioning = VideoConditionByLatentIndex(
@@ -521,8 +520,26 @@ def generate_video_with_audio(
strength=image_strength,
)
video_state1 = apply_conditioning(video_state1, [conditioning])
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
noise = mx.random.normal(video_latent_shape)
noise_scale = STAGE_1_SIGMAS[0] # 1.0
scaled_mask = video_state1.denoise_mask * noise_scale
video_state1 = LatentState(
latent=noise * scaled_mask + video_state1.latent * (1.0 - scaled_mask),
clean_latent=video_state1.clean_latent,
denoise_mask=video_state1.denoise_mask,
)
video_latents = video_state1.latent
mx.eval(video_latents)
else:
# T2V: just use random noise
video_latents = mx.random.normal(video_latent_shape)
mx.eval(video_latents)
# Audio always uses pure noise (no I2V for audio)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS))
mx.eval(audio_latents)
# Stage 1 denoising
video_latents, audio_latents = denoise_av(
@@ -554,19 +571,12 @@ def generate_video_with_audio(
video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(video_positions)
# Add noise for refinement
noise_scale = STAGE_2_SIGMAS[0]
video_noise = mx.random.normal(video_latents.shape)
audio_noise = mx.random.normal(audio_latents.shape)
video_latents = video_noise * noise_scale + video_latents * (1 - noise_scale)
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale)
mx.eval(video_latents, audio_latents)
# Apply I2V conditioning for stage 2 if provided
video_state2 = None
if is_i2v and stage2_image_latent is not None:
# PyTorch flow: start with upscaled latent -> apply conditioning -> apply noiser
video_state2 = LatentState(
latent=video_latents,
latent=video_latents, # Start with upscaled latent
clean_latent=mx.zeros_like(video_latents),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
)
@@ -576,9 +586,32 @@ def generate_video_with_audio(
strength=image_strength,
)
video_state2 = apply_conditioning(video_state2, [conditioning])
# Apply noiser: conditioned frames (mask=0) keep image latent, unconditioned get partial noise
video_noise = mx.random.normal(video_latents.shape)
noise_scale = STAGE_2_SIGMAS[0]
scaled_mask = video_state2.denoise_mask * noise_scale
video_state2 = LatentState(
latent=video_noise * scaled_mask + video_state2.latent * (1.0 - scaled_mask),
clean_latent=video_state2.clean_latent,
denoise_mask=video_state2.denoise_mask,
)
video_latents = video_state2.latent
mx.eval(video_latents)
# Audio still gets noise (no I2V for audio)
audio_noise = mx.random.normal(audio_latents.shape)
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale)
mx.eval(audio_latents)
else:
# T2V: add noise to all frames for refinement
noise_scale = STAGE_2_SIGMAS[0]
video_noise = mx.random.normal(video_latents.shape)
audio_noise = mx.random.normal(audio_latents.shape)
video_latents = video_noise * noise_scale + video_latents * (1 - noise_scale)
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale)
mx.eval(video_latents, audio_latents)
video_latents, audio_latents = denoise_av(
video_latents, audio_latents,
video_positions, audio_positions,