Refactor model loading in generate.py to use dynamic model paths for audio and video components. Simplify weight loading logic in LTX2TextEncoder to accommodate both monolithic and reformatted model structures. Introduce a check for existing model paths in get_model_path function to enhance robustness.

This commit is contained in:
Prince Canuma
2026-03-09 15:51:21 +01:00
parent d1dd30cbac
commit 9f37dab076
5 changed files with 85 additions and 50 deletions

View File

@@ -497,14 +497,17 @@ class LTXModel(nn.Module):
def sanitize(self, weights: dict) -> dict:
sanitized = {}
if "model.diffusion_model." not in weights:
has_raw_prefix = any(k.startswith("model.diffusion_model.") for k in weights)
if not has_raw_prefix:
return weights
for key, value in weights.items():
new_key = key
# Skip non-transformer weights (VAE, vocoder, audio_vae, connectors)
if not key.startswith("model.diffusion_model.") or "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
if not key.startswith("model.diffusion_model."):
continue
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
continue
# Remove 'model.diffusion_model.' prefix
@@ -520,7 +523,6 @@ class LTXModel(nn.Module):
new_key = new_key.replace(".linear_1.", ".linear1.")
new_key = new_key.replace(".linear_2.", ".linear2.")
sanitized[new_key] = value
return sanitized

View File

@@ -646,36 +646,63 @@ class LTX2TextEncoder(nn.Module):
self.language_model = LanguageModel.from_pretrained(text_encoder_path)
# Load transformer weights for feature extractor and connector
transformer_files = list(model_path.glob("ltx-2-19*.safetensors"))
if transformer_files:
transformer_weights = mx.load(str(transformer_files[0]))
# Load transformer weights for feature extractor and connector.
# These weights are stored differently depending on the repo format:
# 1. Monolithic (Lightricks/LTX-2): single ltx-2-19b-*.safetensors at root
# with raw PyTorch key names (model.diffusion_model.* prefix)
# 2. Reformatted (prince-canuma/LTX-2-distilled): separate text_projections/
# directory with pre-sanitized keys (no prefix, already renamed)
transformer_weights = {}
is_reformatted = False
# Try reformatted layout first: text_projections/ subdirectory
text_proj_dir = model_path / "text_projections"
if text_proj_dir.is_dir():
is_reformatted = True
for sf in text_proj_dir.glob("*.safetensors"):
transformer_weights.update(mx.load(str(sf)))
# Fall back to monolithic layout: ltx-2-19*.safetensors at root
if not transformer_weights:
transformer_files = list(model_path.glob("ltx-2-19*.safetensors"))
if transformer_files:
transformer_weights = mx.load(str(transformer_files[0]))
if transformer_weights:
# Load feature extractor (aggregate_embed)
if "text_embedding_projection.aggregate_embed.weight" in transformer_weights:
self.feature_extractor.aggregate_embed.weight = transformer_weights[
"text_embedding_projection.aggregate_embed.weight"
]
# Reformatted key: "aggregate_embed.weight"
# Monolithic key: "text_embedding_projection.aggregate_embed.weight"
agg_key = "aggregate_embed.weight" if is_reformatted else "text_embedding_projection.aggregate_embed.weight"
if agg_key in transformer_weights:
self.feature_extractor.aggregate_embed.weight = transformer_weights[agg_key]
# Load video_embeddings_connector weights
connector_weights = {}
for key, value in transformer_weights.items():
if key.startswith("model.diffusion_model.video_embeddings_connector."):
new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "")
connector_weights[new_key] = value
if is_reformatted:
# Reformatted: keys are already sanitized with "video_embeddings_connector." prefix
for key, value in transformer_weights.items():
if key.startswith("video_embeddings_connector."):
new_key = key.replace("video_embeddings_connector.", "")
connector_weights[new_key] = value
else:
# Monolithic: keys have "model.diffusion_model.video_embeddings_connector." prefix
for key, value in transformer_weights.items():
if key.startswith("model.diffusion_model.video_embeddings_connector."):
new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "")
connector_weights[new_key] = value
if connector_weights:
# Map weight names to our structure
# Map weight names to our structure (only needed for monolithic/raw PyTorch keys)
mapped_weights = {}
for key, value in connector_weights.items():
new_key = key
# Map ff.net.0.proj -> ff.proj_in (GEGLU projection)
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
# Map ff.net.2 -> ff.proj_out (output Linear)
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
# Map to_out.0 -> to_out (Sequential -> direct)
new_key = new_key.replace(".to_out.0.", ".to_out.")
if not is_reformatted:
# Map ff.net.0.proj -> ff.proj_in (GEGLU projection)
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
# Map ff.net.2 -> ff.proj_out (output Linear)
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
# Map to_out.0 -> to_out (Sequential -> direct)
new_key = new_key.replace(".to_out.0.", ".to_out.")
mapped_weights[new_key] = value
self.video_embeddings_connector.load_weights(
@@ -688,22 +715,26 @@ class LTX2TextEncoder(nn.Module):
# Load audio_embeddings_connector weights (same structure as video connector)
audio_connector_weights = {}
for key, value in transformer_weights.items():
if key.startswith("model.diffusion_model.audio_embeddings_connector."):
new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "")
audio_connector_weights[new_key] = value
if is_reformatted:
for key, value in transformer_weights.items():
if key.startswith("audio_embeddings_connector."):
new_key = key.replace("audio_embeddings_connector.", "")
audio_connector_weights[new_key] = value
else:
for key, value in transformer_weights.items():
if key.startswith("model.diffusion_model.audio_embeddings_connector."):
new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "")
audio_connector_weights[new_key] = value
if audio_connector_weights:
# Map weight names to our structure (same as video connector)
mapped_audio_weights = {}
for key, value in audio_connector_weights.items():
new_key = key
# Map ff.net.0.proj -> ff.proj_in (GEGLU projection)
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
# Map ff.net.2 -> ff.proj_out (output Linear)
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
# Map to_out.0 -> to_out (Sequential -> direct)
new_key = new_key.replace(".to_out.0.", ".to_out.")
if not is_reformatted:
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
new_key = new_key.replace(".to_out.0.", ".to_out.")
mapped_audio_weights[new_key] = value
self.audio_embeddings_connector.load_weights(
@@ -713,6 +744,9 @@ class LTX2TextEncoder(nn.Module):
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
if "learnable_registers" in audio_connector_weights:
self.audio_embeddings_connector.learnable_registers = audio_connector_weights["learnable_registers"]
else:
print("WARNING: No transformer weights found for text projection connectors. "
"Text conditioning will use uninitialized weights!")
# Load tokenizer
from transformers import AutoTokenizer