Add image-to-video (I2V) conditioning support

- Introduced `load_image`, `prepare_image_for_encoding`, and `apply_conditioning` functions for handling image inputs and conditioning during video generation.
- Enhanced `generate_video` and `denoise_av` functions to accept optional image inputs for I2V conditioning.
- Updated command-line interface to include parameters for image conditioning, such as `--image`, `--image-strength`, and `--image-frame-idx`.
- Added new `VideoConditionByLatentIndex` and `LatentState` classes for managing latent states with conditioning.
- Implemented VAE encoder loading and image encoding for conditioning in the video generation process.d
This commit is contained in:
Prince Canuma
2026-01-17 00:19:52 +01:00
parent 5f86e881d7
commit 146f5d2981
11 changed files with 937 additions and 67 deletions

View File

@@ -1 +1,3 @@
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder, encode_image
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder, LTX2VideoDecoder

View File

@@ -0,0 +1,187 @@
"""Video VAE Encoder for LTX-2 Image-to-Video.
The encoder compresses input images/videos to latent representations.
Used for I2V (image-to-video) conditioning by encoding the input image
to latent space, which can then be used to condition video generation.
"""
from pathlib import Path
from typing import List, Tuple, Any, Optional
import json
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, LogVarianceType, NormLayerType, PaddingModeType
def load_vae_encoder(model_path: str) -> VideoEncoder:
"""Load VAE encoder from safetensors file.
Args:
model_path: Path to the model weights (safetensors file or directory)
Returns:
Loaded VideoEncoder instance
"""
from safetensors import safe_open
model_path = Path(model_path)
# Try to find the weights file
if model_path.is_file() and model_path.suffix == ".safetensors":
weights_path = model_path
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
else:
raise FileNotFoundError(f"VAE weights not found at {model_path}")
print(f"Loading VAE encoder from {weights_path}...")
# Read config from safetensors metadata
encoder_blocks = []
norm_layer = NormLayerType.PIXEL_NORM
latent_log_var = LogVarianceType.UNIFORM
patch_size = 4
try:
with safe_open(str(weights_path), framework="numpy") as f:
metadata = f.metadata()
if metadata and "config" in metadata:
configs = json.loads(metadata["config"])
vae_config = configs.get("vae", {})
# Parse encoder blocks
raw_blocks = vae_config.get("encoder_blocks", [])
for block in raw_blocks:
if isinstance(block, list) and len(block) == 2:
name, params = block
encoder_blocks.append((name, params))
# Parse other config
norm_str = vae_config.get("norm_layer", "pixel_norm")
norm_layer = NormLayerType.PIXEL_NORM if norm_str == "pixel_norm" else NormLayerType.GROUP_NORM
var_str = vae_config.get("latent_log_var", "uniform")
if var_str == "uniform":
latent_log_var = LogVarianceType.UNIFORM
elif var_str == "per_channel":
latent_log_var = LogVarianceType.PER_CHANNEL
elif var_str == "constant":
latent_log_var = LogVarianceType.CONSTANT
else:
latent_log_var = LogVarianceType.NONE
patch_size = vae_config.get("patch_size", 4)
print(f" Loaded config: {len(encoder_blocks)} encoder blocks, norm={norm_str}, patch_size={patch_size}")
except Exception as e:
print(f" Could not read config from metadata: {e}")
# Use default config
encoder_blocks = [
("res_x", {"num_layers": 4}),
("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_time_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
]
print(f" Using default encoder config with {len(encoder_blocks)} blocks")
# Create encoder
encoder = VideoEncoder(
convolution_dimensions=3,
in_channels=3,
out_channels=128,
encoder_blocks=encoder_blocks,
patch_size=patch_size,
norm_layer=norm_layer,
latent_log_var=latent_log_var,
encoder_spatial_padding_mode=PaddingModeType.ZEROS,
)
# Load weights
weights = mx.load(str(weights_path))
# Determine prefix based on weight keys
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
if has_vae_prefix:
prefix = "vae.encoder."
stats_prefix = "vae.per_channel_statistics."
else:
prefix = "encoder."
stats_prefix = "per_channel_statistics."
# Load per-channel statistics for normalization
mean_key = f"{stats_prefix}mean-of-means"
std_key = f"{stats_prefix}std-of-means"
if mean_key in weights:
encoder.per_channel_statistics.mean = weights[mean_key]
print(f" Loaded latent mean: shape {weights[mean_key].shape}")
if std_key in weights:
encoder.per_channel_statistics.std = weights[std_key]
print(f" Loaded latent std: shape {weights[std_key].shape}")
# Build encoder weights dict with key remapping
encoder_weights = {}
for key, value in weights.items():
if not key.startswith(prefix):
continue
# Remove prefix
new_key = key[len(prefix):]
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
encoder_weights[new_key] = value
print(f" Found {len(encoder_weights)} encoder weights")
# Load weights
encoder.load_weights(list(encoder_weights.items()), strict=False)
print("VAE encoder loaded successfully")
return encoder
def encode_image(
image: mx.array,
encoder: VideoEncoder,
) -> mx.array:
"""Encode a single image to latent space.
Args:
image: Image tensor of shape (H, W, 3) in range [0, 1] or (B, H, W, 3)
encoder: Loaded VAE encoder
Returns:
Latent tensor of shape (1, 128, 1, H//32, W//32)
"""
# Add batch dimension if needed
if image.ndim == 3:
image = mx.expand_dims(image, axis=0) # (1, H, W, 3)
# Convert from (B, H, W, C) to (B, C, H, W)
image = mx.transpose(image, (0, 3, 1, 2)) # (B, 3, H, W)
# Normalize to [-1, 1]
if image.max() > 1.0:
image = image / 255.0
image = image * 2.0 - 1.0
# Add temporal dimension: (B, C, H, W) -> (B, C, 1, H, W)
image = mx.expand_dims(image, axis=2) # (B, 3, 1, H, W)
# Encode
latent = encoder(image)
return latent

View File

@@ -141,9 +141,10 @@ class UNetMidBlock3D(nn.Module):
self.num_layers = num_layers
# Create ResNet blocks
self.resnets = [
ResnetBlock3D(
# Create ResNet blocks - use dict for MLX parameter tracking
# Named res_blocks to match PyTorch weight keys
self.res_blocks = {
i: ResnetBlock3D(
dims=dims,
in_channels=in_channels,
out_channels=in_channels,
@@ -154,8 +155,8 @@ class UNetMidBlock3D(nn.Module):
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
for _ in range(num_layers)
]
for i in range(num_layers)
}
def __call__(
self,
@@ -164,8 +165,8 @@ class UNetMidBlock3D(nn.Module):
timestep: Optional[mx.array] = None,
generator: Optional[int] = None,
) -> mx.array:
for resnet in self.resnets:
for resnet in self.res_blocks.values():
x = resnet(x, causal=causal, generator=generator)
return x

View File

@@ -9,6 +9,15 @@ from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingMode
class SpaceToDepthDownsample(nn.Module):
"""Space-to-depth downsampling with 3x3 conv and skip connection.
PyTorch-compatible implementation:
1. Apply 3x3 conv: in_channels -> out_channels // prod(stride)
2. Space-to-depth on conv output: channels * prod(stride)
3. Space-to-depth on input with group averaging for skip connection
4. Add skip connection
"""
def __init__(
self,
dims: int,
@@ -17,7 +26,6 @@ class SpaceToDepthDownsample(nn.Module):
stride: Union[int, Tuple[int, int, int]],
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
if isinstance(stride, int):
@@ -25,61 +33,74 @@ class SpaceToDepthDownsample(nn.Module):
self.stride = stride
self.dims = dims
self.out_channels = out_channels
# Calculate the multiplier for channels
# Calculate channels
multiplier = stride[0] * stride[1] * stride[2]
intermediate_channels = in_channels * multiplier
self.group_size = in_channels * multiplier // out_channels
conv_out_channels = out_channels // multiplier
# 1x1x1 convolution to adjust channels
# 3x3 convolution (not 1x1)
self.conv = CausalConv3d(
in_channels=intermediate_channels,
out_channels=out_channels,
kernel_size=1,
in_channels=in_channels,
out_channels=conv_out_channels,
kernel_size=3,
stride=1,
padding=0,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
def _space_to_depth(self, x: mx.array) -> mx.array:
"""Rearrange: b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w"""
b, c, d, h, w = x.shape
st, sh, sw = self.stride
# Reshape to group spatial elements
x = mx.reshape(x, (b, c, d // st, st, h // sh, sh, w // sw, sw))
# Permute: (B, C, D', st, H', sh, W', sw) -> (B, C, st, sh, sw, D', H', W')
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
# Reshape to combine channels
new_c = c * st * sh * sw
new_d = d // st
new_h = h // sh
new_w = w // sw
x = mx.reshape(x, (b, new_c, new_d, new_h, new_w))
return x
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
b, c, d, h, w = x.shape
st, sh, sw = self.stride
# Temporal padding for causal mode
if st == 2:
# Duplicate first frame for padding
x = mx.concatenate([x[:, :, :1, :, :], x], axis=2)
d = d + 1
# Pad if necessary to make dimensions divisible by stride
pad_d = (st - d % st) % st
pad_h = (sh - h % sh) % sh
pad_w = (sw - w % sw) % sw
if pad_d > 0 or pad_h > 0 or pad_w > 0:
# For causal, pad at the end of temporal dimension
if causal:
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)])
else:
x = mx.pad(x, [(0, 0), (0, 0), (pad_d // 2, pad_d - pad_d // 2),
(pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2)])
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)])
b, c, d, h, w = x.shape
# Skip connection: space-to-depth on input, then group mean
x_in = self._space_to_depth(x)
# Reshape for group mean: (b, c*prod(stride), d, h, w) -> (b, out_channels, group_size, d, h, w)
b2, c2, d2, h2, w2 = x_in.shape
x_in = mx.reshape(x_in, (b2, self.out_channels, self.group_size, d2, h2, w2))
x_in = mx.mean(x_in, axis=2) # (b, out_channels, d, h, w)
# Reshape to group spatial elements
# (B, C, D, H, W) -> (B, C, D/st, st, H/sh, sh, W/sw, sw)
x = mx.reshape(x, (b, c, d // st, st, h // sh, sh, w // sw, sw))
# Conv branch: apply conv then space-to-depth
x_conv = self.conv(x, causal=causal)
x_conv = self._space_to_depth(x_conv)
# Permute to move stride elements to channel dim
# (B, C, D', st, H', sh, W', sw) -> (B, C, st, sh, sw, D', H', W')
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
# Reshape to combine channels
# (B, C, st, sh, sw, D', H', W') -> (B, C*st*sh*sw, D', H', W')
new_c = c * st * sh * sw
new_d = d // st
new_h = h // sh
new_w = w // sw
x = mx.reshape(x, (b, new_c, new_d, new_h, new_w))
# Apply 1x1 conv to adjust channels
x = self.conv(x, causal=causal)
return x
# Add skip connection
return x_conv + x_in
class DepthToSpaceUpsample(nn.Module):

View File

@@ -273,9 +273,9 @@ class VideoEncoder(nn.Module):
spatial_padding_mode=encoder_spatial_padding_mode,
)
# Build encoder blocks
self.down_blocks = []
for block_name, block_params in encoder_blocks:
# Build encoder blocks - use dict with int keys for MLX parameter tracking
self.down_blocks = {}
for i, (block_name, block_params) in enumerate(encoder_blocks):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block, feature_channels = _make_encoder_block(
@@ -287,7 +287,7 @@ class VideoEncoder(nn.Module):
norm_num_groups=self._norm_num_groups,
spatial_padding_mode=encoder_spatial_padding_mode,
)
self.down_blocks.append(block)
self.down_blocks[i] = block
# Output normalization and convolution
if norm_layer == NormLayerType.GROUP_NORM:
@@ -341,7 +341,7 @@ class VideoEncoder(nn.Module):
sample = self.conv_in(sample, causal=True)
# Process through encoder blocks
for down_block in self.down_blocks:
for down_block in self.down_blocks.values():
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
sample = down_block(sample, causal=True)
else: