Add custom text encoder with quantization

Co-authored-by: HimanshU Mourya <40685364+codingstark-dev@users.noreply.github.com>
This commit is contained in:
Prince Canuma
2026-01-13 22:56:51 +01:00
parent 01d895bc77
commit fc6ef20c1b
3 changed files with 87 additions and 85 deletions

View File

@@ -12,7 +12,7 @@ from typing import Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.utils import rms_norm
from mlx_video.utils import rms_norm, apply_quantization
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
from mlx_vlm.models.gemma3.language import Gemma3Model
@@ -22,26 +22,10 @@ from mlx_vlm.models.gemma3.config import TextConfig
class LanguageModel(nn.Module):
def __init__(self):
def __init__(self, config: TextConfig):
super().__init__()
# Create config matching LTX-2 text encoder requirements
self.config = TextConfig(
model_type="gemma3_text",
hidden_size=3840,
num_hidden_layers=48,
intermediate_size=15360,
num_attention_heads=16,
num_key_value_heads=8,
head_dim=256,
rms_norm_eps=1e-6,
vocab_size=262208,
query_pre_attn_scalar=256,
rope_global_base_freq=1000000.0,
rope_local_base_freq=10000.0,
rope_traditional=False,
sliding_window=1024,
sliding_window_pattern=6,
)
self.config = config
# Create the Gemma3Model from mlx-vlm
self.model = Gemma3Model(self.config)
@@ -136,9 +120,44 @@ class LanguageModel(nn.Module):
return hidden_states, all_hidden_states
def load_weights(self, weights: List[Tuple[str, mx.array]], strict: bool = True):
"""Load weights into the model."""
self.model.load_weights(weights, strict=strict)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
prefix = "language_model."
sanitized = {
key[len(prefix):]: value
for key, value in weights.items()
if key.startswith(prefix)
}
return sanitized
@classmethod
def from_pretrained(cls, model_path: str):
import json
weight_files = sorted(Path(model_path).glob("*.safetensors"))
config_file = Path(model_path) / "config.json"
config_dict = {}
if config_file.exists():
with open(config_file, "r") as f:
config_dict = json.load(f)
language_model = cls(config=TextConfig.from_dict(config_dict["text_config"]))
else:
raise ValueError(f"Config file not found at {model_path}")
quantization = config_dict.get("quantization", None)
weights = {}
for i, wf in enumerate(weight_files):
weights.update(mx.load(str(wf)))
if hasattr(language_model, "sanitize"):
weights = language_model.sanitize(weights=weights)
apply_quantization(model=language_model, weights=weights, quantization=quantization)
language_model.load_weights(list(weights.items()), strict=False)
return language_model
@@ -476,42 +495,19 @@ class GemmaFeaturesExtractor(nn.Module):
def sanitize_gemma3_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
sanitized = {}
for key, value in weights.items():
new_key = None
if key.startswith("base_text_encoder.language_model."):
new_key = key.replace("base_text_encoder.language_model.", "")
elif key.startswith("language_model.model."):
new_key = key.replace("language_model.model.", "")
elif key.startswith("language_model."):
new_key = key.replace("language_model.", "")
else:
continue
if new_key is None:
continue
sanitized[new_key] = value
return sanitized
class LTX2TextEncoder(nn.Module):
def __init__(
self,
model_path: str = "Lightricks/LTX-2",
hidden_dim: int = 3840,
num_layers: int = 49, # 48 transformer layers + 1 embedding
):
super().__init__()
self._model_path = model_path
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.model = LanguageModel()
self.language_model = None
# Feature extractor: 3840*49 -> 3840
self.feature_extractor = GemmaFeaturesExtractor(
@@ -530,37 +526,17 @@ class LTX2TextEncoder(nn.Module):
self.processor = None
def load(self, model_path: Optional[str] = None):
path = model_path or self._model_path
def load(self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it"):
# Load Gemma weights from text_encoder subdirectory
if Path(path).is_dir():
text_encoder_path = Path(path) / "text_encoder"
if text_encoder_path.exists():
gemma_path = str(text_encoder_path)
else:
gemma_path = path
else:
gemma_path = path
print(f"Loading Gemma 3 text encoder from {gemma_path}...")
weight_files = sorted(Path(gemma_path).glob("*.safetensors"))
all_weights = {}
for i, wf in enumerate(weight_files):
print(f" Loading weight file {i+1}/{len(weight_files)}...")
weights = mx.load(str(wf))
all_weights.update(weights)
# Sanitize and load Gemma weights
sanitized = sanitize_gemma3_weights(all_weights)
print(f" Sanitized Gemma weights: {len(sanitized)}")
self.model.load_weights(list(sanitized.items()), strict=False)
if Path(text_encoder_path / "text_encoder").is_dir():
text_encoder_path = str(text_encoder_path / "text_encoder")
self.language_model = LanguageModel.from_pretrained(text_encoder_path)
# Load transformer weights for feature extractor and connector
transformer_path = Path(model_path or self._model_path)
transformer_files = list(transformer_path.glob("ltx-2*.safetensors"))
transformer_files = list(model_path.glob("ltx-2-19*.safetensors"))
if transformer_files:
print(f"Loading transformer weights for text pipeline...")
transformer_weights = mx.load(str(transformer_files[0]))
# Load feature extractor (aggregate_embed)
@@ -568,7 +544,7 @@ class LTX2TextEncoder(nn.Module):
self.feature_extractor.aggregate_embed.weight = transformer_weights[
"text_embedding_projection.aggregate_embed.weight"
]
print(" Loaded aggregate_embed weights")
# Load video_embeddings_connector weights
connector_weights = {}
@@ -589,20 +565,18 @@ class LTX2TextEncoder(nn.Module):
self.video_embeddings_connector.load_weights(
list(mapped_weights.items()), strict=False
)
print(f" Loaded {len(connector_weights)} connector weights")
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
if "learnable_registers" in connector_weights:
self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"]
print(f" Loaded learnable_registers: {connector_weights['learnable_registers'].shape}")
# Load tokenizer
from transformers import AutoTokenizer
tokenizer_path = Path(model_path or self._model_path) / "tokenizer"
tokenizer_path = model_path / "tokenizer"
if tokenizer_path.exists():
self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True)
else:
self.processor = AutoTokenizer.from_pretrained(gemma_path, trust_remote_code=True)
self.processor = AutoTokenizer.from_pretrained(text_encoder_path, trust_remote_code=True)
# Set left padding to match official LTX-2 text encoder
self.processor.padding_side = "left"
@@ -627,7 +601,7 @@ class LTX2TextEncoder(nn.Module):
input_ids = mx.array(inputs["input_ids"])
attention_mask = mx.array(inputs["attention_mask"])
_, all_hidden_states = self.model(input_ids, attention_mask, output_hidden_states=True)
_, all_hidden_states = self.language_model(input_ids, attention_mask, output_hidden_states=True)
concat_hidden = norm_and_concat_hidden_states(
all_hidden_states, attention_mask, padding_side="left"