Add prompt enhancement feature to video generation
- Introduced `enhance_prompt`, `max_tokens`, and `temperature` parameters in `generate_video` function for improved prompt handling. - Implemented prompt enhancement logic using the new `enhance_t2v` method in the text encoder. - Added command-line arguments for prompt enhancement options. - Created new system prompt files for T2V and I2V generation to guide the enhancement process.
This commit is contained in:
@@ -1,16 +1,16 @@
|
||||
"""Gemma 3 Text Encoder for LTX-2 - Full Pipeline.
|
||||
|
||||
Uses mlx-vlm's Gemma3 implementation which has been validated to match PyTorch
|
||||
with 0.999 correlation.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from mlx_video.utils import rms_norm, apply_quantization
|
||||
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
|
||||
@@ -19,6 +19,19 @@ from mlx_vlm.models.gemma3.language import Gemma3Model
|
||||
from mlx_vlm.models.gemma3.config import TextConfig
|
||||
|
||||
|
||||
# Path to system prompts
|
||||
PROMPTS_DIR = Path(__file__).parent / "prompts"
|
||||
|
||||
|
||||
def _load_system_prompt(prompt_name: str) -> str:
|
||||
"""Load a system prompt from the prompts directory."""
|
||||
prompt_path = PROMPTS_DIR / prompt_name
|
||||
if prompt_path.exists():
|
||||
with open(prompt_path, "r") as f:
|
||||
return f.read()
|
||||
raise FileNotFoundError(f"System prompt not found: {prompt_path}")
|
||||
|
||||
|
||||
class LanguageModel(nn.Module):
|
||||
|
||||
|
||||
@@ -57,9 +70,11 @@ class LanguageModel(nn.Module):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: mx.array,
|
||||
inputs: mx.array,
|
||||
input_embeddings: Optional[mx.array] = None,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
output_hidden_states: bool = True,
|
||||
output_hidden_states: bool = False,
|
||||
cache: Optional[List[mx.array]] = None,
|
||||
) -> Tuple[mx.array, List[mx.array]]:
|
||||
"""Forward pass returning hidden states.
|
||||
|
||||
@@ -71,10 +86,10 @@ class LanguageModel(nn.Module):
|
||||
Returns:
|
||||
Tuple of (final_hidden_states, list_of_all_hidden_states)
|
||||
"""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
batch_size, seq_len = inputs.shape
|
||||
|
||||
# Get embeddings
|
||||
h = self.model.embed_tokens(input_ids)
|
||||
h = input_embeddings if input_embeddings is not None else self.model.embed_tokens(inputs)
|
||||
|
||||
# Apply Gemma scaling
|
||||
h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype)
|
||||
@@ -83,11 +98,11 @@ class LanguageModel(nn.Module):
|
||||
all_hidden_states = [h] if output_hidden_states else []
|
||||
|
||||
# Set up cache (all None for non-cached inference)
|
||||
cache = [None] * len(self.model.layers)
|
||||
if cache is None:
|
||||
cache = [None] * len(self.model.layers)
|
||||
|
||||
full_causal_mask = self._create_causal_mask_with_padding(seq_len, attention_mask, h.dtype)
|
||||
|
||||
|
||||
sliding_mask = full_causal_mask
|
||||
|
||||
|
||||
@@ -104,7 +119,7 @@ class LanguageModel(nn.Module):
|
||||
else:
|
||||
local_mask = sliding_mask
|
||||
|
||||
h = layer(h, local_mask, None)
|
||||
h = layer(h, local_mask, cache[i])
|
||||
mx.eval(h)
|
||||
|
||||
if output_hidden_states and i < num_layers - 1:
|
||||
@@ -118,7 +133,11 @@ class LanguageModel(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
return hidden_states, all_hidden_states
|
||||
return hidden_states, all_hidden_states
|
||||
|
||||
else:
|
||||
# Return logits
|
||||
return self.model.embed_tokens.as_linear(hidden_states)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
prefix = "language_model."
|
||||
@@ -130,6 +149,24 @@ class LanguageModel(nn.Module):
|
||||
else:
|
||||
sanitized[key[len(prefix):]] = value
|
||||
return sanitized
|
||||
|
||||
@property
|
||||
def layers(self) -> List[nn.Module]:
|
||||
return self.model.layers
|
||||
|
||||
def make_cache(self):
|
||||
from mlx_vlm.models.cache import KVCache, RotatingKVCache
|
||||
caches = []
|
||||
for i in range(len(self.layers)):
|
||||
if (
|
||||
i % self.config.sliding_window_pattern
|
||||
== self.config.sliding_window_pattern - 1
|
||||
):
|
||||
caches.append(KVCache())
|
||||
else:
|
||||
caches.append(RotatingKVCache(max_size=self.config.sliding_window))
|
||||
return caches
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str):
|
||||
import json
|
||||
@@ -149,7 +186,7 @@ class LanguageModel(nn.Module):
|
||||
for i, wf in enumerate(weight_files):
|
||||
weights.update(mx.load(str(wf)))
|
||||
|
||||
|
||||
|
||||
if hasattr(language_model, "sanitize"):
|
||||
weights = language_model.sanitize(weights=weights)
|
||||
|
||||
@@ -601,7 +638,7 @@ class LTX2TextEncoder(nn.Module):
|
||||
input_ids = mx.array(inputs["input_ids"])
|
||||
attention_mask = mx.array(inputs["attention_mask"])
|
||||
|
||||
_, all_hidden_states = self.language_model(input_ids, attention_mask, output_hidden_states=True)
|
||||
_, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True)
|
||||
|
||||
concat_hidden = norm_and_concat_hidden_states(
|
||||
all_hidden_states, attention_mask, padding_side="left"
|
||||
@@ -623,6 +660,177 @@ class LTX2TextEncoder(nn.Module):
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
return self.encode(prompt, max_length)
|
||||
|
||||
@functools.cached_property
|
||||
def default_t2v_system_prompt(self) -> str:
|
||||
"""Load the default T2V system prompt."""
|
||||
return _load_system_prompt("gemma_t2v_system_prompt.txt")
|
||||
|
||||
@functools.cached_property
|
||||
def default_i2v_system_prompt(self) -> str:
|
||||
"""Load the default I2V system prompt."""
|
||||
return _load_system_prompt("gemma_i2v_system_prompt.txt")
|
||||
|
||||
def _clean_response(self, response: str) -> str:
|
||||
"""Clean up the generated response."""
|
||||
# Remove leading/trailing whitespace
|
||||
response = response.strip()
|
||||
# Remove any leading punctuation
|
||||
response = re.sub(r'^[^\w\s]+', '', response)
|
||||
return response
|
||||
|
||||
def _apply_chat_template(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
) -> str:
|
||||
"""Apply Gemma chat template to messages."""
|
||||
# Gemma 3 chat template format
|
||||
formatted = ""
|
||||
for msg in messages:
|
||||
role = msg["role"]
|
||||
content = msg["content"]
|
||||
if role == "system":
|
||||
formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n"
|
||||
elif role == "user":
|
||||
if isinstance(content, str):
|
||||
formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n"
|
||||
elif isinstance(content, list):
|
||||
# Handle multimodal content (image + text)
|
||||
text_parts = [c["text"] for c in content if c.get("type") == "text"]
|
||||
formatted += f"<start_of_turn>user\n{' '.join(text_parts)}<end_of_turn>\n"
|
||||
elif role == "assistant":
|
||||
formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n"
|
||||
# Add generation prompt
|
||||
formatted += "<start_of_turn>model\n"
|
||||
return formatted
|
||||
|
||||
def enhance_t2v(
|
||||
self,
|
||||
prompt: str,
|
||||
max_tokens: int = 512,
|
||||
system_prompt: Optional[str] = None,
|
||||
seed: int = 42,
|
||||
verbose: bool = True,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Enhance a text prompt for T2V generation using mlx-lm.
|
||||
|
||||
Args:
|
||||
prompt: The original user prompt
|
||||
max_new_tokens: Maximum number of tokens to generate
|
||||
system_prompt: Optional custom system prompt
|
||||
seed: Random seed for generation
|
||||
|
||||
Returns:
|
||||
Enhanced prompt string
|
||||
"""
|
||||
from tqdm import tqdm
|
||||
try:
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||
except ImportError:
|
||||
logging.warning("mlx-lm not available for prompt enhancement. Using original prompt.")
|
||||
return prompt
|
||||
|
||||
if self.processor is None:
|
||||
raise RuntimeError("Model not loaded. Call load() first.")
|
||||
|
||||
system_prompt = system_prompt or self.default_t2v_system_prompt
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": f"user prompt: {prompt}"},
|
||||
]
|
||||
|
||||
# Apply chat template
|
||||
formatted = self._apply_chat_template(messages)
|
||||
|
||||
# Use mlx-lm generate with temperature sampling
|
||||
mx.random.seed(seed)
|
||||
|
||||
|
||||
# Tokenize
|
||||
inputs = self.processor(
|
||||
formatted,
|
||||
return_tensors="np",
|
||||
add_special_tokens=False,
|
||||
)
|
||||
input_ids = mx.array(inputs["input_ids"])
|
||||
|
||||
sampler = make_sampler(kwargs.get("temperature", 0.7), kwargs.get("top_p", 0.95), top_k=kwargs.get("top_k", -1))
|
||||
logits_processors = make_logits_processors(
|
||||
kwargs.get("logit_bias", None),
|
||||
kwargs.get("repetition_penalty", 1.3),
|
||||
kwargs.get("repetition_context_size", 20),
|
||||
)
|
||||
|
||||
generated_token_count = 0
|
||||
generated_tokens = []
|
||||
for i, response in enumerate(
|
||||
tqdm(
|
||||
stream_generate(
|
||||
self.language_model,
|
||||
tokenizer=self.processor,
|
||||
prompt=input_ids.squeeze(0),
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
),
|
||||
total=max_tokens,
|
||||
disable=not verbose,
|
||||
)
|
||||
):
|
||||
next_token = mx.array([response.token])
|
||||
input_ids = mx.concatenate([input_ids, next_token[None, :]], axis=1)
|
||||
generated_tokens.append(next_token.squeeze())
|
||||
generated_token_count += 1
|
||||
|
||||
if i % 50 == 0:
|
||||
mx.clear_cache()
|
||||
|
||||
# Check for EOS
|
||||
if response.token == 1 or response.token == 107: # EOS tokens
|
||||
break
|
||||
|
||||
|
||||
|
||||
# Decode only the new tokens
|
||||
|
||||
enhanced_prompt = self.processor.decode(generated_tokens, skip_special_tokens=True)
|
||||
|
||||
enhanced_prompt = self._clean_response(enhanced_prompt)
|
||||
logging.info(f"Enhanced prompt: {enhanced_prompt}")
|
||||
|
||||
return enhanced_prompt
|
||||
|
||||
|
||||
def enhance_i2v(
|
||||
self,
|
||||
prompt: str,
|
||||
image: Optional[mx.array] = None,
|
||||
max_new_tokens: int = 512,
|
||||
system_prompt: Optional[str] = None,
|
||||
seed: int = 42,
|
||||
) -> str:
|
||||
"""Enhance a text prompt for I2V generation.
|
||||
|
||||
Args:
|
||||
prompt: The original user prompt
|
||||
image: Optional image tensor (not currently used)
|
||||
max_new_tokens: Maximum number of tokens to generate
|
||||
system_prompt: Optional custom system prompt
|
||||
seed: Random seed for generation
|
||||
|
||||
Returns:
|
||||
Enhanced prompt string
|
||||
"""
|
||||
# Use T2V enhancement with I2V system prompt
|
||||
return self.enhance_t2v(
|
||||
prompt,
|
||||
max_new_tokens=max_new_tokens,
|
||||
system_prompt=system_prompt or self.default_i2v_system_prompt,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
|
||||
def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
|
||||
encoder = LTX2TextEncoder(model_path=model_path)
|
||||
|
||||
Reference in New Issue
Block a user