Add audio encoder sanitization and configuration inference to LTX-2 model conversion process; update conversion print statements for new encoder step
This commit is contained in:
@@ -189,6 +189,36 @@ def sanitize_audio_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_audio_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize audio VAE encoder keys: strip prefix, transpose Conv2d."""
|
||||
sanitized = {}
|
||||
for key, value in weights.items():
|
||||
new_key = None
|
||||
|
||||
if key.startswith(AUDIO_ENCODER_PREFIX):
|
||||
new_key = key[len(AUDIO_ENCODER_PREFIX):]
|
||||
elif key.startswith(AUDIO_STATS_PREFIX):
|
||||
if "mean-of-means" in key:
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
elif "std-of-means" in key:
|
||||
new_key = "per_channel_statistics.std_of_means"
|
||||
else:
|
||||
continue
|
||||
elif key == "latents_mean":
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
elif key == "latents_std":
|
||||
new_key = "per_channel_statistics.std_of_means"
|
||||
else:
|
||||
continue
|
||||
|
||||
# Conv2d: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_vocoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize vocoder keys: strip prefix, transpose Conv1d/ConvTranspose1d."""
|
||||
sanitized = {}
|
||||
@@ -553,6 +583,31 @@ def infer_audio_vae_config(weights: Dict[str, mx.array]) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def infer_audio_encoder_config(weights: Dict[str, mx.array]) -> dict:
|
||||
"""Return audio encoder config (mirrors decoder but with encoder-specific fields)."""
|
||||
return {
|
||||
"attn_resolutions": [],
|
||||
"attn_type": "vanilla",
|
||||
"causality_axis": "height",
|
||||
"ch": 128,
|
||||
"ch_mult": [1, 2, 4],
|
||||
"dropout": 0.0,
|
||||
"in_channels": 2,
|
||||
"double_z": True,
|
||||
"is_causal": True,
|
||||
"mel_bins": 64,
|
||||
"mel_hop_length": 160,
|
||||
"mid_block_add_attention": False,
|
||||
"n_fft": 1024,
|
||||
"norm_type": "pixel",
|
||||
"num_res_blocks": 2,
|
||||
"resamp_with_conv": True,
|
||||
"resolution": 256,
|
||||
"sample_rate": 16000,
|
||||
"z_channels": 8,
|
||||
}
|
||||
|
||||
|
||||
def infer_vocoder_config(weights: Dict[str, mx.array]) -> dict:
|
||||
"""Infer vocoder config from weights."""
|
||||
# Check for bwe_generator (LTX-2.3 BigVGAN vocoder)
|
||||
@@ -597,7 +652,7 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
print("\nExtracting components...")
|
||||
|
||||
# 1. Transformer
|
||||
print(" [1/6] Transformer...")
|
||||
print(" [1/7] Transformer...")
|
||||
transformer_weights = sanitize_transformer(all_weights)
|
||||
num_shards = save_sharded(transformer_weights, output_path / "transformer")
|
||||
config = infer_transformer_config(transformer_weights)
|
||||
@@ -606,7 +661,7 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
print(f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards")
|
||||
|
||||
# 2. VAE Decoder
|
||||
print(" [2/6] VAE Decoder...")
|
||||
print(" [2/7] VAE Decoder...")
|
||||
vae_decoder_weights = sanitize_vae_decoder(all_weights)
|
||||
save_single(vae_decoder_weights, output_path / "vae" / "decoder")
|
||||
config = infer_vae_decoder_config(vae_decoder_weights, variant)
|
||||
@@ -615,7 +670,7 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
print(f" {len(vae_decoder_weights)} keys, {d_params:,} params")
|
||||
|
||||
# 3. VAE Encoder
|
||||
print(" [3/6] VAE Encoder...")
|
||||
print(" [3/7] VAE Encoder...")
|
||||
vae_encoder_weights = sanitize_vae_encoder(all_weights)
|
||||
save_single(vae_encoder_weights, output_path / "vae" / "encoder")
|
||||
config = infer_vae_encoder_config(vae_encoder_weights)
|
||||
@@ -624,7 +679,7 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
print(f" {len(vae_encoder_weights)} keys, {e_params:,} params")
|
||||
|
||||
# 4. Audio VAE Decoder
|
||||
print(" [4/6] Audio VAE Decoder...")
|
||||
print(" [4/7] Audio VAE Decoder...")
|
||||
audio_decoder_weights = sanitize_audio_decoder(all_weights)
|
||||
save_single(audio_decoder_weights, output_path / "audio_vae" / "decoder")
|
||||
config = infer_audio_vae_config(audio_decoder_weights)
|
||||
@@ -632,8 +687,17 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
a_params = sum(v.size for v in audio_decoder_weights.values())
|
||||
print(f" {len(audio_decoder_weights)} keys, {a_params:,} params")
|
||||
|
||||
# 5. Vocoder
|
||||
print(" [5/6] Vocoder...")
|
||||
# 5. Audio VAE Encoder
|
||||
print(" [5/7] Audio VAE Encoder...")
|
||||
audio_encoder_weights = sanitize_audio_encoder(all_weights)
|
||||
save_single(audio_encoder_weights, output_path / "audio_vae" / "encoder")
|
||||
config = infer_audio_encoder_config(audio_encoder_weights)
|
||||
save_config(config, output_path / "audio_vae" / "encoder")
|
||||
ae_params = sum(v.size for v in audio_encoder_weights.values())
|
||||
print(f" {len(audio_encoder_weights)} keys, {ae_params:,} params")
|
||||
|
||||
# 6. Vocoder
|
||||
print(" [6/7] Vocoder...")
|
||||
vocoder_weights = sanitize_vocoder(all_weights)
|
||||
save_single(vocoder_weights, output_path / "vocoder")
|
||||
config = infer_vocoder_config(vocoder_weights)
|
||||
@@ -641,8 +705,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
v_params = sum(v.size for v in vocoder_weights.values())
|
||||
print(f" {len(vocoder_weights)} keys, {v_params:,} params")
|
||||
|
||||
# 6. Text Projections
|
||||
print(" [6/6] Text Projections...")
|
||||
# 7. Text Projections
|
||||
print(" [7/7] Text Projections...")
|
||||
text_proj_weights = extract_text_projections(all_weights)
|
||||
tp_dir = output_path / "text_projections"
|
||||
tp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
Reference in New Issue
Block a user