Refactor comments and optimize key skipping logic in LTX-2 model conversion; improve clarity in code documentation
This commit is contained in:
@@ -714,7 +714,7 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
|||||||
tp_params = sum(v.size for v in text_proj_weights.values())
|
tp_params = sum(v.size for v in text_proj_weights.values())
|
||||||
print(f" {len(text_proj_weights)} keys, {tp_params:,} params")
|
print(f" {len(text_proj_weights)} keys, {tp_params:,} params")
|
||||||
|
|
||||||
# 7. Copy upscaler files
|
# Copy upscaler files
|
||||||
print("\nCopying upscaler files...")
|
print("\nCopying upscaler files...")
|
||||||
source_dir = source_path.parent
|
source_dir = source_path.parent
|
||||||
is_hf_repo = "/" in source and not Path(source).exists()
|
is_hf_repo = "/" in source and not Path(source).exists()
|
||||||
@@ -755,7 +755,7 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
|||||||
else:
|
else:
|
||||||
print(f" {upscaler_file}: not found, skipping")
|
print(f" {upscaler_file}: not found, skipping")
|
||||||
|
|
||||||
# 8. Link text_encoder and tokenizer directories
|
# Link text_encoder and tokenizer directories
|
||||||
print("\nLinking text encoder & tokenizer...")
|
print("\nLinking text encoder & tokenizer...")
|
||||||
for subdir in ["text_encoder", "tokenizer"]:
|
for subdir in ["text_encoder", "tokenizer"]:
|
||||||
dest = output_path / subdir
|
dest = output_path / subdir
|
||||||
@@ -793,32 +793,19 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
|||||||
+ len(vae_decoder_weights)
|
+ len(vae_decoder_weights)
|
||||||
+ len(vae_encoder_weights)
|
+ len(vae_encoder_weights)
|
||||||
+ len(audio_decoder_weights)
|
+ len(audio_decoder_weights)
|
||||||
|
+ len(audio_encoder_weights)
|
||||||
+ len(vocoder_weights)
|
+ len(vocoder_weights)
|
||||||
+ len(text_proj_weights)
|
+ len(text_proj_weights)
|
||||||
)
|
)
|
||||||
print(f"\nDone! Converted {all_converted}/{total_keys} keys")
|
print(f"\nDone! Converted {all_converted}/{total_keys} keys")
|
||||||
if all_converted < total_keys:
|
if all_converted < total_keys:
|
||||||
# Find unconverted keys
|
known_prefixes = (
|
||||||
converted_prefixes = set()
|
TRANSFORMER_PREFIX, VAE_DECODER_PREFIX, VAE_ENCODER_PREFIX,
|
||||||
for key in all_weights:
|
VAE_STATS_PREFIX, AUDIO_DECODER_PREFIX, AUDIO_ENCODER_PREFIX,
|
||||||
if key.startswith(TRANSFORMER_PREFIX):
|
AUDIO_STATS_PREFIX, VOCODER_PREFIX, TEXT_PROJ_PREFIX,
|
||||||
converted_prefixes.add(key)
|
VIDEO_CONNECTOR_PREFIX, AUDIO_CONNECTOR_PREFIX,
|
||||||
elif key.startswith(VAE_DECODER_PREFIX) or key.startswith(VAE_STATS_PREFIX):
|
)
|
||||||
converted_prefixes.add(key)
|
skipped = [k for k in all_weights if not any(k.startswith(p) for p in known_prefixes)]
|
||||||
elif key.startswith(VAE_ENCODER_PREFIX):
|
|
||||||
converted_prefixes.add(key)
|
|
||||||
elif key.startswith(AUDIO_DECODER_PREFIX) or key.startswith(AUDIO_STATS_PREFIX):
|
|
||||||
converted_prefixes.add(key)
|
|
||||||
elif key.startswith(AUDIO_ENCODER_PREFIX):
|
|
||||||
converted_prefixes.add(key)
|
|
||||||
elif key.startswith(VOCODER_PREFIX):
|
|
||||||
converted_prefixes.add(key)
|
|
||||||
elif key.startswith(TEXT_PROJ_PREFIX):
|
|
||||||
converted_prefixes.add(key)
|
|
||||||
elif key.startswith(VIDEO_CONNECTOR_PREFIX) or key.startswith(AUDIO_CONNECTOR_PREFIX):
|
|
||||||
converted_prefixes.add(key)
|
|
||||||
|
|
||||||
skipped = set(all_weights.keys()) - converted_prefixes
|
|
||||||
if skipped:
|
if skipped:
|
||||||
print(f" Skipped {len(skipped)} keys:")
|
print(f" Skipped {len(skipped)} keys:")
|
||||||
for k in sorted(skipped)[:20]:
|
for k in sorted(skipped)[:20]:
|
||||||
|
|||||||
Reference in New Issue
Block a user