Add custom text encoder with quantization
Co-authored-by: HimanshU Mourya <40685364+codingstark-dev@users.noreply.github.com>
This commit is contained in:
@@ -150,6 +150,7 @@ def denoise(
|
||||
|
||||
def generate_video(
|
||||
model_repo: str,
|
||||
text_encoder_repo: str,
|
||||
prompt: str,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
@@ -189,6 +190,7 @@ def generate_video(
|
||||
|
||||
# Get model path
|
||||
model_path = get_model_path(model_repo)
|
||||
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
|
||||
|
||||
# Calculate latent dimensions
|
||||
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
|
||||
@@ -200,8 +202,8 @@ def generate_video(
|
||||
# Load text encoder
|
||||
print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}")
|
||||
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
|
||||
text_encoder = LTX2TextEncoder(model_path=str(model_path))
|
||||
text_encoder.load(str(model_path))
|
||||
text_encoder = LTX2TextEncoder()
|
||||
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
|
||||
mx.eval(text_encoder.parameters())
|
||||
|
||||
text_embeddings, _ = text_encoder(prompt)
|
||||
@@ -387,6 +389,12 @@ Examples:
|
||||
default="Lightricks/LTX-2",
|
||||
help="Model repository to use (default: Lightricks/LTX-2)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text-encoder-repo",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Text encoder repository to use (default: None)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
@@ -396,6 +404,7 @@ Examples:
|
||||
|
||||
generate_video(
|
||||
model_repo=args.model_repo,
|
||||
text_encoder_repo=args.text_encoder_repo,
|
||||
prompt=args.prompt,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
|
||||
@@ -12,7 +12,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.utils import rms_norm
|
||||
from mlx_video.utils import rms_norm, apply_quantization
|
||||
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
|
||||
|
||||
from mlx_vlm.models.gemma3.language import Gemma3Model
|
||||
@@ -22,26 +22,10 @@ from mlx_vlm.models.gemma3.config import TextConfig
|
||||
class LanguageModel(nn.Module):
|
||||
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, config: TextConfig):
|
||||
super().__init__()
|
||||
# Create config matching LTX-2 text encoder requirements
|
||||
self.config = TextConfig(
|
||||
model_type="gemma3_text",
|
||||
hidden_size=3840,
|
||||
num_hidden_layers=48,
|
||||
intermediate_size=15360,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=8,
|
||||
head_dim=256,
|
||||
rms_norm_eps=1e-6,
|
||||
vocab_size=262208,
|
||||
query_pre_attn_scalar=256,
|
||||
rope_global_base_freq=1000000.0,
|
||||
rope_local_base_freq=10000.0,
|
||||
rope_traditional=False,
|
||||
sliding_window=1024,
|
||||
sliding_window_pattern=6,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
# Create the Gemma3Model from mlx-vlm
|
||||
self.model = Gemma3Model(self.config)
|
||||
@@ -136,9 +120,44 @@ class LanguageModel(nn.Module):
|
||||
|
||||
return hidden_states, all_hidden_states
|
||||
|
||||
def load_weights(self, weights: List[Tuple[str, mx.array]], strict: bool = True):
|
||||
"""Load weights into the model."""
|
||||
self.model.load_weights(weights, strict=strict)
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
prefix = "language_model."
|
||||
sanitized = {
|
||||
key[len(prefix):]: value
|
||||
for key, value in weights.items()
|
||||
if key.startswith(prefix)
|
||||
}
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str):
|
||||
import json
|
||||
weight_files = sorted(Path(model_path).glob("*.safetensors"))
|
||||
config_file = Path(model_path) / "config.json"
|
||||
config_dict = {}
|
||||
if config_file.exists():
|
||||
with open(config_file, "r") as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
language_model = cls(config=TextConfig.from_dict(config_dict["text_config"]))
|
||||
else:
|
||||
raise ValueError(f"Config file not found at {model_path}")
|
||||
|
||||
quantization = config_dict.get("quantization", None)
|
||||
weights = {}
|
||||
for i, wf in enumerate(weight_files):
|
||||
weights.update(mx.load(str(wf)))
|
||||
|
||||
|
||||
if hasattr(language_model, "sanitize"):
|
||||
weights = language_model.sanitize(weights=weights)
|
||||
|
||||
|
||||
apply_quantization(model=language_model, weights=weights, quantization=quantization)
|
||||
|
||||
language_model.load_weights(list(weights.items()), strict=False)
|
||||
|
||||
return language_model
|
||||
|
||||
|
||||
|
||||
@@ -476,42 +495,19 @@ class GemmaFeaturesExtractor(nn.Module):
|
||||
|
||||
|
||||
|
||||
def sanitize_gemma3_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
sanitized = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = None
|
||||
|
||||
if key.startswith("base_text_encoder.language_model."):
|
||||
new_key = key.replace("base_text_encoder.language_model.", "")
|
||||
elif key.startswith("language_model.model."):
|
||||
new_key = key.replace("language_model.model.", "")
|
||||
elif key.startswith("language_model."):
|
||||
new_key = key.replace("language_model.", "")
|
||||
else:
|
||||
continue
|
||||
|
||||
if new_key is None:
|
||||
continue
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
class LTX2TextEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str = "Lightricks/LTX-2",
|
||||
hidden_dim: int = 3840,
|
||||
num_layers: int = 49, # 48 transformer layers + 1 embedding
|
||||
):
|
||||
super().__init__()
|
||||
self._model_path = model_path
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_layers = num_layers
|
||||
self.model = LanguageModel()
|
||||
self.language_model = None
|
||||
|
||||
# Feature extractor: 3840*49 -> 3840
|
||||
self.feature_extractor = GemmaFeaturesExtractor(
|
||||
@@ -530,37 +526,17 @@ class LTX2TextEncoder(nn.Module):
|
||||
|
||||
self.processor = None
|
||||
|
||||
def load(self, model_path: Optional[str] = None):
|
||||
path = model_path or self._model_path
|
||||
def load(self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it"):
|
||||
|
||||
# Load Gemma weights from text_encoder subdirectory
|
||||
if Path(path).is_dir():
|
||||
text_encoder_path = Path(path) / "text_encoder"
|
||||
if text_encoder_path.exists():
|
||||
gemma_path = str(text_encoder_path)
|
||||
else:
|
||||
gemma_path = path
|
||||
else:
|
||||
gemma_path = path
|
||||
if Path(text_encoder_path / "text_encoder").is_dir():
|
||||
text_encoder_path = str(text_encoder_path / "text_encoder")
|
||||
|
||||
print(f"Loading Gemma 3 text encoder from {gemma_path}...")
|
||||
weight_files = sorted(Path(gemma_path).glob("*.safetensors"))
|
||||
all_weights = {}
|
||||
for i, wf in enumerate(weight_files):
|
||||
print(f" Loading weight file {i+1}/{len(weight_files)}...")
|
||||
weights = mx.load(str(wf))
|
||||
all_weights.update(weights)
|
||||
|
||||
# Sanitize and load Gemma weights
|
||||
sanitized = sanitize_gemma3_weights(all_weights)
|
||||
print(f" Sanitized Gemma weights: {len(sanitized)}")
|
||||
self.model.load_weights(list(sanitized.items()), strict=False)
|
||||
self.language_model = LanguageModel.from_pretrained(text_encoder_path)
|
||||
|
||||
# Load transformer weights for feature extractor and connector
|
||||
transformer_path = Path(model_path or self._model_path)
|
||||
transformer_files = list(transformer_path.glob("ltx-2*.safetensors"))
|
||||
|
||||
transformer_files = list(model_path.glob("ltx-2-19*.safetensors"))
|
||||
if transformer_files:
|
||||
print(f"Loading transformer weights for text pipeline...")
|
||||
transformer_weights = mx.load(str(transformer_files[0]))
|
||||
|
||||
# Load feature extractor (aggregate_embed)
|
||||
@@ -568,7 +544,7 @@ class LTX2TextEncoder(nn.Module):
|
||||
self.feature_extractor.aggregate_embed.weight = transformer_weights[
|
||||
"text_embedding_projection.aggregate_embed.weight"
|
||||
]
|
||||
print(" Loaded aggregate_embed weights")
|
||||
|
||||
|
||||
# Load video_embeddings_connector weights
|
||||
connector_weights = {}
|
||||
@@ -589,20 +565,18 @@ class LTX2TextEncoder(nn.Module):
|
||||
self.video_embeddings_connector.load_weights(
|
||||
list(mapped_weights.items()), strict=False
|
||||
)
|
||||
print(f" Loaded {len(connector_weights)} connector weights")
|
||||
|
||||
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
|
||||
if "learnable_registers" in connector_weights:
|
||||
self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"]
|
||||
print(f" Loaded learnable_registers: {connector_weights['learnable_registers'].shape}")
|
||||
|
||||
# Load tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer_path = Path(model_path or self._model_path) / "tokenizer"
|
||||
tokenizer_path = model_path / "tokenizer"
|
||||
if tokenizer_path.exists():
|
||||
self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True)
|
||||
else:
|
||||
self.processor = AutoTokenizer.from_pretrained(gemma_path, trust_remote_code=True)
|
||||
self.processor = AutoTokenizer.from_pretrained(text_encoder_path, trust_remote_code=True)
|
||||
# Set left padding to match official LTX-2 text encoder
|
||||
self.processor.padding_side = "left"
|
||||
|
||||
@@ -627,7 +601,7 @@ class LTX2TextEncoder(nn.Module):
|
||||
input_ids = mx.array(inputs["input_ids"])
|
||||
attention_mask = mx.array(inputs["attention_mask"])
|
||||
|
||||
_, all_hidden_states = self.model(input_ids, attention_mask, output_hidden_states=True)
|
||||
_, all_hidden_states = self.language_model(input_ids, attention_mask, output_hidden_states=True)
|
||||
|
||||
concat_hidden = norm_and_concat_hidden_states(
|
||||
all_hidden_states, attention_mask, padding_side="left"
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
"""Utility functions for MLX Video."""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -22,6 +19,28 @@ def get_model_path(model_repo: str):
|
||||
allow_patterns=["*.safetensors", "*.json"],
|
||||
))
|
||||
|
||||
def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
|
||||
if quantization is not None:
|
||||
def get_class_predicate(p, m):
|
||||
# Handle custom per layer quantizations
|
||||
if p in quantization:
|
||||
return quantization[p]
|
||||
if not hasattr(m, "to_quantized"):
|
||||
return False
|
||||
# Skip layers not divisible by 64
|
||||
if hasattr(m, "weight") and m.weight.shape[0] % 64 != 0:
|
||||
return False
|
||||
# Handle legacy models which may not have everything quantized
|
||||
return f"{p}.scales" in weights
|
||||
|
||||
nn.quantize(
|
||||
model,
|
||||
group_size=quantization["group_size"],
|
||||
bits=quantization["bits"],
|
||||
mode=quantization.get("mode", "affine"),
|
||||
class_predicate=get_class_predicate,
|
||||
)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
|
||||
|
||||
Reference in New Issue
Block a user