60 lines
1.9 KiB
Python
60 lines
1.9 KiB
Python
"""Normalization layers for audio VAE."""
|
|
|
|
from enum import Enum
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
|
|
class NormType(Enum):
|
|
"""Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
|
|
|
|
GROUP = "group"
|
|
PIXEL = "pixel"
|
|
|
|
|
|
class PixelNorm(nn.Module):
|
|
"""
|
|
Per-pixel (per-location) RMS normalization layer.
|
|
For each element along the chosen dimension, this layer normalizes the tensor
|
|
by the root-mean-square of its values across that dimension:
|
|
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
|
|
"""
|
|
|
|
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
|
|
"""
|
|
Args:
|
|
dim: Dimension along which to compute the RMS (typically channels).
|
|
eps: Small constant added for numerical stability.
|
|
"""
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.eps = eps
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
"""Apply RMS normalization along the configured dimension."""
|
|
mean_sq = mx.mean(x**2, axis=self.dim, keepdims=True)
|
|
rms = mx.sqrt(mean_sq + self.eps)
|
|
return x / rms
|
|
|
|
|
|
def build_normalization_layer(
|
|
in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
|
|
) -> nn.Module:
|
|
"""
|
|
Create a normalization layer based on the normalization type.
|
|
Args:
|
|
in_channels: Number of input channels
|
|
num_groups: Number of groups for group normalization
|
|
normtype: Type of normalization: "group" or "pixel"
|
|
Returns:
|
|
A normalization layer
|
|
"""
|
|
if normtype == NormType.GROUP:
|
|
return nn.GroupNorm(num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True)
|
|
if normtype == NormType.PIXEL:
|
|
# For MLX channels-last format (B, H, W, C), normalize along channels (dim=-1)
|
|
# PyTorch uses dim=1 for channels-first format (B, C, H, W)
|
|
return PixelNorm(dim=-1, eps=1e-6)
|
|
raise ValueError(f"Invalid normalization type: {normtype}")
|