diff --git a/mlx_video/models/ltx/text_encoder.py b/mlx_video/models/ltx/text_encoder.py index 6934965..acea787 100644 --- a/mlx_video/models/ltx/text_encoder.py +++ b/mlx_video/models/ltx/text_encoder.py @@ -122,13 +122,15 @@ class LanguageModel(nn.Module): def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: prefix = "language_model." - sanitized = { - key[len(prefix):]: value - for key, value in weights.items() - if key.startswith(prefix) - } + sanitized = {} + for key, value in weights.items(): + if key.startswith(prefix): + # If the weight is float32, cast to bfloat16 + if hasattr(value, "dtype") and value.dtype == mx.float32: + sanitized[key[len(prefix):]] = value.astype(mx.bfloat16) + else: + sanitized[key[len(prefix):]] = value return sanitized - @classmethod def from_pretrained(cls, model_path: str): import json @@ -532,6 +534,7 @@ class LTX2TextEncoder(nn.Module): text_encoder_path = str(text_encoder_path / "text_encoder") self.language_model = LanguageModel.from_pretrained(text_encoder_path) + print(f"Language model data type: {self.language_model.model.embed_tokens.weight.dtype}") # Load transformer weights for feature extractor and connector