From d028b239fb2f75fde955ae4c214b2857b40e80e3 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Tue, 10 Mar 2026 08:01:26 +0100 Subject: [PATCH] Update LTX conversion script to support LTX-2.3 safetensors format. Enhance documentation and improve file matching logic for variant detection in local directories. --- mlx_video/models/ltx/convert.py | 397 +++++++++++++++++++------------- 1 file changed, 240 insertions(+), 157 deletions(-) diff --git a/mlx_video/models/ltx/convert.py b/mlx_video/models/ltx/convert.py index b2564a2..330c11c 100644 --- a/mlx_video/models/ltx/convert.py +++ b/mlx_video/models/ltx/convert.py @@ -1,7 +1,7 @@ -"""Convert LTX-2 safetensors to MLX directory layout. +"""Convert LTX-2/2.3 safetensors to MLX directory layout. -Converts from the single-file format (e.g. Lightricks/LTX-2/ltx-2-19b-distilled.safetensors) -to the modular directory structure: +Converts from the single-file format (e.g. Lightricks/LTX-2/ltx-2-19b-distilled.safetensors +or Lightricks/LTX-2.3/ltx-2.3-22b-distilled.safetensors) to the modular directory structure: output/ ├── transformer/ # DiT transformer weights (sharded) @@ -27,7 +27,7 @@ to the modular directory structure: Usage: # From HF repo ID python -m mlx_video.models.ltx.convert --source Lightricks/LTX-2 --output LTX-2-distilled --variant distilled - python -m mlx_video.models.ltx.convert --source Lightricks/LTX-2 --output LTX-2-dev --variant dev + python -m mlx_video.models.ltx.convert --source Lightricks/LTX-2.3 --output LTX-2.3-distilled --variant distilled # From local folder containing the monolithic safetensors python -m mlx_video.models.ltx.convert --source ./Lightricks-LTX-2/ --output LTX-2-distilled --variant distilled @@ -46,111 +46,6 @@ from typing import Dict import mlx.core as mx -# ─── Component configs ──────────────────────────────────────────────────────── - -TRANSFORMER_CONFIG = { - "attention_head_dim": 128, - "attention_type": "default", - "audio_attention_head_dim": 64, - "audio_caption_channels": 3840, - "audio_cross_attention_dim": 2048, - "audio_in_channels": 128, - "audio_num_attention_heads": 32, - "audio_out_channels": 128, - "audio_positional_embedding_max_pos": [20], - "av_ca_timestep_scale_multiplier": 1000, - "caption_channels": 3840, - "cross_attention_dim": 4096, - "double_precision_rope": True, - "in_channels": 128, - "model_type": "ltx av model", - "norm_eps": 1e-06, - "num_attention_heads": 32, - "num_layers": 48, - "out_channels": 128, - "positional_embedding_max_pos": [20, 2048, 2048], - "positional_embedding_theta": 10000.0, - "rope_type": "split", - "timestep_scale_multiplier": 1000, - "use_middle_indices_grid": True, -} - -VAE_DECODER_CONFIG_DISTILLED = { - "ch": 128, - "ch_mult": [1, 2, 4], - "dropout": 0.0, - "num_res_blocks": 2, - "out_ch": 2, - "resolution": 256, - "timestep_conditioning": False, - "z_channels": 8, -} - -VAE_DECODER_CONFIG_DEV = { - "ch": 128, - "ch_mult": [1, 2, 4], - "dropout": 0.0, - "num_res_blocks": 2, - "out_ch": 2, - "resolution": 256, - "timestep_conditioning": True, - "z_channels": 8, -} - -VAE_ENCODER_CONFIG = { - "convolution_dimensions": 3, - "encoder_blocks": [ - ["res_x", {"num_layers": 4}], - ["compress_space_res", {"multiplier": 2}], - ["res_x", {"num_layers": 6}], - ["compress_time_res", {"multiplier": 2}], - ["res_x", {"num_layers": 6}], - ["compress_all_res", {"multiplier": 2}], - ["res_x", {"num_layers": 2}], - ["compress_all_res", {"multiplier": 2}], - ["res_x", {"num_layers": 2}], - ], - "encoder_spatial_padding_mode": "zeros", - "in_channels": 3, - "latent_log_var": "uniform", - "norm_layer": "pixel_norm", - "out_channels": 128, - "patch_size": 4, -} - -AUDIO_VAE_CONFIG = { - "attn_resolutions": [], - "attn_type": "vanilla", - "causality_axis": "height", - "ch": 128, - "ch_mult": [1, 2, 4], - "dropout": 0.0, - "give_pre_end": False, - "is_causal": True, - "mel_bins": 64, - "mel_hop_length": 160, - "mid_block_add_attention": False, - "norm_type": "pixel", - "num_res_blocks": 2, - "out_ch": 2, - "resamp_with_conv": True, - "resolution": 256, - "sample_rate": 16000, - "tanh_out": False, - "z_channels": 8, -} - -VOCODER_CONFIG = { - "output_sample_rate": 24000, - "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], - "resblock_kernel_sizes": [3, 7, 11], - "stereo": True, - "upsample_initial_channel": 1024, - "upsample_kernel_sizes": [16, 15, 8, 4, 4], - "upsample_rates": [6, 5, 2, 2, 2], -} - - # ─── Key prefix routing ────────────────────────────────────────────────────── TRANSFORMER_PREFIX = "model.diffusion_model." @@ -158,9 +53,10 @@ VAE_DECODER_PREFIX = "vae.decoder." VAE_ENCODER_PREFIX = "vae.encoder." VAE_STATS_PREFIX = "vae.per_channel_statistics." AUDIO_DECODER_PREFIX = "audio_vae.decoder." +AUDIO_ENCODER_PREFIX = "audio_vae.encoder." AUDIO_STATS_PREFIX = "audio_vae.per_channel_statistics." VOCODER_PREFIX = "vocoder." -TEXT_AGG_KEY = "text_embedding_projection.aggregate_embed.weight" +TEXT_PROJ_PREFIX = "text_embedding_projection." VIDEO_CONNECTOR_PREFIX = "model.diffusion_model.video_embeddings_connector." AUDIO_CONNECTOR_PREFIX = "model.diffusion_model.audio_embeddings_connector." @@ -320,12 +216,18 @@ def sanitize_connector_key(key: str) -> str: def extract_text_projections(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: - """Extract text projection weights (aggregate_embed + connectors).""" + """Extract text projection weights (aggregate_embed + connectors). + + Handles both LTX-2 (aggregate_embed.weight) and LTX-2.3 + (video_aggregate_embed.*, audio_aggregate_embed.*) formats. + """ extracted = {} - # aggregate_embed - if TEXT_AGG_KEY in weights: - extracted["aggregate_embed.weight"] = weights[TEXT_AGG_KEY] + # aggregate_embed weights (text_embedding_projection.*) + for key, value in weights.items(): + if key.startswith(TEXT_PROJ_PREFIX): + new_key = key[len(TEXT_PROJ_PREFIX):] + extracted[new_key] = value # video_embeddings_connector for key, value in weights.items(): @@ -432,10 +334,8 @@ def save_config(config: dict, output_dir: Path): # ─── Source resolution ───────────────────────────────────────────────────────── -VARIANT_FILE_PATTERNS = { - "distilled": "ltx-2-19b-distilled.safetensors", - "dev": "ltx-2-19b-dev.safetensors", -} +# Matches monolithic model files: ltx-2-19b-distilled.safetensors, ltx-2.3-22b-dev.safetensors, etc. +MONOLITHIC_PATTERN = re.compile(r"^ltx-[\d.]+-\d+b-(?Pdistilled|dev)\.safetensors$") # Matches upscaler files like ltx-2-spatial-upscaler-x2-1.0.safetensors, # ltx-2.3-spatial-upscaler-x2-1.0.safetensors, etc. @@ -458,40 +358,49 @@ def resolve_source(source: str, variant: str) -> Path: if source_path.is_file(): return source_path - # Local directory — look for the variant's safetensors file + # Local directory — find the variant's safetensors file if source_path.is_dir(): - target = VARIANT_FILE_PATTERNS.get(variant) - if target: - candidate = source_path / target - if candidate.is_file(): - return candidate + matches = [] + for f in sorted(source_path.glob("ltx-*b-*.safetensors")): + m = MONOLITHIC_PATTERN.match(f.name) + if m and m.group("variant") == variant: + matches.append(f) - # Fallback: glob for ltx-2-19b-*.safetensors - matches = sorted(source_path.glob("ltx-2-19b-*.safetensors")) if matches: - if len(matches) == 1: - return matches[0] - # Multiple matches — pick by variant keyword - for m in matches: - if variant in m.name: - return m return matches[0] + # Broader fallback + all_mono = sorted(source_path.glob("ltx-*.safetensors")) + for f in all_mono: + if variant in f.name and MONOLITHIC_PATTERN.match(f.name): + return f + raise FileNotFoundError( - f"No ltx-2-19b-*.safetensors found in {source_path}. " - f"Expected {target} for variant '{variant}'." + f"No monolithic *-{variant}.safetensors found in {source_path}. " + f"Files found: {[f.name for f in all_mono]}" ) # HF repo ID — download via huggingface_hub if "/" in source and not source_path.exists(): - from huggingface_hub import hf_hub_download + from huggingface_hub import hf_hub_download, list_repo_files - filename = VARIANT_FILE_PATTERNS.get(variant) - if not filename: - raise ValueError(f"Unknown variant '{variant}'. Expected 'distilled' or 'dev'.") + # Find the right file in the repo + repo_files = list_repo_files(source) + target = None + for f in repo_files: + m = MONOLITHIC_PATTERN.match(f) + if m and m.group("variant") == variant: + target = f + break - print(f"Downloading {filename} from {source}...") - local_path = hf_hub_download(repo_id=source, filename=filename) + if not target: + raise FileNotFoundError( + f"No *-{variant}.safetensors found in {source}. " + f"Available: {[f for f in repo_files if f.endswith('.safetensors')]}" + ) + + print(f"Downloading {target} from {source}...") + local_path = hf_hub_download(repo_id=source, filename=target) return Path(local_path) raise FileNotFoundError( @@ -499,6 +408,169 @@ def resolve_source(source: str, variant: str) -> Path: ) +# ─── Config inference ───────────────────────────────────────────────────────── + + +def infer_transformer_config(weights: Dict[str, mx.array]) -> dict: + """Infer transformer config from weight shapes.""" + # Count transformer layers + max_layer = -1 + for key in weights: + if "transformer_blocks." in key: + parts = key.split(".") + try: + idx = parts.index("transformer_blocks") + 1 + if idx < len(parts) and parts[idx].isdigit(): + max_layer = max(max_layer, int(parts[idx])) + except ValueError: + pass + num_layers = max_layer + 1 if max_layer >= 0 else 48 + + # Detect cross_attention_dim from attn2.to_k (cross-attention input dim) + cross_attention_dim = 4096 + for key, value in weights.items(): + if "transformer_blocks.0.attn2.to_k.weight" in key: + cross_attention_dim = value.shape[-1] + break + + # Check for prompt_adaln_single (LTX-2.3 feature) + has_prompt_adaln = any("prompt_adaln_single" in k for k in weights) + + config = { + "attention_head_dim": 128, + "attention_type": "default", + "audio_attention_head_dim": 64, + "audio_caption_channels": 3840, + "audio_cross_attention_dim": 2048, + "audio_in_channels": 128, + "audio_num_attention_heads": 32, + "audio_out_channels": 128, + "audio_positional_embedding_max_pos": [20], + "av_ca_timestep_scale_multiplier": 1000, + "caption_channels": 3840, + "cross_attention_dim": cross_attention_dim, + "double_precision_rope": True, + "in_channels": 128, + "model_type": "ltx av model", + "norm_eps": 1e-06, + "num_attention_heads": 32, + "num_layers": num_layers, + "out_channels": 128, + "positional_embedding_max_pos": [20, 2048, 2048], + "positional_embedding_theta": 10000.0, + "rope_type": "split", + "timestep_scale_multiplier": 1000, + "use_middle_indices_grid": True, + } + + if has_prompt_adaln: + config["has_prompt_adaln"] = True + + return config + + +def infer_vae_decoder_config(weights: Dict[str, mx.array], variant: str) -> dict: + """Infer VAE decoder config from weights.""" + # Check for timestep conditioning keys + has_timestep = any("last_time_embedder" in k or "last_scale_shift_table" in k for k in weights) + + # Count channel multipliers from up_blocks + max_block = -1 + for key in weights: + if "up_blocks." in key: + parts = key.split(".") + try: + idx = parts.index("up_blocks") + 1 + if idx < len(parts) and parts[idx].isdigit(): + max_block = max(max_block, int(parts[idx])) + except ValueError: + pass + + # Default config + config = { + "ch": 128, + "ch_mult": [1, 2, 4], + "dropout": 0.0, + "num_res_blocks": 2, + "out_ch": 2, + "resolution": 256, + "timestep_conditioning": has_timestep if has_timestep else (variant == "dev"), + "z_channels": 8, + } + return config + + +def infer_vae_encoder_config(weights: Dict[str, mx.array]) -> dict: + """Return VAE encoder config (architecture is consistent across versions).""" + return { + "convolution_dimensions": 3, + "encoder_blocks": [ + ["res_x", {"num_layers": 4}], + ["compress_space_res", {"multiplier": 2}], + ["res_x", {"num_layers": 6}], + ["compress_time_res", {"multiplier": 2}], + ["res_x", {"num_layers": 6}], + ["compress_all_res", {"multiplier": 2}], + ["res_x", {"num_layers": 2}], + ["compress_all_res", {"multiplier": 2}], + ["res_x", {"num_layers": 2}], + ], + "encoder_spatial_padding_mode": "zeros", + "in_channels": 3, + "latent_log_var": "uniform", + "norm_layer": "pixel_norm", + "out_channels": 128, + "patch_size": 4, + } + + +def infer_audio_vae_config(weights: Dict[str, mx.array]) -> dict: + """Return audio VAE config.""" + return { + "attn_resolutions": [], + "attn_type": "vanilla", + "causality_axis": "height", + "ch": 128, + "ch_mult": [1, 2, 4], + "dropout": 0.0, + "give_pre_end": False, + "is_causal": True, + "mel_bins": 64, + "mel_hop_length": 160, + "mid_block_add_attention": False, + "norm_type": "pixel", + "num_res_blocks": 2, + "out_ch": 2, + "resamp_with_conv": True, + "resolution": 256, + "sample_rate": 16000, + "tanh_out": False, + "z_channels": 8, + } + + +def infer_vocoder_config(weights: Dict[str, mx.array]) -> dict: + """Infer vocoder config from weights.""" + # Check for bwe_generator (LTX-2.3 BigVGAN vocoder) + has_bwe = any(k.startswith("bwe_generator") for k in weights) + + if has_bwe: + return { + "type": "bigvgan", + "has_bwe_generator": True, + } + + return { + "output_sample_rate": 24000, + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "resblock_kernel_sizes": [3, 7, 11], + "stereo": True, + "upsample_initial_channel": 1024, + "upsample_kernel_sizes": [16, 15, 8, 4, 4], + "upsample_rates": [6, 5, 2, 2, 2], + } + + # ─── Main ───────────────────────────────────────────────────────────────────── @@ -524,7 +596,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"): print(" [1/6] Transformer...") transformer_weights = sanitize_transformer(all_weights) num_shards = save_sharded(transformer_weights, output_path / "transformer") - save_config(TRANSFORMER_CONFIG, output_path / "transformer") + config = infer_transformer_config(transformer_weights) + save_config(config, output_path / "transformer") t_params = sum(v.size for v in transformer_weights.values()) print(f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards") @@ -532,8 +605,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"): print(" [2/6] VAE Decoder...") vae_decoder_weights = sanitize_vae_decoder(all_weights) save_single(vae_decoder_weights, output_path / "vae" / "decoder") - decoder_config = VAE_DECODER_CONFIG_DISTILLED if variant == "distilled" else VAE_DECODER_CONFIG_DEV - save_config(decoder_config, output_path / "vae" / "decoder") + config = infer_vae_decoder_config(vae_decoder_weights, variant) + save_config(config, output_path / "vae" / "decoder") d_params = sum(v.size for v in vae_decoder_weights.values()) print(f" {len(vae_decoder_weights)} keys, {d_params:,} params") @@ -541,7 +614,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"): print(" [3/6] VAE Encoder...") vae_encoder_weights = sanitize_vae_encoder(all_weights) save_single(vae_encoder_weights, output_path / "vae" / "encoder") - save_config(VAE_ENCODER_CONFIG, output_path / "vae" / "encoder") + config = infer_vae_encoder_config(vae_encoder_weights) + save_config(config, output_path / "vae" / "encoder") e_params = sum(v.size for v in vae_encoder_weights.values()) print(f" {len(vae_encoder_weights)} keys, {e_params:,} params") @@ -549,7 +623,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"): print(" [4/6] Audio VAE Decoder...") audio_decoder_weights = sanitize_audio_decoder(all_weights) save_single(audio_decoder_weights, output_path / "audio_vae") - save_config(AUDIO_VAE_CONFIG, output_path / "audio_vae") + config = infer_audio_vae_config(audio_decoder_weights) + save_config(config, output_path / "audio_vae") a_params = sum(v.size for v in audio_decoder_weights.values()) print(f" {len(audio_decoder_weights)} keys, {a_params:,} params") @@ -557,7 +632,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"): print(" [5/6] Vocoder...") vocoder_weights = sanitize_vocoder(all_weights) save_single(vocoder_weights, output_path / "vocoder") - save_config(VOCODER_CONFIG, output_path / "vocoder") + config = infer_vocoder_config(vocoder_weights) + save_config(config, output_path / "vocoder") v_params = sum(v.size for v in vocoder_weights.values()) print(f" {len(vocoder_weights)} keys, {v_params:,} params") @@ -626,15 +702,20 @@ def convert(source: str, output_path: Path, variant: str = "distilled"): dest.symlink_to(real_path) print(f" {subdir}/: symlinked to {real_path}") elif is_hf_repo: - from huggingface_hub import snapshot_download + from huggingface_hub import list_repo_files, snapshot_download - print(f" {subdir}/: downloading from {source}...") - snapshot_download( - repo_id=source, - allow_patterns=f"{subdir}/*", - local_dir=str(output_path), - ) - print(f" {subdir}/: done") + # Only download if the subdir exists in the repo + repo_files = list_repo_files(source) + if any(f.startswith(f"{subdir}/") for f in repo_files): + print(f" {subdir}/: downloading from {source}...") + snapshot_download( + repo_id=source, + allow_patterns=f"{subdir}/*", + local_dir=str(output_path), + ) + print(f" {subdir}/: done") + else: + print(f" {subdir}/: not in repo, skipping") else: print(f" {subdir}/: not found in source, skipping") @@ -660,9 +741,11 @@ def convert(source: str, output_path: Path, variant: str = "distilled"): converted_prefixes.add(key) elif key.startswith(AUDIO_DECODER_PREFIX) or key.startswith(AUDIO_STATS_PREFIX): converted_prefixes.add(key) + elif key.startswith(AUDIO_ENCODER_PREFIX): + converted_prefixes.add(key) elif key.startswith(VOCODER_PREFIX): converted_prefixes.add(key) - elif key == TEXT_AGG_KEY: + elif key.startswith(TEXT_PROJ_PREFIX): converted_prefixes.add(key) elif key.startswith(VIDEO_CONNECTOR_PREFIX) or key.startswith(AUDIO_CONNECTOR_PREFIX): converted_prefixes.add(key) @@ -676,13 +759,13 @@ def convert(source: str, output_path: Path, variant: str = "distilled"): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Convert monolithic LTX-2 safetensors to modular MLX layout" + description="Convert monolithic LTX-2/2.3 safetensors to modular MLX layout" ) parser.add_argument( "--source", type=str, required=True, - help="HF repo ID (e.g. Lightricks/LTX-2), local directory, or direct safetensors file path", + help="HF repo ID (e.g. Lightricks/LTX-2, Lightricks/LTX-2.3), local directory, or direct safetensors file path", ) parser.add_argument( "--output",