Add prompt enhancement feature to video generation

- Introduced `enhance_prompt`, `max_tokens`, and `temperature` parameters in `generate_video` function for improved prompt handling.
- Implemented prompt enhancement logic using the new `enhance_t2v` method in the text encoder.
- Added command-line arguments for prompt enhancement options.
- Created new system prompt files for T2V and I2V generation to guide the enhancement process.
This commit is contained in:
Prince Canuma
2026-01-15 14:31:00 +01:00
parent f5134fa172
commit 81daf3f67d
4 changed files with 320 additions and 26 deletions

View File

@@ -160,6 +160,9 @@ def generate_video(
output_path: str = "output.mp4",
save_frames: bool = False,
verbose: bool = True,
enhance_prompt: bool = False,
max_tokens: int = 512,
temperature: float = 0.7,
):
"""Generate video from text prompt.
@@ -206,6 +209,12 @@ def generate_video(
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
mx.eval(text_encoder.parameters())
# Optionally enhance the prompt
if enhance_prompt:
print(f"{Colors.MAGENTA}✨ Enhancing prompt...{Colors.RESET}")
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
text_embeddings, _ = text_encoder(prompt)
mx.eval(text_embeddings)
@@ -373,7 +382,7 @@ Examples:
help="Frames per second for output video (default: 24)"
)
parser.add_argument(
"--output", "-o",
"--output-path",
type=str,
default="output.mp4",
help="Output video path (default: output.mp4)"
@@ -400,20 +409,27 @@ Examples:
action="store_true",
help="Verbose output"
)
parser.add_argument(
"--enhance-prompt",
action="store_true",
help="Enhance the prompt using Gemma before generation"
)
parser.add_argument(
"--max-tokens",
type=int,
default=512,
help="Maximum number of tokens to generate (default: 512)"
)
parser.add_argument(
"--temperature",
type=float,
default=0.7,
help="Temperature for prompt enhancement (default: 0.7)"
)
args = parser.parse_args()
generate_video(
model_repo=args.model_repo,
text_encoder_repo=args.text_encoder_repo,
prompt=args.prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
seed=args.seed,
fps=args.fps,
output_path=args.output,
save_frames=args.save_frames,
verbose=args.verbose,
**vars(args)
)