Refactor model loading in generate.py to use dynamic model paths for audio and video components. Simplify weight loading logic in LTX2TextEncoder to accommodate both monolithic and reformatted model structures. Introduce a check for existing model paths in get_model_path function to enhance robustness.
This commit is contained in:
@@ -646,36 +646,63 @@ class LTX2TextEncoder(nn.Module):
|
||||
|
||||
self.language_model = LanguageModel.from_pretrained(text_encoder_path)
|
||||
|
||||
# Load transformer weights for feature extractor and connector
|
||||
transformer_files = list(model_path.glob("ltx-2-19*.safetensors"))
|
||||
if transformer_files:
|
||||
transformer_weights = mx.load(str(transformer_files[0]))
|
||||
# Load transformer weights for feature extractor and connector.
|
||||
# These weights are stored differently depending on the repo format:
|
||||
# 1. Monolithic (Lightricks/LTX-2): single ltx-2-19b-*.safetensors at root
|
||||
# with raw PyTorch key names (model.diffusion_model.* prefix)
|
||||
# 2. Reformatted (prince-canuma/LTX-2-distilled): separate text_projections/
|
||||
# directory with pre-sanitized keys (no prefix, already renamed)
|
||||
transformer_weights = {}
|
||||
is_reformatted = False
|
||||
|
||||
# Try reformatted layout first: text_projections/ subdirectory
|
||||
text_proj_dir = model_path / "text_projections"
|
||||
if text_proj_dir.is_dir():
|
||||
is_reformatted = True
|
||||
for sf in text_proj_dir.glob("*.safetensors"):
|
||||
transformer_weights.update(mx.load(str(sf)))
|
||||
|
||||
# Fall back to monolithic layout: ltx-2-19*.safetensors at root
|
||||
if not transformer_weights:
|
||||
transformer_files = list(model_path.glob("ltx-2-19*.safetensors"))
|
||||
if transformer_files:
|
||||
transformer_weights = mx.load(str(transformer_files[0]))
|
||||
|
||||
if transformer_weights:
|
||||
# Load feature extractor (aggregate_embed)
|
||||
if "text_embedding_projection.aggregate_embed.weight" in transformer_weights:
|
||||
self.feature_extractor.aggregate_embed.weight = transformer_weights[
|
||||
"text_embedding_projection.aggregate_embed.weight"
|
||||
]
|
||||
|
||||
# Reformatted key: "aggregate_embed.weight"
|
||||
# Monolithic key: "text_embedding_projection.aggregate_embed.weight"
|
||||
agg_key = "aggregate_embed.weight" if is_reformatted else "text_embedding_projection.aggregate_embed.weight"
|
||||
if agg_key in transformer_weights:
|
||||
self.feature_extractor.aggregate_embed.weight = transformer_weights[agg_key]
|
||||
|
||||
# Load video_embeddings_connector weights
|
||||
connector_weights = {}
|
||||
for key, value in transformer_weights.items():
|
||||
if key.startswith("model.diffusion_model.video_embeddings_connector."):
|
||||
new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "")
|
||||
connector_weights[new_key] = value
|
||||
if is_reformatted:
|
||||
# Reformatted: keys are already sanitized with "video_embeddings_connector." prefix
|
||||
for key, value in transformer_weights.items():
|
||||
if key.startswith("video_embeddings_connector."):
|
||||
new_key = key.replace("video_embeddings_connector.", "")
|
||||
connector_weights[new_key] = value
|
||||
else:
|
||||
# Monolithic: keys have "model.diffusion_model.video_embeddings_connector." prefix
|
||||
for key, value in transformer_weights.items():
|
||||
if key.startswith("model.diffusion_model.video_embeddings_connector."):
|
||||
new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "")
|
||||
connector_weights[new_key] = value
|
||||
|
||||
if connector_weights:
|
||||
# Map weight names to our structure
|
||||
# Map weight names to our structure (only needed for monolithic/raw PyTorch keys)
|
||||
mapped_weights = {}
|
||||
for key, value in connector_weights.items():
|
||||
new_key = key
|
||||
# Map ff.net.0.proj -> ff.proj_in (GEGLU projection)
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
# Map ff.net.2 -> ff.proj_out (output Linear)
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
# Map to_out.0 -> to_out (Sequential -> direct)
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
if not is_reformatted:
|
||||
# Map ff.net.0.proj -> ff.proj_in (GEGLU projection)
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
# Map ff.net.2 -> ff.proj_out (output Linear)
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
# Map to_out.0 -> to_out (Sequential -> direct)
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
mapped_weights[new_key] = value
|
||||
|
||||
self.video_embeddings_connector.load_weights(
|
||||
@@ -688,22 +715,26 @@ class LTX2TextEncoder(nn.Module):
|
||||
|
||||
# Load audio_embeddings_connector weights (same structure as video connector)
|
||||
audio_connector_weights = {}
|
||||
for key, value in transformer_weights.items():
|
||||
if key.startswith("model.diffusion_model.audio_embeddings_connector."):
|
||||
new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "")
|
||||
audio_connector_weights[new_key] = value
|
||||
if is_reformatted:
|
||||
for key, value in transformer_weights.items():
|
||||
if key.startswith("audio_embeddings_connector."):
|
||||
new_key = key.replace("audio_embeddings_connector.", "")
|
||||
audio_connector_weights[new_key] = value
|
||||
else:
|
||||
for key, value in transformer_weights.items():
|
||||
if key.startswith("model.diffusion_model.audio_embeddings_connector."):
|
||||
new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "")
|
||||
audio_connector_weights[new_key] = value
|
||||
|
||||
if audio_connector_weights:
|
||||
# Map weight names to our structure (same as video connector)
|
||||
mapped_audio_weights = {}
|
||||
for key, value in audio_connector_weights.items():
|
||||
new_key = key
|
||||
# Map ff.net.0.proj -> ff.proj_in (GEGLU projection)
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
# Map ff.net.2 -> ff.proj_out (output Linear)
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
# Map to_out.0 -> to_out (Sequential -> direct)
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
if not is_reformatted:
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
mapped_audio_weights[new_key] = value
|
||||
|
||||
self.audio_embeddings_connector.load_weights(
|
||||
@@ -713,6 +744,9 @@ class LTX2TextEncoder(nn.Module):
|
||||
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
|
||||
if "learnable_registers" in audio_connector_weights:
|
||||
self.audio_embeddings_connector.learnable_registers = audio_connector_weights["learnable_registers"]
|
||||
else:
|
||||
print("WARNING: No transformer weights found for text projection connectors. "
|
||||
"Text conditioning will use uninitialized weights!")
|
||||
|
||||
# Load tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
Reference in New Issue
Block a user