diff --git a/mlx_video/generate.py b/mlx_video/generate.py index a9dcf85..5b618b7 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -7,20 +7,13 @@ from typing import Optional import mlx.core as mx import numpy as np from PIL import Image -from tqdm import tqdm +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn +from rich.panel import Panel +from rich.status import Status - -# ANSI color codes -class Colors: - CYAN = "\033[96m" - BLUE = "\033[94m" - GREEN = "\033[92m" - YELLOW = "\033[93m" - RED = "\033[91m" - MAGENTA = "\033[95m" - BOLD = "\033[1m" - DIM = "\033[2m" - RESET = "\033[0m" +# Rich console for styled output +console = Console() from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType @@ -198,91 +191,106 @@ def denoise( if state is not None: latents = state.latent - desc = "Denoising A/V" if enable_audio else "Denoising" - for i in tqdm(range(len(sigmas) - 1), desc=desc, disable=not verbose): - sigma, sigma_next = sigmas[i], sigmas[i + 1] + desc = "[cyan]Denoising A/V[/]" if enable_audio else "[cyan]Denoising[/]" + num_steps = len(sigmas) - 1 - b, c, f, h, w = latents.shape - num_tokens = f * h * w - latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeRemainingColumn(), + console=console, + disable=not verbose, + ) as progress: + task = progress.add_task(desc, total=num_steps) - # Compute per-token timesteps - # For I2V: conditioned tokens get timestep=0 (mask=0), unconditioned get timestep=sigma (mask=1) - if state is not None: - # Reshape denoise_mask from (B, 1, F, 1, 1) to (B, num_tokens) - denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1)) - denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w)) - denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens)) - # Per-token timesteps: sigma * mask (preserve dtype) - timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat - else: - # All tokens get the same timestep (use latent dtype) - timesteps = mx.full((b, num_tokens), sigma, dtype=dtype) + for i in range(num_steps): + sigma, sigma_next = sigmas[i], sigmas[i + 1] - video_modality = Modality( - latent=latents_flat, - timesteps=timesteps, - positions=positions, - context=text_embeddings, - context_mask=None, - enabled=True, - ) + b, c, f, h, w = latents.shape + num_tokens = f * h * w + latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) - # Prepare audio modality if enabled - audio_modality = None - if enable_audio: - ab, ac, at, af = audio_latents.shape - audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F) - audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)) + # Compute per-token timesteps + # For I2V: conditioned tokens get timestep=0 (mask=0), unconditioned get timestep=sigma (mask=1) + if state is not None: + # Reshape denoise_mask from (B, 1, F, 1, 1) to (B, num_tokens) + denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1)) + denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w)) + denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens)) + # Per-token timesteps: sigma * mask (preserve dtype) + timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat + else: + # All tokens get the same timestep (use latent dtype) + timesteps = mx.full((b, num_tokens), sigma, dtype=dtype) - audio_modality = Modality( - latent=audio_flat, - timesteps=mx.full((ab, at), sigma, dtype=dtype), - positions=audio_positions, - context=audio_embeddings, + video_modality = Modality( + latent=latents_flat, + timesteps=timesteps, + positions=positions, + context=text_embeddings, context_mask=None, enabled=True, ) - velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality) - mx.eval(velocity) - if audio_velocity is not None: - mx.eval(audio_velocity) + # Prepare audio modality if enabled + audio_modality = None + if enable_audio: + ab, ac, at, af = audio_latents.shape + audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F) + audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)) - velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w)) - denoised = to_denoised(latents, velocity, sigma) + audio_modality = Modality( + latent=audio_flat, + timesteps=mx.full((ab, at), sigma, dtype=dtype), + positions=audio_positions, + context=audio_embeddings, + context_mask=None, + enabled=True, + ) - # Handle audio velocity if enabled - audio_denoised = None - if enable_audio and audio_velocity is not None: - ab, ac, at, af = audio_latents.shape - audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af)) - audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) # (B, C, T, F) - audio_denoised = to_denoised(audio_latents, audio_velocity, sigma) + velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality) + mx.eval(velocity) + if audio_velocity is not None: + mx.eval(audio_velocity) - # Apply conditioning mask if state is provided - if state is not None: - denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask) + velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w)) + denoised = to_denoised(latents, velocity, sigma) - mx.eval(denoised) - if audio_denoised is not None: - mx.eval(audio_denoised) + # Handle audio velocity if enabled + audio_denoised = None + if enable_audio and audio_velocity is not None: + ab, ac, at, af = audio_latents.shape + audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af)) + audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) # (B, C, T, F) + audio_denoised = to_denoised(audio_latents, audio_velocity, sigma) - # Euler step (preserve dtype by converting Python floats to arrays) - if sigma_next > 0: - sigma_next_arr = mx.array(sigma_next, dtype=dtype) - sigma_arr = mx.array(sigma, dtype=dtype) - latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr - if enable_audio and audio_denoised is not None: - audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr - else: - latents = denoised - if enable_audio and audio_denoised is not None: - audio_latents = audio_denoised + # Apply conditioning mask if state is provided + if state is not None: + denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask) - mx.eval(latents) - if enable_audio: - mx.eval(audio_latents) + mx.eval(denoised) + if audio_denoised is not None: + mx.eval(audio_denoised) + + # Euler step (preserve dtype by converting Python floats to arrays) + if sigma_next > 0: + sigma_next_arr = mx.array(sigma_next, dtype=dtype) + sigma_arr = mx.array(sigma, dtype=dtype) + latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr + if enable_audio and audio_denoised is not None: + audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr + else: + latents = denoised + if enable_audio and audio_denoised is not None: + audio_latents = audio_denoised + + mx.eval(latents) + if enable_audio: + mx.eval(audio_latents) + + progress.advance(task) return latents, audio_latents if enable_audio else None @@ -380,10 +388,10 @@ def mux_video_audio(video_path: Path, audio_path: Path, output_path: Path): subprocess.run(cmd, check=True, capture_output=True) return True except subprocess.CalledProcessError as e: - print(f"{Colors.RED}FFmpeg error: {e.stderr.decode()}{Colors.RESET}") + console.print(f"[red]FFmpeg error: {e.stderr.decode()}[/]") return False except FileNotFoundError: - print(f"{Colors.RED}FFmpeg not found. Please install ffmpeg.{Colors.RESET}") + console.print("[red]FFmpeg not found. Please install ffmpeg.[/]") return False @@ -451,7 +459,7 @@ def generate_video( if num_frames % 8 != 1: adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1 - print(f"{Colors.YELLOW}⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}") + console.print(f"[yellow]⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}[/]") num_frames = adjusted_num_frames is_i2v = image is not None @@ -459,16 +467,18 @@ def generate_video( if audio: mode_str += "+Audio" - print(f"{Colors.BOLD}{Colors.CYAN}🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames{Colors.RESET}") - print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}") + # Display header panel + header = f"[bold cyan]🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames[/]" + console.print(Panel(header, expand=False)) + console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]") if is_i2v: - print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}") + console.print(f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]") # Calculate audio frames if enabled audio_frames = None if audio: audio_frames = compute_audio_frames(num_frames, fps) - print(f"{Colors.DIM}Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz{Colors.RESET}") + console.print(f"[dim]Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz[/]") # Get model path model_path = get_model_path(model_repo) @@ -482,17 +492,18 @@ def generate_video( mx.random.seed(seed) # Load text encoder - print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}") - from mlx_video.models.ltx.text_encoder import LTX2TextEncoder - text_encoder = LTX2TextEncoder() - text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path) - mx.eval(text_encoder.parameters()) + with console.status("[blue]📝 Loading text encoder...[/]", spinner="dots"): + from mlx_video.models.ltx.text_encoder import LTX2TextEncoder + text_encoder = LTX2TextEncoder() + text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path) + mx.eval(text_encoder.parameters()) + console.print("[green]✓[/] Text encoder loaded") # Optionally enhance the prompt if enhance_prompt: - print(f"{Colors.MAGENTA}✨ Enhancing prompt...{Colors.RESET}") - prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose) - print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}") + with console.status("[magenta]✨ Enhancing prompt...[/]", spinner="dots"): + prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose) + console.print(f"[dim]Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}[/]") # Get embeddings - with audio if enabled if audio: @@ -509,75 +520,76 @@ def generate_video( mx.clear_cache() # Load transformer - print(f"{Colors.BLUE}🤖 Loading transformer{' (A/V mode)' if audio else ''}...{Colors.RESET}") - raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors')) - sanitized = sanitize_transformer_weights(raw_weights) - # Convert transformer weights to bfloat16 for memory efficiency - sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()} + transformer_desc = "🤖 Loading transformer (A/V mode)..." if audio else "🤖 Loading transformer..." + with console.status(f"[blue]{transformer_desc}[/]", spinner="dots"): + raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors')) + sanitized = sanitize_transformer_weights(raw_weights) + # Convert transformer weights to bfloat16 for memory efficiency + sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()} - # Configure model type based on audio flag - model_type = LTXModelType.AudioVideo if audio else LTXModelType.VideoOnly + # Configure model type based on audio flag + model_type = LTXModelType.AudioVideo if audio else LTXModelType.VideoOnly - config_kwargs = dict( - model_type=model_type, - num_attention_heads=32, - attention_head_dim=128, - in_channels=128, - out_channels=128, - num_layers=48, - cross_attention_dim=4096, - caption_channels=3840, - rope_type=LTXRopeType.SPLIT, - double_precision_rope=True, - positional_embedding_theta=10000.0, - positional_embedding_max_pos=[20, 2048, 2048], - use_middle_indices_grid=True, - timestep_scale_multiplier=1000, - ) - - if audio: - config_kwargs.update( - audio_num_attention_heads=32, - audio_attention_head_dim=64, - audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, # 8 * 16 = 128 - audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, - audio_cross_attention_dim=2048, - audio_positional_embedding_max_pos=[20], + config_kwargs = dict( + model_type=model_type, + num_attention_heads=32, + attention_head_dim=128, + in_channels=128, + out_channels=128, + num_layers=48, + cross_attention_dim=4096, + caption_channels=3840, + rope_type=LTXRopeType.SPLIT, + double_precision_rope=True, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + timestep_scale_multiplier=1000, ) - config = LTXModelConfig(**config_kwargs) + if audio: + config_kwargs.update( + audio_num_attention_heads=32, + audio_attention_head_dim=64, + audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, # 8 * 16 = 128 + audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, + audio_cross_attention_dim=2048, + audio_positional_embedding_max_pos=[20], + ) - transformer = LTXModel(config) - transformer.load_weights(list(sanitized.items()), strict=False) - mx.eval(transformer.parameters()) + config = LTXModelConfig(**config_kwargs) + + transformer = LTXModel(config) + transformer.load_weights(list(sanitized.items()), strict=False) + mx.eval(transformer.parameters()) + console.print("[green]✓[/] Transformer loaded") # Load VAE encoder and encode image for I2V conditioning stage1_image_latent = None stage2_image_latent = None if is_i2v: - print(f"{Colors.BLUE}🖼️ Loading VAE encoder and encoding image...{Colors.RESET}") - vae_encoder = load_vae_encoder(str(model_path / 'ltx-2-19b-distilled.safetensors')) - mx.eval(vae_encoder.parameters()) + with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"): + vae_encoder = load_vae_encoder(str(model_path / 'ltx-2-19b-distilled.safetensors')) + mx.eval(vae_encoder.parameters()) - # Load and prepare image for stage 1 (half resolution) - input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype) - stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype) - stage1_image_latent = vae_encoder(stage1_image_tensor) - mx.eval(stage1_image_latent) - print(f" Stage 1 image latent: {stage1_image_latent.shape}") + # Load and prepare image for stage 1 (half resolution) + input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype) + stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype) + stage1_image_latent = vae_encoder(stage1_image_tensor) + mx.eval(stage1_image_latent) - # Load and prepare image for stage 2 (full resolution) - input_image = load_image(image, height=height, width=width, dtype=model_dtype) - stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype) - stage2_image_latent = vae_encoder(stage2_image_tensor) - mx.eval(stage2_image_latent) - print(f" Stage 2 image latent: {stage2_image_latent.shape}") + # Load and prepare image for stage 2 (full resolution) + input_image = load_image(image, height=height, width=width, dtype=model_dtype) + stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype) + stage2_image_latent = vae_encoder(stage2_image_tensor) + mx.eval(stage2_image_latent) - del vae_encoder - mx.clear_cache() + del vae_encoder + mx.clear_cache() + console.print("[green]✓[/] VAE encoder loaded and image encoded") # Stage 1: Generate at half resolution - print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}") + console.print(f"\n[bold yellow]⚡ Stage 1:[/] Generating at {width//2}x{height//2} (8 steps)") mx.random.seed(seed) # Position grids stay float32 for RoPE precision @@ -636,23 +648,24 @@ def generate_video( ) # Upsample latents - print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}") - upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')) - mx.eval(upsampler.parameters()) + with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"): + upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')) + mx.eval(upsampler.parameters()) - vae_decoder = load_vae_decoder( - str(model_path / 'ltx-2-19b-distilled.safetensors'), - timestep_conditioning=None # Auto-detect from model metadata - ) + vae_decoder = load_vae_decoder( + str(model_path / 'ltx-2-19b-distilled.safetensors'), + timestep_conditioning=None # Auto-detect from model metadata + ) - latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std) - mx.eval(latents) + latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std) + mx.eval(latents) - del upsampler - mx.clear_cache() + del upsampler + mx.clear_cache() + console.print("[green]✓[/] Latents upsampled") # Stage 2: Refine at full resolution - print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}") + console.print(f"\n[bold yellow]⚡ Stage 2:[/] Refining at {width}x{height} (3 steps)") # Position grids stay float32 for RoPE precision positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) mx.eval(positions) @@ -717,7 +730,7 @@ def generate_video( mx.clear_cache() # Decode to video with tiling - print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}") + console.print("\n[blue]🎞️ Decoding video...[/]") # Select tiling configuration if tiling == "none": @@ -735,7 +748,7 @@ def generate_video( elif tiling == "temporal": tiling_config = TilingConfig.temporal_only() else: - print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}") + console.print(f"[yellow] Unknown tiling mode '{tiling}', using auto[/]") tiling_config = TilingConfig.auto(height, width, num_frames) # Save outputs @@ -744,13 +757,21 @@ def generate_video( # Stream mode: write frames as they're decoded video_writer = None - stream_pbar = None + stream_progress = None if stream and tiling_config is not None: import cv2 fourcc = cv2.VideoWriter_fourcc(*'avc1') video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height)) - stream_pbar = tqdm(total=num_frames, desc="Streaming", unit="frame") + stream_progress = Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + console=console, + ) + stream_progress.start() + stream_task = stream_progress.add_task("[cyan]Streaming frames[/]", total=num_frames) def on_frames_ready(frames: mx.array, _start_idx: int): """Callback to write frames as they're finalized.""" @@ -763,17 +784,17 @@ def generate_video( for frame in frames_np: video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) - stream_pbar.update(1) + stream_progress.advance(stream_task) else: on_frames_ready = None if tiling_config is not None: spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none" temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none" - print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}") + console.print(f"[dim] Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}[/]") video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose, on_frames_ready=on_frames_ready) else: - print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}") + console.print("[dim] Tiling: disabled[/]") video = vae_decoder(latents) mx.eval(video) mx.clear_cache() @@ -781,9 +802,9 @@ def generate_video( # Close progressive video writer if used if video_writer is not None: video_writer.release() - if stream_pbar is not None: - stream_pbar.close() - print(f"{Colors.GREEN}✅ Streamed video to{Colors.RESET} {output_path}") + if stream_progress is not None: + stream_progress.stop() + console.print(f"[green]✅ Streamed video to[/] {output_path}") # Still need video_np for save_frames option video = mx.squeeze(video, axis=0) video = mx.transpose(video, (1, 2, 3, 0)) @@ -815,45 +836,47 @@ def generate_video( out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) out.release() if not audio: - print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}") + console.print(f"[green]✅ Saved video to[/] {output_path}") except Exception as e: - print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}") + console.print(f"[red]❌ Could not save video: {e}[/]") # Decode and save audio if enabled audio_np = None if audio and audio_latents is not None: - print(f"{Colors.BLUE}🔊 Decoding audio...{Colors.RESET}") - audio_decoder = load_audio_decoder(model_path) - vocoder = load_vocoder(model_path) - mx.eval(audio_decoder.parameters(), vocoder.parameters()) + with console.status("[blue]🔊 Decoding audio...[/]", spinner="dots"): + audio_decoder = load_audio_decoder(model_path) + vocoder = load_vocoder(model_path) + mx.eval(audio_decoder.parameters(), vocoder.parameters()) - mel_spectrogram = audio_decoder(audio_latents) - mx.eval(mel_spectrogram) + mel_spectrogram = audio_decoder(audio_latents) + mx.eval(mel_spectrogram) - audio_waveform = vocoder(mel_spectrogram) - mx.eval(audio_waveform) + audio_waveform = vocoder(mel_spectrogram) + mx.eval(audio_waveform) - audio_np = np.array(audio_waveform) - if audio_np.ndim == 3: - audio_np = audio_np[0] + audio_np = np.array(audio_waveform) + if audio_np.ndim == 3: + audio_np = audio_np[0] - del audio_decoder, vocoder - mx.clear_cache() + del audio_decoder, vocoder + mx.clear_cache() + console.print("[green]✓[/] Audio decoded") # Save audio audio_path = Path(output_audio_path) if output_audio_path else output_path.with_suffix('.wav') save_audio(audio_np, audio_path, AUDIO_SAMPLE_RATE) - print(f"{Colors.GREEN}✅ Saved audio to{Colors.RESET} {audio_path}") + console.print(f"[green]✅ Saved audio to[/] {audio_path}") # Mux video and audio - print(f"{Colors.BLUE}🎬 Combining video and audio...{Colors.RESET}") - temp_video_path = output_path.with_suffix('.temp.mp4') - if mux_video_audio(temp_video_path, audio_path, output_path): - print(f"{Colors.GREEN}✅ Saved video with audio to{Colors.RESET} {output_path}") + with console.status("[blue]🎬 Combining video and audio...[/]", spinner="dots"): + temp_video_path = output_path.with_suffix('.temp.mp4') + success = mux_video_audio(temp_video_path, audio_path, output_path) + if success: + console.print(f"[green]✅ Saved video with audio to[/] {output_path}") temp_video_path.unlink() else: temp_video_path.rename(output_path) - print(f"{Colors.YELLOW}⚠️ Saved video without audio to{Colors.RESET} {output_path}") + console.print(f"[yellow]⚠️ Saved video without audio to[/] {output_path}") del vae_decoder mx.clear_cache() @@ -863,11 +886,14 @@ def generate_video( frames_dir.mkdir(exist_ok=True) for i, frame in enumerate(video_np): Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png") - print(f"{Colors.GREEN}✅ Saved {len(video_np)} frames to {frames_dir}{Colors.RESET}") + console.print(f"[green]✅ Saved {len(video_np)} frames to {frames_dir}[/]") elapsed = time.time() - start_time - print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}") - print(f"{Colors.BOLD}{Colors.GREEN}✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}") + console.print(Panel( + f"[bold green]🎉 Done![/] Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)\n" + f"[bold green]✨ Peak memory:[/] {mx.get_peak_memory() / (1024 ** 3):.2f}GB", + expand=False + )) if audio: return video_np, audio_np diff --git a/pyproject.toml b/pyproject.toml index d9bf2f4..7c10195 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,8 @@ dependencies = [ "tqdm", "opencv-python>=4.12.0.88", "Pillow>=10.3.0", - "mlx-vlm" + "mlx-vlm", + "rich>=14.2.0", ] license = {text="MIT"} authors = [ @@ -52,4 +53,4 @@ version = {attr = "mlx_video.version.__version__"} [project.optional-dependencies] dev = [ "pytest", -] \ No newline at end of file +] diff --git a/uv.lock b/uv.lock index ec2a5dd..65e21f1 100644 --- a/uv.lock +++ b/uv.lock @@ -635,6 +635,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, ] +[[package]] +name = "markdown-it-py" +version = "4.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" }, +] + [[package]] name = "markupsafe" version = "3.0.3" @@ -709,6 +721,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/70/bc/6f1c2f612465f5fa89b95bead1f44dcb607670fd42891d8fdcd5d039f4f4/markupsafe-3.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:32001d6a8fc98c8cb5c947787c5d08b0a50663d139f1305bac5885d98d9b40fa", size = 14146, upload-time = "2025-09-27T18:37:28.327Z" }, ] +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, +] + [[package]] name = "mlx" version = "0.30.1" @@ -777,6 +798,7 @@ dependencies = [ { name = "numpy" }, { name = "opencv-python" }, { name = "pillow" }, + { name = "rich" }, { name = "safetensors" }, { name = "tqdm" }, { name = "transformers", extra = ["tokenizers"] }, @@ -796,6 +818,7 @@ requires-dist = [ { name = "opencv-python", specifier = ">=4.12.0.88" }, { name = "pillow", specifier = ">=10.3.0" }, { name = "pytest", marker = "extra == 'dev'" }, + { name = "rich", specifier = ">=14.2.0" }, { name = "safetensors" }, { name = "tqdm" }, { name = "transformers", extras = ["tokenizers"] }, @@ -1679,6 +1702,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, ] +[[package]] +name = "rich" +version = "14.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fb/d2/8920e102050a0de7bfabeb4c4614a49248cf8d5d7a8d01885fbb24dc767a/rich-14.2.0.tar.gz", hash = "sha256:73ff50c7c0c1c77c8243079283f4edb376f0f6442433aecb8ce7e6d0b92d1fe4", size = 219990, upload-time = "2025-10-09T14:16:53.064Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/25/7a/b0178788f8dc6cafce37a212c99565fa1fe7872c70c6c9c1e1a372d9d88f/rich-14.2.0-py3-none-any.whl", hash = "sha256:76bc51fe2e57d2b1be1f96c524b890b816e334ab4c1e45888799bfaab0021edd", size = 243393, upload-time = "2025-10-09T14:16:51.245Z" }, +] + [[package]] name = "safetensors" version = "0.7.0"