Refactor audio VAE directory structure and update related paths in conversion and loading functions
This commit is contained in:
@@ -15,7 +15,11 @@ or Lightricks/LTX-2.3/ltx-2.3-22b-distilled.safetensors) to the modular director
|
|||||||
│ └── encoder/ # Video VAE encoder
|
│ └── encoder/ # Video VAE encoder
|
||||||
│ ├── config.json
|
│ ├── config.json
|
||||||
│ └── model.safetensors
|
│ └── model.safetensors
|
||||||
├── audio_vae/ # Audio VAE decoder
|
├── audio_vae/
|
||||||
|
│ ├── decoder/ # Audio VAE decoder
|
||||||
|
│ │ ├── config.json
|
||||||
|
│ │ └── model.safetensors
|
||||||
|
│ └── encoder/ # Audio VAE encoder
|
||||||
│ ├── config.json
|
│ ├── config.json
|
||||||
│ └── model.safetensors
|
│ └── model.safetensors
|
||||||
├── vocoder/ # Audio vocoder
|
├── vocoder/ # Audio vocoder
|
||||||
@@ -622,9 +626,9 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
|||||||
# 4. Audio VAE Decoder
|
# 4. Audio VAE Decoder
|
||||||
print(" [4/6] Audio VAE Decoder...")
|
print(" [4/6] Audio VAE Decoder...")
|
||||||
audio_decoder_weights = sanitize_audio_decoder(all_weights)
|
audio_decoder_weights = sanitize_audio_decoder(all_weights)
|
||||||
save_single(audio_decoder_weights, output_path / "audio_vae")
|
save_single(audio_decoder_weights, output_path / "audio_vae" / "decoder")
|
||||||
config = infer_audio_vae_config(audio_decoder_weights)
|
config = infer_audio_vae_config(audio_decoder_weights)
|
||||||
save_config(config, output_path / "audio_vae")
|
save_config(config, output_path / "audio_vae" / "decoder")
|
||||||
a_params = sum(v.size for v in audio_decoder_weights.values())
|
a_params = sum(v.size for v in audio_decoder_weights.values())
|
||||||
print(f" {len(audio_decoder_weights)} keys, {a_params:,} params")
|
print(f" {len(audio_decoder_weights)} keys, {a_params:,} params")
|
||||||
|
|
||||||
|
|||||||
@@ -1360,7 +1360,7 @@ def load_audio_decoder(model_path: Path, pipeline: PipelineType):
|
|||||||
"""Load audio VAE decoder."""
|
"""Load audio VAE decoder."""
|
||||||
from mlx_video.models.ltx_2.audio_vae import AudioDecoder
|
from mlx_video.models.ltx_2.audio_vae import AudioDecoder
|
||||||
|
|
||||||
decoder = AudioDecoder.from_pretrained(model_path / "audio_vae")
|
decoder = AudioDecoder.from_pretrained(model_path / "audio_vae" / "decoder")
|
||||||
|
|
||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
|||||||
@@ -120,6 +120,8 @@ def load_audio_vae_weights(model_path: Path) -> Dict[str, mx.array]:
|
|||||||
"""
|
"""
|
||||||
# Try different possible paths for audio VAE weights
|
# Try different possible paths for audio VAE weights
|
||||||
audio_vae_paths = [
|
audio_vae_paths = [
|
||||||
|
model_path / "audio_vae" / "decoder" / "model.safetensors",
|
||||||
|
model_path / "audio_vae" / "decoder" / "diffusion_pytorch_model.safetensors",
|
||||||
model_path / "audio_vae" / "diffusion_pytorch_model.safetensors",
|
model_path / "audio_vae" / "diffusion_pytorch_model.safetensors",
|
||||||
model_path / "audio_vae.safetensors",
|
model_path / "audio_vae.safetensors",
|
||||||
]
|
]
|
||||||
@@ -621,10 +623,10 @@ def convert_audio_encoder(
|
|||||||
source_repo: HF repo containing audio_vae/diffusion_pytorch_model.safetensors.
|
source_repo: HF repo containing audio_vae/diffusion_pytorch_model.safetensors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Path to the audio_vae_encoder directory.
|
Path to the audio_vae/encoder directory.
|
||||||
"""
|
"""
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
encoder_dir = model_path / "audio_vae_encoder"
|
encoder_dir = model_path / "audio_vae" / "encoder"
|
||||||
|
|
||||||
if (encoder_dir / "model.safetensors").exists():
|
if (encoder_dir / "model.safetensors").exists():
|
||||||
return encoder_dir
|
return encoder_dir
|
||||||
@@ -643,7 +645,7 @@ def convert_audio_encoder(
|
|||||||
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
||||||
|
|
||||||
# Build config from the decoder config (same audio VAE architecture)
|
# Build config from the decoder config (same audio VAE architecture)
|
||||||
decoder_config_path = model_path / "audio_vae" / "config.json"
|
decoder_config_path = model_path / "audio_vae" / "decoder" / "config.json"
|
||||||
if decoder_config_path.exists():
|
if decoder_config_path.exists():
|
||||||
with open(decoder_config_path) as f:
|
with open(decoder_config_path) as f:
|
||||||
dec_cfg = json.load(f)
|
dec_cfg = json.load(f)
|
||||||
|
|||||||
Reference in New Issue
Block a user