Add prompt enhancement feature to video generation
- Introduced `enhance_prompt`, `max_tokens`, and `temperature` parameters in `generate_video` function for improved prompt handling. - Implemented prompt enhancement logic using the new `enhance_t2v` method in the text encoder. - Added command-line arguments for prompt enhancement options. - Created new system prompt files for T2V and I2V generation to guide the enhancement process.
This commit is contained in:
@@ -160,6 +160,9 @@ def generate_video(
|
|||||||
output_path: str = "output.mp4",
|
output_path: str = "output.mp4",
|
||||||
save_frames: bool = False,
|
save_frames: bool = False,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
|
enhance_prompt: bool = False,
|
||||||
|
max_tokens: int = 512,
|
||||||
|
temperature: float = 0.7,
|
||||||
):
|
):
|
||||||
"""Generate video from text prompt.
|
"""Generate video from text prompt.
|
||||||
|
|
||||||
@@ -206,6 +209,12 @@ def generate_video(
|
|||||||
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
|
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
|
||||||
mx.eval(text_encoder.parameters())
|
mx.eval(text_encoder.parameters())
|
||||||
|
|
||||||
|
# Optionally enhance the prompt
|
||||||
|
if enhance_prompt:
|
||||||
|
print(f"{Colors.MAGENTA}✨ Enhancing prompt...{Colors.RESET}")
|
||||||
|
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
|
||||||
|
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
|
||||||
|
|
||||||
text_embeddings, _ = text_encoder(prompt)
|
text_embeddings, _ = text_encoder(prompt)
|
||||||
mx.eval(text_embeddings)
|
mx.eval(text_embeddings)
|
||||||
|
|
||||||
@@ -373,7 +382,7 @@ Examples:
|
|||||||
help="Frames per second for output video (default: 24)"
|
help="Frames per second for output video (default: 24)"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output", "-o",
|
"--output-path",
|
||||||
type=str,
|
type=str,
|
||||||
default="output.mp4",
|
default="output.mp4",
|
||||||
help="Output video path (default: output.mp4)"
|
help="Output video path (default: output.mp4)"
|
||||||
@@ -400,20 +409,27 @@ Examples:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Verbose output"
|
help="Verbose output"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enhance-prompt",
|
||||||
|
action="store_true",
|
||||||
|
help="Enhance the prompt using Gemma before generation"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-tokens",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="Maximum number of tokens to generate (default: 512)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature",
|
||||||
|
type=float,
|
||||||
|
default=0.7,
|
||||||
|
help="Temperature for prompt enhancement (default: 0.7)"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
generate_video(
|
generate_video(
|
||||||
model_repo=args.model_repo,
|
**vars(args)
|
||||||
text_encoder_repo=args.text_encoder_repo,
|
|
||||||
prompt=args.prompt,
|
|
||||||
height=args.height,
|
|
||||||
width=args.width,
|
|
||||||
num_frames=args.num_frames,
|
|
||||||
seed=args.seed,
|
|
||||||
fps=args.fps,
|
|
||||||
output_path=args.output,
|
|
||||||
save_frames=args.save_frames,
|
|
||||||
verbose=args.verbose,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
30
mlx_video/models/ltx/prompts/gemma_i2v_system_prompt.txt
Normal file
30
mlx_video/models/ltx/prompts/gemma_i2v_system_prompt.txt
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
You are a Creative Assistant writing concise, action-focused image-to-video prompts. Given an image (first frame) and user Raw Input Prompt, generate a prompt to guide video generation from that image.
|
||||||
|
|
||||||
|
#### Guidelines:
|
||||||
|
- Analyze the Image: Identify Subject, Setting, Elements, Style and Mood.
|
||||||
|
- Follow user Raw Input Prompt: Include all requested motion, actions, camera movements, audio, and details. If in conflict with the image, prioritize user request while maintaining visual consistency (describe transition from image to user's scene).
|
||||||
|
- Describe only changes from the image: Don't reiterate established visual details. Inaccurate descriptions may cause scene cuts.
|
||||||
|
- Active language: Use present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
|
||||||
|
- Chronological flow: Use temporal connectors ("as," "then," "while").
|
||||||
|
- Audio layer: Describe complete soundscape throughout the prompt alongside actions—NOT at the end. Align audio intensity with action tempo. Include natural background audio, ambient sounds, effects, speech or music (when requested). Be specific (e.g., "soft footsteps on tile") not vague (e.g., "ambient sound").
|
||||||
|
- Speech (only when requested): Provide exact words in quotes with character's visual/voice characteristics (e.g., "The tall man speaks in a low, gravelly voice"), language if not English and accent if relevant. If general conversation mentioned without text, generate contextual quoted dialogue. (i.e., "The man is talking" input -> the output should include exact spoken words, like: "The man is talking in an excited voice saying: 'You won't believe what I just saw!' His hands gesture expressively as he speaks, eyebrows raised with enthusiasm. The ambient sound of a quiet room underscores his animated speech.")
|
||||||
|
- Style: Include visual style at beginning: "Style: <style>, <rest of prompt>." If unclear, omit to avoid conflicts.
|
||||||
|
- Visual and audio only: Describe only what is seen and heard. NO smell, taste, or tactile sensations.
|
||||||
|
- Restrained language: Avoid dramatic terms. Use mild, natural, understated phrasing.
|
||||||
|
|
||||||
|
#### Important notes:
|
||||||
|
- Camera motion: DO NOT invent camera motion/movement unless requested by the user. Make sure to include camera motion only if specified in the input.
|
||||||
|
- Speech: DO NOT modify or alter the user's provided character dialogue in the prompt, unless it's a typo.
|
||||||
|
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
|
||||||
|
- Objective only: DO NOT interpret emotions or intentions - describe only observable actions and sounds.
|
||||||
|
- Format: DO NOT use phrases like "The scene opens with..." / "The video starts...". Start directly with Style (optional) and chronological scene description.
|
||||||
|
- Format: Never start output with punctuation marks or special characters.
|
||||||
|
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
|
||||||
|
- Your performance is CRITICAL. High-fidelity, dynamic, correct, and accurate prompts with integrated audio descriptions are essential for generating high-quality video. Your goal is flawless execution of these rules.
|
||||||
|
|
||||||
|
#### Output Format (Strict):
|
||||||
|
- Single concise paragraph in natural English. NO titles, headings, prefaces, sections, code fences, or Markdown.
|
||||||
|
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
|
||||||
|
|
||||||
|
#### Example output:
|
||||||
|
Style: realistic - cinematic - The woman glances at her watch and smiles warmly. She speaks in a cheerful, friendly voice, "I think we're right on time!" In the background, a café barista prepares drinks at the counter. The barista calls out in a clear, upbeat tone, "Two cappuccinos ready!" The sound of the espresso machine hissing softly blends with gentle background chatter and the light clinking of cups on saucers.
|
||||||
40
mlx_video/models/ltx/prompts/gemma_t2v_system_prompt.txt
Normal file
40
mlx_video/models/ltx/prompts/gemma_t2v_system_prompt.txt
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed video generation prompt with specific visuals and integrated audio to guide a text-to-video model.
|
||||||
|
|
||||||
|
#### Guidelines
|
||||||
|
- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions, actions, camera movement, audio).
|
||||||
|
- If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc.
|
||||||
|
- For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters.
|
||||||
|
- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
|
||||||
|
- Maintain chronological flow: use temporal connectors ("as," "then," "while").
|
||||||
|
- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested). Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g., "ambient sound is present").
|
||||||
|
- Speech (only when requested):
|
||||||
|
- For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'").
|
||||||
|
- Specify language if not English and accent if relevant.
|
||||||
|
- Style: Include visual style at the beginning: "Style: <style>, <rest of prompt>." Default to cinematic-realistic if unspecified. Omit if unclear.
|
||||||
|
- Visual and audio only: NO non-visual/auditory senses (smell, taste, touch).
|
||||||
|
- Restrained language: Avoid dramatic/exaggerated terms. Use mild, natural phrasing.
|
||||||
|
- Colors: Use plain terms ("red dress"), not intensified ("vibrant blue," "bright red").
|
||||||
|
- Lighting: Use neutral descriptions ("soft overhead light"), not harsh ("blinding light").
|
||||||
|
- Facial features: Use delicate modifiers for subtle features (i.e., "subtle freckles").
|
||||||
|
|
||||||
|
#### Important notes:
|
||||||
|
- Analyze the user's raw input carefully. In cases of FPV or POV, exclude the description of the subject whose POV is requested.
|
||||||
|
- Camera motion: DO NOT invent camera motion unless requested by the user.
|
||||||
|
- Speech: DO NOT modify user-provided character dialogue unless it's a typo.
|
||||||
|
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
|
||||||
|
- Format: DO NOT use phrases like "The scene opens with...". Start directly with Style (optional) and chronological scene description.
|
||||||
|
- Format: DO NOT start your response with special characters.
|
||||||
|
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
|
||||||
|
- If the user's raw input prompt is highly detailed, chronological and in the requested format: DO NOT make major edits or introduce new elements. Add/enhance audio descriptions if missing.
|
||||||
|
|
||||||
|
#### Output Format (Strict):
|
||||||
|
- Single continuous paragraph in natural language (English).
|
||||||
|
- NO titles, headings, prefaces, code fences, or Markdown.
|
||||||
|
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
|
||||||
|
|
||||||
|
Your output quality is CRITICAL. Generate visually rich, dynamic prompts with integrated audio for high-quality video generation.
|
||||||
|
|
||||||
|
#### Example
|
||||||
|
Input: "A woman at a coffee shop talking on the phone"
|
||||||
|
Output:
|
||||||
|
Style: realistic with cinematic lighting. In a medium close-up, a woman in her early 30s with shoulder-length brown hair sits at a small wooden table by the window. She wears a cream-colored turtleneck sweater, holding a white ceramic coffee cup in one hand and a smartphone to her ear with the other. Ambient cafe sounds fill the space—espresso machine hiss, quiet conversations, gentle clinking of cups. The woman listens intently, nodding slightly, then takes a sip of her coffee and sets it down with a soft clink. Her face brightens into a warm smile as she speaks in a clear, friendly voice, 'That sounds perfect! I'd love to meet up this weekend. How about Saturday afternoon?' She laughs softly—a genuine chuckle—and shifts in her chair. Behind her, other patrons move subtly in and out of focus. 'Great, I'll see you then,' she concludes cheerfully, lowering the phone.
|
||||||
@@ -1,16 +1,16 @@
|
|||||||
"""Gemma 3 Text Encoder for LTX-2 - Full Pipeline.
|
|
||||||
|
|
||||||
Uses mlx-vlm's Gemma3 implementation which has been validated to match PyTorch
|
|
||||||
with 0.999 correlation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from mlx_video.utils import rms_norm, apply_quantization
|
from mlx_video.utils import rms_norm, apply_quantization
|
||||||
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
|
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
|
||||||
@@ -19,6 +19,19 @@ from mlx_vlm.models.gemma3.language import Gemma3Model
|
|||||||
from mlx_vlm.models.gemma3.config import TextConfig
|
from mlx_vlm.models.gemma3.config import TextConfig
|
||||||
|
|
||||||
|
|
||||||
|
# Path to system prompts
|
||||||
|
PROMPTS_DIR = Path(__file__).parent / "prompts"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_system_prompt(prompt_name: str) -> str:
|
||||||
|
"""Load a system prompt from the prompts directory."""
|
||||||
|
prompt_path = PROMPTS_DIR / prompt_name
|
||||||
|
if prompt_path.exists():
|
||||||
|
with open(prompt_path, "r") as f:
|
||||||
|
return f.read()
|
||||||
|
raise FileNotFoundError(f"System prompt not found: {prompt_path}")
|
||||||
|
|
||||||
|
|
||||||
class LanguageModel(nn.Module):
|
class LanguageModel(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
@@ -57,9 +70,11 @@ class LanguageModel(nn.Module):
|
|||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
input_ids: mx.array,
|
inputs: mx.array,
|
||||||
|
input_embeddings: Optional[mx.array] = None,
|
||||||
attention_mask: Optional[mx.array] = None,
|
attention_mask: Optional[mx.array] = None,
|
||||||
output_hidden_states: bool = True,
|
output_hidden_states: bool = False,
|
||||||
|
cache: Optional[List[mx.array]] = None,
|
||||||
) -> Tuple[mx.array, List[mx.array]]:
|
) -> Tuple[mx.array, List[mx.array]]:
|
||||||
"""Forward pass returning hidden states.
|
"""Forward pass returning hidden states.
|
||||||
|
|
||||||
@@ -71,10 +86,10 @@ class LanguageModel(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (final_hidden_states, list_of_all_hidden_states)
|
Tuple of (final_hidden_states, list_of_all_hidden_states)
|
||||||
"""
|
"""
|
||||||
batch_size, seq_len = input_ids.shape
|
batch_size, seq_len = inputs.shape
|
||||||
|
|
||||||
# Get embeddings
|
# Get embeddings
|
||||||
h = self.model.embed_tokens(input_ids)
|
h = input_embeddings if input_embeddings is not None else self.model.embed_tokens(inputs)
|
||||||
|
|
||||||
# Apply Gemma scaling
|
# Apply Gemma scaling
|
||||||
h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype)
|
h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype)
|
||||||
@@ -83,11 +98,11 @@ class LanguageModel(nn.Module):
|
|||||||
all_hidden_states = [h] if output_hidden_states else []
|
all_hidden_states = [h] if output_hidden_states else []
|
||||||
|
|
||||||
# Set up cache (all None for non-cached inference)
|
# Set up cache (all None for non-cached inference)
|
||||||
cache = [None] * len(self.model.layers)
|
if cache is None:
|
||||||
|
cache = [None] * len(self.model.layers)
|
||||||
|
|
||||||
full_causal_mask = self._create_causal_mask_with_padding(seq_len, attention_mask, h.dtype)
|
full_causal_mask = self._create_causal_mask_with_padding(seq_len, attention_mask, h.dtype)
|
||||||
|
|
||||||
|
|
||||||
sliding_mask = full_causal_mask
|
sliding_mask = full_causal_mask
|
||||||
|
|
||||||
|
|
||||||
@@ -104,7 +119,7 @@ class LanguageModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
local_mask = sliding_mask
|
local_mask = sliding_mask
|
||||||
|
|
||||||
h = layer(h, local_mask, None)
|
h = layer(h, local_mask, cache[i])
|
||||||
mx.eval(h)
|
mx.eval(h)
|
||||||
|
|
||||||
if output_hidden_states and i < num_layers - 1:
|
if output_hidden_states and i < num_layers - 1:
|
||||||
@@ -118,7 +133,11 @@ class LanguageModel(nn.Module):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states.append(hidden_states)
|
all_hidden_states.append(hidden_states)
|
||||||
|
|
||||||
return hidden_states, all_hidden_states
|
return hidden_states, all_hidden_states
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Return logits
|
||||||
|
return self.model.embed_tokens.as_linear(hidden_states)
|
||||||
|
|
||||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
prefix = "language_model."
|
prefix = "language_model."
|
||||||
@@ -130,6 +149,24 @@ class LanguageModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
sanitized[key[len(prefix):]] = value
|
sanitized[key[len(prefix):]] = value
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self) -> List[nn.Module]:
|
||||||
|
return self.model.layers
|
||||||
|
|
||||||
|
def make_cache(self):
|
||||||
|
from mlx_vlm.models.cache import KVCache, RotatingKVCache
|
||||||
|
caches = []
|
||||||
|
for i in range(len(self.layers)):
|
||||||
|
if (
|
||||||
|
i % self.config.sliding_window_pattern
|
||||||
|
== self.config.sliding_window_pattern - 1
|
||||||
|
):
|
||||||
|
caches.append(KVCache())
|
||||||
|
else:
|
||||||
|
caches.append(RotatingKVCache(max_size=self.config.sliding_window))
|
||||||
|
return caches
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, model_path: str):
|
def from_pretrained(cls, model_path: str):
|
||||||
import json
|
import json
|
||||||
@@ -601,7 +638,7 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
input_ids = mx.array(inputs["input_ids"])
|
input_ids = mx.array(inputs["input_ids"])
|
||||||
attention_mask = mx.array(inputs["attention_mask"])
|
attention_mask = mx.array(inputs["attention_mask"])
|
||||||
|
|
||||||
_, all_hidden_states = self.language_model(input_ids, attention_mask, output_hidden_states=True)
|
_, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True)
|
||||||
|
|
||||||
concat_hidden = norm_and_concat_hidden_states(
|
concat_hidden = norm_and_concat_hidden_states(
|
||||||
all_hidden_states, attention_mask, padding_side="left"
|
all_hidden_states, attention_mask, padding_side="left"
|
||||||
@@ -623,6 +660,177 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
return self.encode(prompt, max_length)
|
return self.encode(prompt, max_length)
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def default_t2v_system_prompt(self) -> str:
|
||||||
|
"""Load the default T2V system prompt."""
|
||||||
|
return _load_system_prompt("gemma_t2v_system_prompt.txt")
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def default_i2v_system_prompt(self) -> str:
|
||||||
|
"""Load the default I2V system prompt."""
|
||||||
|
return _load_system_prompt("gemma_i2v_system_prompt.txt")
|
||||||
|
|
||||||
|
def _clean_response(self, response: str) -> str:
|
||||||
|
"""Clean up the generated response."""
|
||||||
|
# Remove leading/trailing whitespace
|
||||||
|
response = response.strip()
|
||||||
|
# Remove any leading punctuation
|
||||||
|
response = re.sub(r'^[^\w\s]+', '', response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _apply_chat_template(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
) -> str:
|
||||||
|
"""Apply Gemma chat template to messages."""
|
||||||
|
# Gemma 3 chat template format
|
||||||
|
formatted = ""
|
||||||
|
for msg in messages:
|
||||||
|
role = msg["role"]
|
||||||
|
content = msg["content"]
|
||||||
|
if role == "system":
|
||||||
|
formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n"
|
||||||
|
elif role == "user":
|
||||||
|
if isinstance(content, str):
|
||||||
|
formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n"
|
||||||
|
elif isinstance(content, list):
|
||||||
|
# Handle multimodal content (image + text)
|
||||||
|
text_parts = [c["text"] for c in content if c.get("type") == "text"]
|
||||||
|
formatted += f"<start_of_turn>user\n{' '.join(text_parts)}<end_of_turn>\n"
|
||||||
|
elif role == "assistant":
|
||||||
|
formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n"
|
||||||
|
# Add generation prompt
|
||||||
|
formatted += "<start_of_turn>model\n"
|
||||||
|
return formatted
|
||||||
|
|
||||||
|
def enhance_t2v(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
max_tokens: int = 512,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
seed: int = 42,
|
||||||
|
verbose: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
"""Enhance a text prompt for T2V generation using mlx-lm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The original user prompt
|
||||||
|
max_new_tokens: Maximum number of tokens to generate
|
||||||
|
system_prompt: Optional custom system prompt
|
||||||
|
seed: Random seed for generation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Enhanced prompt string
|
||||||
|
"""
|
||||||
|
from tqdm import tqdm
|
||||||
|
try:
|
||||||
|
from mlx_lm import stream_generate
|
||||||
|
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||||
|
except ImportError:
|
||||||
|
logging.warning("mlx-lm not available for prompt enhancement. Using original prompt.")
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
if self.processor is None:
|
||||||
|
raise RuntimeError("Model not loaded. Call load() first.")
|
||||||
|
|
||||||
|
system_prompt = system_prompt or self.default_t2v_system_prompt
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": f"user prompt: {prompt}"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Apply chat template
|
||||||
|
formatted = self._apply_chat_template(messages)
|
||||||
|
|
||||||
|
# Use mlx-lm generate with temperature sampling
|
||||||
|
mx.random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
# Tokenize
|
||||||
|
inputs = self.processor(
|
||||||
|
formatted,
|
||||||
|
return_tensors="np",
|
||||||
|
add_special_tokens=False,
|
||||||
|
)
|
||||||
|
input_ids = mx.array(inputs["input_ids"])
|
||||||
|
|
||||||
|
sampler = make_sampler(kwargs.get("temperature", 0.7), kwargs.get("top_p", 0.95), top_k=kwargs.get("top_k", -1))
|
||||||
|
logits_processors = make_logits_processors(
|
||||||
|
kwargs.get("logit_bias", None),
|
||||||
|
kwargs.get("repetition_penalty", 1.3),
|
||||||
|
kwargs.get("repetition_context_size", 20),
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_token_count = 0
|
||||||
|
generated_tokens = []
|
||||||
|
for i, response in enumerate(
|
||||||
|
tqdm(
|
||||||
|
stream_generate(
|
||||||
|
self.language_model,
|
||||||
|
tokenizer=self.processor,
|
||||||
|
prompt=input_ids.squeeze(0),
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
sampler=sampler,
|
||||||
|
logits_processors=logits_processors,
|
||||||
|
),
|
||||||
|
total=max_tokens,
|
||||||
|
disable=not verbose,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
next_token = mx.array([response.token])
|
||||||
|
input_ids = mx.concatenate([input_ids, next_token[None, :]], axis=1)
|
||||||
|
generated_tokens.append(next_token.squeeze())
|
||||||
|
generated_token_count += 1
|
||||||
|
|
||||||
|
if i % 50 == 0:
|
||||||
|
mx.clear_cache()
|
||||||
|
|
||||||
|
# Check for EOS
|
||||||
|
if response.token == 1 or response.token == 107: # EOS tokens
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Decode only the new tokens
|
||||||
|
|
||||||
|
enhanced_prompt = self.processor.decode(generated_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
|
enhanced_prompt = self._clean_response(enhanced_prompt)
|
||||||
|
logging.info(f"Enhanced prompt: {enhanced_prompt}")
|
||||||
|
|
||||||
|
return enhanced_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def enhance_i2v(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
image: Optional[mx.array] = None,
|
||||||
|
max_new_tokens: int = 512,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
seed: int = 42,
|
||||||
|
) -> str:
|
||||||
|
"""Enhance a text prompt for I2V generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The original user prompt
|
||||||
|
image: Optional image tensor (not currently used)
|
||||||
|
max_new_tokens: Maximum number of tokens to generate
|
||||||
|
system_prompt: Optional custom system prompt
|
||||||
|
seed: Random seed for generation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Enhanced prompt string
|
||||||
|
"""
|
||||||
|
# Use T2V enhancement with I2V system prompt
|
||||||
|
return self.enhance_t2v(
|
||||||
|
prompt,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
system_prompt=system_prompt or self.default_i2v_system_prompt,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
|
def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
|
||||||
encoder = LTX2TextEncoder(model_path=model_path)
|
encoder = LTX2TextEncoder(model_path=model_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user