add audio
This commit is contained in:
@@ -109,6 +109,84 @@ def load_vae_weights(model_path: Path) -> Dict[str, mx.array]:
|
||||
raise FileNotFoundError(f"VAE weights not found at {vae_path}")
|
||||
|
||||
|
||||
def load_audio_vae_weights(model_path: Path) -> Dict[str, mx.array]:
|
||||
"""Load audio VAE weights from LTX-2 model.
|
||||
|
||||
Args:
|
||||
model_path: Path to LTX-2 model directory
|
||||
|
||||
Returns:
|
||||
Dictionary of audio VAE weights
|
||||
"""
|
||||
# Try different possible paths for audio VAE weights
|
||||
audio_vae_paths = [
|
||||
model_path / "audio_vae" / "diffusion_pytorch_model.safetensors",
|
||||
model_path / "audio_vae.safetensors",
|
||||
]
|
||||
|
||||
# Also check in main model weights
|
||||
main_paths = [
|
||||
model_path / "ltx-2-19b-distilled.safetensors",
|
||||
model_path / "ltx-2-19b-dev.safetensors",
|
||||
]
|
||||
|
||||
for audio_path in audio_vae_paths:
|
||||
if audio_path.exists():
|
||||
print(f"Loading audio VAE weights from {audio_path}...")
|
||||
return mx.load(str(audio_path))
|
||||
|
||||
# Check main model weights for audio_vae keys
|
||||
for main_path in main_paths:
|
||||
if main_path.exists():
|
||||
print(f"Loading audio VAE weights from {main_path.name}...")
|
||||
all_weights = mx.load(str(main_path))
|
||||
# Filter to only audio_vae keys
|
||||
audio_weights = {k: v for k, v in all_weights.items() if "audio_vae" in k}
|
||||
if audio_weights:
|
||||
return audio_weights
|
||||
|
||||
raise FileNotFoundError(f"Audio VAE weights not found in {model_path}")
|
||||
|
||||
|
||||
def load_vocoder_weights(model_path: Path) -> Dict[str, mx.array]:
|
||||
"""Load vocoder weights from LTX-2 model.
|
||||
|
||||
Args:
|
||||
model_path: Path to LTX-2 model directory
|
||||
|
||||
Returns:
|
||||
Dictionary of vocoder weights
|
||||
"""
|
||||
# Try different possible paths for vocoder weights
|
||||
vocoder_paths = [
|
||||
model_path / "vocoder" / "diffusion_pytorch_model.safetensors",
|
||||
model_path / "vocoder.safetensors",
|
||||
]
|
||||
|
||||
# Also check in main model weights
|
||||
main_paths = [
|
||||
model_path / "ltx-2-19b-distilled.safetensors",
|
||||
model_path / "ltx-2-19b-dev.safetensors",
|
||||
]
|
||||
|
||||
for vocoder_path in vocoder_paths:
|
||||
if vocoder_path.exists():
|
||||
print(f"Loading vocoder weights from {vocoder_path}...")
|
||||
return mx.load(str(vocoder_path))
|
||||
|
||||
# Check main model weights for vocoder keys
|
||||
for main_path in main_paths:
|
||||
if main_path.exists():
|
||||
print(f"Loading vocoder weights from {main_path.name}...")
|
||||
all_weights = mx.load(str(main_path))
|
||||
# Filter to only vocoder keys
|
||||
vocoder_weights = {k: v for k, v in all_weights.items() if "vocoder" in k}
|
||||
if vocoder_weights:
|
||||
return vocoder_weights
|
||||
|
||||
raise FileNotFoundError(f"Vocoder weights not found in {model_path}")
|
||||
|
||||
|
||||
def sanitize_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize transformer weight names from PyTorch LTX-2 format to MLX format.
|
||||
|
||||
@@ -213,6 +291,83 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize audio VAE weight names from PyTorch format to MLX format.
|
||||
|
||||
Args:
|
||||
weights: Dictionary of weights with PyTorch naming
|
||||
|
||||
Returns:
|
||||
Dictionary with MLX-compatible naming for audio VAE decoder
|
||||
"""
|
||||
sanitized = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Handle audio_vae.decoder weights
|
||||
if key.startswith("audio_vae.decoder."):
|
||||
new_key = key.replace("audio_vae.decoder.", "")
|
||||
elif key.startswith("audio_vae.per_channel_statistics."):
|
||||
# Map per-channel statistics
|
||||
if "mean-of-means" in key:
|
||||
new_key = "per_channel_statistics._mean_of_means"
|
||||
elif "std-of-means" in key:
|
||||
new_key = "per_channel_statistics._std_of_means"
|
||||
else:
|
||||
continue # Skip other statistics keys
|
||||
else:
|
||||
continue # Skip non-decoder keys
|
||||
|
||||
# Handle Conv2d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, H, W)
|
||||
# MLX: (out_channels, H, W, in_channels)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_vocoder_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize vocoder weight names from PyTorch format to MLX format.
|
||||
|
||||
Args:
|
||||
weights: Dictionary of weights with PyTorch naming
|
||||
|
||||
Returns:
|
||||
Dictionary with MLX-compatible naming for vocoder
|
||||
"""
|
||||
sanitized = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Handle vocoder weights
|
||||
if key.startswith("vocoder."):
|
||||
new_key = key.replace("vocoder.", "")
|
||||
|
||||
# Handle ModuleList indices -> dict keys
|
||||
# PyTorch: ups.0, ups.1, ... -> ups.0, ups.1, ...
|
||||
# PyTorch: resblocks.0, resblocks.1, ... -> resblocks.0, resblocks.1, ...
|
||||
|
||||
# Handle Conv1d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, kernel)
|
||||
# MLX: (out_channels, kernel, in_channels)
|
||||
if "weight" in new_key and value.ndim == 3:
|
||||
if "ups" in new_key:
|
||||
# ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch)
|
||||
value = mx.transpose(value, (1, 2, 0))
|
||||
else:
|
||||
# Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch)
|
||||
value = mx.transpose(value, (0, 2, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize weight names from PyTorch format to MLX format.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user