Add custom text encoder with quantization

Co-authored-by: HimanshU Mourya <40685364+codingstark-dev@users.noreply.github.com>
This commit is contained in:
Prince Canuma
2026-01-13 22:56:51 +01:00
parent 01d895bc77
commit fc6ef20c1b
3 changed files with 87 additions and 85 deletions

View File

@@ -150,6 +150,7 @@ def denoise(
def generate_video( def generate_video(
model_repo: str, model_repo: str,
text_encoder_repo: str,
prompt: str, prompt: str,
height: int = 512, height: int = 512,
width: int = 512, width: int = 512,
@@ -189,6 +190,7 @@ def generate_video(
# Get model path # Get model path
model_path = get_model_path(model_repo) model_path = get_model_path(model_repo)
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
# Calculate latent dimensions # Calculate latent dimensions
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32 stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
@@ -200,8 +202,8 @@ def generate_video(
# Load text encoder # Load text encoder
print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}") print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}")
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
text_encoder = LTX2TextEncoder(model_path=str(model_path)) text_encoder = LTX2TextEncoder()
text_encoder.load(str(model_path)) text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
mx.eval(text_encoder.parameters()) mx.eval(text_encoder.parameters())
text_embeddings, _ = text_encoder(prompt) text_embeddings, _ = text_encoder(prompt)
@@ -387,6 +389,12 @@ Examples:
default="Lightricks/LTX-2", default="Lightricks/LTX-2",
help="Model repository to use (default: Lightricks/LTX-2)" help="Model repository to use (default: Lightricks/LTX-2)"
) )
parser.add_argument(
"--text-encoder-repo",
type=str,
default=None,
help="Text encoder repository to use (default: None)"
)
parser.add_argument( parser.add_argument(
"--verbose", "--verbose",
action="store_true", action="store_true",
@@ -396,6 +404,7 @@ Examples:
generate_video( generate_video(
model_repo=args.model_repo, model_repo=args.model_repo,
text_encoder_repo=args.text_encoder_repo,
prompt=args.prompt, prompt=args.prompt,
height=args.height, height=args.height,
width=args.width, width=args.width,

View File

@@ -12,7 +12,7 @@ from typing import Dict, List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx_video.utils import rms_norm from mlx_video.utils import rms_norm, apply_quantization
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
from mlx_vlm.models.gemma3.language import Gemma3Model from mlx_vlm.models.gemma3.language import Gemma3Model
@@ -22,26 +22,10 @@ from mlx_vlm.models.gemma3.config import TextConfig
class LanguageModel(nn.Module): class LanguageModel(nn.Module):
def __init__(self): def __init__(self, config: TextConfig):
super().__init__() super().__init__()
# Create config matching LTX-2 text encoder requirements # Create config matching LTX-2 text encoder requirements
self.config = TextConfig( self.config = config
model_type="gemma3_text",
hidden_size=3840,
num_hidden_layers=48,
intermediate_size=15360,
num_attention_heads=16,
num_key_value_heads=8,
head_dim=256,
rms_norm_eps=1e-6,
vocab_size=262208,
query_pre_attn_scalar=256,
rope_global_base_freq=1000000.0,
rope_local_base_freq=10000.0,
rope_traditional=False,
sliding_window=1024,
sliding_window_pattern=6,
)
# Create the Gemma3Model from mlx-vlm # Create the Gemma3Model from mlx-vlm
self.model = Gemma3Model(self.config) self.model = Gemma3Model(self.config)
@@ -136,9 +120,44 @@ class LanguageModel(nn.Module):
return hidden_states, all_hidden_states return hidden_states, all_hidden_states
def load_weights(self, weights: List[Tuple[str, mx.array]], strict: bool = True): def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Load weights into the model.""" prefix = "language_model."
self.model.load_weights(weights, strict=strict) sanitized = {
key[len(prefix):]: value
for key, value in weights.items()
if key.startswith(prefix)
}
return sanitized
@classmethod
def from_pretrained(cls, model_path: str):
import json
weight_files = sorted(Path(model_path).glob("*.safetensors"))
config_file = Path(model_path) / "config.json"
config_dict = {}
if config_file.exists():
with open(config_file, "r") as f:
config_dict = json.load(f)
language_model = cls(config=TextConfig.from_dict(config_dict["text_config"]))
else:
raise ValueError(f"Config file not found at {model_path}")
quantization = config_dict.get("quantization", None)
weights = {}
for i, wf in enumerate(weight_files):
weights.update(mx.load(str(wf)))
if hasattr(language_model, "sanitize"):
weights = language_model.sanitize(weights=weights)
apply_quantization(model=language_model, weights=weights, quantization=quantization)
language_model.load_weights(list(weights.items()), strict=False)
return language_model
@@ -476,42 +495,19 @@ class GemmaFeaturesExtractor(nn.Module):
def sanitize_gemma3_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
sanitized = {}
for key, value in weights.items():
new_key = None
if key.startswith("base_text_encoder.language_model."):
new_key = key.replace("base_text_encoder.language_model.", "")
elif key.startswith("language_model.model."):
new_key = key.replace("language_model.model.", "")
elif key.startswith("language_model."):
new_key = key.replace("language_model.", "")
else:
continue
if new_key is None:
continue
sanitized[new_key] = value
return sanitized
class LTX2TextEncoder(nn.Module): class LTX2TextEncoder(nn.Module):
def __init__( def __init__(
self, self,
model_path: str = "Lightricks/LTX-2",
hidden_dim: int = 3840, hidden_dim: int = 3840,
num_layers: int = 49, # 48 transformer layers + 1 embedding num_layers: int = 49, # 48 transformer layers + 1 embedding
): ):
super().__init__() super().__init__()
self._model_path = model_path
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.num_layers = num_layers self.num_layers = num_layers
self.model = LanguageModel() self.language_model = None
# Feature extractor: 3840*49 -> 3840 # Feature extractor: 3840*49 -> 3840
self.feature_extractor = GemmaFeaturesExtractor( self.feature_extractor = GemmaFeaturesExtractor(
@@ -530,37 +526,17 @@ class LTX2TextEncoder(nn.Module):
self.processor = None self.processor = None
def load(self, model_path: Optional[str] = None): def load(self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it"):
path = model_path or self._model_path
# Load Gemma weights from text_encoder subdirectory if Path(text_encoder_path / "text_encoder").is_dir():
if Path(path).is_dir(): text_encoder_path = str(text_encoder_path / "text_encoder")
text_encoder_path = Path(path) / "text_encoder"
if text_encoder_path.exists():
gemma_path = str(text_encoder_path)
else:
gemma_path = path
else:
gemma_path = path
print(f"Loading Gemma 3 text encoder from {gemma_path}...") self.language_model = LanguageModel.from_pretrained(text_encoder_path)
weight_files = sorted(Path(gemma_path).glob("*.safetensors"))
all_weights = {}
for i, wf in enumerate(weight_files):
print(f" Loading weight file {i+1}/{len(weight_files)}...")
weights = mx.load(str(wf))
all_weights.update(weights)
# Sanitize and load Gemma weights
sanitized = sanitize_gemma3_weights(all_weights)
print(f" Sanitized Gemma weights: {len(sanitized)}")
self.model.load_weights(list(sanitized.items()), strict=False)
# Load transformer weights for feature extractor and connector # Load transformer weights for feature extractor and connector
transformer_path = Path(model_path or self._model_path)
transformer_files = list(transformer_path.glob("ltx-2*.safetensors")) transformer_files = list(model_path.glob("ltx-2-19*.safetensors"))
if transformer_files: if transformer_files:
print(f"Loading transformer weights for text pipeline...")
transformer_weights = mx.load(str(transformer_files[0])) transformer_weights = mx.load(str(transformer_files[0]))
# Load feature extractor (aggregate_embed) # Load feature extractor (aggregate_embed)
@@ -568,7 +544,7 @@ class LTX2TextEncoder(nn.Module):
self.feature_extractor.aggregate_embed.weight = transformer_weights[ self.feature_extractor.aggregate_embed.weight = transformer_weights[
"text_embedding_projection.aggregate_embed.weight" "text_embedding_projection.aggregate_embed.weight"
] ]
print(" Loaded aggregate_embed weights")
# Load video_embeddings_connector weights # Load video_embeddings_connector weights
connector_weights = {} connector_weights = {}
@@ -589,20 +565,18 @@ class LTX2TextEncoder(nn.Module):
self.video_embeddings_connector.load_weights( self.video_embeddings_connector.load_weights(
list(mapped_weights.items()), strict=False list(mapped_weights.items()), strict=False
) )
print(f" Loaded {len(connector_weights)} connector weights")
# Manually load learnable_registers (it's a plain mx.array, not a parameter) # Manually load learnable_registers (it's a plain mx.array, not a parameter)
if "learnable_registers" in connector_weights: if "learnable_registers" in connector_weights:
self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"] self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"]
print(f" Loaded learnable_registers: {connector_weights['learnable_registers'].shape}")
# Load tokenizer # Load tokenizer
from transformers import AutoTokenizer from transformers import AutoTokenizer
tokenizer_path = Path(model_path or self._model_path) / "tokenizer" tokenizer_path = model_path / "tokenizer"
if tokenizer_path.exists(): if tokenizer_path.exists():
self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True) self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True)
else: else:
self.processor = AutoTokenizer.from_pretrained(gemma_path, trust_remote_code=True) self.processor = AutoTokenizer.from_pretrained(text_encoder_path, trust_remote_code=True)
# Set left padding to match official LTX-2 text encoder # Set left padding to match official LTX-2 text encoder
self.processor.padding_side = "left" self.processor.padding_side = "left"
@@ -627,7 +601,7 @@ class LTX2TextEncoder(nn.Module):
input_ids = mx.array(inputs["input_ids"]) input_ids = mx.array(inputs["input_ids"])
attention_mask = mx.array(inputs["attention_mask"]) attention_mask = mx.array(inputs["attention_mask"])
_, all_hidden_states = self.model(input_ids, attention_mask, output_hidden_states=True) _, all_hidden_states = self.language_model(input_ids, attention_mask, output_hidden_states=True)
concat_hidden = norm_and_concat_hidden_states( concat_hidden = norm_and_concat_hidden_states(
all_hidden_states, attention_mask, padding_side="left" all_hidden_states, attention_mask, padding_side="left"

View File

@@ -1,7 +1,4 @@
"""Utility functions for MLX Video."""
import math import math
from typing import Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@@ -22,6 +19,28 @@ def get_model_path(model_repo: str):
allow_patterns=["*.safetensors", "*.json"], allow_patterns=["*.safetensors", "*.json"],
)) ))
def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
if quantization is not None:
def get_class_predicate(p, m):
# Handle custom per layer quantizations
if p in quantization:
return quantization[p]
if not hasattr(m, "to_quantized"):
return False
# Skip layers not divisible by 64
if hasattr(m, "weight") and m.weight.shape[0] % 64 != 0:
return False
# Handle legacy models which may not have everything quantized
return f"{p}.scales" in weights
nn.quantize(
model,
group_size=quantization["group_size"],
bits=quantization["bits"],
mode=quantization.get("mode", "affine"),
class_predicate=get_class_predicate,
)
@partial(mx.compile, shapeless=True) @partial(mx.compile, shapeless=True)
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: