From f9880a068316b4b3689264e13eb8ef65b340e496 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 16 Mar 2026 22:35:27 +0100 Subject: [PATCH] Add audio encoder sanitization and configuration inference to LTX-2 model conversion process; update conversion print statements for new encoder step --- mlx_video/models/ltx_2/convert.py | 80 +++++++++++++++++++++++++++---- 1 file changed, 72 insertions(+), 8 deletions(-) diff --git a/mlx_video/models/ltx_2/convert.py b/mlx_video/models/ltx_2/convert.py index fabf04e..be02523 100644 --- a/mlx_video/models/ltx_2/convert.py +++ b/mlx_video/models/ltx_2/convert.py @@ -189,6 +189,36 @@ def sanitize_audio_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: return sanitized +def sanitize_audio_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: + """Sanitize audio VAE encoder keys: strip prefix, transpose Conv2d.""" + sanitized = {} + for key, value in weights.items(): + new_key = None + + if key.startswith(AUDIO_ENCODER_PREFIX): + new_key = key[len(AUDIO_ENCODER_PREFIX):] + elif key.startswith(AUDIO_STATS_PREFIX): + if "mean-of-means" in key: + new_key = "per_channel_statistics.mean_of_means" + elif "std-of-means" in key: + new_key = "per_channel_statistics.std_of_means" + else: + continue + elif key == "latents_mean": + new_key = "per_channel_statistics.mean_of_means" + elif key == "latents_std": + new_key = "per_channel_statistics.std_of_means" + else: + continue + + # Conv2d: PyTorch (O, I, H, W) -> MLX (O, H, W, I) + if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4: + value = mx.transpose(value, (0, 2, 3, 1)) + + sanitized[new_key] = value + return sanitized + + def sanitize_vocoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: """Sanitize vocoder keys: strip prefix, transpose Conv1d/ConvTranspose1d.""" sanitized = {} @@ -553,6 +583,31 @@ def infer_audio_vae_config(weights: Dict[str, mx.array]) -> dict: } +def infer_audio_encoder_config(weights: Dict[str, mx.array]) -> dict: + """Return audio encoder config (mirrors decoder but with encoder-specific fields).""" + return { + "attn_resolutions": [], + "attn_type": "vanilla", + "causality_axis": "height", + "ch": 128, + "ch_mult": [1, 2, 4], + "dropout": 0.0, + "in_channels": 2, + "double_z": True, + "is_causal": True, + "mel_bins": 64, + "mel_hop_length": 160, + "mid_block_add_attention": False, + "n_fft": 1024, + "norm_type": "pixel", + "num_res_blocks": 2, + "resamp_with_conv": True, + "resolution": 256, + "sample_rate": 16000, + "z_channels": 8, + } + + def infer_vocoder_config(weights: Dict[str, mx.array]) -> dict: """Infer vocoder config from weights.""" # Check for bwe_generator (LTX-2.3 BigVGAN vocoder) @@ -597,7 +652,7 @@ def convert(source: str, output_path: Path, variant: str = "distilled"): print("\nExtracting components...") # 1. Transformer - print(" [1/6] Transformer...") + print(" [1/7] Transformer...") transformer_weights = sanitize_transformer(all_weights) num_shards = save_sharded(transformer_weights, output_path / "transformer") config = infer_transformer_config(transformer_weights) @@ -606,7 +661,7 @@ def convert(source: str, output_path: Path, variant: str = "distilled"): print(f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards") # 2. VAE Decoder - print(" [2/6] VAE Decoder...") + print(" [2/7] VAE Decoder...") vae_decoder_weights = sanitize_vae_decoder(all_weights) save_single(vae_decoder_weights, output_path / "vae" / "decoder") config = infer_vae_decoder_config(vae_decoder_weights, variant) @@ -615,7 +670,7 @@ def convert(source: str, output_path: Path, variant: str = "distilled"): print(f" {len(vae_decoder_weights)} keys, {d_params:,} params") # 3. VAE Encoder - print(" [3/6] VAE Encoder...") + print(" [3/7] VAE Encoder...") vae_encoder_weights = sanitize_vae_encoder(all_weights) save_single(vae_encoder_weights, output_path / "vae" / "encoder") config = infer_vae_encoder_config(vae_encoder_weights) @@ -624,7 +679,7 @@ def convert(source: str, output_path: Path, variant: str = "distilled"): print(f" {len(vae_encoder_weights)} keys, {e_params:,} params") # 4. Audio VAE Decoder - print(" [4/6] Audio VAE Decoder...") + print(" [4/7] Audio VAE Decoder...") audio_decoder_weights = sanitize_audio_decoder(all_weights) save_single(audio_decoder_weights, output_path / "audio_vae" / "decoder") config = infer_audio_vae_config(audio_decoder_weights) @@ -632,8 +687,17 @@ def convert(source: str, output_path: Path, variant: str = "distilled"): a_params = sum(v.size for v in audio_decoder_weights.values()) print(f" {len(audio_decoder_weights)} keys, {a_params:,} params") - # 5. Vocoder - print(" [5/6] Vocoder...") + # 5. Audio VAE Encoder + print(" [5/7] Audio VAE Encoder...") + audio_encoder_weights = sanitize_audio_encoder(all_weights) + save_single(audio_encoder_weights, output_path / "audio_vae" / "encoder") + config = infer_audio_encoder_config(audio_encoder_weights) + save_config(config, output_path / "audio_vae" / "encoder") + ae_params = sum(v.size for v in audio_encoder_weights.values()) + print(f" {len(audio_encoder_weights)} keys, {ae_params:,} params") + + # 6. Vocoder + print(" [6/7] Vocoder...") vocoder_weights = sanitize_vocoder(all_weights) save_single(vocoder_weights, output_path / "vocoder") config = infer_vocoder_config(vocoder_weights) @@ -641,8 +705,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"): v_params = sum(v.size for v in vocoder_weights.values()) print(f" {len(vocoder_weights)} keys, {v_params:,} params") - # 6. Text Projections - print(" [6/6] Text Projections...") + # 7. Text Projections + print(" [7/7] Text Projections...") text_proj_weights = extract_text_projections(all_weights) tp_dir = output_path / "text_projections" tp_dir.mkdir(parents=True, exist_ok=True)