Add image-to-video (I2V) conditioning support

- Introduced `load_image`, `prepare_image_for_encoding`, and `apply_conditioning` functions for handling image inputs and conditioning during video generation.
- Enhanced `generate_video` and `denoise_av` functions to accept optional image inputs for I2V conditioning.
- Updated command-line interface to include parameters for image conditioning, such as `--image`, `--image-strength`, and `--image-frame-idx`.
- Added new `VideoConditionByLatentIndex` and `LatentState` classes for managing latent states with conditioning.
- Implemented VAE encoder loading and image encoding for conditioning in the video generation process.d
This commit is contained in:
Prince Canuma
2026-01-17 00:19:52 +01:00
parent 5f86e881d7
commit 146f5d2981
11 changed files with 937 additions and 67 deletions

View File

@@ -1,10 +1,13 @@
import math
from typing import Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from functools import partial
from pathlib import Path
from huggingface_hub import snapshot_download
from PIL import Image
def get_model_path(model_repo: str):
"""Get or download LTX-2 model path."""
@@ -160,3 +163,122 @@ def get_timestep_embedding(
emb = mx.pad(emb, [(0, 0), (0, 1)])
return emb
def load_image(
image_path: Union[str, Path],
height: Optional[int] = None,
width: Optional[int] = None,
) -> mx.array:
"""Load and preprocess an image for I2V conditioning.
Args:
image_path: Path to the image file
height: Target height (must be divisible by 32). If None, uses original.
width: Target width (must be divisible by 32). If None, uses original.
Returns:
Image tensor of shape (H, W, 3) in range [0, 1]
"""
image = Image.open(image_path).convert("RGB")
# Resize if dimensions specified
if height is not None and width is not None:
image = image.resize((width, height), Image.Resampling.LANCZOS)
elif height is not None or width is not None:
# If only one dimension specified, resize preserving aspect ratio
orig_w, orig_h = image.size
if height is not None:
scale = height / orig_h
new_w = int(orig_w * scale)
new_w = (new_w // 32) * 32 # Round to nearest 32
image = image.resize((new_w, height), Image.Resampling.LANCZOS)
else:
scale = width / orig_w
new_h = int(orig_h * scale)
new_h = (new_h // 32) * 32 # Round to nearest 32
image = image.resize((width, new_h), Image.Resampling.LANCZOS)
else:
# Round to nearest 32
orig_w, orig_h = image.size
new_w = (orig_w // 32) * 32
new_h = (orig_h // 32) * 32
if new_w != orig_w or new_h != orig_h:
image = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
# Convert to numpy then MLX
image_np = np.array(image).astype(np.float32) / 255.0
return mx.array(image_np)
def resize_image_aspect_ratio(
image: mx.array,
long_side: int = 512,
) -> mx.array:
"""Resize image preserving aspect ratio, making long side = long_side.
Args:
image: Image tensor of shape (H, W, 3)
long_side: Target size for the longer dimension
Returns:
Resized image tensor
"""
h, w = image.shape[:2]
if h > w:
new_h = long_side
new_w = int(w * long_side / h)
else:
new_w = long_side
new_h = int(h * long_side / w)
# Round to nearest 32
new_h = (new_h // 32) * 32
new_w = (new_w // 32) * 32
# Use PIL for high-quality resize
image_np = np.array(image)
if image_np.max() <= 1.0:
image_np = (image_np * 255).astype(np.uint8)
pil_image = Image.fromarray(image_np)
pil_image = pil_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
return mx.array(np.array(pil_image).astype(np.float32) / 255.0)
def prepare_image_for_encoding(
image: mx.array,
target_height: int,
target_width: int,
) -> mx.array:
"""Prepare image for VAE encoding by resizing and normalizing.
Args:
image: Image tensor of shape (H, W, 3) in range [0, 1]
target_height: Target height for the video
target_width: Target width for the video
Returns:
Image tensor ready for encoding, shape (1, 3, 1, H, W) in range [-1, 1]
"""
h, w = image.shape[:2]
# Resize if needed
if h != target_height or w != target_width:
image_np = np.array(image)
if image_np.max() <= 1.0:
image_np = (image_np * 255).astype(np.uint8)
pil_image = Image.fromarray(image_np)
pil_image = pil_image.resize((target_width, target_height), Image.Resampling.LANCZOS)
image = mx.array(np.array(pil_image).astype(np.float32) / 255.0)
# Normalize to [-1, 1]
image = image * 2.0 - 1.0
# Convert to (B, C, 1, H, W)
image = mx.transpose(image, (2, 0, 1)) # (3, H, W)
image = mx.expand_dims(image, axis=0) # (1, 3, H, W)
image = mx.expand_dims(image, axis=2) # (1, 3, 1, H, W)
return image