Add custom text encoder with quantization
Co-authored-by: HimanshU Mourya <40685364+codingstark-dev@users.noreply.github.com>
This commit is contained in:
@@ -150,6 +150,7 @@ def denoise(
|
|||||||
|
|
||||||
def generate_video(
|
def generate_video(
|
||||||
model_repo: str,
|
model_repo: str,
|
||||||
|
text_encoder_repo: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
height: int = 512,
|
height: int = 512,
|
||||||
width: int = 512,
|
width: int = 512,
|
||||||
@@ -189,6 +190,7 @@ def generate_video(
|
|||||||
|
|
||||||
# Get model path
|
# Get model path
|
||||||
model_path = get_model_path(model_repo)
|
model_path = get_model_path(model_repo)
|
||||||
|
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
|
||||||
|
|
||||||
# Calculate latent dimensions
|
# Calculate latent dimensions
|
||||||
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
|
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
|
||||||
@@ -200,8 +202,8 @@ def generate_video(
|
|||||||
# Load text encoder
|
# Load text encoder
|
||||||
print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}")
|
print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}")
|
||||||
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
|
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
|
||||||
text_encoder = LTX2TextEncoder(model_path=str(model_path))
|
text_encoder = LTX2TextEncoder()
|
||||||
text_encoder.load(str(model_path))
|
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
|
||||||
mx.eval(text_encoder.parameters())
|
mx.eval(text_encoder.parameters())
|
||||||
|
|
||||||
text_embeddings, _ = text_encoder(prompt)
|
text_embeddings, _ = text_encoder(prompt)
|
||||||
@@ -317,7 +319,7 @@ def generate_video(
|
|||||||
|
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}")
|
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}")
|
||||||
print(f"{Colors.BOLD}{Colors.GREEN} ✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}")
|
print(f"{Colors.BOLD}{Colors.GREEN}✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}")
|
||||||
|
|
||||||
return video_np
|
return video_np
|
||||||
|
|
||||||
@@ -387,6 +389,12 @@ Examples:
|
|||||||
default="Lightricks/LTX-2",
|
default="Lightricks/LTX-2",
|
||||||
help="Model repository to use (default: Lightricks/LTX-2)"
|
help="Model repository to use (default: Lightricks/LTX-2)"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--text-encoder-repo",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Text encoder repository to use (default: None)"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--verbose",
|
"--verbose",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -396,6 +404,7 @@ Examples:
|
|||||||
|
|
||||||
generate_video(
|
generate_video(
|
||||||
model_repo=args.model_repo,
|
model_repo=args.model_repo,
|
||||||
|
text_encoder_repo=args.text_encoder_repo,
|
||||||
prompt=args.prompt,
|
prompt=args.prompt,
|
||||||
height=args.height,
|
height=args.height,
|
||||||
width=args.width,
|
width=args.width,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from mlx_video.utils import rms_norm
|
from mlx_video.utils import rms_norm, apply_quantization
|
||||||
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
|
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
|
||||||
|
|
||||||
from mlx_vlm.models.gemma3.language import Gemma3Model
|
from mlx_vlm.models.gemma3.language import Gemma3Model
|
||||||
@@ -22,26 +22,10 @@ from mlx_vlm.models.gemma3.config import TextConfig
|
|||||||
class LanguageModel(nn.Module):
|
class LanguageModel(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, config: TextConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Create config matching LTX-2 text encoder requirements
|
# Create config matching LTX-2 text encoder requirements
|
||||||
self.config = TextConfig(
|
self.config = config
|
||||||
model_type="gemma3_text",
|
|
||||||
hidden_size=3840,
|
|
||||||
num_hidden_layers=48,
|
|
||||||
intermediate_size=15360,
|
|
||||||
num_attention_heads=16,
|
|
||||||
num_key_value_heads=8,
|
|
||||||
head_dim=256,
|
|
||||||
rms_norm_eps=1e-6,
|
|
||||||
vocab_size=262208,
|
|
||||||
query_pre_attn_scalar=256,
|
|
||||||
rope_global_base_freq=1000000.0,
|
|
||||||
rope_local_base_freq=10000.0,
|
|
||||||
rope_traditional=False,
|
|
||||||
sliding_window=1024,
|
|
||||||
sliding_window_pattern=6,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create the Gemma3Model from mlx-vlm
|
# Create the Gemma3Model from mlx-vlm
|
||||||
self.model = Gemma3Model(self.config)
|
self.model = Gemma3Model(self.config)
|
||||||
@@ -136,9 +120,44 @@ class LanguageModel(nn.Module):
|
|||||||
|
|
||||||
return hidden_states, all_hidden_states
|
return hidden_states, all_hidden_states
|
||||||
|
|
||||||
def load_weights(self, weights: List[Tuple[str, mx.array]], strict: bool = True):
|
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
"""Load weights into the model."""
|
prefix = "language_model."
|
||||||
self.model.load_weights(weights, strict=strict)
|
sanitized = {
|
||||||
|
key[len(prefix):]: value
|
||||||
|
for key, value in weights.items()
|
||||||
|
if key.startswith(prefix)
|
||||||
|
}
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_path: str):
|
||||||
|
import json
|
||||||
|
weight_files = sorted(Path(model_path).glob("*.safetensors"))
|
||||||
|
config_file = Path(model_path) / "config.json"
|
||||||
|
config_dict = {}
|
||||||
|
if config_file.exists():
|
||||||
|
with open(config_file, "r") as f:
|
||||||
|
config_dict = json.load(f)
|
||||||
|
|
||||||
|
language_model = cls(config=TextConfig.from_dict(config_dict["text_config"]))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Config file not found at {model_path}")
|
||||||
|
|
||||||
|
quantization = config_dict.get("quantization", None)
|
||||||
|
weights = {}
|
||||||
|
for i, wf in enumerate(weight_files):
|
||||||
|
weights.update(mx.load(str(wf)))
|
||||||
|
|
||||||
|
|
||||||
|
if hasattr(language_model, "sanitize"):
|
||||||
|
weights = language_model.sanitize(weights=weights)
|
||||||
|
|
||||||
|
|
||||||
|
apply_quantization(model=language_model, weights=weights, quantization=quantization)
|
||||||
|
|
||||||
|
language_model.load_weights(list(weights.items()), strict=False)
|
||||||
|
|
||||||
|
return language_model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -476,42 +495,19 @@ class GemmaFeaturesExtractor(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_gemma3_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|
||||||
sanitized = {}
|
|
||||||
|
|
||||||
for key, value in weights.items():
|
|
||||||
new_key = None
|
|
||||||
|
|
||||||
if key.startswith("base_text_encoder.language_model."):
|
|
||||||
new_key = key.replace("base_text_encoder.language_model.", "")
|
|
||||||
elif key.startswith("language_model.model."):
|
|
||||||
new_key = key.replace("language_model.model.", "")
|
|
||||||
elif key.startswith("language_model."):
|
|
||||||
new_key = key.replace("language_model.", "")
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if new_key is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
sanitized[new_key] = value
|
|
||||||
|
|
||||||
return sanitized
|
|
||||||
|
|
||||||
|
|
||||||
class LTX2TextEncoder(nn.Module):
|
class LTX2TextEncoder(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_path: str = "Lightricks/LTX-2",
|
|
||||||
hidden_dim: int = 3840,
|
hidden_dim: int = 3840,
|
||||||
num_layers: int = 49, # 48 transformer layers + 1 embedding
|
num_layers: int = 49, # 48 transformer layers + 1 embedding
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._model_path = model_path
|
|
||||||
self.hidden_dim = hidden_dim
|
self.hidden_dim = hidden_dim
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.model = LanguageModel()
|
self.language_model = None
|
||||||
|
|
||||||
# Feature extractor: 3840*49 -> 3840
|
# Feature extractor: 3840*49 -> 3840
|
||||||
self.feature_extractor = GemmaFeaturesExtractor(
|
self.feature_extractor = GemmaFeaturesExtractor(
|
||||||
@@ -530,37 +526,17 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
|
|
||||||
self.processor = None
|
self.processor = None
|
||||||
|
|
||||||
def load(self, model_path: Optional[str] = None):
|
def load(self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it"):
|
||||||
path = model_path or self._model_path
|
|
||||||
|
|
||||||
# Load Gemma weights from text_encoder subdirectory
|
if Path(text_encoder_path / "text_encoder").is_dir():
|
||||||
if Path(path).is_dir():
|
text_encoder_path = str(text_encoder_path / "text_encoder")
|
||||||
text_encoder_path = Path(path) / "text_encoder"
|
|
||||||
if text_encoder_path.exists():
|
|
||||||
gemma_path = str(text_encoder_path)
|
|
||||||
else:
|
|
||||||
gemma_path = path
|
|
||||||
else:
|
|
||||||
gemma_path = path
|
|
||||||
|
|
||||||
print(f"Loading Gemma 3 text encoder from {gemma_path}...")
|
self.language_model = LanguageModel.from_pretrained(text_encoder_path)
|
||||||
weight_files = sorted(Path(gemma_path).glob("*.safetensors"))
|
|
||||||
all_weights = {}
|
|
||||||
for i, wf in enumerate(weight_files):
|
|
||||||
print(f" Loading weight file {i+1}/{len(weight_files)}...")
|
|
||||||
weights = mx.load(str(wf))
|
|
||||||
all_weights.update(weights)
|
|
||||||
|
|
||||||
# Sanitize and load Gemma weights
|
|
||||||
sanitized = sanitize_gemma3_weights(all_weights)
|
|
||||||
print(f" Sanitized Gemma weights: {len(sanitized)}")
|
|
||||||
self.model.load_weights(list(sanitized.items()), strict=False)
|
|
||||||
|
|
||||||
# Load transformer weights for feature extractor and connector
|
# Load transformer weights for feature extractor and connector
|
||||||
transformer_path = Path(model_path or self._model_path)
|
|
||||||
transformer_files = list(transformer_path.glob("ltx-2*.safetensors"))
|
transformer_files = list(model_path.glob("ltx-2-19*.safetensors"))
|
||||||
if transformer_files:
|
if transformer_files:
|
||||||
print(f"Loading transformer weights for text pipeline...")
|
|
||||||
transformer_weights = mx.load(str(transformer_files[0]))
|
transformer_weights = mx.load(str(transformer_files[0]))
|
||||||
|
|
||||||
# Load feature extractor (aggregate_embed)
|
# Load feature extractor (aggregate_embed)
|
||||||
@@ -568,7 +544,7 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
self.feature_extractor.aggregate_embed.weight = transformer_weights[
|
self.feature_extractor.aggregate_embed.weight = transformer_weights[
|
||||||
"text_embedding_projection.aggregate_embed.weight"
|
"text_embedding_projection.aggregate_embed.weight"
|
||||||
]
|
]
|
||||||
print(" Loaded aggregate_embed weights")
|
|
||||||
|
|
||||||
# Load video_embeddings_connector weights
|
# Load video_embeddings_connector weights
|
||||||
connector_weights = {}
|
connector_weights = {}
|
||||||
@@ -589,20 +565,18 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
self.video_embeddings_connector.load_weights(
|
self.video_embeddings_connector.load_weights(
|
||||||
list(mapped_weights.items()), strict=False
|
list(mapped_weights.items()), strict=False
|
||||||
)
|
)
|
||||||
print(f" Loaded {len(connector_weights)} connector weights")
|
|
||||||
|
|
||||||
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
|
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
|
||||||
if "learnable_registers" in connector_weights:
|
if "learnable_registers" in connector_weights:
|
||||||
self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"]
|
self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"]
|
||||||
print(f" Loaded learnable_registers: {connector_weights['learnable_registers'].shape}")
|
|
||||||
|
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
tokenizer_path = Path(model_path or self._model_path) / "tokenizer"
|
tokenizer_path = model_path / "tokenizer"
|
||||||
if tokenizer_path.exists():
|
if tokenizer_path.exists():
|
||||||
self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True)
|
self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True)
|
||||||
else:
|
else:
|
||||||
self.processor = AutoTokenizer.from_pretrained(gemma_path, trust_remote_code=True)
|
self.processor = AutoTokenizer.from_pretrained(text_encoder_path, trust_remote_code=True)
|
||||||
# Set left padding to match official LTX-2 text encoder
|
# Set left padding to match official LTX-2 text encoder
|
||||||
self.processor.padding_side = "left"
|
self.processor.padding_side = "left"
|
||||||
|
|
||||||
@@ -627,7 +601,7 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
input_ids = mx.array(inputs["input_ids"])
|
input_ids = mx.array(inputs["input_ids"])
|
||||||
attention_mask = mx.array(inputs["attention_mask"])
|
attention_mask = mx.array(inputs["attention_mask"])
|
||||||
|
|
||||||
_, all_hidden_states = self.model(input_ids, attention_mask, output_hidden_states=True)
|
_, all_hidden_states = self.language_model(input_ids, attention_mask, output_hidden_states=True)
|
||||||
|
|
||||||
concat_hidden = norm_and_concat_hidden_states(
|
concat_hidden = norm_and_concat_hidden_states(
|
||||||
all_hidden_states, attention_mask, padding_side="left"
|
all_hidden_states, attention_mask, padding_side="left"
|
||||||
|
|||||||
@@ -1,7 +1,4 @@
|
|||||||
"""Utility functions for MLX Video."""
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@@ -22,6 +19,28 @@ def get_model_path(model_repo: str):
|
|||||||
allow_patterns=["*.safetensors", "*.json"],
|
allow_patterns=["*.safetensors", "*.json"],
|
||||||
))
|
))
|
||||||
|
|
||||||
|
def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
|
||||||
|
if quantization is not None:
|
||||||
|
def get_class_predicate(p, m):
|
||||||
|
# Handle custom per layer quantizations
|
||||||
|
if p in quantization:
|
||||||
|
return quantization[p]
|
||||||
|
if not hasattr(m, "to_quantized"):
|
||||||
|
return False
|
||||||
|
# Skip layers not divisible by 64
|
||||||
|
if hasattr(m, "weight") and m.weight.shape[0] % 64 != 0:
|
||||||
|
return False
|
||||||
|
# Handle legacy models which may not have everything quantized
|
||||||
|
return f"{p}.scales" in weights
|
||||||
|
|
||||||
|
nn.quantize(
|
||||||
|
model,
|
||||||
|
group_size=quantization["group_size"],
|
||||||
|
bits=quantization["bits"],
|
||||||
|
mode=quantization.get("mode", "affine"),
|
||||||
|
class_predicate=get_class_predicate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, shapeless=True)
|
@partial(mx.compile, shapeless=True)
|
||||||
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
|
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
|
||||||
|
|||||||
Reference in New Issue
Block a user