Add image-to-video (I2V) conditioning support

- Introduced `load_image`, `prepare_image_for_encoding`, and `apply_conditioning` functions for handling image inputs and conditioning during video generation.
- Enhanced `generate_video` and `denoise_av` functions to accept optional image inputs for I2V conditioning.
- Updated command-line interface to include parameters for image conditioning, such as `--image`, `--image-strength`, and `--image-frame-idx`.
- Added new `VideoConditionByLatentIndex` and `LatentState` classes for managing latent states with conditioning.
- Implemented VAE encoder loading and image encoding for conditioning in the video generation process.d
This commit is contained in:
Prince Canuma
2026-01-17 00:19:52 +01:00
parent 5f86e881d7
commit 146f5d2981
11 changed files with 937 additions and 67 deletions

View File

@@ -273,9 +273,9 @@ class VideoEncoder(nn.Module):
spatial_padding_mode=encoder_spatial_padding_mode,
)
# Build encoder blocks
self.down_blocks = []
for block_name, block_params in encoder_blocks:
# Build encoder blocks - use dict with int keys for MLX parameter tracking
self.down_blocks = {}
for i, (block_name, block_params) in enumerate(encoder_blocks):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block, feature_channels = _make_encoder_block(
@@ -287,7 +287,7 @@ class VideoEncoder(nn.Module):
norm_num_groups=self._norm_num_groups,
spatial_padding_mode=encoder_spatial_padding_mode,
)
self.down_blocks.append(block)
self.down_blocks[i] = block
# Output normalization and convolution
if norm_layer == NormLayerType.GROUP_NORM:
@@ -341,7 +341,7 @@ class VideoEncoder(nn.Module):
sample = self.conv_in(sample, causal=True)
# Process through encoder blocks
for down_block in self.down_blocks:
for down_block in self.down_blocks.values():
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
sample = down_block(sample, causal=True)
else: