This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -49,7 +49,6 @@ from typing import Dict
import mlx.core as mx
# ─── Key prefix routing ──────────────────────────────────────────────────────
TRANSFORMER_PREFIX = "model.diffusion_model."
@@ -78,7 +77,7 @@ def sanitize_transformer(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
continue
new_key = key[len(TRANSFORMER_PREFIX):]
new_key = key[len(TRANSFORMER_PREFIX) :]
new_key = new_key.replace(".to_out.0.", ".to_out.")
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
@@ -109,7 +108,7 @@ def sanitize_vae_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
else:
continue
elif key.startswith(VAE_DECODER_PREFIX):
new_key = key[len(VAE_DECODER_PREFIX):]
new_key = key[len(VAE_DECODER_PREFIX) :]
else:
continue
@@ -147,7 +146,7 @@ def sanitize_vae_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
if value.dtype != mx.float32:
value = value.astype(mx.float32)
elif key.startswith(VAE_ENCODER_PREFIX):
new_key = key[len(VAE_ENCODER_PREFIX):]
new_key = key[len(VAE_ENCODER_PREFIX) :]
else:
continue
@@ -170,7 +169,7 @@ def sanitize_audio_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
new_key = None
if key.startswith(AUDIO_DECODER_PREFIX):
new_key = key[len(AUDIO_DECODER_PREFIX):]
new_key = key[len(AUDIO_DECODER_PREFIX) :]
elif key.startswith(AUDIO_STATS_PREFIX):
if "mean-of-means" in key:
new_key = "per_channel_statistics.mean_of_means"
@@ -196,7 +195,7 @@ def sanitize_audio_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
new_key = None
if key.startswith(AUDIO_ENCODER_PREFIX):
new_key = key[len(AUDIO_ENCODER_PREFIX):]
new_key = key[len(AUDIO_ENCODER_PREFIX) :]
elif key.startswith(AUDIO_STATS_PREFIX):
if "mean-of-means" in key:
new_key = "per_channel_statistics.mean_of_means"
@@ -226,7 +225,7 @@ def sanitize_vocoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
if not key.startswith(VOCODER_PREFIX):
continue
new_key = key[len(VOCODER_PREFIX):]
new_key = key[len(VOCODER_PREFIX) :]
# Handle Conv1d/ConvTranspose1d weight shape conversion
if "weight" in new_key and value.ndim == 3:
@@ -260,20 +259,20 @@ def extract_text_projections(weights: Dict[str, mx.array]) -> Dict[str, mx.array
# aggregate_embed weights (text_embedding_projection.*)
for key, value in weights.items():
if key.startswith(TEXT_PROJ_PREFIX):
new_key = key[len(TEXT_PROJ_PREFIX):]
new_key = key[len(TEXT_PROJ_PREFIX) :]
extracted[new_key] = value
# video_embeddings_connector
for key, value in weights.items():
if key.startswith(VIDEO_CONNECTOR_PREFIX):
suffix = key[len(VIDEO_CONNECTOR_PREFIX):]
suffix = key[len(VIDEO_CONNECTOR_PREFIX) :]
new_key = "video_embeddings_connector." + sanitize_connector_key(suffix)
extracted[new_key] = value
# audio_embeddings_connector
for key, value in weights.items():
if key.startswith(AUDIO_CONNECTOR_PREFIX):
suffix = key[len(AUDIO_CONNECTOR_PREFIX):]
suffix = key[len(AUDIO_CONNECTOR_PREFIX) :]
new_key = "audio_embeddings_connector." + sanitize_connector_key(suffix)
extracted[new_key] = value
@@ -369,11 +368,15 @@ def save_config(config: dict, output_dir: Path):
# ─── Source resolution ─────────────────────────────────────────────────────────
# Matches monolithic model files: ltx-2-19b-distilled.safetensors, ltx-2.3-22b-dev.safetensors, etc.
MONOLITHIC_PATTERN = re.compile(r"^ltx-[\d.]+-\d+b-(?P<variant>distilled|dev)\.safetensors$")
MONOLITHIC_PATTERN = re.compile(
r"^ltx-[\d.]+-\d+b-(?P<variant>distilled|dev)\.safetensors$"
)
# Matches upscaler files like ltx-2-spatial-upscaler-x2-1.0.safetensors,
# ltx-2.3-spatial-upscaler-x2-1.0.safetensors, etc.
UPSCALER_PATTERN = re.compile(r"^ltx-[\d.]+-(?:spatial|temporal)-upscaler-.+\.safetensors$")
UPSCALER_PATTERN = re.compile(
r"^ltx-[\d.]+-(?:spatial|temporal)-upscaler-.+\.safetensors$"
)
def resolve_source(source: str, variant: str) -> Path:
@@ -506,7 +509,9 @@ def infer_transformer_config(weights: Dict[str, mx.array]) -> dict:
def infer_vae_decoder_config(weights: Dict[str, mx.array], variant: str) -> dict:
"""Infer VAE decoder config from weights."""
# Check for timestep conditioning keys
has_timestep = any("last_time_embedder" in k or "last_scale_shift_table" in k for k in weights)
has_timestep = any(
"last_time_embedder" in k or "last_scale_shift_table" in k for k in weights
)
# Count channel multipliers from up_blocks
max_block = -1
@@ -658,7 +663,9 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
config = infer_transformer_config(transformer_weights)
save_config(config, output_path / "transformer")
t_params = sum(v.size for v in transformer_weights.values())
print(f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards")
print(
f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards"
)
# 2. VAE Decoder
print(" [2/7] VAE Decoder...")
@@ -728,7 +735,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
]
else:
upscaler_files = [
f.name for f in source_dir.iterdir()
f.name
for f in source_dir.iterdir()
if f.is_file() and UPSCALER_PATTERN.match(f.name)
]
@@ -800,12 +808,21 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
print(f"\nDone! Converted {all_converted}/{total_keys} keys")
if all_converted < total_keys:
known_prefixes = (
TRANSFORMER_PREFIX, VAE_DECODER_PREFIX, VAE_ENCODER_PREFIX,
VAE_STATS_PREFIX, AUDIO_DECODER_PREFIX, AUDIO_ENCODER_PREFIX,
AUDIO_STATS_PREFIX, VOCODER_PREFIX, TEXT_PROJ_PREFIX,
VIDEO_CONNECTOR_PREFIX, AUDIO_CONNECTOR_PREFIX,
TRANSFORMER_PREFIX,
VAE_DECODER_PREFIX,
VAE_ENCODER_PREFIX,
VAE_STATS_PREFIX,
AUDIO_DECODER_PREFIX,
AUDIO_ENCODER_PREFIX,
AUDIO_STATS_PREFIX,
VOCODER_PREFIX,
TEXT_PROJ_PREFIX,
VIDEO_CONNECTOR_PREFIX,
AUDIO_CONNECTOR_PREFIX,
)
skipped = [k for k in all_weights if not any(k.startswith(p) for p in known_prefixes)]
skipped = [
k for k in all_weights if not any(k.startswith(p) for p in known_prefixes)
]
if skipped:
print(f" Skipped {len(skipped)} keys:")
for k in sorted(skipped)[:20]: