Add image-to-video (I2V) conditioning support
- Introduced `load_image`, `prepare_image_for_encoding`, and `apply_conditioning` functions for handling image inputs and conditioning during video generation. - Enhanced `generate_video` and `denoise_av` functions to accept optional image inputs for I2V conditioning. - Updated command-line interface to include parameters for image conditioning, such as `--image`, `--image-strength`, and `--image-frame-idx`. - Added new `VideoConditionByLatentIndex` and `LatentState` classes for managing latent states with conditioning. - Implemented VAE encoder loading and image encoding for conditioning in the video generation process.d
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
def get_model_path(model_repo: str):
|
||||
"""Get or download LTX-2 model path."""
|
||||
@@ -160,3 +163,122 @@ def get_timestep_embedding(
|
||||
emb = mx.pad(emb, [(0, 0), (0, 1)])
|
||||
|
||||
return emb
|
||||
|
||||
|
||||
def load_image(
|
||||
image_path: Union[str, Path],
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
) -> mx.array:
|
||||
"""Load and preprocess an image for I2V conditioning.
|
||||
|
||||
Args:
|
||||
image_path: Path to the image file
|
||||
height: Target height (must be divisible by 32). If None, uses original.
|
||||
width: Target width (must be divisible by 32). If None, uses original.
|
||||
|
||||
Returns:
|
||||
Image tensor of shape (H, W, 3) in range [0, 1]
|
||||
"""
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
# Resize if dimensions specified
|
||||
if height is not None and width is not None:
|
||||
image = image.resize((width, height), Image.Resampling.LANCZOS)
|
||||
elif height is not None or width is not None:
|
||||
# If only one dimension specified, resize preserving aspect ratio
|
||||
orig_w, orig_h = image.size
|
||||
if height is not None:
|
||||
scale = height / orig_h
|
||||
new_w = int(orig_w * scale)
|
||||
new_w = (new_w // 32) * 32 # Round to nearest 32
|
||||
image = image.resize((new_w, height), Image.Resampling.LANCZOS)
|
||||
else:
|
||||
scale = width / orig_w
|
||||
new_h = int(orig_h * scale)
|
||||
new_h = (new_h // 32) * 32 # Round to nearest 32
|
||||
image = image.resize((width, new_h), Image.Resampling.LANCZOS)
|
||||
else:
|
||||
# Round to nearest 32
|
||||
orig_w, orig_h = image.size
|
||||
new_w = (orig_w // 32) * 32
|
||||
new_h = (orig_h // 32) * 32
|
||||
if new_w != orig_w or new_h != orig_h:
|
||||
image = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
|
||||
# Convert to numpy then MLX
|
||||
image_np = np.array(image).astype(np.float32) / 255.0
|
||||
return mx.array(image_np)
|
||||
|
||||
|
||||
def resize_image_aspect_ratio(
|
||||
image: mx.array,
|
||||
long_side: int = 512,
|
||||
) -> mx.array:
|
||||
"""Resize image preserving aspect ratio, making long side = long_side.
|
||||
|
||||
Args:
|
||||
image: Image tensor of shape (H, W, 3)
|
||||
long_side: Target size for the longer dimension
|
||||
|
||||
Returns:
|
||||
Resized image tensor
|
||||
"""
|
||||
h, w = image.shape[:2]
|
||||
|
||||
if h > w:
|
||||
new_h = long_side
|
||||
new_w = int(w * long_side / h)
|
||||
else:
|
||||
new_w = long_side
|
||||
new_h = int(h * long_side / w)
|
||||
|
||||
# Round to nearest 32
|
||||
new_h = (new_h // 32) * 32
|
||||
new_w = (new_w // 32) * 32
|
||||
|
||||
# Use PIL for high-quality resize
|
||||
image_np = np.array(image)
|
||||
if image_np.max() <= 1.0:
|
||||
image_np = (image_np * 255).astype(np.uint8)
|
||||
pil_image = Image.fromarray(image_np)
|
||||
pil_image = pil_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
|
||||
return mx.array(np.array(pil_image).astype(np.float32) / 255.0)
|
||||
|
||||
|
||||
def prepare_image_for_encoding(
|
||||
image: mx.array,
|
||||
target_height: int,
|
||||
target_width: int,
|
||||
) -> mx.array:
|
||||
"""Prepare image for VAE encoding by resizing and normalizing.
|
||||
|
||||
Args:
|
||||
image: Image tensor of shape (H, W, 3) in range [0, 1]
|
||||
target_height: Target height for the video
|
||||
target_width: Target width for the video
|
||||
|
||||
Returns:
|
||||
Image tensor ready for encoding, shape (1, 3, 1, H, W) in range [-1, 1]
|
||||
"""
|
||||
h, w = image.shape[:2]
|
||||
|
||||
# Resize if needed
|
||||
if h != target_height or w != target_width:
|
||||
image_np = np.array(image)
|
||||
if image_np.max() <= 1.0:
|
||||
image_np = (image_np * 255).astype(np.uint8)
|
||||
pil_image = Image.fromarray(image_np)
|
||||
pil_image = pil_image.resize((target_width, target_height), Image.Resampling.LANCZOS)
|
||||
image = mx.array(np.array(pil_image).astype(np.float32) / 255.0)
|
||||
|
||||
# Normalize to [-1, 1]
|
||||
image = image * 2.0 - 1.0
|
||||
|
||||
# Convert to (B, C, 1, H, W)
|
||||
image = mx.transpose(image, (2, 0, 1)) # (3, H, W)
|
||||
image = mx.expand_dims(image, axis=0) # (1, 3, H, W)
|
||||
image = mx.expand_dims(image, axis=2) # (1, 3, 1, H, W)
|
||||
|
||||
return image
|
||||
|
||||
Reference in New Issue
Block a user