Add prompt enhancement feature to video generation

- Introduced `enhance_prompt`, `max_tokens`, and `temperature` parameters in `generate_video` function for improved prompt handling.
- Implemented prompt enhancement logic using the new `enhance_t2v` method in the text encoder.
- Added command-line arguments for prompt enhancement options.
- Created new system prompt files for T2V and I2V generation to guide the enhancement process.
This commit is contained in:
Prince Canuma
2026-01-15 14:31:00 +01:00
parent f5134fa172
commit 81daf3f67d
4 changed files with 320 additions and 26 deletions

View File

@@ -160,6 +160,9 @@ def generate_video(
output_path: str = "output.mp4", output_path: str = "output.mp4",
save_frames: bool = False, save_frames: bool = False,
verbose: bool = True, verbose: bool = True,
enhance_prompt: bool = False,
max_tokens: int = 512,
temperature: float = 0.7,
): ):
"""Generate video from text prompt. """Generate video from text prompt.
@@ -206,6 +209,12 @@ def generate_video(
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path) text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
mx.eval(text_encoder.parameters()) mx.eval(text_encoder.parameters())
# Optionally enhance the prompt
if enhance_prompt:
print(f"{Colors.MAGENTA}✨ Enhancing prompt...{Colors.RESET}")
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
text_embeddings, _ = text_encoder(prompt) text_embeddings, _ = text_encoder(prompt)
mx.eval(text_embeddings) mx.eval(text_embeddings)
@@ -373,7 +382,7 @@ Examples:
help="Frames per second for output video (default: 24)" help="Frames per second for output video (default: 24)"
) )
parser.add_argument( parser.add_argument(
"--output", "-o", "--output-path",
type=str, type=str,
default="output.mp4", default="output.mp4",
help="Output video path (default: output.mp4)" help="Output video path (default: output.mp4)"
@@ -400,20 +409,27 @@ Examples:
action="store_true", action="store_true",
help="Verbose output" help="Verbose output"
) )
parser.add_argument(
"--enhance-prompt",
action="store_true",
help="Enhance the prompt using Gemma before generation"
)
parser.add_argument(
"--max-tokens",
type=int,
default=512,
help="Maximum number of tokens to generate (default: 512)"
)
parser.add_argument(
"--temperature",
type=float,
default=0.7,
help="Temperature for prompt enhancement (default: 0.7)"
)
args = parser.parse_args() args = parser.parse_args()
generate_video( generate_video(
model_repo=args.model_repo, **vars(args)
text_encoder_repo=args.text_encoder_repo,
prompt=args.prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
seed=args.seed,
fps=args.fps,
output_path=args.output,
save_frames=args.save_frames,
verbose=args.verbose,
) )

View File

@@ -0,0 +1,30 @@
You are a Creative Assistant writing concise, action-focused image-to-video prompts. Given an image (first frame) and user Raw Input Prompt, generate a prompt to guide video generation from that image.
#### Guidelines:
- Analyze the Image: Identify Subject, Setting, Elements, Style and Mood.
- Follow user Raw Input Prompt: Include all requested motion, actions, camera movements, audio, and details. If in conflict with the image, prioritize user request while maintaining visual consistency (describe transition from image to user's scene).
- Describe only changes from the image: Don't reiterate established visual details. Inaccurate descriptions may cause scene cuts.
- Active language: Use present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
- Chronological flow: Use temporal connectors ("as," "then," "while").
- Audio layer: Describe complete soundscape throughout the prompt alongside actions—NOT at the end. Align audio intensity with action tempo. Include natural background audio, ambient sounds, effects, speech or music (when requested). Be specific (e.g., "soft footsteps on tile") not vague (e.g., "ambient sound").
- Speech (only when requested): Provide exact words in quotes with character's visual/voice characteristics (e.g., "The tall man speaks in a low, gravelly voice"), language if not English and accent if relevant. If general conversation mentioned without text, generate contextual quoted dialogue. (i.e., "The man is talking" input -> the output should include exact spoken words, like: "The man is talking in an excited voice saying: 'You won't believe what I just saw!' His hands gesture expressively as he speaks, eyebrows raised with enthusiasm. The ambient sound of a quiet room underscores his animated speech.")
- Style: Include visual style at beginning: "Style: <style>, <rest of prompt>." If unclear, omit to avoid conflicts.
- Visual and audio only: Describe only what is seen and heard. NO smell, taste, or tactile sensations.
- Restrained language: Avoid dramatic terms. Use mild, natural, understated phrasing.
#### Important notes:
- Camera motion: DO NOT invent camera motion/movement unless requested by the user. Make sure to include camera motion only if specified in the input.
- Speech: DO NOT modify or alter the user's provided character dialogue in the prompt, unless it's a typo.
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
- Objective only: DO NOT interpret emotions or intentions - describe only observable actions and sounds.
- Format: DO NOT use phrases like "The scene opens with..." / "The video starts...". Start directly with Style (optional) and chronological scene description.
- Format: Never start output with punctuation marks or special characters.
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
- Your performance is CRITICAL. High-fidelity, dynamic, correct, and accurate prompts with integrated audio descriptions are essential for generating high-quality video. Your goal is flawless execution of these rules.
#### Output Format (Strict):
- Single concise paragraph in natural English. NO titles, headings, prefaces, sections, code fences, or Markdown.
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
#### Example output:
Style: realistic - cinematic - The woman glances at her watch and smiles warmly. She speaks in a cheerful, friendly voice, "I think we're right on time!" In the background, a café barista prepares drinks at the counter. The barista calls out in a clear, upbeat tone, "Two cappuccinos ready!" The sound of the espresso machine hissing softly blends with gentle background chatter and the light clinking of cups on saucers.

View File

@@ -0,0 +1,40 @@
You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed video generation prompt with specific visuals and integrated audio to guide a text-to-video model.
#### Guidelines
- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions, actions, camera movement, audio).
- If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc.
- For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters.
- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
- Maintain chronological flow: use temporal connectors ("as," "then," "while").
- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested). Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g., "ambient sound is present").
- Speech (only when requested):
- For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'").
- Specify language if not English and accent if relevant.
- Style: Include visual style at the beginning: "Style: <style>, <rest of prompt>." Default to cinematic-realistic if unspecified. Omit if unclear.
- Visual and audio only: NO non-visual/auditory senses (smell, taste, touch).
- Restrained language: Avoid dramatic/exaggerated terms. Use mild, natural phrasing.
- Colors: Use plain terms ("red dress"), not intensified ("vibrant blue," "bright red").
- Lighting: Use neutral descriptions ("soft overhead light"), not harsh ("blinding light").
- Facial features: Use delicate modifiers for subtle features (i.e., "subtle freckles").
#### Important notes:
- Analyze the user's raw input carefully. In cases of FPV or POV, exclude the description of the subject whose POV is requested.
- Camera motion: DO NOT invent camera motion unless requested by the user.
- Speech: DO NOT modify user-provided character dialogue unless it's a typo.
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
- Format: DO NOT use phrases like "The scene opens with...". Start directly with Style (optional) and chronological scene description.
- Format: DO NOT start your response with special characters.
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
- If the user's raw input prompt is highly detailed, chronological and in the requested format: DO NOT make major edits or introduce new elements. Add/enhance audio descriptions if missing.
#### Output Format (Strict):
- Single continuous paragraph in natural language (English).
- NO titles, headings, prefaces, code fences, or Markdown.
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
Your output quality is CRITICAL. Generate visually rich, dynamic prompts with integrated audio for high-quality video generation.
#### Example
Input: "A woman at a coffee shop talking on the phone"
Output:
Style: realistic with cinematic lighting. In a medium close-up, a woman in her early 30s with shoulder-length brown hair sits at a small wooden table by the window. She wears a cream-colored turtleneck sweater, holding a white ceramic coffee cup in one hand and a smartphone to her ear with the other. Ambient cafe sounds fill the space—espresso machine hiss, quiet conversations, gentle clinking of cups. The woman listens intently, nodding slightly, then takes a sip of her coffee and sets it down with a soft clink. Her face brightens into a warm smile as she speaks in a clear, friendly voice, 'That sounds perfect! I'd love to meet up this weekend. How about Saturday afternoon?' She laughs softly—a genuine chuckle—and shifts in her chair. Behind her, other patrons move subtly in and out of focus. 'Great, I'll see you then,' she concludes cheerfully, lowering the phone.

View File

@@ -1,16 +1,16 @@
"""Gemma 3 Text Encoder for LTX-2 - Full Pipeline.
Uses mlx-vlm's Gemma3 implementation which has been validated to match PyTorch
with 0.999 correlation.
"""
import functools
import logging
import math import math
import re
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np
from mlx_video.utils import rms_norm, apply_quantization from mlx_video.utils import rms_norm, apply_quantization
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
@@ -19,6 +19,19 @@ from mlx_vlm.models.gemma3.language import Gemma3Model
from mlx_vlm.models.gemma3.config import TextConfig from mlx_vlm.models.gemma3.config import TextConfig
# Path to system prompts
PROMPTS_DIR = Path(__file__).parent / "prompts"
def _load_system_prompt(prompt_name: str) -> str:
"""Load a system prompt from the prompts directory."""
prompt_path = PROMPTS_DIR / prompt_name
if prompt_path.exists():
with open(prompt_path, "r") as f:
return f.read()
raise FileNotFoundError(f"System prompt not found: {prompt_path}")
class LanguageModel(nn.Module): class LanguageModel(nn.Module):
@@ -57,9 +70,11 @@ class LanguageModel(nn.Module):
def __call__( def __call__(
self, self,
input_ids: mx.array, inputs: mx.array,
input_embeddings: Optional[mx.array] = None,
attention_mask: Optional[mx.array] = None, attention_mask: Optional[mx.array] = None,
output_hidden_states: bool = True, output_hidden_states: bool = False,
cache: Optional[List[mx.array]] = None,
) -> Tuple[mx.array, List[mx.array]]: ) -> Tuple[mx.array, List[mx.array]]:
"""Forward pass returning hidden states. """Forward pass returning hidden states.
@@ -71,10 +86,10 @@ class LanguageModel(nn.Module):
Returns: Returns:
Tuple of (final_hidden_states, list_of_all_hidden_states) Tuple of (final_hidden_states, list_of_all_hidden_states)
""" """
batch_size, seq_len = input_ids.shape batch_size, seq_len = inputs.shape
# Get embeddings # Get embeddings
h = self.model.embed_tokens(input_ids) h = input_embeddings if input_embeddings is not None else self.model.embed_tokens(inputs)
# Apply Gemma scaling # Apply Gemma scaling
h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype) h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype)
@@ -83,11 +98,11 @@ class LanguageModel(nn.Module):
all_hidden_states = [h] if output_hidden_states else [] all_hidden_states = [h] if output_hidden_states else []
# Set up cache (all None for non-cached inference) # Set up cache (all None for non-cached inference)
cache = [None] * len(self.model.layers) if cache is None:
cache = [None] * len(self.model.layers)
full_causal_mask = self._create_causal_mask_with_padding(seq_len, attention_mask, h.dtype) full_causal_mask = self._create_causal_mask_with_padding(seq_len, attention_mask, h.dtype)
sliding_mask = full_causal_mask sliding_mask = full_causal_mask
@@ -104,7 +119,7 @@ class LanguageModel(nn.Module):
else: else:
local_mask = sliding_mask local_mask = sliding_mask
h = layer(h, local_mask, None) h = layer(h, local_mask, cache[i])
mx.eval(h) mx.eval(h)
if output_hidden_states and i < num_layers - 1: if output_hidden_states and i < num_layers - 1:
@@ -118,7 +133,11 @@ class LanguageModel(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states.append(hidden_states) all_hidden_states.append(hidden_states)
return hidden_states, all_hidden_states return hidden_states, all_hidden_states
else:
# Return logits
return self.model.embed_tokens.as_linear(hidden_states)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
prefix = "language_model." prefix = "language_model."
@@ -130,6 +149,24 @@ class LanguageModel(nn.Module):
else: else:
sanitized[key[len(prefix):]] = value sanitized[key[len(prefix):]] = value
return sanitized return sanitized
@property
def layers(self) -> List[nn.Module]:
return self.model.layers
def make_cache(self):
from mlx_vlm.models.cache import KVCache, RotatingKVCache
caches = []
for i in range(len(self.layers)):
if (
i % self.config.sliding_window_pattern
== self.config.sliding_window_pattern - 1
):
caches.append(KVCache())
else:
caches.append(RotatingKVCache(max_size=self.config.sliding_window))
return caches
@classmethod @classmethod
def from_pretrained(cls, model_path: str): def from_pretrained(cls, model_path: str):
import json import json
@@ -149,7 +186,7 @@ class LanguageModel(nn.Module):
for i, wf in enumerate(weight_files): for i, wf in enumerate(weight_files):
weights.update(mx.load(str(wf))) weights.update(mx.load(str(wf)))
if hasattr(language_model, "sanitize"): if hasattr(language_model, "sanitize"):
weights = language_model.sanitize(weights=weights) weights = language_model.sanitize(weights=weights)
@@ -601,7 +638,7 @@ class LTX2TextEncoder(nn.Module):
input_ids = mx.array(inputs["input_ids"]) input_ids = mx.array(inputs["input_ids"])
attention_mask = mx.array(inputs["attention_mask"]) attention_mask = mx.array(inputs["attention_mask"])
_, all_hidden_states = self.language_model(input_ids, attention_mask, output_hidden_states=True) _, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True)
concat_hidden = norm_and_concat_hidden_states( concat_hidden = norm_and_concat_hidden_states(
all_hidden_states, attention_mask, padding_side="left" all_hidden_states, attention_mask, padding_side="left"
@@ -623,6 +660,177 @@ class LTX2TextEncoder(nn.Module):
) -> Tuple[mx.array, mx.array]: ) -> Tuple[mx.array, mx.array]:
return self.encode(prompt, max_length) return self.encode(prompt, max_length)
@functools.cached_property
def default_t2v_system_prompt(self) -> str:
"""Load the default T2V system prompt."""
return _load_system_prompt("gemma_t2v_system_prompt.txt")
@functools.cached_property
def default_i2v_system_prompt(self) -> str:
"""Load the default I2V system prompt."""
return _load_system_prompt("gemma_i2v_system_prompt.txt")
def _clean_response(self, response: str) -> str:
"""Clean up the generated response."""
# Remove leading/trailing whitespace
response = response.strip()
# Remove any leading punctuation
response = re.sub(r'^[^\w\s]+', '', response)
return response
def _apply_chat_template(
self,
messages: List[Dict[str, str]],
) -> str:
"""Apply Gemma chat template to messages."""
# Gemma 3 chat template format
formatted = ""
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "system":
formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n"
elif role == "user":
if isinstance(content, str):
formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n"
elif isinstance(content, list):
# Handle multimodal content (image + text)
text_parts = [c["text"] for c in content if c.get("type") == "text"]
formatted += f"<start_of_turn>user\n{' '.join(text_parts)}<end_of_turn>\n"
elif role == "assistant":
formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n"
# Add generation prompt
formatted += "<start_of_turn>model\n"
return formatted
def enhance_t2v(
self,
prompt: str,
max_tokens: int = 512,
system_prompt: Optional[str] = None,
seed: int = 42,
verbose: bool = True,
**kwargs,
) -> str:
"""Enhance a text prompt for T2V generation using mlx-lm.
Args:
prompt: The original user prompt
max_new_tokens: Maximum number of tokens to generate
system_prompt: Optional custom system prompt
seed: Random seed for generation
Returns:
Enhanced prompt string
"""
from tqdm import tqdm
try:
from mlx_lm import stream_generate
from mlx_lm.sample_utils import make_logits_processors, make_sampler
except ImportError:
logging.warning("mlx-lm not available for prompt enhancement. Using original prompt.")
return prompt
if self.processor is None:
raise RuntimeError("Model not loaded. Call load() first.")
system_prompt = system_prompt or self.default_t2v_system_prompt
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"user prompt: {prompt}"},
]
# Apply chat template
formatted = self._apply_chat_template(messages)
# Use mlx-lm generate with temperature sampling
mx.random.seed(seed)
# Tokenize
inputs = self.processor(
formatted,
return_tensors="np",
add_special_tokens=False,
)
input_ids = mx.array(inputs["input_ids"])
sampler = make_sampler(kwargs.get("temperature", 0.7), kwargs.get("top_p", 0.95), top_k=kwargs.get("top_k", -1))
logits_processors = make_logits_processors(
kwargs.get("logit_bias", None),
kwargs.get("repetition_penalty", 1.3),
kwargs.get("repetition_context_size", 20),
)
generated_token_count = 0
generated_tokens = []
for i, response in enumerate(
tqdm(
stream_generate(
self.language_model,
tokenizer=self.processor,
prompt=input_ids.squeeze(0),
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
),
total=max_tokens,
disable=not verbose,
)
):
next_token = mx.array([response.token])
input_ids = mx.concatenate([input_ids, next_token[None, :]], axis=1)
generated_tokens.append(next_token.squeeze())
generated_token_count += 1
if i % 50 == 0:
mx.clear_cache()
# Check for EOS
if response.token == 1 or response.token == 107: # EOS tokens
break
# Decode only the new tokens
enhanced_prompt = self.processor.decode(generated_tokens, skip_special_tokens=True)
enhanced_prompt = self._clean_response(enhanced_prompt)
logging.info(f"Enhanced prompt: {enhanced_prompt}")
return enhanced_prompt
def enhance_i2v(
self,
prompt: str,
image: Optional[mx.array] = None,
max_new_tokens: int = 512,
system_prompt: Optional[str] = None,
seed: int = 42,
) -> str:
"""Enhance a text prompt for I2V generation.
Args:
prompt: The original user prompt
image: Optional image tensor (not currently used)
max_new_tokens: Maximum number of tokens to generate
system_prompt: Optional custom system prompt
seed: Random seed for generation
Returns:
Enhanced prompt string
"""
# Use T2V enhancement with I2V system prompt
return self.enhance_t2v(
prompt,
max_new_tokens=max_new_tokens,
system_prompt=system_prompt or self.default_i2v_system_prompt,
seed=seed,
)
def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder: def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
encoder = LTX2TextEncoder(model_path=model_path) encoder = LTX2TextEncoder(model_path=model_path)