format
This commit is contained in:
@@ -51,7 +51,9 @@ def build_normalization_layer(
|
||||
A normalization layer
|
||||
"""
|
||||
if normtype == NormType.GROUP:
|
||||
return nn.GroupNorm(num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True)
|
||||
return nn.GroupNorm(
|
||||
num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
if normtype == NormType.PIXEL:
|
||||
# For MLX channels-last format (B, H, W, C), normalize along channels (dim=-1)
|
||||
# PyTorch uses dim=1 for channels-first format (B, C, H, W)
|
||||
|
||||
Reference in New Issue
Block a user