add audio

This commit is contained in:
Prince Canuma
2026-01-16 01:15:22 +01:00
parent 81daf3f67d
commit a658911f98
19 changed files with 2335 additions and 54 deletions

View File

@@ -109,6 +109,84 @@ def load_vae_weights(model_path: Path) -> Dict[str, mx.array]:
raise FileNotFoundError(f"VAE weights not found at {vae_path}")
def load_audio_vae_weights(model_path: Path) -> Dict[str, mx.array]:
"""Load audio VAE weights from LTX-2 model.
Args:
model_path: Path to LTX-2 model directory
Returns:
Dictionary of audio VAE weights
"""
# Try different possible paths for audio VAE weights
audio_vae_paths = [
model_path / "audio_vae" / "diffusion_pytorch_model.safetensors",
model_path / "audio_vae.safetensors",
]
# Also check in main model weights
main_paths = [
model_path / "ltx-2-19b-distilled.safetensors",
model_path / "ltx-2-19b-dev.safetensors",
]
for audio_path in audio_vae_paths:
if audio_path.exists():
print(f"Loading audio VAE weights from {audio_path}...")
return mx.load(str(audio_path))
# Check main model weights for audio_vae keys
for main_path in main_paths:
if main_path.exists():
print(f"Loading audio VAE weights from {main_path.name}...")
all_weights = mx.load(str(main_path))
# Filter to only audio_vae keys
audio_weights = {k: v for k, v in all_weights.items() if "audio_vae" in k}
if audio_weights:
return audio_weights
raise FileNotFoundError(f"Audio VAE weights not found in {model_path}")
def load_vocoder_weights(model_path: Path) -> Dict[str, mx.array]:
"""Load vocoder weights from LTX-2 model.
Args:
model_path: Path to LTX-2 model directory
Returns:
Dictionary of vocoder weights
"""
# Try different possible paths for vocoder weights
vocoder_paths = [
model_path / "vocoder" / "diffusion_pytorch_model.safetensors",
model_path / "vocoder.safetensors",
]
# Also check in main model weights
main_paths = [
model_path / "ltx-2-19b-distilled.safetensors",
model_path / "ltx-2-19b-dev.safetensors",
]
for vocoder_path in vocoder_paths:
if vocoder_path.exists():
print(f"Loading vocoder weights from {vocoder_path}...")
return mx.load(str(vocoder_path))
# Check main model weights for vocoder keys
for main_path in main_paths:
if main_path.exists():
print(f"Loading vocoder weights from {main_path.name}...")
all_weights = mx.load(str(main_path))
# Filter to only vocoder keys
vocoder_weights = {k: v for k, v in all_weights.items() if "vocoder" in k}
if vocoder_weights:
return vocoder_weights
raise FileNotFoundError(f"Vocoder weights not found in {model_path}")
def sanitize_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize transformer weight names from PyTorch LTX-2 format to MLX format.
@@ -213,6 +291,83 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
return sanitized
def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize audio VAE weight names from PyTorch format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for audio VAE decoder
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Handle audio_vae.decoder weights
if key.startswith("audio_vae.decoder."):
new_key = key.replace("audio_vae.decoder.", "")
elif key.startswith("audio_vae.per_channel_statistics."):
# Map per-channel statistics
if "mean-of-means" in key:
new_key = "per_channel_statistics._mean_of_means"
elif "std-of-means" in key:
new_key = "per_channel_statistics._std_of_means"
else:
continue # Skip other statistics keys
else:
continue # Skip non-decoder keys
# Handle Conv2d weight shape conversion
# PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
def sanitize_vocoder_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize vocoder weight names from PyTorch format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for vocoder
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Handle vocoder weights
if key.startswith("vocoder."):
new_key = key.replace("vocoder.", "")
# Handle ModuleList indices -> dict keys
# PyTorch: ups.0, ups.1, ... -> ups.0, ups.1, ...
# PyTorch: resblocks.0, resblocks.1, ... -> resblocks.0, resblocks.1, ...
# Handle Conv1d weight shape conversion
# PyTorch: (out_channels, in_channels, kernel)
# MLX: (out_channels, kernel, in_channels)
if "weight" in new_key and value.ndim == 3:
if "ups" in new_key:
# ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch)
value = mx.transpose(value, (1, 2, 0))
else:
# Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch)
value = mx.transpose(value, (0, 2, 1))
sanitized[new_key] = value
return sanitized
def sanitize_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize weight names from PyTorch format to MLX format.