Cast LM weights to bfloat16

This commit is contained in:
Prince Canuma
2026-01-13 23:30:26 +01:00
parent fc6ef20c1b
commit ea063f7550

View File

@@ -122,13 +122,15 @@ class LanguageModel(nn.Module):
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
prefix = "language_model." prefix = "language_model."
sanitized = { sanitized = {}
key[len(prefix):]: value for key, value in weights.items():
for key, value in weights.items() if key.startswith(prefix):
if key.startswith(prefix) # If the weight is float32, cast to bfloat16
} if hasattr(value, "dtype") and value.dtype == mx.float32:
sanitized[key[len(prefix):]] = value.astype(mx.bfloat16)
else:
sanitized[key[len(prefix):]] = value
return sanitized return sanitized
@classmethod @classmethod
def from_pretrained(cls, model_path: str): def from_pretrained(cls, model_path: str):
import json import json
@@ -532,6 +534,7 @@ class LTX2TextEncoder(nn.Module):
text_encoder_path = str(text_encoder_path / "text_encoder") text_encoder_path = str(text_encoder_path / "text_encoder")
self.language_model = LanguageModel.from_pretrained(text_encoder_path) self.language_model = LanguageModel.from_pretrained(text_encoder_path)
print(f"Language model data type: {self.language_model.model.embed_tokens.weight.dtype}")
# Load transformer weights for feature extractor and connector # Load transformer weights for feature extractor and connector