This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -104,7 +104,7 @@ class SpaceToDepthDownsample(nn.Module):
class DepthToSpaceUpsample(nn.Module):
def __init__(
self,
dims: int,
@@ -114,7 +114,7 @@ class DepthToSpaceUpsample(nn.Module):
out_channels_reduction_factor: int = 1,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
if isinstance(stride, int):
@@ -156,7 +156,9 @@ class DepthToSpaceUpsample(nn.Module):
return x
def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array:
def __call__(
self, x: mx.array, causal: bool = True, chunked_conv: bool = False
) -> mx.array:
b, c, d, h, w = x.shape
st, sh, sw = self.stride
@@ -196,7 +198,9 @@ class DepthToSpaceUpsample(nn.Module):
return x
def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array:
def _chunked_conv_depth_to_space(
self, x: mx.array, causal: bool = True
) -> mx.array:
"""Chunked conv + depth_to_space that processes in temporal chunks.
This reduces peak memory by avoiding the full high-channel intermediate tensor.