diff --git a/LTX-2 b/LTX-2 new file mode 160000 index 0000000..6289560 --- /dev/null +++ b/LTX-2 @@ -0,0 +1 @@ +Subproject commit 628956009ca14446f01976ba4e861e6a9e210a93 diff --git a/README.md b/README.md index bfd8ead..42654b6 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,16 @@ mlx_video/ └── video_vae/ # VAE encoder/decoder ``` +# Examples + +Here's an example result generated by MLX-Video: + +```sh +uv run mlx_video.generate --prompt "Two dogs of the poodle breed wearing sunglasses, close up, cinematic, sunset" -n 100 --width 768 +``` + + + ## License MIT \ No newline at end of file diff --git a/examples/poodles.mp4 b/examples/poodles.mp4 new file mode 100644 index 0000000..26a306f Binary files /dev/null and b/examples/poodles.mp4 differ diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 0a945cc..530aa13 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -6,6 +6,18 @@ import mlx.core as mx import numpy as np from PIL import Image +# ANSI color codes +class Colors: + CYAN = "\033[96m" + BLUE = "\033[94m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + RED = "\033[91m" + MAGENTA = "\033[95m" + BOLD = "\033[1m" + DIM = "\033[2m" + RESET = "\033[0m" + from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType from mlx_video.models.ltx.ltx import LTXModel from mlx_video.models.ltx.transformer import Modality @@ -163,8 +175,8 @@ def generate_video( assert height % 64 == 0, f"Height must be divisible by 64, got {height}" assert width % 64 == 0, f"Width must be divisible by 64, got {width}" - print(f"Generating {width}x{height} video with {num_frames} frames") - print(f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}") + print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames{Colors.RESET}") + print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}") # Get model path model_path = get_model_path(model_repo) @@ -177,7 +189,7 @@ def generate_video( mx.random.seed(seed) # Load text encoder - print("Loading text encoder...") + print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}") from mlx_video.models.ltx.text_encoder import LTX2TextEncoder text_encoder = LTX2TextEncoder(model_path=str(model_path)) text_encoder.load(str(model_path)) @@ -190,7 +202,7 @@ def generate_video( mx.clear_cache() # Load transformer - print("Loading transformer...") + print(f"{Colors.BLUE}🤖 Loading transformer...{Colors.RESET}") raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors')) sanitized = sanitize_transformer_weights(raw_weights) @@ -216,7 +228,7 @@ def generate_video( mx.eval(transformer.parameters()) # Stage 1: Generate at half resolution - print(f"Stage 1: Generating at {width//2}x{height//2} (8 steps)...") + print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}") mx.random.seed(seed) latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w)) mx.eval(latents) @@ -227,7 +239,7 @@ def generate_video( latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS) # Upsample latents - print("Upsampling latents 2x...") + print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}") upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')) mx.eval(upsampler.parameters()) @@ -243,7 +255,7 @@ def generate_video( mx.clear_cache() # Stage 2: Refine at full resolution - print(f"Stage 2: Refining at {width}x{height} (3 steps)...") + print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}") positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) mx.eval(positions) @@ -259,7 +271,7 @@ def generate_video( mx.clear_cache() # Decode to video - print("Decoding video...") + print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}") video = vae_decoder(latents) mx.eval(video) mx.clear_cache() @@ -283,19 +295,19 @@ def generate_video( for frame in video_np: out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) out.release() - print(f"Saved video to {output_path}") + print(f"{Colors.GREEN}✅ Saved video to {output_path}{Colors.RESET}") except Exception as e: - print(f"Could not save video: {e}") + print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}") if save_frames: frames_dir = output_path.parent / f"{output_path.stem}_frames" frames_dir.mkdir(exist_ok=True) for i, frame in enumerate(video_np): Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png") - print(f"Saved {len(video_np)} frames to {frames_dir}") + print(f"{Colors.GREEN}✅ Saved {len(video_np)} frames to {frames_dir}{Colors.RESET}") elapsed = time.time() - start_time - print(f"Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)") + print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}") return video_np