- Introduced `load_image`, `prepare_image_for_encoding`, and `apply_conditioning` functions for handling image inputs and conditioning during video generation. - Enhanced `generate_video` and `denoise_av` functions to accept optional image inputs for I2V conditioning. - Updated command-line interface to include parameters for image conditioning, such as `--image`, `--image-strength`, and `--image-frame-idx`. - Added new `VideoConditionByLatentIndex` and `LatentState` classes for managing latent states with conditioning. - Implemented VAE encoder loading and image encoding for conditioning in the video generation process.d
285 lines
8.5 KiB
Python
285 lines
8.5 KiB
Python
import math
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
import numpy as np
|
|
from functools import partial
|
|
from pathlib import Path
|
|
from huggingface_hub import snapshot_download
|
|
from PIL import Image
|
|
|
|
def get_model_path(model_repo: str):
|
|
"""Get or download LTX-2 model path."""
|
|
try:
|
|
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
|
|
except Exception:
|
|
print("Downloading LTX-2 model weights...")
|
|
return Path(snapshot_download(
|
|
repo_id=model_repo,
|
|
local_files_only=False,
|
|
resume_download=True,
|
|
allow_patterns=["*.safetensors", "*.json"],
|
|
))
|
|
|
|
def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
|
|
if quantization is not None:
|
|
def get_class_predicate(p, m):
|
|
# Handle custom per layer quantizations
|
|
if p in quantization:
|
|
return quantization[p]
|
|
if not hasattr(m, "to_quantized"):
|
|
return False
|
|
# Skip layers not divisible by 64
|
|
if hasattr(m, "weight") and m.weight.shape[0] % 64 != 0:
|
|
return False
|
|
# Handle legacy models which may not have everything quantized
|
|
return f"{p}.scales" in weights
|
|
|
|
nn.quantize(
|
|
model,
|
|
group_size=quantization["group_size"],
|
|
bits=quantization["bits"],
|
|
mode=quantization.get("mode", "affine"),
|
|
class_predicate=get_class_predicate,
|
|
)
|
|
|
|
|
|
@partial(mx.compile, shapeless=True)
|
|
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
|
|
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],)), eps)
|
|
|
|
|
|
|
|
@partial(mx.compile, shapeless=True)
|
|
def to_denoised(
|
|
noisy: mx.array,
|
|
velocity: mx.array,
|
|
sigma: mx.array | float
|
|
) -> mx.array:
|
|
"""Convert velocity prediction to denoised output.
|
|
|
|
Given noisy input x_t and velocity prediction v, compute denoised x_0:
|
|
x_0 = x_t - sigma * v
|
|
|
|
Args:
|
|
noisy: Noisy input tensor x_t
|
|
velocity: Velocity prediction v
|
|
sigma: Noise level (scalar or per-sample)
|
|
|
|
Returns:
|
|
Denoised tensor x_0
|
|
"""
|
|
if isinstance(sigma, (int, float)):
|
|
return noisy - sigma * velocity
|
|
else:
|
|
# sigma is per-sample
|
|
while sigma.ndim < velocity.ndim:
|
|
sigma = mx.expand_dims(sigma, axis=-1)
|
|
return noisy - sigma * velocity
|
|
|
|
|
|
def repeat_interleave(x: mx.array, repeats: int, axis: int = -1) -> mx.array:
|
|
"""Repeat elements of tensor along an axis, similar to torch.repeat_interleave.
|
|
|
|
Args:
|
|
x: Input tensor
|
|
repeats: Number of repetitions for each element
|
|
axis: The axis along which to repeat values
|
|
|
|
Returns:
|
|
Tensor with repeated values
|
|
"""
|
|
# Handle negative axis
|
|
if axis < 0:
|
|
axis = x.ndim + axis
|
|
|
|
# Get shape
|
|
shape = list(x.shape)
|
|
|
|
# Expand dims, repeat, then reshape
|
|
x = mx.expand_dims(x, axis=axis + 1)
|
|
|
|
# Create tile pattern
|
|
tile_pattern = [1] * x.ndim
|
|
tile_pattern[axis + 1] = repeats
|
|
|
|
x = mx.tile(x, tile_pattern)
|
|
|
|
# Reshape to merge the repeated dimension
|
|
new_shape = shape.copy()
|
|
new_shape[axis] *= repeats
|
|
|
|
return mx.reshape(x, new_shape)
|
|
|
|
|
|
class PixelNorm(nn.Module):
|
|
def __init__(self, eps: float = 1e-6):
|
|
super().__init__()
|
|
self.eps = eps
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
return x / mx.sqrt(mx.mean(x * x, axis=1, keepdims=True) + self.eps)
|
|
|
|
|
|
def get_timestep_embedding(
|
|
timesteps: mx.array,
|
|
embedding_dim: int,
|
|
flip_sin_to_cos: bool = False,
|
|
downscale_freq_shift: float = 1.0,
|
|
scale: float = 1.0,
|
|
max_period: int = 10000,
|
|
) -> mx.array:
|
|
"""Create sinusoidal timestep embeddings.
|
|
|
|
Args:
|
|
timesteps: 1D tensor of timesteps
|
|
embedding_dim: Dimension of the embeddings to create
|
|
flip_sin_to_cos: If True, flip sin and cos ordering
|
|
downscale_freq_shift: Frequency shift factor
|
|
scale: Scale factor for timesteps
|
|
max_period: Maximum period for the sinusoids
|
|
|
|
Returns:
|
|
Tensor of shape (len(timesteps), embedding_dim)
|
|
"""
|
|
assert timesteps.ndim == 1, "Timesteps should be 1D"
|
|
|
|
half_dim = embedding_dim // 2
|
|
exponent = -math.log(max_period) * mx.arange(0, half_dim, dtype=mx.float32)
|
|
exponent = exponent / (half_dim - downscale_freq_shift)
|
|
|
|
emb = mx.exp(exponent)
|
|
emb = (timesteps[:, None].astype(mx.float32) * scale) * emb[None, :]
|
|
|
|
# Compute sin and cos embeddings
|
|
if flip_sin_to_cos:
|
|
emb = mx.concatenate([mx.cos(emb), mx.sin(emb)], axis=-1)
|
|
else:
|
|
emb = mx.concatenate([mx.sin(emb), mx.cos(emb)], axis=-1)
|
|
|
|
# Zero pad if odd embedding dimension
|
|
if embedding_dim % 2 == 1:
|
|
emb = mx.pad(emb, [(0, 0), (0, 1)])
|
|
|
|
return emb
|
|
|
|
|
|
def load_image(
|
|
image_path: Union[str, Path],
|
|
height: Optional[int] = None,
|
|
width: Optional[int] = None,
|
|
) -> mx.array:
|
|
"""Load and preprocess an image for I2V conditioning.
|
|
|
|
Args:
|
|
image_path: Path to the image file
|
|
height: Target height (must be divisible by 32). If None, uses original.
|
|
width: Target width (must be divisible by 32). If None, uses original.
|
|
|
|
Returns:
|
|
Image tensor of shape (H, W, 3) in range [0, 1]
|
|
"""
|
|
image = Image.open(image_path).convert("RGB")
|
|
|
|
# Resize if dimensions specified
|
|
if height is not None and width is not None:
|
|
image = image.resize((width, height), Image.Resampling.LANCZOS)
|
|
elif height is not None or width is not None:
|
|
# If only one dimension specified, resize preserving aspect ratio
|
|
orig_w, orig_h = image.size
|
|
if height is not None:
|
|
scale = height / orig_h
|
|
new_w = int(orig_w * scale)
|
|
new_w = (new_w // 32) * 32 # Round to nearest 32
|
|
image = image.resize((new_w, height), Image.Resampling.LANCZOS)
|
|
else:
|
|
scale = width / orig_w
|
|
new_h = int(orig_h * scale)
|
|
new_h = (new_h // 32) * 32 # Round to nearest 32
|
|
image = image.resize((width, new_h), Image.Resampling.LANCZOS)
|
|
else:
|
|
# Round to nearest 32
|
|
orig_w, orig_h = image.size
|
|
new_w = (orig_w // 32) * 32
|
|
new_h = (orig_h // 32) * 32
|
|
if new_w != orig_w or new_h != orig_h:
|
|
image = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
|
|
|
# Convert to numpy then MLX
|
|
image_np = np.array(image).astype(np.float32) / 255.0
|
|
return mx.array(image_np)
|
|
|
|
|
|
def resize_image_aspect_ratio(
|
|
image: mx.array,
|
|
long_side: int = 512,
|
|
) -> mx.array:
|
|
"""Resize image preserving aspect ratio, making long side = long_side.
|
|
|
|
Args:
|
|
image: Image tensor of shape (H, W, 3)
|
|
long_side: Target size for the longer dimension
|
|
|
|
Returns:
|
|
Resized image tensor
|
|
"""
|
|
h, w = image.shape[:2]
|
|
|
|
if h > w:
|
|
new_h = long_side
|
|
new_w = int(w * long_side / h)
|
|
else:
|
|
new_w = long_side
|
|
new_h = int(h * long_side / w)
|
|
|
|
# Round to nearest 32
|
|
new_h = (new_h // 32) * 32
|
|
new_w = (new_w // 32) * 32
|
|
|
|
# Use PIL for high-quality resize
|
|
image_np = np.array(image)
|
|
if image_np.max() <= 1.0:
|
|
image_np = (image_np * 255).astype(np.uint8)
|
|
pil_image = Image.fromarray(image_np)
|
|
pil_image = pil_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
|
|
|
return mx.array(np.array(pil_image).astype(np.float32) / 255.0)
|
|
|
|
|
|
def prepare_image_for_encoding(
|
|
image: mx.array,
|
|
target_height: int,
|
|
target_width: int,
|
|
) -> mx.array:
|
|
"""Prepare image for VAE encoding by resizing and normalizing.
|
|
|
|
Args:
|
|
image: Image tensor of shape (H, W, 3) in range [0, 1]
|
|
target_height: Target height for the video
|
|
target_width: Target width for the video
|
|
|
|
Returns:
|
|
Image tensor ready for encoding, shape (1, 3, 1, H, W) in range [-1, 1]
|
|
"""
|
|
h, w = image.shape[:2]
|
|
|
|
# Resize if needed
|
|
if h != target_height or w != target_width:
|
|
image_np = np.array(image)
|
|
if image_np.max() <= 1.0:
|
|
image_np = (image_np * 255).astype(np.uint8)
|
|
pil_image = Image.fromarray(image_np)
|
|
pil_image = pil_image.resize((target_width, target_height), Image.Resampling.LANCZOS)
|
|
image = mx.array(np.array(pil_image).astype(np.float32) / 255.0)
|
|
|
|
# Normalize to [-1, 1]
|
|
image = image * 2.0 - 1.0
|
|
|
|
# Convert to (B, C, 1, H, W)
|
|
image = mx.transpose(image, (2, 0, 1)) # (3, H, W)
|
|
image = mx.expand_dims(image, axis=0) # (1, 3, H, W)
|
|
image = mx.expand_dims(image, axis=2) # (1, 3, 1, H, W)
|
|
|
|
return image
|