Add image-to-video (I2V) conditioning support

- Introduced `load_image`, `prepare_image_for_encoding`, and `apply_conditioning` functions for handling image inputs and conditioning during video generation.
- Enhanced `generate_video` and `denoise_av` functions to accept optional image inputs for I2V conditioning.
- Updated command-line interface to include parameters for image conditioning, such as `--image`, `--image-strength`, and `--image-frame-idx`.
- Added new `VideoConditionByLatentIndex` and `LatentState` classes for managing latent states with conditioning.
- Implemented VAE encoder loading and image encoding for conditioning in the video generation process.d
This commit is contained in:
Prince Canuma
2026-01-17 00:19:52 +01:00
parent 5f86e881d7
commit 146f5d2981
11 changed files with 937 additions and 67 deletions

View File

@@ -141,9 +141,10 @@ class UNetMidBlock3D(nn.Module):
self.num_layers = num_layers
# Create ResNet blocks
self.resnets = [
ResnetBlock3D(
# Create ResNet blocks - use dict for MLX parameter tracking
# Named res_blocks to match PyTorch weight keys
self.res_blocks = {
i: ResnetBlock3D(
dims=dims,
in_channels=in_channels,
out_channels=in_channels,
@@ -154,8 +155,8 @@ class UNetMidBlock3D(nn.Module):
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
for _ in range(num_layers)
]
for i in range(num_layers)
}
def __call__(
self,
@@ -164,8 +165,8 @@ class UNetMidBlock3D(nn.Module):
timestep: Optional[mx.array] = None,
generator: Optional[int] = None,
) -> mx.array:
for resnet in self.resnets:
for resnet in self.res_blocks.values():
x = resnet(x, causal=causal, generator=generator)
return x