fix LTX-2.3 audio

This commit is contained in:
Prince Canuma
2026-03-15 02:06:35 +01:00
parent eb0d1355e4
commit 53bae534e7
4 changed files with 649 additions and 130 deletions

View File

@@ -1012,13 +1012,14 @@ def load_audio_decoder(model_path: Path, pipeline: PipelineType):
return decoder
def load_vocoder(model_path: Path, pipeline: PipelineType):
"""Load vocoder for mel to waveform conversion."""
from mlx_video.models.ltx.audio_vae import Vocoder
def load_vocoder_model(model_path: Path, pipeline: PipelineType):
"""Load vocoder for mel to waveform conversion.
vocoder = Vocoder.from_pretrained(model_path / "vocoder")
Automatically detects HiFi-GAN (LTX-2) or BigVGAN+BWE (LTX-2.3).
"""
from mlx_video.models.ltx.audio_vae.vocoder import load_vocoder as _load_vocoder
return vocoder
return _load_vocoder(model_path / "vocoder")
def save_audio(audio: np.ndarray, path: Path, sample_rate: int = AUDIO_SAMPLE_RATE):
@@ -1795,7 +1796,7 @@ def generate_video(
if audio and audio_latents is not None:
with console.status("[blue]🔊 Decoding audio...[/]", spinner="dots"):
audio_decoder = load_audio_decoder(model_path, pipeline)
vocoder = load_vocoder(model_path, pipeline)
vocoder = load_vocoder_model(model_path, pipeline)
mx.eval(audio_decoder.parameters(), vocoder.parameters())
mel_spectrogram = audio_decoder(audio_latents)
@@ -1809,12 +1810,15 @@ def generate_video(
if audio_np.ndim == 3:
audio_np = audio_np[0]
# Get sample rate from vocoder (dynamic: 24kHz for LTX-2, 48kHz for LTX-2.3 BWE)
vocoder_sample_rate = getattr(vocoder, 'output_sampling_rate', AUDIO_SAMPLE_RATE)
del audio_decoder, vocoder
mx.clear_cache()
console.print("[green]✓[/] Audio decoded")
audio_path = Path(output_audio_path) if output_audio_path else output_path.with_suffix('.wav')
save_audio(audio_np, audio_path, AUDIO_SAMPLE_RATE)
save_audio(audio_np, audio_path, vocoder_sample_rate)
console.print(f"[green]✅ Saved audio to[/] {audio_path}")
with console.status("[blue]🎬 Combining video and audio...[/]", spinner="dots"):