From dd573d53d20a25cb21ac6e72327ff27b5b9e3b35 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 16 Mar 2026 21:53:37 +0100 Subject: [PATCH] Refactor audio VAE directory structure and update related paths in conversion and loading functions --- mlx_video/models/ltx_2/convert.py | 14 +++++++++----- mlx_video/models/ltx_2/generate.py | 2 +- mlx_video/models/ltx_2/weight_loading.py | 8 +++++--- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/mlx_video/models/ltx_2/convert.py b/mlx_video/models/ltx_2/convert.py index dadbcdd..fabf04e 100644 --- a/mlx_video/models/ltx_2/convert.py +++ b/mlx_video/models/ltx_2/convert.py @@ -15,9 +15,13 @@ or Lightricks/LTX-2.3/ltx-2.3-22b-distilled.safetensors) to the modular director │ └── encoder/ # Video VAE encoder │ ├── config.json │ └── model.safetensors - ├── audio_vae/ # Audio VAE decoder - │ ├── config.json - │ └── model.safetensors + ├── audio_vae/ + │ ├── decoder/ # Audio VAE decoder + │ │ ├── config.json + │ │ └── model.safetensors + │ └── encoder/ # Audio VAE encoder + │ ├── config.json + │ └── model.safetensors ├── vocoder/ # Audio vocoder │ ├── config.json │ └── model.safetensors @@ -622,9 +626,9 @@ def convert(source: str, output_path: Path, variant: str = "distilled"): # 4. Audio VAE Decoder print(" [4/6] Audio VAE Decoder...") audio_decoder_weights = sanitize_audio_decoder(all_weights) - save_single(audio_decoder_weights, output_path / "audio_vae") + save_single(audio_decoder_weights, output_path / "audio_vae" / "decoder") config = infer_audio_vae_config(audio_decoder_weights) - save_config(config, output_path / "audio_vae") + save_config(config, output_path / "audio_vae" / "decoder") a_params = sum(v.size for v in audio_decoder_weights.values()) print(f" {len(audio_decoder_weights)} keys, {a_params:,} params") diff --git a/mlx_video/models/ltx_2/generate.py b/mlx_video/models/ltx_2/generate.py index 2ef7da3..6f49147 100644 --- a/mlx_video/models/ltx_2/generate.py +++ b/mlx_video/models/ltx_2/generate.py @@ -1360,7 +1360,7 @@ def load_audio_decoder(model_path: Path, pipeline: PipelineType): """Load audio VAE decoder.""" from mlx_video.models.ltx_2.audio_vae import AudioDecoder - decoder = AudioDecoder.from_pretrained(model_path / "audio_vae") + decoder = AudioDecoder.from_pretrained(model_path / "audio_vae" / "decoder") return decoder diff --git a/mlx_video/models/ltx_2/weight_loading.py b/mlx_video/models/ltx_2/weight_loading.py index 2a8d463..234bb08 100644 --- a/mlx_video/models/ltx_2/weight_loading.py +++ b/mlx_video/models/ltx_2/weight_loading.py @@ -120,6 +120,8 @@ def load_audio_vae_weights(model_path: Path) -> Dict[str, mx.array]: """ # Try different possible paths for audio VAE weights audio_vae_paths = [ + model_path / "audio_vae" / "decoder" / "model.safetensors", + model_path / "audio_vae" / "decoder" / "diffusion_pytorch_model.safetensors", model_path / "audio_vae" / "diffusion_pytorch_model.safetensors", model_path / "audio_vae.safetensors", ] @@ -621,10 +623,10 @@ def convert_audio_encoder( source_repo: HF repo containing audio_vae/diffusion_pytorch_model.safetensors. Returns: - Path to the audio_vae_encoder directory. + Path to the audio_vae/encoder directory. """ model_path = Path(model_path) - encoder_dir = model_path / "audio_vae_encoder" + encoder_dir = model_path / "audio_vae" / "encoder" if (encoder_dir / "model.safetensors").exists(): return encoder_dir @@ -643,7 +645,7 @@ def convert_audio_encoder( from mlx_video.models.ltx_2.config import AudioEncoderModelConfig # Build config from the decoder config (same audio VAE architecture) - decoder_config_path = model_path / "audio_vae" / "config.json" + decoder_config_path = model_path / "audio_vae" / "decoder" / "config.json" if decoder_config_path.exists(): with open(decoder_config_path) as f: dec_cfg = json.load(f)