Add image-to-video (I2V) conditioning support
- Introduced `load_image`, `prepare_image_for_encoding`, and `apply_conditioning` functions for handling image inputs and conditioning during video generation. - Enhanced `generate_video` and `denoise_av` functions to accept optional image inputs for I2V conditioning. - Updated command-line interface to include parameters for image conditioning, such as `--image`, `--image-strength`, and `--image-frame-idx`. - Added new `VideoConditionByLatentIndex` and `LatentState` classes for managing latent states with conditioning. - Implemented VAE encoder loading and image encoding for conditioning in the video generation process.d
This commit is contained in:
@@ -1 +1,3 @@
|
||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder
|
||||
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder, encode_image
|
||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder, LTX2VideoDecoder
|
||||
|
||||
187
mlx_video/models/ltx/video_vae/encoder.py
Normal file
187
mlx_video/models/ltx/video_vae/encoder.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""Video VAE Encoder for LTX-2 Image-to-Video.
|
||||
|
||||
The encoder compresses input images/videos to latent representations.
|
||||
Used for I2V (image-to-video) conditioning by encoding the input image
|
||||
to latent space, which can then be used to condition video generation.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Any, Optional
|
||||
import json
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, LogVarianceType, NormLayerType, PaddingModeType
|
||||
|
||||
|
||||
def load_vae_encoder(model_path: str) -> VideoEncoder:
|
||||
"""Load VAE encoder from safetensors file.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model weights (safetensors file or directory)
|
||||
|
||||
Returns:
|
||||
Loaded VideoEncoder instance
|
||||
"""
|
||||
from safetensors import safe_open
|
||||
|
||||
model_path = Path(model_path)
|
||||
|
||||
# Try to find the weights file
|
||||
if model_path.is_file() and model_path.suffix == ".safetensors":
|
||||
weights_path = model_path
|
||||
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
|
||||
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
|
||||
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
|
||||
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
|
||||
else:
|
||||
raise FileNotFoundError(f"VAE weights not found at {model_path}")
|
||||
|
||||
print(f"Loading VAE encoder from {weights_path}...")
|
||||
|
||||
# Read config from safetensors metadata
|
||||
encoder_blocks = []
|
||||
norm_layer = NormLayerType.PIXEL_NORM
|
||||
latent_log_var = LogVarianceType.UNIFORM
|
||||
patch_size = 4
|
||||
|
||||
try:
|
||||
with safe_open(str(weights_path), framework="numpy") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata and "config" in metadata:
|
||||
configs = json.loads(metadata["config"])
|
||||
vae_config = configs.get("vae", {})
|
||||
|
||||
# Parse encoder blocks
|
||||
raw_blocks = vae_config.get("encoder_blocks", [])
|
||||
for block in raw_blocks:
|
||||
if isinstance(block, list) and len(block) == 2:
|
||||
name, params = block
|
||||
encoder_blocks.append((name, params))
|
||||
|
||||
# Parse other config
|
||||
norm_str = vae_config.get("norm_layer", "pixel_norm")
|
||||
norm_layer = NormLayerType.PIXEL_NORM if norm_str == "pixel_norm" else NormLayerType.GROUP_NORM
|
||||
|
||||
var_str = vae_config.get("latent_log_var", "uniform")
|
||||
if var_str == "uniform":
|
||||
latent_log_var = LogVarianceType.UNIFORM
|
||||
elif var_str == "per_channel":
|
||||
latent_log_var = LogVarianceType.PER_CHANNEL
|
||||
elif var_str == "constant":
|
||||
latent_log_var = LogVarianceType.CONSTANT
|
||||
else:
|
||||
latent_log_var = LogVarianceType.NONE
|
||||
|
||||
patch_size = vae_config.get("patch_size", 4)
|
||||
|
||||
print(f" Loaded config: {len(encoder_blocks)} encoder blocks, norm={norm_str}, patch_size={patch_size}")
|
||||
except Exception as e:
|
||||
print(f" Could not read config from metadata: {e}")
|
||||
# Use default config
|
||||
encoder_blocks = [
|
||||
("res_x", {"num_layers": 4}),
|
||||
("compress_space_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_time_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
]
|
||||
print(f" Using default encoder config with {len(encoder_blocks)} blocks")
|
||||
|
||||
# Create encoder
|
||||
encoder = VideoEncoder(
|
||||
convolution_dimensions=3,
|
||||
in_channels=3,
|
||||
out_channels=128,
|
||||
encoder_blocks=encoder_blocks,
|
||||
patch_size=patch_size,
|
||||
norm_layer=norm_layer,
|
||||
latent_log_var=latent_log_var,
|
||||
encoder_spatial_padding_mode=PaddingModeType.ZEROS,
|
||||
)
|
||||
|
||||
# Load weights
|
||||
weights = mx.load(str(weights_path))
|
||||
|
||||
# Determine prefix based on weight keys
|
||||
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
|
||||
|
||||
if has_vae_prefix:
|
||||
prefix = "vae.encoder."
|
||||
stats_prefix = "vae.per_channel_statistics."
|
||||
else:
|
||||
prefix = "encoder."
|
||||
stats_prefix = "per_channel_statistics."
|
||||
|
||||
# Load per-channel statistics for normalization
|
||||
mean_key = f"{stats_prefix}mean-of-means"
|
||||
std_key = f"{stats_prefix}std-of-means"
|
||||
|
||||
if mean_key in weights:
|
||||
encoder.per_channel_statistics.mean = weights[mean_key]
|
||||
print(f" Loaded latent mean: shape {weights[mean_key].shape}")
|
||||
if std_key in weights:
|
||||
encoder.per_channel_statistics.std = weights[std_key]
|
||||
print(f" Loaded latent std: shape {weights[std_key].shape}")
|
||||
|
||||
# Build encoder weights dict with key remapping
|
||||
encoder_weights = {}
|
||||
for key, value in weights.items():
|
||||
if not key.startswith(prefix):
|
||||
continue
|
||||
|
||||
# Remove prefix
|
||||
new_key = key[len(prefix):]
|
||||
|
||||
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||
if ".weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
encoder_weights[new_key] = value
|
||||
|
||||
print(f" Found {len(encoder_weights)} encoder weights")
|
||||
|
||||
# Load weights
|
||||
encoder.load_weights(list(encoder_weights.items()), strict=False)
|
||||
|
||||
print("VAE encoder loaded successfully")
|
||||
return encoder
|
||||
|
||||
|
||||
def encode_image(
|
||||
image: mx.array,
|
||||
encoder: VideoEncoder,
|
||||
) -> mx.array:
|
||||
"""Encode a single image to latent space.
|
||||
|
||||
Args:
|
||||
image: Image tensor of shape (H, W, 3) in range [0, 1] or (B, H, W, 3)
|
||||
encoder: Loaded VAE encoder
|
||||
|
||||
Returns:
|
||||
Latent tensor of shape (1, 128, 1, H//32, W//32)
|
||||
"""
|
||||
# Add batch dimension if needed
|
||||
if image.ndim == 3:
|
||||
image = mx.expand_dims(image, axis=0) # (1, H, W, 3)
|
||||
|
||||
# Convert from (B, H, W, C) to (B, C, H, W)
|
||||
image = mx.transpose(image, (0, 3, 1, 2)) # (B, 3, H, W)
|
||||
|
||||
# Normalize to [-1, 1]
|
||||
if image.max() > 1.0:
|
||||
image = image / 255.0
|
||||
image = image * 2.0 - 1.0
|
||||
|
||||
# Add temporal dimension: (B, C, H, W) -> (B, C, 1, H, W)
|
||||
image = mx.expand_dims(image, axis=2) # (B, 3, 1, H, W)
|
||||
|
||||
# Encode
|
||||
latent = encoder(image)
|
||||
|
||||
return latent
|
||||
@@ -141,9 +141,10 @@ class UNetMidBlock3D(nn.Module):
|
||||
|
||||
self.num_layers = num_layers
|
||||
|
||||
# Create ResNet blocks
|
||||
self.resnets = [
|
||||
ResnetBlock3D(
|
||||
# Create ResNet blocks - use dict for MLX parameter tracking
|
||||
# Named res_blocks to match PyTorch weight keys
|
||||
self.res_blocks = {
|
||||
i: ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
@@ -154,8 +155,8 @@ class UNetMidBlock3D(nn.Module):
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
for i in range(num_layers)
|
||||
}
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -164,8 +165,8 @@ class UNetMidBlock3D(nn.Module):
|
||||
timestep: Optional[mx.array] = None,
|
||||
generator: Optional[int] = None,
|
||||
) -> mx.array:
|
||||
|
||||
for resnet in self.resnets:
|
||||
|
||||
for resnet in self.res_blocks.values():
|
||||
x = resnet(x, causal=causal, generator=generator)
|
||||
|
||||
return x
|
||||
|
||||
@@ -9,6 +9,15 @@ from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingMode
|
||||
|
||||
|
||||
class SpaceToDepthDownsample(nn.Module):
|
||||
"""Space-to-depth downsampling with 3x3 conv and skip connection.
|
||||
|
||||
PyTorch-compatible implementation:
|
||||
1. Apply 3x3 conv: in_channels -> out_channels // prod(stride)
|
||||
2. Space-to-depth on conv output: channels * prod(stride)
|
||||
3. Space-to-depth on input with group averaging for skip connection
|
||||
4. Add skip connection
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
@@ -17,7 +26,6 @@ class SpaceToDepthDownsample(nn.Module):
|
||||
stride: Union[int, Tuple[int, int, int]],
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if isinstance(stride, int):
|
||||
@@ -25,61 +33,74 @@ class SpaceToDepthDownsample(nn.Module):
|
||||
|
||||
self.stride = stride
|
||||
self.dims = dims
|
||||
self.out_channels = out_channels
|
||||
|
||||
# Calculate the multiplier for channels
|
||||
# Calculate channels
|
||||
multiplier = stride[0] * stride[1] * stride[2]
|
||||
intermediate_channels = in_channels * multiplier
|
||||
self.group_size = in_channels * multiplier // out_channels
|
||||
conv_out_channels = out_channels // multiplier
|
||||
|
||||
# 1x1x1 convolution to adjust channels
|
||||
# 3x3 convolution (not 1x1)
|
||||
self.conv = CausalConv3d(
|
||||
in_channels=intermediate_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
in_channels=in_channels,
|
||||
out_channels=conv_out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=0,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
|
||||
|
||||
def _space_to_depth(self, x: mx.array) -> mx.array:
|
||||
"""Rearrange: b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w"""
|
||||
b, c, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
|
||||
# Reshape to group spatial elements
|
||||
x = mx.reshape(x, (b, c, d // st, st, h // sh, sh, w // sw, sw))
|
||||
|
||||
# Permute: (B, C, D', st, H', sh, W', sw) -> (B, C, st, sh, sw, D', H', W')
|
||||
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
|
||||
|
||||
# Reshape to combine channels
|
||||
new_c = c * st * sh * sw
|
||||
new_d = d // st
|
||||
new_h = h // sh
|
||||
new_w = w // sw
|
||||
x = mx.reshape(x, (b, new_c, new_d, new_h, new_w))
|
||||
|
||||
return x
|
||||
|
||||
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
|
||||
b, c, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
|
||||
# Temporal padding for causal mode
|
||||
if st == 2:
|
||||
# Duplicate first frame for padding
|
||||
x = mx.concatenate([x[:, :, :1, :, :], x], axis=2)
|
||||
d = d + 1
|
||||
|
||||
# Pad if necessary to make dimensions divisible by stride
|
||||
pad_d = (st - d % st) % st
|
||||
pad_h = (sh - h % sh) % sh
|
||||
pad_w = (sw - w % sw) % sw
|
||||
|
||||
if pad_d > 0 or pad_h > 0 or pad_w > 0:
|
||||
# For causal, pad at the end of temporal dimension
|
||||
if causal:
|
||||
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)])
|
||||
else:
|
||||
x = mx.pad(x, [(0, 0), (0, 0), (pad_d // 2, pad_d - pad_d // 2),
|
||||
(pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2)])
|
||||
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)])
|
||||
|
||||
b, c, d, h, w = x.shape
|
||||
# Skip connection: space-to-depth on input, then group mean
|
||||
x_in = self._space_to_depth(x)
|
||||
# Reshape for group mean: (b, c*prod(stride), d, h, w) -> (b, out_channels, group_size, d, h, w)
|
||||
b2, c2, d2, h2, w2 = x_in.shape
|
||||
x_in = mx.reshape(x_in, (b2, self.out_channels, self.group_size, d2, h2, w2))
|
||||
x_in = mx.mean(x_in, axis=2) # (b, out_channels, d, h, w)
|
||||
|
||||
# Reshape to group spatial elements
|
||||
# (B, C, D, H, W) -> (B, C, D/st, st, H/sh, sh, W/sw, sw)
|
||||
x = mx.reshape(x, (b, c, d // st, st, h // sh, sh, w // sw, sw))
|
||||
# Conv branch: apply conv then space-to-depth
|
||||
x_conv = self.conv(x, causal=causal)
|
||||
x_conv = self._space_to_depth(x_conv)
|
||||
|
||||
# Permute to move stride elements to channel dim
|
||||
# (B, C, D', st, H', sh, W', sw) -> (B, C, st, sh, sw, D', H', W')
|
||||
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
|
||||
|
||||
# Reshape to combine channels
|
||||
# (B, C, st, sh, sw, D', H', W') -> (B, C*st*sh*sw, D', H', W')
|
||||
new_c = c * st * sh * sw
|
||||
new_d = d // st
|
||||
new_h = h // sh
|
||||
new_w = w // sw
|
||||
x = mx.reshape(x, (b, new_c, new_d, new_h, new_w))
|
||||
|
||||
# Apply 1x1 conv to adjust channels
|
||||
x = self.conv(x, causal=causal)
|
||||
|
||||
return x
|
||||
# Add skip connection
|
||||
return x_conv + x_in
|
||||
|
||||
|
||||
class DepthToSpaceUpsample(nn.Module):
|
||||
|
||||
@@ -273,9 +273,9 @@ class VideoEncoder(nn.Module):
|
||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||
)
|
||||
|
||||
# Build encoder blocks
|
||||
self.down_blocks = []
|
||||
for block_name, block_params in encoder_blocks:
|
||||
# Build encoder blocks - use dict with int keys for MLX parameter tracking
|
||||
self.down_blocks = {}
|
||||
for i, (block_name, block_params) in enumerate(encoder_blocks):
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
|
||||
block, feature_channels = _make_encoder_block(
|
||||
@@ -287,7 +287,7 @@ class VideoEncoder(nn.Module):
|
||||
norm_num_groups=self._norm_num_groups,
|
||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||
)
|
||||
self.down_blocks.append(block)
|
||||
self.down_blocks[i] = block
|
||||
|
||||
# Output normalization and convolution
|
||||
if norm_layer == NormLayerType.GROUP_NORM:
|
||||
@@ -341,7 +341,7 @@ class VideoEncoder(nn.Module):
|
||||
sample = self.conv_in(sample, causal=True)
|
||||
|
||||
# Process through encoder blocks
|
||||
for down_block in self.down_blocks:
|
||||
for down_block in self.down_blocks.values():
|
||||
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
|
||||
sample = down_block(sample, causal=True)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user