format
This commit is contained in:
@@ -104,7 +104,7 @@ class SpaceToDepthDownsample(nn.Module):
|
||||
|
||||
|
||||
class DepthToSpaceUpsample(nn.Module):
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
@@ -114,7 +114,7 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
out_channels_reduction_factor: int = 1,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
|
||||
super().__init__()
|
||||
|
||||
if isinstance(stride, int):
|
||||
@@ -156,7 +156,9 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array:
|
||||
def __call__(
|
||||
self, x: mx.array, causal: bool = True, chunked_conv: bool = False
|
||||
) -> mx.array:
|
||||
|
||||
b, c, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
@@ -196,7 +198,9 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array:
|
||||
def _chunked_conv_depth_to_space(
|
||||
self, x: mx.array, causal: bool = True
|
||||
) -> mx.array:
|
||||
"""Chunked conv + depth_to_space that processes in temporal chunks.
|
||||
|
||||
This reduces peak memory by avoiding the full high-channel intermediate tensor.
|
||||
|
||||
Reference in New Issue
Block a user