Add example usage to README and enhance console output in generate.py with ANSI colors

This commit is contained in:
Prince Canuma
2026-01-12 16:45:09 +01:00
parent 28417fe126
commit 4f6fc8252c
4 changed files with 35 additions and 12 deletions

View File

@@ -6,6 +6,18 @@ import mlx.core as mx
import numpy as np
from PIL import Image
# ANSI color codes
class Colors:
CYAN = "\033[96m"
BLUE = "\033[94m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
RED = "\033[91m"
MAGENTA = "\033[95m"
BOLD = "\033[1m"
DIM = "\033[2m"
RESET = "\033[0m"
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.transformer import Modality
@@ -163,8 +175,8 @@ def generate_video(
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
print(f"Generating {width}x{height} video with {num_frames} frames")
print(f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
# Get model path
model_path = get_model_path(model_repo)
@@ -177,7 +189,7 @@ def generate_video(
mx.random.seed(seed)
# Load text encoder
print("Loading text encoder...")
print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}")
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
text_encoder = LTX2TextEncoder(model_path=str(model_path))
text_encoder.load(str(model_path))
@@ -190,7 +202,7 @@ def generate_video(
mx.clear_cache()
# Load transformer
print("Loading transformer...")
print(f"{Colors.BLUE}🤖 Loading transformer...{Colors.RESET}")
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
sanitized = sanitize_transformer_weights(raw_weights)
@@ -216,7 +228,7 @@ def generate_video(
mx.eval(transformer.parameters())
# Stage 1: Generate at half resolution
print(f"Stage 1: Generating at {width//2}x{height//2} (8 steps)...")
print(f"{Colors.YELLOW}Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
mx.random.seed(seed)
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
mx.eval(latents)
@@ -227,7 +239,7 @@ def generate_video(
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS)
# Upsample latents
print("Upsampling latents 2x...")
print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}")
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
mx.eval(upsampler.parameters())
@@ -243,7 +255,7 @@ def generate_video(
mx.clear_cache()
# Stage 2: Refine at full resolution
print(f"Stage 2: Refining at {width}x{height} (3 steps)...")
print(f"{Colors.YELLOW}Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}")
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions)
@@ -259,7 +271,7 @@ def generate_video(
mx.clear_cache()
# Decode to video
print("Decoding video...")
print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}")
video = vae_decoder(latents)
mx.eval(video)
mx.clear_cache()
@@ -283,19 +295,19 @@ def generate_video(
for frame in video_np:
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
out.release()
print(f"Saved video to {output_path}")
print(f"{Colors.GREEN}Saved video to {output_path}{Colors.RESET}")
except Exception as e:
print(f"Could not save video: {e}")
print(f"{Colors.RED}Could not save video: {e}{Colors.RESET}")
if save_frames:
frames_dir = output_path.parent / f"{output_path.stem}_frames"
frames_dir.mkdir(exist_ok=True)
for i, frame in enumerate(video_np):
Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png")
print(f"Saved {len(video_np)} frames to {frames_dir}")
print(f"{Colors.GREEN}Saved {len(video_np)} frames to {frames_dir}{Colors.RESET}")
elapsed = time.time() - start_time
print(f"Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)")
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}")
return video_np