Refactor video and audio latent generation in generate_video and generate_video_with_audio
- Removed direct initialization of latents with random noise, replacing it with a conditional approach based on I2V (Image-to-Video) conditioning. - Introduced a structured flow for applying noise during the latent state creation, enhancing the conditioning process for both video and audio. - Updated the noise application logic to ensure proper handling of conditioned and unconditioned frames in both stages of video generation. - Improved code clarity and maintainability by consolidating latent shape definitions and restructuring noise application logic.
This commit is contained in:
@@ -332,8 +332,6 @@ def generate_video(
|
|||||||
# Stage 1: Generate at half resolution
|
# Stage 1: Generate at half resolution
|
||||||
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
|
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
|
||||||
mx.random.seed(seed)
|
mx.random.seed(seed)
|
||||||
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
|
|
||||||
mx.eval(latents)
|
|
||||||
|
|
||||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||||
mx.eval(positions)
|
mx.eval(positions)
|
||||||
@@ -341,10 +339,12 @@ def generate_video(
|
|||||||
# Apply I2V conditioning if provided
|
# Apply I2V conditioning if provided
|
||||||
state1 = None
|
state1 = None
|
||||||
if is_i2v and stage1_image_latent is not None:
|
if is_i2v and stage1_image_latent is not None:
|
||||||
# Create state with conditioning
|
# PyTorch flow: create zeros -> apply conditioning -> apply noiser
|
||||||
|
# Create initial state with zeros
|
||||||
|
latent_shape = (1, 128, latent_frames, stage1_h, stage1_w)
|
||||||
state1 = LatentState(
|
state1 = LatentState(
|
||||||
latent=latents,
|
latent=mx.zeros(latent_shape),
|
||||||
clean_latent=mx.zeros_like(latents),
|
clean_latent=mx.zeros(latent_shape),
|
||||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
||||||
)
|
)
|
||||||
conditioning = VideoConditionByLatentIndex(
|
conditioning = VideoConditionByLatentIndex(
|
||||||
@@ -353,8 +353,23 @@ def generate_video(
|
|||||||
strength=image_strength,
|
strength=image_strength,
|
||||||
)
|
)
|
||||||
state1 = apply_conditioning(state1, [conditioning])
|
state1 = apply_conditioning(state1, [conditioning])
|
||||||
|
|
||||||
|
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
|
||||||
|
# For Stage 1, noise_scale = 1.0 (first sigma)
|
||||||
|
noise = mx.random.normal(latent_shape)
|
||||||
|
noise_scale = STAGE_1_SIGMAS[0] # 1.0
|
||||||
|
scaled_mask = state1.denoise_mask * noise_scale
|
||||||
|
state1 = LatentState(
|
||||||
|
latent=noise * scaled_mask + state1.latent * (1.0 - scaled_mask),
|
||||||
|
clean_latent=state1.clean_latent,
|
||||||
|
denoise_mask=state1.denoise_mask,
|
||||||
|
)
|
||||||
latents = state1.latent
|
latents = state1.latent
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
|
else:
|
||||||
|
# T2V: just use random noise
|
||||||
|
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
|
||||||
|
mx.eval(latents)
|
||||||
|
|
||||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1)
|
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1)
|
||||||
|
|
||||||
@@ -379,17 +394,12 @@ def generate_video(
|
|||||||
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||||
mx.eval(positions)
|
mx.eval(positions)
|
||||||
|
|
||||||
# Add noise for refinement
|
|
||||||
noise_scale = STAGE_2_SIGMAS[0]
|
|
||||||
noise = mx.random.normal(latents.shape)
|
|
||||||
latents = noise * noise_scale + latents * (1 - noise_scale)
|
|
||||||
mx.eval(latents)
|
|
||||||
|
|
||||||
# Apply I2V conditioning for stage 2 if provided
|
# Apply I2V conditioning for stage 2 if provided
|
||||||
state2 = None
|
state2 = None
|
||||||
if is_i2v and stage2_image_latent is not None:
|
if is_i2v and stage2_image_latent is not None:
|
||||||
|
# PyTorch flow: start with upscaled latent -> apply conditioning -> apply noiser
|
||||||
state2 = LatentState(
|
state2 = LatentState(
|
||||||
latent=latents,
|
latent=latents, # Start with upscaled latent
|
||||||
clean_latent=mx.zeros_like(latents),
|
clean_latent=mx.zeros_like(latents),
|
||||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
||||||
)
|
)
|
||||||
@@ -399,8 +409,26 @@ def generate_video(
|
|||||||
strength=image_strength,
|
strength=image_strength,
|
||||||
)
|
)
|
||||||
state2 = apply_conditioning(state2, [conditioning])
|
state2 = apply_conditioning(state2, [conditioning])
|
||||||
|
|
||||||
|
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
|
||||||
|
# For Stage 2, noise_scale = stage_2_sigmas[0]
|
||||||
|
# Conditioned frames (mask=0) keep image latent, unconditioned get partial noise
|
||||||
|
noise = mx.random.normal(latents.shape)
|
||||||
|
noise_scale = STAGE_2_SIGMAS[0]
|
||||||
|
scaled_mask = state2.denoise_mask * noise_scale
|
||||||
|
state2 = LatentState(
|
||||||
|
latent=noise * scaled_mask + state2.latent * (1.0 - scaled_mask),
|
||||||
|
clean_latent=state2.clean_latent,
|
||||||
|
denoise_mask=state2.denoise_mask,
|
||||||
|
)
|
||||||
latents = state2.latent
|
latents = state2.latent
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
|
else:
|
||||||
|
# T2V: add noise to all frames for refinement
|
||||||
|
noise_scale = STAGE_2_SIGMAS[0]
|
||||||
|
noise = mx.random.normal(latents.shape)
|
||||||
|
latents = noise * noise_scale + latents * (1 - noise_scale)
|
||||||
|
mx.eval(latents)
|
||||||
|
|
||||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2)
|
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2)
|
||||||
|
|
||||||
|
|||||||
@@ -498,9 +498,6 @@ def generate_video_with_audio(
|
|||||||
# Initialize latents
|
# Initialize latents
|
||||||
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
|
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
|
||||||
mx.random.seed(seed)
|
mx.random.seed(seed)
|
||||||
video_latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
|
|
||||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS))
|
|
||||||
mx.eval(video_latents, audio_latents)
|
|
||||||
|
|
||||||
# Create position grids
|
# Create position grids
|
||||||
video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w)
|
video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||||
@@ -509,10 +506,12 @@ def generate_video_with_audio(
|
|||||||
|
|
||||||
# Apply I2V conditioning for stage 1 if provided
|
# Apply I2V conditioning for stage 1 if provided
|
||||||
video_state1 = None
|
video_state1 = None
|
||||||
|
video_latent_shape = (1, 128, latent_frames, stage1_h, stage1_w)
|
||||||
if is_i2v and stage1_image_latent is not None:
|
if is_i2v and stage1_image_latent is not None:
|
||||||
|
# PyTorch flow: create zeros -> apply conditioning -> apply noiser
|
||||||
video_state1 = LatentState(
|
video_state1 = LatentState(
|
||||||
latent=video_latents,
|
latent=mx.zeros(video_latent_shape),
|
||||||
clean_latent=mx.zeros_like(video_latents),
|
clean_latent=mx.zeros(video_latent_shape),
|
||||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
||||||
)
|
)
|
||||||
conditioning = VideoConditionByLatentIndex(
|
conditioning = VideoConditionByLatentIndex(
|
||||||
@@ -521,8 +520,26 @@ def generate_video_with_audio(
|
|||||||
strength=image_strength,
|
strength=image_strength,
|
||||||
)
|
)
|
||||||
video_state1 = apply_conditioning(video_state1, [conditioning])
|
video_state1 = apply_conditioning(video_state1, [conditioning])
|
||||||
|
|
||||||
|
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
|
||||||
|
noise = mx.random.normal(video_latent_shape)
|
||||||
|
noise_scale = STAGE_1_SIGMAS[0] # 1.0
|
||||||
|
scaled_mask = video_state1.denoise_mask * noise_scale
|
||||||
|
video_state1 = LatentState(
|
||||||
|
latent=noise * scaled_mask + video_state1.latent * (1.0 - scaled_mask),
|
||||||
|
clean_latent=video_state1.clean_latent,
|
||||||
|
denoise_mask=video_state1.denoise_mask,
|
||||||
|
)
|
||||||
video_latents = video_state1.latent
|
video_latents = video_state1.latent
|
||||||
mx.eval(video_latents)
|
mx.eval(video_latents)
|
||||||
|
else:
|
||||||
|
# T2V: just use random noise
|
||||||
|
video_latents = mx.random.normal(video_latent_shape)
|
||||||
|
mx.eval(video_latents)
|
||||||
|
|
||||||
|
# Audio always uses pure noise (no I2V for audio)
|
||||||
|
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS))
|
||||||
|
mx.eval(audio_latents)
|
||||||
|
|
||||||
# Stage 1 denoising
|
# Stage 1 denoising
|
||||||
video_latents, audio_latents = denoise_av(
|
video_latents, audio_latents = denoise_av(
|
||||||
@@ -554,19 +571,12 @@ def generate_video_with_audio(
|
|||||||
video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w)
|
video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||||
mx.eval(video_positions)
|
mx.eval(video_positions)
|
||||||
|
|
||||||
# Add noise for refinement
|
|
||||||
noise_scale = STAGE_2_SIGMAS[0]
|
|
||||||
video_noise = mx.random.normal(video_latents.shape)
|
|
||||||
audio_noise = mx.random.normal(audio_latents.shape)
|
|
||||||
video_latents = video_noise * noise_scale + video_latents * (1 - noise_scale)
|
|
||||||
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale)
|
|
||||||
mx.eval(video_latents, audio_latents)
|
|
||||||
|
|
||||||
# Apply I2V conditioning for stage 2 if provided
|
# Apply I2V conditioning for stage 2 if provided
|
||||||
video_state2 = None
|
video_state2 = None
|
||||||
if is_i2v and stage2_image_latent is not None:
|
if is_i2v and stage2_image_latent is not None:
|
||||||
|
# PyTorch flow: start with upscaled latent -> apply conditioning -> apply noiser
|
||||||
video_state2 = LatentState(
|
video_state2 = LatentState(
|
||||||
latent=video_latents,
|
latent=video_latents, # Start with upscaled latent
|
||||||
clean_latent=mx.zeros_like(video_latents),
|
clean_latent=mx.zeros_like(video_latents),
|
||||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
||||||
)
|
)
|
||||||
@@ -576,9 +586,32 @@ def generate_video_with_audio(
|
|||||||
strength=image_strength,
|
strength=image_strength,
|
||||||
)
|
)
|
||||||
video_state2 = apply_conditioning(video_state2, [conditioning])
|
video_state2 = apply_conditioning(video_state2, [conditioning])
|
||||||
|
|
||||||
|
# Apply noiser: conditioned frames (mask=0) keep image latent, unconditioned get partial noise
|
||||||
|
video_noise = mx.random.normal(video_latents.shape)
|
||||||
|
noise_scale = STAGE_2_SIGMAS[0]
|
||||||
|
scaled_mask = video_state2.denoise_mask * noise_scale
|
||||||
|
video_state2 = LatentState(
|
||||||
|
latent=video_noise * scaled_mask + video_state2.latent * (1.0 - scaled_mask),
|
||||||
|
clean_latent=video_state2.clean_latent,
|
||||||
|
denoise_mask=video_state2.denoise_mask,
|
||||||
|
)
|
||||||
video_latents = video_state2.latent
|
video_latents = video_state2.latent
|
||||||
mx.eval(video_latents)
|
mx.eval(video_latents)
|
||||||
|
|
||||||
|
# Audio still gets noise (no I2V for audio)
|
||||||
|
audio_noise = mx.random.normal(audio_latents.shape)
|
||||||
|
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale)
|
||||||
|
mx.eval(audio_latents)
|
||||||
|
else:
|
||||||
|
# T2V: add noise to all frames for refinement
|
||||||
|
noise_scale = STAGE_2_SIGMAS[0]
|
||||||
|
video_noise = mx.random.normal(video_latents.shape)
|
||||||
|
audio_noise = mx.random.normal(audio_latents.shape)
|
||||||
|
video_latents = video_noise * noise_scale + video_latents * (1 - noise_scale)
|
||||||
|
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale)
|
||||||
|
mx.eval(video_latents, audio_latents)
|
||||||
|
|
||||||
video_latents, audio_latents = denoise_av(
|
video_latents, audio_latents = denoise_av(
|
||||||
video_latents, audio_latents,
|
video_latents, audio_latents,
|
||||||
video_positions, audio_positions,
|
video_positions, audio_positions,
|
||||||
|
|||||||
Reference in New Issue
Block a user