diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 530aa13..4c78bb3 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -5,6 +5,7 @@ from pathlib import Path import mlx.core as mx import numpy as np from PIL import Image +from tqdm import tqdm # ANSI color codes class Colors: @@ -113,9 +114,10 @@ def denoise( text_embeddings: mx.array, transformer: LTXModel, sigmas: list, + verbose: bool = True, ) -> mx.array: """Run denoising loop.""" - for i in range(len(sigmas) - 1): + for i in tqdm(range(len(sigmas) - 1), desc="Denoising", disable=not verbose): sigma, sigma_next = sigmas[i], sigmas[i + 1] b, c, f, h, w = latents.shape @@ -148,6 +150,7 @@ def denoise( def generate_video( model_repo: str, + text_encoder_repo: str, prompt: str, height: int = 512, width: int = 512, @@ -156,6 +159,10 @@ def generate_video( fps: int = 24, output_path: str = "output.mp4", save_frames: bool = False, + verbose: bool = True, + enhance_prompt: bool = False, + max_tokens: int = 512, + temperature: float = 0.7, ): """Generate video from text prompt. @@ -174,12 +181,19 @@ def generate_video( # Validate dimensions assert height % 64 == 0, f"Height must be divisible by 64, got {height}" assert width % 64 == 0, f"Width must be divisible by 64, got {width}" + + if num_frames % 8 != 1: + adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1 + print(f"{Colors.YELLOW}⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}") + num_frames = adjusted_num_frames + print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames{Colors.RESET}") print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}") # Get model path model_path = get_model_path(model_repo) + text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo) # Calculate latent dimensions stage1_h, stage1_w = height // 2 // 32, width // 2 // 32 @@ -191,10 +205,16 @@ def generate_video( # Load text encoder print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}") from mlx_video.models.ltx.text_encoder import LTX2TextEncoder - text_encoder = LTX2TextEncoder(model_path=str(model_path)) - text_encoder.load(str(model_path)) + text_encoder = LTX2TextEncoder() + text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path) mx.eval(text_encoder.parameters()) + # Optionally enhance the prompt + if enhance_prompt: + print(f"{Colors.MAGENTA}✨ Enhancing prompt...{Colors.RESET}") + prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose) + print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}") + text_embeddings, _ = text_encoder(prompt) mx.eval(text_embeddings) @@ -236,7 +256,7 @@ def generate_video( positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) mx.eval(positions) - latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS) + latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose) # Upsample latents print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}") @@ -265,7 +285,7 @@ def generate_video( latents = noise * noise_scale + latents * (1 - noise_scale) mx.eval(latents) - latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS) + latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose) del transformer mx.clear_cache() @@ -295,7 +315,7 @@ def generate_video( for frame in video_np: out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) out.release() - print(f"{Colors.GREEN}✅ Saved video to {output_path}{Colors.RESET}") + print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}") except Exception as e: print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}") @@ -308,6 +328,7 @@ def generate_video( elapsed = time.time() - start_time print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}") + print(f"{Colors.BOLD}{Colors.GREEN}✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}") return video_np @@ -361,7 +382,7 @@ Examples: help="Frames per second for output video (default: 24)" ) parser.add_argument( - "--output", "-o", + "--output-path", type=str, default="output.mp4", help="Output video path (default: output.mp4)" @@ -377,18 +398,38 @@ Examples: default="Lightricks/LTX-2", help="Model repository to use (default: Lightricks/LTX-2)" ) + parser.add_argument( + "--text-encoder-repo", + type=str, + default=None, + help="Text encoder repository to use (default: None)" + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Verbose output" + ) + parser.add_argument( + "--enhance-prompt", + action="store_true", + help="Enhance the prompt using Gemma before generation" + ) + parser.add_argument( + "--max-tokens", + type=int, + default=512, + help="Maximum number of tokens to generate (default: 512)" + ) + parser.add_argument( + "--temperature", + type=float, + default=0.7, + help="Temperature for prompt enhancement (default: 0.7)" + ) args = parser.parse_args() generate_video( - model_repo=args.model_repo, - prompt=args.prompt, - height=args.height, - width=args.width, - num_frames=args.num_frames, - seed=args.seed, - fps=args.fps, - output_path=args.output, - save_frames=args.save_frames, + **vars(args) ) diff --git a/mlx_video/models/ltx/ltx.py b/mlx_video/models/ltx/ltx.py index 2b4bec2..dea4089 100644 --- a/mlx_video/models/ltx/ltx.py +++ b/mlx_video/models/ltx/ltx.py @@ -384,8 +384,9 @@ class LTXModel(nn.Module): video_config = config.get_video_config() audio_config = config.get_audio_config() - self.transformer_blocks = [ - BasicAVTransformerBlock( + + self.transformer_blocks = { + idx: BasicAVTransformerBlock( idx=idx, video=video_config, audio=audio_config, @@ -393,7 +394,7 @@ class LTXModel(nn.Module): norm_eps=config.norm_eps, ) for idx in range(config.num_layers) - ] + } def _process_transformer_blocks( self, @@ -401,7 +402,7 @@ class LTXModel(nn.Module): audio: Optional[TransformerArgs], ) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]: """Process through all transformer blocks.""" - for block in self.transformer_blocks: + for block in self.transformer_blocks.values(): video, audio = block(video=video, audio=audio) return video, audio diff --git a/mlx_video/models/ltx/prompts/gemma_i2v_system_prompt.txt b/mlx_video/models/ltx/prompts/gemma_i2v_system_prompt.txt new file mode 100644 index 0000000..0d67724 --- /dev/null +++ b/mlx_video/models/ltx/prompts/gemma_i2v_system_prompt.txt @@ -0,0 +1,30 @@ +You are a Creative Assistant writing concise, action-focused image-to-video prompts. Given an image (first frame) and user Raw Input Prompt, generate a prompt to guide video generation from that image. + +#### Guidelines: +- Analyze the Image: Identify Subject, Setting, Elements, Style and Mood. +- Follow user Raw Input Prompt: Include all requested motion, actions, camera movements, audio, and details. If in conflict with the image, prioritize user request while maintaining visual consistency (describe transition from image to user's scene). +- Describe only changes from the image: Don't reiterate established visual details. Inaccurate descriptions may cause scene cuts. +- Active language: Use present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements. +- Chronological flow: Use temporal connectors ("as," "then," "while"). +- Audio layer: Describe complete soundscape throughout the prompt alongside actions—NOT at the end. Align audio intensity with action tempo. Include natural background audio, ambient sounds, effects, speech or music (when requested). Be specific (e.g., "soft footsteps on tile") not vague (e.g., "ambient sound"). +- Speech (only when requested): Provide exact words in quotes with character's visual/voice characteristics (e.g., "The tall man speaks in a low, gravelly voice"), language if not English and accent if relevant. If general conversation mentioned without text, generate contextual quoted dialogue. (i.e., "The man is talking" input -> the output should include exact spoken words, like: "The man is talking in an excited voice saying: 'You won't believe what I just saw!' His hands gesture expressively as he speaks, eyebrows raised with enthusiasm. The ambient sound of a quiet room underscores his animated speech.") +- Style: Include visual style at beginning: "Style: