Remove commented-out code and clean up text encoder initialization
This commit is contained in:
@@ -125,7 +125,6 @@ class LanguageModel(nn.Module):
|
||||
sanitized = {}
|
||||
for key, value in weights.items():
|
||||
if key.startswith(prefix):
|
||||
# If the weight is float32, cast to bfloat16
|
||||
if hasattr(value, "dtype") and value.dtype == mx.float32:
|
||||
sanitized[key[len(prefix):]] = value.astype(mx.bfloat16)
|
||||
else:
|
||||
@@ -534,10 +533,8 @@ class LTX2TextEncoder(nn.Module):
|
||||
text_encoder_path = str(text_encoder_path / "text_encoder")
|
||||
|
||||
self.language_model = LanguageModel.from_pretrained(text_encoder_path)
|
||||
print(f"Language model data type: {self.language_model.model.embed_tokens.weight.dtype}")
|
||||
|
||||
# Load transformer weights for feature extractor and connector
|
||||
|
||||
transformer_files = list(model_path.glob("ltx-2-19*.safetensors"))
|
||||
if transformer_files:
|
||||
transformer_weights = mx.load(str(transformer_files[0]))
|
||||
|
||||
Reference in New Issue
Block a user