diff --git a/mlx_video/models/ltx/text_encoder.py b/mlx_video/models/ltx/text_encoder.py index acea787..168b7b9 100644 --- a/mlx_video/models/ltx/text_encoder.py +++ b/mlx_video/models/ltx/text_encoder.py @@ -125,7 +125,6 @@ class LanguageModel(nn.Module): sanitized = {} for key, value in weights.items(): if key.startswith(prefix): - # If the weight is float32, cast to bfloat16 if hasattr(value, "dtype") and value.dtype == mx.float32: sanitized[key[len(prefix):]] = value.astype(mx.bfloat16) else: @@ -534,10 +533,8 @@ class LTX2TextEncoder(nn.Module): text_encoder_path = str(text_encoder_path / "text_encoder") self.language_model = LanguageModel.from_pretrained(text_encoder_path) - print(f"Language model data type: {self.language_model.model.embed_tokens.weight.dtype}") # Load transformer weights for feature extractor and connector - transformer_files = list(model_path.glob("ltx-2-19*.safetensors")) if transformer_files: transformer_weights = mx.load(str(transformer_files[0]))