Enhance video generation pipeline by integrating Rich for styled console output and progress tracking. Update dependencies in pyproject.toml to include Rich. Refactor print statements to use console methods for improved user experience during video and audio processing.

This commit is contained in:
Prince Canuma
2026-01-19 01:43:14 +01:00
parent cae11291a9
commit 0538af6554
3 changed files with 263 additions and 200 deletions

View File

@@ -7,20 +7,13 @@ from typing import Optional
import mlx.core as mx
import numpy as np
from PIL import Image
from tqdm import tqdm
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
from rich.panel import Panel
from rich.status import Status
# ANSI color codes
class Colors:
CYAN = "\033[96m"
BLUE = "\033[94m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
RED = "\033[91m"
MAGENTA = "\033[95m"
BOLD = "\033[1m"
DIM = "\033[2m"
RESET = "\033[0m"
# Rich console for styled output
console = Console()
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
@@ -198,8 +191,21 @@ def denoise(
if state is not None:
latents = state.latent
desc = "Denoising A/V" if enable_audio else "Denoising"
for i in tqdm(range(len(sigmas) - 1), desc=desc, disable=not verbose):
desc = "[cyan]Denoising A/V[/]" if enable_audio else "[cyan]Denoising[/]"
num_steps = len(sigmas) - 1
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
TimeRemainingColumn(),
console=console,
disable=not verbose,
) as progress:
task = progress.add_task(desc, total=num_steps)
for i in range(num_steps):
sigma, sigma_next = sigmas[i], sigmas[i + 1]
b, c, f, h, w = latents.shape
@@ -284,6 +290,8 @@ def denoise(
if enable_audio:
mx.eval(audio_latents)
progress.advance(task)
return latents, audio_latents if enable_audio else None
@@ -380,10 +388,10 @@ def mux_video_audio(video_path: Path, audio_path: Path, output_path: Path):
subprocess.run(cmd, check=True, capture_output=True)
return True
except subprocess.CalledProcessError as e:
print(f"{Colors.RED}FFmpeg error: {e.stderr.decode()}{Colors.RESET}")
console.print(f"[red]FFmpeg error: {e.stderr.decode()}[/]")
return False
except FileNotFoundError:
print(f"{Colors.RED}FFmpeg not found. Please install ffmpeg.{Colors.RESET}")
console.print("[red]FFmpeg not found. Please install ffmpeg.[/]")
return False
@@ -451,7 +459,7 @@ def generate_video(
if num_frames % 8 != 1:
adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1
print(f"{Colors.YELLOW}⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}")
console.print(f"[yellow]⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}[/]")
num_frames = adjusted_num_frames
is_i2v = image is not None
@@ -459,16 +467,18 @@ def generate_video(
if audio:
mode_str += "+Audio"
print(f"{Colors.BOLD}{Colors.CYAN}🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
# Display header panel
header = f"[bold cyan]🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames[/]"
console.print(Panel(header, expand=False))
console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]")
if is_i2v:
print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}")
console.print(f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]")
# Calculate audio frames if enabled
audio_frames = None
if audio:
audio_frames = compute_audio_frames(num_frames, fps)
print(f"{Colors.DIM}Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz{Colors.RESET}")
console.print(f"[dim]Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz[/]")
# Get model path
model_path = get_model_path(model_repo)
@@ -482,17 +492,18 @@ def generate_video(
mx.random.seed(seed)
# Load text encoder
print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}")
with console.status("[blue]📝 Loading text encoder...[/]", spinner="dots"):
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
text_encoder = LTX2TextEncoder()
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
mx.eval(text_encoder.parameters())
console.print("[green]✓[/] Text encoder loaded")
# Optionally enhance the prompt
if enhance_prompt:
print(f"{Colors.MAGENTA}✨ Enhancing prompt...{Colors.RESET}")
with console.status("[magenta]✨ Enhancing prompt...[/]", spinner="dots"):
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
console.print(f"[dim]Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}[/]")
# Get embeddings - with audio if enabled
if audio:
@@ -509,7 +520,8 @@ def generate_video(
mx.clear_cache()
# Load transformer
print(f"{Colors.BLUE}🤖 Loading transformer{' (A/V mode)' if audio else ''}...{Colors.RESET}")
transformer_desc = "🤖 Loading transformer (A/V mode)..." if audio else "🤖 Loading transformer..."
with console.status(f"[blue]{transformer_desc}[/]", spinner="dots"):
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
sanitized = sanitize_transformer_weights(raw_weights)
# Convert transformer weights to bfloat16 for memory efficiency
@@ -550,12 +562,13 @@ def generate_video(
transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False)
mx.eval(transformer.parameters())
console.print("[green]✓[/] Transformer loaded")
# Load VAE encoder and encode image for I2V conditioning
stage1_image_latent = None
stage2_image_latent = None
if is_i2v:
print(f"{Colors.BLUE}🖼️ Loading VAE encoder and encoding image...{Colors.RESET}")
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
vae_encoder = load_vae_encoder(str(model_path / 'ltx-2-19b-distilled.safetensors'))
mx.eval(vae_encoder.parameters())
@@ -564,20 +577,19 @@ def generate_video(
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
stage1_image_latent = vae_encoder(stage1_image_tensor)
mx.eval(stage1_image_latent)
print(f" Stage 1 image latent: {stage1_image_latent.shape}")
# Load and prepare image for stage 2 (full resolution)
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
stage2_image_latent = vae_encoder(stage2_image_tensor)
mx.eval(stage2_image_latent)
print(f" Stage 2 image latent: {stage2_image_latent.shape}")
del vae_encoder
mx.clear_cache()
console.print("[green]✓[/] VAE encoder loaded and image encoded")
# Stage 1: Generate at half resolution
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
console.print(f"\n[bold yellow]⚡ Stage 1:[/] Generating at {width//2}x{height//2} (8 steps)")
mx.random.seed(seed)
# Position grids stay float32 for RoPE precision
@@ -636,7 +648,7 @@ def generate_video(
)
# Upsample latents
print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}")
with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"):
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
mx.eval(upsampler.parameters())
@@ -650,9 +662,10 @@ def generate_video(
del upsampler
mx.clear_cache()
console.print("[green]✓[/] Latents upsampled")
# Stage 2: Refine at full resolution
print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}")
console.print(f"\n[bold yellow]⚡ Stage 2:[/] Refining at {width}x{height} (3 steps)")
# Position grids stay float32 for RoPE precision
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions)
@@ -717,7 +730,7 @@ def generate_video(
mx.clear_cache()
# Decode to video with tiling
print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}")
console.print("\n[blue]🎞️ Decoding video...[/]")
# Select tiling configuration
if tiling == "none":
@@ -735,7 +748,7 @@ def generate_video(
elif tiling == "temporal":
tiling_config = TilingConfig.temporal_only()
else:
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
console.print(f"[yellow] Unknown tiling mode '{tiling}', using auto[/]")
tiling_config = TilingConfig.auto(height, width, num_frames)
# Save outputs
@@ -744,13 +757,21 @@ def generate_video(
# Stream mode: write frames as they're decoded
video_writer = None
stream_pbar = None
stream_progress = None
if stream and tiling_config is not None:
import cv2
fourcc = cv2.VideoWriter_fourcc(*'avc1')
video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
stream_pbar = tqdm(total=num_frames, desc="Streaming", unit="frame")
stream_progress = Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
console=console,
)
stream_progress.start()
stream_task = stream_progress.add_task("[cyan]Streaming frames[/]", total=num_frames)
def on_frames_ready(frames: mx.array, _start_idx: int):
"""Callback to write frames as they're finalized."""
@@ -763,17 +784,17 @@ def generate_video(
for frame in frames_np:
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
stream_pbar.update(1)
stream_progress.advance(stream_task)
else:
on_frames_ready = None
if tiling_config is not None:
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
console.print(f"[dim] Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}[/]")
video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose, on_frames_ready=on_frames_ready)
else:
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
console.print("[dim] Tiling: disabled[/]")
video = vae_decoder(latents)
mx.eval(video)
mx.clear_cache()
@@ -781,9 +802,9 @@ def generate_video(
# Close progressive video writer if used
if video_writer is not None:
video_writer.release()
if stream_pbar is not None:
stream_pbar.close()
print(f"{Colors.GREEN}✅ Streamed video to{Colors.RESET} {output_path}")
if stream_progress is not None:
stream_progress.stop()
console.print(f"[green]✅ Streamed video to[/] {output_path}")
# Still need video_np for save_frames option
video = mx.squeeze(video, axis=0)
video = mx.transpose(video, (1, 2, 3, 0))
@@ -815,14 +836,14 @@ def generate_video(
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
out.release()
if not audio:
print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}")
console.print(f"[green]✅ Saved video to[/] {output_path}")
except Exception as e:
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}")
console.print(f"[red]❌ Could not save video: {e}[/]")
# Decode and save audio if enabled
audio_np = None
if audio and audio_latents is not None:
print(f"{Colors.BLUE}🔊 Decoding audio...{Colors.RESET}")
with console.status("[blue]🔊 Decoding audio...[/]", spinner="dots"):
audio_decoder = load_audio_decoder(model_path)
vocoder = load_vocoder(model_path)
mx.eval(audio_decoder.parameters(), vocoder.parameters())
@@ -839,21 +860,23 @@ def generate_video(
del audio_decoder, vocoder
mx.clear_cache()
console.print("[green]✓[/] Audio decoded")
# Save audio
audio_path = Path(output_audio_path) if output_audio_path else output_path.with_suffix('.wav')
save_audio(audio_np, audio_path, AUDIO_SAMPLE_RATE)
print(f"{Colors.GREEN}✅ Saved audio to{Colors.RESET} {audio_path}")
console.print(f"[green]✅ Saved audio to[/] {audio_path}")
# Mux video and audio
print(f"{Colors.BLUE}🎬 Combining video and audio...{Colors.RESET}")
with console.status("[blue]🎬 Combining video and audio...[/]", spinner="dots"):
temp_video_path = output_path.with_suffix('.temp.mp4')
if mux_video_audio(temp_video_path, audio_path, output_path):
print(f"{Colors.GREEN}✅ Saved video with audio to{Colors.RESET} {output_path}")
success = mux_video_audio(temp_video_path, audio_path, output_path)
if success:
console.print(f"[green]✅ Saved video with audio to[/] {output_path}")
temp_video_path.unlink()
else:
temp_video_path.rename(output_path)
print(f"{Colors.YELLOW}⚠️ Saved video without audio to{Colors.RESET} {output_path}")
console.print(f"[yellow]⚠️ Saved video without audio to[/] {output_path}")
del vae_decoder
mx.clear_cache()
@@ -863,11 +886,14 @@ def generate_video(
frames_dir.mkdir(exist_ok=True)
for i, frame in enumerate(video_np):
Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png")
print(f"{Colors.GREEN}✅ Saved {len(video_np)} frames to {frames_dir}{Colors.RESET}")
console.print(f"[green]✅ Saved {len(video_np)} frames to {frames_dir}[/]")
elapsed = time.time() - start_time
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}")
print(f"{Colors.BOLD}{Colors.GREEN}✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}")
console.print(Panel(
f"[bold green]🎉 Done![/] Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)\n"
f"[bold green]✨ Peak memory:[/] {mx.get_peak_memory() / (1024 ** 3):.2f}GB",
expand=False
))
if audio:
return video_np, audio_np

View File

@@ -19,7 +19,8 @@ dependencies = [
"tqdm",
"opencv-python>=4.12.0.88",
"Pillow>=10.3.0",
"mlx-vlm"
"mlx-vlm",
"rich>=14.2.0",
]
license = {text="MIT"}
authors = [

36
uv.lock generated
View File

@@ -635,6 +635,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" },
]
[[package]]
name = "markdown-it-py"
version = "4.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mdurl" },
]
sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" },
]
[[package]]
name = "markupsafe"
version = "3.0.3"
@@ -709,6 +721,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/70/bc/6f1c2f612465f5fa89b95bead1f44dcb607670fd42891d8fdcd5d039f4f4/markupsafe-3.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:32001d6a8fc98c8cb5c947787c5d08b0a50663d139f1305bac5885d98d9b40fa", size = 14146, upload-time = "2025-09-27T18:37:28.327Z" },
]
[[package]]
name = "mdurl"
version = "0.1.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" },
]
[[package]]
name = "mlx"
version = "0.30.1"
@@ -777,6 +798,7 @@ dependencies = [
{ name = "numpy" },
{ name = "opencv-python" },
{ name = "pillow" },
{ name = "rich" },
{ name = "safetensors" },
{ name = "tqdm" },
{ name = "transformers", extra = ["tokenizers"] },
@@ -796,6 +818,7 @@ requires-dist = [
{ name = "opencv-python", specifier = ">=4.12.0.88" },
{ name = "pillow", specifier = ">=10.3.0" },
{ name = "pytest", marker = "extra == 'dev'" },
{ name = "rich", specifier = ">=14.2.0" },
{ name = "safetensors" },
{ name = "tqdm" },
{ name = "transformers", extras = ["tokenizers"] },
@@ -1679,6 +1702,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
]
[[package]]
name = "rich"
version = "14.2.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "markdown-it-py" },
{ name = "pygments" },
]
sdist = { url = "https://files.pythonhosted.org/packages/fb/d2/8920e102050a0de7bfabeb4c4614a49248cf8d5d7a8d01885fbb24dc767a/rich-14.2.0.tar.gz", hash = "sha256:73ff50c7c0c1c77c8243079283f4edb376f0f6442433aecb8ce7e6d0b92d1fe4", size = 219990, upload-time = "2025-10-09T14:16:53.064Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/25/7a/b0178788f8dc6cafce37a212c99565fa1fe7872c70c6c9c1e1a372d9d88f/rich-14.2.0-py3-none-any.whl", hash = "sha256:76bc51fe2e57d2b1be1f96c524b890b816e334ab4c1e45888799bfaab0021edd", size = 243393, upload-time = "2025-10-09T14:16:51.245Z" },
]
[[package]]
name = "safetensors"
version = "0.7.0"