diff --git a/mlx_video/models/ltx_2/convert.py b/mlx_video/models/ltx_2/convert.py index be02523..bc6d239 100644 --- a/mlx_video/models/ltx_2/convert.py +++ b/mlx_video/models/ltx_2/convert.py @@ -714,7 +714,7 @@ def convert(source: str, output_path: Path, variant: str = "distilled"): tp_params = sum(v.size for v in text_proj_weights.values()) print(f" {len(text_proj_weights)} keys, {tp_params:,} params") - # 7. Copy upscaler files + # Copy upscaler files print("\nCopying upscaler files...") source_dir = source_path.parent is_hf_repo = "/" in source and not Path(source).exists() @@ -755,7 +755,7 @@ def convert(source: str, output_path: Path, variant: str = "distilled"): else: print(f" {upscaler_file}: not found, skipping") - # 8. Link text_encoder and tokenizer directories + # Link text_encoder and tokenizer directories print("\nLinking text encoder & tokenizer...") for subdir in ["text_encoder", "tokenizer"]: dest = output_path / subdir @@ -793,32 +793,19 @@ def convert(source: str, output_path: Path, variant: str = "distilled"): + len(vae_decoder_weights) + len(vae_encoder_weights) + len(audio_decoder_weights) + + len(audio_encoder_weights) + len(vocoder_weights) + len(text_proj_weights) ) print(f"\nDone! Converted {all_converted}/{total_keys} keys") if all_converted < total_keys: - # Find unconverted keys - converted_prefixes = set() - for key in all_weights: - if key.startswith(TRANSFORMER_PREFIX): - converted_prefixes.add(key) - elif key.startswith(VAE_DECODER_PREFIX) or key.startswith(VAE_STATS_PREFIX): - converted_prefixes.add(key) - elif key.startswith(VAE_ENCODER_PREFIX): - converted_prefixes.add(key) - elif key.startswith(AUDIO_DECODER_PREFIX) or key.startswith(AUDIO_STATS_PREFIX): - converted_prefixes.add(key) - elif key.startswith(AUDIO_ENCODER_PREFIX): - converted_prefixes.add(key) - elif key.startswith(VOCODER_PREFIX): - converted_prefixes.add(key) - elif key.startswith(TEXT_PROJ_PREFIX): - converted_prefixes.add(key) - elif key.startswith(VIDEO_CONNECTOR_PREFIX) or key.startswith(AUDIO_CONNECTOR_PREFIX): - converted_prefixes.add(key) - - skipped = set(all_weights.keys()) - converted_prefixes + known_prefixes = ( + TRANSFORMER_PREFIX, VAE_DECODER_PREFIX, VAE_ENCODER_PREFIX, + VAE_STATS_PREFIX, AUDIO_DECODER_PREFIX, AUDIO_ENCODER_PREFIX, + AUDIO_STATS_PREFIX, VOCODER_PREFIX, TEXT_PROJ_PREFIX, + VIDEO_CONNECTOR_PREFIX, AUDIO_CONNECTOR_PREFIX, + ) + skipped = [k for k in all_weights if not any(k.startswith(p) for p in known_prefixes)] if skipped: print(f" Skipped {len(skipped)} keys:") for k in sorted(skipped)[:20]: