Add image-to-video (I2V) conditioning support
- Introduced `load_image`, `prepare_image_for_encoding`, and `apply_conditioning` functions for handling image inputs and conditioning during video generation. - Enhanced `generate_video` and `denoise_av` functions to accept optional image inputs for I2V conditioning. - Updated command-line interface to include parameters for image conditioning, such as `--image`, `--image-strength`, and `--image-frame-idx`. - Added new `VideoConditionByLatentIndex` and `LatentState` classes for managing latent states with conditioning. - Implemented VAE encoder loading and image encoding for conditioning in the video generation process.d
This commit is contained in:
@@ -141,9 +141,10 @@ class UNetMidBlock3D(nn.Module):
|
||||
|
||||
self.num_layers = num_layers
|
||||
|
||||
# Create ResNet blocks
|
||||
self.resnets = [
|
||||
ResnetBlock3D(
|
||||
# Create ResNet blocks - use dict for MLX parameter tracking
|
||||
# Named res_blocks to match PyTorch weight keys
|
||||
self.res_blocks = {
|
||||
i: ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
@@ -154,8 +155,8 @@ class UNetMidBlock3D(nn.Module):
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
for i in range(num_layers)
|
||||
}
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -164,8 +165,8 @@ class UNetMidBlock3D(nn.Module):
|
||||
timestep: Optional[mx.array] = None,
|
||||
generator: Optional[int] = None,
|
||||
) -> mx.array:
|
||||
|
||||
for resnet in self.resnets:
|
||||
|
||||
for resnet in self.res_blocks.values():
|
||||
x = resnet(x, causal=causal, generator=generator)
|
||||
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user