Add image-to-video (I2V) conditioning support
- Introduced `load_image`, `prepare_image_for_encoding`, and `apply_conditioning` functions for handling image inputs and conditioning during video generation. - Enhanced `generate_video` and `denoise_av` functions to accept optional image inputs for I2V conditioning. - Updated command-line interface to include parameters for image conditioning, such as `--image`, `--image-strength`, and `--image-frame-idx`. - Added new `VideoConditionByLatentIndex` and `LatentState` classes for managing latent states with conditioning. - Implemented VAE encoder loading and image encoding for conditioning in the video generation process.d
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
@@ -27,9 +27,12 @@ from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeTyp
|
||||
from mlx_video.models.ltx.ltx import LTXModel
|
||||
from mlx_video.models.ltx.transformer import Modality
|
||||
from mlx_video.convert import sanitize_transformer_weights, sanitize_audio_vae_weights, sanitize_vocoder_weights
|
||||
from mlx_video.utils import to_denoised, get_model_path
|
||||
from mlx_video.utils import to_denoised, get_model_path, load_image, prepare_image_for_encoding
|
||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
|
||||
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
|
||||
from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning
|
||||
from mlx_video.conditioning.latent import LatentState, apply_denoise_mask
|
||||
|
||||
|
||||
# Distilled sigma schedules
|
||||
@@ -141,13 +144,35 @@ def denoise_av(
|
||||
transformer: LTXModel,
|
||||
sigmas: list,
|
||||
verbose: bool = True,
|
||||
video_state: Optional[LatentState] = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Run denoising loop for audio-video generation."""
|
||||
"""Run denoising loop for audio-video generation with optional I2V conditioning.
|
||||
|
||||
Args:
|
||||
video_latents: Video latent tensor (B, C, F, H, W)
|
||||
audio_latents: Audio latent tensor (B, C, T, F)
|
||||
video_positions: Video position embeddings
|
||||
audio_positions: Audio position embeddings
|
||||
video_embeddings: Video text embeddings
|
||||
audio_embeddings: Audio text embeddings
|
||||
transformer: LTX model
|
||||
sigmas: List of sigma values
|
||||
verbose: Whether to show progress bar
|
||||
video_state: Optional LatentState for I2V conditioning
|
||||
|
||||
Returns:
|
||||
Tuple of (video_latents, audio_latents)
|
||||
"""
|
||||
# If video state is provided, use its latent
|
||||
if video_state is not None:
|
||||
video_latents = video_state.latent
|
||||
|
||||
for i in tqdm(range(len(sigmas) - 1), desc="Denoising A/V", disable=not verbose):
|
||||
sigma, sigma_next = sigmas[i], sigmas[i + 1]
|
||||
|
||||
# Flatten video latents
|
||||
b, c, f, h, w = video_latents.shape
|
||||
num_video_tokens = f * h * w
|
||||
video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1))
|
||||
|
||||
# Flatten audio latents: (B, C, T, F) -> (B, T, C*F)
|
||||
@@ -155,9 +180,22 @@ def denoise_av(
|
||||
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F)
|
||||
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af))
|
||||
|
||||
# Compute per-token timesteps for video
|
||||
# For I2V: conditioned tokens get timestep=0 (mask=0), unconditioned get timestep=sigma (mask=1)
|
||||
if video_state is not None:
|
||||
# Reshape denoise_mask from (B, 1, F, 1, 1) to (B, num_tokens)
|
||||
denoise_mask_flat = mx.reshape(video_state.denoise_mask, (b, 1, f, 1, 1))
|
||||
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
|
||||
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens))
|
||||
# Per-token timesteps: sigma * mask
|
||||
video_timesteps = sigma * denoise_mask_flat
|
||||
else:
|
||||
# All tokens get the same timestep
|
||||
video_timesteps = mx.full((b, num_video_tokens), sigma)
|
||||
|
||||
video_modality = Modality(
|
||||
latent=video_flat,
|
||||
timesteps=mx.full((1,), sigma),
|
||||
timesteps=video_timesteps,
|
||||
positions=video_positions,
|
||||
context=video_embeddings,
|
||||
context_mask=None,
|
||||
@@ -166,7 +204,7 @@ def denoise_av(
|
||||
|
||||
audio_modality = Modality(
|
||||
latent=audio_flat,
|
||||
timesteps=mx.full((1,), sigma),
|
||||
timesteps=mx.full((ab, at), sigma),
|
||||
positions=audio_positions,
|
||||
context=audio_embeddings,
|
||||
context_mask=None,
|
||||
@@ -184,6 +222,11 @@ def denoise_av(
|
||||
# Compute denoised
|
||||
video_denoised = to_denoised(video_latents, video_velocity, sigma)
|
||||
audio_denoised = to_denoised(audio_latents, audio_velocity, sigma)
|
||||
|
||||
# Apply conditioning mask for video if state is provided
|
||||
if video_state is not None:
|
||||
video_denoised = apply_denoise_mask(video_denoised, video_state.clean_latent, video_state.denoise_mask)
|
||||
|
||||
mx.eval(video_denoised, audio_denoised)
|
||||
|
||||
# Euler step
|
||||
@@ -317,8 +360,31 @@ def generate_video_with_audio(
|
||||
enhance_prompt: bool = False,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
image: Optional[str] = None,
|
||||
image_strength: float = 1.0,
|
||||
image_frame_idx: int = 0,
|
||||
):
|
||||
"""Generate video with synchronized audio from text prompt."""
|
||||
"""Generate video with synchronized audio from text prompt, optionally conditioned on an image.
|
||||
|
||||
Args:
|
||||
model_repo: Model repository ID
|
||||
text_encoder_repo: Text encoder repository ID
|
||||
prompt: Text description of the video to generate
|
||||
height: Output video height (must be divisible by 64)
|
||||
width: Output video width (must be divisible by 64)
|
||||
num_frames: Number of frames
|
||||
seed: Random seed
|
||||
fps: Frames per second
|
||||
output_path: Output video path
|
||||
output_audio_path: Output audio path
|
||||
verbose: Whether to print progress
|
||||
enhance_prompt: Whether to enhance prompt using Gemma
|
||||
max_tokens: Max tokens for prompt enhancement
|
||||
temperature: Temperature for prompt enhancement
|
||||
image: Path to conditioning image for I2V
|
||||
image_strength: Conditioning strength (1.0 = full denoise)
|
||||
image_frame_idx: Frame index to condition (0 = first frame)
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Validate dimensions
|
||||
@@ -333,9 +399,13 @@ def generate_video_with_audio(
|
||||
# Calculate audio frames
|
||||
audio_frames = compute_audio_frames(num_frames, fps)
|
||||
|
||||
print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames + audio{Colors.RESET}")
|
||||
is_i2v = image is not None
|
||||
mode_str = "I2V+Audio" if is_i2v else "T2V+Audio"
|
||||
print(f"{Colors.BOLD}{Colors.CYAN}🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames + audio{Colors.RESET}")
|
||||
print(f"{Colors.DIM}Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz{Colors.RESET}")
|
||||
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
|
||||
if is_i2v:
|
||||
print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}")
|
||||
|
||||
model_path = get_model_path(model_repo)
|
||||
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
|
||||
@@ -400,6 +470,31 @@ def generate_video_with_audio(
|
||||
transformer.load_weights(list(sanitized.items()), strict=False)
|
||||
mx.eval(transformer.parameters())
|
||||
|
||||
# Load VAE encoder and encode image for I2V conditioning
|
||||
stage1_image_latent = None
|
||||
stage2_image_latent = None
|
||||
if is_i2v:
|
||||
print(f"{Colors.BLUE}🖼️ Loading VAE encoder and encoding image...{Colors.RESET}")
|
||||
vae_encoder = load_vae_encoder(str(model_path / 'ltx-2-19b-distilled.safetensors'))
|
||||
mx.eval(vae_encoder.parameters())
|
||||
|
||||
# Load and prepare image for stage 1 (half resolution)
|
||||
input_image = load_image(image, height=height // 2, width=width // 2)
|
||||
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2)
|
||||
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
||||
mx.eval(stage1_image_latent)
|
||||
print(f" Stage 1 image latent: {stage1_image_latent.shape}")
|
||||
|
||||
# Load and prepare image for stage 2 (full resolution)
|
||||
input_image = load_image(image, height=height, width=width)
|
||||
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width)
|
||||
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
||||
mx.eval(stage2_image_latent)
|
||||
print(f" Stage 2 image latent: {stage2_image_latent.shape}")
|
||||
|
||||
del vae_encoder
|
||||
mx.clear_cache()
|
||||
|
||||
# Initialize latents
|
||||
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
|
||||
mx.random.seed(seed)
|
||||
@@ -412,12 +507,30 @@ def generate_video_with_audio(
|
||||
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||
mx.eval(video_positions, audio_positions)
|
||||
|
||||
# Apply I2V conditioning for stage 1 if provided
|
||||
video_state1 = None
|
||||
if is_i2v and stage1_image_latent is not None:
|
||||
video_state1 = LatentState(
|
||||
latent=video_latents,
|
||||
clean_latent=mx.zeros_like(video_latents),
|
||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
||||
)
|
||||
conditioning = VideoConditionByLatentIndex(
|
||||
latent=stage1_image_latent,
|
||||
frame_idx=image_frame_idx,
|
||||
strength=image_strength,
|
||||
)
|
||||
video_state1 = apply_conditioning(video_state1, [conditioning])
|
||||
video_latents = video_state1.latent
|
||||
mx.eval(video_latents)
|
||||
|
||||
# Stage 1 denoising
|
||||
video_latents, audio_latents = denoise_av(
|
||||
video_latents, audio_latents,
|
||||
video_positions, audio_positions,
|
||||
video_embeddings, audio_embeddings,
|
||||
transformer, STAGE_1_SIGMAS, verbose=verbose
|
||||
transformer, STAGE_1_SIGMAS, verbose=verbose,
|
||||
video_state=video_state1
|
||||
)
|
||||
|
||||
# Upsample video latents
|
||||
@@ -449,11 +562,29 @@ def generate_video_with_audio(
|
||||
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale)
|
||||
mx.eval(video_latents, audio_latents)
|
||||
|
||||
# Apply I2V conditioning for stage 2 if provided
|
||||
video_state2 = None
|
||||
if is_i2v and stage2_image_latent is not None:
|
||||
video_state2 = LatentState(
|
||||
latent=video_latents,
|
||||
clean_latent=mx.zeros_like(video_latents),
|
||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
||||
)
|
||||
conditioning = VideoConditionByLatentIndex(
|
||||
latent=stage2_image_latent,
|
||||
frame_idx=image_frame_idx,
|
||||
strength=image_strength,
|
||||
)
|
||||
video_state2 = apply_conditioning(video_state2, [conditioning])
|
||||
video_latents = video_state2.latent
|
||||
mx.eval(video_latents)
|
||||
|
||||
video_latents, audio_latents = denoise_av(
|
||||
video_latents, audio_latents,
|
||||
video_positions, audio_positions,
|
||||
video_embeddings, audio_embeddings,
|
||||
transformer, STAGE_2_SIGMAS, verbose=verbose
|
||||
transformer, STAGE_2_SIGMAS, verbose=verbose,
|
||||
video_state=video_state2
|
||||
)
|
||||
|
||||
del transformer
|
||||
@@ -549,13 +680,18 @@ def generate_video_with_audio(
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate videos with synchronized audio using MLX LTX-2",
|
||||
description="Generate videos with synchronized audio using MLX LTX-2 (T2V and I2V)",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Text-to-Video with Audio (T2V+Audio)
|
||||
python -m mlx_video.generate_av --prompt "Ocean waves crashing on a beach"
|
||||
python -m mlx_video.generate_av --prompt "A jazz band playing" --enhance-prompt
|
||||
python -m mlx_video.generate_av --prompt "..." --output my_video.mp4 --output-audio my_audio.wav
|
||||
|
||||
# Image-to-Video with Audio (I2V+Audio)
|
||||
python -m mlx_video.generate_av --prompt "A person dancing" --image photo.jpg
|
||||
python -m mlx_video.generate_av --prompt "Waves crashing" --image beach.png --image-strength 0.8
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -587,6 +723,12 @@ Examples:
|
||||
help="Max tokens for prompt enhancement")
|
||||
parser.add_argument("--temperature", type=float, default=0.7,
|
||||
help="Temperature for prompt enhancement")
|
||||
parser.add_argument("--image", "-i", type=str, default=None,
|
||||
help="Path to conditioning image for I2V (Image-to-Video) generation")
|
||||
parser.add_argument("--image-strength", type=float, default=1.0,
|
||||
help="Conditioning strength for I2V (1.0 = full denoise, 0.0 = keep original, default: 1.0)")
|
||||
parser.add_argument("--image-frame-idx", type=int, default=0,
|
||||
help="Frame index to condition for I2V (0 = first frame, default: 0)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -605,6 +747,9 @@ Examples:
|
||||
enhance_prompt=args.enhance_prompt,
|
||||
max_tokens=args.max_tokens,
|
||||
temperature=args.temperature,
|
||||
image=args.image,
|
||||
image_strength=args.image_strength,
|
||||
image_frame_idx=args.image_frame_idx,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user