Refactor model loading in generate.py to use dynamic model paths for audio and video components. Simplify weight loading logic in LTX2TextEncoder to accommodate both monolithic and reformatted model structures. Introduce a check for existing model paths in get_model_path function to enhance robustness.
This commit is contained in:
@@ -497,14 +497,17 @@ class LTXModel(nn.Module):
|
||||
|
||||
def sanitize(self, weights: dict) -> dict:
|
||||
sanitized = {}
|
||||
|
||||
if "model.diffusion_model." not in weights:
|
||||
|
||||
has_raw_prefix = any(k.startswith("model.diffusion_model.") for k in weights)
|
||||
if not has_raw_prefix:
|
||||
return weights
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
# Skip non-transformer weights (VAE, vocoder, audio_vae, connectors)
|
||||
if not key.startswith("model.diffusion_model.") or "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
|
||||
|
||||
if not key.startswith("model.diffusion_model."):
|
||||
continue
|
||||
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
|
||||
continue
|
||||
|
||||
# Remove 'model.diffusion_model.' prefix
|
||||
@@ -520,7 +523,6 @@ class LTXModel(nn.Module):
|
||||
new_key = new_key.replace(".linear_1.", ".linear1.")
|
||||
new_key = new_key.replace(".linear_2.", ".linear2.")
|
||||
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
Reference in New Issue
Block a user