diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 5b618b7..7ba7cc9 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -1,6 +1,12 @@ +"""Unified video and audio-video generation pipeline for LTX-2. + +Supports both distilled (two-stage with upsampling) and dev (single-stage with CFG) pipelines. +""" import argparse +import math import time +from enum import Enum from pathlib import Path from typing import Optional @@ -10,7 +16,6 @@ from PIL import Image from rich.console import Console from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn from rich.panel import Panel -from rich.status import Status # Rich console for styled output console = Console() @@ -29,10 +34,20 @@ from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioni from mlx_video.conditioning.latent import LatentState, apply_denoise_mask -# Distilled sigma schedules +class PipelineType(Enum): + """Pipeline type selector.""" + DISTILLED = "distilled" # Two-stage with upsampling, fixed sigmas, no CFG + DEV = "dev" # Single-stage, dynamic sigmas, CFG + + +# Distilled model sigma schedules STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0] +# Dev model scheduling constants +BASE_SHIFT_ANCHOR = 1024 +MAX_SHIFT_ANCHOR = 4096 + # Audio constants AUDIO_SAMPLE_RATE = 24000 # Output audio sample rate AUDIO_LATENT_SAMPLE_RATE = 16000 # VAE internal sample rate @@ -42,6 +57,89 @@ AUDIO_LATENT_CHANNELS = 8 # Latent channels before patchifying AUDIO_MEL_BINS = 16 AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25 +# Default negative prompt for CFG (dev pipeline) +DEFAULT_NEGATIVE_PROMPT = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) + + +def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array: + """Compute CFG delta for classifier-free guidance. + + Args: + cond: Conditional prediction + uncond: Unconditional prediction + scale: CFG guidance scale + + Returns: + Delta to add to unconditional for CFG: (scale - 1) * (cond - uncond) + """ + return (scale - 1.0) * (cond - uncond) + + +def ltx2_scheduler( + steps: int, + num_tokens: Optional[int] = None, + max_shift: float = 2.05, + base_shift: float = 0.95, + stretch: bool = True, + terminal: float = 0.1, +) -> mx.array: + """LTX-2 scheduler for sigma generation (dev model). + + Generates a sigma schedule with token-count-dependent shifting and optional + stretching to a terminal value. + + Args: + steps: Number of inference steps + num_tokens: Number of latent tokens (F*H*W). If None, uses MAX_SHIFT_ANCHOR + max_shift: Maximum shift factor + base_shift: Base shift factor + stretch: Whether to stretch sigmas to terminal value + terminal: Terminal sigma value for stretching + + Returns: + Array of sigma values of shape (steps + 1,) + """ + tokens = num_tokens if num_tokens is not None else MAX_SHIFT_ANCHOR + sigmas = np.linspace(1.0, 0.0, steps + 1) + + # Compute shift based on token count + x1 = BASE_SHIFT_ANCHOR + x2 = MAX_SHIFT_ANCHOR + mm = (max_shift - base_shift) / (x2 - x1) + b = base_shift - mm * x1 + sigma_shift = tokens * mm + b + + # Apply shift transformation + power = 1 + sigmas = np.where( + sigmas != 0, + math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power), + 0, + ) + + # Stretch sigmas to terminal value + if stretch: + non_zero_mask = sigmas != 0 + non_zero_sigmas = sigmas[non_zero_mask] + one_minus_z = 1.0 - non_zero_sigmas + scale_factor = one_minus_z[-1] / (1.0 - terminal) + stretched = 1.0 - (one_minus_z / scale_factor) + sigmas[non_zero_mask] = stretched + + return mx.array(sigmas, dtype=mx.float32) + def create_position_grid( batch_size: int, @@ -69,51 +167,35 @@ def create_position_grid( Position grid of shape (B, 3, num_patches, 2) in pixel space where dim 2 is [start, end) bounds for each patch """ - # Patch size is (1, 1, 1) for LTX-2 - no spatial patching patch_size_t, patch_size_h, patch_size_w = 1, 1, 1 - # Generate grid coordinates for each dimension (frame, height, width) t_coords = np.arange(0, num_frames, patch_size_t) h_coords = np.arange(0, height, patch_size_h) w_coords = np.arange(0, width, patch_size_w) - # Create meshgrid with indexing='ij' for (frame, height, width) order t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij') - - # Stack to get shape (3, grid_t, grid_h, grid_w) patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0) - # Calculate end coordinates (start + patch_size) patch_size_delta = np.array([patch_size_t, patch_size_h, patch_size_w]).reshape(3, 1, 1, 1) patch_ends = patch_starts + patch_size_delta - # Stack start and end: shape (3, grid_t, grid_h, grid_w, 2) latent_coords = np.stack([patch_starts, patch_ends], axis=-1) - - # Flatten spatial/temporal dims: (3, num_patches, 2) num_patches = num_frames * height * width latent_coords = latent_coords.reshape(3, num_patches, 2) - - # Broadcast to batch: (batch, 3, num_patches, 2) latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1)) - # Convert latent coords to pixel coords by scaling with VAE factors scale_factors = np.array([temporal_scale, spatial_scale, spatial_scale]).reshape(1, 3, 1, 1) pixel_coords = (latent_coords * scale_factors).astype(np.float32) - # Apply causal fix for first frame temporal axis if causal_fix: - # VAE temporal stride for first frame is 1 instead of temporal_scale pixel_coords[:, 0, :, :] = np.clip( pixel_coords[:, 0, :, :] + 1 - temporal_scale, a_min=0, a_max=None ) - # Convert temporal to time in seconds by dividing by fps pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps - # Always return float32 for RoPE precision - bfloat16 causes quality degradation return mx.array(pixel_coords, dtype=mx.float32) @@ -125,13 +207,8 @@ def create_audio_position_grid( downsample_factor: int = AUDIO_LATENT_DOWNSAMPLE_FACTOR, is_causal: bool = True, ) -> mx.array: - """Create temporal position grid for audio RoPE. - - Audio positions are timestamps in seconds, shape (B, 1, T, 2). - Matches PyTorch's AudioPatchifier.get_patch_grid_bounds exactly. - """ + """Create temporal position grid for audio RoPE.""" def get_audio_latent_time_in_sec(start_idx: int, end_idx: int) -> np.ndarray: - """Convert latent indices to seconds.""" latent_frame = np.arange(start_idx, end_idx, dtype=np.float32) mel_frame = latent_frame * downsample_factor if is_causal: @@ -142,7 +219,7 @@ def create_audio_position_grid( end_times = get_audio_latent_time_in_sec(1, audio_frames + 1) positions = np.stack([start_times, end_times], axis=-1) - positions = positions[np.newaxis, np.newaxis, :, :] # (1, 1, T, 2) + positions = positions[np.newaxis, np.newaxis, :, :] positions = np.tile(positions, (batch_size, 1, 1, 1)) return mx.array(positions, dtype=mx.float32) @@ -154,7 +231,11 @@ def compute_audio_frames(num_video_frames: int, fps: float) -> int: return round(duration * AUDIO_LATENTS_PER_SECOND) -def denoise( +# ============================================================================= +# Distilled Pipeline Denoising (no CFG, fixed sigmas) +# ============================================================================= + +def denoise_distilled( latents: mx.array, positions: mx.array, text_embeddings: mx.array, @@ -162,32 +243,14 @@ def denoise( sigmas: list, verbose: bool = True, state: Optional[LatentState] = None, - # Audio parameters (optional) audio_latents: Optional[mx.array] = None, audio_positions: Optional[mx.array] = None, audio_embeddings: Optional[mx.array] = None, ) -> tuple[mx.array, Optional[mx.array]]: - """Run denoising loop with optional conditioning and optional audio. - - Args: - latents: Noisy video latent tensor (B, C, F, H, W) - positions: Video position embeddings - text_embeddings: Video text conditioning embeddings - transformer: LTX model - sigmas: List of sigma values for denoising schedule - verbose: Whether to show progress bar - state: Optional LatentState for I2V conditioning - audio_latents: Optional audio latent tensor (B, C, T, F) for audio generation - audio_positions: Optional audio position embeddings - audio_embeddings: Optional audio text embeddings - - Returns: - Tuple of (video_latents, audio_latents) - audio_latents is None if audio disabled - """ + """Run denoising loop for distilled pipeline (no CFG).""" dtype = latents.dtype enable_audio = audio_latents is not None - # If state is provided, use its latent (which may have conditioning applied) if state is not None: latents = state.latent @@ -212,17 +275,12 @@ def denoise( num_tokens = f * h * w latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) - # Compute per-token timesteps - # For I2V: conditioned tokens get timestep=0 (mask=0), unconditioned get timestep=sigma (mask=1) if state is not None: - # Reshape denoise_mask from (B, 1, F, 1, 1) to (B, num_tokens) denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1)) denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w)) denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens)) - # Per-token timesteps: sigma * mask (preserve dtype) timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat else: - # All tokens get the same timestep (use latent dtype) timesteps = mx.full((b, num_tokens), sigma, dtype=dtype) video_modality = Modality( @@ -234,11 +292,10 @@ def denoise( enabled=True, ) - # Prepare audio modality if enabled audio_modality = None if enable_audio: ab, ac, at, af = audio_latents.shape - audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F) + audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)) audio_modality = Modality( @@ -258,15 +315,13 @@ def denoise( velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w)) denoised = to_denoised(latents, velocity, sigma) - # Handle audio velocity if enabled audio_denoised = None if enable_audio and audio_velocity is not None: ab, ac, at, af = audio_latents.shape audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af)) - audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) # (B, C, T, F) + audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) audio_denoised = to_denoised(audio_latents, audio_velocity, sigma) - # Apply conditioning mask if state is provided if state is not None: denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask) @@ -274,7 +329,6 @@ def denoise( if audio_denoised is not None: mx.eval(audio_denoised) - # Euler step (preserve dtype by converting Python floats to arrays) if sigma_next > 0: sigma_next_arr = mx.array(sigma_next, dtype=dtype) sigma_arr = mx.array(sigma, dtype=dtype) @@ -295,14 +349,282 @@ def denoise( return latents, audio_latents if enable_audio else None -def load_audio_decoder(model_path: Path): +# ============================================================================= +# Dev Pipeline Denoising (with CFG, dynamic sigmas) +# ============================================================================= + +def denoise_dev( + latents: mx.array, + positions: mx.array, + text_embeddings_pos: mx.array, + text_embeddings_neg: mx.array, + transformer: LTXModel, + sigmas: mx.array, + cfg_scale: float = 4.0, + verbose: bool = True, + state: Optional[LatentState] = None, +) -> mx.array: + """Run denoising loop for dev pipeline with CFG.""" + from mlx_video.models.ltx.rope import precompute_freqs_cis + + dtype = latents.dtype + if state is not None: + latents = state.latent + + sigmas_list = sigmas.tolist() + use_cfg = cfg_scale != 1.0 + num_steps = len(sigmas_list) - 1 + + # Precompute RoPE once + precomputed_rope = precompute_freqs_cis( + positions, + dim=transformer.inner_dim, + theta=transformer.positional_embedding_theta, + max_pos=transformer.positional_embedding_max_pos, + use_middle_indices_grid=transformer.use_middle_indices_grid, + num_attention_heads=transformer.num_attention_heads, + rope_type=transformer.rope_type, + double_precision=transformer.config.double_precision_rope, + ) + mx.eval(precomputed_rope) + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeRemainingColumn(), + console=console, + disable=not verbose, + ) as progress: + task = progress.add_task("[cyan]Denoising (CFG)[/]", total=num_steps) + + for i in range(num_steps): + sigma = sigmas_list[i] + sigma_next = sigmas_list[i + 1] + + b, c, f, h, w = latents.shape + num_tokens = f * h * w + latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) + + if state is not None: + denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1)) + denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w)) + denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens)) + timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat + else: + timesteps = mx.full((b, num_tokens), sigma, dtype=dtype) + + # Positive conditioning pass + video_modality_pos = Modality( + latent=latents_flat, + timesteps=timesteps, + positions=positions, + context=text_embeddings_pos, + context_mask=None, + enabled=True, + positional_embeddings=precomputed_rope, + ) + velocity_pos, _ = transformer(video=video_modality_pos, audio=None) + + if use_cfg: + # Negative conditioning pass + video_modality_neg = Modality( + latent=latents_flat, + timesteps=timesteps, + positions=positions, + context=text_embeddings_neg, + context_mask=None, + enabled=True, + positional_embeddings=precomputed_rope, + ) + velocity_neg, _ = transformer(video=video_modality_neg, audio=None) + + # Apply CFG + velocity_flat = velocity_pos + (cfg_scale - 1.0) * (velocity_pos - velocity_neg) + else: + velocity_flat = velocity_pos + + velocity = mx.reshape(mx.transpose(velocity_flat, (0, 2, 1)), (b, c, f, h, w)) + denoised = to_denoised(latents, velocity, sigma) + + if state is not None: + denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask) + + if sigma_next > 0: + sigma_next_arr = mx.array(sigma_next, dtype=dtype) + sigma_arr = mx.array(sigma, dtype=dtype) + latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr + else: + latents = denoised + + mx.eval(latents) + progress.advance(task) + + return latents + + +def denoise_dev_av( + video_latents: mx.array, + audio_latents: mx.array, + video_positions: mx.array, + audio_positions: mx.array, + video_embeddings_pos: mx.array, + video_embeddings_neg: mx.array, + audio_embeddings_pos: mx.array, + audio_embeddings_neg: mx.array, + transformer: LTXModel, + sigmas: mx.array, + cfg_scale: float = 4.0, + verbose: bool = True, + video_state: Optional[LatentState] = None, +) -> tuple[mx.array, mx.array]: + """Run denoising loop for dev pipeline with CFG and audio.""" + from mlx_video.models.ltx.rope import precompute_freqs_cis + + dtype = video_latents.dtype + if video_state is not None: + video_latents = video_state.latent + + sigmas_list = sigmas.tolist() + use_cfg = cfg_scale != 1.0 + num_steps = len(sigmas_list) - 1 + + # Precompute video RoPE + precomputed_video_rope = precompute_freqs_cis( + video_positions, + dim=transformer.inner_dim, + theta=transformer.positional_embedding_theta, + max_pos=transformer.positional_embedding_max_pos, + use_middle_indices_grid=transformer.use_middle_indices_grid, + num_attention_heads=transformer.num_attention_heads, + rope_type=transformer.rope_type, + double_precision=transformer.config.double_precision_rope, + ) + + # Precompute audio RoPE + precomputed_audio_rope = precompute_freqs_cis( + audio_positions, + dim=transformer.audio_inner_dim, + theta=transformer.positional_embedding_theta, + max_pos=transformer.audio_positional_embedding_max_pos, + use_middle_indices_grid=transformer.use_middle_indices_grid, + num_attention_heads=transformer.audio_num_attention_heads, + rope_type=transformer.rope_type, + double_precision=transformer.config.double_precision_rope, + ) + mx.eval(precomputed_video_rope, precomputed_audio_rope) + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeRemainingColumn(), + console=console, + disable=not verbose, + ) as progress: + task = progress.add_task("[cyan]Denoising A/V (CFG)[/]", total=num_steps) + + for i in range(num_steps): + sigma = sigmas_list[i] + sigma_next = sigmas_list[i + 1] + + # Flatten video latents + b, c, f, h, w = video_latents.shape + num_video_tokens = f * h * w + video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1)) + + # Flatten audio latents + ab, ac, at, af = audio_latents.shape + audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) + audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)) + + # Compute timesteps + if video_state is not None: + denoise_mask_flat = mx.reshape(video_state.denoise_mask, (b, 1, f, 1, 1)) + denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w)) + denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens)) + video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat + else: + video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype) + + audio_timesteps = mx.full((ab, at), sigma, dtype=dtype) + + # Positive conditioning pass + video_modality_pos = Modality( + latent=video_flat, timesteps=video_timesteps, positions=video_positions, + context=video_embeddings_pos, context_mask=None, enabled=True, + positional_embeddings=precomputed_video_rope, + ) + audio_modality_pos = Modality( + latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions, + context=audio_embeddings_pos, context_mask=None, enabled=True, + positional_embeddings=precomputed_audio_rope, + ) + video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos) + + if use_cfg: + # Negative conditioning pass + video_modality_neg = Modality( + latent=video_flat, timesteps=video_timesteps, positions=video_positions, + context=video_embeddings_neg, context_mask=None, enabled=True, + positional_embeddings=precomputed_video_rope, + ) + audio_modality_neg = Modality( + latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions, + context=audio_embeddings_neg, context_mask=None, enabled=True, + positional_embeddings=precomputed_audio_rope, + ) + video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg) + + # Apply CFG + video_velocity_flat = video_vel_pos + (cfg_scale - 1.0) * (video_vel_pos - video_vel_neg) + audio_velocity_flat = audio_vel_pos + (cfg_scale - 1.0) * (audio_vel_pos - audio_vel_neg) + else: + video_velocity_flat = video_vel_pos + audio_velocity_flat = audio_vel_pos + + # Reshape velocities + video_velocity = mx.reshape(mx.transpose(video_velocity_flat, (0, 2, 1)), (b, c, f, h, w)) + audio_velocity = mx.reshape(audio_velocity_flat, (ab, at, ac, af)) + audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) + + # Compute denoised + video_denoised = to_denoised(video_latents, video_velocity, sigma) + audio_denoised = to_denoised(audio_latents, audio_velocity, sigma) + + if video_state is not None: + video_denoised = apply_denoise_mask(video_denoised, video_state.clean_latent, video_state.denoise_mask) + + # Euler step + if sigma_next > 0: + sigma_next_arr = mx.array(sigma_next, dtype=dtype) + sigma_arr = mx.array(sigma, dtype=dtype) + video_latents = video_denoised + sigma_next_arr * (video_latents - video_denoised) / sigma_arr + audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr + else: + video_latents = video_denoised + audio_latents = audio_denoised + + mx.eval(video_latents, audio_latents) + progress.advance(task) + + return video_latents, audio_latents + + +# ============================================================================= +# Audio Loading and Processing +# ============================================================================= + +def load_audio_decoder(model_path: Path, pipeline: PipelineType): """Load audio VAE decoder.""" from mlx_video.models.ltx.audio_vae import AudioDecoder, CausalityAxis, NormType from mlx_video.convert import sanitize_audio_vae_weights decoder = AudioDecoder( ch=128, - out_ch=2, # stereo + out_ch=2, ch_mult=(1, 2, 4), num_res_blocks=2, attn_resolutions=set(), @@ -313,13 +635,12 @@ def load_audio_decoder(model_path: Path): mel_bins=64, ) - weight_file = model_path / "ltx-2-19b-distilled.safetensors" + weight_file = model_path / ("ltx-2-19b-dev.safetensors" if pipeline == PipelineType.DEV else "ltx-2-19b-distilled.safetensors") if weight_file.exists(): raw_weights = mx.load(str(weight_file)) sanitized = sanitize_audio_vae_weights(raw_weights) if sanitized: decoder.load_weights(list(sanitized.items()), strict=False) - if "per_channel_statistics._mean_of_means" in sanitized: decoder.per_channel_statistics._mean_of_means = sanitized["per_channel_statistics._mean_of_means"] if "per_channel_statistics._std_of_means" in sanitized: @@ -328,7 +649,7 @@ def load_audio_decoder(model_path: Path): return decoder -def load_vocoder(model_path: Path): +def load_vocoder(model_path: Path, pipeline: PipelineType): """Load vocoder for mel to waveform conversion.""" from mlx_video.models.ltx.audio_vae import Vocoder from mlx_video.convert import sanitize_vocoder_weights @@ -343,7 +664,7 @@ def load_vocoder(model_path: Path): output_sample_rate=AUDIO_SAMPLE_RATE, ) - weight_file = model_path / "ltx-2-19b-distilled.safetensors" + weight_file = model_path / ("ltx-2-19b-dev.safetensors" if pipeline == PipelineType.DEV else "ltx-2-19b-distilled.safetensors") if weight_file.exists(): raw_weights = mx.load(str(weight_file)) sanitized = sanitize_vocoder_weights(raw_weights) @@ -358,7 +679,7 @@ def save_audio(audio: np.ndarray, path: Path, sample_rate: int = AUDIO_SAMPLE_RA import wave if audio.ndim == 2: - audio = audio.T # (channels, samples) -> (samples, channels) + audio = audio.T audio = np.clip(audio, -1.0, 1.0) audio_int16 = (audio * 32767).astype(np.int16) @@ -395,13 +716,21 @@ def mux_video_audio(video_path: Path, audio_path: Path, output_path: Path): return False +# ============================================================================= +# Unified Generate Function +# ============================================================================= + def generate_video( model_repo: str, text_encoder_repo: str, prompt: str, + pipeline: PipelineType = PipelineType.DISTILLED, + negative_prompt: str = DEFAULT_NEGATIVE_PROMPT, height: int = 512, width: int = 512, num_frames: int = 33, + num_inference_steps: int = 40, + cfg_scale: float = 4.0, seed: int = 42, fps: int = 24, output_path: str = "output.mp4", @@ -415,19 +744,26 @@ def generate_video( image_frame_idx: int = 0, tiling: str = "auto", stream: bool = False, - # Audio options audio: bool = False, output_audio_path: Optional[str] = None, ): - """Generate video from text prompt, optionally conditioned on an image and with audio. + """Generate video using LTX-2 models. + + Supports two pipelines: + - DISTILLED: Two-stage generation with upsampling, fixed sigma schedules, no CFG + - DEV: Single-stage generation with dynamic sigmas and CFG Args: model_repo: Model repository ID text_encoder_repo: Text encoder repository ID prompt: Text description of the video to generate - height: Output video height (must be divisible by 64) - width: Output video width (must be divisible by 64) - num_frames: Number of frames (must be 1 + 8*k, e.g., 33, 65, 97) + pipeline: Pipeline type (DISTILLED or DEV) + negative_prompt: Negative prompt for CFG (dev pipeline only) + height: Output video height (must be divisible by 32/64) + width: Output video width (must be divisible by 32/64) + num_frames: Number of frames (must be 1 + 8*k) + num_inference_steps: Number of denoising steps (dev pipeline only) + cfg_scale: Guidance scale for CFG (dev pipeline only) seed: Random seed for reproducibility fps: Frames per second for output video output_path: Path to save the output video @@ -436,30 +772,24 @@ def generate_video( enhance_prompt: Whether to enhance prompt using Gemma max_tokens: Max tokens for prompt enhancement temperature: Temperature for prompt enhancement - image: Path to conditioning image for I2V (Image-to-Video) - image_strength: Conditioning strength (1.0 = full denoise, 0.0 = keep original) - image_frame_idx: Frame index to condition (0 = first frame) - tiling: Tiling mode for VAE decoding. Options: - - "auto": Automatically determine based on video size (default) - - "none": Disable tiling - - "default": 512px spatial, 64 frame temporal - - "aggressive": 256px spatial, 32 frame temporal (lowest memory) - - "conservative": 768px spatial, 96 frame temporal (faster) - - "spatial": Spatial tiling only - - "temporal": Temporal tiling only - stream: Stream frames to output as they're decoded (requires tiling) + image: Path to conditioning image for I2V + image_strength: Conditioning strength for I2V + image_frame_idx: Frame index to condition for I2V + tiling: Tiling mode for VAE decoding + stream: Stream frames to output as they're decoded audio: Enable synchronized audio generation - output_audio_path: Path to save audio file (default: same as video with .wav) + output_audio_path: Path to save audio file """ start_time = time.time() # Validate dimensions - assert height % 64 == 0, f"Height must be divisible by 64, got {height}" - assert width % 64 == 0, f"Width must be divisible by 64, got {width}" + divisor = 64 if pipeline == PipelineType.DISTILLED else 32 + assert height % divisor == 0, f"Height must be divisible by {divisor}, got {height}" + assert width % divisor == 0, f"Width must be divisible by {divisor}, got {width}" if num_frames % 8 != 1: adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1 - console.print(f"[yellow]⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}[/]") + console.print(f"[yellow]⚠️ Number of frames must be 1 + 8*k. Using: {adjusted_num_frames}[/]") num_frames = adjusted_num_frames is_i2v = image is not None @@ -467,14 +797,17 @@ def generate_video( if audio: mode_str += "+Audio" - # Display header panel - header = f"[bold cyan]🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames[/]" + pipeline_name = "DEV" if pipeline == PipelineType.DEV else "DISTILLED" + header = f"[bold cyan]🎬 [{pipeline_name}] [{mode_str}] {width}x{height} • {num_frames} frames[/]" console.print(Panel(header, expand=False)) console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]") + + if pipeline == PipelineType.DEV: + console.print(f"[dim]Steps: {num_inference_steps}, CFG: {cfg_scale}[/]") + if is_i2v: console.print(f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]") - # Calculate audio frames if enabled audio_frames = None if audio: audio_frames = compute_audio_frames(num_frames, fps) @@ -484,9 +817,15 @@ def generate_video( model_path = get_model_path(model_repo) text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo) + # Model weight file + weight_file = "ltx-2-19b-dev.safetensors" if pipeline == PipelineType.DEV else "ltx-2-19b-distilled.safetensors" + # Calculate latent dimensions - stage1_h, stage1_w = height // 2 // 32, width // 2 // 32 - stage2_h, stage2_w = height // 32, width // 32 + if pipeline == PipelineType.DISTILLED: + stage1_h, stage1_w = height // 2 // 32, width // 2 // 32 + stage2_h, stage2_w = height // 32, width // 32 + else: + latent_h, latent_w = height // 32, width // 32 latent_frames = 1 + (num_frames - 1) // 8 mx.random.seed(seed) @@ -501,33 +840,45 @@ def generate_video( # Optionally enhance the prompt if enhance_prompt: - with console.status("[magenta]✨ Enhancing prompt...[/]", spinner="dots"): - prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose) + console.print("[bold magenta]✨ Enhancing prompt[/]") + prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose) console.print(f"[dim]Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}[/]") - # Get embeddings - with audio if enabled - if audio: - text_embeddings, audio_embeddings = text_encoder(prompt, return_audio_embeddings=True) - mx.eval(text_embeddings, audio_embeddings) + # Encode prompts + if pipeline == PipelineType.DEV: + # Dev pipeline needs positive and negative embeddings + if audio: + video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True) + video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True) + model_dtype = video_embeddings_pos.dtype + mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg) + else: + video_embeddings_pos, _ = text_encoder(prompt, return_audio_embeddings=False) + video_embeddings_neg, _ = text_encoder(negative_prompt, return_audio_embeddings=False) + audio_embeddings_pos = audio_embeddings_neg = None + model_dtype = video_embeddings_pos.dtype + mx.eval(video_embeddings_pos, video_embeddings_neg) else: - text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False) - audio_embeddings = None - mx.eval(text_embeddings) - - model_dtype = text_embeddings.dtype # bfloat16 from text encoder + # Distilled pipeline - single embedding + if audio: + text_embeddings, audio_embeddings = text_encoder(prompt, return_audio_embeddings=True) + mx.eval(text_embeddings, audio_embeddings) + else: + text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False) + audio_embeddings = None + mx.eval(text_embeddings) + model_dtype = text_embeddings.dtype del text_encoder mx.clear_cache() # Load transformer - transformer_desc = "🤖 Loading transformer (A/V mode)..." if audio else "🤖 Loading transformer..." + transformer_desc = f"🤖 Loading {pipeline_name.lower()} transformer{' (A/V mode)' if audio else ''}..." with console.status(f"[blue]{transformer_desc}[/]", spinner="dots"): - raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors')) + raw_weights = mx.load(str(model_path / weight_file)) sanitized = sanitize_transformer_weights(raw_weights) - # Convert transformer weights to bfloat16 for memory efficiency sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()} - # Configure model type based on audio flag model_type = LTXModelType.AudioVideo if audio else LTXModelType.VideoOnly config_kwargs = dict( @@ -551,185 +902,249 @@ def generate_video( config_kwargs.update( audio_num_attention_heads=32, audio_attention_head_dim=64, - audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, # 8 * 16 = 128 + audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, audio_cross_attention_dim=2048, audio_positional_embedding_max_pos=[20], ) config = LTXModelConfig(**config_kwargs) - transformer = LTXModel(config) transformer.load_weights(list(sanitized.items()), strict=False) mx.eval(transformer.parameters()) console.print("[green]✓[/] Transformer loaded") - # Load VAE encoder and encode image for I2V conditioning - stage1_image_latent = None - stage2_image_latent = None - if is_i2v: - with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"): - vae_encoder = load_vae_encoder(str(model_path / 'ltx-2-19b-distilled.safetensors')) - mx.eval(vae_encoder.parameters()) + # ========================================================================== + # Pipeline-specific generation logic + # ========================================================================== - # Load and prepare image for stage 1 (half resolution) - input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype) - stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype) - stage1_image_latent = vae_encoder(stage1_image_tensor) - mx.eval(stage1_image_latent) + if pipeline == PipelineType.DISTILLED: + # ====================================================================== + # DISTILLED PIPELINE: Two-stage with upsampling + # ====================================================================== - # Load and prepare image for stage 2 (full resolution) - input_image = load_image(image, height=height, width=width, dtype=model_dtype) - stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype) - stage2_image_latent = vae_encoder(stage2_image_tensor) - mx.eval(stage2_image_latent) + # Load VAE encoder for I2V + stage1_image_latent = None + stage2_image_latent = None + if is_i2v: + with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"): + vae_encoder = load_vae_encoder(str(model_path / weight_file)) + mx.eval(vae_encoder.parameters()) - del vae_encoder + input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype) + stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype) + stage1_image_latent = vae_encoder(stage1_image_tensor) + mx.eval(stage1_image_latent) + + input_image = load_image(image, height=height, width=width, dtype=model_dtype) + stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype) + stage2_image_latent = vae_encoder(stage2_image_tensor) + mx.eval(stage2_image_latent) + + del vae_encoder + mx.clear_cache() + console.print("[green]✓[/] VAE encoder loaded and image encoded") + + # Stage 1 + console.print(f"\n[bold yellow]⚡ Stage 1:[/] Generating at {width//2}x{height//2} (8 steps)") + mx.random.seed(seed) + + positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) + mx.eval(positions) + + audio_positions = None + audio_latents = None + if audio: + audio_positions = create_audio_position_grid(1, audio_frames) + audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype) + mx.eval(audio_positions, audio_latents) + + # Apply I2V conditioning + state1 = None + if is_i2v and stage1_image_latent is not None: + latent_shape = (1, 128, latent_frames, stage1_h, stage1_w) + state1 = LatentState( + latent=mx.zeros(latent_shape, dtype=model_dtype), + clean_latent=mx.zeros(latent_shape, dtype=model_dtype), + denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), + ) + conditioning = VideoConditionByLatentIndex(latent=stage1_image_latent, frame_idx=image_frame_idx, strength=image_strength) + state1 = apply_conditioning(state1, [conditioning]) + + noise = mx.random.normal(latent_shape, dtype=model_dtype) + noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype) + scaled_mask = state1.denoise_mask * noise_scale + state1 = LatentState( + latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), + clean_latent=state1.clean_latent, + denoise_mask=state1.denoise_mask, + ) + latents = state1.latent + mx.eval(latents) + else: + latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w), dtype=model_dtype) + mx.eval(latents) + + latents, audio_latents = denoise_distilled( + latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, + verbose=verbose, state=state1, + audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings, + ) + + # Upsample latents + with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"): + upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')) + mx.eval(upsampler.parameters()) + + vae_decoder = load_vae_decoder(str(model_path / weight_file), timestep_conditioning=None) + + latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std) + mx.eval(latents) + + del upsampler mx.clear_cache() - console.print("[green]✓[/] VAE encoder loaded and image encoded") + console.print("[green]✓[/] Latents upsampled") - # Stage 1: Generate at half resolution - console.print(f"\n[bold yellow]⚡ Stage 1:[/] Generating at {width//2}x{height//2} (8 steps)") - mx.random.seed(seed) + # Stage 2 + console.print(f"\n[bold yellow]⚡ Stage 2:[/] Refining at {width}x{height} (3 steps)") + positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) + mx.eval(positions) - # Position grids stay float32 for RoPE precision - positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) - mx.eval(positions) + state2 = None + if is_i2v and stage2_image_latent is not None: + state2 = LatentState( + latent=latents, + clean_latent=mx.zeros_like(latents), + denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), + ) + conditioning = VideoConditionByLatentIndex(latent=stage2_image_latent, frame_idx=image_frame_idx, strength=image_strength) + state2 = apply_conditioning(state2, [conditioning]) - # Create audio positions if enabled - audio_positions = None - audio_latents = None - if audio: - audio_positions = create_audio_position_grid(1, audio_frames) - audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype) - mx.eval(audio_positions, audio_latents) + noise = mx.random.normal(latents.shape).astype(model_dtype) + noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) + scaled_mask = state2.denoise_mask * noise_scale + state2 = LatentState( + latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), + clean_latent=state2.clean_latent, + denoise_mask=state2.denoise_mask, + ) + latents = state2.latent + mx.eval(latents) - # Apply I2V conditioning if provided - state1 = None - if is_i2v and stage1_image_latent is not None: - # PyTorch flow: create zeros -> apply conditioning -> apply noiser - # Create initial state with zeros - latent_shape = (1, 128, latent_frames, stage1_h, stage1_w) - state1 = LatentState( - latent=mx.zeros(latent_shape, dtype=model_dtype), - clean_latent=mx.zeros(latent_shape, dtype=model_dtype), - denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), - ) - conditioning = VideoConditionByLatentIndex( - latent=stage1_image_latent, - frame_idx=image_frame_idx, - strength=image_strength, + if audio and audio_latents is not None: + audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) + one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale + audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale + mx.eval(audio_latents) + else: + noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) + one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype) + noise = mx.random.normal(latents.shape).astype(model_dtype) + latents = noise * noise_scale + latents * one_minus_scale + mx.eval(latents) + + if audio and audio_latents is not None: + audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) + audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale + mx.eval(audio_latents) + + latents, audio_latents = denoise_distilled( + latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, + verbose=verbose, state=state2, + audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings, ) - state1 = apply_conditioning(state1, [conditioning]) - - # Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale) - # For Stage 1, noise_scale = 1.0 (first sigma) - noise = mx.random.normal(latent_shape, dtype=model_dtype) - noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype) # 1.0 - scaled_mask = state1.denoise_mask * noise_scale - - state1 = LatentState( - latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), - clean_latent=state1.clean_latent, - denoise_mask=state1.denoise_mask, - ) - latents = state1.latent - mx.eval(latents) else: - # T2V: just use random noise - latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w), dtype=model_dtype) - mx.eval(latents) + # ====================================================================== + # DEV PIPELINE: Single-stage with CFG + # ====================================================================== - latents, audio_latents = denoise( - latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, - verbose=verbose, state=state1, - audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings, - ) + # Load VAE encoder for I2V + image_latent = None + if is_i2v: + with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"): + vae_encoder = load_vae_encoder(str(model_path / weight_file)) + mx.eval(vae_encoder.parameters()) - # Upsample latents - with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"): - upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')) - mx.eval(upsampler.parameters()) + input_image = load_image(image, height=height, width=width, dtype=model_dtype) + image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype) + image_latent = vae_encoder(image_tensor) + mx.eval(image_latent) - vae_decoder = load_vae_decoder( - str(model_path / 'ltx-2-19b-distilled.safetensors'), - timestep_conditioning=None # Auto-detect from model metadata - ) + del vae_encoder + mx.clear_cache() + console.print("[green]✓[/] VAE encoder loaded and image encoded") - latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std) - mx.eval(latents) + # Generate sigma schedule + num_tokens = latent_frames * latent_h * latent_w + sigmas = ltx2_scheduler(steps=num_inference_steps, num_tokens=num_tokens) + mx.eval(sigmas) + console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]") - del upsampler - mx.clear_cache() - console.print("[green]✓[/] Latents upsampled") + console.print(f"\n[bold yellow]⚡ Generating:[/] {width}x{height} ({num_inference_steps} steps, CFG={cfg_scale})") + mx.random.seed(seed) - # Stage 2: Refine at full resolution - console.print(f"\n[bold yellow]⚡ Stage 2:[/] Refining at {width}x{height} (3 steps)") - # Position grids stay float32 for RoPE precision - positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) - mx.eval(positions) + video_positions = create_position_grid(1, latent_frames, latent_h, latent_w) + mx.eval(video_positions) - # Apply I2V conditioning for stage 2 if provided - state2 = None - if is_i2v and stage2_image_latent is not None: - # PyTorch flow: start with upscaled latent -> apply conditioning -> apply noiser - state2 = LatentState( - latent=latents, # Start with upscaled latent - clean_latent=mx.zeros_like(latents), - denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), - ) - conditioning = VideoConditionByLatentIndex( - latent=stage2_image_latent, - frame_idx=image_frame_idx, - strength=image_strength, - ) - state2 = apply_conditioning(state2, [conditioning]) + audio_positions = None + audio_latents = None + if audio: + audio_positions = create_audio_position_grid(1, audio_frames) + audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype) + mx.eval(audio_positions, audio_latents) - # Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale) - # For Stage 2, noise_scale = stage_2_sigmas[0] - # Conditioned frames (mask=0) keep image latent, unconditioned get partial noise - noise = mx.random.normal(latents.shape).astype(model_dtype) - noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) - scaled_mask = state2.denoise_mask * noise_scale - state2 = LatentState( - latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), - clean_latent=state2.clean_latent, - denoise_mask=state2.denoise_mask, - ) - latents = state2.latent - mx.eval(latents) + # Initialize latents with optional I2V conditioning + video_state = None + video_latent_shape = (1, 128, latent_frames, latent_h, latent_w) + if is_i2v and image_latent is not None: + video_state = LatentState( + latent=mx.zeros(video_latent_shape, dtype=model_dtype), + clean_latent=mx.zeros(video_latent_shape, dtype=model_dtype), + denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), + ) + conditioning = VideoConditionByLatentIndex(latent=image_latent, frame_idx=image_frame_idx, strength=image_strength) + video_state = apply_conditioning(video_state, [conditioning]) - # Audio also gets noise for stage 2 if enabled - if audio and audio_latents is not None: - audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) - one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale - audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale - mx.eval(audio_latents) - else: - # T2V: add noise to all frames for refinement - noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) - one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype) - noise = mx.random.normal(latents.shape).astype(model_dtype) - latents = noise * noise_scale + latents * one_minus_scale - mx.eval(latents) + noise = mx.random.normal(video_latent_shape, dtype=model_dtype) + noise_scale = sigmas[0] + scaled_mask = video_state.denoise_mask * noise_scale + video_state = LatentState( + latent=noise * scaled_mask + video_state.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), + clean_latent=video_state.clean_latent, + denoise_mask=video_state.denoise_mask, + ) + latents = video_state.latent + mx.eval(latents) + else: + latents = mx.random.normal(video_latent_shape, dtype=model_dtype) + mx.eval(latents) - # Audio also gets noise for stage 2 if enabled - if audio and audio_latents is not None: - audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) - audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale - mx.eval(audio_latents) + # Denoise with CFG + if audio: + latents, audio_latents = denoise_dev_av( + latents, audio_latents, + video_positions, audio_positions, + video_embeddings_pos, video_embeddings_neg, + audio_embeddings_pos, audio_embeddings_neg, + transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, video_state=video_state + ) + else: + latents = denoise_dev( + latents, video_positions, video_embeddings_pos, video_embeddings_neg, + transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, state=video_state + ) - latents, audio_latents = denoise( - latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, - verbose=verbose, state=state2, - audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings, - ) + # Load VAE decoder (for dev pipeline, loaded here instead of during upsampling) + vae_decoder = load_vae_decoder(str(model_path / weight_file), timestep_conditioning=None) del transformer mx.clear_cache() - # Decode to video with tiling + # ========================================================================== + # Decode and save outputs (common to both pipelines) + # ========================================================================== + console.print("\n[blue]🎞️ Decoding video...[/]") # Select tiling configuration @@ -751,11 +1166,10 @@ def generate_video( console.print(f"[yellow] Unknown tiling mode '{tiling}', using auto[/]") tiling_config = TilingConfig.auto(height, width, num_frames) - # Save outputs output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) - # Stream mode: write frames as they're decoded + # Stream mode video_writer = None stream_progress = None @@ -774,10 +1188,8 @@ def generate_video( stream_task = stream_progress.add_task("[cyan]Streaming frames[/]", total=num_frames) def on_frames_ready(frames: mx.array, _start_idx: int): - """Callback to write frames as they're finalized.""" - # frames: (B, 3, num_frames, H, W) - frames = mx.squeeze(frames, axis=0) # (3, num_frames, H, W) - frames = mx.transpose(frames, (1, 2, 3, 0)) # (num_frames, H, W, 3) + frames = mx.squeeze(frames, axis=0) + frames = mx.transpose(frames, (1, 2, 3, 0)) frames = mx.clip((frames + 1.0) / 2.0, 0.0, 1.0) frames = (frames * 255).astype(mx.uint8) frames_np = np.array(frames) @@ -799,34 +1211,30 @@ def generate_video( mx.eval(video) mx.clear_cache() - # Close progressive video writer if used + # Close stream writer if video_writer is not None: video_writer.release() if stream_progress is not None: stream_progress.stop() console.print(f"[green]✅ Streamed video to[/] {output_path}") - # Still need video_np for save_frames option video = mx.squeeze(video, axis=0) video = mx.transpose(video, (1, 2, 3, 0)) video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0) video = (video * 255).astype(mx.uint8) video_np = np.array(video) else: - # Convert to uint8 frames - video = mx.squeeze(video, axis=0) # (C, F, H, W) - video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C) + video = mx.squeeze(video, axis=0) + video = mx.transpose(video, (1, 2, 3, 0)) video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0) video = (video * 255).astype(mx.uint8) video_np = np.array(video) - # For audio mode, save to temp file first if audio: temp_video_path = output_path.with_suffix('.temp.mp4') save_path = temp_video_path else: save_path = output_path - # Save video try: import cv2 h, w = video_np.shape[1], video_np.shape[2] @@ -844,8 +1252,8 @@ def generate_video( audio_np = None if audio and audio_latents is not None: with console.status("[blue]🔊 Decoding audio...[/]", spinner="dots"): - audio_decoder = load_audio_decoder(model_path) - vocoder = load_vocoder(model_path) + audio_decoder = load_audio_decoder(model_path, pipeline) + vocoder = load_vocoder(model_path, pipeline) mx.eval(audio_decoder.parameters(), vocoder.parameters()) mel_spectrogram = audio_decoder(audio_latents) @@ -862,12 +1270,10 @@ def generate_video( mx.clear_cache() console.print("[green]✓[/] Audio decoded") - # Save audio audio_path = Path(output_audio_path) if output_audio_path else output_path.with_suffix('.wav') save_audio(audio_np, audio_path, AUDIO_SAMPLE_RATE) console.print(f"[green]✅ Saved audio to[/] {audio_path}") - # Mux video and audio with console.status("[blue]🎬 Combining video and audio...[/]", spinner="dots"): temp_video_path = output_path.with_suffix('.temp.mp4') success = mux_video_audio(temp_video_path, audio_path, output_path) @@ -902,160 +1308,73 @@ def generate_video( def main(): parser = argparse.ArgumentParser( - description="Generate videos with MLX LTX-2 (T2V, I2V, and Audio)", + description="Generate videos with MLX LTX-2 (Distilled or Dev pipeline)", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: - # Text-to-Video (T2V) + # Distilled pipeline (two-stage, fast, no CFG) python -m mlx_video.generate --prompt "A cat walking on grass" - python -m mlx_video.generate --prompt "Ocean waves at sunset" --height 768 --width 768 - python -m mlx_video.generate --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4 + python -m mlx_video.generate --prompt "Ocean waves" --pipeline distilled - # Image-to-Video (I2V) + # Dev pipeline (single-stage, CFG, higher quality) + python -m mlx_video.generate --prompt "A cat walking" --pipeline dev --cfg-scale 4.0 + python -m mlx_video.generate --prompt "Ocean waves" --pipeline dev --steps 50 + + # Image-to-Video (works with both pipelines) python -m mlx_video.generate --prompt "A person dancing" --image photo.jpg - python -m mlx_video.generate --prompt "Waves crashing" --image beach.png --image-strength 0.8 + python -m mlx_video.generate --prompt "Waves crashing" --image beach.png --pipeline dev - # With Audio (T2V+Audio or I2V+Audio) + # With Audio (works with both pipelines) python -m mlx_video.generate --prompt "Ocean waves crashing" --audio - python -m mlx_video.generate --prompt "A jazz band playing" --audio --enhance-prompt - python -m mlx_video.generate --prompt "Waves crashing" --image beach.png --audio + python -m mlx_video.generate --prompt "A jazz band playing" --audio --pipeline dev """ ) - parser.add_argument( - "--prompt", "-p", - type=str, - required=True, - help="Text description of the video to generate" - ) - parser.add_argument( - "--height", "-H", - type=int, - default=512, - help="Output video height (default: 512, must be divisible by 32)" - ) - parser.add_argument( - "--width", "-W", - type=int, - default=512, - help="Output video width (default: 512, must be divisible by 32)" - ) - parser.add_argument( - "--num-frames", "-n", - type=int, - default=100, - help="Number of frames (default: 100)" - ) - parser.add_argument( - "--seed", "-s", - type=int, - default=42, - help="Random seed for reproducibility (default: 42)" - ) - parser.add_argument( - "--fps", - type=int, - default=24, - help="Frames per second for output video (default: 24)" - ) - parser.add_argument( - "--output-path", "-o", - type=str, - default="output.mp4", - help="Output video path (default: output.mp4)" - ) - parser.add_argument( - "--save-frames", - action="store_true", - help="Save individual frames as images" - ) - parser.add_argument( - "--model-repo", - type=str, - default="Lightricks/LTX-2", - help="Model repository to use (default: Lightricks/LTX-2)" - ) - parser.add_argument( - "--text-encoder-repo", - type=str, - default=None, - help="Text encoder repository to use (default: None)" - ) - parser.add_argument( - "--verbose", - action="store_true", - help="Verbose output" - ) - parser.add_argument( - "--enhance-prompt", - action="store_true", - help="Enhance the prompt using Gemma before generation" - ) - parser.add_argument( - "--max-tokens", - type=int, - default=512, - help="Maximum number of tokens to generate (default: 512)" - ) - parser.add_argument( - "--temperature", - type=float, - default=0.7, - help="Temperature for prompt enhancement (default: 0.7)" - ) - parser.add_argument( - "--image", "-i", - type=str, - default=None, - help="Path to conditioning image for I2V (Image-to-Video) generation" - ) - parser.add_argument( - "--image-strength", - type=float, - default=1.0, - help="Conditioning strength for I2V (1.0 = full denoise, 0.0 = keep original, default: 1.0)" - ) - parser.add_argument( - "--image-frame-idx", - type=int, - default=0, - help="Frame index to condition for I2V (0 = first frame, default: 0)" - ) - parser.add_argument( - "--tiling", - type=str, - default="auto", - choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"], - help="Tiling mode for VAE decoding (default: auto). " - "auto=based on video size, none=disabled, default=512px/64f, " - "aggressive=256px/32f (lowest memory), conservative=768px/96f, spatial=spatial only, temporal=temporal only" - ) - parser.add_argument( - "--stream", - action="store_true", - help="Stream frames to output file as they're decoded (requires tiling). Allows viewing partial results sooner." - ) - # Audio options - parser.add_argument( - "--audio", "-a", - action="store_true", - help="Enable synchronized audio generation" - ) - parser.add_argument( - "--output-audio", - type=str, - default=None, - help="Output audio path (default: same as video with .wav)" - ) + parser.add_argument("--prompt", "-p", type=str, required=True, help="Text description of the video to generate") + parser.add_argument("--pipeline", type=str, default="distilled", choices=["distilled", "dev"], + help="Pipeline type: distilled (two-stage, fast) or dev (single-stage, CFG)") + parser.add_argument("--negative-prompt", type=str, default=DEFAULT_NEGATIVE_PROMPT, + help="Negative prompt for CFG (dev pipeline only)") + parser.add_argument("--height", "-H", type=int, default=512, help="Output video height") + parser.add_argument("--width", "-W", type=int, default=512, help="Output video width") + parser.add_argument("--num-frames", "-n", type=int, default=33, help="Number of frames") + parser.add_argument("--steps", type=int, default=40, help="Number of inference steps (dev pipeline only)") + parser.add_argument("--cfg-scale", type=float, default=4.0, help="CFG guidance scale (dev pipeline only)") + parser.add_argument("--seed", "-s", type=int, default=42, help="Random seed") + parser.add_argument("--fps", type=int, default=24, help="Frames per second") + parser.add_argument("--output-path", "-o", type=str, default="output.mp4", help="Output video path") + parser.add_argument("--save-frames", action="store_true", help="Save individual frames as images") + parser.add_argument("--model-repo", type=str, default="Lightricks/LTX-2", help="Model repository") + parser.add_argument("--text-encoder-repo", type=str, default=None, help="Text encoder repository") + parser.add_argument("--verbose", action="store_true", help="Verbose output") + parser.add_argument("--enhance-prompt", action="store_true", help="Enhance the prompt using Gemma") + parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens for prompt enhancement") + parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for prompt enhancement") + parser.add_argument("--image", "-i", type=str, default=None, help="Path to conditioning image for I2V") + parser.add_argument("--image-strength", type=float, default=1.0, help="Conditioning strength for I2V") + parser.add_argument("--image-frame-idx", type=int, default=0, help="Frame index to condition for I2V") + parser.add_argument("--tiling", type=str, default="auto", + choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"], + help="Tiling mode for VAE decoding") + parser.add_argument("--stream", action="store_true", help="Stream frames to output as they're decoded") + parser.add_argument("--audio", "-a", action="store_true", help="Enable synchronized audio generation") + parser.add_argument("--output-audio", type=str, default=None, help="Output audio path") + args = parser.parse_args() + pipeline = PipelineType.DEV if args.pipeline == "dev" else PipelineType.DISTILLED + generate_video( model_repo=args.model_repo, text_encoder_repo=args.text_encoder_repo, prompt=args.prompt, + pipeline=pipeline, + negative_prompt=args.negative_prompt, height=args.height, width=args.width, num_frames=args.num_frames, + num_inference_steps=args.steps, + cfg_scale=args.cfg_scale, seed=args.seed, fps=args.fps, output_path=args.output_path, diff --git a/mlx_video/generate_dev.py b/mlx_video/generate_dev.py deleted file mode 100644 index 1d2a041..0000000 --- a/mlx_video/generate_dev.py +++ /dev/null @@ -1,1281 +0,0 @@ -""" -Copyright (c) 2026, Prince Canuma and contributors (https://github.com/Blaizzy/mlx-video) - -LTX-2 Dev Model Generation Pipeline - -This module provides a single-stage video generation pipeline using the LTX-2 19B dev model. -Unlike the distilled model which uses fixed sigma schedules, the dev model uses: -- Dynamic sigma scheduling via LTX2Scheduler -- Classifier-Free Guidance (CFG) for better prompt adherence -- More inference steps (default 40) -""" - -import argparse -import math -import time -from pathlib import Path -from typing import Optional - -import mlx.core as mx -import numpy as np -from PIL import Image -from tqdm import tqdm - -# ANSI color codes -class Colors: - CYAN = "\033[96m" - BLUE = "\033[94m" - GREEN = "\033[92m" - YELLOW = "\033[93m" - RED = "\033[91m" - MAGENTA = "\033[95m" - BOLD = "\033[1m" - DIM = "\033[2m" - RESET = "\033[0m" - - -from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType -from mlx_video.models.ltx.ltx import LTXModel -from mlx_video.models.ltx.transformer import Modality -from mlx_video.convert import sanitize_transformer_weights, sanitize_audio_vae_weights, sanitize_vocoder_weights -from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding -from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder -from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder -from mlx_video.models.ltx.video_vae.tiling import TilingConfig -from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning -from mlx_video.conditioning.latent import LatentState, apply_denoise_mask -from mlx_video.utils import get_model_path - - -# Default values matching PyTorch implementation -DEFAULT_NEGATIVE_PROMPT = ( - "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " - "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " - "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " - "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " - "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " - "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " - "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " - "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " - "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " - "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " - "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." -) - -BASE_SHIFT_ANCHOR = 1024 -MAX_SHIFT_ANCHOR = 4096 - -# Audio constants -AUDIO_SAMPLE_RATE = 24000 # Output audio sample rate -AUDIO_LATENT_SAMPLE_RATE = 16000 # VAE internal sample rate -AUDIO_HOP_LENGTH = 160 -AUDIO_LATENT_DOWNSAMPLE_FACTOR = 4 -AUDIO_LATENT_CHANNELS = 8 # Latent channels before patchifying -AUDIO_MEL_BINS = 16 -AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25 - - -def ltx2_scheduler( - steps: int, - num_tokens: Optional[int] = None, - max_shift: float = 2.05, - base_shift: float = 0.95, - stretch: bool = True, - terminal: float = 0.1, -) -> mx.array: - """ - LTX-2 scheduler for sigma generation. - - Generates a sigma schedule with token-count-dependent shifting and optional - stretching to a terminal value. - - Args: - steps: Number of inference steps - num_tokens: Number of latent tokens (F*H*W). If None, uses MAX_SHIFT_ANCHOR - max_shift: Maximum shift factor - base_shift: Base shift factor - stretch: Whether to stretch sigmas to terminal value - terminal: Terminal sigma value for stretching - - Returns: - Array of sigma values of shape (steps + 1,) - """ - tokens = num_tokens if num_tokens is not None else MAX_SHIFT_ANCHOR - sigmas = np.linspace(1.0, 0.0, steps + 1) - - # Compute shift based on token count - x1 = BASE_SHIFT_ANCHOR - x2 = MAX_SHIFT_ANCHOR - mm = (max_shift - base_shift) / (x2 - x1) - b = base_shift - mm * x1 - sigma_shift = tokens * mm + b - - # Apply shift transformation - power = 1 - sigmas = np.where( - sigmas != 0, - math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power), - 0, - ) - - # Stretch sigmas to terminal value - if stretch: - non_zero_mask = sigmas != 0 - non_zero_sigmas = sigmas[non_zero_mask] - one_minus_z = 1.0 - non_zero_sigmas - scale_factor = one_minus_z[-1] / (1.0 - terminal) - stretched = 1.0 - (one_minus_z / scale_factor) - sigmas[non_zero_mask] = stretched - - return mx.array(sigmas, dtype=mx.float32) - - -def create_position_grid( - batch_size: int, - num_frames: int, - height: int, - width: int, - temporal_scale: int = 8, - spatial_scale: int = 32, - fps: float = 24.0, - causal_fix: bool = True, -) -> mx.array: - """Create position grid for RoPE in pixel space. - - Args: - batch_size: Batch size - num_frames: Number of frames (latent) - height: Height (latent) - width: Width (latent) - temporal_scale: VAE temporal scale factor (default 8) - spatial_scale: VAE spatial scale factor (default 32) - fps: Frames per second (default 24.0) - causal_fix: Apply causal fix for first frame (default True) - - Returns: - Position grid of shape (B, 3, num_patches, 2) in pixel space - where dim 2 is [start, end) bounds for each patch - """ - # Patch size is (1, 1, 1) for LTX-2 - no spatial patching - patch_size_t, patch_size_h, patch_size_w = 1, 1, 1 - - # Generate grid coordinates for each dimension (frame, height, width) - t_coords = np.arange(0, num_frames, patch_size_t) - h_coords = np.arange(0, height, patch_size_h) - w_coords = np.arange(0, width, patch_size_w) - - # Create meshgrid with indexing='ij' for (frame, height, width) order - t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij') - - # Stack to get shape (3, grid_t, grid_h, grid_w) - patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0) - - # Calculate end coordinates (start + patch_size) - patch_size_delta = np.array([patch_size_t, patch_size_h, patch_size_w]).reshape(3, 1, 1, 1) - patch_ends = patch_starts + patch_size_delta - - # Stack start and end: shape (3, grid_t, grid_h, grid_w, 2) - latent_coords = np.stack([patch_starts, patch_ends], axis=-1) - - # Flatten spatial/temporal dims: (3, num_patches, 2) - num_patches = num_frames * height * width - latent_coords = latent_coords.reshape(3, num_patches, 2) - - # Broadcast to batch: (batch, 3, num_patches, 2) - latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1)) - - # Convert latent coords to pixel coords by scaling with VAE factors - scale_factors = np.array([temporal_scale, spatial_scale, spatial_scale]).reshape(1, 3, 1, 1) - pixel_coords = (latent_coords * scale_factors).astype(np.float32) - - # Apply causal fix for first frame temporal axis - if causal_fix: - # VAE temporal stride for first frame is 1 instead of temporal_scale - pixel_coords[:, 0, :, :] = np.clip( - pixel_coords[:, 0, :, :] + 1 - temporal_scale, - a_min=0, - a_max=None - ) - - # Convert temporal to time in seconds by dividing by fps - pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps - - # Always return float32 for RoPE precision - bfloat16 causes quality degradation - return mx.array(pixel_coords, dtype=mx.float32) - - -def create_audio_position_grid( - batch_size: int, - audio_frames: int, - sample_rate: int = AUDIO_LATENT_SAMPLE_RATE, - hop_length: int = AUDIO_HOP_LENGTH, - downsample_factor: int = AUDIO_LATENT_DOWNSAMPLE_FACTOR, - is_causal: bool = True, -) -> mx.array: - """Create temporal position grid for audio RoPE. - - Audio positions are timestamps in seconds, shape (B, 1, T, 2). - Matches PyTorch's AudioPatchifier.get_patch_grid_bounds exactly. - - Args: - batch_size: Batch size - audio_frames: Number of audio latent frames - sample_rate: Audio sample rate (default 16000) - hop_length: Hop length for mel spectrogram (default 160) - downsample_factor: Latent downsample factor (default 4) - is_causal: Whether to use causal alignment (default True) - - Returns: - Position grid of shape (B, 1, T, 2) - """ - def get_audio_latent_time_in_sec(start_idx: int, end_idx: int) -> np.ndarray: - """Convert latent indices to seconds.""" - latent_frame = np.arange(start_idx, end_idx, dtype=np.float32) - mel_frame = latent_frame * downsample_factor - if is_causal: - mel_frame = np.clip(mel_frame + 1 - downsample_factor, 0, None) - return mel_frame * hop_length / sample_rate - - start_times = get_audio_latent_time_in_sec(0, audio_frames) - end_times = get_audio_latent_time_in_sec(1, audio_frames + 1) - - positions = np.stack([start_times, end_times], axis=-1) - positions = positions[np.newaxis, np.newaxis, :, :] # (1, 1, T, 2) - positions = np.tile(positions, (batch_size, 1, 1, 1)) - - return mx.array(positions, dtype=mx.float32) - - -def compute_audio_frames(num_video_frames: int, fps: float) -> int: - """Compute number of audio latent frames given video duration.""" - duration = num_video_frames / fps - return round(duration * AUDIO_LATENTS_PER_SECOND) - - -def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array: - """Compute CFG (Classifier-Free Guidance) delta. - - Args: - cond: Conditioned prediction - uncond: Unconditioned prediction - scale: Guidance scale (1.0 = no guidance) - - Returns: - CFG delta to add to conditioned prediction - """ - return (scale - 1.0) * (cond - uncond) - - -def load_audio_decoder(model_path: Path): - """Load audio VAE decoder.""" - from mlx_video.models.ltx.audio_vae import AudioDecoder, CausalityAxis, NormType - - decoder = AudioDecoder( - ch=128, - out_ch=2, # stereo - ch_mult=(1, 2, 4), - num_res_blocks=2, - attn_resolutions=set(), # PyTorch uses empty set (no attention in audio decoder) - resolution=256, - z_channels=AUDIO_LATENT_CHANNELS, - norm_type=NormType.PIXEL, - causality_axis=CausalityAxis.HEIGHT, - mel_bins=64, # Output mel bins - ) - - # Load weights - try dev model first, fall back to distilled - weight_file = model_path / "ltx-2-19b-dev.safetensors" - if not weight_file.exists(): - weight_file = model_path / "ltx-2-19b-distilled.safetensors" - - if weight_file.exists(): - print(f"Loading audio VAE decoder from {weight_file}...") - raw_weights = mx.load(str(weight_file)) - sanitized = sanitize_audio_vae_weights(raw_weights) - if sanitized: - decoder.load_weights(list(sanitized.items()), strict=False) - - # Manually load per-channel statistics - if "per_channel_statistics._mean_of_means" in sanitized: - decoder.per_channel_statistics._mean_of_means = sanitized["per_channel_statistics._mean_of_means"] - if "per_channel_statistics._std_of_means" in sanitized: - decoder.per_channel_statistics._std_of_means = sanitized["per_channel_statistics._std_of_means"] - - return decoder - - -def load_vocoder(model_path: Path): - """Load vocoder for mel to waveform conversion.""" - from mlx_video.models.ltx.audio_vae import Vocoder - - vocoder = Vocoder( - resblock_kernel_sizes=[3, 7, 11], - upsample_rates=[6, 5, 2, 2, 2], - upsample_kernel_sizes=[16, 15, 8, 4, 4], - resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], - upsample_initial_channel=1024, - stereo=True, - output_sample_rate=AUDIO_SAMPLE_RATE, - ) - - # Load weights - try dev model first, fall back to distilled - weight_file = model_path / "ltx-2-19b-dev.safetensors" - if not weight_file.exists(): - weight_file = model_path / "ltx-2-19b-distilled.safetensors" - - if weight_file.exists(): - raw_weights = mx.load(str(weight_file)) - sanitized = sanitize_vocoder_weights(raw_weights) - if sanitized: - vocoder.load_weights(list(sanitized.items()), strict=False) - - return vocoder - - -def save_audio(audio: np.ndarray, path: Path, sample_rate: int = AUDIO_SAMPLE_RATE): - """Save audio to WAV file.""" - import wave - - # Ensure audio is in correct format (channels, samples) or (samples,) - if audio.ndim == 2: - # (channels, samples) -> (samples, channels) - audio = audio.T - - # Normalize and convert to int16 - audio = np.clip(audio, -1.0, 1.0) - audio_int16 = (audio * 32767).astype(np.int16) - - with wave.open(str(path), 'wb') as wf: - wf.setnchannels(2 if audio_int16.ndim == 2 else 1) - wf.setsampwidth(2) # 16-bit - wf.setframerate(sample_rate) - wf.writeframes(audio_int16.tobytes()) - - -def mux_video_audio(video_path: Path, audio_path: Path, output_path: Path) -> bool: - """Combine video and audio into final output using ffmpeg.""" - import subprocess - - cmd = [ - "ffmpeg", "-y", - "-i", str(video_path), - "-i", str(audio_path), - "-c:v", "copy", - "-c:a", "aac", - "-shortest", - str(output_path) - ] - - try: - subprocess.run(cmd, check=True, capture_output=True) - return True - except subprocess.CalledProcessError as e: - print(f"{Colors.RED}FFmpeg error: {e.stderr.decode()}{Colors.RESET}") - return False - except FileNotFoundError: - print(f"{Colors.RED}FFmpeg not found. Please install ffmpeg.{Colors.RESET}") - return False - - -def denoise_with_cfg( - latents: mx.array, - positions: mx.array, - text_embeddings_pos: mx.array, - text_embeddings_neg: mx.array, - transformer: LTXModel, - sigmas: mx.array, - cfg_scale: float = 4.0, - verbose: bool = True, - state: Optional[LatentState] = None, -) -> mx.array: - """Run denoising loop with CFG (Classifier-Free Guidance). - - Uses separate forward passes for positive and negative conditioning - to match PyTorch implementation behavior (avoids potential issues with - batched attention patterns). - - Args: - latents: Noisy latent tensor (B, C, F, H, W) - positions: Position embeddings - text_embeddings_pos: Positive (prompt) text conditioning embeddings - text_embeddings_neg: Negative prompt text conditioning embeddings - transformer: LTX model - sigmas: Array of sigma values for denoising schedule - cfg_scale: Guidance scale (default 4.0, 1.0 = no guidance) - verbose: Whether to show progress bar - state: Optional LatentState for I2V conditioning - - Returns: - Denoised latent tensor - """ - from mlx_video.models.ltx.rope import precompute_freqs_cis - - dtype = latents.dtype - if state is not None: - latents = state.latent - - sigmas_list = sigmas.tolist() - use_cfg = cfg_scale != 1.0 - - # Precompute RoPE once (expensive operation due to NumPy conversion for double precision) - # This avoids recomputing it every forward pass - precomputed_rope = precompute_freqs_cis( - positions, - dim=transformer.inner_dim, - theta=transformer.positional_embedding_theta, - max_pos=transformer.positional_embedding_max_pos, - use_middle_indices_grid=transformer.use_middle_indices_grid, - num_attention_heads=transformer.num_attention_heads, - rope_type=transformer.rope_type, - double_precision=transformer.config.double_precision_rope, - ) - mx.eval(precomputed_rope) - - for i in tqdm(range(len(sigmas_list) - 1), desc="Denoising", disable=not verbose): - sigma = sigmas_list[i] - sigma_next = sigmas_list[i + 1] - - b, c, f, h, w = latents.shape - num_tokens = f * h * w - latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) - - # Compute per-token timesteps - if state is not None: - denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1)) - denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w)) - denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens)) - timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat - else: - timesteps = mx.full((b, num_tokens), sigma, dtype=dtype) - - # First forward pass: positive conditioning - video_modality_pos = Modality( - latent=latents_flat, - timesteps=timesteps, - positions=positions, - context=text_embeddings_pos, - context_mask=None, - enabled=True, - positional_embeddings=precomputed_rope, - ) - velocity_pos, _ = transformer(video=video_modality_pos, audio=None) - - if use_cfg: - # Second forward pass: negative conditioning - video_modality_neg = Modality( - latent=latents_flat, - timesteps=timesteps, - positions=positions, - context=text_embeddings_neg, - context_mask=None, - enabled=True, - positional_embeddings=precomputed_rope, - ) - velocity_neg, _ = transformer(video=video_modality_neg, audio=None) - - # Apply CFG: velocity = pos + (scale - 1) * (pos - neg) - velocity_flat = velocity_pos + (cfg_scale - 1.0) * (velocity_pos - velocity_neg) - else: - velocity_flat = velocity_pos - - # Reshape back to 5D - velocity = mx.reshape(mx.transpose(velocity_flat, (0, 2, 1)), (b, c, f, h, w)) - denoised = to_denoised(latents, velocity, sigma) - - # Apply conditioning mask if state is provided - if state is not None: - denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask) - - # Euler step - if sigma_next > 0: - sigma_next_arr = mx.array(sigma_next, dtype=dtype) - sigma_arr = mx.array(sigma, dtype=dtype) - latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr - else: - latents = denoised - - # Single eval at end of step (lazy evaluation handles the rest) - mx.eval(latents) - - return latents - - -def denoise_av_with_cfg( - video_latents: mx.array, - audio_latents: mx.array, - video_positions: mx.array, - audio_positions: mx.array, - video_embeddings_pos: mx.array, - video_embeddings_neg: mx.array, - audio_embeddings_pos: mx.array, - audio_embeddings_neg: mx.array, - transformer: LTXModel, - sigmas: mx.array, - cfg_scale: float = 4.0, - verbose: bool = True, - video_state: Optional[LatentState] = None, -) -> tuple[mx.array, mx.array]: - """Run denoising loop for audio-video generation with CFG. - - Uses separate forward passes for positive and negative CFG to ensure - correct audio-video cross-attention behavior (matching PyTorch implementation). - - Args: - video_latents: Video latent tensor (B, C, F, H, W) - audio_latents: Audio latent tensor (B, C, T, F) - video_positions: Video position embeddings - audio_positions: Audio position embeddings - video_embeddings_pos: Positive video text embeddings - video_embeddings_neg: Negative video text embeddings - audio_embeddings_pos: Positive audio text embeddings - audio_embeddings_neg: Negative audio text embeddings - transformer: LTX model - sigmas: Array of sigma values for denoising schedule - cfg_scale: Guidance scale (default 4.0, 1.0 = no guidance) - verbose: Whether to show progress bar - video_state: Optional LatentState for I2V conditioning - - Returns: - Tuple of (video_latents, audio_latents) - """ - from mlx_video.models.ltx.rope import precompute_freqs_cis - - dtype = video_latents.dtype - if video_state is not None: - video_latents = video_state.latent - - sigmas_list = sigmas.tolist() - use_cfg = cfg_scale != 1.0 - - # Precompute video RoPE (single batch, not doubled for CFG) - precomputed_video_rope = precompute_freqs_cis( - video_positions, - dim=transformer.inner_dim, - theta=transformer.positional_embedding_theta, - max_pos=transformer.positional_embedding_max_pos, - use_middle_indices_grid=transformer.use_middle_indices_grid, - num_attention_heads=transformer.num_attention_heads, - rope_type=transformer.rope_type, - double_precision=transformer.config.double_precision_rope, - ) - - # Precompute audio RoPE (1D positions) - precomputed_audio_rope = precompute_freqs_cis( - audio_positions, - dim=transformer.audio_inner_dim, - theta=transformer.positional_embedding_theta, - max_pos=transformer.audio_positional_embedding_max_pos, - use_middle_indices_grid=transformer.use_middle_indices_grid, - num_attention_heads=transformer.audio_num_attention_heads, - rope_type=transformer.rope_type, - double_precision=transformer.config.double_precision_rope, - ) - mx.eval(precomputed_video_rope, precomputed_audio_rope) - - for i in tqdm(range(len(sigmas_list) - 1), desc="Denoising A/V", disable=not verbose): - sigma = sigmas_list[i] - sigma_next = sigmas_list[i + 1] - - # Flatten video latents - b, c, f, h, w = video_latents.shape - num_video_tokens = f * h * w - video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1)) - - # Flatten audio latents: (B, C, T, F) -> (B, T, C*F) - ab, ac, at, af = audio_latents.shape - audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F) - audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)) - - # Compute per-token timesteps for video - if video_state is not None: - denoise_mask_flat = mx.reshape(video_state.denoise_mask, (b, 1, f, 1, 1)) - denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w)) - denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens)) - video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat - else: - video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype) - - audio_timesteps = mx.full((ab, at), sigma, dtype=dtype) - - # First forward pass: positive conditioning - video_modality_pos = Modality( - latent=video_flat, - timesteps=video_timesteps, - positions=video_positions, - context=video_embeddings_pos, - context_mask=None, - enabled=True, - positional_embeddings=precomputed_video_rope, - ) - - audio_modality_pos = Modality( - latent=audio_flat, - timesteps=audio_timesteps, - positions=audio_positions, - context=audio_embeddings_pos, - context_mask=None, - enabled=True, - positional_embeddings=precomputed_audio_rope, - ) - - video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos) - - if use_cfg: - # Second forward pass: negative conditioning - video_modality_neg = Modality( - latent=video_flat, - timesteps=video_timesteps, - positions=video_positions, - context=video_embeddings_neg, - context_mask=None, - enabled=True, - positional_embeddings=precomputed_video_rope, - ) - - audio_modality_neg = Modality( - latent=audio_flat, - timesteps=audio_timesteps, - positions=audio_positions, - context=audio_embeddings_neg, - context_mask=None, - enabled=True, - positional_embeddings=precomputed_audio_rope, - ) - - video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg) - - # Apply CFG: denoised = pos + (scale - 1) * (pos - neg) - video_velocity_flat = video_vel_pos + (cfg_scale - 1.0) * (video_vel_pos - video_vel_neg) - audio_velocity_flat = audio_vel_pos + (cfg_scale - 1.0) * (audio_vel_pos - audio_vel_neg) - else: - video_velocity_flat = video_vel_pos - audio_velocity_flat = audio_vel_pos - - # Reshape velocities back - video_velocity = mx.reshape(mx.transpose(video_velocity_flat, (0, 2, 1)), (b, c, f, h, w)) - audio_velocity = mx.reshape(audio_velocity_flat, (ab, at, ac, af)) - audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) # (B, C, T, F) - - # Compute denoised - video_denoised = to_denoised(video_latents, video_velocity, sigma) - audio_denoised = to_denoised(audio_latents, audio_velocity, sigma) - - # Apply conditioning mask for video if state is provided - if video_state is not None: - video_denoised = apply_denoise_mask(video_denoised, video_state.clean_latent, video_state.denoise_mask) - - # Euler step - if sigma_next > 0: - sigma_next_arr = mx.array(sigma_next, dtype=dtype) - sigma_arr = mx.array(sigma, dtype=dtype) - video_latents = video_denoised + sigma_next_arr * (video_latents - video_denoised) / sigma_arr - audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr - else: - video_latents = video_denoised - audio_latents = audio_denoised - - mx.eval(video_latents, audio_latents) - - return video_latents, audio_latents - - -def generate_video_dev( - model_repo: str, - text_encoder_repo: str, - prompt: str, - negative_prompt: str = DEFAULT_NEGATIVE_PROMPT, - height: int = 512, - width: int = 768, - num_frames: int = 33, - num_inference_steps: int = 40, - cfg_scale: float = 4.0, - seed: int = 42, - fps: int = 24, - output_path: str = "output.mp4", - output_audio_path: Optional[str] = None, - save_frames: bool = False, - verbose: bool = True, - enhance_prompt: bool = False, - max_tokens: int = 512, - temperature: float = 0.7, - image: Optional[str] = None, - image_strength: float = 1.0, - image_frame_idx: int = 0, - tiling: str = "none", - audio: bool = False, -): - """Generate video using LTX-2 dev model with CFG. - - This is a single-stage pipeline that uses the full dev model with - Classifier-Free Guidance for better prompt adherence. - - Args: - model_repo: Model repository ID - text_encoder_repo: Text encoder repository ID - prompt: Text description of the video to generate - negative_prompt: Negative prompt for CFG - height: Output video height (must be divisible by 32) - width: Output video width (must be divisible by 32) - num_frames: Number of frames (must be 1 + 8*k, e.g., 33, 65, 97) - num_inference_steps: Number of denoising steps (default 40) - cfg_scale: Guidance scale for CFG (default 4.0) - seed: Random seed for reproducibility - fps: Frames per second for output video - output_path: Path to save the output video - output_audio_path: Path to save audio (if audio=True) - save_frames: Whether to save individual frames as images - verbose: Whether to print progress - enhance_prompt: Whether to enhance prompt using Gemma - max_tokens: Max tokens for prompt enhancement - temperature: Temperature for prompt enhancement - image: Path to conditioning image for I2V (Image-to-Video) - image_strength: Conditioning strength (1.0 = full denoise, 0.0 = keep original) - image_frame_idx: Frame index to condition (0 = first frame) - tiling: Tiling mode for VAE decoding - audio: Whether to generate synchronized audio - """ - start_time = time.time() - - # Validate dimensions - assert height % 32 == 0, f"Height must be divisible by 32, got {height}" - assert width % 32 == 0, f"Width must be divisible by 32, got {width}" - - if num_frames % 8 != 1: - adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1 - print(f"{Colors.YELLOW}Warning: Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}") - num_frames = adjusted_num_frames - - # Calculate audio frames if audio is enabled - audio_frames = compute_audio_frames(num_frames, fps) if audio else 0 - - is_i2v = image is not None - mode_str = "I2V" if is_i2v else "T2V" - if audio: - mode_str += "+Audio" - print(f"{Colors.BOLD}{Colors.CYAN}[DEV] [{mode_str}] Generating {width}x{height} video with {num_frames} frames{Colors.RESET}") - print(f"{Colors.DIM}Steps: {num_inference_steps}, CFG: {cfg_scale}{Colors.RESET}") - if audio: - print(f"{Colors.DIM}Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz{Colors.RESET}") - print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}") - if is_i2v: - print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}") - - # Get model path - model_path = get_model_path(model_repo) - text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo) - - # Calculate latent dimensions (single-stage, no upsampling) - latent_h, latent_w = height // 32, width // 32 - latent_frames = 1 + (num_frames - 1) // 8 - - mx.random.seed(seed) - - # Load text encoder - print(f"{Colors.BLUE}Loading text encoder...{Colors.RESET}") - from mlx_video.models.ltx.text_encoder import LTX2TextEncoder - text_encoder = LTX2TextEncoder() - text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path) - mx.eval(text_encoder.parameters()) - - # Optionally enhance the prompt - if enhance_prompt: - print(f"{Colors.MAGENTA}Enhancing prompt...{Colors.RESET}") - prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose) - print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}") - - # Encode both positive and negative prompts - if audio: - video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True) - video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True) - model_dtype = video_embeddings_pos.dtype - mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg) - else: - video_embeddings_pos, _ = text_encoder(prompt, return_audio_embeddings=False) - video_embeddings_neg, _ = text_encoder(negative_prompt, return_audio_embeddings=False) - audio_embeddings_pos = None - audio_embeddings_neg = None - model_dtype = video_embeddings_pos.dtype - mx.eval(video_embeddings_pos, video_embeddings_neg) - - del text_encoder - mx.clear_cache() - - # Load transformer (dev model) - print(f"{Colors.BLUE}Loading dev transformer{' (A/V mode)' if audio else ''}...{Colors.RESET}") - raw_weights = mx.load(str(model_path / 'ltx-2-19b-dev.safetensors')) - sanitized = sanitize_transformer_weights(raw_weights) - # Convert transformer weights to bfloat16 for memory efficiency - sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()} - - if audio: - config = LTXModelConfig( - model_type=LTXModelType.AudioVideo, - num_attention_heads=32, - attention_head_dim=128, - in_channels=128, - out_channels=128, - num_layers=48, - cross_attention_dim=4096, - caption_channels=3840, - # Audio config - audio_num_attention_heads=32, - audio_attention_head_dim=64, - audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, # 8 * 16 = 128 - audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, - audio_cross_attention_dim=2048, - rope_type=LTXRopeType.SPLIT, - double_precision_rope=True, - positional_embedding_theta=10000.0, - positional_embedding_max_pos=[20, 2048, 2048], - audio_positional_embedding_max_pos=[20], - use_middle_indices_grid=True, - timestep_scale_multiplier=1000, - ) - else: - config = LTXModelConfig( - model_type=LTXModelType.VideoOnly, - num_attention_heads=32, - attention_head_dim=128, - in_channels=128, - out_channels=128, - num_layers=48, - cross_attention_dim=4096, - caption_channels=3840, - rope_type=LTXRopeType.SPLIT, - double_precision_rope=True, - positional_embedding_theta=10000.0, - positional_embedding_max_pos=[20, 2048, 2048], - use_middle_indices_grid=True, - timestep_scale_multiplier=1000, - ) - - transformer = LTXModel(config) - transformer.load_weights(list(sanitized.items()), strict=False) - mx.eval(transformer.parameters()) - - # Load VAE encoder for I2V - image_latent = None - if is_i2v: - print(f"{Colors.BLUE}Loading VAE encoder and encoding image...{Colors.RESET}") - vae_encoder = load_vae_encoder(str(model_path / 'ltx-2-19b-dev.safetensors')) - mx.eval(vae_encoder.parameters()) - - input_image = load_image(image, height=height, width=width, dtype=model_dtype) - image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype) - image_latent = vae_encoder(image_tensor) - mx.eval(image_latent) - print(f" Image latent: {image_latent.shape}") - - del vae_encoder - mx.clear_cache() - - # Generate sigma schedule - num_tokens = latent_frames * latent_h * latent_w - sigmas = ltx2_scheduler( - steps=num_inference_steps, - num_tokens=num_tokens, - ) - mx.eval(sigmas) - print(f"{Colors.DIM}Sigma schedule: {sigmas[0].item():.4f} -> {sigmas[-2].item():.4f} -> {sigmas[-1].item():.4f}{Colors.RESET}") - - # Create position grids - print(f"{Colors.YELLOW}Generating at {width}x{height} ({num_inference_steps} steps, CFG={cfg_scale})...{Colors.RESET}") - mx.random.seed(seed) - - video_positions = create_position_grid(1, latent_frames, latent_h, latent_w) - mx.eval(video_positions) - - if audio: - audio_positions = create_audio_position_grid(1, audio_frames) - mx.eval(audio_positions) - else: - audio_positions = None - - # Initialize latents with optional I2V conditioning - video_state = None - video_latent_shape = (1, 128, latent_frames, latent_h, latent_w) - if is_i2v and image_latent is not None: - video_state = LatentState( - latent=mx.zeros(video_latent_shape, dtype=model_dtype), - clean_latent=mx.zeros(video_latent_shape, dtype=model_dtype), - denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), - ) - conditioning = VideoConditionByLatentIndex( - latent=image_latent, - frame_idx=image_frame_idx, - strength=image_strength, - ) - video_state = apply_conditioning(video_state, [conditioning]) - - # Apply noiser - noise = mx.random.normal(video_latent_shape, dtype=model_dtype) - noise_scale = sigmas[0] - scaled_mask = video_state.denoise_mask * noise_scale - - video_state = LatentState( - latent=noise * scaled_mask + video_state.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), - clean_latent=video_state.clean_latent, - denoise_mask=video_state.denoise_mask, - ) - video_latents = video_state.latent - mx.eval(video_latents) - else: - # T2V: just use random noise - video_latents = mx.random.normal(video_latent_shape, dtype=model_dtype) - mx.eval(video_latents) - - # Initialize audio latents if audio is enabled - if audio: - audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype) - mx.eval(audio_latents) - else: - audio_latents = None - - # Denoise with CFG - if audio: - video_latents, audio_latents = denoise_av_with_cfg( - video_latents, audio_latents, - video_positions, audio_positions, - video_embeddings_pos, video_embeddings_neg, - audio_embeddings_pos, audio_embeddings_neg, - transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, video_state=video_state - ) - else: - video_latents = denoise_with_cfg( - video_latents, video_positions, video_embeddings_pos, video_embeddings_neg, - transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, state=video_state - ) - - del transformer - mx.clear_cache() - - # Decode to video - print(f"{Colors.BLUE}Decoding video...{Colors.RESET}") - vae_decoder = load_vae_decoder( - str(model_path / 'ltx-2-19b-dev.safetensors'), - timestep_conditioning=None - ) - mx.eval(vae_decoder.parameters()) - - # Select tiling configuration - if tiling == "none": - tiling_config = None - elif tiling == "auto": - tiling_config = TilingConfig.auto(height, width, num_frames) - elif tiling == "default": - tiling_config = TilingConfig.default() - elif tiling == "aggressive": - tiling_config = TilingConfig.aggressive() - elif tiling == "conservative": - tiling_config = TilingConfig.conservative() - elif tiling == "spatial": - tiling_config = TilingConfig.spatial_only() - elif tiling == "temporal": - tiling_config = TilingConfig.temporal_only() - else: - print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}") - tiling_config = TilingConfig.auto(height, width, num_frames) - - if tiling_config is not None: - spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none" - temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none" - print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}") - video = vae_decoder.decode_tiled(video_latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose) - else: - print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}") - video = vae_decoder(video_latents) - mx.eval(video) - - del vae_decoder - mx.clear_cache() - - # Decode audio if enabled - audio_np = None - if audio and audio_latents is not None: - print(f"{Colors.BLUE}Decoding audio...{Colors.RESET}") - - # Load audio decoder - audio_decoder = load_audio_decoder(model_path) - mx.eval(audio_decoder.parameters()) - - # Decode audio latents to mel spectrogram - mel_spectrogram = audio_decoder(audio_latents) - mx.eval(mel_spectrogram) - - del audio_decoder - mx.clear_cache() - - # Load vocoder and convert mel to waveform - vocoder = load_vocoder(model_path) - mx.eval(vocoder.parameters()) - - audio_waveform = vocoder(mel_spectrogram) - mx.eval(audio_waveform) - - del vocoder - mx.clear_cache() - - # Convert to numpy - audio_np = np.array(audio_waveform) - if audio_np.ndim == 3: - audio_np = audio_np[0] # Remove batch dim - - print(f"{Colors.DIM} Audio shape: {audio_np.shape}, duration: {audio_np.shape[-1] / AUDIO_SAMPLE_RATE:.2f}s{Colors.RESET}") - - # Convert video to uint8 frames - video = mx.squeeze(video, axis=0) # (C, F, H, W) - video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C) - video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0) - video = (video * 255).astype(mx.uint8) - video_np = np.array(video) - - # Save outputs - output_path = Path(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - - # Determine audio output path - if audio and audio_np is not None: - if output_audio_path is None: - audio_output = output_path.parent / f"{output_path.stem}.wav" - else: - audio_output = Path(output_audio_path) - - # Save audio - save_audio(audio_np, audio_output) - print(f"{Colors.GREEN}Saved audio to{Colors.RESET} {audio_output}") - - # Save video (to temp file if we need to mux with audio) - if audio and audio_np is not None: - # Save video to temp file, then mux with audio - temp_video_path = output_path.parent / f"{output_path.stem}_temp.mp4" - video_save_path = temp_video_path - else: - video_save_path = output_path - - try: - import cv2 - h, w = video_np.shape[1], video_np.shape[2] - fourcc = cv2.VideoWriter_fourcc(*'avc1') - out = cv2.VideoWriter(str(video_save_path), fourcc, fps, (w, h)) - for frame in video_np: - out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) - out.release() - - if audio and audio_np is not None: - # Mux video and audio - print(f"{Colors.BLUE}Muxing video and audio...{Colors.RESET}") - if mux_video_audio(temp_video_path, audio_output, output_path): - print(f"{Colors.GREEN}Saved video with audio to{Colors.RESET} {output_path}") - # Clean up temp file - temp_video_path.unlink(missing_ok=True) - else: - # Fallback: keep separate files - print(f"{Colors.YELLOW}Could not mux, keeping separate files{Colors.RESET}") - temp_video_path.rename(output_path.parent / f"{output_path.stem}_video.mp4") - else: - print(f"{Colors.GREEN}Saved video to{Colors.RESET} {output_path}") - except Exception as e: - print(f"{Colors.RED}Could not save video: {e}{Colors.RESET}") - - if save_frames: - frames_dir = output_path.parent / f"{output_path.stem}_frames" - frames_dir.mkdir(exist_ok=True) - for i, frame in enumerate(video_np): - Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png") - print(f"{Colors.GREEN}Saved {len(video_np)} frames to {frames_dir}{Colors.RESET}") - - elapsed = time.time() - start_time - print(f"{Colors.BOLD}{Colors.GREEN}Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}") - print(f"{Colors.BOLD}{Colors.GREEN}Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}") - - return video_np - - -def main(): - parser = argparse.ArgumentParser( - description="Generate videos with MLX LTX-2 Dev Model (with CFG)", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Text-to-Video (T2V) with dev model - python -m mlx_video.generate_dev --prompt "A cat walking on grass" - python -m mlx_video.generate_dev --prompt "Ocean waves at sunset" --cfg-scale 6.0 --steps 50 - - # With custom negative prompt - python -m mlx_video.generate_dev --prompt "..." --negative-prompt "blurry, low quality" - - # Image-to-Video (I2V) - python -m mlx_video.generate_dev --prompt "A person dancing" --image photo.jpg - - # With synchronized audio - python -m mlx_video.generate_dev --prompt "Ocean waves crashing on rocks" --audio - python -m mlx_video.generate_dev --prompt "A busy city street" --audio --output-audio street.wav - """ - ) - - parser.add_argument( - "--prompt", "-p", - type=str, - required=True, - help="Text description of the video to generate" - ) - parser.add_argument( - "--negative-prompt", - type=str, - default=DEFAULT_NEGATIVE_PROMPT, - help="Negative prompt for CFG guidance" - ) - parser.add_argument( - "--height", "-H", - type=int, - default=512, - help="Output video height (default: 512, must be divisible by 32)" - ) - parser.add_argument( - "--width", "-W", - type=int, - default=768, - help="Output video width (default: 768, must be divisible by 32)" - ) - parser.add_argument( - "--num-frames", "-n", - type=int, - default=33, - help="Number of frames (default: 33)" - ) - parser.add_argument( - "--steps", - type=int, - default=40, - help="Number of inference steps (default: 40)" - ) - parser.add_argument( - "--cfg-scale", - type=float, - default=4.0, - help="CFG guidance scale (default: 4.0, 1.0 = no guidance)" - ) - parser.add_argument( - "--seed", "-s", - type=int, - default=42, - help="Random seed for reproducibility (default: 42)" - ) - parser.add_argument( - "--fps", - type=int, - default=24, - help="Frames per second for output video (default: 24)" - ) - parser.add_argument( - "--output-path", - type=str, - default="output_dev.mp4", - help="Output video path (default: output_dev.mp4)" - ) - parser.add_argument( - "--save-frames", - action="store_true", - help="Save individual frames as images" - ) - parser.add_argument( - "--model-repo", - type=str, - default="Lightricks/LTX-2", - help="Model repository to use (default: Lightricks/LTX-2)" - ) - parser.add_argument( - "--text-encoder-repo", - type=str, - default=None, - help="Text encoder repository to use (default: None)" - ) - parser.add_argument( - "--verbose", - action="store_true", - help="Verbose output" - ) - parser.add_argument( - "--enhance-prompt", - action="store_true", - help="Enhance the prompt using Gemma before generation" - ) - parser.add_argument( - "--max-tokens", - type=int, - default=512, - help="Maximum number of tokens to generate (default: 512)" - ) - parser.add_argument( - "--temperature", - type=float, - default=0.7, - help="Temperature for prompt enhancement (default: 0.7)" - ) - parser.add_argument( - "--image", "-i", - type=str, - default=None, - help="Path to conditioning image for I2V (Image-to-Video) generation" - ) - parser.add_argument( - "--image-strength", - type=float, - default=1.0, - help="Conditioning strength for I2V (1.0 = full denoise, 0.0 = keep original, default: 1.0)" - ) - parser.add_argument( - "--image-frame-idx", - type=int, - default=0, - help="Frame index to condition for I2V (0 = first frame, default: 0)" - ) - parser.add_argument( - "--tiling", - type=str, - default="none", - choices=["none", "auto", "default", "aggressive", "conservative", "spatial", "temporal"], - help="Tiling mode for VAE decoding (default: none, faster on high-memory systems)" - ) - parser.add_argument( - "--audio", - action="store_true", - help="Generate synchronized audio with the video" - ) - parser.add_argument( - "--output-audio", - type=str, - default=None, - help="Output audio path (default: same as video with .wav extension)" - ) - args = parser.parse_args() - - generate_video_dev( - model_repo=args.model_repo, - text_encoder_repo=args.text_encoder_repo, - prompt=args.prompt, - negative_prompt=args.negative_prompt, - height=args.height, - width=args.width, - num_frames=args.num_frames, - num_inference_steps=args.steps, - cfg_scale=args.cfg_scale, - seed=args.seed, - fps=args.fps, - output_path=args.output_path, - output_audio_path=args.output_audio, - save_frames=args.save_frames, - verbose=args.verbose, - enhance_prompt=args.enhance_prompt, - max_tokens=args.max_tokens, - temperature=args.temperature, - image=args.image, - image_strength=args.image_strength, - image_frame_idx=args.image_frame_idx, - tiling=args.tiling, - audio=args.audio, - ) - - -if __name__ == "__main__": - main()