From f607112407a9a04a0b14fe8e558320f9149b4914 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 17 Jan 2026 01:38:53 +0100 Subject: [PATCH] Refactor video and audio latent generation in `generate_video` and `generate_video_with_audio` - Removed direct initialization of latents with random noise, replacing it with a conditional approach based on I2V (Image-to-Video) conditioning. - Introduced a structured flow for applying noise during the latent state creation, enhancing the conditioning process for both video and audio. - Updated the noise application logic to ensure proper handling of conditioned and unconditioned frames in both stages of video generation. - Improved code clarity and maintainability by consolidating latent shape definitions and restructuring noise application logic. --- mlx_video/generate.py | 52 ++++++++++++++++++++++++++-------- mlx_video/generate_av.py | 61 +++++++++++++++++++++++++++++++--------- 2 files changed, 87 insertions(+), 26 deletions(-) diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 973fa60..8a6c5d5 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -332,8 +332,6 @@ def generate_video( # Stage 1: Generate at half resolution print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}") mx.random.seed(seed) - latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w)) - mx.eval(latents) positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) mx.eval(positions) @@ -341,10 +339,12 @@ def generate_video( # Apply I2V conditioning if provided state1 = None if is_i2v and stage1_image_latent is not None: - # Create state with conditioning + # PyTorch flow: create zeros -> apply conditioning -> apply noiser + # Create initial state with zeros + latent_shape = (1, 128, latent_frames, stage1_h, stage1_w) state1 = LatentState( - latent=latents, - clean_latent=mx.zeros_like(latents), + latent=mx.zeros(latent_shape), + clean_latent=mx.zeros(latent_shape), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), ) conditioning = VideoConditionByLatentIndex( @@ -353,8 +353,23 @@ def generate_video( strength=image_strength, ) state1 = apply_conditioning(state1, [conditioning]) + + # Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale) + # For Stage 1, noise_scale = 1.0 (first sigma) + noise = mx.random.normal(latent_shape) + noise_scale = STAGE_1_SIGMAS[0] # 1.0 + scaled_mask = state1.denoise_mask * noise_scale + state1 = LatentState( + latent=noise * scaled_mask + state1.latent * (1.0 - scaled_mask), + clean_latent=state1.clean_latent, + denoise_mask=state1.denoise_mask, + ) latents = state1.latent mx.eval(latents) + else: + # T2V: just use random noise + latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w)) + mx.eval(latents) latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1) @@ -379,17 +394,12 @@ def generate_video( positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) mx.eval(positions) - # Add noise for refinement - noise_scale = STAGE_2_SIGMAS[0] - noise = mx.random.normal(latents.shape) - latents = noise * noise_scale + latents * (1 - noise_scale) - mx.eval(latents) - # Apply I2V conditioning for stage 2 if provided state2 = None if is_i2v and stage2_image_latent is not None: + # PyTorch flow: start with upscaled latent -> apply conditioning -> apply noiser state2 = LatentState( - latent=latents, + latent=latents, # Start with upscaled latent clean_latent=mx.zeros_like(latents), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), ) @@ -399,8 +409,26 @@ def generate_video( strength=image_strength, ) state2 = apply_conditioning(state2, [conditioning]) + + # Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale) + # For Stage 2, noise_scale = stage_2_sigmas[0] + # Conditioned frames (mask=0) keep image latent, unconditioned get partial noise + noise = mx.random.normal(latents.shape) + noise_scale = STAGE_2_SIGMAS[0] + scaled_mask = state2.denoise_mask * noise_scale + state2 = LatentState( + latent=noise * scaled_mask + state2.latent * (1.0 - scaled_mask), + clean_latent=state2.clean_latent, + denoise_mask=state2.denoise_mask, + ) latents = state2.latent mx.eval(latents) + else: + # T2V: add noise to all frames for refinement + noise_scale = STAGE_2_SIGMAS[0] + noise = mx.random.normal(latents.shape) + latents = noise * noise_scale + latents * (1 - noise_scale) + mx.eval(latents) latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2) diff --git a/mlx_video/generate_av.py b/mlx_video/generate_av.py index f73998d..a481c23 100644 --- a/mlx_video/generate_av.py +++ b/mlx_video/generate_av.py @@ -498,9 +498,6 @@ def generate_video_with_audio( # Initialize latents print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}") mx.random.seed(seed) - video_latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w)) - audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)) - mx.eval(video_latents, audio_latents) # Create position grids video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w) @@ -509,10 +506,12 @@ def generate_video_with_audio( # Apply I2V conditioning for stage 1 if provided video_state1 = None + video_latent_shape = (1, 128, latent_frames, stage1_h, stage1_w) if is_i2v and stage1_image_latent is not None: + # PyTorch flow: create zeros -> apply conditioning -> apply noiser video_state1 = LatentState( - latent=video_latents, - clean_latent=mx.zeros_like(video_latents), + latent=mx.zeros(video_latent_shape), + clean_latent=mx.zeros(video_latent_shape), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), ) conditioning = VideoConditionByLatentIndex( @@ -521,8 +520,26 @@ def generate_video_with_audio( strength=image_strength, ) video_state1 = apply_conditioning(video_state1, [conditioning]) + + # Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale) + noise = mx.random.normal(video_latent_shape) + noise_scale = STAGE_1_SIGMAS[0] # 1.0 + scaled_mask = video_state1.denoise_mask * noise_scale + video_state1 = LatentState( + latent=noise * scaled_mask + video_state1.latent * (1.0 - scaled_mask), + clean_latent=video_state1.clean_latent, + denoise_mask=video_state1.denoise_mask, + ) video_latents = video_state1.latent mx.eval(video_latents) + else: + # T2V: just use random noise + video_latents = mx.random.normal(video_latent_shape) + mx.eval(video_latents) + + # Audio always uses pure noise (no I2V for audio) + audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)) + mx.eval(audio_latents) # Stage 1 denoising video_latents, audio_latents = denoise_av( @@ -554,19 +571,12 @@ def generate_video_with_audio( video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w) mx.eval(video_positions) - # Add noise for refinement - noise_scale = STAGE_2_SIGMAS[0] - video_noise = mx.random.normal(video_latents.shape) - audio_noise = mx.random.normal(audio_latents.shape) - video_latents = video_noise * noise_scale + video_latents * (1 - noise_scale) - audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale) - mx.eval(video_latents, audio_latents) - # Apply I2V conditioning for stage 2 if provided video_state2 = None if is_i2v and stage2_image_latent is not None: + # PyTorch flow: start with upscaled latent -> apply conditioning -> apply noiser video_state2 = LatentState( - latent=video_latents, + latent=video_latents, # Start with upscaled latent clean_latent=mx.zeros_like(video_latents), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), ) @@ -576,9 +586,32 @@ def generate_video_with_audio( strength=image_strength, ) video_state2 = apply_conditioning(video_state2, [conditioning]) + + # Apply noiser: conditioned frames (mask=0) keep image latent, unconditioned get partial noise + video_noise = mx.random.normal(video_latents.shape) + noise_scale = STAGE_2_SIGMAS[0] + scaled_mask = video_state2.denoise_mask * noise_scale + video_state2 = LatentState( + latent=video_noise * scaled_mask + video_state2.latent * (1.0 - scaled_mask), + clean_latent=video_state2.clean_latent, + denoise_mask=video_state2.denoise_mask, + ) video_latents = video_state2.latent mx.eval(video_latents) + # Audio still gets noise (no I2V for audio) + audio_noise = mx.random.normal(audio_latents.shape) + audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale) + mx.eval(audio_latents) + else: + # T2V: add noise to all frames for refinement + noise_scale = STAGE_2_SIGMAS[0] + video_noise = mx.random.normal(video_latents.shape) + audio_noise = mx.random.normal(audio_latents.shape) + video_latents = video_noise * noise_scale + video_latents * (1 - noise_scale) + audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale) + mx.eval(video_latents, audio_latents) + video_latents, audio_latents = denoise_av( video_latents, audio_latents, video_positions, audio_positions,