Add example usage to README and enhance console output in generate.py with ANSI colors

This commit is contained in:
Prince Canuma
2026-01-12 16:45:09 +01:00
parent 28417fe126
commit 4f6fc8252c
4 changed files with 35 additions and 12 deletions

1
LTX-2 Submodule

Submodule LTX-2 added at 628956009c

View File

@@ -115,6 +115,16 @@ mlx_video/
└── video_vae/ # VAE encoder/decoder └── video_vae/ # VAE encoder/decoder
``` ```
# Examples
Here's an example result generated by MLX-Video:
```sh
uv run mlx_video.generate --prompt "Two dogs of the poodle breed wearing sunglasses, close up, cinematic, sunset" -n 100 --width 768
```
<video src="https://github.com/Blaizzy/mlx-video/raw/main/examples/poodles.mp4" controls width="512"></video>
## License ## License
MIT MIT

BIN
examples/poodles.mp4 Normal file

Binary file not shown.

View File

@@ -6,6 +6,18 @@ import mlx.core as mx
import numpy as np import numpy as np
from PIL import Image from PIL import Image
# ANSI color codes
class Colors:
CYAN = "\033[96m"
BLUE = "\033[94m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
RED = "\033[91m"
MAGENTA = "\033[95m"
BOLD = "\033[1m"
DIM = "\033[2m"
RESET = "\033[0m"
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
from mlx_video.models.ltx.ltx import LTXModel from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.transformer import Modality from mlx_video.models.ltx.transformer import Modality
@@ -163,8 +175,8 @@ def generate_video(
assert height % 64 == 0, f"Height must be divisible by 64, got {height}" assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
assert width % 64 == 0, f"Width must be divisible by 64, got {width}" assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
print(f"Generating {width}x{height} video with {num_frames} frames") print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
print(f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}") print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
# Get model path # Get model path
model_path = get_model_path(model_repo) model_path = get_model_path(model_repo)
@@ -177,7 +189,7 @@ def generate_video(
mx.random.seed(seed) mx.random.seed(seed)
# Load text encoder # Load text encoder
print("Loading text encoder...") print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}")
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
text_encoder = LTX2TextEncoder(model_path=str(model_path)) text_encoder = LTX2TextEncoder(model_path=str(model_path))
text_encoder.load(str(model_path)) text_encoder.load(str(model_path))
@@ -190,7 +202,7 @@ def generate_video(
mx.clear_cache() mx.clear_cache()
# Load transformer # Load transformer
print("Loading transformer...") print(f"{Colors.BLUE}🤖 Loading transformer...{Colors.RESET}")
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors')) raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
sanitized = sanitize_transformer_weights(raw_weights) sanitized = sanitize_transformer_weights(raw_weights)
@@ -216,7 +228,7 @@ def generate_video(
mx.eval(transformer.parameters()) mx.eval(transformer.parameters())
# Stage 1: Generate at half resolution # Stage 1: Generate at half resolution
print(f"Stage 1: Generating at {width//2}x{height//2} (8 steps)...") print(f"{Colors.YELLOW}Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
mx.random.seed(seed) mx.random.seed(seed)
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w)) latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
mx.eval(latents) mx.eval(latents)
@@ -227,7 +239,7 @@ def generate_video(
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS) latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS)
# Upsample latents # Upsample latents
print("Upsampling latents 2x...") print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}")
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')) upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
mx.eval(upsampler.parameters()) mx.eval(upsampler.parameters())
@@ -243,7 +255,7 @@ def generate_video(
mx.clear_cache() mx.clear_cache()
# Stage 2: Refine at full resolution # Stage 2: Refine at full resolution
print(f"Stage 2: Refining at {width}x{height} (3 steps)...") print(f"{Colors.YELLOW}Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}")
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions) mx.eval(positions)
@@ -259,7 +271,7 @@ def generate_video(
mx.clear_cache() mx.clear_cache()
# Decode to video # Decode to video
print("Decoding video...") print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}")
video = vae_decoder(latents) video = vae_decoder(latents)
mx.eval(video) mx.eval(video)
mx.clear_cache() mx.clear_cache()
@@ -283,19 +295,19 @@ def generate_video(
for frame in video_np: for frame in video_np:
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
out.release() out.release()
print(f"Saved video to {output_path}") print(f"{Colors.GREEN}Saved video to {output_path}{Colors.RESET}")
except Exception as e: except Exception as e:
print(f"Could not save video: {e}") print(f"{Colors.RED}Could not save video: {e}{Colors.RESET}")
if save_frames: if save_frames:
frames_dir = output_path.parent / f"{output_path.stem}_frames" frames_dir = output_path.parent / f"{output_path.stem}_frames"
frames_dir.mkdir(exist_ok=True) frames_dir.mkdir(exist_ok=True)
for i, frame in enumerate(video_np): for i, frame in enumerate(video_np):
Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png") Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png")
print(f"Saved {len(video_np)} frames to {frames_dir}") print(f"{Colors.GREEN}Saved {len(video_np)} frames to {frames_dir}{Colors.RESET}")
elapsed = time.time() - start_time elapsed = time.time() - start_time
print(f"Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)") print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}")
return video_np return video_np