Cast LM weights to bfloat16
This commit is contained in:
@@ -122,13 +122,15 @@ class LanguageModel(nn.Module):
|
|||||||
|
|
||||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
prefix = "language_model."
|
prefix = "language_model."
|
||||||
sanitized = {
|
sanitized = {}
|
||||||
key[len(prefix):]: value
|
for key, value in weights.items():
|
||||||
for key, value in weights.items()
|
if key.startswith(prefix):
|
||||||
if key.startswith(prefix)
|
# If the weight is float32, cast to bfloat16
|
||||||
}
|
if hasattr(value, "dtype") and value.dtype == mx.float32:
|
||||||
|
sanitized[key[len(prefix):]] = value.astype(mx.bfloat16)
|
||||||
|
else:
|
||||||
|
sanitized[key[len(prefix):]] = value
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, model_path: str):
|
def from_pretrained(cls, model_path: str):
|
||||||
import json
|
import json
|
||||||
@@ -532,6 +534,7 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
text_encoder_path = str(text_encoder_path / "text_encoder")
|
text_encoder_path = str(text_encoder_path / "text_encoder")
|
||||||
|
|
||||||
self.language_model = LanguageModel.from_pretrained(text_encoder_path)
|
self.language_model = LanguageModel.from_pretrained(text_encoder_path)
|
||||||
|
print(f"Language model data type: {self.language_model.model.embed_tokens.weight.dtype}")
|
||||||
|
|
||||||
# Load transformer weights for feature extractor and connector
|
# Load transformer weights for feature extractor and connector
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user