Enhance video generation pipeline by integrating Rich for styled console output and progress tracking. Update dependencies in pyproject.toml to include Rich. Refactor print statements to use console methods for improved user experience during video and audio processing.

This commit is contained in:
Prince Canuma
2026-01-19 01:43:14 +01:00
parent cae11291a9
commit 0538af6554
3 changed files with 263 additions and 200 deletions

View File

@@ -7,20 +7,13 @@ from typing import Optional
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
from rich.panel import Panel
from rich.status import Status
# Rich console for styled output
# ANSI color codes console = Console()
class Colors:
CYAN = "\033[96m"
BLUE = "\033[94m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
RED = "\033[91m"
MAGENTA = "\033[95m"
BOLD = "\033[1m"
DIM = "\033[2m"
RESET = "\033[0m"
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
@@ -198,91 +191,106 @@ def denoise(
if state is not None: if state is not None:
latents = state.latent latents = state.latent
desc = "Denoising A/V" if enable_audio else "Denoising" desc = "[cyan]Denoising A/V[/]" if enable_audio else "[cyan]Denoising[/]"
for i in tqdm(range(len(sigmas) - 1), desc=desc, disable=not verbose): num_steps = len(sigmas) - 1
sigma, sigma_next = sigmas[i], sigmas[i + 1]
b, c, f, h, w = latents.shape with Progress(
num_tokens = f * h * w SpinnerColumn(),
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
TimeRemainingColumn(),
console=console,
disable=not verbose,
) as progress:
task = progress.add_task(desc, total=num_steps)
# Compute per-token timesteps for i in range(num_steps):
# For I2V: conditioned tokens get timestep=0 (mask=0), unconditioned get timestep=sigma (mask=1) sigma, sigma_next = sigmas[i], sigmas[i + 1]
if state is not None:
# Reshape denoise_mask from (B, 1, F, 1, 1) to (B, num_tokens)
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens))
# Per-token timesteps: sigma * mask (preserve dtype)
timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
else:
# All tokens get the same timestep (use latent dtype)
timesteps = mx.full((b, num_tokens), sigma, dtype=dtype)
video_modality = Modality( b, c, f, h, w = latents.shape
latent=latents_flat, num_tokens = f * h * w
timesteps=timesteps, latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
positions=positions,
context=text_embeddings,
context_mask=None,
enabled=True,
)
# Prepare audio modality if enabled # Compute per-token timesteps
audio_modality = None # For I2V: conditioned tokens get timestep=0 (mask=0), unconditioned get timestep=sigma (mask=1)
if enable_audio: if state is not None:
ab, ac, at, af = audio_latents.shape # Reshape denoise_mask from (B, 1, F, 1, 1) to (B, num_tokens)
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F) denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)) denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens))
# Per-token timesteps: sigma * mask (preserve dtype)
timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
else:
# All tokens get the same timestep (use latent dtype)
timesteps = mx.full((b, num_tokens), sigma, dtype=dtype)
audio_modality = Modality( video_modality = Modality(
latent=audio_flat, latent=latents_flat,
timesteps=mx.full((ab, at), sigma, dtype=dtype), timesteps=timesteps,
positions=audio_positions, positions=positions,
context=audio_embeddings, context=text_embeddings,
context_mask=None, context_mask=None,
enabled=True, enabled=True,
) )
velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality) # Prepare audio modality if enabled
mx.eval(velocity) audio_modality = None
if audio_velocity is not None: if enable_audio:
mx.eval(audio_velocity) ab, ac, at, af = audio_latents.shape
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F)
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af))
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w)) audio_modality = Modality(
denoised = to_denoised(latents, velocity, sigma) latent=audio_flat,
timesteps=mx.full((ab, at), sigma, dtype=dtype),
positions=audio_positions,
context=audio_embeddings,
context_mask=None,
enabled=True,
)
# Handle audio velocity if enabled velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality)
audio_denoised = None mx.eval(velocity)
if enable_audio and audio_velocity is not None: if audio_velocity is not None:
ab, ac, at, af = audio_latents.shape mx.eval(audio_velocity)
audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af))
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) # (B, C, T, F)
audio_denoised = to_denoised(audio_latents, audio_velocity, sigma)
# Apply conditioning mask if state is provided velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
if state is not None: denoised = to_denoised(latents, velocity, sigma)
denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask)
mx.eval(denoised) # Handle audio velocity if enabled
if audio_denoised is not None: audio_denoised = None
mx.eval(audio_denoised) if enable_audio and audio_velocity is not None:
ab, ac, at, af = audio_latents.shape
audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af))
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) # (B, C, T, F)
audio_denoised = to_denoised(audio_latents, audio_velocity, sigma)
# Euler step (preserve dtype by converting Python floats to arrays) # Apply conditioning mask if state is provided
if sigma_next > 0: if state is not None:
sigma_next_arr = mx.array(sigma_next, dtype=dtype) denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask)
sigma_arr = mx.array(sigma, dtype=dtype)
latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr
if enable_audio and audio_denoised is not None:
audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr
else:
latents = denoised
if enable_audio and audio_denoised is not None:
audio_latents = audio_denoised
mx.eval(latents) mx.eval(denoised)
if enable_audio: if audio_denoised is not None:
mx.eval(audio_latents) mx.eval(audio_denoised)
# Euler step (preserve dtype by converting Python floats to arrays)
if sigma_next > 0:
sigma_next_arr = mx.array(sigma_next, dtype=dtype)
sigma_arr = mx.array(sigma, dtype=dtype)
latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr
if enable_audio and audio_denoised is not None:
audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr
else:
latents = denoised
if enable_audio and audio_denoised is not None:
audio_latents = audio_denoised
mx.eval(latents)
if enable_audio:
mx.eval(audio_latents)
progress.advance(task)
return latents, audio_latents if enable_audio else None return latents, audio_latents if enable_audio else None
@@ -380,10 +388,10 @@ def mux_video_audio(video_path: Path, audio_path: Path, output_path: Path):
subprocess.run(cmd, check=True, capture_output=True) subprocess.run(cmd, check=True, capture_output=True)
return True return True
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print(f"{Colors.RED}FFmpeg error: {e.stderr.decode()}{Colors.RESET}") console.print(f"[red]FFmpeg error: {e.stderr.decode()}[/]")
return False return False
except FileNotFoundError: except FileNotFoundError:
print(f"{Colors.RED}FFmpeg not found. Please install ffmpeg.{Colors.RESET}") console.print("[red]FFmpeg not found. Please install ffmpeg.[/]")
return False return False
@@ -451,7 +459,7 @@ def generate_video(
if num_frames % 8 != 1: if num_frames % 8 != 1:
adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1 adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1
print(f"{Colors.YELLOW}⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}") console.print(f"[yellow]⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}[/]")
num_frames = adjusted_num_frames num_frames = adjusted_num_frames
is_i2v = image is not None is_i2v = image is not None
@@ -459,16 +467,18 @@ def generate_video(
if audio: if audio:
mode_str += "+Audio" mode_str += "+Audio"
print(f"{Colors.BOLD}{Colors.CYAN}🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames{Colors.RESET}") # Display header panel
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}") header = f"[bold cyan]🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames[/]"
console.print(Panel(header, expand=False))
console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]")
if is_i2v: if is_i2v:
print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}") console.print(f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]")
# Calculate audio frames if enabled # Calculate audio frames if enabled
audio_frames = None audio_frames = None
if audio: if audio:
audio_frames = compute_audio_frames(num_frames, fps) audio_frames = compute_audio_frames(num_frames, fps)
print(f"{Colors.DIM}Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz{Colors.RESET}") console.print(f"[dim]Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz[/]")
# Get model path # Get model path
model_path = get_model_path(model_repo) model_path = get_model_path(model_repo)
@@ -482,17 +492,18 @@ def generate_video(
mx.random.seed(seed) mx.random.seed(seed)
# Load text encoder # Load text encoder
print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}") with console.status("[blue]📝 Loading text encoder...[/]", spinner="dots"):
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
text_encoder = LTX2TextEncoder() text_encoder = LTX2TextEncoder()
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path) text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
mx.eval(text_encoder.parameters()) mx.eval(text_encoder.parameters())
console.print("[green]✓[/] Text encoder loaded")
# Optionally enhance the prompt # Optionally enhance the prompt
if enhance_prompt: if enhance_prompt:
print(f"{Colors.MAGENTA}✨ Enhancing prompt...{Colors.RESET}") with console.status("[magenta]✨ Enhancing prompt...[/]", spinner="dots"):
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose) prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}") console.print(f"[dim]Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}[/]")
# Get embeddings - with audio if enabled # Get embeddings - with audio if enabled
if audio: if audio:
@@ -509,75 +520,76 @@ def generate_video(
mx.clear_cache() mx.clear_cache()
# Load transformer # Load transformer
print(f"{Colors.BLUE}🤖 Loading transformer{' (A/V mode)' if audio else ''}...{Colors.RESET}") transformer_desc = "🤖 Loading transformer (A/V mode)..." if audio else "🤖 Loading transformer..."
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors')) with console.status(f"[blue]{transformer_desc}[/]", spinner="dots"):
sanitized = sanitize_transformer_weights(raw_weights) raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
# Convert transformer weights to bfloat16 for memory efficiency sanitized = sanitize_transformer_weights(raw_weights)
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()} # Convert transformer weights to bfloat16 for memory efficiency
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
# Configure model type based on audio flag # Configure model type based on audio flag
model_type = LTXModelType.AudioVideo if audio else LTXModelType.VideoOnly model_type = LTXModelType.AudioVideo if audio else LTXModelType.VideoOnly
config_kwargs = dict( config_kwargs = dict(
model_type=model_type, model_type=model_type,
num_attention_heads=32, num_attention_heads=32,
attention_head_dim=128, attention_head_dim=128,
in_channels=128, in_channels=128,
out_channels=128, out_channels=128,
num_layers=48, num_layers=48,
cross_attention_dim=4096, cross_attention_dim=4096,
caption_channels=3840, caption_channels=3840,
rope_type=LTXRopeType.SPLIT, rope_type=LTXRopeType.SPLIT,
double_precision_rope=True, double_precision_rope=True,
positional_embedding_theta=10000.0, positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048], positional_embedding_max_pos=[20, 2048, 2048],
use_middle_indices_grid=True, use_middle_indices_grid=True,
timestep_scale_multiplier=1000, timestep_scale_multiplier=1000,
)
if audio:
config_kwargs.update(
audio_num_attention_heads=32,
audio_attention_head_dim=64,
audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, # 8 * 16 = 128
audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS,
audio_cross_attention_dim=2048,
audio_positional_embedding_max_pos=[20],
) )
config = LTXModelConfig(**config_kwargs) if audio:
config_kwargs.update(
audio_num_attention_heads=32,
audio_attention_head_dim=64,
audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, # 8 * 16 = 128
audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS,
audio_cross_attention_dim=2048,
audio_positional_embedding_max_pos=[20],
)
transformer = LTXModel(config) config = LTXModelConfig(**config_kwargs)
transformer.load_weights(list(sanitized.items()), strict=False)
mx.eval(transformer.parameters()) transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False)
mx.eval(transformer.parameters())
console.print("[green]✓[/] Transformer loaded")
# Load VAE encoder and encode image for I2V conditioning # Load VAE encoder and encode image for I2V conditioning
stage1_image_latent = None stage1_image_latent = None
stage2_image_latent = None stage2_image_latent = None
if is_i2v: if is_i2v:
print(f"{Colors.BLUE}🖼️ Loading VAE encoder and encoding image...{Colors.RESET}") with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
vae_encoder = load_vae_encoder(str(model_path / 'ltx-2-19b-distilled.safetensors')) vae_encoder = load_vae_encoder(str(model_path / 'ltx-2-19b-distilled.safetensors'))
mx.eval(vae_encoder.parameters()) mx.eval(vae_encoder.parameters())
# Load and prepare image for stage 1 (half resolution) # Load and prepare image for stage 1 (half resolution)
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype) input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype) stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
stage1_image_latent = vae_encoder(stage1_image_tensor) stage1_image_latent = vae_encoder(stage1_image_tensor)
mx.eval(stage1_image_latent) mx.eval(stage1_image_latent)
print(f" Stage 1 image latent: {stage1_image_latent.shape}")
# Load and prepare image for stage 2 (full resolution) # Load and prepare image for stage 2 (full resolution)
input_image = load_image(image, height=height, width=width, dtype=model_dtype) input_image = load_image(image, height=height, width=width, dtype=model_dtype)
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype) stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
stage2_image_latent = vae_encoder(stage2_image_tensor) stage2_image_latent = vae_encoder(stage2_image_tensor)
mx.eval(stage2_image_latent) mx.eval(stage2_image_latent)
print(f" Stage 2 image latent: {stage2_image_latent.shape}")
del vae_encoder del vae_encoder
mx.clear_cache() mx.clear_cache()
console.print("[green]✓[/] VAE encoder loaded and image encoded")
# Stage 1: Generate at half resolution # Stage 1: Generate at half resolution
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}") console.print(f"\n[bold yellow]⚡ Stage 1:[/] Generating at {width//2}x{height//2} (8 steps)")
mx.random.seed(seed) mx.random.seed(seed)
# Position grids stay float32 for RoPE precision # Position grids stay float32 for RoPE precision
@@ -636,23 +648,24 @@ def generate_video(
) )
# Upsample latents # Upsample latents
print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}") with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"):
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')) upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
mx.eval(upsampler.parameters()) mx.eval(upsampler.parameters())
vae_decoder = load_vae_decoder( vae_decoder = load_vae_decoder(
str(model_path / 'ltx-2-19b-distilled.safetensors'), str(model_path / 'ltx-2-19b-distilled.safetensors'),
timestep_conditioning=None # Auto-detect from model metadata timestep_conditioning=None # Auto-detect from model metadata
) )
latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std) latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
mx.eval(latents) mx.eval(latents)
del upsampler del upsampler
mx.clear_cache() mx.clear_cache()
console.print("[green]✓[/] Latents upsampled")
# Stage 2: Refine at full resolution # Stage 2: Refine at full resolution
print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}") console.print(f"\n[bold yellow]⚡ Stage 2:[/] Refining at {width}x{height} (3 steps)")
# Position grids stay float32 for RoPE precision # Position grids stay float32 for RoPE precision
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions) mx.eval(positions)
@@ -717,7 +730,7 @@ def generate_video(
mx.clear_cache() mx.clear_cache()
# Decode to video with tiling # Decode to video with tiling
print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}") console.print("\n[blue]🎞️ Decoding video...[/]")
# Select tiling configuration # Select tiling configuration
if tiling == "none": if tiling == "none":
@@ -735,7 +748,7 @@ def generate_video(
elif tiling == "temporal": elif tiling == "temporal":
tiling_config = TilingConfig.temporal_only() tiling_config = TilingConfig.temporal_only()
else: else:
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}") console.print(f"[yellow] Unknown tiling mode '{tiling}', using auto[/]")
tiling_config = TilingConfig.auto(height, width, num_frames) tiling_config = TilingConfig.auto(height, width, num_frames)
# Save outputs # Save outputs
@@ -744,13 +757,21 @@ def generate_video(
# Stream mode: write frames as they're decoded # Stream mode: write frames as they're decoded
video_writer = None video_writer = None
stream_pbar = None stream_progress = None
if stream and tiling_config is not None: if stream and tiling_config is not None:
import cv2 import cv2
fourcc = cv2.VideoWriter_fourcc(*'avc1') fourcc = cv2.VideoWriter_fourcc(*'avc1')
video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height)) video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
stream_pbar = tqdm(total=num_frames, desc="Streaming", unit="frame") stream_progress = Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
console=console,
)
stream_progress.start()
stream_task = stream_progress.add_task("[cyan]Streaming frames[/]", total=num_frames)
def on_frames_ready(frames: mx.array, _start_idx: int): def on_frames_ready(frames: mx.array, _start_idx: int):
"""Callback to write frames as they're finalized.""" """Callback to write frames as they're finalized."""
@@ -763,17 +784,17 @@ def generate_video(
for frame in frames_np: for frame in frames_np:
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
stream_pbar.update(1) stream_progress.advance(stream_task)
else: else:
on_frames_ready = None on_frames_ready = None
if tiling_config is not None: if tiling_config is not None:
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none" spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none" temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}") console.print(f"[dim] Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}[/]")
video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose, on_frames_ready=on_frames_ready) video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose, on_frames_ready=on_frames_ready)
else: else:
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}") console.print("[dim] Tiling: disabled[/]")
video = vae_decoder(latents) video = vae_decoder(latents)
mx.eval(video) mx.eval(video)
mx.clear_cache() mx.clear_cache()
@@ -781,9 +802,9 @@ def generate_video(
# Close progressive video writer if used # Close progressive video writer if used
if video_writer is not None: if video_writer is not None:
video_writer.release() video_writer.release()
if stream_pbar is not None: if stream_progress is not None:
stream_pbar.close() stream_progress.stop()
print(f"{Colors.GREEN}✅ Streamed video to{Colors.RESET} {output_path}") console.print(f"[green]✅ Streamed video to[/] {output_path}")
# Still need video_np for save_frames option # Still need video_np for save_frames option
video = mx.squeeze(video, axis=0) video = mx.squeeze(video, axis=0)
video = mx.transpose(video, (1, 2, 3, 0)) video = mx.transpose(video, (1, 2, 3, 0))
@@ -815,45 +836,47 @@ def generate_video(
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
out.release() out.release()
if not audio: if not audio:
print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}") console.print(f"[green]✅ Saved video to[/] {output_path}")
except Exception as e: except Exception as e:
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}") console.print(f"[red]❌ Could not save video: {e}[/]")
# Decode and save audio if enabled # Decode and save audio if enabled
audio_np = None audio_np = None
if audio and audio_latents is not None: if audio and audio_latents is not None:
print(f"{Colors.BLUE}🔊 Decoding audio...{Colors.RESET}") with console.status("[blue]🔊 Decoding audio...[/]", spinner="dots"):
audio_decoder = load_audio_decoder(model_path) audio_decoder = load_audio_decoder(model_path)
vocoder = load_vocoder(model_path) vocoder = load_vocoder(model_path)
mx.eval(audio_decoder.parameters(), vocoder.parameters()) mx.eval(audio_decoder.parameters(), vocoder.parameters())
mel_spectrogram = audio_decoder(audio_latents) mel_spectrogram = audio_decoder(audio_latents)
mx.eval(mel_spectrogram) mx.eval(mel_spectrogram)
audio_waveform = vocoder(mel_spectrogram) audio_waveform = vocoder(mel_spectrogram)
mx.eval(audio_waveform) mx.eval(audio_waveform)
audio_np = np.array(audio_waveform) audio_np = np.array(audio_waveform)
if audio_np.ndim == 3: if audio_np.ndim == 3:
audio_np = audio_np[0] audio_np = audio_np[0]
del audio_decoder, vocoder del audio_decoder, vocoder
mx.clear_cache() mx.clear_cache()
console.print("[green]✓[/] Audio decoded")
# Save audio # Save audio
audio_path = Path(output_audio_path) if output_audio_path else output_path.with_suffix('.wav') audio_path = Path(output_audio_path) if output_audio_path else output_path.with_suffix('.wav')
save_audio(audio_np, audio_path, AUDIO_SAMPLE_RATE) save_audio(audio_np, audio_path, AUDIO_SAMPLE_RATE)
print(f"{Colors.GREEN}✅ Saved audio to{Colors.RESET} {audio_path}") console.print(f"[green]✅ Saved audio to[/] {audio_path}")
# Mux video and audio # Mux video and audio
print(f"{Colors.BLUE}🎬 Combining video and audio...{Colors.RESET}") with console.status("[blue]🎬 Combining video and audio...[/]", spinner="dots"):
temp_video_path = output_path.with_suffix('.temp.mp4') temp_video_path = output_path.with_suffix('.temp.mp4')
if mux_video_audio(temp_video_path, audio_path, output_path): success = mux_video_audio(temp_video_path, audio_path, output_path)
print(f"{Colors.GREEN}✅ Saved video with audio to{Colors.RESET} {output_path}") if success:
console.print(f"[green]✅ Saved video with audio to[/] {output_path}")
temp_video_path.unlink() temp_video_path.unlink()
else: else:
temp_video_path.rename(output_path) temp_video_path.rename(output_path)
print(f"{Colors.YELLOW}⚠️ Saved video without audio to{Colors.RESET} {output_path}") console.print(f"[yellow]⚠️ Saved video without audio to[/] {output_path}")
del vae_decoder del vae_decoder
mx.clear_cache() mx.clear_cache()
@@ -863,11 +886,14 @@ def generate_video(
frames_dir.mkdir(exist_ok=True) frames_dir.mkdir(exist_ok=True)
for i, frame in enumerate(video_np): for i, frame in enumerate(video_np):
Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png") Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png")
print(f"{Colors.GREEN}✅ Saved {len(video_np)} frames to {frames_dir}{Colors.RESET}") console.print(f"[green]✅ Saved {len(video_np)} frames to {frames_dir}[/]")
elapsed = time.time() - start_time elapsed = time.time() - start_time
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}") console.print(Panel(
print(f"{Colors.BOLD}{Colors.GREEN}✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}") f"[bold green]🎉 Done![/] Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)\n"
f"[bold green]✨ Peak memory:[/] {mx.get_peak_memory() / (1024 ** 3):.2f}GB",
expand=False
))
if audio: if audio:
return video_np, audio_np return video_np, audio_np

View File

@@ -19,7 +19,8 @@ dependencies = [
"tqdm", "tqdm",
"opencv-python>=4.12.0.88", "opencv-python>=4.12.0.88",
"Pillow>=10.3.0", "Pillow>=10.3.0",
"mlx-vlm" "mlx-vlm",
"rich>=14.2.0",
] ]
license = {text="MIT"} license = {text="MIT"}
authors = [ authors = [

36
uv.lock generated
View File

@@ -635,6 +635,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" },
] ]
[[package]]
name = "markdown-it-py"
version = "4.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mdurl" },
]
sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" },
]
[[package]] [[package]]
name = "markupsafe" name = "markupsafe"
version = "3.0.3" version = "3.0.3"
@@ -709,6 +721,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/70/bc/6f1c2f612465f5fa89b95bead1f44dcb607670fd42891d8fdcd5d039f4f4/markupsafe-3.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:32001d6a8fc98c8cb5c947787c5d08b0a50663d139f1305bac5885d98d9b40fa", size = 14146, upload-time = "2025-09-27T18:37:28.327Z" }, { url = "https://files.pythonhosted.org/packages/70/bc/6f1c2f612465f5fa89b95bead1f44dcb607670fd42891d8fdcd5d039f4f4/markupsafe-3.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:32001d6a8fc98c8cb5c947787c5d08b0a50663d139f1305bac5885d98d9b40fa", size = 14146, upload-time = "2025-09-27T18:37:28.327Z" },
] ]
[[package]]
name = "mdurl"
version = "0.1.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" },
]
[[package]] [[package]]
name = "mlx" name = "mlx"
version = "0.30.1" version = "0.30.1"
@@ -777,6 +798,7 @@ dependencies = [
{ name = "numpy" }, { name = "numpy" },
{ name = "opencv-python" }, { name = "opencv-python" },
{ name = "pillow" }, { name = "pillow" },
{ name = "rich" },
{ name = "safetensors" }, { name = "safetensors" },
{ name = "tqdm" }, { name = "tqdm" },
{ name = "transformers", extra = ["tokenizers"] }, { name = "transformers", extra = ["tokenizers"] },
@@ -796,6 +818,7 @@ requires-dist = [
{ name = "opencv-python", specifier = ">=4.12.0.88" }, { name = "opencv-python", specifier = ">=4.12.0.88" },
{ name = "pillow", specifier = ">=10.3.0" }, { name = "pillow", specifier = ">=10.3.0" },
{ name = "pytest", marker = "extra == 'dev'" }, { name = "pytest", marker = "extra == 'dev'" },
{ name = "rich", specifier = ">=14.2.0" },
{ name = "safetensors" }, { name = "safetensors" },
{ name = "tqdm" }, { name = "tqdm" },
{ name = "transformers", extras = ["tokenizers"] }, { name = "transformers", extras = ["tokenizers"] },
@@ -1679,6 +1702,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
] ]
[[package]]
name = "rich"
version = "14.2.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "markdown-it-py" },
{ name = "pygments" },
]
sdist = { url = "https://files.pythonhosted.org/packages/fb/d2/8920e102050a0de7bfabeb4c4614a49248cf8d5d7a8d01885fbb24dc767a/rich-14.2.0.tar.gz", hash = "sha256:73ff50c7c0c1c77c8243079283f4edb376f0f6442433aecb8ce7e6d0b92d1fe4", size = 219990, upload-time = "2025-10-09T14:16:53.064Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/25/7a/b0178788f8dc6cafce37a212c99565fa1fe7872c70c6c9c1e1a372d9d88f/rich-14.2.0-py3-none-any.whl", hash = "sha256:76bc51fe2e57d2b1be1f96c524b890b816e334ab4c1e45888799bfaab0021edd", size = 243393, upload-time = "2025-10-09T14:16:51.245Z" },
]
[[package]] [[package]]
name = "safetensors" name = "safetensors"
version = "0.7.0" version = "0.7.0"