Add prompt enhancement feature to video generation

- Introduced `enhance_prompt`, `max_tokens`, and `temperature` parameters in `generate_video` function for improved prompt handling.
- Implemented prompt enhancement logic using the new `enhance_t2v` method in the text encoder.
- Added command-line arguments for prompt enhancement options.
- Created new system prompt files for T2V and I2V generation to guide the enhancement process.
This commit is contained in:
Prince Canuma
2026-01-15 14:31:00 +01:00
parent f5134fa172
commit 81daf3f67d
4 changed files with 320 additions and 26 deletions

View File

@@ -1,16 +1,16 @@
"""Gemma 3 Text Encoder for LTX-2 - Full Pipeline.
Uses mlx-vlm's Gemma3 implementation which has been validated to match PyTorch
with 0.999 correlation.
"""
import functools
import logging
import math
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx_video.utils import rms_norm, apply_quantization
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
@@ -19,6 +19,19 @@ from mlx_vlm.models.gemma3.language import Gemma3Model
from mlx_vlm.models.gemma3.config import TextConfig
# Path to system prompts
PROMPTS_DIR = Path(__file__).parent / "prompts"
def _load_system_prompt(prompt_name: str) -> str:
"""Load a system prompt from the prompts directory."""
prompt_path = PROMPTS_DIR / prompt_name
if prompt_path.exists():
with open(prompt_path, "r") as f:
return f.read()
raise FileNotFoundError(f"System prompt not found: {prompt_path}")
class LanguageModel(nn.Module):
@@ -57,9 +70,11 @@ class LanguageModel(nn.Module):
def __call__(
self,
input_ids: mx.array,
inputs: mx.array,
input_embeddings: Optional[mx.array] = None,
attention_mask: Optional[mx.array] = None,
output_hidden_states: bool = True,
output_hidden_states: bool = False,
cache: Optional[List[mx.array]] = None,
) -> Tuple[mx.array, List[mx.array]]:
"""Forward pass returning hidden states.
@@ -71,10 +86,10 @@ class LanguageModel(nn.Module):
Returns:
Tuple of (final_hidden_states, list_of_all_hidden_states)
"""
batch_size, seq_len = input_ids.shape
batch_size, seq_len = inputs.shape
# Get embeddings
h = self.model.embed_tokens(input_ids)
h = input_embeddings if input_embeddings is not None else self.model.embed_tokens(inputs)
# Apply Gemma scaling
h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype)
@@ -83,11 +98,11 @@ class LanguageModel(nn.Module):
all_hidden_states = [h] if output_hidden_states else []
# Set up cache (all None for non-cached inference)
cache = [None] * len(self.model.layers)
if cache is None:
cache = [None] * len(self.model.layers)
full_causal_mask = self._create_causal_mask_with_padding(seq_len, attention_mask, h.dtype)
sliding_mask = full_causal_mask
@@ -104,7 +119,7 @@ class LanguageModel(nn.Module):
else:
local_mask = sliding_mask
h = layer(h, local_mask, None)
h = layer(h, local_mask, cache[i])
mx.eval(h)
if output_hidden_states and i < num_layers - 1:
@@ -118,7 +133,11 @@ class LanguageModel(nn.Module):
if output_hidden_states:
all_hidden_states.append(hidden_states)
return hidden_states, all_hidden_states
return hidden_states, all_hidden_states
else:
# Return logits
return self.model.embed_tokens.as_linear(hidden_states)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
prefix = "language_model."
@@ -130,6 +149,24 @@ class LanguageModel(nn.Module):
else:
sanitized[key[len(prefix):]] = value
return sanitized
@property
def layers(self) -> List[nn.Module]:
return self.model.layers
def make_cache(self):
from mlx_vlm.models.cache import KVCache, RotatingKVCache
caches = []
for i in range(len(self.layers)):
if (
i % self.config.sliding_window_pattern
== self.config.sliding_window_pattern - 1
):
caches.append(KVCache())
else:
caches.append(RotatingKVCache(max_size=self.config.sliding_window))
return caches
@classmethod
def from_pretrained(cls, model_path: str):
import json
@@ -149,7 +186,7 @@ class LanguageModel(nn.Module):
for i, wf in enumerate(weight_files):
weights.update(mx.load(str(wf)))
if hasattr(language_model, "sanitize"):
weights = language_model.sanitize(weights=weights)
@@ -601,7 +638,7 @@ class LTX2TextEncoder(nn.Module):
input_ids = mx.array(inputs["input_ids"])
attention_mask = mx.array(inputs["attention_mask"])
_, all_hidden_states = self.language_model(input_ids, attention_mask, output_hidden_states=True)
_, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True)
concat_hidden = norm_and_concat_hidden_states(
all_hidden_states, attention_mask, padding_side="left"
@@ -623,6 +660,177 @@ class LTX2TextEncoder(nn.Module):
) -> Tuple[mx.array, mx.array]:
return self.encode(prompt, max_length)
@functools.cached_property
def default_t2v_system_prompt(self) -> str:
"""Load the default T2V system prompt."""
return _load_system_prompt("gemma_t2v_system_prompt.txt")
@functools.cached_property
def default_i2v_system_prompt(self) -> str:
"""Load the default I2V system prompt."""
return _load_system_prompt("gemma_i2v_system_prompt.txt")
def _clean_response(self, response: str) -> str:
"""Clean up the generated response."""
# Remove leading/trailing whitespace
response = response.strip()
# Remove any leading punctuation
response = re.sub(r'^[^\w\s]+', '', response)
return response
def _apply_chat_template(
self,
messages: List[Dict[str, str]],
) -> str:
"""Apply Gemma chat template to messages."""
# Gemma 3 chat template format
formatted = ""
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "system":
formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n"
elif role == "user":
if isinstance(content, str):
formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n"
elif isinstance(content, list):
# Handle multimodal content (image + text)
text_parts = [c["text"] for c in content if c.get("type") == "text"]
formatted += f"<start_of_turn>user\n{' '.join(text_parts)}<end_of_turn>\n"
elif role == "assistant":
formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n"
# Add generation prompt
formatted += "<start_of_turn>model\n"
return formatted
def enhance_t2v(
self,
prompt: str,
max_tokens: int = 512,
system_prompt: Optional[str] = None,
seed: int = 42,
verbose: bool = True,
**kwargs,
) -> str:
"""Enhance a text prompt for T2V generation using mlx-lm.
Args:
prompt: The original user prompt
max_new_tokens: Maximum number of tokens to generate
system_prompt: Optional custom system prompt
seed: Random seed for generation
Returns:
Enhanced prompt string
"""
from tqdm import tqdm
try:
from mlx_lm import stream_generate
from mlx_lm.sample_utils import make_logits_processors, make_sampler
except ImportError:
logging.warning("mlx-lm not available for prompt enhancement. Using original prompt.")
return prompt
if self.processor is None:
raise RuntimeError("Model not loaded. Call load() first.")
system_prompt = system_prompt or self.default_t2v_system_prompt
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"user prompt: {prompt}"},
]
# Apply chat template
formatted = self._apply_chat_template(messages)
# Use mlx-lm generate with temperature sampling
mx.random.seed(seed)
# Tokenize
inputs = self.processor(
formatted,
return_tensors="np",
add_special_tokens=False,
)
input_ids = mx.array(inputs["input_ids"])
sampler = make_sampler(kwargs.get("temperature", 0.7), kwargs.get("top_p", 0.95), top_k=kwargs.get("top_k", -1))
logits_processors = make_logits_processors(
kwargs.get("logit_bias", None),
kwargs.get("repetition_penalty", 1.3),
kwargs.get("repetition_context_size", 20),
)
generated_token_count = 0
generated_tokens = []
for i, response in enumerate(
tqdm(
stream_generate(
self.language_model,
tokenizer=self.processor,
prompt=input_ids.squeeze(0),
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
),
total=max_tokens,
disable=not verbose,
)
):
next_token = mx.array([response.token])
input_ids = mx.concatenate([input_ids, next_token[None, :]], axis=1)
generated_tokens.append(next_token.squeeze())
generated_token_count += 1
if i % 50 == 0:
mx.clear_cache()
# Check for EOS
if response.token == 1 or response.token == 107: # EOS tokens
break
# Decode only the new tokens
enhanced_prompt = self.processor.decode(generated_tokens, skip_special_tokens=True)
enhanced_prompt = self._clean_response(enhanced_prompt)
logging.info(f"Enhanced prompt: {enhanced_prompt}")
return enhanced_prompt
def enhance_i2v(
self,
prompt: str,
image: Optional[mx.array] = None,
max_new_tokens: int = 512,
system_prompt: Optional[str] = None,
seed: int = 42,
) -> str:
"""Enhance a text prompt for I2V generation.
Args:
prompt: The original user prompt
image: Optional image tensor (not currently used)
max_new_tokens: Maximum number of tokens to generate
system_prompt: Optional custom system prompt
seed: Random seed for generation
Returns:
Enhanced prompt string
"""
# Use T2V enhancement with I2V system prompt
return self.enhance_t2v(
prompt,
max_new_tokens=max_new_tokens,
system_prompt=system_prompt or self.default_i2v_system_prompt,
seed=seed,
)
def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
encoder = LTX2TextEncoder(model_path=model_path)