format
This commit is contained in:
@@ -1,25 +1,25 @@
|
||||
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
|
||||
|
||||
from mlx_video.utils import rms_norm, apply_quantization
|
||||
from mlx_video.models.ltx_2.rope import apply_interleaved_rotary_emb
|
||||
|
||||
from mlx_vlm.models.gemma3.language import Gemma3Model
|
||||
from mlx_vlm.models.gemma3.config import TextConfig
|
||||
from mlx_vlm.models.gemma3.language import Gemma3Model
|
||||
from rich.console import Console
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TaskProgressColumn,
|
||||
TextColumn,
|
||||
TimeRemainingColumn,
|
||||
)
|
||||
|
||||
from mlx_video.utils import apply_quantization, rms_norm
|
||||
|
||||
# Path to system prompts
|
||||
PROMPTS_DIR = Path(__file__).parent / "prompts"
|
||||
@@ -36,11 +36,10 @@ def _load_system_prompt(prompt_name: str) -> str:
|
||||
|
||||
class LanguageModel(nn.Module):
|
||||
|
||||
|
||||
def __init__(self, config: TextConfig):
|
||||
super().__init__()
|
||||
# Create config matching LTX-2 text encoder requirements
|
||||
self.config = config
|
||||
self.config = config
|
||||
|
||||
# Create the Gemma3Model from mlx-vlm
|
||||
self.model = Gemma3Model(self.config)
|
||||
@@ -51,7 +50,7 @@ class LanguageModel(nn.Module):
|
||||
attention_mask: Optional[mx.array],
|
||||
dtype: mx.Dtype,
|
||||
) -> mx.array:
|
||||
|
||||
|
||||
causal_mask = mx.tril(mx.ones((seq_len, seq_len), dtype=mx.bool_))
|
||||
|
||||
if attention_mask is not None:
|
||||
@@ -59,15 +58,25 @@ class LanguageModel(nn.Module):
|
||||
|
||||
padding_mask = attention_mask.astype(mx.bool_) # (batch, seq_len)
|
||||
combined = causal_mask[None, :, :] & padding_mask[:, None, :]
|
||||
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
||||
mask = mx.where(combined, mx.zeros(combined.shape, dtype=dtype),
|
||||
mx.full(combined.shape, min_val, dtype=dtype))
|
||||
min_val = (
|
||||
mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
||||
)
|
||||
mask = mx.where(
|
||||
combined,
|
||||
mx.zeros(combined.shape, dtype=dtype),
|
||||
mx.full(combined.shape, min_val, dtype=dtype),
|
||||
)
|
||||
return mask[:, None, :, :]
|
||||
else:
|
||||
# No padding mask, just causal
|
||||
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
||||
mask = mx.where(causal_mask, mx.zeros((seq_len, seq_len), dtype=dtype),
|
||||
mx.full((seq_len, seq_len), min_val, dtype=dtype))
|
||||
min_val = (
|
||||
mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
||||
)
|
||||
mask = mx.where(
|
||||
causal_mask,
|
||||
mx.zeros((seq_len, seq_len), dtype=dtype),
|
||||
mx.full((seq_len, seq_len), min_val, dtype=dtype),
|
||||
)
|
||||
return mask[None, None, :, :] # (1, 1, seq, seq)
|
||||
|
||||
def __call__(
|
||||
@@ -91,7 +100,11 @@ class LanguageModel(nn.Module):
|
||||
batch_size, seq_len = inputs.shape
|
||||
|
||||
# Get embeddings
|
||||
h = input_embeddings if input_embeddings is not None else self.model.embed_tokens(inputs)
|
||||
h = (
|
||||
input_embeddings
|
||||
if input_embeddings is not None
|
||||
else self.model.embed_tokens(inputs)
|
||||
)
|
||||
|
||||
# Apply Gemma scaling
|
||||
h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype)
|
||||
@@ -103,11 +116,12 @@ class LanguageModel(nn.Module):
|
||||
if cache is None:
|
||||
cache = [None] * len(self.model.layers)
|
||||
|
||||
full_causal_mask = self._create_causal_mask_with_padding(seq_len, attention_mask, h.dtype)
|
||||
full_causal_mask = self._create_causal_mask_with_padding(
|
||||
seq_len, attention_mask, h.dtype
|
||||
)
|
||||
|
||||
sliding_mask = full_causal_mask
|
||||
|
||||
|
||||
num_layers = len(self.model.layers)
|
||||
for i, layer in enumerate(self.model.layers):
|
||||
is_global = (
|
||||
@@ -147,9 +161,9 @@ class LanguageModel(nn.Module):
|
||||
for key, value in weights.items():
|
||||
if key.startswith(prefix):
|
||||
if hasattr(value, "dtype") and value.dtype == mx.float32:
|
||||
sanitized[key[len(prefix):]] = value.astype(mx.bfloat16)
|
||||
sanitized[key[len(prefix) :]] = value.astype(mx.bfloat16)
|
||||
else:
|
||||
sanitized[key[len(prefix):]] = value
|
||||
sanitized[key[len(prefix) :]] = value
|
||||
return sanitized
|
||||
|
||||
@property
|
||||
@@ -158,6 +172,7 @@ class LanguageModel(nn.Module):
|
||||
|
||||
def make_cache(self):
|
||||
from mlx_vlm.models.cache import KVCache, RotatingKVCache
|
||||
|
||||
caches = []
|
||||
for i in range(len(self.layers)):
|
||||
if (
|
||||
@@ -172,6 +187,7 @@ class LanguageModel(nn.Module):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str):
|
||||
import json
|
||||
|
||||
weight_files = sorted(Path(model_path).glob("*.safetensors"))
|
||||
config_file = Path(model_path) / "config.json"
|
||||
config_dict = {}
|
||||
@@ -179,7 +195,9 @@ class LanguageModel(nn.Module):
|
||||
with open(config_file, "r") as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
language_model = cls(config=TextConfig.from_dict(config_dict["text_config"]))
|
||||
language_model = cls(
|
||||
config=TextConfig.from_dict(config_dict["text_config"])
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Config file not found at {model_path}")
|
||||
|
||||
@@ -188,19 +206,18 @@ class LanguageModel(nn.Module):
|
||||
for i, wf in enumerate(weight_files):
|
||||
weights.update(mx.load(str(wf)))
|
||||
|
||||
|
||||
if hasattr(language_model, "sanitize"):
|
||||
weights = language_model.sanitize(weights=weights)
|
||||
|
||||
|
||||
apply_quantization(model=language_model, weights=weights, quantization=quantization)
|
||||
apply_quantization(
|
||||
model=language_model, weights=weights, quantization=quantization
|
||||
)
|
||||
|
||||
language_model.load_weights(list(weights.items()), strict=False)
|
||||
|
||||
return language_model
|
||||
|
||||
|
||||
|
||||
class ConnectorAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -250,9 +267,15 @@ class ConnectorAttention(nn.Module):
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Reshape to (B, H, T, D) for SPLIT RoPE
|
||||
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
q = mx.reshape(
|
||||
q, (batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
).transpose(0, 2, 1, 3)
|
||||
k = mx.reshape(
|
||||
k, (batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
).transpose(0, 2, 1, 3)
|
||||
v = mx.reshape(
|
||||
v, (batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
).transpose(0, 2, 1, 3)
|
||||
|
||||
if pe is not None:
|
||||
q = self._apply_split_rope(q, pe[0], pe[1])
|
||||
@@ -304,7 +327,7 @@ class ConnectorAttention(nn.Module):
|
||||
out2 = x2 * cos_freq + x1 * sin_freq
|
||||
|
||||
return mx.concatenate([out1, out2], axis=-1).astype(input_dtype)
|
||||
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
"""GELU-gated linear unit."""
|
||||
@@ -336,9 +359,17 @@ class ConnectorFeedForward(nn.Module):
|
||||
|
||||
class ConnectorTransformerBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128, has_gate_logits: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 3840,
|
||||
num_heads: int = 30,
|
||||
head_dim: int = 128,
|
||||
has_gate_logits: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.attn1 = ConnectorAttention(dim, num_heads, head_dim, has_gate_logits=has_gate_logits)
|
||||
self.attn1 = ConnectorAttention(
|
||||
dim, num_heads, head_dim, has_gate_logits=has_gate_logits
|
||||
)
|
||||
self.ff = ConnectorFeedForward(dim)
|
||||
|
||||
def __call__(
|
||||
@@ -388,14 +419,18 @@ class Embeddings1DConnector(nn.Module):
|
||||
self.positional_embedding_max_pos = positional_embedding_max_pos or [1]
|
||||
|
||||
self.transformer_1d_blocks = {
|
||||
i: ConnectorTransformerBlock(dim, num_heads, head_dim, has_gate_logits=has_gate_logits)
|
||||
i: ConnectorTransformerBlock(
|
||||
dim, num_heads, head_dim, has_gate_logits=has_gate_logits
|
||||
)
|
||||
for i in range(num_layers)
|
||||
}
|
||||
|
||||
if num_learnable_registers > 0:
|
||||
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
|
||||
|
||||
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
|
||||
def _precompute_freqs_cis(
|
||||
self, seq_len: int, dtype: mx.Dtype
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute RoPE frequencies for connector (SPLIT type matching PyTorch).
|
||||
|
||||
Returns tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2).
|
||||
@@ -464,11 +499,15 @@ class Embeddings1DConnector(nn.Module):
|
||||
|
||||
# Binary mask: 1 for valid tokens, 0 for padded
|
||||
# attention_mask is additive: 0 for valid, large negative for padded
|
||||
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(mx.int32) # (batch, seq)
|
||||
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(
|
||||
mx.int32
|
||||
) # (batch, seq)
|
||||
|
||||
# Tile registers to match sequence length, cast to hidden_states dtype
|
||||
num_tiles = seq_len // self.num_learnable_registers
|
||||
registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype(dtype) # (seq_len, dim)
|
||||
registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype(
|
||||
dtype
|
||||
) # (seq_len, dim)
|
||||
|
||||
# Process each batch item (PyTorch uses advanced indexing)
|
||||
result_list = []
|
||||
@@ -481,25 +520,33 @@ class Embeddings1DConnector(nn.Module):
|
||||
|
||||
# Extract valid tokens (where mask is 1)
|
||||
# Since we have left-padded input, valid tokens are at the end
|
||||
valid_tokens = hs_b[seq_len - num_valid:] # (num_valid, dim)
|
||||
valid_tokens = hs_b[seq_len - num_valid :] # (num_valid, dim)
|
||||
|
||||
# Pad with zeros on the right to get back to seq_len
|
||||
pad_length = seq_len - num_valid
|
||||
if pad_length > 0:
|
||||
padding = mx.zeros((pad_length, dim), dtype=dtype)
|
||||
adjusted = mx.concatenate([valid_tokens, padding], axis=0) # (seq_len, dim)
|
||||
adjusted = mx.concatenate(
|
||||
[valid_tokens, padding], axis=0
|
||||
) # (seq_len, dim)
|
||||
else:
|
||||
adjusted = valid_tokens
|
||||
|
||||
# Create flipped mask: 1s at front (where valid tokens now are), 0s at back
|
||||
flipped_mask = mx.concatenate([
|
||||
mx.ones((num_valid,), dtype=mx.int32),
|
||||
mx.zeros((pad_length,), dtype=mx.int32)
|
||||
], axis=0) # (seq,)
|
||||
flipped_mask = mx.concatenate(
|
||||
[
|
||||
mx.ones((num_valid,), dtype=mx.int32),
|
||||
mx.zeros((pad_length,), dtype=mx.int32),
|
||||
],
|
||||
axis=0,
|
||||
) # (seq,)
|
||||
|
||||
# Combine: valid tokens at front, registers at back
|
||||
flipped_mask_expanded = flipped_mask[:, None].astype(dtype) # (seq, 1)
|
||||
combined = flipped_mask_expanded * adjusted + (1 - flipped_mask_expanded) * registers
|
||||
combined = (
|
||||
flipped_mask_expanded * adjusted
|
||||
+ (1 - flipped_mask_expanded) * registers
|
||||
)
|
||||
result_list.append(combined)
|
||||
|
||||
hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim)
|
||||
@@ -526,7 +573,9 @@ class Embeddings1DConnector(nn.Module):
|
||||
|
||||
# Process through transformer blocks
|
||||
for i in range(len(self.transformer_1d_blocks)):
|
||||
hidden_states = self.transformer_1d_blocks[i](hidden_states, attention_mask, freqs_cis)
|
||||
hidden_states = self.transformer_1d_blocks[i](
|
||||
hidden_states, attention_mask, freqs_cis
|
||||
)
|
||||
|
||||
# Final RMS norm
|
||||
hidden_states = rms_norm(hidden_states)
|
||||
@@ -534,7 +583,6 @@ class Embeddings1DConnector(nn.Module):
|
||||
return hidden_states, attention_mask
|
||||
|
||||
|
||||
|
||||
def norm_and_concat_hidden_states(
|
||||
hidden_states: List[mx.array],
|
||||
attention_mask: mx.array,
|
||||
@@ -567,8 +615,12 @@ def norm_and_concat_hidden_states(
|
||||
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
|
||||
|
||||
# Compute masked min/max per layer
|
||||
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=dtype))
|
||||
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=dtype))
|
||||
x_for_min = mx.where(
|
||||
mask, stacked, mx.full(stacked.shape, float("inf"), dtype=dtype)
|
||||
)
|
||||
x_for_max = mx.where(
|
||||
mask, stacked, mx.full(stacked.shape, float("-inf"), dtype=dtype)
|
||||
)
|
||||
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
|
||||
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
|
||||
range_val = x_max - x_min
|
||||
@@ -603,7 +655,9 @@ def norm_and_concat_per_token_rms(
|
||||
dtype = encoded_text.dtype
|
||||
|
||||
# Per-token RMSNorm across hidden dimension: variance = mean(x^2) over dim D
|
||||
variance = mx.mean(encoded_text.astype(mx.float32) ** 2, axis=2, keepdims=True) # (B, T, 1, L)
|
||||
variance = mx.mean(
|
||||
encoded_text.astype(mx.float32) ** 2, axis=2, keepdims=True
|
||||
) # (B, T, 1, L)
|
||||
normed = encoded_text.astype(mx.float32) * mx.rsqrt(variance + 1e-6)
|
||||
normed = normed.astype(dtype)
|
||||
|
||||
@@ -625,7 +679,9 @@ def _rescale_norm(x: mx.array, target_dim: int, source_dim: int) -> mx.array:
|
||||
class GemmaFeaturesExtractor(nn.Module):
|
||||
"""V1 feature extractor (LTX-2): 8 * (x - mean) / range normalization."""
|
||||
|
||||
def __init__(self, input_dim: int = 188160, output_dim: int = 3840, bias: bool = False):
|
||||
def __init__(
|
||||
self, input_dim: int = 188160, output_dim: int = 3840, bias: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=bias)
|
||||
|
||||
@@ -674,13 +730,14 @@ class GemmaFeaturesExtractorV2(nn.Module):
|
||||
|
||||
if mode == "video":
|
||||
target_dim = self.video_aggregate_embed.weight.shape[0]
|
||||
return self.video_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim))
|
||||
return self.video_aggregate_embed(
|
||||
_rescale_norm(normed, target_dim, self.embedding_dim)
|
||||
)
|
||||
else:
|
||||
target_dim = self.audio_aggregate_embed.weight.shape[0]
|
||||
return self.audio_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim))
|
||||
|
||||
|
||||
|
||||
return self.audio_aggregate_embed(
|
||||
_rescale_norm(normed, target_dim, self.embedding_dim)
|
||||
)
|
||||
|
||||
|
||||
class AudioEmbeddingsConnector(nn.Module):
|
||||
@@ -717,8 +774,8 @@ class LTX2TextEncoder(nn.Module):
|
||||
video_output_dim = 4096
|
||||
audio_output_dim = 2048
|
||||
self.feature_extractor_v2 = GemmaFeaturesExtractorV2(
|
||||
flat_dim=feature_input_dim, # 3840 * 49 = 188160 (concatenated)
|
||||
embedding_dim=hidden_dim, # 3840 (Gemma hidden_dim, for rescale)
|
||||
flat_dim=feature_input_dim, # 3840 * 49 = 188160 (concatenated)
|
||||
embedding_dim=hidden_dim, # 3840 (Gemma hidden_dim, for rescale)
|
||||
video_output_dim=video_output_dim,
|
||||
audio_output_dim=audio_output_dim,
|
||||
bias=True,
|
||||
@@ -728,37 +785,57 @@ class LTX2TextEncoder(nn.Module):
|
||||
# connector_positional_embedding_max_pos=[4096] from LTX-2.3 safetensors
|
||||
# config (nested under config.transformer.connector_positional_embedding_max_pos)
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
dim=video_output_dim, num_heads=32, head_dim=128,
|
||||
num_layers=8, num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[4096], has_gate_logits=True,
|
||||
dim=video_output_dim,
|
||||
num_heads=32,
|
||||
head_dim=128,
|
||||
num_layers=8,
|
||||
num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[4096],
|
||||
has_gate_logits=True,
|
||||
)
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
dim=audio_output_dim, num_heads=32, head_dim=64,
|
||||
num_layers=8, num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[4096], has_gate_logits=True,
|
||||
dim=audio_output_dim,
|
||||
num_heads=32,
|
||||
head_dim=64,
|
||||
num_layers=8,
|
||||
num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[4096],
|
||||
has_gate_logits=True,
|
||||
)
|
||||
else:
|
||||
# LTX-2: shared feature extractor, 3840-dim connectors
|
||||
self.feature_extractor = GemmaFeaturesExtractor(feature_input_dim, hidden_dim)
|
||||
self.feature_extractor = GemmaFeaturesExtractor(
|
||||
feature_input_dim, hidden_dim
|
||||
)
|
||||
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
dim=hidden_dim, num_heads=30, head_dim=128,
|
||||
num_layers=2, num_learnable_registers=128,
|
||||
dim=hidden_dim,
|
||||
num_heads=30,
|
||||
head_dim=128,
|
||||
num_layers=2,
|
||||
num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[1],
|
||||
)
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
dim=hidden_dim, num_heads=30, head_dim=128,
|
||||
num_layers=2, num_learnable_registers=128,
|
||||
dim=hidden_dim,
|
||||
num_heads=30,
|
||||
head_dim=128,
|
||||
num_layers=2,
|
||||
num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[1],
|
||||
)
|
||||
|
||||
self.processor = None
|
||||
|
||||
def load(self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it"):
|
||||
def load(
|
||||
self,
|
||||
model_path: Optional[str] = None,
|
||||
text_encoder_path: Optional[str] = "google/gemma-3-12b-it",
|
||||
):
|
||||
|
||||
if Path(str(text_encoder_path)).joinpath("text_encoder").is_dir():
|
||||
text_encoder_path = str(Path(text_encoder_path) / "text_encoder")
|
||||
|
||||
|
||||
self.language_model = LanguageModel.from_pretrained(text_encoder_path)
|
||||
|
||||
# Load transformer weights for feature extractor and connector.
|
||||
@@ -785,22 +862,35 @@ class LTX2TextEncoder(nn.Module):
|
||||
|
||||
if transformer_weights:
|
||||
self._load_feature_extractors(transformer_weights, is_reformatted)
|
||||
self._load_connector("video_embeddings_connector", transformer_weights, is_reformatted)
|
||||
self._load_connector("audio_embeddings_connector", transformer_weights, is_reformatted)
|
||||
self._load_connector(
|
||||
"video_embeddings_connector", transformer_weights, is_reformatted
|
||||
)
|
||||
self._load_connector(
|
||||
"audio_embeddings_connector", transformer_weights, is_reformatted
|
||||
)
|
||||
else:
|
||||
print("WARNING: No transformer weights found for text projection connectors. "
|
||||
"Text conditioning will use uninitialized weights!")
|
||||
print(
|
||||
"WARNING: No transformer weights found for text projection connectors. "
|
||||
"Text conditioning will use uninitialized weights!"
|
||||
)
|
||||
|
||||
# Load tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer_path = model_path / "tokenizer"
|
||||
if tokenizer_path.exists():
|
||||
self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True)
|
||||
self.processor = AutoTokenizer.from_pretrained(
|
||||
str(tokenizer_path), trust_remote_code=True
|
||||
)
|
||||
else:
|
||||
try:
|
||||
self.processor = AutoTokenizer.from_pretrained(text_encoder_path, trust_remote_code=True)
|
||||
self.processor = AutoTokenizer.from_pretrained(
|
||||
text_encoder_path, trust_remote_code=True
|
||||
)
|
||||
except Exception:
|
||||
self.processor = AutoTokenizer.from_pretrained("google/gemma-3-12b-it", trust_remote_code=True)
|
||||
self.processor = AutoTokenizer.from_pretrained(
|
||||
"google/gemma-3-12b-it", trust_remote_code=True
|
||||
)
|
||||
# Set left padding to match official LTX-2 text encoder
|
||||
self.processor.padding_side = "left"
|
||||
|
||||
@@ -823,7 +913,11 @@ class LTX2TextEncoder(nn.Module):
|
||||
submodule.bias = weights[b_key]
|
||||
else:
|
||||
# LTX-2: single aggregate_embed
|
||||
agg_key = "aggregate_embed.weight" if is_reformatted else "text_embedding_projection.aggregate_embed.weight"
|
||||
agg_key = (
|
||||
"aggregate_embed.weight"
|
||||
if is_reformatted
|
||||
else "text_embedding_projection.aggregate_embed.weight"
|
||||
)
|
||||
if agg_key in weights:
|
||||
self.feature_extractor.aggregate_embed.weight = weights[agg_key]
|
||||
|
||||
@@ -837,12 +931,12 @@ class LTX2TextEncoder(nn.Module):
|
||||
prefix = f"{name}."
|
||||
for key, value in weights.items():
|
||||
if key.startswith(prefix):
|
||||
connector_weights[key[len(prefix):]] = value
|
||||
connector_weights[key[len(prefix) :]] = value
|
||||
else:
|
||||
mono_prefix = f"model.diffusion_model.{name}."
|
||||
for key, value in weights.items():
|
||||
if key.startswith(mono_prefix):
|
||||
connector_weights[key[len(mono_prefix):]] = value
|
||||
connector_weights[key[len(mono_prefix) :]] = value
|
||||
|
||||
if not connector_weights:
|
||||
return
|
||||
@@ -894,21 +988,36 @@ class LTX2TextEncoder(nn.Module):
|
||||
input_ids = mx.array(inputs["input_ids"])
|
||||
attention_mask = mx.array(inputs["attention_mask"])
|
||||
|
||||
_, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True)
|
||||
_, all_hidden_states = self.language_model(
|
||||
inputs=input_ids,
|
||||
input_embeddings=None,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
if self.has_prompt_adaln:
|
||||
# LTX-2.3: V2 feature extraction (per-token RMSNorm + rescale)
|
||||
video_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="video")
|
||||
video_features = self.feature_extractor_v2(
|
||||
all_hidden_states, attention_mask, mode="video"
|
||||
)
|
||||
additive_mask = (attention_mask - 1).astype(video_features.dtype)
|
||||
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
additive_mask = (
|
||||
additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
)
|
||||
|
||||
video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask)
|
||||
video_embeddings, _ = self.video_embeddings_connector(
|
||||
video_features, additive_mask
|
||||
)
|
||||
|
||||
if return_audio_embeddings:
|
||||
audio_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="audio")
|
||||
audio_features = self.feature_extractor_v2(
|
||||
all_hidden_states, attention_mask, mode="audio"
|
||||
)
|
||||
audio_mask = (attention_mask - 1).astype(audio_features.dtype)
|
||||
audio_mask = audio_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
audio_embeddings, _ = self.audio_embeddings_connector(audio_features, audio_mask)
|
||||
audio_embeddings, _ = self.audio_embeddings_connector(
|
||||
audio_features, audio_mask
|
||||
)
|
||||
return video_embeddings, audio_embeddings
|
||||
else:
|
||||
return video_embeddings, attention_mask
|
||||
@@ -920,12 +1029,18 @@ class LTX2TextEncoder(nn.Module):
|
||||
|
||||
video_features = self.feature_extractor(concat_hidden)
|
||||
additive_mask = (attention_mask - 1).astype(video_features.dtype)
|
||||
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
additive_mask = (
|
||||
additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
)
|
||||
|
||||
video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask)
|
||||
video_embeddings, _ = self.video_embeddings_connector(
|
||||
video_features, additive_mask
|
||||
)
|
||||
|
||||
if return_audio_embeddings:
|
||||
audio_embeddings, _ = self.audio_embeddings_connector(video_features, additive_mask)
|
||||
audio_embeddings, _ = self.audio_embeddings_connector(
|
||||
video_features, additive_mask
|
||||
)
|
||||
return video_embeddings, audio_embeddings
|
||||
else:
|
||||
return video_embeddings, attention_mask
|
||||
@@ -964,7 +1079,7 @@ class LTX2TextEncoder(nn.Module):
|
||||
# Remove leading/trailing whitespace
|
||||
response = response.strip()
|
||||
# Remove any leading punctuation
|
||||
response = re.sub(r'^[^\w\s]+', '', response)
|
||||
response = re.sub(r"^[^\w\s]+", "", response)
|
||||
return response
|
||||
|
||||
def _apply_chat_template(
|
||||
@@ -985,7 +1100,9 @@ class LTX2TextEncoder(nn.Module):
|
||||
elif isinstance(content, list):
|
||||
# Handle multimodal content (image + text)
|
||||
text_parts = [c["text"] for c in content if c.get("type") == "text"]
|
||||
formatted += f"<start_of_turn>user\n{' '.join(text_parts)}<end_of_turn>\n"
|
||||
formatted += (
|
||||
f"<start_of_turn>user\n{' '.join(text_parts)}<end_of_turn>\n"
|
||||
)
|
||||
elif role == "assistant":
|
||||
formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n"
|
||||
# Add generation prompt
|
||||
@@ -1016,7 +1133,9 @@ class LTX2TextEncoder(nn.Module):
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||
except ImportError:
|
||||
logging.warning("mlx-lm not available for prompt enhancement. Using original prompt.")
|
||||
logging.warning(
|
||||
"mlx-lm not available for prompt enhancement. Using original prompt."
|
||||
)
|
||||
return prompt
|
||||
|
||||
if self.processor is None:
|
||||
@@ -1043,7 +1162,11 @@ class LTX2TextEncoder(nn.Module):
|
||||
)
|
||||
input_ids = mx.array(inputs["input_ids"])
|
||||
|
||||
sampler = make_sampler(kwargs.get("temperature", 0.7), kwargs.get("top_p", 1.0), top_k=kwargs.get("top_k", -1))
|
||||
sampler = make_sampler(
|
||||
kwargs.get("temperature", 0.7),
|
||||
kwargs.get("top_p", 1.0),
|
||||
top_k=kwargs.get("top_k", -1),
|
||||
)
|
||||
logits_processors = make_logits_processors(
|
||||
kwargs.get("logit_bias", None),
|
||||
kwargs.get("repetition_penalty", 1.3),
|
||||
@@ -1094,14 +1217,15 @@ class LTX2TextEncoder(nn.Module):
|
||||
mx.clear_cache()
|
||||
|
||||
# Decode only the new tokens
|
||||
enhanced_prompt = self.processor.decode(generated_tokens, skip_special_tokens=True)
|
||||
enhanced_prompt = self.processor.decode(
|
||||
generated_tokens, skip_special_tokens=True
|
||||
)
|
||||
|
||||
enhanced_prompt = self._clean_response(enhanced_prompt)
|
||||
logging.info(f"Enhanced prompt: {enhanced_prompt}")
|
||||
|
||||
return enhanced_prompt
|
||||
|
||||
|
||||
def enhance_i2v(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -1135,4 +1259,3 @@ def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
|
||||
encoder = LTX2TextEncoder()
|
||||
encoder.load(model_path=model_path)
|
||||
return encoder
|
||||
|
||||
|
||||
Reference in New Issue
Block a user