Refactor video and audio latent generation in generate_video and generate_video_with_audio
- Removed direct initialization of latents with random noise, replacing it with a conditional approach based on I2V (Image-to-Video) conditioning. - Introduced a structured flow for applying noise during the latent state creation, enhancing the conditioning process for both video and audio. - Updated the noise application logic to ensure proper handling of conditioned and unconditioned frames in both stages of video generation. - Improved code clarity and maintainability by consolidating latent shape definitions and restructuring noise application logic.
This commit is contained in:
@@ -498,9 +498,6 @@ def generate_video_with_audio(
|
||||
# Initialize latents
|
||||
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
|
||||
mx.random.seed(seed)
|
||||
video_latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
|
||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS))
|
||||
mx.eval(video_latents, audio_latents)
|
||||
|
||||
# Create position grids
|
||||
video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||
@@ -509,10 +506,12 @@ def generate_video_with_audio(
|
||||
|
||||
# Apply I2V conditioning for stage 1 if provided
|
||||
video_state1 = None
|
||||
video_latent_shape = (1, 128, latent_frames, stage1_h, stage1_w)
|
||||
if is_i2v and stage1_image_latent is not None:
|
||||
# PyTorch flow: create zeros -> apply conditioning -> apply noiser
|
||||
video_state1 = LatentState(
|
||||
latent=video_latents,
|
||||
clean_latent=mx.zeros_like(video_latents),
|
||||
latent=mx.zeros(video_latent_shape),
|
||||
clean_latent=mx.zeros(video_latent_shape),
|
||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
||||
)
|
||||
conditioning = VideoConditionByLatentIndex(
|
||||
@@ -521,8 +520,26 @@ def generate_video_with_audio(
|
||||
strength=image_strength,
|
||||
)
|
||||
video_state1 = apply_conditioning(video_state1, [conditioning])
|
||||
|
||||
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
|
||||
noise = mx.random.normal(video_latent_shape)
|
||||
noise_scale = STAGE_1_SIGMAS[0] # 1.0
|
||||
scaled_mask = video_state1.denoise_mask * noise_scale
|
||||
video_state1 = LatentState(
|
||||
latent=noise * scaled_mask + video_state1.latent * (1.0 - scaled_mask),
|
||||
clean_latent=video_state1.clean_latent,
|
||||
denoise_mask=video_state1.denoise_mask,
|
||||
)
|
||||
video_latents = video_state1.latent
|
||||
mx.eval(video_latents)
|
||||
else:
|
||||
# T2V: just use random noise
|
||||
video_latents = mx.random.normal(video_latent_shape)
|
||||
mx.eval(video_latents)
|
||||
|
||||
# Audio always uses pure noise (no I2V for audio)
|
||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS))
|
||||
mx.eval(audio_latents)
|
||||
|
||||
# Stage 1 denoising
|
||||
video_latents, audio_latents = denoise_av(
|
||||
@@ -554,19 +571,12 @@ def generate_video_with_audio(
|
||||
video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||
mx.eval(video_positions)
|
||||
|
||||
# Add noise for refinement
|
||||
noise_scale = STAGE_2_SIGMAS[0]
|
||||
video_noise = mx.random.normal(video_latents.shape)
|
||||
audio_noise = mx.random.normal(audio_latents.shape)
|
||||
video_latents = video_noise * noise_scale + video_latents * (1 - noise_scale)
|
||||
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale)
|
||||
mx.eval(video_latents, audio_latents)
|
||||
|
||||
# Apply I2V conditioning for stage 2 if provided
|
||||
video_state2 = None
|
||||
if is_i2v and stage2_image_latent is not None:
|
||||
# PyTorch flow: start with upscaled latent -> apply conditioning -> apply noiser
|
||||
video_state2 = LatentState(
|
||||
latent=video_latents,
|
||||
latent=video_latents, # Start with upscaled latent
|
||||
clean_latent=mx.zeros_like(video_latents),
|
||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
||||
)
|
||||
@@ -576,9 +586,32 @@ def generate_video_with_audio(
|
||||
strength=image_strength,
|
||||
)
|
||||
video_state2 = apply_conditioning(video_state2, [conditioning])
|
||||
|
||||
# Apply noiser: conditioned frames (mask=0) keep image latent, unconditioned get partial noise
|
||||
video_noise = mx.random.normal(video_latents.shape)
|
||||
noise_scale = STAGE_2_SIGMAS[0]
|
||||
scaled_mask = video_state2.denoise_mask * noise_scale
|
||||
video_state2 = LatentState(
|
||||
latent=video_noise * scaled_mask + video_state2.latent * (1.0 - scaled_mask),
|
||||
clean_latent=video_state2.clean_latent,
|
||||
denoise_mask=video_state2.denoise_mask,
|
||||
)
|
||||
video_latents = video_state2.latent
|
||||
mx.eval(video_latents)
|
||||
|
||||
# Audio still gets noise (no I2V for audio)
|
||||
audio_noise = mx.random.normal(audio_latents.shape)
|
||||
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale)
|
||||
mx.eval(audio_latents)
|
||||
else:
|
||||
# T2V: add noise to all frames for refinement
|
||||
noise_scale = STAGE_2_SIGMAS[0]
|
||||
video_noise = mx.random.normal(video_latents.shape)
|
||||
audio_noise = mx.random.normal(audio_latents.shape)
|
||||
video_latents = video_noise * noise_scale + video_latents * (1 - noise_scale)
|
||||
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale)
|
||||
mx.eval(video_latents, audio_latents)
|
||||
|
||||
video_latents, audio_latents = denoise_av(
|
||||
video_latents, audio_latents,
|
||||
video_positions, audio_positions,
|
||||
|
||||
Reference in New Issue
Block a user