from enum import Enum from typing import Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn class PaddingModeType(Enum): ZEROS = "zeros" REFLECT = "reflect" def reflect_pad_2d(x: mx.array, pad_h: int, pad_w: int) -> mx.array: """Apply reflect padding to spatial dimensions of a 5D tensor. Args: x: Input tensor of shape (B, D, H, W, C) - channels last pad_h: Padding for height dimension pad_w: Padding for width dimension Returns: Padded tensor """ if pad_h == 0 and pad_w == 0: return x # Height padding (axis 2) if pad_h > 0: # Get reflection indices - exclude boundary top_pad = x[:, :, 1 : pad_h + 1, :, :][:, :, ::-1, :, :] # Flip top portion bottom_pad = x[:, :, -pad_h - 1 : -1, :, :][ :, :, ::-1, :, : ] # Flip bottom portion x = mx.concatenate([top_pad, x, bottom_pad], axis=2) # Width padding (axis 3) if pad_w > 0: left_pad = x[:, :, :, 1 : pad_w + 1, :][:, :, :, ::-1, :] # Flip left portion right_pad = x[:, :, :, -pad_w - 1 : -1, :][ :, :, :, ::-1, : ] # Flip right portion x = mx.concatenate([left_pad, x, right_pad], axis=3) return x def make_conv_nd( dims: int, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, ...]], stride: Union[int, Tuple[int, ...]] = 1, padding: Union[int, Tuple[int, ...], str] = 0, causal: bool = False, spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, ) -> nn.Module: if dims == 2: return CausalConv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, causal=causal, spatial_padding_mode=spatial_padding_mode, ) elif dims == 3: return CausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, causal=causal, spatial_padding_mode=spatial_padding_mode, ) else: raise ValueError(f"Unsupported number of dimensions: {dims}") class CausalConv3d(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int, int]], stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int], str] = 0, causal: bool = False, spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, ): super().__init__() self.causal = causal self.spatial_padding_mode = spatial_padding_mode # Normalize kernel_size and stride to tuples if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(stride, int): stride = (stride, stride, stride) self.kernel_size = kernel_size self.stride = stride self.time_kernel_size = kernel_size[0] # Calculate spatial padding (temporal is handled separately via frame replication) height_pad = kernel_size[1] // 2 width_pad = kernel_size[2] // 2 self.spatial_padding = (height_pad, width_pad) # Create the base convolution (without padding, we'll handle it manually) self.conv = nn.Conv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=0, # We handle padding manually bias=True, ) def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array: use_causal = causal if causal is not None else self.causal # Apply temporal padding via frame replication # Only apply if kernel_size > 1 if self.time_kernel_size > 1: if use_causal: # Causal: replicate first frame kernel_size-1 times at the beginning first_frame_pad = mx.repeat( x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2 ) x = mx.concatenate([first_frame_pad, x], axis=2) else: # Non-causal: replicate first frame at start, last frame at end pad_size = (self.time_kernel_size - 1) // 2 if pad_size > 0: first_frame_pad = mx.repeat(x[:, :, :1, :, :], pad_size, axis=2) last_frame_pad = mx.repeat(x[:, :, -1:, :, :], pad_size, axis=2) x = mx.concatenate([first_frame_pad, x, last_frame_pad], axis=2) # Transpose to channels last: (B, C, D, H, W) -> (B, D, H, W, C) x = mx.transpose(x, (0, 2, 3, 4, 1)) # Apply spatial padding pad_h, pad_w = self.spatial_padding if pad_h > 0 or pad_w > 0: if self.spatial_padding_mode == PaddingModeType.REFLECT: # Use reflect padding for spatial dimensions x = reflect_pad_2d(x, pad_h, pad_w) else: # Use zero padding for spatial dimensions pad_width = [ (0, 0), # Batch (0, 0), # D (temporal - already padded) (pad_h, pad_h), # H (pad_w, pad_w), # W (0, 0), # C ] x = mx.pad(x, pad_width) # Apply convolution with chunking for large tensors # Note: We choose to use chunking because MLX conv3d fails around 33 frames with 192x192 spatial x = self._chunked_conv3d(x) # Transpose back to channels first: (B, D, H, W, C) -> (B, C, D, H, W) x = mx.transpose(x, (0, 4, 1, 2, 3)) return x def _chunked_conv3d(self, x: mx.array) -> mx.array: """Apply conv3d in temporal chunks to work around MLX bug with large tensors. Args: x: Input tensor of shape (B, D, H, W, C) in channels-last format Returns: Output tensor after conv3d """ b, d, h, w, c = x.shape total_elements = d * h * w * c max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk if total_elements <= max_safe_elements: return self.conv(x) elements_per_frame = h * w * c max_frames_per_chunk = max(1, max_safe_elements // elements_per_frame) chunk_size = min(max_frames_per_chunk, 24) # Cap at 24 frames per chunk kernel_t = self.time_kernel_size overlap = kernel_t - 1 expected_output_frames = d - overlap outputs = [] out_idx = 0 # Process chunks in_start = 0 while out_idx < expected_output_frames: remaining = expected_output_frames - out_idx out_frames_this_chunk = min(chunk_size, remaining) in_frames_needed = out_frames_this_chunk + overlap in_end = min(in_start + in_frames_needed, d) chunk = x[:, in_start:in_end, :, :, :] chunk_out = self.conv(chunk) mx.eval(chunk_out) outputs.append(chunk_out) out_idx += chunk_out.shape[1] in_start += chunk_out.shape[1] # Concatenate all chunks if len(outputs) == 1: return outputs[0] return mx.concatenate(outputs, axis=1) class CausalConv2d(nn.Module): """2D convolution with optional causal padding.""" def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int], str] = 0, causal: bool = False, spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, ): """Initialize CausalConv2d.""" super().__init__() self.causal = causal self.spatial_padding_mode = spatial_padding_mode # Normalize kernel_size and stride if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) if isinstance(stride, int): stride = (stride, stride) self.kernel_size = kernel_size self.stride = stride # Calculate padding if isinstance(padding, str) and padding == "same": self.padding = ( (kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2, ) elif isinstance(padding, int): self.padding = (padding, padding) else: self.padding = padding self.conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=0, bias=True, ) def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array: """Forward pass.""" # Transpose to channels last: (B, C, H, W) -> (B, H, W, C) x = mx.transpose(x, (0, 2, 3, 1)) # Apply padding pad_h, pad_w = self.padding if pad_h != 0 or pad_w != 0: pad_width = [ (0, 0), # Batch (pad_h, pad_h), # H (pad_w, pad_w), # W (0, 0), # C ] x = mx.pad(x, pad_width) x = self.conv(x) # Transpose back: (B, H, W, C) -> (B, C, H, W) x = mx.transpose(x, (0, 3, 1, 2)) return x