Add audio encoder sanitization and configuration inference to LTX-2 model conversion process; update conversion print statements for new encoder step

This commit is contained in:
Prince Canuma
2026-03-16 22:35:27 +01:00
parent 7a576bfbf4
commit f9880a0683

View File

@@ -189,6 +189,36 @@ def sanitize_audio_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
return sanitized
def sanitize_audio_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize audio VAE encoder keys: strip prefix, transpose Conv2d."""
sanitized = {}
for key, value in weights.items():
new_key = None
if key.startswith(AUDIO_ENCODER_PREFIX):
new_key = key[len(AUDIO_ENCODER_PREFIX):]
elif key.startswith(AUDIO_STATS_PREFIX):
if "mean-of-means" in key:
new_key = "per_channel_statistics.mean_of_means"
elif "std-of-means" in key:
new_key = "per_channel_statistics.std_of_means"
else:
continue
elif key == "latents_mean":
new_key = "per_channel_statistics.mean_of_means"
elif key == "latents_std":
new_key = "per_channel_statistics.std_of_means"
else:
continue
# Conv2d: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
def sanitize_vocoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize vocoder keys: strip prefix, transpose Conv1d/ConvTranspose1d."""
sanitized = {}
@@ -553,6 +583,31 @@ def infer_audio_vae_config(weights: Dict[str, mx.array]) -> dict:
}
def infer_audio_encoder_config(weights: Dict[str, mx.array]) -> dict:
"""Return audio encoder config (mirrors decoder but with encoder-specific fields)."""
return {
"attn_resolutions": [],
"attn_type": "vanilla",
"causality_axis": "height",
"ch": 128,
"ch_mult": [1, 2, 4],
"dropout": 0.0,
"in_channels": 2,
"double_z": True,
"is_causal": True,
"mel_bins": 64,
"mel_hop_length": 160,
"mid_block_add_attention": False,
"n_fft": 1024,
"norm_type": "pixel",
"num_res_blocks": 2,
"resamp_with_conv": True,
"resolution": 256,
"sample_rate": 16000,
"z_channels": 8,
}
def infer_vocoder_config(weights: Dict[str, mx.array]) -> dict:
"""Infer vocoder config from weights."""
# Check for bwe_generator (LTX-2.3 BigVGAN vocoder)
@@ -597,7 +652,7 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
print("\nExtracting components...")
# 1. Transformer
print(" [1/6] Transformer...")
print(" [1/7] Transformer...")
transformer_weights = sanitize_transformer(all_weights)
num_shards = save_sharded(transformer_weights, output_path / "transformer")
config = infer_transformer_config(transformer_weights)
@@ -606,7 +661,7 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
print(f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards")
# 2. VAE Decoder
print(" [2/6] VAE Decoder...")
print(" [2/7] VAE Decoder...")
vae_decoder_weights = sanitize_vae_decoder(all_weights)
save_single(vae_decoder_weights, output_path / "vae" / "decoder")
config = infer_vae_decoder_config(vae_decoder_weights, variant)
@@ -615,7 +670,7 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
print(f" {len(vae_decoder_weights)} keys, {d_params:,} params")
# 3. VAE Encoder
print(" [3/6] VAE Encoder...")
print(" [3/7] VAE Encoder...")
vae_encoder_weights = sanitize_vae_encoder(all_weights)
save_single(vae_encoder_weights, output_path / "vae" / "encoder")
config = infer_vae_encoder_config(vae_encoder_weights)
@@ -624,7 +679,7 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
print(f" {len(vae_encoder_weights)} keys, {e_params:,} params")
# 4. Audio VAE Decoder
print(" [4/6] Audio VAE Decoder...")
print(" [4/7] Audio VAE Decoder...")
audio_decoder_weights = sanitize_audio_decoder(all_weights)
save_single(audio_decoder_weights, output_path / "audio_vae" / "decoder")
config = infer_audio_vae_config(audio_decoder_weights)
@@ -632,8 +687,17 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
a_params = sum(v.size for v in audio_decoder_weights.values())
print(f" {len(audio_decoder_weights)} keys, {a_params:,} params")
# 5. Vocoder
print(" [5/6] Vocoder...")
# 5. Audio VAE Encoder
print(" [5/7] Audio VAE Encoder...")
audio_encoder_weights = sanitize_audio_encoder(all_weights)
save_single(audio_encoder_weights, output_path / "audio_vae" / "encoder")
config = infer_audio_encoder_config(audio_encoder_weights)
save_config(config, output_path / "audio_vae" / "encoder")
ae_params = sum(v.size for v in audio_encoder_weights.values())
print(f" {len(audio_encoder_weights)} keys, {ae_params:,} params")
# 6. Vocoder
print(" [6/7] Vocoder...")
vocoder_weights = sanitize_vocoder(all_weights)
save_single(vocoder_weights, output_path / "vocoder")
config = infer_vocoder_config(vocoder_weights)
@@ -641,8 +705,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
v_params = sum(v.size for v in vocoder_weights.values())
print(f" {len(vocoder_weights)} keys, {v_params:,} params")
# 6. Text Projections
print(" [6/6] Text Projections...")
# 7. Text Projections
print(" [7/7] Text Projections...")
text_proj_weights = extract_text_projections(all_weights)
tp_dir = output_path / "text_projections"
tp_dir.mkdir(parents=True, exist_ok=True)