Files
mlx-video/mlx_video/models/ltx_2/audio_vae/normalization.py
Prince Canuma 17397da70c format
2026-03-18 17:40:05 +01:00

62 lines
2.0 KiB
Python

"""Normalization layers for audio VAE."""
from enum import Enum
import mlx.core as mx
import mlx.nn as nn
class NormType(Enum):
"""Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
GROUP = "group"
PIXEL = "pixel"
class PixelNorm(nn.Module):
"""
Per-pixel (per-location) RMS normalization layer.
For each element along the chosen dimension, this layer normalizes the tensor
by the root-mean-square of its values across that dimension:
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
"""
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
"""
Args:
dim: Dimension along which to compute the RMS (typically channels).
eps: Small constant added for numerical stability.
"""
super().__init__()
self.dim = dim
self.eps = eps
def __call__(self, x: mx.array) -> mx.array:
"""Apply RMS normalization along the configured dimension."""
mean_sq = mx.mean(x**2, axis=self.dim, keepdims=True)
rms = mx.sqrt(mean_sq + self.eps)
return x / rms
def build_normalization_layer(
in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
) -> nn.Module:
"""
Create a normalization layer based on the normalization type.
Args:
in_channels: Number of input channels
num_groups: Number of groups for group normalization
normtype: Type of normalization: "group" or "pixel"
Returns:
A normalization layer
"""
if normtype == NormType.GROUP:
return nn.GroupNorm(
num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True
)
if normtype == NormType.PIXEL:
# For MLX channels-last format (B, H, W, C), normalize along channels (dim=-1)
# PyTorch uses dim=1 for channels-first format (B, C, H, W)
return PixelNorm(dim=-1, eps=1e-6)
raise ValueError(f"Invalid normalization type: {normtype}")