Add custom text encoder with quantization

Co-authored-by: HimanshU Mourya <40685364+codingstark-dev@users.noreply.github.com>
This commit is contained in:
Prince Canuma
2026-01-13 22:56:51 +01:00
parent 01d895bc77
commit fc6ef20c1b
3 changed files with 87 additions and 85 deletions

View File

@@ -150,6 +150,7 @@ def denoise(
def generate_video(
model_repo: str,
text_encoder_repo: str,
prompt: str,
height: int = 512,
width: int = 512,
@@ -189,6 +190,7 @@ def generate_video(
# Get model path
model_path = get_model_path(model_repo)
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
# Calculate latent dimensions
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
@@ -200,8 +202,8 @@ def generate_video(
# Load text encoder
print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}")
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
text_encoder = LTX2TextEncoder(model_path=str(model_path))
text_encoder.load(str(model_path))
text_encoder = LTX2TextEncoder()
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
mx.eval(text_encoder.parameters())
text_embeddings, _ = text_encoder(prompt)
@@ -317,7 +319,7 @@ def generate_video(
elapsed = time.time() - start_time
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}")
print(f"{Colors.BOLD}{Colors.GREEN} ✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}")
print(f"{Colors.BOLD}{Colors.GREEN}✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}")
return video_np
@@ -387,6 +389,12 @@ Examples:
default="Lightricks/LTX-2",
help="Model repository to use (default: Lightricks/LTX-2)"
)
parser.add_argument(
"--text-encoder-repo",
type=str,
default=None,
help="Text encoder repository to use (default: None)"
)
parser.add_argument(
"--verbose",
action="store_true",
@@ -396,6 +404,7 @@ Examples:
generate_video(
model_repo=args.model_repo,
text_encoder_repo=args.text_encoder_repo,
prompt=args.prompt,
height=args.height,
width=args.width,

View File

@@ -12,7 +12,7 @@ from typing import Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.utils import rms_norm
from mlx_video.utils import rms_norm, apply_quantization
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
from mlx_vlm.models.gemma3.language import Gemma3Model
@@ -22,26 +22,10 @@ from mlx_vlm.models.gemma3.config import TextConfig
class LanguageModel(nn.Module):
def __init__(self):
def __init__(self, config: TextConfig):
super().__init__()
# Create config matching LTX-2 text encoder requirements
self.config = TextConfig(
model_type="gemma3_text",
hidden_size=3840,
num_hidden_layers=48,
intermediate_size=15360,
num_attention_heads=16,
num_key_value_heads=8,
head_dim=256,
rms_norm_eps=1e-6,
vocab_size=262208,
query_pre_attn_scalar=256,
rope_global_base_freq=1000000.0,
rope_local_base_freq=10000.0,
rope_traditional=False,
sliding_window=1024,
sliding_window_pattern=6,
)
self.config = config
# Create the Gemma3Model from mlx-vlm
self.model = Gemma3Model(self.config)
@@ -136,9 +120,44 @@ class LanguageModel(nn.Module):
return hidden_states, all_hidden_states
def load_weights(self, weights: List[Tuple[str, mx.array]], strict: bool = True):
"""Load weights into the model."""
self.model.load_weights(weights, strict=strict)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
prefix = "language_model."
sanitized = {
key[len(prefix):]: value
for key, value in weights.items()
if key.startswith(prefix)
}
return sanitized
@classmethod
def from_pretrained(cls, model_path: str):
import json
weight_files = sorted(Path(model_path).glob("*.safetensors"))
config_file = Path(model_path) / "config.json"
config_dict = {}
if config_file.exists():
with open(config_file, "r") as f:
config_dict = json.load(f)
language_model = cls(config=TextConfig.from_dict(config_dict["text_config"]))
else:
raise ValueError(f"Config file not found at {model_path}")
quantization = config_dict.get("quantization", None)
weights = {}
for i, wf in enumerate(weight_files):
weights.update(mx.load(str(wf)))
if hasattr(language_model, "sanitize"):
weights = language_model.sanitize(weights=weights)
apply_quantization(model=language_model, weights=weights, quantization=quantization)
language_model.load_weights(list(weights.items()), strict=False)
return language_model
@@ -476,42 +495,19 @@ class GemmaFeaturesExtractor(nn.Module):
def sanitize_gemma3_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
sanitized = {}
for key, value in weights.items():
new_key = None
if key.startswith("base_text_encoder.language_model."):
new_key = key.replace("base_text_encoder.language_model.", "")
elif key.startswith("language_model.model."):
new_key = key.replace("language_model.model.", "")
elif key.startswith("language_model."):
new_key = key.replace("language_model.", "")
else:
continue
if new_key is None:
continue
sanitized[new_key] = value
return sanitized
class LTX2TextEncoder(nn.Module):
def __init__(
self,
model_path: str = "Lightricks/LTX-2",
hidden_dim: int = 3840,
num_layers: int = 49, # 48 transformer layers + 1 embedding
):
super().__init__()
self._model_path = model_path
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.model = LanguageModel()
self.language_model = None
# Feature extractor: 3840*49 -> 3840
self.feature_extractor = GemmaFeaturesExtractor(
@@ -530,37 +526,17 @@ class LTX2TextEncoder(nn.Module):
self.processor = None
def load(self, model_path: Optional[str] = None):
path = model_path or self._model_path
def load(self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it"):
# Load Gemma weights from text_encoder subdirectory
if Path(path).is_dir():
text_encoder_path = Path(path) / "text_encoder"
if text_encoder_path.exists():
gemma_path = str(text_encoder_path)
else:
gemma_path = path
else:
gemma_path = path
if Path(text_encoder_path / "text_encoder").is_dir():
text_encoder_path = str(text_encoder_path / "text_encoder")
print(f"Loading Gemma 3 text encoder from {gemma_path}...")
weight_files = sorted(Path(gemma_path).glob("*.safetensors"))
all_weights = {}
for i, wf in enumerate(weight_files):
print(f" Loading weight file {i+1}/{len(weight_files)}...")
weights = mx.load(str(wf))
all_weights.update(weights)
# Sanitize and load Gemma weights
sanitized = sanitize_gemma3_weights(all_weights)
print(f" Sanitized Gemma weights: {len(sanitized)}")
self.model.load_weights(list(sanitized.items()), strict=False)
self.language_model = LanguageModel.from_pretrained(text_encoder_path)
# Load transformer weights for feature extractor and connector
transformer_path = Path(model_path or self._model_path)
transformer_files = list(transformer_path.glob("ltx-2*.safetensors"))
transformer_files = list(model_path.glob("ltx-2-19*.safetensors"))
if transformer_files:
print(f"Loading transformer weights for text pipeline...")
transformer_weights = mx.load(str(transformer_files[0]))
# Load feature extractor (aggregate_embed)
@@ -568,7 +544,7 @@ class LTX2TextEncoder(nn.Module):
self.feature_extractor.aggregate_embed.weight = transformer_weights[
"text_embedding_projection.aggregate_embed.weight"
]
print(" Loaded aggregate_embed weights")
# Load video_embeddings_connector weights
connector_weights = {}
@@ -589,20 +565,18 @@ class LTX2TextEncoder(nn.Module):
self.video_embeddings_connector.load_weights(
list(mapped_weights.items()), strict=False
)
print(f" Loaded {len(connector_weights)} connector weights")
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
if "learnable_registers" in connector_weights:
self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"]
print(f" Loaded learnable_registers: {connector_weights['learnable_registers'].shape}")
# Load tokenizer
from transformers import AutoTokenizer
tokenizer_path = Path(model_path or self._model_path) / "tokenizer"
tokenizer_path = model_path / "tokenizer"
if tokenizer_path.exists():
self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True)
else:
self.processor = AutoTokenizer.from_pretrained(gemma_path, trust_remote_code=True)
self.processor = AutoTokenizer.from_pretrained(text_encoder_path, trust_remote_code=True)
# Set left padding to match official LTX-2 text encoder
self.processor.padding_side = "left"
@@ -627,7 +601,7 @@ class LTX2TextEncoder(nn.Module):
input_ids = mx.array(inputs["input_ids"])
attention_mask = mx.array(inputs["attention_mask"])
_, all_hidden_states = self.model(input_ids, attention_mask, output_hidden_states=True)
_, all_hidden_states = self.language_model(input_ids, attention_mask, output_hidden_states=True)
concat_hidden = norm_and_concat_hidden_states(
all_hidden_states, attention_mask, padding_side="left"

View File

@@ -1,7 +1,4 @@
"""Utility functions for MLX Video."""
import math
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
@@ -22,6 +19,28 @@ def get_model_path(model_repo: str):
allow_patterns=["*.safetensors", "*.json"],
))
def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
if quantization is not None:
def get_class_predicate(p, m):
# Handle custom per layer quantizations
if p in quantization:
return quantization[p]
if not hasattr(m, "to_quantized"):
return False
# Skip layers not divisible by 64
if hasattr(m, "weight") and m.weight.shape[0] % 64 != 0:
return False
# Handle legacy models which may not have everything quantized
return f"{p}.scales" in weights
nn.quantize(
model,
group_size=quantization["group_size"],
bits=quantization["bits"],
mode=quantization.get("mode", "affine"),
class_predicate=get_class_predicate,
)
@partial(mx.compile, shapeless=True)
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: