Refactor weight loading and sanitization processes for audio models
This commit is contained in:
@@ -355,6 +355,9 @@ def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.arr
|
||||
"""
|
||||
sanitized = {}
|
||||
|
||||
if "audio_vae." in weights:
|
||||
return weights
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
@@ -364,9 +367,9 @@ def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.arr
|
||||
elif key.startswith("audio_vae.per_channel_statistics."):
|
||||
# Map per-channel statistics
|
||||
if "mean-of-means" in key:
|
||||
new_key = "per_channel_statistics._mean_of_means"
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
elif "std-of-means" in key:
|
||||
new_key = "per_channel_statistics._std_of_means"
|
||||
new_key = "per_channel_statistics.std_of_means"
|
||||
else:
|
||||
continue # Skip other statistics keys
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user