Add image-to-video (I2V) conditioning support

- Introduced `load_image`, `prepare_image_for_encoding`, and `apply_conditioning` functions for handling image inputs and conditioning during video generation.
- Enhanced `generate_video` and `denoise_av` functions to accept optional image inputs for I2V conditioning.
- Updated command-line interface to include parameters for image conditioning, such as `--image`, `--image-strength`, and `--image-frame-idx`.
- Added new `VideoConditionByLatentIndex` and `LatentState` classes for managing latent states with conditioning.
- Implemented VAE encoder loading and image encoding for conditioning in the video generation process.d
This commit is contained in:
Prince Canuma
2026-01-17 00:19:52 +01:00
parent 5f86e881d7
commit 146f5d2981
11 changed files with 937 additions and 67 deletions

View File

@@ -1,6 +1,7 @@
import argparse
import time
from pathlib import Path
from typing import Optional, List, Tuple
import mlx.core as mx
import numpy as np
@@ -22,10 +23,13 @@ class Colors:
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.transformer import Modality
from mlx_video.convert import sanitize_transformer_weights
from mlx_video.utils import to_denoised
from mlx_video.convert import sanitize_transformer_weights, sanitize_vae_encoder_weights
from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning
from mlx_video.conditioning.latent import LatentState, create_initial_state, apply_denoise_mask, add_noise_with_state
from mlx_video.utils import get_model_path
@@ -115,17 +119,49 @@ def denoise(
transformer: LTXModel,
sigmas: list,
verbose: bool = True,
state: Optional[LatentState] = None,
) -> mx.array:
"""Run denoising loop."""
"""Run denoising loop with optional conditioning.
Args:
latents: Noisy latent tensor (B, C, F, H, W)
positions: Position embeddings
text_embeddings: Text conditioning embeddings
transformer: LTX model
sigmas: List of sigma values for denoising schedule
verbose: Whether to show progress bar
state: Optional LatentState for I2V conditioning
Returns:
Denoised latent tensor
"""
# If state is provided, use its latent (which may have conditioning applied)
if state is not None:
latents = state.latent
for i in tqdm(range(len(sigmas) - 1), desc="Denoising", disable=not verbose):
sigma, sigma_next = sigmas[i], sigmas[i + 1]
b, c, f, h, w = latents.shape
num_tokens = f * h * w
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
# Compute per-token timesteps
# For I2V: conditioned tokens get timestep=0 (mask=0), unconditioned get timestep=sigma (mask=1)
if state is not None:
# Reshape denoise_mask from (B, 1, F, 1, 1) to (B, num_tokens)
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens))
# Per-token timesteps: sigma * mask
timesteps = sigma * denoise_mask_flat
else:
# All tokens get the same timestep
timesteps = mx.full((b, num_tokens), sigma)
video_modality = Modality(
latent=latents_flat,
timesteps=mx.full((1,), sigma),
timesteps=timesteps,
positions=positions,
context=text_embeddings,
context_mask=None,
@@ -137,6 +173,11 @@ def denoise(
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
denoised = to_denoised(latents, velocity, sigma)
# Apply conditioning mask if state is provided
if state is not None:
denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask)
mx.eval(denoised)
if sigma_next > 0:
@@ -163,10 +204,15 @@ def generate_video(
enhance_prompt: bool = False,
max_tokens: int = 512,
temperature: float = 0.7,
image: Optional[str] = None,
image_strength: float = 1.0,
image_frame_idx: int = 0,
):
"""Generate video from text prompt.
"""Generate video from text prompt, optionally conditioned on an image.
Args:
model_repo: Model repository ID
text_encoder_repo: Text encoder repository ID
prompt: Text description of the video to generate
height: Output video height (must be divisible by 64)
width: Output video width (must be divisible by 64)
@@ -175,6 +221,13 @@ def generate_video(
fps: Frames per second for output video
output_path: Path to save the output video
save_frames: Whether to save individual frames as images
verbose: Whether to print progress
enhance_prompt: Whether to enhance prompt using Gemma
max_tokens: Max tokens for prompt enhancement
temperature: Temperature for prompt enhancement
image: Path to conditioning image for I2V (Image-to-Video)
image_strength: Conditioning strength (1.0 = full denoise, 0.0 = keep original)
image_frame_idx: Frame index to condition (0 = first frame)
"""
start_time = time.time()
@@ -188,8 +241,12 @@ def generate_video(
num_frames = adjusted_num_frames
print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
is_i2v = image is not None
mode_str = "I2V" if is_i2v else "T2V"
print(f"{Colors.BOLD}{Colors.CYAN}🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
if is_i2v:
print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}")
# Get model path
model_path = get_model_path(model_repo)
@@ -247,6 +304,31 @@ def generate_video(
transformer.load_weights(list(sanitized.items()), strict=False)
mx.eval(transformer.parameters())
# Load VAE encoder and encode image for I2V conditioning
stage1_image_latent = None
stage2_image_latent = None
if is_i2v:
print(f"{Colors.BLUE}🖼️ Loading VAE encoder and encoding image...{Colors.RESET}")
vae_encoder = load_vae_encoder(str(model_path / 'ltx-2-19b-distilled.safetensors'))
mx.eval(vae_encoder.parameters())
# Load and prepare image for stage 1 (half resolution)
input_image = load_image(image, height=height // 2, width=width // 2)
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2)
stage1_image_latent = vae_encoder(stage1_image_tensor)
mx.eval(stage1_image_latent)
print(f" Stage 1 image latent: {stage1_image_latent.shape}")
# Load and prepare image for stage 2 (full resolution)
input_image = load_image(image, height=height, width=width)
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width)
stage2_image_latent = vae_encoder(stage2_image_tensor)
mx.eval(stage2_image_latent)
print(f" Stage 2 image latent: {stage2_image_latent.shape}")
del vae_encoder
mx.clear_cache()
# Stage 1: Generate at half resolution
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
mx.random.seed(seed)
@@ -256,7 +338,25 @@ def generate_video(
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose)
# Apply I2V conditioning if provided
state1 = None
if is_i2v and stage1_image_latent is not None:
# Create state with conditioning
state1 = LatentState(
latent=latents,
clean_latent=mx.zeros_like(latents),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
)
conditioning = VideoConditionByLatentIndex(
latent=stage1_image_latent,
frame_idx=image_frame_idx,
strength=image_strength,
)
state1 = apply_conditioning(state1, [conditioning])
latents = state1.latent
mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1)
# Upsample latents
print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}")
@@ -285,7 +385,24 @@ def generate_video(
latents = noise * noise_scale + latents * (1 - noise_scale)
mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose)
# Apply I2V conditioning for stage 2 if provided
state2 = None
if is_i2v and stage2_image_latent is not None:
state2 = LatentState(
latent=latents,
clean_latent=mx.zeros_like(latents),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
)
conditioning = VideoConditionByLatentIndex(
latent=stage2_image_latent,
frame_idx=image_frame_idx,
strength=image_strength,
)
state2 = apply_conditioning(state2, [conditioning])
latents = state2.latent
mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2)
del transformer
mx.clear_cache()
@@ -335,13 +452,18 @@ def generate_video(
def main():
parser = argparse.ArgumentParser(
description="Generate videos with MLX LTX-2",
description="Generate videos with MLX LTX-2 (T2V and I2V)",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Text-to-Video (T2V)
python -m mlx_video.generate --prompt "A cat walking on grass"
python -m mlx_video.generate --prompt "Ocean waves at sunset" --height 768 --width 768
python -m mlx_video.generate --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4
# Image-to-Video (I2V)
python -m mlx_video.generate --prompt "A person dancing" --image photo.jpg
python -m mlx_video.generate --prompt "Waves crashing" --image beach.png --image-strength 0.8
"""
)
@@ -426,6 +548,24 @@ Examples:
default=0.7,
help="Temperature for prompt enhancement (default: 0.7)"
)
parser.add_argument(
"--image", "-i",
type=str,
default=None,
help="Path to conditioning image for I2V (Image-to-Video) generation"
)
parser.add_argument(
"--image-strength",
type=float,
default=1.0,
help="Conditioning strength for I2V (1.0 = full denoise, 0.0 = keep original, default: 1.0)"
)
parser.add_argument(
"--image-frame-idx",
type=int,
default=0,
help="Frame index to condition for I2V (0 = first frame, default: 0)"
)
args = parser.parse_args()
generate_video(