Add prompt enhancement feature to video generation
- Introduced `enhance_prompt`, `max_tokens`, and `temperature` parameters in `generate_video` function for improved prompt handling. - Implemented prompt enhancement logic using the new `enhance_t2v` method in the text encoder. - Added command-line arguments for prompt enhancement options. - Created new system prompt files for T2V and I2V generation to guide the enhancement process.
This commit is contained in:
@@ -160,6 +160,9 @@ def generate_video(
|
||||
output_path: str = "output.mp4",
|
||||
save_frames: bool = False,
|
||||
verbose: bool = True,
|
||||
enhance_prompt: bool = False,
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 0.7,
|
||||
):
|
||||
"""Generate video from text prompt.
|
||||
|
||||
@@ -206,6 +209,12 @@ def generate_video(
|
||||
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
|
||||
mx.eval(text_encoder.parameters())
|
||||
|
||||
# Optionally enhance the prompt
|
||||
if enhance_prompt:
|
||||
print(f"{Colors.MAGENTA}✨ Enhancing prompt...{Colors.RESET}")
|
||||
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
|
||||
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
|
||||
|
||||
text_embeddings, _ = text_encoder(prompt)
|
||||
mx.eval(text_embeddings)
|
||||
|
||||
@@ -373,7 +382,7 @@ Examples:
|
||||
help="Frames per second for output video (default: 24)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o",
|
||||
"--output-path",
|
||||
type=str,
|
||||
default="output.mp4",
|
||||
help="Output video path (default: output.mp4)"
|
||||
@@ -400,20 +409,27 @@ Examples:
|
||||
action="store_true",
|
||||
help="Verbose output"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enhance-prompt",
|
||||
action="store_true",
|
||||
help="Enhance the prompt using Gemma before generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Maximum number of tokens to generate (default: 512)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="Temperature for prompt enhancement (default: 0.7)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
generate_video(
|
||||
model_repo=args.model_repo,
|
||||
text_encoder_repo=args.text_encoder_repo,
|
||||
prompt=args.prompt,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
num_frames=args.num_frames,
|
||||
seed=args.seed,
|
||||
fps=args.fps,
|
||||
output_path=args.output,
|
||||
save_frames=args.save_frames,
|
||||
verbose=args.verbose,
|
||||
**vars(args)
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user