Enhance video generation pipeline by integrating Rich for styled console output and progress tracking. Update dependencies in pyproject.toml to include Rich. Refactor print statements to use console methods for improved user experience during video and audio processing.
This commit is contained in:
@@ -7,20 +7,13 @@ from typing import Optional
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
|
||||
from rich.panel import Panel
|
||||
from rich.status import Status
|
||||
|
||||
|
||||
# ANSI color codes
|
||||
class Colors:
|
||||
CYAN = "\033[96m"
|
||||
BLUE = "\033[94m"
|
||||
GREEN = "\033[92m"
|
||||
YELLOW = "\033[93m"
|
||||
RED = "\033[91m"
|
||||
MAGENTA = "\033[95m"
|
||||
BOLD = "\033[1m"
|
||||
DIM = "\033[2m"
|
||||
RESET = "\033[0m"
|
||||
# Rich console for styled output
|
||||
console = Console()
|
||||
|
||||
|
||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
|
||||
@@ -198,91 +191,106 @@ def denoise(
|
||||
if state is not None:
|
||||
latents = state.latent
|
||||
|
||||
desc = "Denoising A/V" if enable_audio else "Denoising"
|
||||
for i in tqdm(range(len(sigmas) - 1), desc=desc, disable=not verbose):
|
||||
sigma, sigma_next = sigmas[i], sigmas[i + 1]
|
||||
desc = "[cyan]Denoising A/V[/]" if enable_audio else "[cyan]Denoising[/]"
|
||||
num_steps = len(sigmas) - 1
|
||||
|
||||
b, c, f, h, w = latents.shape
|
||||
num_tokens = f * h * w
|
||||
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
TimeRemainingColumn(),
|
||||
console=console,
|
||||
disable=not verbose,
|
||||
) as progress:
|
||||
task = progress.add_task(desc, total=num_steps)
|
||||
|
||||
# Compute per-token timesteps
|
||||
# For I2V: conditioned tokens get timestep=0 (mask=0), unconditioned get timestep=sigma (mask=1)
|
||||
if state is not None:
|
||||
# Reshape denoise_mask from (B, 1, F, 1, 1) to (B, num_tokens)
|
||||
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
|
||||
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
|
||||
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens))
|
||||
# Per-token timesteps: sigma * mask (preserve dtype)
|
||||
timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
|
||||
else:
|
||||
# All tokens get the same timestep (use latent dtype)
|
||||
timesteps = mx.full((b, num_tokens), sigma, dtype=dtype)
|
||||
for i in range(num_steps):
|
||||
sigma, sigma_next = sigmas[i], sigmas[i + 1]
|
||||
|
||||
video_modality = Modality(
|
||||
latent=latents_flat,
|
||||
timesteps=timesteps,
|
||||
positions=positions,
|
||||
context=text_embeddings,
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
)
|
||||
b, c, f, h, w = latents.shape
|
||||
num_tokens = f * h * w
|
||||
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
|
||||
|
||||
# Prepare audio modality if enabled
|
||||
audio_modality = None
|
||||
if enable_audio:
|
||||
ab, ac, at, af = audio_latents.shape
|
||||
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F)
|
||||
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af))
|
||||
# Compute per-token timesteps
|
||||
# For I2V: conditioned tokens get timestep=0 (mask=0), unconditioned get timestep=sigma (mask=1)
|
||||
if state is not None:
|
||||
# Reshape denoise_mask from (B, 1, F, 1, 1) to (B, num_tokens)
|
||||
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
|
||||
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
|
||||
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens))
|
||||
# Per-token timesteps: sigma * mask (preserve dtype)
|
||||
timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
|
||||
else:
|
||||
# All tokens get the same timestep (use latent dtype)
|
||||
timesteps = mx.full((b, num_tokens), sigma, dtype=dtype)
|
||||
|
||||
audio_modality = Modality(
|
||||
latent=audio_flat,
|
||||
timesteps=mx.full((ab, at), sigma, dtype=dtype),
|
||||
positions=audio_positions,
|
||||
context=audio_embeddings,
|
||||
video_modality = Modality(
|
||||
latent=latents_flat,
|
||||
timesteps=timesteps,
|
||||
positions=positions,
|
||||
context=text_embeddings,
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality)
|
||||
mx.eval(velocity)
|
||||
if audio_velocity is not None:
|
||||
mx.eval(audio_velocity)
|
||||
# Prepare audio modality if enabled
|
||||
audio_modality = None
|
||||
if enable_audio:
|
||||
ab, ac, at, af = audio_latents.shape
|
||||
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F)
|
||||
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af))
|
||||
|
||||
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
|
||||
denoised = to_denoised(latents, velocity, sigma)
|
||||
audio_modality = Modality(
|
||||
latent=audio_flat,
|
||||
timesteps=mx.full((ab, at), sigma, dtype=dtype),
|
||||
positions=audio_positions,
|
||||
context=audio_embeddings,
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
# Handle audio velocity if enabled
|
||||
audio_denoised = None
|
||||
if enable_audio and audio_velocity is not None:
|
||||
ab, ac, at, af = audio_latents.shape
|
||||
audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af))
|
||||
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) # (B, C, T, F)
|
||||
audio_denoised = to_denoised(audio_latents, audio_velocity, sigma)
|
||||
velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality)
|
||||
mx.eval(velocity)
|
||||
if audio_velocity is not None:
|
||||
mx.eval(audio_velocity)
|
||||
|
||||
# Apply conditioning mask if state is provided
|
||||
if state is not None:
|
||||
denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask)
|
||||
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
|
||||
denoised = to_denoised(latents, velocity, sigma)
|
||||
|
||||
mx.eval(denoised)
|
||||
if audio_denoised is not None:
|
||||
mx.eval(audio_denoised)
|
||||
# Handle audio velocity if enabled
|
||||
audio_denoised = None
|
||||
if enable_audio and audio_velocity is not None:
|
||||
ab, ac, at, af = audio_latents.shape
|
||||
audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af))
|
||||
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) # (B, C, T, F)
|
||||
audio_denoised = to_denoised(audio_latents, audio_velocity, sigma)
|
||||
|
||||
# Euler step (preserve dtype by converting Python floats to arrays)
|
||||
if sigma_next > 0:
|
||||
sigma_next_arr = mx.array(sigma_next, dtype=dtype)
|
||||
sigma_arr = mx.array(sigma, dtype=dtype)
|
||||
latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr
|
||||
if enable_audio and audio_denoised is not None:
|
||||
audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr
|
||||
else:
|
||||
latents = denoised
|
||||
if enable_audio and audio_denoised is not None:
|
||||
audio_latents = audio_denoised
|
||||
# Apply conditioning mask if state is provided
|
||||
if state is not None:
|
||||
denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask)
|
||||
|
||||
mx.eval(latents)
|
||||
if enable_audio:
|
||||
mx.eval(audio_latents)
|
||||
mx.eval(denoised)
|
||||
if audio_denoised is not None:
|
||||
mx.eval(audio_denoised)
|
||||
|
||||
# Euler step (preserve dtype by converting Python floats to arrays)
|
||||
if sigma_next > 0:
|
||||
sigma_next_arr = mx.array(sigma_next, dtype=dtype)
|
||||
sigma_arr = mx.array(sigma, dtype=dtype)
|
||||
latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr
|
||||
if enable_audio and audio_denoised is not None:
|
||||
audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr
|
||||
else:
|
||||
latents = denoised
|
||||
if enable_audio and audio_denoised is not None:
|
||||
audio_latents = audio_denoised
|
||||
|
||||
mx.eval(latents)
|
||||
if enable_audio:
|
||||
mx.eval(audio_latents)
|
||||
|
||||
progress.advance(task)
|
||||
|
||||
return latents, audio_latents if enable_audio else None
|
||||
|
||||
@@ -380,10 +388,10 @@ def mux_video_audio(video_path: Path, audio_path: Path, output_path: Path):
|
||||
subprocess.run(cmd, check=True, capture_output=True)
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"{Colors.RED}FFmpeg error: {e.stderr.decode()}{Colors.RESET}")
|
||||
console.print(f"[red]FFmpeg error: {e.stderr.decode()}[/]")
|
||||
return False
|
||||
except FileNotFoundError:
|
||||
print(f"{Colors.RED}FFmpeg not found. Please install ffmpeg.{Colors.RESET}")
|
||||
console.print("[red]FFmpeg not found. Please install ffmpeg.[/]")
|
||||
return False
|
||||
|
||||
|
||||
@@ -451,7 +459,7 @@ def generate_video(
|
||||
|
||||
if num_frames % 8 != 1:
|
||||
adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1
|
||||
print(f"{Colors.YELLOW}⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}")
|
||||
console.print(f"[yellow]⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}[/]")
|
||||
num_frames = adjusted_num_frames
|
||||
|
||||
is_i2v = image is not None
|
||||
@@ -459,16 +467,18 @@ def generate_video(
|
||||
if audio:
|
||||
mode_str += "+Audio"
|
||||
|
||||
print(f"{Colors.BOLD}{Colors.CYAN}🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
|
||||
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
|
||||
# Display header panel
|
||||
header = f"[bold cyan]🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames[/]"
|
||||
console.print(Panel(header, expand=False))
|
||||
console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]")
|
||||
if is_i2v:
|
||||
print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}")
|
||||
console.print(f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]")
|
||||
|
||||
# Calculate audio frames if enabled
|
||||
audio_frames = None
|
||||
if audio:
|
||||
audio_frames = compute_audio_frames(num_frames, fps)
|
||||
print(f"{Colors.DIM}Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz{Colors.RESET}")
|
||||
console.print(f"[dim]Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz[/]")
|
||||
|
||||
# Get model path
|
||||
model_path = get_model_path(model_repo)
|
||||
@@ -482,17 +492,18 @@ def generate_video(
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Load text encoder
|
||||
print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}")
|
||||
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
|
||||
text_encoder = LTX2TextEncoder()
|
||||
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
|
||||
mx.eval(text_encoder.parameters())
|
||||
with console.status("[blue]📝 Loading text encoder...[/]", spinner="dots"):
|
||||
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
|
||||
text_encoder = LTX2TextEncoder()
|
||||
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
|
||||
mx.eval(text_encoder.parameters())
|
||||
console.print("[green]✓[/] Text encoder loaded")
|
||||
|
||||
# Optionally enhance the prompt
|
||||
if enhance_prompt:
|
||||
print(f"{Colors.MAGENTA}✨ Enhancing prompt...{Colors.RESET}")
|
||||
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
|
||||
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
|
||||
with console.status("[magenta]✨ Enhancing prompt...[/]", spinner="dots"):
|
||||
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
|
||||
console.print(f"[dim]Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}[/]")
|
||||
|
||||
# Get embeddings - with audio if enabled
|
||||
if audio:
|
||||
@@ -509,75 +520,76 @@ def generate_video(
|
||||
mx.clear_cache()
|
||||
|
||||
# Load transformer
|
||||
print(f"{Colors.BLUE}🤖 Loading transformer{' (A/V mode)' if audio else ''}...{Colors.RESET}")
|
||||
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
|
||||
sanitized = sanitize_transformer_weights(raw_weights)
|
||||
# Convert transformer weights to bfloat16 for memory efficiency
|
||||
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
|
||||
transformer_desc = "🤖 Loading transformer (A/V mode)..." if audio else "🤖 Loading transformer..."
|
||||
with console.status(f"[blue]{transformer_desc}[/]", spinner="dots"):
|
||||
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
|
||||
sanitized = sanitize_transformer_weights(raw_weights)
|
||||
# Convert transformer weights to bfloat16 for memory efficiency
|
||||
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
|
||||
|
||||
# Configure model type based on audio flag
|
||||
model_type = LTXModelType.AudioVideo if audio else LTXModelType.VideoOnly
|
||||
# Configure model type based on audio flag
|
||||
model_type = LTXModelType.AudioVideo if audio else LTXModelType.VideoOnly
|
||||
|
||||
config_kwargs = dict(
|
||||
model_type=model_type,
|
||||
num_attention_heads=32,
|
||||
attention_head_dim=128,
|
||||
in_channels=128,
|
||||
out_channels=128,
|
||||
num_layers=48,
|
||||
cross_attention_dim=4096,
|
||||
caption_channels=3840,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision_rope=True,
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
timestep_scale_multiplier=1000,
|
||||
)
|
||||
|
||||
if audio:
|
||||
config_kwargs.update(
|
||||
audio_num_attention_heads=32,
|
||||
audio_attention_head_dim=64,
|
||||
audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, # 8 * 16 = 128
|
||||
audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS,
|
||||
audio_cross_attention_dim=2048,
|
||||
audio_positional_embedding_max_pos=[20],
|
||||
config_kwargs = dict(
|
||||
model_type=model_type,
|
||||
num_attention_heads=32,
|
||||
attention_head_dim=128,
|
||||
in_channels=128,
|
||||
out_channels=128,
|
||||
num_layers=48,
|
||||
cross_attention_dim=4096,
|
||||
caption_channels=3840,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision_rope=True,
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
timestep_scale_multiplier=1000,
|
||||
)
|
||||
|
||||
config = LTXModelConfig(**config_kwargs)
|
||||
if audio:
|
||||
config_kwargs.update(
|
||||
audio_num_attention_heads=32,
|
||||
audio_attention_head_dim=64,
|
||||
audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, # 8 * 16 = 128
|
||||
audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS,
|
||||
audio_cross_attention_dim=2048,
|
||||
audio_positional_embedding_max_pos=[20],
|
||||
)
|
||||
|
||||
transformer = LTXModel(config)
|
||||
transformer.load_weights(list(sanitized.items()), strict=False)
|
||||
mx.eval(transformer.parameters())
|
||||
config = LTXModelConfig(**config_kwargs)
|
||||
|
||||
transformer = LTXModel(config)
|
||||
transformer.load_weights(list(sanitized.items()), strict=False)
|
||||
mx.eval(transformer.parameters())
|
||||
console.print("[green]✓[/] Transformer loaded")
|
||||
|
||||
# Load VAE encoder and encode image for I2V conditioning
|
||||
stage1_image_latent = None
|
||||
stage2_image_latent = None
|
||||
if is_i2v:
|
||||
print(f"{Colors.BLUE}🖼️ Loading VAE encoder and encoding image...{Colors.RESET}")
|
||||
vae_encoder = load_vae_encoder(str(model_path / 'ltx-2-19b-distilled.safetensors'))
|
||||
mx.eval(vae_encoder.parameters())
|
||||
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
||||
vae_encoder = load_vae_encoder(str(model_path / 'ltx-2-19b-distilled.safetensors'))
|
||||
mx.eval(vae_encoder.parameters())
|
||||
|
||||
# Load and prepare image for stage 1 (half resolution)
|
||||
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
|
||||
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
|
||||
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
||||
mx.eval(stage1_image_latent)
|
||||
print(f" Stage 1 image latent: {stage1_image_latent.shape}")
|
||||
# Load and prepare image for stage 1 (half resolution)
|
||||
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
|
||||
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
|
||||
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
||||
mx.eval(stage1_image_latent)
|
||||
|
||||
# Load and prepare image for stage 2 (full resolution)
|
||||
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
||||
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
||||
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
||||
mx.eval(stage2_image_latent)
|
||||
print(f" Stage 2 image latent: {stage2_image_latent.shape}")
|
||||
# Load and prepare image for stage 2 (full resolution)
|
||||
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
||||
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
||||
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
||||
mx.eval(stage2_image_latent)
|
||||
|
||||
del vae_encoder
|
||||
mx.clear_cache()
|
||||
del vae_encoder
|
||||
mx.clear_cache()
|
||||
console.print("[green]✓[/] VAE encoder loaded and image encoded")
|
||||
|
||||
# Stage 1: Generate at half resolution
|
||||
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
|
||||
console.print(f"\n[bold yellow]⚡ Stage 1:[/] Generating at {width//2}x{height//2} (8 steps)")
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Position grids stay float32 for RoPE precision
|
||||
@@ -636,23 +648,24 @@ def generate_video(
|
||||
)
|
||||
|
||||
# Upsample latents
|
||||
print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}")
|
||||
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
|
||||
mx.eval(upsampler.parameters())
|
||||
with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"):
|
||||
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
|
||||
mx.eval(upsampler.parameters())
|
||||
|
||||
vae_decoder = load_vae_decoder(
|
||||
str(model_path / 'ltx-2-19b-distilled.safetensors'),
|
||||
timestep_conditioning=None # Auto-detect from model metadata
|
||||
)
|
||||
vae_decoder = load_vae_decoder(
|
||||
str(model_path / 'ltx-2-19b-distilled.safetensors'),
|
||||
timestep_conditioning=None # Auto-detect from model metadata
|
||||
)
|
||||
|
||||
latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
|
||||
mx.eval(latents)
|
||||
latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
|
||||
mx.eval(latents)
|
||||
|
||||
del upsampler
|
||||
mx.clear_cache()
|
||||
del upsampler
|
||||
mx.clear_cache()
|
||||
console.print("[green]✓[/] Latents upsampled")
|
||||
|
||||
# Stage 2: Refine at full resolution
|
||||
print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}")
|
||||
console.print(f"\n[bold yellow]⚡ Stage 2:[/] Refining at {width}x{height} (3 steps)")
|
||||
# Position grids stay float32 for RoPE precision
|
||||
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||
mx.eval(positions)
|
||||
@@ -717,7 +730,7 @@ def generate_video(
|
||||
mx.clear_cache()
|
||||
|
||||
# Decode to video with tiling
|
||||
print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}")
|
||||
console.print("\n[blue]🎞️ Decoding video...[/]")
|
||||
|
||||
# Select tiling configuration
|
||||
if tiling == "none":
|
||||
@@ -735,7 +748,7 @@ def generate_video(
|
||||
elif tiling == "temporal":
|
||||
tiling_config = TilingConfig.temporal_only()
|
||||
else:
|
||||
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
|
||||
console.print(f"[yellow] Unknown tiling mode '{tiling}', using auto[/]")
|
||||
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||
|
||||
# Save outputs
|
||||
@@ -744,13 +757,21 @@ def generate_video(
|
||||
|
||||
# Stream mode: write frames as they're decoded
|
||||
video_writer = None
|
||||
stream_pbar = None
|
||||
stream_progress = None
|
||||
|
||||
if stream and tiling_config is not None:
|
||||
import cv2
|
||||
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
||||
video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
|
||||
stream_pbar = tqdm(total=num_frames, desc="Streaming", unit="frame")
|
||||
stream_progress = Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
console=console,
|
||||
)
|
||||
stream_progress.start()
|
||||
stream_task = stream_progress.add_task("[cyan]Streaming frames[/]", total=num_frames)
|
||||
|
||||
def on_frames_ready(frames: mx.array, _start_idx: int):
|
||||
"""Callback to write frames as they're finalized."""
|
||||
@@ -763,17 +784,17 @@ def generate_video(
|
||||
|
||||
for frame in frames_np:
|
||||
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||
stream_pbar.update(1)
|
||||
stream_progress.advance(stream_task)
|
||||
else:
|
||||
on_frames_ready = None
|
||||
|
||||
if tiling_config is not None:
|
||||
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
|
||||
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
|
||||
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
|
||||
console.print(f"[dim] Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}[/]")
|
||||
video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose, on_frames_ready=on_frames_ready)
|
||||
else:
|
||||
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
|
||||
console.print("[dim] Tiling: disabled[/]")
|
||||
video = vae_decoder(latents)
|
||||
mx.eval(video)
|
||||
mx.clear_cache()
|
||||
@@ -781,9 +802,9 @@ def generate_video(
|
||||
# Close progressive video writer if used
|
||||
if video_writer is not None:
|
||||
video_writer.release()
|
||||
if stream_pbar is not None:
|
||||
stream_pbar.close()
|
||||
print(f"{Colors.GREEN}✅ Streamed video to{Colors.RESET} {output_path}")
|
||||
if stream_progress is not None:
|
||||
stream_progress.stop()
|
||||
console.print(f"[green]✅ Streamed video to[/] {output_path}")
|
||||
# Still need video_np for save_frames option
|
||||
video = mx.squeeze(video, axis=0)
|
||||
video = mx.transpose(video, (1, 2, 3, 0))
|
||||
@@ -815,45 +836,47 @@ def generate_video(
|
||||
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||
out.release()
|
||||
if not audio:
|
||||
print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}")
|
||||
console.print(f"[green]✅ Saved video to[/] {output_path}")
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}")
|
||||
console.print(f"[red]❌ Could not save video: {e}[/]")
|
||||
|
||||
# Decode and save audio if enabled
|
||||
audio_np = None
|
||||
if audio and audio_latents is not None:
|
||||
print(f"{Colors.BLUE}🔊 Decoding audio...{Colors.RESET}")
|
||||
audio_decoder = load_audio_decoder(model_path)
|
||||
vocoder = load_vocoder(model_path)
|
||||
mx.eval(audio_decoder.parameters(), vocoder.parameters())
|
||||
with console.status("[blue]🔊 Decoding audio...[/]", spinner="dots"):
|
||||
audio_decoder = load_audio_decoder(model_path)
|
||||
vocoder = load_vocoder(model_path)
|
||||
mx.eval(audio_decoder.parameters(), vocoder.parameters())
|
||||
|
||||
mel_spectrogram = audio_decoder(audio_latents)
|
||||
mx.eval(mel_spectrogram)
|
||||
mel_spectrogram = audio_decoder(audio_latents)
|
||||
mx.eval(mel_spectrogram)
|
||||
|
||||
audio_waveform = vocoder(mel_spectrogram)
|
||||
mx.eval(audio_waveform)
|
||||
audio_waveform = vocoder(mel_spectrogram)
|
||||
mx.eval(audio_waveform)
|
||||
|
||||
audio_np = np.array(audio_waveform)
|
||||
if audio_np.ndim == 3:
|
||||
audio_np = audio_np[0]
|
||||
audio_np = np.array(audio_waveform)
|
||||
if audio_np.ndim == 3:
|
||||
audio_np = audio_np[0]
|
||||
|
||||
del audio_decoder, vocoder
|
||||
mx.clear_cache()
|
||||
del audio_decoder, vocoder
|
||||
mx.clear_cache()
|
||||
console.print("[green]✓[/] Audio decoded")
|
||||
|
||||
# Save audio
|
||||
audio_path = Path(output_audio_path) if output_audio_path else output_path.with_suffix('.wav')
|
||||
save_audio(audio_np, audio_path, AUDIO_SAMPLE_RATE)
|
||||
print(f"{Colors.GREEN}✅ Saved audio to{Colors.RESET} {audio_path}")
|
||||
console.print(f"[green]✅ Saved audio to[/] {audio_path}")
|
||||
|
||||
# Mux video and audio
|
||||
print(f"{Colors.BLUE}🎬 Combining video and audio...{Colors.RESET}")
|
||||
temp_video_path = output_path.with_suffix('.temp.mp4')
|
||||
if mux_video_audio(temp_video_path, audio_path, output_path):
|
||||
print(f"{Colors.GREEN}✅ Saved video with audio to{Colors.RESET} {output_path}")
|
||||
with console.status("[blue]🎬 Combining video and audio...[/]", spinner="dots"):
|
||||
temp_video_path = output_path.with_suffix('.temp.mp4')
|
||||
success = mux_video_audio(temp_video_path, audio_path, output_path)
|
||||
if success:
|
||||
console.print(f"[green]✅ Saved video with audio to[/] {output_path}")
|
||||
temp_video_path.unlink()
|
||||
else:
|
||||
temp_video_path.rename(output_path)
|
||||
print(f"{Colors.YELLOW}⚠️ Saved video without audio to{Colors.RESET} {output_path}")
|
||||
console.print(f"[yellow]⚠️ Saved video without audio to[/] {output_path}")
|
||||
|
||||
del vae_decoder
|
||||
mx.clear_cache()
|
||||
@@ -863,11 +886,14 @@ def generate_video(
|
||||
frames_dir.mkdir(exist_ok=True)
|
||||
for i, frame in enumerate(video_np):
|
||||
Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png")
|
||||
print(f"{Colors.GREEN}✅ Saved {len(video_np)} frames to {frames_dir}{Colors.RESET}")
|
||||
console.print(f"[green]✅ Saved {len(video_np)} frames to {frames_dir}[/]")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}")
|
||||
print(f"{Colors.BOLD}{Colors.GREEN}✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}")
|
||||
console.print(Panel(
|
||||
f"[bold green]🎉 Done![/] Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)\n"
|
||||
f"[bold green]✨ Peak memory:[/] {mx.get_peak_memory() / (1024 ** 3):.2f}GB",
|
||||
expand=False
|
||||
))
|
||||
|
||||
if audio:
|
||||
return video_np, audio_np
|
||||
|
||||
@@ -19,7 +19,8 @@ dependencies = [
|
||||
"tqdm",
|
||||
"opencv-python>=4.12.0.88",
|
||||
"Pillow>=10.3.0",
|
||||
"mlx-vlm"
|
||||
"mlx-vlm",
|
||||
"rich>=14.2.0",
|
||||
]
|
||||
license = {text="MIT"}
|
||||
authors = [
|
||||
@@ -52,4 +53,4 @@ version = {attr = "mlx_video.version.__version__"}
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest",
|
||||
]
|
||||
]
|
||||
|
||||
36
uv.lock
generated
36
uv.lock
generated
@@ -635,6 +635,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "markdown-it-py"
|
||||
version = "4.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "mdurl" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "markupsafe"
|
||||
version = "3.0.3"
|
||||
@@ -709,6 +721,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/70/bc/6f1c2f612465f5fa89b95bead1f44dcb607670fd42891d8fdcd5d039f4f4/markupsafe-3.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:32001d6a8fc98c8cb5c947787c5d08b0a50663d139f1305bac5885d98d9b40fa", size = 14146, upload-time = "2025-09-27T18:37:28.327Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mdurl"
|
||||
version = "0.1.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx"
|
||||
version = "0.30.1"
|
||||
@@ -777,6 +798,7 @@ dependencies = [
|
||||
{ name = "numpy" },
|
||||
{ name = "opencv-python" },
|
||||
{ name = "pillow" },
|
||||
{ name = "rich" },
|
||||
{ name = "safetensors" },
|
||||
{ name = "tqdm" },
|
||||
{ name = "transformers", extra = ["tokenizers"] },
|
||||
@@ -796,6 +818,7 @@ requires-dist = [
|
||||
{ name = "opencv-python", specifier = ">=4.12.0.88" },
|
||||
{ name = "pillow", specifier = ">=10.3.0" },
|
||||
{ name = "pytest", marker = "extra == 'dev'" },
|
||||
{ name = "rich", specifier = ">=14.2.0" },
|
||||
{ name = "safetensors" },
|
||||
{ name = "tqdm" },
|
||||
{ name = "transformers", extras = ["tokenizers"] },
|
||||
@@ -1679,6 +1702,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rich"
|
||||
version = "14.2.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "markdown-it-py" },
|
||||
{ name = "pygments" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/fb/d2/8920e102050a0de7bfabeb4c4614a49248cf8d5d7a8d01885fbb24dc767a/rich-14.2.0.tar.gz", hash = "sha256:73ff50c7c0c1c77c8243079283f4edb376f0f6442433aecb8ce7e6d0b92d1fe4", size = 219990, upload-time = "2025-10-09T14:16:53.064Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/25/7a/b0178788f8dc6cafce37a212c99565fa1fe7872c70c6c9c1e1a372d9d88f/rich-14.2.0-py3-none-any.whl", hash = "sha256:76bc51fe2e57d2b1be1f96c524b890b816e334ab4c1e45888799bfaab0021edd", size = 243393, upload-time = "2025-10-09T14:16:51.245Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "safetensors"
|
||||
version = "0.7.0"
|
||||
|
||||
Reference in New Issue
Block a user