Refactor weight loading and sanitization processes for audio models

This commit is contained in:
Prince Canuma
2026-01-23 17:31:25 +01:00
parent 2681f75d2f
commit 02bfa228d9
18 changed files with 510 additions and 498 deletions

View File

@@ -27,21 +27,21 @@ class PerChannelStatistics(nn.Module):
self.latent_channels = latent_channels
# Initialize buffers - will be loaded from weights
# Using underscores for MLX compatibility with weight loading
self._std_of_means = mx.ones((latent_channels,))
self._mean_of_means = mx.zeros((latent_channels,))
self.std_of_means = mx.ones((latent_channels,))
self.mean_of_means = mx.zeros((latent_channels,))
def un_normalize(self, x: mx.array) -> mx.array:
"""Denormalize latent representation."""
# Broadcast statistics to match x shape
# x shape: (B, C, ...) or (B, ..., C)
std = self._std_of_means.astype(x.dtype)
mean = self._mean_of_means.astype(x.dtype)
std = self.std_of_means.astype(x.dtype)
mean = self.mean_of_means.astype(x.dtype)
return (x * std) + mean
def normalize(self, x: mx.array) -> mx.array:
"""Normalize latent representation."""
std = self._std_of_means.astype(x.dtype)
mean = self._mean_of_means.astype(x.dtype)
std = self.std_of_means.astype(x.dtype)
mean = self.mean_of_means.astype(x.dtype)
return (x - mean) / std