diff --git a/mlx_video/generate_dev.py b/mlx_video/generate_dev.py index 9ee766d..791c9ba 100644 --- a/mlx_video/generate_dev.py +++ b/mlx_video/generate_dev.py @@ -37,7 +37,7 @@ class Colors: from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType from mlx_video.models.ltx.ltx import LTXModel from mlx_video.models.ltx.transformer import Modality -from mlx_video.convert import sanitize_transformer_weights +from mlx_video.convert import sanitize_transformer_weights, sanitize_audio_vae_weights, sanitize_vocoder_weights from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder @@ -65,6 +65,15 @@ DEFAULT_NEGATIVE_PROMPT = ( BASE_SHIFT_ANCHOR = 1024 MAX_SHIFT_ANCHOR = 4096 +# Audio constants +AUDIO_SAMPLE_RATE = 24000 # Output audio sample rate +AUDIO_LATENT_SAMPLE_RATE = 16000 # VAE internal sample rate +AUDIO_HOP_LENGTH = 160 +AUDIO_LATENT_DOWNSAMPLE_FACTOR = 4 +AUDIO_LATENT_CHANNELS = 8 # Latent channels before patchifying +AUDIO_MEL_BINS = 16 +AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25 + def ltx2_scheduler( steps: int, @@ -195,6 +204,54 @@ def create_position_grid( return mx.array(pixel_coords, dtype=mx.float32) +def create_audio_position_grid( + batch_size: int, + audio_frames: int, + sample_rate: int = AUDIO_LATENT_SAMPLE_RATE, + hop_length: int = AUDIO_HOP_LENGTH, + downsample_factor: int = AUDIO_LATENT_DOWNSAMPLE_FACTOR, + is_causal: bool = True, +) -> mx.array: + """Create temporal position grid for audio RoPE. + + Audio positions are timestamps in seconds, shape (B, 1, T, 2). + Matches PyTorch's AudioPatchifier.get_patch_grid_bounds exactly. + + Args: + batch_size: Batch size + audio_frames: Number of audio latent frames + sample_rate: Audio sample rate (default 16000) + hop_length: Hop length for mel spectrogram (default 160) + downsample_factor: Latent downsample factor (default 4) + is_causal: Whether to use causal alignment (default True) + + Returns: + Position grid of shape (B, 1, T, 2) + """ + def get_audio_latent_time_in_sec(start_idx: int, end_idx: int) -> np.ndarray: + """Convert latent indices to seconds.""" + latent_frame = np.arange(start_idx, end_idx, dtype=np.float32) + mel_frame = latent_frame * downsample_factor + if is_causal: + mel_frame = np.clip(mel_frame + 1 - downsample_factor, 0, None) + return mel_frame * hop_length / sample_rate + + start_times = get_audio_latent_time_in_sec(0, audio_frames) + end_times = get_audio_latent_time_in_sec(1, audio_frames + 1) + + positions = np.stack([start_times, end_times], axis=-1) + positions = positions[np.newaxis, np.newaxis, :, :] # (1, 1, T, 2) + positions = np.tile(positions, (batch_size, 1, 1, 1)) + + return mx.array(positions, dtype=mx.float32) + + +def compute_audio_frames(num_video_frames: int, fps: float) -> int: + """Compute number of audio latent frames given video duration.""" + duration = num_video_frames / fps + return round(duration * AUDIO_LATENTS_PER_SECOND) + + def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array: """Compute CFG (Classifier-Free Guidance) delta. @@ -209,6 +266,116 @@ def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array: return (scale - 1.0) * (cond - uncond) +def load_audio_decoder(model_path: Path): + """Load audio VAE decoder.""" + from mlx_video.models.ltx.audio_vae import AudioDecoder, CausalityAxis, NormType + + decoder = AudioDecoder( + ch=128, + out_ch=2, # stereo + ch_mult=(1, 2, 4), + num_res_blocks=2, + attn_resolutions={8, 16, 32}, + resolution=256, + z_channels=AUDIO_LATENT_CHANNELS, + norm_type=NormType.PIXEL, + causality_axis=CausalityAxis.HEIGHT, + mel_bins=64, # Output mel bins + ) + + # Load weights - try dev model first, fall back to distilled + weight_file = model_path / "ltx-2-19b-dev.safetensors" + if not weight_file.exists(): + weight_file = model_path / "ltx-2-19b-distilled.safetensors" + + if weight_file.exists(): + raw_weights = mx.load(str(weight_file)) + sanitized = sanitize_audio_vae_weights(raw_weights) + if sanitized: + decoder.load_weights(list(sanitized.items()), strict=False) + + # Manually load per-channel statistics + if "per_channel_statistics._mean_of_means" in sanitized: + decoder.per_channel_statistics._mean_of_means = sanitized["per_channel_statistics._mean_of_means"] + if "per_channel_statistics._std_of_means" in sanitized: + decoder.per_channel_statistics._std_of_means = sanitized["per_channel_statistics._std_of_means"] + + return decoder + + +def load_vocoder(model_path: Path): + """Load vocoder for mel to waveform conversion.""" + from mlx_video.models.ltx.audio_vae import Vocoder + + vocoder = Vocoder( + resblock_kernel_sizes=[3, 7, 11], + upsample_rates=[6, 5, 2, 2, 2], + upsample_kernel_sizes=[16, 15, 8, 4, 4], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_initial_channel=1024, + stereo=True, + output_sample_rate=AUDIO_SAMPLE_RATE, + ) + + # Load weights - try dev model first, fall back to distilled + weight_file = model_path / "ltx-2-19b-dev.safetensors" + if not weight_file.exists(): + weight_file = model_path / "ltx-2-19b-distilled.safetensors" + + if weight_file.exists(): + raw_weights = mx.load(str(weight_file)) + sanitized = sanitize_vocoder_weights(raw_weights) + if sanitized: + vocoder.load_weights(list(sanitized.items()), strict=False) + + return vocoder + + +def save_audio(audio: np.ndarray, path: Path, sample_rate: int = AUDIO_SAMPLE_RATE): + """Save audio to WAV file.""" + import wave + + # Ensure audio is in correct format (channels, samples) or (samples,) + if audio.ndim == 2: + # (channels, samples) -> (samples, channels) + audio = audio.T + + # Normalize and convert to int16 + audio = np.clip(audio, -1.0, 1.0) + audio_int16 = (audio * 32767).astype(np.int16) + + with wave.open(str(path), 'wb') as wf: + wf.setnchannels(2 if audio_int16.ndim == 2 else 1) + wf.setsampwidth(2) # 16-bit + wf.setframerate(sample_rate) + wf.writeframes(audio_int16.tobytes()) + + +def mux_video_audio(video_path: Path, audio_path: Path, output_path: Path) -> bool: + """Combine video and audio into final output using ffmpeg.""" + import subprocess + + cmd = [ + "ffmpeg", "-y", + "-i", str(video_path), + "-i", str(audio_path), + "-c:v", "copy", + "-c:a", "aac", + "-shortest", + str(output_path) + ] + + try: + subprocess.run(cmd, check=True, capture_output=True) + return True + except subprocess.CalledProcessError as e: + print(f"{Colors.RED}FFmpeg error: {e.stderr.decode()}{Colors.RESET}") + return False + except FileNotFoundError: + print(f"{Colors.RED}FFmpeg not found. Please install ffmpeg.{Colors.RESET}") + return False + + def denoise_with_cfg( latents: mx.array, positions: mx.array, @@ -222,10 +389,9 @@ def denoise_with_cfg( ) -> mx.array: """Run denoising loop with CFG (Classifier-Free Guidance). - Optimized version that: - 1. Batches positive and negative forward passes together - 2. Precomputes RoPE once and reuses it (avoids expensive NumPy conversion each step) - 3. Minimizes mx.eval() calls for better performance + Uses separate forward passes for positive and negative conditioning + to match PyTorch implementation behavior (avoids potential issues with + batched attention patterns). Args: latents: Noisy latent tensor (B, C, F, H, W) @@ -250,18 +416,10 @@ def denoise_with_cfg( sigmas_list = sigmas.tolist() use_cfg = cfg_scale != 1.0 - # Pre-compute batched context for CFG (concat pos and neg along batch dim) - if use_cfg: - # Shape: (2, seq_len, dim) - batch pos and neg together - batched_context = mx.concatenate([text_embeddings_pos, text_embeddings_neg], axis=0) - batched_positions = mx.concatenate([positions, positions], axis=0) - else: - batched_positions = positions - # Precompute RoPE once (expensive operation due to NumPy conversion for double precision) # This avoids recomputing it every forward pass precomputed_rope = precompute_freqs_cis( - batched_positions, + positions, dim=transformer.inner_dim, theta=transformer.positional_embedding_theta, max_pos=transformer.positional_embedding_max_pos, @@ -289,45 +447,38 @@ def denoise_with_cfg( else: timesteps = mx.full((b, num_tokens), sigma, dtype=dtype) + # First forward pass: positive conditioning + video_modality_pos = Modality( + latent=latents_flat, + timesteps=timesteps, + positions=positions, + context=text_embeddings_pos, + context_mask=None, + enabled=True, + positional_embeddings=precomputed_rope, + ) + velocity_pos, _ = transformer(video=video_modality_pos, audio=None) + if use_cfg: - # Batch both positive and negative in a single forward pass - batched_latents = mx.concatenate([latents_flat, latents_flat], axis=0) - batched_timesteps = mx.concatenate([timesteps, timesteps], axis=0) - - video_modality = Modality( - latent=batched_latents, - timesteps=batched_timesteps, - positions=batched_positions, - context=batched_context, - context_mask=None, - enabled=True, - positional_embeddings=precomputed_rope, # Use precomputed RoPE - ) - - # Single forward pass for both pos and neg - batched_output, _ = transformer(video=video_modality, audio=None) - - # Split results: first half is positive, second half is negative - denoised_pos = batched_output[:1] - denoised_neg = batched_output[1:] - - # Apply CFG: denoised = denoised_pos + (scale - 1) * (denoised_pos - denoised_neg) - denoised_flat = denoised_pos + (cfg_scale - 1.0) * (denoised_pos - denoised_neg) - else: - # No CFG - single forward pass - video_modality = Modality( + # Second forward pass: negative conditioning + video_modality_neg = Modality( latent=latents_flat, timesteps=timesteps, positions=positions, - context=text_embeddings_pos, + context=text_embeddings_neg, context_mask=None, enabled=True, - positional_embeddings=precomputed_rope, # Use precomputed RoPE + positional_embeddings=precomputed_rope, ) - denoised_flat, _ = transformer(video=video_modality, audio=None) + velocity_neg, _ = transformer(video=video_modality_neg, audio=None) + + # Apply CFG: velocity = pos + (scale - 1) * (pos - neg) + velocity_flat = velocity_pos + (cfg_scale - 1.0) * (velocity_pos - velocity_neg) + else: + velocity_flat = velocity_pos # Reshape back to 5D - velocity = mx.reshape(mx.transpose(denoised_flat, (0, 2, 1)), (b, c, f, h, w)) + velocity = mx.reshape(mx.transpose(velocity_flat, (0, 2, 1)), (b, c, f, h, w)) denoised = to_denoised(latents, velocity, sigma) # Apply conditioning mask if state is provided @@ -348,6 +499,185 @@ def denoise_with_cfg( return latents +def denoise_av_with_cfg( + video_latents: mx.array, + audio_latents: mx.array, + video_positions: mx.array, + audio_positions: mx.array, + video_embeddings_pos: mx.array, + video_embeddings_neg: mx.array, + audio_embeddings_pos: mx.array, + audio_embeddings_neg: mx.array, + transformer: LTXModel, + sigmas: mx.array, + cfg_scale: float = 4.0, + verbose: bool = True, + video_state: Optional[LatentState] = None, +) -> tuple[mx.array, mx.array]: + """Run denoising loop for audio-video generation with CFG. + + Uses separate forward passes for positive and negative CFG to ensure + correct audio-video cross-attention behavior (matching PyTorch implementation). + + Args: + video_latents: Video latent tensor (B, C, F, H, W) + audio_latents: Audio latent tensor (B, C, T, F) + video_positions: Video position embeddings + audio_positions: Audio position embeddings + video_embeddings_pos: Positive video text embeddings + video_embeddings_neg: Negative video text embeddings + audio_embeddings_pos: Positive audio text embeddings + audio_embeddings_neg: Negative audio text embeddings + transformer: LTX model + sigmas: Array of sigma values for denoising schedule + cfg_scale: Guidance scale (default 4.0, 1.0 = no guidance) + verbose: Whether to show progress bar + video_state: Optional LatentState for I2V conditioning + + Returns: + Tuple of (video_latents, audio_latents) + """ + from mlx_video.models.ltx.rope import precompute_freqs_cis + + dtype = video_latents.dtype + if video_state is not None: + video_latents = video_state.latent + + sigmas_list = sigmas.tolist() + use_cfg = cfg_scale != 1.0 + + # Precompute video RoPE (single batch, not doubled for CFG) + precomputed_video_rope = precompute_freqs_cis( + video_positions, + dim=transformer.inner_dim, + theta=transformer.positional_embedding_theta, + max_pos=transformer.positional_embedding_max_pos, + use_middle_indices_grid=transformer.use_middle_indices_grid, + num_attention_heads=transformer.num_attention_heads, + rope_type=transformer.rope_type, + double_precision=transformer.config.double_precision_rope, + ) + + # Precompute audio RoPE (1D positions) + precomputed_audio_rope = precompute_freqs_cis( + audio_positions, + dim=transformer.audio_inner_dim, + theta=transformer.positional_embedding_theta, + max_pos=transformer.audio_positional_embedding_max_pos, + use_middle_indices_grid=transformer.use_middle_indices_grid, + num_attention_heads=transformer.audio_num_attention_heads, + rope_type=transformer.rope_type, + double_precision=transformer.config.double_precision_rope, + ) + mx.eval(precomputed_video_rope, precomputed_audio_rope) + + for i in tqdm(range(len(sigmas_list) - 1), desc="Denoising A/V", disable=not verbose): + sigma = sigmas_list[i] + sigma_next = sigmas_list[i + 1] + + # Flatten video latents + b, c, f, h, w = video_latents.shape + num_video_tokens = f * h * w + video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1)) + + # Flatten audio latents: (B, C, T, F) -> (B, T, C*F) + ab, ac, at, af = audio_latents.shape + audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F) + audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)) + + # Compute per-token timesteps for video + if video_state is not None: + denoise_mask_flat = mx.reshape(video_state.denoise_mask, (b, 1, f, 1, 1)) + denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w)) + denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens)) + video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat + else: + video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype) + + audio_timesteps = mx.full((ab, at), sigma, dtype=dtype) + + # First forward pass: positive conditioning + video_modality_pos = Modality( + latent=video_flat, + timesteps=video_timesteps, + positions=video_positions, + context=video_embeddings_pos, + context_mask=None, + enabled=True, + positional_embeddings=precomputed_video_rope, + ) + + audio_modality_pos = Modality( + latent=audio_flat, + timesteps=audio_timesteps, + positions=audio_positions, + context=audio_embeddings_pos, + context_mask=None, + enabled=True, + positional_embeddings=precomputed_audio_rope, + ) + + video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos) + + if use_cfg: + # Second forward pass: negative conditioning + video_modality_neg = Modality( + latent=video_flat, + timesteps=video_timesteps, + positions=video_positions, + context=video_embeddings_neg, + context_mask=None, + enabled=True, + positional_embeddings=precomputed_video_rope, + ) + + audio_modality_neg = Modality( + latent=audio_flat, + timesteps=audio_timesteps, + positions=audio_positions, + context=audio_embeddings_neg, + context_mask=None, + enabled=True, + positional_embeddings=precomputed_audio_rope, + ) + + video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg) + + # Apply CFG: denoised = pos + (scale - 1) * (pos - neg) + video_velocity_flat = video_vel_pos + (cfg_scale - 1.0) * (video_vel_pos - video_vel_neg) + audio_velocity_flat = audio_vel_pos + (cfg_scale - 1.0) * (audio_vel_pos - audio_vel_neg) + else: + video_velocity_flat = video_vel_pos + audio_velocity_flat = audio_vel_pos + + # Reshape velocities back + video_velocity = mx.reshape(mx.transpose(video_velocity_flat, (0, 2, 1)), (b, c, f, h, w)) + audio_velocity = mx.reshape(audio_velocity_flat, (ab, at, ac, af)) + audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) # (B, C, T, F) + + # Compute denoised + video_denoised = to_denoised(video_latents, video_velocity, sigma) + audio_denoised = to_denoised(audio_latents, audio_velocity, sigma) + + # Apply conditioning mask for video if state is provided + if video_state is not None: + video_denoised = apply_denoise_mask(video_denoised, video_state.clean_latent, video_state.denoise_mask) + + # Euler step + if sigma_next > 0: + sigma_next_arr = mx.array(sigma_next, dtype=dtype) + sigma_arr = mx.array(sigma, dtype=dtype) + video_latents = video_denoised + sigma_next_arr * (video_latents - video_denoised) / sigma_arr + audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr + else: + video_latents = video_denoised + audio_latents = audio_denoised + + mx.eval(video_latents, audio_latents) + + return video_latents, audio_latents + + def generate_video_dev( model_repo: str, text_encoder_repo: str, @@ -361,6 +691,7 @@ def generate_video_dev( seed: int = 42, fps: int = 24, output_path: str = "output.mp4", + output_audio_path: Optional[str] = None, save_frames: bool = False, verbose: bool = True, enhance_prompt: bool = False, @@ -370,6 +701,7 @@ def generate_video_dev( image_strength: float = 1.0, image_frame_idx: int = 0, tiling: str = "none", + audio: bool = False, ): """Generate video using LTX-2 dev model with CFG. @@ -389,6 +721,7 @@ def generate_video_dev( seed: Random seed for reproducibility fps: Frames per second for output video output_path: Path to save the output video + output_audio_path: Path to save audio (if audio=True) save_frames: Whether to save individual frames as images verbose: Whether to print progress enhance_prompt: Whether to enhance prompt using Gemma @@ -398,6 +731,7 @@ def generate_video_dev( image_strength: Conditioning strength (1.0 = full denoise, 0.0 = keep original) image_frame_idx: Frame index to condition (0 = first frame) tiling: Tiling mode for VAE decoding + audio: Whether to generate synchronized audio """ start_time = time.time() @@ -410,10 +744,17 @@ def generate_video_dev( print(f"{Colors.YELLOW}Warning: Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}") num_frames = adjusted_num_frames + # Calculate audio frames if audio is enabled + audio_frames = compute_audio_frames(num_frames, fps) if audio else 0 + is_i2v = image is not None mode_str = "I2V" if is_i2v else "T2V" + if audio: + mode_str += "+Audio" print(f"{Colors.BOLD}{Colors.CYAN}[DEV] [{mode_str}] Generating {width}x{height} video with {num_frames} frames{Colors.RESET}") print(f"{Colors.DIM}Steps: {num_inference_steps}, CFG: {cfg_scale}{Colors.RESET}") + if audio: + print(f"{Colors.DIM}Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz{Colors.RESET}") print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}") if is_i2v: print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}") @@ -442,37 +783,70 @@ def generate_video_dev( print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}") # Encode both positive and negative prompts - text_embeddings_pos, _ = text_encoder(prompt, return_audio_embeddings=False) - text_embeddings_neg, _ = text_encoder(negative_prompt, return_audio_embeddings=False) - model_dtype = text_embeddings_pos.dtype - mx.eval(text_embeddings_pos, text_embeddings_neg) + if audio: + video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True) + video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True) + model_dtype = video_embeddings_pos.dtype + mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg) + else: + video_embeddings_pos, _ = text_encoder(prompt, return_audio_embeddings=False) + video_embeddings_neg, _ = text_encoder(negative_prompt, return_audio_embeddings=False) + audio_embeddings_pos = None + audio_embeddings_neg = None + model_dtype = video_embeddings_pos.dtype + mx.eval(video_embeddings_pos, video_embeddings_neg) del text_encoder mx.clear_cache() # Load transformer (dev model) - print(f"{Colors.BLUE}Loading dev transformer...{Colors.RESET}") + print(f"{Colors.BLUE}Loading dev transformer{' (A/V mode)' if audio else ''}...{Colors.RESET}") raw_weights = mx.load(str(model_path / 'ltx-2-19b-dev.safetensors')) sanitized = sanitize_transformer_weights(raw_weights) # Convert transformer weights to bfloat16 for memory efficiency sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()} - config = LTXModelConfig( - model_type=LTXModelType.VideoOnly, - num_attention_heads=32, - attention_head_dim=128, - in_channels=128, - out_channels=128, - num_layers=48, - cross_attention_dim=4096, - caption_channels=3840, - rope_type=LTXRopeType.SPLIT, - double_precision_rope=True, - positional_embedding_theta=10000.0, - positional_embedding_max_pos=[20, 2048, 2048], - use_middle_indices_grid=True, - timestep_scale_multiplier=1000, - ) + if audio: + config = LTXModelConfig( + model_type=LTXModelType.AudioVideo, + num_attention_heads=32, + attention_head_dim=128, + in_channels=128, + out_channels=128, + num_layers=48, + cross_attention_dim=4096, + caption_channels=3840, + # Audio config + audio_num_attention_heads=32, + audio_attention_head_dim=64, + audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, # 8 * 16 = 128 + audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, + audio_cross_attention_dim=2048, + rope_type=LTXRopeType.SPLIT, + double_precision_rope=True, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + audio_positional_embedding_max_pos=[20], + use_middle_indices_grid=True, + timestep_scale_multiplier=1000, + ) + else: + config = LTXModelConfig( + model_type=LTXModelType.VideoOnly, + num_attention_heads=32, + attention_head_dim=128, + in_channels=128, + out_channels=128, + num_layers=48, + cross_attention_dim=4096, + caption_channels=3840, + rope_type=LTXRopeType.SPLIT, + double_precision_rope=True, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + timestep_scale_multiplier=1000, + ) transformer = LTXModel(config) transformer.load_weights(list(sanitized.items()), strict=False) @@ -503,20 +877,26 @@ def generate_video_dev( mx.eval(sigmas) print(f"{Colors.DIM}Sigma schedule: {sigmas[0].item():.4f} -> {sigmas[-2].item():.4f} -> {sigmas[-1].item():.4f}{Colors.RESET}") - # Create position grid + # Create position grids print(f"{Colors.YELLOW}Generating at {width}x{height} ({num_inference_steps} steps, CFG={cfg_scale})...{Colors.RESET}") mx.random.seed(seed) - positions = create_position_grid(1, latent_frames, latent_h, latent_w) - mx.eval(positions) + video_positions = create_position_grid(1, latent_frames, latent_h, latent_w) + mx.eval(video_positions) + + if audio: + audio_positions = create_audio_position_grid(1, audio_frames) + mx.eval(audio_positions) + else: + audio_positions = None # Initialize latents with optional I2V conditioning - state = None + video_state = None + video_latent_shape = (1, 128, latent_frames, latent_h, latent_w) if is_i2v and image_latent is not None: - latent_shape = (1, 128, latent_frames, latent_h, latent_w) - state = LatentState( - latent=mx.zeros(latent_shape, dtype=model_dtype), - clean_latent=mx.zeros(latent_shape, dtype=model_dtype), + video_state = LatentState( + latent=mx.zeros(video_latent_shape, dtype=model_dtype), + clean_latent=mx.zeros(video_latent_shape, dtype=model_dtype), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) conditioning = VideoConditionByLatentIndex( @@ -524,30 +904,46 @@ def generate_video_dev( frame_idx=image_frame_idx, strength=image_strength, ) - state = apply_conditioning(state, [conditioning]) + video_state = apply_conditioning(video_state, [conditioning]) # Apply noiser - noise = mx.random.normal(latent_shape, dtype=model_dtype) + noise = mx.random.normal(video_latent_shape, dtype=model_dtype) noise_scale = sigmas[0] - scaled_mask = state.denoise_mask * noise_scale + scaled_mask = video_state.denoise_mask * noise_scale - state = LatentState( - latent=noise * scaled_mask + state.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), - clean_latent=state.clean_latent, - denoise_mask=state.denoise_mask, + video_state = LatentState( + latent=noise * scaled_mask + video_state.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), + clean_latent=video_state.clean_latent, + denoise_mask=video_state.denoise_mask, ) - latents = state.latent - mx.eval(latents) + video_latents = video_state.latent + mx.eval(video_latents) else: # T2V: just use random noise - latents = mx.random.normal((1, 128, latent_frames, latent_h, latent_w), dtype=model_dtype) - mx.eval(latents) + video_latents = mx.random.normal(video_latent_shape, dtype=model_dtype) + mx.eval(video_latents) + + # Initialize audio latents if audio is enabled + if audio: + audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype) + mx.eval(audio_latents) + else: + audio_latents = None # Denoise with CFG - latents = denoise_with_cfg( - latents, positions, text_embeddings_pos, text_embeddings_neg, - transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, state=state - ) + if audio: + video_latents, audio_latents = denoise_av_with_cfg( + video_latents, audio_latents, + video_positions, audio_positions, + video_embeddings_pos, video_embeddings_neg, + audio_embeddings_pos, audio_embeddings_neg, + transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, video_state=video_state + ) + else: + video_latents = denoise_with_cfg( + video_latents, video_positions, video_embeddings_pos, video_embeddings_neg, + transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, state=video_state + ) del transformer mx.clear_cache() @@ -583,33 +979,100 @@ def generate_video_dev( spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none" temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none" print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}") - video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose) + video = vae_decoder.decode_tiled(video_latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose) else: print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}") - video = vae_decoder(latents) + video = vae_decoder(video_latents) mx.eval(video) + + del vae_decoder mx.clear_cache() - # Convert to uint8 frames + # Decode audio if enabled + audio_np = None + if audio and audio_latents is not None: + print(f"{Colors.BLUE}Decoding audio...{Colors.RESET}") + + # Load audio decoder + audio_decoder = load_audio_decoder(model_path) + mx.eval(audio_decoder.parameters()) + + # Decode audio latents to mel spectrogram + mel_spectrogram = audio_decoder(audio_latents) + mx.eval(mel_spectrogram) + + del audio_decoder + mx.clear_cache() + + # Load vocoder and convert mel to waveform + vocoder = load_vocoder(model_path) + mx.eval(vocoder.parameters()) + + audio_waveform = vocoder(mel_spectrogram) + mx.eval(audio_waveform) + + del vocoder + mx.clear_cache() + + # Convert to numpy + audio_np = np.array(audio_waveform) + if audio_np.ndim == 3: + audio_np = audio_np[0] # Remove batch dim + + print(f"{Colors.DIM} Audio shape: {audio_np.shape}, duration: {audio_np.shape[-1] / AUDIO_SAMPLE_RATE:.2f}s{Colors.RESET}") + + # Convert video to uint8 frames video = mx.squeeze(video, axis=0) # (C, F, H, W) video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C) video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0) video = (video * 255).astype(mx.uint8) video_np = np.array(video) - # Save video + # Save outputs output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) + # Determine audio output path + if audio and audio_np is not None: + if output_audio_path is None: + audio_output = output_path.parent / f"{output_path.stem}.wav" + else: + audio_output = Path(output_audio_path) + + # Save audio + save_audio(audio_np, audio_output) + print(f"{Colors.GREEN}Saved audio to{Colors.RESET} {audio_output}") + + # Save video (to temp file if we need to mux with audio) + if audio and audio_np is not None: + # Save video to temp file, then mux with audio + temp_video_path = output_path.parent / f"{output_path.stem}_temp.mp4" + video_save_path = temp_video_path + else: + video_save_path = output_path + try: import cv2 h, w = video_np.shape[1], video_np.shape[2] fourcc = cv2.VideoWriter_fourcc(*'avc1') - out = cv2.VideoWriter(str(output_path), fourcc, fps, (w, h)) + out = cv2.VideoWriter(str(video_save_path), fourcc, fps, (w, h)) for frame in video_np: out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) out.release() - print(f"{Colors.GREEN}Saved video to{Colors.RESET} {output_path}") + + if audio and audio_np is not None: + # Mux video and audio + print(f"{Colors.BLUE}Muxing video and audio...{Colors.RESET}") + if mux_video_audio(temp_video_path, audio_output, output_path): + print(f"{Colors.GREEN}Saved video with audio to{Colors.RESET} {output_path}") + # Clean up temp file + temp_video_path.unlink(missing_ok=True) + else: + # Fallback: keep separate files + print(f"{Colors.YELLOW}Could not mux, keeping separate files{Colors.RESET}") + temp_video_path.rename(output_path.parent / f"{output_path.stem}_video.mp4") + else: + print(f"{Colors.GREEN}Saved video to{Colors.RESET} {output_path}") except Exception as e: print(f"{Colors.RED}Could not save video: {e}{Colors.RESET}") @@ -642,6 +1105,10 @@ Examples: # Image-to-Video (I2V) python -m mlx_video.generate_dev --prompt "A person dancing" --image photo.jpg + + # With synchronized audio + python -m mlx_video.generate_dev --prompt "Ocean waves crashing on rocks" --audio + python -m mlx_video.generate_dev --prompt "A busy city street" --audio --output-audio street.wav """ ) @@ -769,6 +1236,17 @@ Examples: choices=["none", "auto", "default", "aggressive", "conservative", "spatial", "temporal"], help="Tiling mode for VAE decoding (default: none, faster on high-memory systems)" ) + parser.add_argument( + "--audio", + action="store_true", + help="Generate synchronized audio with the video" + ) + parser.add_argument( + "--output-audio", + type=str, + default=None, + help="Output audio path (default: same as video with .wav extension)" + ) args = parser.parse_args() generate_video_dev( @@ -784,6 +1262,7 @@ Examples: seed=args.seed, fps=args.fps, output_path=args.output_path, + output_audio_path=args.output_audio, save_frames=args.save_frames, verbose=args.verbose, enhance_prompt=args.enhance_prompt, @@ -793,6 +1272,7 @@ Examples: image_strength=args.image_strength, image_frame_idx=args.image_frame_idx, tiling=args.tiling, + audio=args.audio, ) diff --git a/tests/test_generate_dev.py b/tests/test_generate_dev.py index 4a008d7..e4fa17e 100644 --- a/tests/test_generate_dev.py +++ b/tests/test_generate_dev.py @@ -2,14 +2,16 @@ import pytest import mlx.core as mx -import numpy as np from mlx_video.generate_dev import ( ltx2_scheduler, create_position_grid, + create_audio_position_grid, + compute_audio_frames, cfg_delta, - denoise_with_cfg, DEFAULT_NEGATIVE_PROMPT, + AUDIO_SAMPLE_RATE, + AUDIO_LATENTS_PER_SECOND, ) @@ -260,28 +262,6 @@ class TestInputValidation: class TestDenoiseWithCFGMocked: """Tests for denoise_with_cfg with mocked transformer.""" - def test_denoise_returns_correct_shape(self): - """Denoised output should have same shape as input latents.""" - # Create a simple mock transformer - class MockTransformer: - inner_dim = 4096 - positional_embedding_theta = 10000.0 - positional_embedding_max_pos = [20, 2048, 2048] - use_middle_indices_grid = True - num_attention_heads = 32 - rope_type = None - - class config: - double_precision_rope = True - - def __call__(self, video, audio): - # Return input as output (identity) - return video.latent, None - - # Skip this test if we can't import the required modules easily - # This is a structural test to ensure the function signature is correct - pass - def test_sigmas_list_conversion(self): """Sigmas should be convertible to list.""" sigmas = ltx2_scheduler(steps=5) @@ -296,16 +276,10 @@ class TestTilingDefault: def test_tiling_default_is_none(self): """Default tiling should be 'none' for performance.""" - # Import and check the default - import argparse - from mlx_video.generate_dev import main - - # The default is set in the argparse definition - # We verify this by checking the function signature import inspect - sig = inspect.signature( - __import__('mlx_video.generate_dev', fromlist=['generate_video_dev']).generate_video_dev - ) + from mlx_video.generate_dev import generate_video_dev + + sig = inspect.signature(generate_video_dev) tiling_param = sig.parameters.get('tiling') assert tiling_param is not None @@ -358,5 +332,91 @@ class TestLatentDimensions: assert num_tokens == expected, f"Expected {expected} tokens, got {num_tokens}" +class TestAudioPositionGrid: + """Tests for audio position grid creation.""" + + def test_audio_position_grid_shape(self): + """Audio position grid should have correct shape (B, 1, T, 2).""" + batch_size = 1 + audio_frames = 34 # ~1.36 seconds at 25 latent frames/sec + + positions = create_audio_position_grid(batch_size, audio_frames) + expected_shape = (batch_size, 1, audio_frames, 2) + + assert positions.shape == expected_shape, \ + f"Expected {expected_shape}, got {positions.shape}" + + def test_audio_position_grid_dtype(self): + """Audio position grid should be float32.""" + positions = create_audio_position_grid(1, 34) + assert positions.dtype == mx.float32, \ + f"Expected float32, got {positions.dtype}" + + def test_audio_position_grid_batch_size(self): + """Audio position grid should respect batch size.""" + for batch_size in [1, 2, 4]: + positions = create_audio_position_grid(batch_size, 34) + assert positions.shape[0] == batch_size + + def test_audio_position_grid_temporal_values(self): + """Audio positions should be in seconds.""" + positions = create_audio_position_grid(1, 34) + + # Values should be in seconds (small values for ~1 second of audio) + max_val = mx.max(positions).item() + assert max_val < 10, f"Audio positions seem too large: {max_val}" + assert max_val > 0, "Audio positions should be positive" + + def test_audio_position_grid_no_nan_or_inf(self): + """Audio position grid should not contain NaN or Inf.""" + positions = create_audio_position_grid(1, 34) + + assert not mx.any(mx.isnan(positions)).item(), "Audio position grid contains NaN" + assert not mx.any(mx.isinf(positions)).item(), "Audio position grid contains Inf" + + +class TestComputeAudioFrames: + """Tests for audio frame count calculation.""" + + def test_audio_frames_basic(self): + """Audio frames should be proportional to video duration.""" + # 33 frames at 24 fps = ~1.375 seconds + # At 25 latent frames/sec = ~34 audio frames + audio_frames = compute_audio_frames(33, 24.0) + assert audio_frames > 0 + assert isinstance(audio_frames, int) + + def test_audio_frames_scales_with_video(self): + """More video frames should produce more audio frames.""" + audio_33 = compute_audio_frames(33, 24.0) + audio_65 = compute_audio_frames(65, 24.0) + + assert audio_65 > audio_33, \ + f"Expected more audio frames for longer video: {audio_65} <= {audio_33}" + + def test_audio_frames_formula(self): + """Audio frames should match expected formula.""" + num_video_frames = 33 + fps = 24.0 + + duration = num_video_frames / fps # ~1.375 seconds + expected = round(duration * AUDIO_LATENTS_PER_SECOND) + + actual = compute_audio_frames(num_video_frames, fps) + assert actual == expected, f"Expected {expected}, got {actual}" + + +class TestAudioConstants: + """Tests for audio constants.""" + + def test_audio_sample_rate(self): + """Audio sample rate should be 24000 Hz.""" + assert AUDIO_SAMPLE_RATE == 24000 + + def test_audio_latents_per_second(self): + """Audio latents per second should be 25.""" + assert AUDIO_LATENTS_PER_SECOND == 25.0 + + if __name__ == "__main__": pytest.main([__file__, "-v"])