This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -44,7 +44,7 @@ class ResnetBlock3D(nn.Module):
timestep_conditioning: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
out_channels = out_channels or in_channels
@@ -96,7 +96,7 @@ class ResnetBlock3D(nn.Module):
causal: bool = True,
generator: Optional[int] = None,
) -> mx.array:
residual = x
# First block
@@ -136,7 +136,7 @@ class UNetMidBlock3D(nn.Module):
attention_head_dim: Optional[int] = None,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
self.num_layers = num_layers