Refactor model loading in generate.py to use dynamic model paths for audio and video components. Simplify weight loading logic in LTX2TextEncoder to accommodate both monolithic and reformatted model structures. Introduce a check for existing model paths in get_model_path function to enhance robustness.

This commit is contained in:
Prince Canuma
2026-03-09 15:51:21 +01:00
parent d1dd30cbac
commit 9f37dab076
5 changed files with 85 additions and 50 deletions

View File

@@ -497,14 +497,17 @@ class LTXModel(nn.Module):
def sanitize(self, weights: dict) -> dict:
sanitized = {}
if "model.diffusion_model." not in weights:
has_raw_prefix = any(k.startswith("model.diffusion_model.") for k in weights)
if not has_raw_prefix:
return weights
for key, value in weights.items():
new_key = key
# Skip non-transformer weights (VAE, vocoder, audio_vae, connectors)
if not key.startswith("model.diffusion_model.") or "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
if not key.startswith("model.diffusion_model."):
continue
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
continue
# Remove 'model.diffusion_model.' prefix
@@ -520,7 +523,6 @@ class LTXModel(nn.Module):
new_key = new_key.replace(".linear_1.", ".linear1.")
new_key = new_key.replace(".linear_2.", ".linear2.")
sanitized[new_key] = value
return sanitized