Refactor audio VAE directory structure and update related paths in conversion and loading functions

This commit is contained in:
Prince Canuma
2026-03-16 21:53:37 +01:00
parent a6a6bb2166
commit dd573d53d2
3 changed files with 15 additions and 9 deletions

View File

@@ -15,7 +15,11 @@ or Lightricks/LTX-2.3/ltx-2.3-22b-distilled.safetensors) to the modular director
│ └── encoder/ # Video VAE encoder │ └── encoder/ # Video VAE encoder
│ ├── config.json │ ├── config.json
│ └── model.safetensors │ └── model.safetensors
├── audio_vae/ # Audio VAE decoder ├── audio_vae/
│ ├── decoder/ # Audio VAE decoder
│ │ ├── config.json
│ │ └── model.safetensors
│ └── encoder/ # Audio VAE encoder
│ ├── config.json │ ├── config.json
│ └── model.safetensors │ └── model.safetensors
├── vocoder/ # Audio vocoder ├── vocoder/ # Audio vocoder
@@ -622,9 +626,9 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
# 4. Audio VAE Decoder # 4. Audio VAE Decoder
print(" [4/6] Audio VAE Decoder...") print(" [4/6] Audio VAE Decoder...")
audio_decoder_weights = sanitize_audio_decoder(all_weights) audio_decoder_weights = sanitize_audio_decoder(all_weights)
save_single(audio_decoder_weights, output_path / "audio_vae") save_single(audio_decoder_weights, output_path / "audio_vae" / "decoder")
config = infer_audio_vae_config(audio_decoder_weights) config = infer_audio_vae_config(audio_decoder_weights)
save_config(config, output_path / "audio_vae") save_config(config, output_path / "audio_vae" / "decoder")
a_params = sum(v.size for v in audio_decoder_weights.values()) a_params = sum(v.size for v in audio_decoder_weights.values())
print(f" {len(audio_decoder_weights)} keys, {a_params:,} params") print(f" {len(audio_decoder_weights)} keys, {a_params:,} params")

View File

@@ -1360,7 +1360,7 @@ def load_audio_decoder(model_path: Path, pipeline: PipelineType):
"""Load audio VAE decoder.""" """Load audio VAE decoder."""
from mlx_video.models.ltx_2.audio_vae import AudioDecoder from mlx_video.models.ltx_2.audio_vae import AudioDecoder
decoder = AudioDecoder.from_pretrained(model_path / "audio_vae") decoder = AudioDecoder.from_pretrained(model_path / "audio_vae" / "decoder")
return decoder return decoder

View File

@@ -120,6 +120,8 @@ def load_audio_vae_weights(model_path: Path) -> Dict[str, mx.array]:
""" """
# Try different possible paths for audio VAE weights # Try different possible paths for audio VAE weights
audio_vae_paths = [ audio_vae_paths = [
model_path / "audio_vae" / "decoder" / "model.safetensors",
model_path / "audio_vae" / "decoder" / "diffusion_pytorch_model.safetensors",
model_path / "audio_vae" / "diffusion_pytorch_model.safetensors", model_path / "audio_vae" / "diffusion_pytorch_model.safetensors",
model_path / "audio_vae.safetensors", model_path / "audio_vae.safetensors",
] ]
@@ -621,10 +623,10 @@ def convert_audio_encoder(
source_repo: HF repo containing audio_vae/diffusion_pytorch_model.safetensors. source_repo: HF repo containing audio_vae/diffusion_pytorch_model.safetensors.
Returns: Returns:
Path to the audio_vae_encoder directory. Path to the audio_vae/encoder directory.
""" """
model_path = Path(model_path) model_path = Path(model_path)
encoder_dir = model_path / "audio_vae_encoder" encoder_dir = model_path / "audio_vae" / "encoder"
if (encoder_dir / "model.safetensors").exists(): if (encoder_dir / "model.safetensors").exists():
return encoder_dir return encoder_dir
@@ -643,7 +645,7 @@ def convert_audio_encoder(
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
# Build config from the decoder config (same audio VAE architecture) # Build config from the decoder config (same audio VAE architecture)
decoder_config_path = model_path / "audio_vae" / "config.json" decoder_config_path = model_path / "audio_vae" / "decoder" / "config.json"
if decoder_config_path.exists(): if decoder_config_path.exists():
with open(decoder_config_path) as f: with open(decoder_config_path) as f:
dec_cfg = json.load(f) dec_cfg = json.load(f)