Add image-to-video (I2V) conditioning support

- Introduced `load_image`, `prepare_image_for_encoding`, and `apply_conditioning` functions for handling image inputs and conditioning during video generation.
- Enhanced `generate_video` and `denoise_av` functions to accept optional image inputs for I2V conditioning.
- Updated command-line interface to include parameters for image conditioning, such as `--image`, `--image-strength`, and `--image-frame-idx`.
- Added new `VideoConditionByLatentIndex` and `LatentState` classes for managing latent states with conditioning.
- Implemented VAE encoder loading and image encoding for conditioning in the video generation process.d
This commit is contained in:
Prince Canuma
2026-01-17 00:19:52 +01:00
parent 5f86e881d7
commit 146f5d2981
11 changed files with 937 additions and 67 deletions

View File

@@ -9,6 +9,15 @@ from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingMode
class SpaceToDepthDownsample(nn.Module):
"""Space-to-depth downsampling with 3x3 conv and skip connection.
PyTorch-compatible implementation:
1. Apply 3x3 conv: in_channels -> out_channels // prod(stride)
2. Space-to-depth on conv output: channels * prod(stride)
3. Space-to-depth on input with group averaging for skip connection
4. Add skip connection
"""
def __init__(
self,
dims: int,
@@ -17,7 +26,6 @@ class SpaceToDepthDownsample(nn.Module):
stride: Union[int, Tuple[int, int, int]],
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
if isinstance(stride, int):
@@ -25,61 +33,74 @@ class SpaceToDepthDownsample(nn.Module):
self.stride = stride
self.dims = dims
self.out_channels = out_channels
# Calculate the multiplier for channels
# Calculate channels
multiplier = stride[0] * stride[1] * stride[2]
intermediate_channels = in_channels * multiplier
self.group_size = in_channels * multiplier // out_channels
conv_out_channels = out_channels // multiplier
# 1x1x1 convolution to adjust channels
# 3x3 convolution (not 1x1)
self.conv = CausalConv3d(
in_channels=intermediate_channels,
out_channels=out_channels,
kernel_size=1,
in_channels=in_channels,
out_channels=conv_out_channels,
kernel_size=3,
stride=1,
padding=0,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
def _space_to_depth(self, x: mx.array) -> mx.array:
"""Rearrange: b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w"""
b, c, d, h, w = x.shape
st, sh, sw = self.stride
# Reshape to group spatial elements
x = mx.reshape(x, (b, c, d // st, st, h // sh, sh, w // sw, sw))
# Permute: (B, C, D', st, H', sh, W', sw) -> (B, C, st, sh, sw, D', H', W')
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
# Reshape to combine channels
new_c = c * st * sh * sw
new_d = d // st
new_h = h // sh
new_w = w // sw
x = mx.reshape(x, (b, new_c, new_d, new_h, new_w))
return x
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
b, c, d, h, w = x.shape
st, sh, sw = self.stride
# Temporal padding for causal mode
if st == 2:
# Duplicate first frame for padding
x = mx.concatenate([x[:, :, :1, :, :], x], axis=2)
d = d + 1
# Pad if necessary to make dimensions divisible by stride
pad_d = (st - d % st) % st
pad_h = (sh - h % sh) % sh
pad_w = (sw - w % sw) % sw
if pad_d > 0 or pad_h > 0 or pad_w > 0:
# For causal, pad at the end of temporal dimension
if causal:
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)])
else:
x = mx.pad(x, [(0, 0), (0, 0), (pad_d // 2, pad_d - pad_d // 2),
(pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2)])
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)])
b, c, d, h, w = x.shape
# Skip connection: space-to-depth on input, then group mean
x_in = self._space_to_depth(x)
# Reshape for group mean: (b, c*prod(stride), d, h, w) -> (b, out_channels, group_size, d, h, w)
b2, c2, d2, h2, w2 = x_in.shape
x_in = mx.reshape(x_in, (b2, self.out_channels, self.group_size, d2, h2, w2))
x_in = mx.mean(x_in, axis=2) # (b, out_channels, d, h, w)
# Reshape to group spatial elements
# (B, C, D, H, W) -> (B, C, D/st, st, H/sh, sh, W/sw, sw)
x = mx.reshape(x, (b, c, d // st, st, h // sh, sh, w // sw, sw))
# Conv branch: apply conv then space-to-depth
x_conv = self.conv(x, causal=causal)
x_conv = self._space_to_depth(x_conv)
# Permute to move stride elements to channel dim
# (B, C, D', st, H', sh, W', sw) -> (B, C, st, sh, sw, D', H', W')
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
# Reshape to combine channels
# (B, C, st, sh, sw, D', H', W') -> (B, C*st*sh*sw, D', H', W')
new_c = c * st * sh * sw
new_d = d // st
new_h = h // sh
new_w = w // sw
x = mx.reshape(x, (b, new_c, new_d, new_h, new_w))
# Apply 1x1 conv to adjust channels
x = self.conv(x, causal=causal)
return x
# Add skip connection
return x_conv + x_in
class DepthToSpaceUpsample(nn.Module):