Refactor video and audio latent generation in generate_video and generate_video_with_audio

- Removed direct initialization of latents with random noise, replacing it with a conditional approach based on I2V (Image-to-Video) conditioning.
- Introduced a structured flow for applying noise during the latent state creation, enhancing the conditioning process for both video and audio.
- Updated the noise application logic to ensure proper handling of conditioned and unconditioned frames in both stages of video generation.
- Improved code clarity and maintainability by consolidating latent shape definitions and restructuring noise application logic.
This commit is contained in:
Prince Canuma
2026-01-17 01:38:53 +01:00
parent d52e567c56
commit f607112407
2 changed files with 87 additions and 26 deletions

View File

@@ -332,8 +332,6 @@ def generate_video(
# Stage 1: Generate at half resolution
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
mx.random.seed(seed)
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
mx.eval(latents)
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions)
@@ -341,10 +339,12 @@ def generate_video(
# Apply I2V conditioning if provided
state1 = None
if is_i2v and stage1_image_latent is not None:
# Create state with conditioning
# PyTorch flow: create zeros -> apply conditioning -> apply noiser
# Create initial state with zeros
latent_shape = (1, 128, latent_frames, stage1_h, stage1_w)
state1 = LatentState(
latent=latents,
clean_latent=mx.zeros_like(latents),
latent=mx.zeros(latent_shape),
clean_latent=mx.zeros(latent_shape),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
)
conditioning = VideoConditionByLatentIndex(
@@ -353,8 +353,23 @@ def generate_video(
strength=image_strength,
)
state1 = apply_conditioning(state1, [conditioning])
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
# For Stage 1, noise_scale = 1.0 (first sigma)
noise = mx.random.normal(latent_shape)
noise_scale = STAGE_1_SIGMAS[0] # 1.0
scaled_mask = state1.denoise_mask * noise_scale
state1 = LatentState(
latent=noise * scaled_mask + state1.latent * (1.0 - scaled_mask),
clean_latent=state1.clean_latent,
denoise_mask=state1.denoise_mask,
)
latents = state1.latent
mx.eval(latents)
else:
# T2V: just use random noise
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1)
@@ -379,17 +394,12 @@ def generate_video(
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions)
# Add noise for refinement
noise_scale = STAGE_2_SIGMAS[0]
noise = mx.random.normal(latents.shape)
latents = noise * noise_scale + latents * (1 - noise_scale)
mx.eval(latents)
# Apply I2V conditioning for stage 2 if provided
state2 = None
if is_i2v and stage2_image_latent is not None:
# PyTorch flow: start with upscaled latent -> apply conditioning -> apply noiser
state2 = LatentState(
latent=latents,
latent=latents, # Start with upscaled latent
clean_latent=mx.zeros_like(latents),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
)
@@ -399,8 +409,26 @@ def generate_video(
strength=image_strength,
)
state2 = apply_conditioning(state2, [conditioning])
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
# For Stage 2, noise_scale = stage_2_sigmas[0]
# Conditioned frames (mask=0) keep image latent, unconditioned get partial noise
noise = mx.random.normal(latents.shape)
noise_scale = STAGE_2_SIGMAS[0]
scaled_mask = state2.denoise_mask * noise_scale
state2 = LatentState(
latent=noise * scaled_mask + state2.latent * (1.0 - scaled_mask),
clean_latent=state2.clean_latent,
denoise_mask=state2.denoise_mask,
)
latents = state2.latent
mx.eval(latents)
else:
# T2V: add noise to all frames for refinement
noise_scale = STAGE_2_SIGMAS[0]
noise = mx.random.normal(latents.shape)
latents = noise * noise_scale + latents * (1 - noise_scale)
mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2)