From cae11291a9f86625a812e1684cd0dbbcaf3d8caf Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 19 Jan 2026 01:28:53 +0100 Subject: [PATCH] Remove the audio-video generation pipeline from generate_av.py and integrate audio capabilities into generate.py. This includes adding audio position grid creation, audio frame computation, and updating the denoising function to handle audio latents. Enhance the command-line interface to support audio generation options and update the model configuration accordingly. --- mlx_video/generate.py | 412 ++++++++++++++++++-- mlx_video/generate_av.py | 821 --------------------------------------- 2 files changed, 377 insertions(+), 856 deletions(-) delete mode 100644 mlx_video/generate_av.py diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 9a72fe9..a9dcf85 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -1,14 +1,15 @@ + import argparse import time from pathlib import Path from typing import Optional import mlx.core as mx -import mlx.nn as nn import numpy as np from PIL import Image from tqdm import tqdm + # ANSI color codes class Colors: CYAN = "\033[96m" @@ -21,25 +22,33 @@ class Colors: DIM = "\033[2m" RESET = "\033[0m" + from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType from mlx_video.models.ltx.ltx import LTXModel from mlx_video.models.ltx.transformer import Modality -from mlx_video.convert import sanitize_transformer_weights, sanitize_vae_encoder_weights -from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding +from mlx_video.convert import sanitize_transformer_weights +from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding, get_model_path from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder from mlx_video.models.ltx.video_vae.tiling import TilingConfig from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning -from mlx_video.conditioning.latent import LatentState, create_initial_state, apply_denoise_mask, add_noise_with_state - -from mlx_video.utils import get_model_path +from mlx_video.conditioning.latent import LatentState, apply_denoise_mask # Distilled sigma schedules STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0] +# Audio constants +AUDIO_SAMPLE_RATE = 24000 # Output audio sample rate +AUDIO_LATENT_SAMPLE_RATE = 16000 # VAE internal sample rate +AUDIO_HOP_LENGTH = 160 +AUDIO_LATENT_DOWNSAMPLE_FACTOR = 4 +AUDIO_LATENT_CHANNELS = 8 # Latent channels before patchifying +AUDIO_MEL_BINS = 16 +AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25 + def create_position_grid( batch_size: int, @@ -115,6 +124,43 @@ def create_position_grid( return mx.array(pixel_coords, dtype=mx.float32) +def create_audio_position_grid( + batch_size: int, + audio_frames: int, + sample_rate: int = AUDIO_LATENT_SAMPLE_RATE, + hop_length: int = AUDIO_HOP_LENGTH, + downsample_factor: int = AUDIO_LATENT_DOWNSAMPLE_FACTOR, + is_causal: bool = True, +) -> mx.array: + """Create temporal position grid for audio RoPE. + + Audio positions are timestamps in seconds, shape (B, 1, T, 2). + Matches PyTorch's AudioPatchifier.get_patch_grid_bounds exactly. + """ + def get_audio_latent_time_in_sec(start_idx: int, end_idx: int) -> np.ndarray: + """Convert latent indices to seconds.""" + latent_frame = np.arange(start_idx, end_idx, dtype=np.float32) + mel_frame = latent_frame * downsample_factor + if is_causal: + mel_frame = np.clip(mel_frame + 1 - downsample_factor, 0, None) + return mel_frame * hop_length / sample_rate + + start_times = get_audio_latent_time_in_sec(0, audio_frames) + end_times = get_audio_latent_time_in_sec(1, audio_frames + 1) + + positions = np.stack([start_times, end_times], axis=-1) + positions = positions[np.newaxis, np.newaxis, :, :] # (1, 1, T, 2) + positions = np.tile(positions, (batch_size, 1, 1, 1)) + + return mx.array(positions, dtype=mx.float32) + + +def compute_audio_frames(num_video_frames: int, fps: float) -> int: + """Compute number of audio latent frames given video duration.""" + duration = num_video_frames / fps + return round(duration * AUDIO_LATENTS_PER_SECOND) + + def denoise( latents: mx.array, positions: mx.array, @@ -123,27 +169,37 @@ def denoise( sigmas: list, verbose: bool = True, state: Optional[LatentState] = None, -) -> mx.array: - """Run denoising loop with optional conditioning. + # Audio parameters (optional) + audio_latents: Optional[mx.array] = None, + audio_positions: Optional[mx.array] = None, + audio_embeddings: Optional[mx.array] = None, +) -> tuple[mx.array, Optional[mx.array]]: + """Run denoising loop with optional conditioning and optional audio. Args: - latents: Noisy latent tensor (B, C, F, H, W) - positions: Position embeddings - text_embeddings: Text conditioning embeddings + latents: Noisy video latent tensor (B, C, F, H, W) + positions: Video position embeddings + text_embeddings: Video text conditioning embeddings transformer: LTX model sigmas: List of sigma values for denoising schedule verbose: Whether to show progress bar state: Optional LatentState for I2V conditioning + audio_latents: Optional audio latent tensor (B, C, T, F) for audio generation + audio_positions: Optional audio position embeddings + audio_embeddings: Optional audio text embeddings Returns: - Denoised latent tensor + Tuple of (video_latents, audio_latents) - audio_latents is None if audio disabled """ - # If state is provided, use its latent (which may have conditioning applied) dtype = latents.dtype + enable_audio = audio_latents is not None + + # If state is provided, use its latent (which may have conditioning applied) if state is not None: latents = state.latent - for i in tqdm(range(len(sigmas) - 1), desc="Denoising", disable=not verbose): + desc = "Denoising A/V" if enable_audio else "Denoising" + for i in tqdm(range(len(sigmas) - 1), desc=desc, disable=not verbose): sigma, sigma_next = sigmas[i], sigmas[i + 1] b, c, f, h, w = latents.shape @@ -172,28 +228,163 @@ def denoise( enabled=True, ) - velocity, _ = transformer(video=video_modality, audio=None) + # Prepare audio modality if enabled + audio_modality = None + if enable_audio: + ab, ac, at, af = audio_latents.shape + audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F) + audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)) + + audio_modality = Modality( + latent=audio_flat, + timesteps=mx.full((ab, at), sigma, dtype=dtype), + positions=audio_positions, + context=audio_embeddings, + context_mask=None, + enabled=True, + ) + + velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality) mx.eval(velocity) + if audio_velocity is not None: + mx.eval(audio_velocity) velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w)) denoised = to_denoised(latents, velocity, sigma) + # Handle audio velocity if enabled + audio_denoised = None + if enable_audio and audio_velocity is not None: + ab, ac, at, af = audio_latents.shape + audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af)) + audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) # (B, C, T, F) + audio_denoised = to_denoised(audio_latents, audio_velocity, sigma) + # Apply conditioning mask if state is provided if state is not None: denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask) mx.eval(denoised) + if audio_denoised is not None: + mx.eval(audio_denoised) # Euler step (preserve dtype by converting Python floats to arrays) if sigma_next > 0: sigma_next_arr = mx.array(sigma_next, dtype=dtype) sigma_arr = mx.array(sigma, dtype=dtype) latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr + if enable_audio and audio_denoised is not None: + audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr else: latents = denoised - mx.eval(latents) + if enable_audio and audio_denoised is not None: + audio_latents = audio_denoised - return latents + mx.eval(latents) + if enable_audio: + mx.eval(audio_latents) + + return latents, audio_latents if enable_audio else None + + +def load_audio_decoder(model_path: Path): + """Load audio VAE decoder.""" + from mlx_video.models.ltx.audio_vae import AudioDecoder, CausalityAxis, NormType + from mlx_video.convert import sanitize_audio_vae_weights + + decoder = AudioDecoder( + ch=128, + out_ch=2, # stereo + ch_mult=(1, 2, 4), + num_res_blocks=2, + attn_resolutions=set(), + resolution=256, + z_channels=AUDIO_LATENT_CHANNELS, + norm_type=NormType.PIXEL, + causality_axis=CausalityAxis.HEIGHT, + mel_bins=64, + ) + + weight_file = model_path / "ltx-2-19b-distilled.safetensors" + if weight_file.exists(): + raw_weights = mx.load(str(weight_file)) + sanitized = sanitize_audio_vae_weights(raw_weights) + if sanitized: + decoder.load_weights(list(sanitized.items()), strict=False) + + if "per_channel_statistics._mean_of_means" in sanitized: + decoder.per_channel_statistics._mean_of_means = sanitized["per_channel_statistics._mean_of_means"] + if "per_channel_statistics._std_of_means" in sanitized: + decoder.per_channel_statistics._std_of_means = sanitized["per_channel_statistics._std_of_means"] + + return decoder + + +def load_vocoder(model_path: Path): + """Load vocoder for mel to waveform conversion.""" + from mlx_video.models.ltx.audio_vae import Vocoder + from mlx_video.convert import sanitize_vocoder_weights + + vocoder = Vocoder( + resblock_kernel_sizes=[3, 7, 11], + upsample_rates=[6, 5, 2, 2, 2], + upsample_kernel_sizes=[16, 15, 8, 4, 4], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_initial_channel=1024, + stereo=True, + output_sample_rate=AUDIO_SAMPLE_RATE, + ) + + weight_file = model_path / "ltx-2-19b-distilled.safetensors" + if weight_file.exists(): + raw_weights = mx.load(str(weight_file)) + sanitized = sanitize_vocoder_weights(raw_weights) + if sanitized: + vocoder.load_weights(list(sanitized.items()), strict=False) + + return vocoder + + +def save_audio(audio: np.ndarray, path: Path, sample_rate: int = AUDIO_SAMPLE_RATE): + """Save audio to WAV file.""" + import wave + + if audio.ndim == 2: + audio = audio.T # (channels, samples) -> (samples, channels) + + audio = np.clip(audio, -1.0, 1.0) + audio_int16 = (audio * 32767).astype(np.int16) + + with wave.open(str(path), 'wb') as wf: + wf.setnchannels(2 if audio_int16.ndim == 2 else 1) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(audio_int16.tobytes()) + + +def mux_video_audio(video_path: Path, audio_path: Path, output_path: Path): + """Combine video and audio into final output using ffmpeg.""" + import subprocess + + cmd = [ + "ffmpeg", "-y", + "-i", str(video_path), + "-i", str(audio_path), + "-c:v", "copy", + "-c:a", "aac", + "-shortest", + str(output_path) + ] + + try: + subprocess.run(cmd, check=True, capture_output=True) + return True + except subprocess.CalledProcessError as e: + print(f"{Colors.RED}FFmpeg error: {e.stderr.decode()}{Colors.RESET}") + return False + except FileNotFoundError: + print(f"{Colors.RED}FFmpeg not found. Please install ffmpeg.{Colors.RESET}") + return False def generate_video( @@ -216,8 +407,11 @@ def generate_video( image_frame_idx: int = 0, tiling: str = "auto", stream: bool = False, + # Audio options + audio: bool = False, + output_audio_path: Optional[str] = None, ): - """Generate video from text prompt, optionally conditioned on an image. + """Generate video from text prompt, optionally conditioned on an image and with audio. Args: model_repo: Model repository ID @@ -245,26 +439,37 @@ def generate_video( - "conservative": 768px spatial, 96 frame temporal (faster) - "spatial": Spatial tiling only - "temporal": Temporal tiling only + stream: Stream frames to output as they're decoded (requires tiling) + audio: Enable synchronized audio generation + output_audio_path: Path to save audio file (default: same as video with .wav) """ start_time = time.time() # Validate dimensions assert height % 64 == 0, f"Height must be divisible by 64, got {height}" assert width % 64 == 0, f"Width must be divisible by 64, got {width}" - + if num_frames % 8 != 1: adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1 print(f"{Colors.YELLOW}⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}") num_frames = adjusted_num_frames - is_i2v = image is not None mode_str = "I2V" if is_i2v else "T2V" + if audio: + mode_str += "+Audio" + print(f"{Colors.BOLD}{Colors.CYAN}🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames{Colors.RESET}") print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}") if is_i2v: print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}") + # Calculate audio frames if enabled + audio_frames = None + if audio: + audio_frames = compute_audio_frames(num_frames, fps) + print(f"{Colors.DIM}Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz{Colors.RESET}") + # Get model path model_path = get_model_path(model_repo) text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo) @@ -289,22 +494,32 @@ def generate_video( prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose) print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}") - text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False) + # Get embeddings - with audio if enabled + if audio: + text_embeddings, audio_embeddings = text_encoder(prompt, return_audio_embeddings=True) + mx.eval(text_embeddings, audio_embeddings) + else: + text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False) + audio_embeddings = None + mx.eval(text_embeddings) + model_dtype = text_embeddings.dtype # bfloat16 from text encoder - mx.eval(text_embeddings) del text_encoder mx.clear_cache() # Load transformer - print(f"{Colors.BLUE}🤖 Loading transformer...{Colors.RESET}") + print(f"{Colors.BLUE}🤖 Loading transformer{' (A/V mode)' if audio else ''}...{Colors.RESET}") raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors')) sanitized = sanitize_transformer_weights(raw_weights) # Convert transformer weights to bfloat16 for memory efficiency sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()} - config = LTXModelConfig( - model_type=LTXModelType.VideoOnly, + # Configure model type based on audio flag + model_type = LTXModelType.AudioVideo if audio else LTXModelType.VideoOnly + + config_kwargs = dict( + model_type=model_type, num_attention_heads=32, attention_head_dim=128, in_channels=128, @@ -320,7 +535,19 @@ def generate_video( timestep_scale_multiplier=1000, ) - transformer = LTXModel(config) + if audio: + config_kwargs.update( + audio_num_attention_heads=32, + audio_attention_head_dim=64, + audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, # 8 * 16 = 128 + audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, + audio_cross_attention_dim=2048, + audio_positional_embedding_max_pos=[20], + ) + + config = LTXModelConfig(**config_kwargs) + + transformer = LTXModel(config) transformer.load_weights(list(sanitized.items()), strict=False) mx.eval(transformer.parameters()) @@ -357,6 +584,14 @@ def generate_video( positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) mx.eval(positions) + # Create audio positions if enabled + audio_positions = None + audio_latents = None + if audio: + audio_positions = create_audio_position_grid(1, audio_frames) + audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype) + mx.eval(audio_positions, audio_latents) + # Apply I2V conditioning if provided state1 = None if is_i2v and stage1_image_latent is not None: @@ -394,7 +629,11 @@ def generate_video( latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w), dtype=model_dtype) mx.eval(latents) - latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1) + latents, audio_latents = denoise( + latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, + verbose=verbose, state=state1, + audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings, + ) # Upsample latents print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}") @@ -447,6 +686,13 @@ def generate_video( ) latents = state2.latent mx.eval(latents) + + # Audio also gets noise for stage 2 if enabled + if audio and audio_latents is not None: + audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) + one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale + audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale + mx.eval(audio_latents) else: # T2V: add noise to all frames for refinement noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) @@ -455,7 +701,17 @@ def generate_video( latents = noise * noise_scale + latents * one_minus_scale mx.eval(latents) - latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2) + # Audio also gets noise for stage 2 if enabled + if audio and audio_latents is not None: + audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) + audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale + mx.eval(audio_latents) + + latents, audio_latents = denoise( + latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, + verbose=verbose, state=state2, + audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings, + ) del transformer mx.clear_cache() @@ -496,7 +752,7 @@ def generate_video( video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height)) stream_pbar = tqdm(total=num_frames, desc="Streaming", unit="frame") - def on_frames_ready(frames: mx.array, start_idx: int): + def on_frames_ready(frames: mx.array, _start_idx: int): """Callback to write frames as they're finalized.""" # frames: (B, 3, num_frames, H, W) frames = mx.squeeze(frames, axis=0) # (3, num_frames, H, W) @@ -542,19 +798,66 @@ def generate_video( video = (video * 255).astype(mx.uint8) video_np = np.array(video) - # Save video normally + # For audio mode, save to temp file first + if audio: + temp_video_path = output_path.with_suffix('.temp.mp4') + save_path = temp_video_path + else: + save_path = output_path + + # Save video try: import cv2 h, w = video_np.shape[1], video_np.shape[2] fourcc = cv2.VideoWriter_fourcc(*'avc1') - out = cv2.VideoWriter(str(output_path), fourcc, fps, (w, h)) + out = cv2.VideoWriter(str(save_path), fourcc, fps, (w, h)) for frame in video_np: out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) out.release() - print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}") + if not audio: + print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}") except Exception as e: print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}") + # Decode and save audio if enabled + audio_np = None + if audio and audio_latents is not None: + print(f"{Colors.BLUE}🔊 Decoding audio...{Colors.RESET}") + audio_decoder = load_audio_decoder(model_path) + vocoder = load_vocoder(model_path) + mx.eval(audio_decoder.parameters(), vocoder.parameters()) + + mel_spectrogram = audio_decoder(audio_latents) + mx.eval(mel_spectrogram) + + audio_waveform = vocoder(mel_spectrogram) + mx.eval(audio_waveform) + + audio_np = np.array(audio_waveform) + if audio_np.ndim == 3: + audio_np = audio_np[0] + + del audio_decoder, vocoder + mx.clear_cache() + + # Save audio + audio_path = Path(output_audio_path) if output_audio_path else output_path.with_suffix('.wav') + save_audio(audio_np, audio_path, AUDIO_SAMPLE_RATE) + print(f"{Colors.GREEN}✅ Saved audio to{Colors.RESET} {audio_path}") + + # Mux video and audio + print(f"{Colors.BLUE}🎬 Combining video and audio...{Colors.RESET}") + temp_video_path = output_path.with_suffix('.temp.mp4') + if mux_video_audio(temp_video_path, audio_path, output_path): + print(f"{Colors.GREEN}✅ Saved video with audio to{Colors.RESET} {output_path}") + temp_video_path.unlink() + else: + temp_video_path.rename(output_path) + print(f"{Colors.YELLOW}⚠️ Saved video without audio to{Colors.RESET} {output_path}") + + del vae_decoder + mx.clear_cache() + if save_frames: frames_dir = output_path.parent / f"{output_path.stem}_frames" frames_dir.mkdir(exist_ok=True) @@ -566,12 +869,14 @@ def generate_video( print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}") print(f"{Colors.BOLD}{Colors.GREEN}✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}") + if audio: + return video_np, audio_np return video_np def main(): parser = argparse.ArgumentParser( - description="Generate videos with MLX LTX-2 (T2V and I2V)", + description="Generate videos with MLX LTX-2 (T2V, I2V, and Audio)", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: @@ -583,6 +888,11 @@ Examples: # Image-to-Video (I2V) python -m mlx_video.generate --prompt "A person dancing" --image photo.jpg python -m mlx_video.generate --prompt "Waves crashing" --image beach.png --image-strength 0.8 + + # With Audio (T2V+Audio or I2V+Audio) + python -m mlx_video.generate --prompt "Ocean waves crashing" --audio + python -m mlx_video.generate --prompt "A jazz band playing" --audio --enhance-prompt + python -m mlx_video.generate --prompt "Waves crashing" --image beach.png --audio """ ) @@ -623,7 +933,7 @@ Examples: help="Frames per second for output video (default: 24)" ) parser.add_argument( - "--output-path", + "--output-path", "-o", type=str, default="output.mp4", help="Output video path (default: output.mp4)" @@ -699,10 +1009,42 @@ Examples: action="store_true", help="Stream frames to output file as they're decoded (requires tiling). Allows viewing partial results sooner." ) + # Audio options + parser.add_argument( + "--audio", "-a", + action="store_true", + help="Enable synchronized audio generation" + ) + parser.add_argument( + "--output-audio", + type=str, + default=None, + help="Output audio path (default: same as video with .wav)" + ) args = parser.parse_args() generate_video( - **vars(args) + model_repo=args.model_repo, + text_encoder_repo=args.text_encoder_repo, + prompt=args.prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, + seed=args.seed, + fps=args.fps, + output_path=args.output_path, + save_frames=args.save_frames, + verbose=args.verbose, + enhance_prompt=args.enhance_prompt, + max_tokens=args.max_tokens, + temperature=args.temperature, + image=args.image, + image_strength=args.image_strength, + image_frame_idx=args.image_frame_idx, + tiling=args.tiling, + stream=args.stream, + audio=args.audio, + output_audio_path=args.output_audio, ) diff --git a/mlx_video/generate_av.py b/mlx_video/generate_av.py deleted file mode 100644 index 56d182a..0000000 --- a/mlx_video/generate_av.py +++ /dev/null @@ -1,821 +0,0 @@ -"""Audio-Video generation pipeline for LTX-2.""" - -import argparse -import time -from pathlib import Path -from typing import Optional - -import mlx.core as mx -import numpy as np -from tqdm import tqdm - - -# ANSI color codes -class Colors: - CYAN = "\033[96m" - BLUE = "\033[94m" - GREEN = "\033[92m" - YELLOW = "\033[93m" - RED = "\033[91m" - MAGENTA = "\033[95m" - BOLD = "\033[1m" - DIM = "\033[2m" - RESET = "\033[0m" - - -from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType -from mlx_video.models.ltx.ltx import LTXModel -from mlx_video.models.ltx.transformer import Modality -from mlx_video.convert import sanitize_transformer_weights, sanitize_audio_vae_weights, sanitize_vocoder_weights -from mlx_video.utils import to_denoised, get_model_path, load_image, prepare_image_for_encoding -from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder -from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder -from mlx_video.models.ltx.video_vae.tiling import TilingConfig -from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents -from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning -from mlx_video.conditioning.latent import LatentState, apply_denoise_mask - - -# Distilled sigma schedules -STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] -STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0] - -# Audio constants -AUDIO_SAMPLE_RATE = 24000 # Output audio sample rate -AUDIO_LATENT_SAMPLE_RATE = 16000 # VAE internal sample rate -AUDIO_HOP_LENGTH = 160 -AUDIO_LATENT_DOWNSAMPLE_FACTOR = 4 -AUDIO_LATENT_CHANNELS = 8 # Latent channels before patchifying -AUDIO_MEL_BINS = 16 -AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25 - - -def create_video_position_grid( - batch_size: int, - num_frames: int, - height: int, - width: int, - temporal_scale: int = 8, - spatial_scale: int = 32, - fps: float = 24.0, - causal_fix: bool = True, -) -> mx.array: - """Create position grid for video RoPE in pixel space.""" - patch_size_t, patch_size_h, patch_size_w = 1, 1, 1 - - t_coords = np.arange(0, num_frames, patch_size_t) - h_coords = np.arange(0, height, patch_size_h) - w_coords = np.arange(0, width, patch_size_w) - - t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij') - patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0) - - patch_size_delta = np.array([patch_size_t, patch_size_h, patch_size_w]).reshape(3, 1, 1, 1) - patch_ends = patch_starts + patch_size_delta - - latent_coords = np.stack([patch_starts, patch_ends], axis=-1) - num_patches = num_frames * height * width - latent_coords = latent_coords.reshape(3, num_patches, 2) - latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1)) - - scale_factors = np.array([temporal_scale, spatial_scale, spatial_scale]).reshape(1, 3, 1, 1) - pixel_coords = (latent_coords * scale_factors).astype(np.float32) - - if causal_fix: - pixel_coords[:, 0, :, :] = np.clip( - pixel_coords[:, 0, :, :] + 1 - temporal_scale, - a_min=0, - a_max=None - ) - - pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps - - return mx.array(pixel_coords, dtype=mx.float32) - - -def create_audio_position_grid( - batch_size: int, - audio_frames: int, - sample_rate: int = AUDIO_LATENT_SAMPLE_RATE, - hop_length: int = AUDIO_HOP_LENGTH, - downsample_factor: int = AUDIO_LATENT_DOWNSAMPLE_FACTOR, - is_causal: bool = True, -) -> mx.array: - """Create temporal position grid for audio RoPE. - - Audio positions are timestamps in seconds, shape (B, 1, T, 2). - Matches PyTorch's AudioPatchifier.get_patch_grid_bounds exactly. - """ - def get_audio_latent_time_in_sec(start_idx: int, end_idx: int) -> np.ndarray: - """Convert latent indices to seconds (matching PyTorch's _get_audio_latent_time_in_sec).""" - latent_frame = np.arange(start_idx, end_idx, dtype=np.float32) - mel_frame = latent_frame * downsample_factor - if is_causal: - # Frame offset for causal alignment (PyTorch uses +1 - downsample_factor) - mel_frame = np.clip(mel_frame + 1 - downsample_factor, 0, None) - return mel_frame * hop_length / sample_rate - - # Start times: latent indices 0 to audio_frames - start_times = get_audio_latent_time_in_sec(0, audio_frames) - - # End times: latent indices 1 to audio_frames+1 (shifted by 1) - end_times = get_audio_latent_time_in_sec(1, audio_frames + 1) - - # Shape: (B, 1, T, 2) - positions = np.stack([start_times, end_times], axis=-1) - positions = positions[np.newaxis, np.newaxis, :, :] # (1, 1, T, 2) - positions = np.tile(positions, (batch_size, 1, 1, 1)) - - return mx.array(positions, dtype=mx.float32) - - -def compute_audio_frames(num_video_frames: int, fps: float) -> int: - """Compute number of audio latent frames given video duration.""" - duration = num_video_frames / fps - return round(duration * AUDIO_LATENTS_PER_SECOND) - - -def denoise_av( - video_latents: mx.array, - audio_latents: mx.array, - video_positions: mx.array, - audio_positions: mx.array, - video_embeddings: mx.array, - audio_embeddings: mx.array, - transformer: LTXModel, - sigmas: list, - verbose: bool = True, - video_state: Optional[LatentState] = None, -) -> tuple[mx.array, mx.array]: - """Run denoising loop for audio-video generation with optional I2V conditioning. - - Args: - video_latents: Video latent tensor (B, C, F, H, W) - audio_latents: Audio latent tensor (B, C, T, F) - video_positions: Video position embeddings - audio_positions: Audio position embeddings - video_embeddings: Video text embeddings - audio_embeddings: Audio text embeddings - transformer: LTX model - sigmas: List of sigma values - verbose: Whether to show progress bar - video_state: Optional LatentState for I2V conditioning - - Returns: - Tuple of (video_latents, audio_latents) - """ - dtype = video_latents.dtype - # If video state is provided, use its latent - if video_state is not None: - video_latents = video_state.latent - - for i in tqdm(range(len(sigmas) - 1), desc="Denoising A/V", disable=not verbose): - sigma, sigma_next = sigmas[i], sigmas[i + 1] - - # Flatten video latents - b, c, f, h, w = video_latents.shape - num_video_tokens = f * h * w - video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1)) - - # Flatten audio latents: (B, C, T, F) -> (B, T, C*F) - ab, ac, at, af = audio_latents.shape - audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F) - audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)) - - # Compute per-token timesteps for video - # For I2V: conditioned tokens get timestep=0 (mask=0), unconditioned get timestep=sigma (mask=1) - if video_state is not None: - # Reshape denoise_mask from (B, 1, F, 1, 1) to (B, num_tokens) - denoise_mask_flat = mx.reshape(video_state.denoise_mask, (b, 1, f, 1, 1)) - denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w)) - denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens)) - # Per-token timesteps: sigma * mask - video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat - else: - # All tokens get the same timestep - video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype) - - video_modality = Modality( - latent=video_flat, - timesteps=video_timesteps, - positions=video_positions, - context=video_embeddings, - context_mask=None, - enabled=True, - ) - - audio_modality = Modality( - latent=audio_flat, - timesteps=mx.full((ab, at), sigma, dtype=dtype), - positions=audio_positions, - context=audio_embeddings, - context_mask=None, - enabled=True, - ) - - video_velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality) - mx.eval(video_velocity, audio_velocity) - - # Reshape velocities back - video_velocity = mx.reshape(mx.transpose(video_velocity, (0, 2, 1)), (b, c, f, h, w)) - audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af)) - audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) # (B, C, T, F) - - # Compute denoised - video_denoised = to_denoised(video_latents, video_velocity, sigma) - audio_denoised = to_denoised(audio_latents, audio_velocity, sigma) - - # Apply conditioning mask for video if state is provided - if video_state is not None: - video_denoised = apply_denoise_mask(video_denoised, video_state.clean_latent, video_state.denoise_mask) - - mx.eval(video_denoised, audio_denoised) - - # Euler step - use dtype-preserving arrays to avoid float32 promotion - if sigma_next > 0: - sigma_next_arr = mx.array(sigma_next, dtype=dtype) - sigma_arr = mx.array(sigma, dtype=dtype) - video_latents = video_denoised + sigma_next_arr * (video_latents - video_denoised) / sigma_arr - audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr - else: - video_latents = video_denoised - audio_latents = audio_denoised - mx.eval(video_latents, audio_latents) - - return video_latents, audio_latents - - -def load_audio_decoder(model_path: Path): - """Load audio VAE decoder.""" - from mlx_video.models.ltx.audio_vae import AudioDecoder, CausalityAxis, NormType - - decoder = AudioDecoder( - ch=128, - out_ch=2, # stereo - ch_mult=(1, 2, 4), - num_res_blocks=2, - attn_resolutions=set(), # PyTorch uses empty set (no attention in audio decoder) - resolution=256, - z_channels=AUDIO_LATENT_CHANNELS, - norm_type=NormType.PIXEL, - causality_axis=CausalityAxis.HEIGHT, - mel_bins=64, # Output mel bins - ) - - # Load weights from main model file - weight_file = model_path / "ltx-2-19b-distilled.safetensors" - if weight_file.exists(): - raw_weights = mx.load(str(weight_file)) - sanitized = sanitize_audio_vae_weights(raw_weights) - if sanitized: - decoder.load_weights(list(sanitized.items()), strict=False) - - # Manually load per-channel statistics (they're plain mx.array, not tracked by load_weights) - if "per_channel_statistics._mean_of_means" in sanitized: - decoder.per_channel_statistics._mean_of_means = sanitized["per_channel_statistics._mean_of_means"] - if "per_channel_statistics._std_of_means" in sanitized: - decoder.per_channel_statistics._std_of_means = sanitized["per_channel_statistics._std_of_means"] - - return decoder - - -def load_vocoder(model_path: Path): - """Load vocoder for mel to waveform conversion.""" - from mlx_video.models.ltx.audio_vae import Vocoder - - vocoder = Vocoder( - resblock_kernel_sizes=[3, 7, 11], - upsample_rates=[6, 5, 2, 2, 2], - upsample_kernel_sizes=[16, 15, 8, 4, 4], - resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], - upsample_initial_channel=1024, - stereo=True, - output_sample_rate=AUDIO_SAMPLE_RATE, - ) - - # Load weights - weight_file = model_path / "ltx-2-19b-distilled.safetensors" - if weight_file.exists(): - raw_weights = mx.load(str(weight_file)) - sanitized = sanitize_vocoder_weights(raw_weights) - if sanitized: - vocoder.load_weights(list(sanitized.items()), strict=False) - - return vocoder - - -def save_audio(audio: np.ndarray, path: Path, sample_rate: int = AUDIO_SAMPLE_RATE): - """Save audio to WAV file.""" - import wave - - # Ensure audio is in correct format (channels, samples) or (samples,) - if audio.ndim == 2: - # (channels, samples) -> (samples, channels) - audio = audio.T - - # Normalize and convert to int16 - audio = np.clip(audio, -1.0, 1.0) - audio_int16 = (audio * 32767).astype(np.int16) - - with wave.open(str(path), 'wb') as wf: - wf.setnchannels(2 if audio_int16.ndim == 2 else 1) - wf.setsampwidth(2) # 16-bit - wf.setframerate(sample_rate) - wf.writeframes(audio_int16.tobytes()) - - -def mux_video_audio(video_path: Path, audio_path: Path, output_path: Path): - """Combine video and audio into final output using ffmpeg.""" - import subprocess - - cmd = [ - "ffmpeg", "-y", - "-i", str(video_path), - "-i", str(audio_path), - "-c:v", "copy", - "-c:a", "aac", - "-shortest", - str(output_path) - ] - - try: - subprocess.run(cmd, check=True, capture_output=True) - return True - except subprocess.CalledProcessError as e: - print(f"{Colors.RED}FFmpeg error: {e.stderr.decode()}{Colors.RESET}") - return False - except FileNotFoundError: - print(f"{Colors.RED}FFmpeg not found. Please install ffmpeg.{Colors.RESET}") - return False - - -def generate_video_with_audio( - model_repo: str, - text_encoder_repo: Optional[str], - prompt: str, - height: int = 512, - width: int = 512, - num_frames: int = 33, - seed: int = 42, - fps: int = 24, - output_path: str = "output_av.mp4", - output_audio_path: Optional[str] = None, - verbose: bool = True, - enhance_prompt: bool = False, - max_tokens: int = 512, - temperature: float = 0.7, - image: Optional[str] = None, - image_strength: float = 1.0, - image_frame_idx: int = 0, - tiling: str = "auto", -): - """Generate video with synchronized audio from text prompt, optionally conditioned on an image. - - Args: - model_repo: Model repository ID - text_encoder_repo: Text encoder repository ID - prompt: Text description of the video to generate - height: Output video height (must be divisible by 64) - width: Output video width (must be divisible by 64) - num_frames: Number of frames - seed: Random seed - fps: Frames per second - output_path: Output video path - output_audio_path: Output audio path - verbose: Whether to print progress - enhance_prompt: Whether to enhance prompt using Gemma - max_tokens: Max tokens for prompt enhancement - temperature: Temperature for prompt enhancement - image: Path to conditioning image for I2V - image_strength: Conditioning strength (1.0 = full denoise) - image_frame_idx: Frame index to condition (0 = first frame) - tiling: Tiling mode for VAE decoding (auto/none/default/aggressive/conservative/spatial/temporal) - """ - start_time = time.time() - - # Validate dimensions - assert height % 64 == 0, f"Height must be divisible by 64, got {height}" - assert width % 64 == 0, f"Width must be divisible by 64, got {width}" - - if num_frames % 8 != 1: - adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1 - print(f"{Colors.YELLOW}⚠️ Adjusted frames to {adjusted_num_frames}{Colors.RESET}") - num_frames = adjusted_num_frames - - # Calculate audio frames - audio_frames = compute_audio_frames(num_frames, fps) - - is_i2v = image is not None - mode_str = "I2V+Audio" if is_i2v else "T2V+Audio" - print(f"{Colors.BOLD}{Colors.CYAN}🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames + audio{Colors.RESET}") - print(f"{Colors.DIM}Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz{Colors.RESET}") - print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}") - if is_i2v: - print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}") - - model_path = get_model_path(model_repo) - text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo) - - # Calculate latent dimensions - stage1_h, stage1_w = height // 2 // 32, width // 2 // 32 - stage2_h, stage2_w = height // 32, width // 32 - latent_frames = 1 + (num_frames - 1) // 8 - - mx.random.seed(seed) - - # Load text encoder with audio embeddings - print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}") - from mlx_video.models.ltx.text_encoder import LTX2TextEncoder - text_encoder = LTX2TextEncoder() - text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path) - mx.eval(text_encoder.parameters()) - - # Optionally enhance prompt - if enhance_prompt: - print(f"{Colors.MAGENTA}✨ Enhancing prompt...{Colors.RESET}") - prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose) - print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}") - - # Get both video and audio embeddings - video_embeddings, audio_embeddings = text_encoder(prompt) - model_dtype = video_embeddings.dtype # bfloat16 from text encoder - mx.eval(video_embeddings, audio_embeddings) - - del text_encoder - mx.clear_cache() - - # Load transformer with AudioVideo config - print(f"{Colors.BLUE}🤖 Loading transformer (A/V mode)...{Colors.RESET}") - raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors')) - sanitized = sanitize_transformer_weights(raw_weights) - - # Convert transformer weights to bfloat16 for memory efficiency - sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()} - - config = LTXModelConfig( - model_type=LTXModelType.AudioVideo, - num_attention_heads=32, - attention_head_dim=128, - in_channels=128, - out_channels=128, - num_layers=48, - cross_attention_dim=4096, - caption_channels=3840, - # Audio config - audio_num_attention_heads=32, - audio_attention_head_dim=64, - audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, # 8 * 16 = 128 - audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, - audio_cross_attention_dim=2048, - rope_type=LTXRopeType.SPLIT, - double_precision_rope=True, - positional_embedding_theta=10000.0, - positional_embedding_max_pos=[20, 2048, 2048], - audio_positional_embedding_max_pos=[20], - use_middle_indices_grid=True, - timestep_scale_multiplier=1000, - ) - - transformer = LTXModel(config) - transformer.load_weights(list(sanitized.items()), strict=False) - mx.eval(transformer.parameters()) - - # Load VAE encoder and encode image for I2V conditioning - stage1_image_latent = None - stage2_image_latent = None - if is_i2v: - print(f"{Colors.BLUE}🖼️ Loading VAE encoder and encoding image...{Colors.RESET}") - vae_encoder = load_vae_encoder(str(model_path / 'ltx-2-19b-distilled.safetensors')) - mx.eval(vae_encoder.parameters()) - - # Load and prepare image for stage 1 (half resolution) - input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype) - stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype) - stage1_image_latent = vae_encoder(stage1_image_tensor) - mx.eval(stage1_image_latent) - - # Load and prepare image for stage 2 (full resolution) - input_image = load_image(image, height=height, width=width, dtype=model_dtype) - stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype) - stage2_image_latent = vae_encoder(stage2_image_tensor) - mx.eval(stage2_image_latent) - - del vae_encoder - mx.clear_cache() - - # Initialize latents - print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}") - mx.random.seed(seed) - - # Create position grids - MUST stay float32 for RoPE precision - # bfloat16 positions cause quality degradation due to precision loss in sin/cos calculations - video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w) # float32 - audio_positions = create_audio_position_grid(1, audio_frames) # float32 - mx.eval(video_positions, audio_positions) - - # Apply I2V conditioning for stage 1 if provided - video_state1 = None - video_latent_shape = (1, 128, latent_frames, stage1_h, stage1_w) - if is_i2v and stage1_image_latent is not None: - # PyTorch flow: create zeros -> apply conditioning -> apply noiser - video_state1 = LatentState( - latent=mx.zeros(video_latent_shape, dtype=model_dtype), - clean_latent=mx.zeros(video_latent_shape, dtype=model_dtype), - denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), - ) - conditioning = VideoConditionByLatentIndex( - latent=stage1_image_latent, - frame_idx=image_frame_idx, - strength=image_strength, - ) - video_state1 = apply_conditioning(video_state1, [conditioning]) - - # Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale) - noise = mx.random.normal(video_latent_shape).astype(model_dtype) - noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype) # 1.0 - scaled_mask = video_state1.denoise_mask * noise_scale - video_state1 = LatentState( - latent=noise * scaled_mask + video_state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), - clean_latent=video_state1.clean_latent, - denoise_mask=video_state1.denoise_mask, - ) - video_latents = video_state1.latent - mx.eval(video_latents) - else: - # T2V: just use random noise - video_latents = mx.random.normal(video_latent_shape).astype(model_dtype) - mx.eval(video_latents) - - # Audio always uses pure noise (no I2V for audio) - audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype) - mx.eval(audio_latents) - - # Stage 1 denoising - video_latents, audio_latents = denoise_av( - video_latents, audio_latents, - video_positions, audio_positions, - video_embeddings, audio_embeddings, - transformer, STAGE_1_SIGMAS, verbose=verbose, - video_state=video_state1 - ) - - # Upsample video latents - print(f"{Colors.MAGENTA}🔍 Upsampling video latents 2x...{Colors.RESET}") - upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')) - mx.eval(upsampler.parameters()) - - vae_decoder = load_vae_decoder( - str(model_path / 'ltx-2-19b-distilled.safetensors'), - timestep_conditioning=None # Auto-detect from model metadata - ) - - video_latents = upsample_latents(video_latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std) - mx.eval(video_latents) - - del upsampler - mx.clear_cache() - - # Stage 2: Refine at full resolution - print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}") - # Position grids stay float32 for RoPE precision - video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w) # float32 - mx.eval(video_positions) - - # Apply I2V conditioning for stage 2 if provided - video_state2 = None - if is_i2v and stage2_image_latent is not None: - # PyTorch flow: start with upscaled latent -> apply conditioning -> apply noiser - video_state2 = LatentState( - latent=video_latents, # Start with upscaled latent - clean_latent=mx.zeros_like(video_latents), - denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), - ) - conditioning = VideoConditionByLatentIndex( - latent=stage2_image_latent, - frame_idx=image_frame_idx, - strength=image_strength, - ) - video_state2 = apply_conditioning(video_state2, [conditioning]) - - # Apply noiser: conditioned frames (mask=0) keep image latent, unconditioned get partial noise - video_noise = mx.random.normal(video_latents.shape).astype(model_dtype) - noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) - scaled_mask = video_state2.denoise_mask * noise_scale - video_state2 = LatentState( - latent=video_noise * scaled_mask + video_state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), - clean_latent=video_state2.clean_latent, - denoise_mask=video_state2.denoise_mask, - ) - video_latents = video_state2.latent - mx.eval(video_latents) - - # Audio still gets noise (no I2V for audio) - audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) - one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale - audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale - mx.eval(audio_latents) - else: - # T2V: add noise to all frames for refinement - noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) - one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale - video_noise = mx.random.normal(video_latents.shape).astype(model_dtype) - audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) - video_latents = video_noise * noise_scale + video_latents * one_minus_scale - audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale - mx.eval(video_latents, audio_latents) - - video_latents, audio_latents = denoise_av( - video_latents, audio_latents, - video_positions, audio_positions, - video_embeddings, audio_embeddings, - transformer, STAGE_2_SIGMAS, verbose=verbose, - video_state=video_state2 - ) - - del transformer - mx.clear_cache() - - # Decode video with tiling - print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}") - - # Select tiling configuration - if tiling == "none": - tiling_config = None - elif tiling == "auto": - tiling_config = TilingConfig.auto(height, width, num_frames) - elif tiling == "default": - tiling_config = TilingConfig.default() - elif tiling == "aggressive": - tiling_config = TilingConfig.aggressive() - elif tiling == "conservative": - tiling_config = TilingConfig.conservative() - elif tiling == "spatial": - tiling_config = TilingConfig.spatial_only() - elif tiling == "temporal": - tiling_config = TilingConfig.temporal_only() - else: - print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}") - tiling_config = TilingConfig.auto(height, width, num_frames) - - if tiling_config is not None: - spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none" - temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none" - print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}") - video = vae_decoder.decode_tiled(video_latents, tiling_config=tiling_config, debug=verbose) - else: - print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}") - video = vae_decoder(video_latents) - mx.eval(video) - - # Convert video to uint8 frames - video = mx.squeeze(video, axis=0) - video = mx.transpose(video, (1, 2, 3, 0)) - video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0) - video = (video * 255).astype(mx.uint8) - video_np = np.array(video) - - # Decode audio - print(f"{Colors.BLUE}🔊 Decoding audio...{Colors.RESET}") - audio_decoder = load_audio_decoder(model_path) - vocoder = load_vocoder(model_path) - mx.eval(audio_decoder.parameters(), vocoder.parameters()) - - mel_spectrogram = audio_decoder(audio_latents) - mx.eval(mel_spectrogram) - - # Audio decoder output is already in vocoder format (B, C, T, F) - audio_waveform = vocoder(mel_spectrogram) - mx.eval(audio_waveform) - - audio_np = np.array(audio_waveform) - if audio_np.ndim == 3: - audio_np = audio_np[0] # Remove batch dim - - del audio_decoder, vocoder, vae_decoder - mx.clear_cache() - - # Save outputs - output_path = Path(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - - # Save video (temporary without audio) - temp_video_path = output_path.with_suffix('.temp.mp4') - - try: - import cv2 - h, w = video_np.shape[1], video_np.shape[2] - fourcc = cv2.VideoWriter_fourcc(*'avc1') - out = cv2.VideoWriter(str(temp_video_path), fourcc, fps, (w, h)) - for frame in video_np: - out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) - out.release() - print(f"{Colors.GREEN}✅ Video encoded{Colors.RESET}") - except Exception as e: - print(f"{Colors.RED}❌ Video encoding failed: {e}{Colors.RESET}") - return None, None - - # Save audio - audio_path = output_path.with_suffix('.wav') if output_audio_path is None else Path(output_audio_path) - save_audio(audio_np, audio_path, AUDIO_SAMPLE_RATE) - print(f"{Colors.GREEN}✅ Saved audio to{Colors.RESET} {audio_path}") - - # Mux video and audio - print(f"{Colors.BLUE}🎬 Combining video and audio...{Colors.RESET}") - if mux_video_audio(temp_video_path, audio_path, output_path): - print(f"{Colors.GREEN}✅ Saved video with audio to{Colors.RESET} {output_path}") - temp_video_path.unlink() # Remove temp file - else: - # Fallback: keep video without audio - temp_video_path.rename(output_path) - print(f"{Colors.YELLOW}⚠️ Saved video without audio to{Colors.RESET} {output_path}") - - elapsed = time.time() - start_time - print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s{Colors.RESET}") - print(f"{Colors.BOLD}{Colors.GREEN}✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}") - - return video_np, audio_np - - -def main(): - parser = argparse.ArgumentParser( - description="Generate videos with synchronized audio using MLX LTX-2 (T2V and I2V)", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Text-to-Video with Audio (T2V+Audio) - python -m mlx_video.generate_av --prompt "Ocean waves crashing on a beach" - python -m mlx_video.generate_av --prompt "A jazz band playing" --enhance-prompt - python -m mlx_video.generate_av --prompt "..." --output my_video.mp4 --output-audio my_audio.wav - - # Image-to-Video with Audio (I2V+Audio) - python -m mlx_video.generate_av --prompt "A person dancing" --image photo.jpg - python -m mlx_video.generate_av --prompt "Waves crashing" --image beach.png --image-strength 0.8 - """ - ) - - parser.add_argument("--prompt", "-p", type=str, required=True, - help="Text description of the video/audio to generate") - parser.add_argument("--height", "-H", type=int, default=512, - help="Output video height (default: 512)") - parser.add_argument("--width", "-W", type=int, default=512, - help="Output video width (default: 512)") - parser.add_argument("--num-frames", "-n", type=int, default=65, - help="Number of frames (default: 65)") - parser.add_argument("--seed", "-s", type=int, default=42, - help="Random seed (default: 42)") - parser.add_argument("--fps", type=int, default=24, - help="Frames per second (default: 24)") - parser.add_argument("--output-path", type=str, default="output_av.mp4", - help="Output video path (default: output_av.mp4)") - parser.add_argument("--output-audio", type=str, default=None, - help="Output audio path (default: same as video with .wav)") - parser.add_argument("--model-repo", type=str, default="Lightricks/LTX-2", - help="Model repository (default: Lightricks/LTX-2)") - parser.add_argument("--text-encoder-repo", type=str, default=None, - help="Text encoder repository") - parser.add_argument("--verbose", action="store_true", - help="Verbose output") - parser.add_argument("--enhance-prompt", action="store_true", - help="Enhance prompt using Gemma") - parser.add_argument("--max-tokens", type=int, default=512, - help="Max tokens for prompt enhancement") - parser.add_argument("--temperature", type=float, default=0.7, - help="Temperature for prompt enhancement") - parser.add_argument("--image", "-i", type=str, default=None, - help="Path to conditioning image for I2V (Image-to-Video) generation") - parser.add_argument("--image-strength", type=float, default=1.0, - help="Conditioning strength for I2V (1.0 = full denoise, 0.0 = keep original, default: 1.0)") - parser.add_argument("--image-frame-idx", type=int, default=0, - help="Frame index to condition for I2V (0 = first frame, default: 0)") - parser.add_argument("--tiling", type=str, default="auto", - choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"], - help="Tiling mode for VAE decoding (default: auto). " - "auto=based on size, none=disabled, default=512px/64f, " - "aggressive=256px/32f (lowest memory), conservative=768px/96f") - - args = parser.parse_args() - - generate_video_with_audio( - model_repo=args.model_repo, - text_encoder_repo=args.text_encoder_repo, - prompt=args.prompt, - height=args.height, - width=args.width, - num_frames=args.num_frames, - seed=args.seed, - fps=args.fps, - output_path=args.output_path, - output_audio_path=args.output_audio, - verbose=args.verbose, - enhance_prompt=args.enhance_prompt, - max_tokens=args.max_tokens, - temperature=args.temperature, - image=args.image, - image_strength=args.image_strength, - image_frame_idx=args.image_frame_idx, - tiling=args.tiling, - ) - - -if __name__ == "__main__": - main()