Add audio encoder sanitization and configuration inference to LTX-2 model conversion process; update conversion print statements for new encoder step
This commit is contained in:
@@ -189,6 +189,36 @@ def sanitize_audio_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_audio_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
|
"""Sanitize audio VAE encoder keys: strip prefix, transpose Conv2d."""
|
||||||
|
sanitized = {}
|
||||||
|
for key, value in weights.items():
|
||||||
|
new_key = None
|
||||||
|
|
||||||
|
if key.startswith(AUDIO_ENCODER_PREFIX):
|
||||||
|
new_key = key[len(AUDIO_ENCODER_PREFIX):]
|
||||||
|
elif key.startswith(AUDIO_STATS_PREFIX):
|
||||||
|
if "mean-of-means" in key:
|
||||||
|
new_key = "per_channel_statistics.mean_of_means"
|
||||||
|
elif "std-of-means" in key:
|
||||||
|
new_key = "per_channel_statistics.std_of_means"
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
elif key == "latents_mean":
|
||||||
|
new_key = "per_channel_statistics.mean_of_means"
|
||||||
|
elif key == "latents_std":
|
||||||
|
new_key = "per_channel_statistics.std_of_means"
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Conv2d: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
|
||||||
|
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||||
|
value = mx.transpose(value, (0, 2, 3, 1))
|
||||||
|
|
||||||
|
sanitized[new_key] = value
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
def sanitize_vocoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
def sanitize_vocoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
"""Sanitize vocoder keys: strip prefix, transpose Conv1d/ConvTranspose1d."""
|
"""Sanitize vocoder keys: strip prefix, transpose Conv1d/ConvTranspose1d."""
|
||||||
sanitized = {}
|
sanitized = {}
|
||||||
@@ -553,6 +583,31 @@ def infer_audio_vae_config(weights: Dict[str, mx.array]) -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def infer_audio_encoder_config(weights: Dict[str, mx.array]) -> dict:
|
||||||
|
"""Return audio encoder config (mirrors decoder but with encoder-specific fields)."""
|
||||||
|
return {
|
||||||
|
"attn_resolutions": [],
|
||||||
|
"attn_type": "vanilla",
|
||||||
|
"causality_axis": "height",
|
||||||
|
"ch": 128,
|
||||||
|
"ch_mult": [1, 2, 4],
|
||||||
|
"dropout": 0.0,
|
||||||
|
"in_channels": 2,
|
||||||
|
"double_z": True,
|
||||||
|
"is_causal": True,
|
||||||
|
"mel_bins": 64,
|
||||||
|
"mel_hop_length": 160,
|
||||||
|
"mid_block_add_attention": False,
|
||||||
|
"n_fft": 1024,
|
||||||
|
"norm_type": "pixel",
|
||||||
|
"num_res_blocks": 2,
|
||||||
|
"resamp_with_conv": True,
|
||||||
|
"resolution": 256,
|
||||||
|
"sample_rate": 16000,
|
||||||
|
"z_channels": 8,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def infer_vocoder_config(weights: Dict[str, mx.array]) -> dict:
|
def infer_vocoder_config(weights: Dict[str, mx.array]) -> dict:
|
||||||
"""Infer vocoder config from weights."""
|
"""Infer vocoder config from weights."""
|
||||||
# Check for bwe_generator (LTX-2.3 BigVGAN vocoder)
|
# Check for bwe_generator (LTX-2.3 BigVGAN vocoder)
|
||||||
@@ -597,7 +652,7 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
|||||||
print("\nExtracting components...")
|
print("\nExtracting components...")
|
||||||
|
|
||||||
# 1. Transformer
|
# 1. Transformer
|
||||||
print(" [1/6] Transformer...")
|
print(" [1/7] Transformer...")
|
||||||
transformer_weights = sanitize_transformer(all_weights)
|
transformer_weights = sanitize_transformer(all_weights)
|
||||||
num_shards = save_sharded(transformer_weights, output_path / "transformer")
|
num_shards = save_sharded(transformer_weights, output_path / "transformer")
|
||||||
config = infer_transformer_config(transformer_weights)
|
config = infer_transformer_config(transformer_weights)
|
||||||
@@ -606,7 +661,7 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
|||||||
print(f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards")
|
print(f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards")
|
||||||
|
|
||||||
# 2. VAE Decoder
|
# 2. VAE Decoder
|
||||||
print(" [2/6] VAE Decoder...")
|
print(" [2/7] VAE Decoder...")
|
||||||
vae_decoder_weights = sanitize_vae_decoder(all_weights)
|
vae_decoder_weights = sanitize_vae_decoder(all_weights)
|
||||||
save_single(vae_decoder_weights, output_path / "vae" / "decoder")
|
save_single(vae_decoder_weights, output_path / "vae" / "decoder")
|
||||||
config = infer_vae_decoder_config(vae_decoder_weights, variant)
|
config = infer_vae_decoder_config(vae_decoder_weights, variant)
|
||||||
@@ -615,7 +670,7 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
|||||||
print(f" {len(vae_decoder_weights)} keys, {d_params:,} params")
|
print(f" {len(vae_decoder_weights)} keys, {d_params:,} params")
|
||||||
|
|
||||||
# 3. VAE Encoder
|
# 3. VAE Encoder
|
||||||
print(" [3/6] VAE Encoder...")
|
print(" [3/7] VAE Encoder...")
|
||||||
vae_encoder_weights = sanitize_vae_encoder(all_weights)
|
vae_encoder_weights = sanitize_vae_encoder(all_weights)
|
||||||
save_single(vae_encoder_weights, output_path / "vae" / "encoder")
|
save_single(vae_encoder_weights, output_path / "vae" / "encoder")
|
||||||
config = infer_vae_encoder_config(vae_encoder_weights)
|
config = infer_vae_encoder_config(vae_encoder_weights)
|
||||||
@@ -624,7 +679,7 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
|||||||
print(f" {len(vae_encoder_weights)} keys, {e_params:,} params")
|
print(f" {len(vae_encoder_weights)} keys, {e_params:,} params")
|
||||||
|
|
||||||
# 4. Audio VAE Decoder
|
# 4. Audio VAE Decoder
|
||||||
print(" [4/6] Audio VAE Decoder...")
|
print(" [4/7] Audio VAE Decoder...")
|
||||||
audio_decoder_weights = sanitize_audio_decoder(all_weights)
|
audio_decoder_weights = sanitize_audio_decoder(all_weights)
|
||||||
save_single(audio_decoder_weights, output_path / "audio_vae" / "decoder")
|
save_single(audio_decoder_weights, output_path / "audio_vae" / "decoder")
|
||||||
config = infer_audio_vae_config(audio_decoder_weights)
|
config = infer_audio_vae_config(audio_decoder_weights)
|
||||||
@@ -632,8 +687,17 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
|||||||
a_params = sum(v.size for v in audio_decoder_weights.values())
|
a_params = sum(v.size for v in audio_decoder_weights.values())
|
||||||
print(f" {len(audio_decoder_weights)} keys, {a_params:,} params")
|
print(f" {len(audio_decoder_weights)} keys, {a_params:,} params")
|
||||||
|
|
||||||
# 5. Vocoder
|
# 5. Audio VAE Encoder
|
||||||
print(" [5/6] Vocoder...")
|
print(" [5/7] Audio VAE Encoder...")
|
||||||
|
audio_encoder_weights = sanitize_audio_encoder(all_weights)
|
||||||
|
save_single(audio_encoder_weights, output_path / "audio_vae" / "encoder")
|
||||||
|
config = infer_audio_encoder_config(audio_encoder_weights)
|
||||||
|
save_config(config, output_path / "audio_vae" / "encoder")
|
||||||
|
ae_params = sum(v.size for v in audio_encoder_weights.values())
|
||||||
|
print(f" {len(audio_encoder_weights)} keys, {ae_params:,} params")
|
||||||
|
|
||||||
|
# 6. Vocoder
|
||||||
|
print(" [6/7] Vocoder...")
|
||||||
vocoder_weights = sanitize_vocoder(all_weights)
|
vocoder_weights = sanitize_vocoder(all_weights)
|
||||||
save_single(vocoder_weights, output_path / "vocoder")
|
save_single(vocoder_weights, output_path / "vocoder")
|
||||||
config = infer_vocoder_config(vocoder_weights)
|
config = infer_vocoder_config(vocoder_weights)
|
||||||
@@ -641,8 +705,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
|||||||
v_params = sum(v.size for v in vocoder_weights.values())
|
v_params = sum(v.size for v in vocoder_weights.values())
|
||||||
print(f" {len(vocoder_weights)} keys, {v_params:,} params")
|
print(f" {len(vocoder_weights)} keys, {v_params:,} params")
|
||||||
|
|
||||||
# 6. Text Projections
|
# 7. Text Projections
|
||||||
print(" [6/6] Text Projections...")
|
print(" [7/7] Text Projections...")
|
||||||
text_proj_weights = extract_text_projections(all_weights)
|
text_proj_weights = extract_text_projections(all_weights)
|
||||||
tp_dir = output_path / "text_projections"
|
tp_dir = output_path / "text_projections"
|
||||||
tp_dir.mkdir(parents=True, exist_ok=True)
|
tp_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user