This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -1,25 +1,25 @@
import functools
import logging
import math
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
from mlx_video.utils import rms_norm, apply_quantization
from mlx_video.models.ltx_2.rope import apply_interleaved_rotary_emb
from mlx_vlm.models.gemma3.language import Gemma3Model
from mlx_vlm.models.gemma3.config import TextConfig
from mlx_vlm.models.gemma3.language import Gemma3Model
from rich.console import Console
from rich.progress import (
BarColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeRemainingColumn,
)
from mlx_video.utils import apply_quantization, rms_norm
# Path to system prompts
PROMPTS_DIR = Path(__file__).parent / "prompts"
@@ -36,11 +36,10 @@ def _load_system_prompt(prompt_name: str) -> str:
class LanguageModel(nn.Module):
def __init__(self, config: TextConfig):
super().__init__()
# Create config matching LTX-2 text encoder requirements
self.config = config
self.config = config
# Create the Gemma3Model from mlx-vlm
self.model = Gemma3Model(self.config)
@@ -51,7 +50,7 @@ class LanguageModel(nn.Module):
attention_mask: Optional[mx.array],
dtype: mx.Dtype,
) -> mx.array:
causal_mask = mx.tril(mx.ones((seq_len, seq_len), dtype=mx.bool_))
if attention_mask is not None:
@@ -59,15 +58,25 @@ class LanguageModel(nn.Module):
padding_mask = attention_mask.astype(mx.bool_) # (batch, seq_len)
combined = causal_mask[None, :, :] & padding_mask[:, None, :]
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
mask = mx.where(combined, mx.zeros(combined.shape, dtype=dtype),
mx.full(combined.shape, min_val, dtype=dtype))
min_val = (
mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
)
mask = mx.where(
combined,
mx.zeros(combined.shape, dtype=dtype),
mx.full(combined.shape, min_val, dtype=dtype),
)
return mask[:, None, :, :]
else:
# No padding mask, just causal
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
mask = mx.where(causal_mask, mx.zeros((seq_len, seq_len), dtype=dtype),
mx.full((seq_len, seq_len), min_val, dtype=dtype))
min_val = (
mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
)
mask = mx.where(
causal_mask,
mx.zeros((seq_len, seq_len), dtype=dtype),
mx.full((seq_len, seq_len), min_val, dtype=dtype),
)
return mask[None, None, :, :] # (1, 1, seq, seq)
def __call__(
@@ -91,7 +100,11 @@ class LanguageModel(nn.Module):
batch_size, seq_len = inputs.shape
# Get embeddings
h = input_embeddings if input_embeddings is not None else self.model.embed_tokens(inputs)
h = (
input_embeddings
if input_embeddings is not None
else self.model.embed_tokens(inputs)
)
# Apply Gemma scaling
h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype)
@@ -103,11 +116,12 @@ class LanguageModel(nn.Module):
if cache is None:
cache = [None] * len(self.model.layers)
full_causal_mask = self._create_causal_mask_with_padding(seq_len, attention_mask, h.dtype)
full_causal_mask = self._create_causal_mask_with_padding(
seq_len, attention_mask, h.dtype
)
sliding_mask = full_causal_mask
num_layers = len(self.model.layers)
for i, layer in enumerate(self.model.layers):
is_global = (
@@ -147,9 +161,9 @@ class LanguageModel(nn.Module):
for key, value in weights.items():
if key.startswith(prefix):
if hasattr(value, "dtype") and value.dtype == mx.float32:
sanitized[key[len(prefix):]] = value.astype(mx.bfloat16)
sanitized[key[len(prefix) :]] = value.astype(mx.bfloat16)
else:
sanitized[key[len(prefix):]] = value
sanitized[key[len(prefix) :]] = value
return sanitized
@property
@@ -158,6 +172,7 @@ class LanguageModel(nn.Module):
def make_cache(self):
from mlx_vlm.models.cache import KVCache, RotatingKVCache
caches = []
for i in range(len(self.layers)):
if (
@@ -172,6 +187,7 @@ class LanguageModel(nn.Module):
@classmethod
def from_pretrained(cls, model_path: str):
import json
weight_files = sorted(Path(model_path).glob("*.safetensors"))
config_file = Path(model_path) / "config.json"
config_dict = {}
@@ -179,7 +195,9 @@ class LanguageModel(nn.Module):
with open(config_file, "r") as f:
config_dict = json.load(f)
language_model = cls(config=TextConfig.from_dict(config_dict["text_config"]))
language_model = cls(
config=TextConfig.from_dict(config_dict["text_config"])
)
else:
raise ValueError(f"Config file not found at {model_path}")
@@ -188,19 +206,18 @@ class LanguageModel(nn.Module):
for i, wf in enumerate(weight_files):
weights.update(mx.load(str(wf)))
if hasattr(language_model, "sanitize"):
weights = language_model.sanitize(weights=weights)
apply_quantization(model=language_model, weights=weights, quantization=quantization)
apply_quantization(
model=language_model, weights=weights, quantization=quantization
)
language_model.load_weights(list(weights.items()), strict=False)
return language_model
class ConnectorAttention(nn.Module):
def __init__(
@@ -250,9 +267,15 @@ class ConnectorAttention(nn.Module):
k = self.k_norm(k)
# Reshape to (B, H, T, D) for SPLIT RoPE
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
q = mx.reshape(
q, (batch_size, seq_len, self.num_heads, self.head_dim)
).transpose(0, 2, 1, 3)
k = mx.reshape(
k, (batch_size, seq_len, self.num_heads, self.head_dim)
).transpose(0, 2, 1, 3)
v = mx.reshape(
v, (batch_size, seq_len, self.num_heads, self.head_dim)
).transpose(0, 2, 1, 3)
if pe is not None:
q = self._apply_split_rope(q, pe[0], pe[1])
@@ -304,7 +327,7 @@ class ConnectorAttention(nn.Module):
out2 = x2 * cos_freq + x1 * sin_freq
return mx.concatenate([out1, out2], axis=-1).astype(input_dtype)
class GEGLU(nn.Module):
"""GELU-gated linear unit."""
@@ -336,9 +359,17 @@ class ConnectorFeedForward(nn.Module):
class ConnectorTransformerBlock(nn.Module):
def __init__(self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128, has_gate_logits: bool = False):
def __init__(
self,
dim: int = 3840,
num_heads: int = 30,
head_dim: int = 128,
has_gate_logits: bool = False,
):
super().__init__()
self.attn1 = ConnectorAttention(dim, num_heads, head_dim, has_gate_logits=has_gate_logits)
self.attn1 = ConnectorAttention(
dim, num_heads, head_dim, has_gate_logits=has_gate_logits
)
self.ff = ConnectorFeedForward(dim)
def __call__(
@@ -388,14 +419,18 @@ class Embeddings1DConnector(nn.Module):
self.positional_embedding_max_pos = positional_embedding_max_pos or [1]
self.transformer_1d_blocks = {
i: ConnectorTransformerBlock(dim, num_heads, head_dim, has_gate_logits=has_gate_logits)
i: ConnectorTransformerBlock(
dim, num_heads, head_dim, has_gate_logits=has_gate_logits
)
for i in range(num_layers)
}
if num_learnable_registers > 0:
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
def _precompute_freqs_cis(
self, seq_len: int, dtype: mx.Dtype
) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies for connector (SPLIT type matching PyTorch).
Returns tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2).
@@ -464,11 +499,15 @@ class Embeddings1DConnector(nn.Module):
# Binary mask: 1 for valid tokens, 0 for padded
# attention_mask is additive: 0 for valid, large negative for padded
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(mx.int32) # (batch, seq)
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(
mx.int32
) # (batch, seq)
# Tile registers to match sequence length, cast to hidden_states dtype
num_tiles = seq_len // self.num_learnable_registers
registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype(dtype) # (seq_len, dim)
registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype(
dtype
) # (seq_len, dim)
# Process each batch item (PyTorch uses advanced indexing)
result_list = []
@@ -481,25 +520,33 @@ class Embeddings1DConnector(nn.Module):
# Extract valid tokens (where mask is 1)
# Since we have left-padded input, valid tokens are at the end
valid_tokens = hs_b[seq_len - num_valid:] # (num_valid, dim)
valid_tokens = hs_b[seq_len - num_valid :] # (num_valid, dim)
# Pad with zeros on the right to get back to seq_len
pad_length = seq_len - num_valid
if pad_length > 0:
padding = mx.zeros((pad_length, dim), dtype=dtype)
adjusted = mx.concatenate([valid_tokens, padding], axis=0) # (seq_len, dim)
adjusted = mx.concatenate(
[valid_tokens, padding], axis=0
) # (seq_len, dim)
else:
adjusted = valid_tokens
# Create flipped mask: 1s at front (where valid tokens now are), 0s at back
flipped_mask = mx.concatenate([
mx.ones((num_valid,), dtype=mx.int32),
mx.zeros((pad_length,), dtype=mx.int32)
], axis=0) # (seq,)
flipped_mask = mx.concatenate(
[
mx.ones((num_valid,), dtype=mx.int32),
mx.zeros((pad_length,), dtype=mx.int32),
],
axis=0,
) # (seq,)
# Combine: valid tokens at front, registers at back
flipped_mask_expanded = flipped_mask[:, None].astype(dtype) # (seq, 1)
combined = flipped_mask_expanded * adjusted + (1 - flipped_mask_expanded) * registers
combined = (
flipped_mask_expanded * adjusted
+ (1 - flipped_mask_expanded) * registers
)
result_list.append(combined)
hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim)
@@ -526,7 +573,9 @@ class Embeddings1DConnector(nn.Module):
# Process through transformer blocks
for i in range(len(self.transformer_1d_blocks)):
hidden_states = self.transformer_1d_blocks[i](hidden_states, attention_mask, freqs_cis)
hidden_states = self.transformer_1d_blocks[i](
hidden_states, attention_mask, freqs_cis
)
# Final RMS norm
hidden_states = rms_norm(hidden_states)
@@ -534,7 +583,6 @@ class Embeddings1DConnector(nn.Module):
return hidden_states, attention_mask
def norm_and_concat_hidden_states(
hidden_states: List[mx.array],
attention_mask: mx.array,
@@ -567,8 +615,12 @@ def norm_and_concat_hidden_states(
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
# Compute masked min/max per layer
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=dtype))
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=dtype))
x_for_min = mx.where(
mask, stacked, mx.full(stacked.shape, float("inf"), dtype=dtype)
)
x_for_max = mx.where(
mask, stacked, mx.full(stacked.shape, float("-inf"), dtype=dtype)
)
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
range_val = x_max - x_min
@@ -603,7 +655,9 @@ def norm_and_concat_per_token_rms(
dtype = encoded_text.dtype
# Per-token RMSNorm across hidden dimension: variance = mean(x^2) over dim D
variance = mx.mean(encoded_text.astype(mx.float32) ** 2, axis=2, keepdims=True) # (B, T, 1, L)
variance = mx.mean(
encoded_text.astype(mx.float32) ** 2, axis=2, keepdims=True
) # (B, T, 1, L)
normed = encoded_text.astype(mx.float32) * mx.rsqrt(variance + 1e-6)
normed = normed.astype(dtype)
@@ -625,7 +679,9 @@ def _rescale_norm(x: mx.array, target_dim: int, source_dim: int) -> mx.array:
class GemmaFeaturesExtractor(nn.Module):
"""V1 feature extractor (LTX-2): 8 * (x - mean) / range normalization."""
def __init__(self, input_dim: int = 188160, output_dim: int = 3840, bias: bool = False):
def __init__(
self, input_dim: int = 188160, output_dim: int = 3840, bias: bool = False
):
super().__init__()
self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=bias)
@@ -674,13 +730,14 @@ class GemmaFeaturesExtractorV2(nn.Module):
if mode == "video":
target_dim = self.video_aggregate_embed.weight.shape[0]
return self.video_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim))
return self.video_aggregate_embed(
_rescale_norm(normed, target_dim, self.embedding_dim)
)
else:
target_dim = self.audio_aggregate_embed.weight.shape[0]
return self.audio_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim))
return self.audio_aggregate_embed(
_rescale_norm(normed, target_dim, self.embedding_dim)
)
class AudioEmbeddingsConnector(nn.Module):
@@ -717,8 +774,8 @@ class LTX2TextEncoder(nn.Module):
video_output_dim = 4096
audio_output_dim = 2048
self.feature_extractor_v2 = GemmaFeaturesExtractorV2(
flat_dim=feature_input_dim, # 3840 * 49 = 188160 (concatenated)
embedding_dim=hidden_dim, # 3840 (Gemma hidden_dim, for rescale)
flat_dim=feature_input_dim, # 3840 * 49 = 188160 (concatenated)
embedding_dim=hidden_dim, # 3840 (Gemma hidden_dim, for rescale)
video_output_dim=video_output_dim,
audio_output_dim=audio_output_dim,
bias=True,
@@ -728,37 +785,57 @@ class LTX2TextEncoder(nn.Module):
# connector_positional_embedding_max_pos=[4096] from LTX-2.3 safetensors
# config (nested under config.transformer.connector_positional_embedding_max_pos)
self.video_embeddings_connector = Embeddings1DConnector(
dim=video_output_dim, num_heads=32, head_dim=128,
num_layers=8, num_learnable_registers=128,
positional_embedding_max_pos=[4096], has_gate_logits=True,
dim=video_output_dim,
num_heads=32,
head_dim=128,
num_layers=8,
num_learnable_registers=128,
positional_embedding_max_pos=[4096],
has_gate_logits=True,
)
self.audio_embeddings_connector = Embeddings1DConnector(
dim=audio_output_dim, num_heads=32, head_dim=64,
num_layers=8, num_learnable_registers=128,
positional_embedding_max_pos=[4096], has_gate_logits=True,
dim=audio_output_dim,
num_heads=32,
head_dim=64,
num_layers=8,
num_learnable_registers=128,
positional_embedding_max_pos=[4096],
has_gate_logits=True,
)
else:
# LTX-2: shared feature extractor, 3840-dim connectors
self.feature_extractor = GemmaFeaturesExtractor(feature_input_dim, hidden_dim)
self.feature_extractor = GemmaFeaturesExtractor(
feature_input_dim, hidden_dim
)
self.video_embeddings_connector = Embeddings1DConnector(
dim=hidden_dim, num_heads=30, head_dim=128,
num_layers=2, num_learnable_registers=128,
dim=hidden_dim,
num_heads=30,
head_dim=128,
num_layers=2,
num_learnable_registers=128,
positional_embedding_max_pos=[1],
)
self.audio_embeddings_connector = Embeddings1DConnector(
dim=hidden_dim, num_heads=30, head_dim=128,
num_layers=2, num_learnable_registers=128,
dim=hidden_dim,
num_heads=30,
head_dim=128,
num_layers=2,
num_learnable_registers=128,
positional_embedding_max_pos=[1],
)
self.processor = None
def load(self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it"):
def load(
self,
model_path: Optional[str] = None,
text_encoder_path: Optional[str] = "google/gemma-3-12b-it",
):
if Path(str(text_encoder_path)).joinpath("text_encoder").is_dir():
text_encoder_path = str(Path(text_encoder_path) / "text_encoder")
self.language_model = LanguageModel.from_pretrained(text_encoder_path)
# Load transformer weights for feature extractor and connector.
@@ -785,22 +862,35 @@ class LTX2TextEncoder(nn.Module):
if transformer_weights:
self._load_feature_extractors(transformer_weights, is_reformatted)
self._load_connector("video_embeddings_connector", transformer_weights, is_reformatted)
self._load_connector("audio_embeddings_connector", transformer_weights, is_reformatted)
self._load_connector(
"video_embeddings_connector", transformer_weights, is_reformatted
)
self._load_connector(
"audio_embeddings_connector", transformer_weights, is_reformatted
)
else:
print("WARNING: No transformer weights found for text projection connectors. "
"Text conditioning will use uninitialized weights!")
print(
"WARNING: No transformer weights found for text projection connectors. "
"Text conditioning will use uninitialized weights!"
)
# Load tokenizer
from transformers import AutoTokenizer
tokenizer_path = model_path / "tokenizer"
if tokenizer_path.exists():
self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True)
self.processor = AutoTokenizer.from_pretrained(
str(tokenizer_path), trust_remote_code=True
)
else:
try:
self.processor = AutoTokenizer.from_pretrained(text_encoder_path, trust_remote_code=True)
self.processor = AutoTokenizer.from_pretrained(
text_encoder_path, trust_remote_code=True
)
except Exception:
self.processor = AutoTokenizer.from_pretrained("google/gemma-3-12b-it", trust_remote_code=True)
self.processor = AutoTokenizer.from_pretrained(
"google/gemma-3-12b-it", trust_remote_code=True
)
# Set left padding to match official LTX-2 text encoder
self.processor.padding_side = "left"
@@ -823,7 +913,11 @@ class LTX2TextEncoder(nn.Module):
submodule.bias = weights[b_key]
else:
# LTX-2: single aggregate_embed
agg_key = "aggregate_embed.weight" if is_reformatted else "text_embedding_projection.aggregate_embed.weight"
agg_key = (
"aggregate_embed.weight"
if is_reformatted
else "text_embedding_projection.aggregate_embed.weight"
)
if agg_key in weights:
self.feature_extractor.aggregate_embed.weight = weights[agg_key]
@@ -837,12 +931,12 @@ class LTX2TextEncoder(nn.Module):
prefix = f"{name}."
for key, value in weights.items():
if key.startswith(prefix):
connector_weights[key[len(prefix):]] = value
connector_weights[key[len(prefix) :]] = value
else:
mono_prefix = f"model.diffusion_model.{name}."
for key, value in weights.items():
if key.startswith(mono_prefix):
connector_weights[key[len(mono_prefix):]] = value
connector_weights[key[len(mono_prefix) :]] = value
if not connector_weights:
return
@@ -894,21 +988,36 @@ class LTX2TextEncoder(nn.Module):
input_ids = mx.array(inputs["input_ids"])
attention_mask = mx.array(inputs["attention_mask"])
_, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True)
_, all_hidden_states = self.language_model(
inputs=input_ids,
input_embeddings=None,
attention_mask=attention_mask,
output_hidden_states=True,
)
if self.has_prompt_adaln:
# LTX-2.3: V2 feature extraction (per-token RMSNorm + rescale)
video_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="video")
video_features = self.feature_extractor_v2(
all_hidden_states, attention_mask, mode="video"
)
additive_mask = (attention_mask - 1).astype(video_features.dtype)
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
additive_mask = (
additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
)
video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask)
video_embeddings, _ = self.video_embeddings_connector(
video_features, additive_mask
)
if return_audio_embeddings:
audio_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="audio")
audio_features = self.feature_extractor_v2(
all_hidden_states, attention_mask, mode="audio"
)
audio_mask = (attention_mask - 1).astype(audio_features.dtype)
audio_mask = audio_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
audio_embeddings, _ = self.audio_embeddings_connector(audio_features, audio_mask)
audio_embeddings, _ = self.audio_embeddings_connector(
audio_features, audio_mask
)
return video_embeddings, audio_embeddings
else:
return video_embeddings, attention_mask
@@ -920,12 +1029,18 @@ class LTX2TextEncoder(nn.Module):
video_features = self.feature_extractor(concat_hidden)
additive_mask = (attention_mask - 1).astype(video_features.dtype)
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
additive_mask = (
additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
)
video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask)
video_embeddings, _ = self.video_embeddings_connector(
video_features, additive_mask
)
if return_audio_embeddings:
audio_embeddings, _ = self.audio_embeddings_connector(video_features, additive_mask)
audio_embeddings, _ = self.audio_embeddings_connector(
video_features, additive_mask
)
return video_embeddings, audio_embeddings
else:
return video_embeddings, attention_mask
@@ -964,7 +1079,7 @@ class LTX2TextEncoder(nn.Module):
# Remove leading/trailing whitespace
response = response.strip()
# Remove any leading punctuation
response = re.sub(r'^[^\w\s]+', '', response)
response = re.sub(r"^[^\w\s]+", "", response)
return response
def _apply_chat_template(
@@ -985,7 +1100,9 @@ class LTX2TextEncoder(nn.Module):
elif isinstance(content, list):
# Handle multimodal content (image + text)
text_parts = [c["text"] for c in content if c.get("type") == "text"]
formatted += f"<start_of_turn>user\n{' '.join(text_parts)}<end_of_turn>\n"
formatted += (
f"<start_of_turn>user\n{' '.join(text_parts)}<end_of_turn>\n"
)
elif role == "assistant":
formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n"
# Add generation prompt
@@ -1016,7 +1133,9 @@ class LTX2TextEncoder(nn.Module):
from mlx_lm import stream_generate
from mlx_lm.sample_utils import make_logits_processors, make_sampler
except ImportError:
logging.warning("mlx-lm not available for prompt enhancement. Using original prompt.")
logging.warning(
"mlx-lm not available for prompt enhancement. Using original prompt."
)
return prompt
if self.processor is None:
@@ -1043,7 +1162,11 @@ class LTX2TextEncoder(nn.Module):
)
input_ids = mx.array(inputs["input_ids"])
sampler = make_sampler(kwargs.get("temperature", 0.7), kwargs.get("top_p", 1.0), top_k=kwargs.get("top_k", -1))
sampler = make_sampler(
kwargs.get("temperature", 0.7),
kwargs.get("top_p", 1.0),
top_k=kwargs.get("top_k", -1),
)
logits_processors = make_logits_processors(
kwargs.get("logit_bias", None),
kwargs.get("repetition_penalty", 1.3),
@@ -1094,14 +1217,15 @@ class LTX2TextEncoder(nn.Module):
mx.clear_cache()
# Decode only the new tokens
enhanced_prompt = self.processor.decode(generated_tokens, skip_special_tokens=True)
enhanced_prompt = self.processor.decode(
generated_tokens, skip_special_tokens=True
)
enhanced_prompt = self._clean_response(enhanced_prompt)
logging.info(f"Enhanced prompt: {enhanced_prompt}")
return enhanced_prompt
def enhance_i2v(
self,
prompt: str,
@@ -1135,4 +1259,3 @@ def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
encoder = LTX2TextEncoder()
encoder.load(model_path=model_path)
return encoder