Add example usage to README and enhance console output in generate.py with ANSI colors
This commit is contained in:
1
LTX-2
Submodule
1
LTX-2
Submodule
Submodule LTX-2 added at 628956009c
10
README.md
10
README.md
@@ -115,6 +115,16 @@ mlx_video/
|
||||
└── video_vae/ # VAE encoder/decoder
|
||||
```
|
||||
|
||||
# Examples
|
||||
|
||||
Here's an example result generated by MLX-Video:
|
||||
|
||||
```sh
|
||||
uv run mlx_video.generate --prompt "Two dogs of the poodle breed wearing sunglasses, close up, cinematic, sunset" -n 100 --width 768
|
||||
```
|
||||
|
||||
<video src="https://github.com/Blaizzy/mlx-video/raw/main/examples/poodles.mp4" controls width="512"></video>
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
BIN
examples/poodles.mp4
Normal file
BIN
examples/poodles.mp4
Normal file
Binary file not shown.
@@ -6,6 +6,18 @@ import mlx.core as mx
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
# ANSI color codes
|
||||
class Colors:
|
||||
CYAN = "\033[96m"
|
||||
BLUE = "\033[94m"
|
||||
GREEN = "\033[92m"
|
||||
YELLOW = "\033[93m"
|
||||
RED = "\033[91m"
|
||||
MAGENTA = "\033[95m"
|
||||
BOLD = "\033[1m"
|
||||
DIM = "\033[2m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
|
||||
from mlx_video.models.ltx.ltx import LTXModel
|
||||
from mlx_video.models.ltx.transformer import Modality
|
||||
@@ -163,8 +175,8 @@ def generate_video(
|
||||
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
|
||||
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
|
||||
|
||||
print(f"Generating {width}x{height} video with {num_frames} frames")
|
||||
print(f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
|
||||
print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
|
||||
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
|
||||
|
||||
# Get model path
|
||||
model_path = get_model_path(model_repo)
|
||||
@@ -177,7 +189,7 @@ def generate_video(
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Load text encoder
|
||||
print("Loading text encoder...")
|
||||
print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}")
|
||||
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
|
||||
text_encoder = LTX2TextEncoder(model_path=str(model_path))
|
||||
text_encoder.load(str(model_path))
|
||||
@@ -190,7 +202,7 @@ def generate_video(
|
||||
mx.clear_cache()
|
||||
|
||||
# Load transformer
|
||||
print("Loading transformer...")
|
||||
print(f"{Colors.BLUE}🤖 Loading transformer...{Colors.RESET}")
|
||||
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
|
||||
sanitized = sanitize_transformer_weights(raw_weights)
|
||||
|
||||
@@ -216,7 +228,7 @@ def generate_video(
|
||||
mx.eval(transformer.parameters())
|
||||
|
||||
# Stage 1: Generate at half resolution
|
||||
print(f"Stage 1: Generating at {width//2}x{height//2} (8 steps)...")
|
||||
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
|
||||
mx.random.seed(seed)
|
||||
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
|
||||
mx.eval(latents)
|
||||
@@ -227,7 +239,7 @@ def generate_video(
|
||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS)
|
||||
|
||||
# Upsample latents
|
||||
print("Upsampling latents 2x...")
|
||||
print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}")
|
||||
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
|
||||
mx.eval(upsampler.parameters())
|
||||
|
||||
@@ -243,7 +255,7 @@ def generate_video(
|
||||
mx.clear_cache()
|
||||
|
||||
# Stage 2: Refine at full resolution
|
||||
print(f"Stage 2: Refining at {width}x{height} (3 steps)...")
|
||||
print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}")
|
||||
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||
mx.eval(positions)
|
||||
|
||||
@@ -259,7 +271,7 @@ def generate_video(
|
||||
mx.clear_cache()
|
||||
|
||||
# Decode to video
|
||||
print("Decoding video...")
|
||||
print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}")
|
||||
video = vae_decoder(latents)
|
||||
mx.eval(video)
|
||||
mx.clear_cache()
|
||||
@@ -283,19 +295,19 @@ def generate_video(
|
||||
for frame in video_np:
|
||||
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||
out.release()
|
||||
print(f"Saved video to {output_path}")
|
||||
print(f"{Colors.GREEN}✅ Saved video to {output_path}{Colors.RESET}")
|
||||
except Exception as e:
|
||||
print(f"Could not save video: {e}")
|
||||
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}")
|
||||
|
||||
if save_frames:
|
||||
frames_dir = output_path.parent / f"{output_path.stem}_frames"
|
||||
frames_dir.mkdir(exist_ok=True)
|
||||
for i, frame in enumerate(video_np):
|
||||
Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png")
|
||||
print(f"Saved {len(video_np)} frames to {frames_dir}")
|
||||
print(f"{Colors.GREEN}✅ Saved {len(video_np)} frames to {frames_dir}{Colors.RESET}")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print(f"Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)")
|
||||
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}")
|
||||
|
||||
return video_np
|
||||
|
||||
|
||||
Reference in New Issue
Block a user