format
This commit is contained in:
@@ -49,7 +49,6 @@ from typing import Dict
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
# ─── Key prefix routing ──────────────────────────────────────────────────────
|
||||
|
||||
TRANSFORMER_PREFIX = "model.diffusion_model."
|
||||
@@ -78,7 +77,7 @@ def sanitize_transformer(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
|
||||
continue
|
||||
|
||||
new_key = key[len(TRANSFORMER_PREFIX):]
|
||||
new_key = key[len(TRANSFORMER_PREFIX) :]
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
@@ -109,7 +108,7 @@ def sanitize_vae_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
else:
|
||||
continue
|
||||
elif key.startswith(VAE_DECODER_PREFIX):
|
||||
new_key = key[len(VAE_DECODER_PREFIX):]
|
||||
new_key = key[len(VAE_DECODER_PREFIX) :]
|
||||
else:
|
||||
continue
|
||||
|
||||
@@ -147,7 +146,7 @@ def sanitize_vae_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
if value.dtype != mx.float32:
|
||||
value = value.astype(mx.float32)
|
||||
elif key.startswith(VAE_ENCODER_PREFIX):
|
||||
new_key = key[len(VAE_ENCODER_PREFIX):]
|
||||
new_key = key[len(VAE_ENCODER_PREFIX) :]
|
||||
else:
|
||||
continue
|
||||
|
||||
@@ -170,7 +169,7 @@ def sanitize_audio_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
new_key = None
|
||||
|
||||
if key.startswith(AUDIO_DECODER_PREFIX):
|
||||
new_key = key[len(AUDIO_DECODER_PREFIX):]
|
||||
new_key = key[len(AUDIO_DECODER_PREFIX) :]
|
||||
elif key.startswith(AUDIO_STATS_PREFIX):
|
||||
if "mean-of-means" in key:
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
@@ -196,7 +195,7 @@ def sanitize_audio_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
new_key = None
|
||||
|
||||
if key.startswith(AUDIO_ENCODER_PREFIX):
|
||||
new_key = key[len(AUDIO_ENCODER_PREFIX):]
|
||||
new_key = key[len(AUDIO_ENCODER_PREFIX) :]
|
||||
elif key.startswith(AUDIO_STATS_PREFIX):
|
||||
if "mean-of-means" in key:
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
@@ -226,7 +225,7 @@ def sanitize_vocoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
if not key.startswith(VOCODER_PREFIX):
|
||||
continue
|
||||
|
||||
new_key = key[len(VOCODER_PREFIX):]
|
||||
new_key = key[len(VOCODER_PREFIX) :]
|
||||
|
||||
# Handle Conv1d/ConvTranspose1d weight shape conversion
|
||||
if "weight" in new_key and value.ndim == 3:
|
||||
@@ -260,20 +259,20 @@ def extract_text_projections(weights: Dict[str, mx.array]) -> Dict[str, mx.array
|
||||
# aggregate_embed weights (text_embedding_projection.*)
|
||||
for key, value in weights.items():
|
||||
if key.startswith(TEXT_PROJ_PREFIX):
|
||||
new_key = key[len(TEXT_PROJ_PREFIX):]
|
||||
new_key = key[len(TEXT_PROJ_PREFIX) :]
|
||||
extracted[new_key] = value
|
||||
|
||||
# video_embeddings_connector
|
||||
for key, value in weights.items():
|
||||
if key.startswith(VIDEO_CONNECTOR_PREFIX):
|
||||
suffix = key[len(VIDEO_CONNECTOR_PREFIX):]
|
||||
suffix = key[len(VIDEO_CONNECTOR_PREFIX) :]
|
||||
new_key = "video_embeddings_connector." + sanitize_connector_key(suffix)
|
||||
extracted[new_key] = value
|
||||
|
||||
# audio_embeddings_connector
|
||||
for key, value in weights.items():
|
||||
if key.startswith(AUDIO_CONNECTOR_PREFIX):
|
||||
suffix = key[len(AUDIO_CONNECTOR_PREFIX):]
|
||||
suffix = key[len(AUDIO_CONNECTOR_PREFIX) :]
|
||||
new_key = "audio_embeddings_connector." + sanitize_connector_key(suffix)
|
||||
extracted[new_key] = value
|
||||
|
||||
@@ -369,11 +368,15 @@ def save_config(config: dict, output_dir: Path):
|
||||
# ─── Source resolution ─────────────────────────────────────────────────────────
|
||||
|
||||
# Matches monolithic model files: ltx-2-19b-distilled.safetensors, ltx-2.3-22b-dev.safetensors, etc.
|
||||
MONOLITHIC_PATTERN = re.compile(r"^ltx-[\d.]+-\d+b-(?P<variant>distilled|dev)\.safetensors$")
|
||||
MONOLITHIC_PATTERN = re.compile(
|
||||
r"^ltx-[\d.]+-\d+b-(?P<variant>distilled|dev)\.safetensors$"
|
||||
)
|
||||
|
||||
# Matches upscaler files like ltx-2-spatial-upscaler-x2-1.0.safetensors,
|
||||
# ltx-2.3-spatial-upscaler-x2-1.0.safetensors, etc.
|
||||
UPSCALER_PATTERN = re.compile(r"^ltx-[\d.]+-(?:spatial|temporal)-upscaler-.+\.safetensors$")
|
||||
UPSCALER_PATTERN = re.compile(
|
||||
r"^ltx-[\d.]+-(?:spatial|temporal)-upscaler-.+\.safetensors$"
|
||||
)
|
||||
|
||||
|
||||
def resolve_source(source: str, variant: str) -> Path:
|
||||
@@ -506,7 +509,9 @@ def infer_transformer_config(weights: Dict[str, mx.array]) -> dict:
|
||||
def infer_vae_decoder_config(weights: Dict[str, mx.array], variant: str) -> dict:
|
||||
"""Infer VAE decoder config from weights."""
|
||||
# Check for timestep conditioning keys
|
||||
has_timestep = any("last_time_embedder" in k or "last_scale_shift_table" in k for k in weights)
|
||||
has_timestep = any(
|
||||
"last_time_embedder" in k or "last_scale_shift_table" in k for k in weights
|
||||
)
|
||||
|
||||
# Count channel multipliers from up_blocks
|
||||
max_block = -1
|
||||
@@ -658,7 +663,9 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
config = infer_transformer_config(transformer_weights)
|
||||
save_config(config, output_path / "transformer")
|
||||
t_params = sum(v.size for v in transformer_weights.values())
|
||||
print(f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards")
|
||||
print(
|
||||
f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards"
|
||||
)
|
||||
|
||||
# 2. VAE Decoder
|
||||
print(" [2/7] VAE Decoder...")
|
||||
@@ -728,7 +735,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
]
|
||||
else:
|
||||
upscaler_files = [
|
||||
f.name for f in source_dir.iterdir()
|
||||
f.name
|
||||
for f in source_dir.iterdir()
|
||||
if f.is_file() and UPSCALER_PATTERN.match(f.name)
|
||||
]
|
||||
|
||||
@@ -800,12 +808,21 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
print(f"\nDone! Converted {all_converted}/{total_keys} keys")
|
||||
if all_converted < total_keys:
|
||||
known_prefixes = (
|
||||
TRANSFORMER_PREFIX, VAE_DECODER_PREFIX, VAE_ENCODER_PREFIX,
|
||||
VAE_STATS_PREFIX, AUDIO_DECODER_PREFIX, AUDIO_ENCODER_PREFIX,
|
||||
AUDIO_STATS_PREFIX, VOCODER_PREFIX, TEXT_PROJ_PREFIX,
|
||||
VIDEO_CONNECTOR_PREFIX, AUDIO_CONNECTOR_PREFIX,
|
||||
TRANSFORMER_PREFIX,
|
||||
VAE_DECODER_PREFIX,
|
||||
VAE_ENCODER_PREFIX,
|
||||
VAE_STATS_PREFIX,
|
||||
AUDIO_DECODER_PREFIX,
|
||||
AUDIO_ENCODER_PREFIX,
|
||||
AUDIO_STATS_PREFIX,
|
||||
VOCODER_PREFIX,
|
||||
TEXT_PROJ_PREFIX,
|
||||
VIDEO_CONNECTOR_PREFIX,
|
||||
AUDIO_CONNECTOR_PREFIX,
|
||||
)
|
||||
skipped = [k for k in all_weights if not any(k.startswith(p) for p in known_prefixes)]
|
||||
skipped = [
|
||||
k for k in all_weights if not any(k.startswith(p) for p in known_prefixes)
|
||||
]
|
||||
if skipped:
|
||||
print(f" Skipped {len(skipped)} keys:")
|
||||
for k in sorted(skipped)[:20]:
|
||||
|
||||
Reference in New Issue
Block a user