From fc6ef20c1bf3c260baa15bea244dbcb7b3cfbe17 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Tue, 13 Jan 2026 22:56:51 +0100 Subject: [PATCH] Add custom text encoder with quantization Co-authored-by: HimanshU Mourya <40685364+codingstark-dev@users.noreply.github.com> --- mlx_video/generate.py | 15 ++- mlx_video/models/ltx/text_encoder.py | 132 +++++++++++---------------- mlx_video/utils.py | 25 ++++- 3 files changed, 87 insertions(+), 85 deletions(-) diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 9167d0c..4bf889f 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -150,6 +150,7 @@ def denoise( def generate_video( model_repo: str, + text_encoder_repo: str, prompt: str, height: int = 512, width: int = 512, @@ -189,6 +190,7 @@ def generate_video( # Get model path model_path = get_model_path(model_repo) + text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo) # Calculate latent dimensions stage1_h, stage1_w = height // 2 // 32, width // 2 // 32 @@ -200,8 +202,8 @@ def generate_video( # Load text encoder print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}") from mlx_video.models.ltx.text_encoder import LTX2TextEncoder - text_encoder = LTX2TextEncoder(model_path=str(model_path)) - text_encoder.load(str(model_path)) + text_encoder = LTX2TextEncoder() + text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path) mx.eval(text_encoder.parameters()) text_embeddings, _ = text_encoder(prompt) @@ -317,7 +319,7 @@ def generate_video( elapsed = time.time() - start_time print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}") - print(f"{Colors.BOLD}{Colors.GREEN} ✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}") + print(f"{Colors.BOLD}{Colors.GREEN}✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}") return video_np @@ -387,6 +389,12 @@ Examples: default="Lightricks/LTX-2", help="Model repository to use (default: Lightricks/LTX-2)" ) + parser.add_argument( + "--text-encoder-repo", + type=str, + default=None, + help="Text encoder repository to use (default: None)" + ) parser.add_argument( "--verbose", action="store_true", @@ -396,6 +404,7 @@ Examples: generate_video( model_repo=args.model_repo, + text_encoder_repo=args.text_encoder_repo, prompt=args.prompt, height=args.height, width=args.width, diff --git a/mlx_video/models/ltx/text_encoder.py b/mlx_video/models/ltx/text_encoder.py index 0f66eea..6934965 100644 --- a/mlx_video/models/ltx/text_encoder.py +++ b/mlx_video/models/ltx/text_encoder.py @@ -12,7 +12,7 @@ from typing import Dict, List, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from mlx_video.utils import rms_norm +from mlx_video.utils import rms_norm, apply_quantization from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb from mlx_vlm.models.gemma3.language import Gemma3Model @@ -22,26 +22,10 @@ from mlx_vlm.models.gemma3.config import TextConfig class LanguageModel(nn.Module): - def __init__(self): + def __init__(self, config: TextConfig): super().__init__() # Create config matching LTX-2 text encoder requirements - self.config = TextConfig( - model_type="gemma3_text", - hidden_size=3840, - num_hidden_layers=48, - intermediate_size=15360, - num_attention_heads=16, - num_key_value_heads=8, - head_dim=256, - rms_norm_eps=1e-6, - vocab_size=262208, - query_pre_attn_scalar=256, - rope_global_base_freq=1000000.0, - rope_local_base_freq=10000.0, - rope_traditional=False, - sliding_window=1024, - sliding_window_pattern=6, - ) + self.config = config # Create the Gemma3Model from mlx-vlm self.model = Gemma3Model(self.config) @@ -136,9 +120,44 @@ class LanguageModel(nn.Module): return hidden_states, all_hidden_states - def load_weights(self, weights: List[Tuple[str, mx.array]], strict: bool = True): - """Load weights into the model.""" - self.model.load_weights(weights, strict=strict) + def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: + prefix = "language_model." + sanitized = { + key[len(prefix):]: value + for key, value in weights.items() + if key.startswith(prefix) + } + return sanitized + + @classmethod + def from_pretrained(cls, model_path: str): + import json + weight_files = sorted(Path(model_path).glob("*.safetensors")) + config_file = Path(model_path) / "config.json" + config_dict = {} + if config_file.exists(): + with open(config_file, "r") as f: + config_dict = json.load(f) + + language_model = cls(config=TextConfig.from_dict(config_dict["text_config"])) + else: + raise ValueError(f"Config file not found at {model_path}") + + quantization = config_dict.get("quantization", None) + weights = {} + for i, wf in enumerate(weight_files): + weights.update(mx.load(str(wf))) + + + if hasattr(language_model, "sanitize"): + weights = language_model.sanitize(weights=weights) + + + apply_quantization(model=language_model, weights=weights, quantization=quantization) + + language_model.load_weights(list(weights.items()), strict=False) + + return language_model @@ -476,42 +495,19 @@ class GemmaFeaturesExtractor(nn.Module): -def sanitize_gemma3_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: - sanitized = {} - - for key, value in weights.items(): - new_key = None - - if key.startswith("base_text_encoder.language_model."): - new_key = key.replace("base_text_encoder.language_model.", "") - elif key.startswith("language_model.model."): - new_key = key.replace("language_model.model.", "") - elif key.startswith("language_model."): - new_key = key.replace("language_model.", "") - else: - continue - - if new_key is None: - continue - - sanitized[new_key] = value - - return sanitized class LTX2TextEncoder(nn.Module): def __init__( self, - model_path: str = "Lightricks/LTX-2", hidden_dim: int = 3840, num_layers: int = 49, # 48 transformer layers + 1 embedding ): super().__init__() - self._model_path = model_path self.hidden_dim = hidden_dim self.num_layers = num_layers - self.model = LanguageModel() + self.language_model = None # Feature extractor: 3840*49 -> 3840 self.feature_extractor = GemmaFeaturesExtractor( @@ -530,37 +526,17 @@ class LTX2TextEncoder(nn.Module): self.processor = None - def load(self, model_path: Optional[str] = None): - path = model_path or self._model_path + def load(self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it"): - # Load Gemma weights from text_encoder subdirectory - if Path(path).is_dir(): - text_encoder_path = Path(path) / "text_encoder" - if text_encoder_path.exists(): - gemma_path = str(text_encoder_path) - else: - gemma_path = path - else: - gemma_path = path - - print(f"Loading Gemma 3 text encoder from {gemma_path}...") - weight_files = sorted(Path(gemma_path).glob("*.safetensors")) - all_weights = {} - for i, wf in enumerate(weight_files): - print(f" Loading weight file {i+1}/{len(weight_files)}...") - weights = mx.load(str(wf)) - all_weights.update(weights) - - # Sanitize and load Gemma weights - sanitized = sanitize_gemma3_weights(all_weights) - print(f" Sanitized Gemma weights: {len(sanitized)}") - self.model.load_weights(list(sanitized.items()), strict=False) + if Path(text_encoder_path / "text_encoder").is_dir(): + text_encoder_path = str(text_encoder_path / "text_encoder") + + self.language_model = LanguageModel.from_pretrained(text_encoder_path) # Load transformer weights for feature extractor and connector - transformer_path = Path(model_path or self._model_path) - transformer_files = list(transformer_path.glob("ltx-2*.safetensors")) + + transformer_files = list(model_path.glob("ltx-2-19*.safetensors")) if transformer_files: - print(f"Loading transformer weights for text pipeline...") transformer_weights = mx.load(str(transformer_files[0])) # Load feature extractor (aggregate_embed) @@ -568,7 +544,7 @@ class LTX2TextEncoder(nn.Module): self.feature_extractor.aggregate_embed.weight = transformer_weights[ "text_embedding_projection.aggregate_embed.weight" ] - print(" Loaded aggregate_embed weights") + # Load video_embeddings_connector weights connector_weights = {} @@ -589,20 +565,18 @@ class LTX2TextEncoder(nn.Module): self.video_embeddings_connector.load_weights( list(mapped_weights.items()), strict=False ) - print(f" Loaded {len(connector_weights)} connector weights") # Manually load learnable_registers (it's a plain mx.array, not a parameter) if "learnable_registers" in connector_weights: self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"] - print(f" Loaded learnable_registers: {connector_weights['learnable_registers'].shape}") # Load tokenizer from transformers import AutoTokenizer - tokenizer_path = Path(model_path or self._model_path) / "tokenizer" + tokenizer_path = model_path / "tokenizer" if tokenizer_path.exists(): self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True) else: - self.processor = AutoTokenizer.from_pretrained(gemma_path, trust_remote_code=True) + self.processor = AutoTokenizer.from_pretrained(text_encoder_path, trust_remote_code=True) # Set left padding to match official LTX-2 text encoder self.processor.padding_side = "left" @@ -627,7 +601,7 @@ class LTX2TextEncoder(nn.Module): input_ids = mx.array(inputs["input_ids"]) attention_mask = mx.array(inputs["attention_mask"]) - _, all_hidden_states = self.model(input_ids, attention_mask, output_hidden_states=True) + _, all_hidden_states = self.language_model(input_ids, attention_mask, output_hidden_states=True) concat_hidden = norm_and_concat_hidden_states( all_hidden_states, attention_mask, padding_side="left" diff --git a/mlx_video/utils.py b/mlx_video/utils.py index 1ef162f..c6840bd 100644 --- a/mlx_video/utils.py +++ b/mlx_video/utils.py @@ -1,7 +1,4 @@ -"""Utility functions for MLX Video.""" - import math -from typing import Optional import mlx.core as mx import mlx.nn as nn @@ -22,6 +19,28 @@ def get_model_path(model_repo: str): allow_patterns=["*.safetensors", "*.json"], )) +def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict): + if quantization is not None: + def get_class_predicate(p, m): + # Handle custom per layer quantizations + if p in quantization: + return quantization[p] + if not hasattr(m, "to_quantized"): + return False + # Skip layers not divisible by 64 + if hasattr(m, "weight") and m.weight.shape[0] % 64 != 0: + return False + # Handle legacy models which may not have everything quantized + return f"{p}.scales" in weights + + nn.quantize( + model, + group_size=quantization["group_size"], + bits=quantization["bits"], + mode=quantization.get("mode", "affine"), + class_predicate=get_class_predicate, + ) + @partial(mx.compile, shapeless=True) def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: