This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -51,7 +51,9 @@ def build_normalization_layer(
A normalization layer
"""
if normtype == NormType.GROUP:
return nn.GroupNorm(num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True)
return nn.GroupNorm(
num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True
)
if normtype == NormType.PIXEL:
# For MLX channels-last format (B, H, W, C), normalize along channels (dim=-1)
# PyTorch uses dim=1 for channels-first format (B, C, H, W)