diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 1d716ee..1ac4508 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -648,6 +648,7 @@ def load_audio_decoder(model_path: Path, pipeline: PipelineType): norm_type=NormType.PIXEL, causality_axis=CausalityAxis.HEIGHT, mel_bins=64, + mid_block_add_attention=False, # Config says no attention in mid block ) weight_file = model_path / ("ltx-2-19b-dev.safetensors" if pipeline == PipelineType.DEV else "ltx-2-19b-distilled.safetensors") @@ -1277,7 +1278,7 @@ def generate_video( audio_waveform = vocoder(mel_spectrogram) mx.eval(audio_waveform) - audio_np = np.array(audio_waveform) + audio_np = np.array(audio_waveform.astype(mx.float32)) if audio_np.ndim == 3: audio_np = audio_np[0]