Update LTX conversion script to support LTX-2.3 safetensors format. Enhance documentation and improve file matching logic for variant detection in local directories.
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""Convert LTX-2 safetensors to MLX directory layout.
|
||||
"""Convert LTX-2/2.3 safetensors to MLX directory layout.
|
||||
|
||||
Converts from the single-file format (e.g. Lightricks/LTX-2/ltx-2-19b-distilled.safetensors)
|
||||
to the modular directory structure:
|
||||
Converts from the single-file format (e.g. Lightricks/LTX-2/ltx-2-19b-distilled.safetensors
|
||||
or Lightricks/LTX-2.3/ltx-2.3-22b-distilled.safetensors) to the modular directory structure:
|
||||
|
||||
output/
|
||||
├── transformer/ # DiT transformer weights (sharded)
|
||||
@@ -27,7 +27,7 @@ to the modular directory structure:
|
||||
Usage:
|
||||
# From HF repo ID
|
||||
python -m mlx_video.models.ltx.convert --source Lightricks/LTX-2 --output LTX-2-distilled --variant distilled
|
||||
python -m mlx_video.models.ltx.convert --source Lightricks/LTX-2 --output LTX-2-dev --variant dev
|
||||
python -m mlx_video.models.ltx.convert --source Lightricks/LTX-2.3 --output LTX-2.3-distilled --variant distilled
|
||||
|
||||
# From local folder containing the monolithic safetensors
|
||||
python -m mlx_video.models.ltx.convert --source ./Lightricks-LTX-2/ --output LTX-2-distilled --variant distilled
|
||||
@@ -46,111 +46,6 @@ from typing import Dict
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
# ─── Component configs ────────────────────────────────────────────────────────
|
||||
|
||||
TRANSFORMER_CONFIG = {
|
||||
"attention_head_dim": 128,
|
||||
"attention_type": "default",
|
||||
"audio_attention_head_dim": 64,
|
||||
"audio_caption_channels": 3840,
|
||||
"audio_cross_attention_dim": 2048,
|
||||
"audio_in_channels": 128,
|
||||
"audio_num_attention_heads": 32,
|
||||
"audio_out_channels": 128,
|
||||
"audio_positional_embedding_max_pos": [20],
|
||||
"av_ca_timestep_scale_multiplier": 1000,
|
||||
"caption_channels": 3840,
|
||||
"cross_attention_dim": 4096,
|
||||
"double_precision_rope": True,
|
||||
"in_channels": 128,
|
||||
"model_type": "ltx av model",
|
||||
"norm_eps": 1e-06,
|
||||
"num_attention_heads": 32,
|
||||
"num_layers": 48,
|
||||
"out_channels": 128,
|
||||
"positional_embedding_max_pos": [20, 2048, 2048],
|
||||
"positional_embedding_theta": 10000.0,
|
||||
"rope_type": "split",
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"use_middle_indices_grid": True,
|
||||
}
|
||||
|
||||
VAE_DECODER_CONFIG_DISTILLED = {
|
||||
"ch": 128,
|
||||
"ch_mult": [1, 2, 4],
|
||||
"dropout": 0.0,
|
||||
"num_res_blocks": 2,
|
||||
"out_ch": 2,
|
||||
"resolution": 256,
|
||||
"timestep_conditioning": False,
|
||||
"z_channels": 8,
|
||||
}
|
||||
|
||||
VAE_DECODER_CONFIG_DEV = {
|
||||
"ch": 128,
|
||||
"ch_mult": [1, 2, 4],
|
||||
"dropout": 0.0,
|
||||
"num_res_blocks": 2,
|
||||
"out_ch": 2,
|
||||
"resolution": 256,
|
||||
"timestep_conditioning": True,
|
||||
"z_channels": 8,
|
||||
}
|
||||
|
||||
VAE_ENCODER_CONFIG = {
|
||||
"convolution_dimensions": 3,
|
||||
"encoder_blocks": [
|
||||
["res_x", {"num_layers": 4}],
|
||||
["compress_space_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 6}],
|
||||
["compress_time_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 6}],
|
||||
["compress_all_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 2}],
|
||||
["compress_all_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 2}],
|
||||
],
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
"in_channels": 3,
|
||||
"latent_log_var": "uniform",
|
||||
"norm_layer": "pixel_norm",
|
||||
"out_channels": 128,
|
||||
"patch_size": 4,
|
||||
}
|
||||
|
||||
AUDIO_VAE_CONFIG = {
|
||||
"attn_resolutions": [],
|
||||
"attn_type": "vanilla",
|
||||
"causality_axis": "height",
|
||||
"ch": 128,
|
||||
"ch_mult": [1, 2, 4],
|
||||
"dropout": 0.0,
|
||||
"give_pre_end": False,
|
||||
"is_causal": True,
|
||||
"mel_bins": 64,
|
||||
"mel_hop_length": 160,
|
||||
"mid_block_add_attention": False,
|
||||
"norm_type": "pixel",
|
||||
"num_res_blocks": 2,
|
||||
"out_ch": 2,
|
||||
"resamp_with_conv": True,
|
||||
"resolution": 256,
|
||||
"sample_rate": 16000,
|
||||
"tanh_out": False,
|
||||
"z_channels": 8,
|
||||
}
|
||||
|
||||
VOCODER_CONFIG = {
|
||||
"output_sample_rate": 24000,
|
||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"stereo": True,
|
||||
"upsample_initial_channel": 1024,
|
||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||
"upsample_rates": [6, 5, 2, 2, 2],
|
||||
}
|
||||
|
||||
|
||||
# ─── Key prefix routing ──────────────────────────────────────────────────────
|
||||
|
||||
TRANSFORMER_PREFIX = "model.diffusion_model."
|
||||
@@ -158,9 +53,10 @@ VAE_DECODER_PREFIX = "vae.decoder."
|
||||
VAE_ENCODER_PREFIX = "vae.encoder."
|
||||
VAE_STATS_PREFIX = "vae.per_channel_statistics."
|
||||
AUDIO_DECODER_PREFIX = "audio_vae.decoder."
|
||||
AUDIO_ENCODER_PREFIX = "audio_vae.encoder."
|
||||
AUDIO_STATS_PREFIX = "audio_vae.per_channel_statistics."
|
||||
VOCODER_PREFIX = "vocoder."
|
||||
TEXT_AGG_KEY = "text_embedding_projection.aggregate_embed.weight"
|
||||
TEXT_PROJ_PREFIX = "text_embedding_projection."
|
||||
VIDEO_CONNECTOR_PREFIX = "model.diffusion_model.video_embeddings_connector."
|
||||
AUDIO_CONNECTOR_PREFIX = "model.diffusion_model.audio_embeddings_connector."
|
||||
|
||||
@@ -320,12 +216,18 @@ def sanitize_connector_key(key: str) -> str:
|
||||
|
||||
|
||||
def extract_text_projections(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Extract text projection weights (aggregate_embed + connectors)."""
|
||||
"""Extract text projection weights (aggregate_embed + connectors).
|
||||
|
||||
Handles both LTX-2 (aggregate_embed.weight) and LTX-2.3
|
||||
(video_aggregate_embed.*, audio_aggregate_embed.*) formats.
|
||||
"""
|
||||
extracted = {}
|
||||
|
||||
# aggregate_embed
|
||||
if TEXT_AGG_KEY in weights:
|
||||
extracted["aggregate_embed.weight"] = weights[TEXT_AGG_KEY]
|
||||
# aggregate_embed weights (text_embedding_projection.*)
|
||||
for key, value in weights.items():
|
||||
if key.startswith(TEXT_PROJ_PREFIX):
|
||||
new_key = key[len(TEXT_PROJ_PREFIX):]
|
||||
extracted[new_key] = value
|
||||
|
||||
# video_embeddings_connector
|
||||
for key, value in weights.items():
|
||||
@@ -432,10 +334,8 @@ def save_config(config: dict, output_dir: Path):
|
||||
|
||||
# ─── Source resolution ─────────────────────────────────────────────────────────
|
||||
|
||||
VARIANT_FILE_PATTERNS = {
|
||||
"distilled": "ltx-2-19b-distilled.safetensors",
|
||||
"dev": "ltx-2-19b-dev.safetensors",
|
||||
}
|
||||
# Matches monolithic model files: ltx-2-19b-distilled.safetensors, ltx-2.3-22b-dev.safetensors, etc.
|
||||
MONOLITHIC_PATTERN = re.compile(r"^ltx-[\d.]+-\d+b-(?P<variant>distilled|dev)\.safetensors$")
|
||||
|
||||
# Matches upscaler files like ltx-2-spatial-upscaler-x2-1.0.safetensors,
|
||||
# ltx-2.3-spatial-upscaler-x2-1.0.safetensors, etc.
|
||||
@@ -458,40 +358,49 @@ def resolve_source(source: str, variant: str) -> Path:
|
||||
if source_path.is_file():
|
||||
return source_path
|
||||
|
||||
# Local directory — look for the variant's safetensors file
|
||||
# Local directory — find the variant's safetensors file
|
||||
if source_path.is_dir():
|
||||
target = VARIANT_FILE_PATTERNS.get(variant)
|
||||
if target:
|
||||
candidate = source_path / target
|
||||
if candidate.is_file():
|
||||
return candidate
|
||||
matches = []
|
||||
for f in sorted(source_path.glob("ltx-*b-*.safetensors")):
|
||||
m = MONOLITHIC_PATTERN.match(f.name)
|
||||
if m and m.group("variant") == variant:
|
||||
matches.append(f)
|
||||
|
||||
# Fallback: glob for ltx-2-19b-*.safetensors
|
||||
matches = sorted(source_path.glob("ltx-2-19b-*.safetensors"))
|
||||
if matches:
|
||||
if len(matches) == 1:
|
||||
return matches[0]
|
||||
# Multiple matches — pick by variant keyword
|
||||
for m in matches:
|
||||
if variant in m.name:
|
||||
return m
|
||||
return matches[0]
|
||||
|
||||
# Broader fallback
|
||||
all_mono = sorted(source_path.glob("ltx-*.safetensors"))
|
||||
for f in all_mono:
|
||||
if variant in f.name and MONOLITHIC_PATTERN.match(f.name):
|
||||
return f
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"No ltx-2-19b-*.safetensors found in {source_path}. "
|
||||
f"Expected {target} for variant '{variant}'."
|
||||
f"No monolithic *-{variant}.safetensors found in {source_path}. "
|
||||
f"Files found: {[f.name for f in all_mono]}"
|
||||
)
|
||||
|
||||
# HF repo ID — download via huggingface_hub
|
||||
if "/" in source and not source_path.exists():
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub import hf_hub_download, list_repo_files
|
||||
|
||||
filename = VARIANT_FILE_PATTERNS.get(variant)
|
||||
if not filename:
|
||||
raise ValueError(f"Unknown variant '{variant}'. Expected 'distilled' or 'dev'.")
|
||||
# Find the right file in the repo
|
||||
repo_files = list_repo_files(source)
|
||||
target = None
|
||||
for f in repo_files:
|
||||
m = MONOLITHIC_PATTERN.match(f)
|
||||
if m and m.group("variant") == variant:
|
||||
target = f
|
||||
break
|
||||
|
||||
print(f"Downloading {filename} from {source}...")
|
||||
local_path = hf_hub_download(repo_id=source, filename=filename)
|
||||
if not target:
|
||||
raise FileNotFoundError(
|
||||
f"No *-{variant}.safetensors found in {source}. "
|
||||
f"Available: {[f for f in repo_files if f.endswith('.safetensors')]}"
|
||||
)
|
||||
|
||||
print(f"Downloading {target} from {source}...")
|
||||
local_path = hf_hub_download(repo_id=source, filename=target)
|
||||
return Path(local_path)
|
||||
|
||||
raise FileNotFoundError(
|
||||
@@ -499,6 +408,169 @@ def resolve_source(source: str, variant: str) -> Path:
|
||||
)
|
||||
|
||||
|
||||
# ─── Config inference ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def infer_transformer_config(weights: Dict[str, mx.array]) -> dict:
|
||||
"""Infer transformer config from weight shapes."""
|
||||
# Count transformer layers
|
||||
max_layer = -1
|
||||
for key in weights:
|
||||
if "transformer_blocks." in key:
|
||||
parts = key.split(".")
|
||||
try:
|
||||
idx = parts.index("transformer_blocks") + 1
|
||||
if idx < len(parts) and parts[idx].isdigit():
|
||||
max_layer = max(max_layer, int(parts[idx]))
|
||||
except ValueError:
|
||||
pass
|
||||
num_layers = max_layer + 1 if max_layer >= 0 else 48
|
||||
|
||||
# Detect cross_attention_dim from attn2.to_k (cross-attention input dim)
|
||||
cross_attention_dim = 4096
|
||||
for key, value in weights.items():
|
||||
if "transformer_blocks.0.attn2.to_k.weight" in key:
|
||||
cross_attention_dim = value.shape[-1]
|
||||
break
|
||||
|
||||
# Check for prompt_adaln_single (LTX-2.3 feature)
|
||||
has_prompt_adaln = any("prompt_adaln_single" in k for k in weights)
|
||||
|
||||
config = {
|
||||
"attention_head_dim": 128,
|
||||
"attention_type": "default",
|
||||
"audio_attention_head_dim": 64,
|
||||
"audio_caption_channels": 3840,
|
||||
"audio_cross_attention_dim": 2048,
|
||||
"audio_in_channels": 128,
|
||||
"audio_num_attention_heads": 32,
|
||||
"audio_out_channels": 128,
|
||||
"audio_positional_embedding_max_pos": [20],
|
||||
"av_ca_timestep_scale_multiplier": 1000,
|
||||
"caption_channels": 3840,
|
||||
"cross_attention_dim": cross_attention_dim,
|
||||
"double_precision_rope": True,
|
||||
"in_channels": 128,
|
||||
"model_type": "ltx av model",
|
||||
"norm_eps": 1e-06,
|
||||
"num_attention_heads": 32,
|
||||
"num_layers": num_layers,
|
||||
"out_channels": 128,
|
||||
"positional_embedding_max_pos": [20, 2048, 2048],
|
||||
"positional_embedding_theta": 10000.0,
|
||||
"rope_type": "split",
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"use_middle_indices_grid": True,
|
||||
}
|
||||
|
||||
if has_prompt_adaln:
|
||||
config["has_prompt_adaln"] = True
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def infer_vae_decoder_config(weights: Dict[str, mx.array], variant: str) -> dict:
|
||||
"""Infer VAE decoder config from weights."""
|
||||
# Check for timestep conditioning keys
|
||||
has_timestep = any("last_time_embedder" in k or "last_scale_shift_table" in k for k in weights)
|
||||
|
||||
# Count channel multipliers from up_blocks
|
||||
max_block = -1
|
||||
for key in weights:
|
||||
if "up_blocks." in key:
|
||||
parts = key.split(".")
|
||||
try:
|
||||
idx = parts.index("up_blocks") + 1
|
||||
if idx < len(parts) and parts[idx].isdigit():
|
||||
max_block = max(max_block, int(parts[idx]))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Default config
|
||||
config = {
|
||||
"ch": 128,
|
||||
"ch_mult": [1, 2, 4],
|
||||
"dropout": 0.0,
|
||||
"num_res_blocks": 2,
|
||||
"out_ch": 2,
|
||||
"resolution": 256,
|
||||
"timestep_conditioning": has_timestep if has_timestep else (variant == "dev"),
|
||||
"z_channels": 8,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def infer_vae_encoder_config(weights: Dict[str, mx.array]) -> dict:
|
||||
"""Return VAE encoder config (architecture is consistent across versions)."""
|
||||
return {
|
||||
"convolution_dimensions": 3,
|
||||
"encoder_blocks": [
|
||||
["res_x", {"num_layers": 4}],
|
||||
["compress_space_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 6}],
|
||||
["compress_time_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 6}],
|
||||
["compress_all_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 2}],
|
||||
["compress_all_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 2}],
|
||||
],
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
"in_channels": 3,
|
||||
"latent_log_var": "uniform",
|
||||
"norm_layer": "pixel_norm",
|
||||
"out_channels": 128,
|
||||
"patch_size": 4,
|
||||
}
|
||||
|
||||
|
||||
def infer_audio_vae_config(weights: Dict[str, mx.array]) -> dict:
|
||||
"""Return audio VAE config."""
|
||||
return {
|
||||
"attn_resolutions": [],
|
||||
"attn_type": "vanilla",
|
||||
"causality_axis": "height",
|
||||
"ch": 128,
|
||||
"ch_mult": [1, 2, 4],
|
||||
"dropout": 0.0,
|
||||
"give_pre_end": False,
|
||||
"is_causal": True,
|
||||
"mel_bins": 64,
|
||||
"mel_hop_length": 160,
|
||||
"mid_block_add_attention": False,
|
||||
"norm_type": "pixel",
|
||||
"num_res_blocks": 2,
|
||||
"out_ch": 2,
|
||||
"resamp_with_conv": True,
|
||||
"resolution": 256,
|
||||
"sample_rate": 16000,
|
||||
"tanh_out": False,
|
||||
"z_channels": 8,
|
||||
}
|
||||
|
||||
|
||||
def infer_vocoder_config(weights: Dict[str, mx.array]) -> dict:
|
||||
"""Infer vocoder config from weights."""
|
||||
# Check for bwe_generator (LTX-2.3 BigVGAN vocoder)
|
||||
has_bwe = any(k.startswith("bwe_generator") for k in weights)
|
||||
|
||||
if has_bwe:
|
||||
return {
|
||||
"type": "bigvgan",
|
||||
"has_bwe_generator": True,
|
||||
}
|
||||
|
||||
return {
|
||||
"output_sample_rate": 24000,
|
||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"stereo": True,
|
||||
"upsample_initial_channel": 1024,
|
||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||
"upsample_rates": [6, 5, 2, 2, 2],
|
||||
}
|
||||
|
||||
|
||||
# ─── Main ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -524,7 +596,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
print(" [1/6] Transformer...")
|
||||
transformer_weights = sanitize_transformer(all_weights)
|
||||
num_shards = save_sharded(transformer_weights, output_path / "transformer")
|
||||
save_config(TRANSFORMER_CONFIG, output_path / "transformer")
|
||||
config = infer_transformer_config(transformer_weights)
|
||||
save_config(config, output_path / "transformer")
|
||||
t_params = sum(v.size for v in transformer_weights.values())
|
||||
print(f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards")
|
||||
|
||||
@@ -532,8 +605,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
print(" [2/6] VAE Decoder...")
|
||||
vae_decoder_weights = sanitize_vae_decoder(all_weights)
|
||||
save_single(vae_decoder_weights, output_path / "vae" / "decoder")
|
||||
decoder_config = VAE_DECODER_CONFIG_DISTILLED if variant == "distilled" else VAE_DECODER_CONFIG_DEV
|
||||
save_config(decoder_config, output_path / "vae" / "decoder")
|
||||
config = infer_vae_decoder_config(vae_decoder_weights, variant)
|
||||
save_config(config, output_path / "vae" / "decoder")
|
||||
d_params = sum(v.size for v in vae_decoder_weights.values())
|
||||
print(f" {len(vae_decoder_weights)} keys, {d_params:,} params")
|
||||
|
||||
@@ -541,7 +614,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
print(" [3/6] VAE Encoder...")
|
||||
vae_encoder_weights = sanitize_vae_encoder(all_weights)
|
||||
save_single(vae_encoder_weights, output_path / "vae" / "encoder")
|
||||
save_config(VAE_ENCODER_CONFIG, output_path / "vae" / "encoder")
|
||||
config = infer_vae_encoder_config(vae_encoder_weights)
|
||||
save_config(config, output_path / "vae" / "encoder")
|
||||
e_params = sum(v.size for v in vae_encoder_weights.values())
|
||||
print(f" {len(vae_encoder_weights)} keys, {e_params:,} params")
|
||||
|
||||
@@ -549,7 +623,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
print(" [4/6] Audio VAE Decoder...")
|
||||
audio_decoder_weights = sanitize_audio_decoder(all_weights)
|
||||
save_single(audio_decoder_weights, output_path / "audio_vae")
|
||||
save_config(AUDIO_VAE_CONFIG, output_path / "audio_vae")
|
||||
config = infer_audio_vae_config(audio_decoder_weights)
|
||||
save_config(config, output_path / "audio_vae")
|
||||
a_params = sum(v.size for v in audio_decoder_weights.values())
|
||||
print(f" {len(audio_decoder_weights)} keys, {a_params:,} params")
|
||||
|
||||
@@ -557,7 +632,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
print(" [5/6] Vocoder...")
|
||||
vocoder_weights = sanitize_vocoder(all_weights)
|
||||
save_single(vocoder_weights, output_path / "vocoder")
|
||||
save_config(VOCODER_CONFIG, output_path / "vocoder")
|
||||
config = infer_vocoder_config(vocoder_weights)
|
||||
save_config(config, output_path / "vocoder")
|
||||
v_params = sum(v.size for v in vocoder_weights.values())
|
||||
print(f" {len(vocoder_weights)} keys, {v_params:,} params")
|
||||
|
||||
@@ -626,15 +702,20 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
dest.symlink_to(real_path)
|
||||
print(f" {subdir}/: symlinked to {real_path}")
|
||||
elif is_hf_repo:
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub import list_repo_files, snapshot_download
|
||||
|
||||
print(f" {subdir}/: downloading from {source}...")
|
||||
snapshot_download(
|
||||
repo_id=source,
|
||||
allow_patterns=f"{subdir}/*",
|
||||
local_dir=str(output_path),
|
||||
)
|
||||
print(f" {subdir}/: done")
|
||||
# Only download if the subdir exists in the repo
|
||||
repo_files = list_repo_files(source)
|
||||
if any(f.startswith(f"{subdir}/") for f in repo_files):
|
||||
print(f" {subdir}/: downloading from {source}...")
|
||||
snapshot_download(
|
||||
repo_id=source,
|
||||
allow_patterns=f"{subdir}/*",
|
||||
local_dir=str(output_path),
|
||||
)
|
||||
print(f" {subdir}/: done")
|
||||
else:
|
||||
print(f" {subdir}/: not in repo, skipping")
|
||||
else:
|
||||
print(f" {subdir}/: not found in source, skipping")
|
||||
|
||||
@@ -660,9 +741,11 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
converted_prefixes.add(key)
|
||||
elif key.startswith(AUDIO_DECODER_PREFIX) or key.startswith(AUDIO_STATS_PREFIX):
|
||||
converted_prefixes.add(key)
|
||||
elif key.startswith(AUDIO_ENCODER_PREFIX):
|
||||
converted_prefixes.add(key)
|
||||
elif key.startswith(VOCODER_PREFIX):
|
||||
converted_prefixes.add(key)
|
||||
elif key == TEXT_AGG_KEY:
|
||||
elif key.startswith(TEXT_PROJ_PREFIX):
|
||||
converted_prefixes.add(key)
|
||||
elif key.startswith(VIDEO_CONNECTOR_PREFIX) or key.startswith(AUDIO_CONNECTOR_PREFIX):
|
||||
converted_prefixes.add(key)
|
||||
@@ -676,13 +759,13 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert monolithic LTX-2 safetensors to modular MLX layout"
|
||||
description="Convert monolithic LTX-2/2.3 safetensors to modular MLX layout"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HF repo ID (e.g. Lightricks/LTX-2), local directory, or direct safetensors file path",
|
||||
help="HF repo ID (e.g. Lightricks/LTX-2, Lightricks/LTX-2.3), local directory, or direct safetensors file path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
|
||||
Reference in New Issue
Block a user