diff --git a/README.md b/README.md index fdbddf9..8d86c69 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ Supported models: ## Features - Text-to-video (T2V) and Image-to-video (I2V) generation +- Audio-to-video (A2V) conditioning — generate video from input audio - Four pipeline modes: Distilled, Dev, Dev Two-Stage, and Dev Two-Stage HQ - Synchronized audio-video generation (experimental) - LoRA support (including HuggingFace repos) @@ -85,7 +86,27 @@ uv run mlx_video.generate --prompt "A person dancing" --image photo.jpg uv run mlx_video.generate --pipeline dev --prompt "Waves crashing" --image beach.png --cfg-scale 3.5 ``` -### Audio-Video (experimental) +### Audio-to-Video (A2V) + +Generate video conditioned on an input audio file. The audio is encoded to latent space and frozen during denoising — the transformer's cross-attention reads the audio signal to guide video generation. + +```bash +# A2V - generate video from audio +uv run mlx_video.generate --audio-file music.wav --prompt "A band playing music" + +# A2V with dev pipeline +uv run mlx_video.generate --pipeline dev --audio-file ocean.wav --prompt "Ocean waves" + +# A2V + I2V (audio + image conditioning) +uv run mlx_video.generate --audio-file rain.wav --image forest.jpg --prompt "Rain in forest" + +# A2V with custom start time +uv run mlx_video.generate --audio-file song.mp3 --audio-start-time 30.0 --prompt "Concert" +``` + +### Audio-Video Generation (experimental) + +Generate synchronized audio alongside video from scratch: ```bash uv run mlx_video.generate --prompt "Ocean waves crashing" --audio @@ -150,6 +171,8 @@ uv run mlx_video.upscale --input video.mp4 --output upscaled.mp4 --refine --prom | `--image`, `-i` | None | Conditioning image for I2V | | `--image-strength` | 1.0 | Conditioning strength for I2V | | `--audio`, `-a` | false | Enable synchronized audio generation | +| `--audio-file` | None | Path to audio file for A2V conditioning | +| `--audio-start-time` | 0.0 | Start time in seconds for audio file | | `--tiling` | `auto` | VAE tiling mode: `auto`, `none`, `aggressive`, `conservative` | | `--stream` | false | Stream frames as they decode | diff --git a/mlx_video/convert.py b/mlx_video/convert.py index de9f01d..1efc97f 100644 --- a/mlx_video/convert.py +++ b/mlx_video/convert.py @@ -606,6 +606,86 @@ def save_weights(path: Path, weights: Dict[str, mx.array]) -> None: mx.save_safetensors(str(path / "model.safetensors"), weights) +def convert_audio_encoder( + model_path: Union[str, Path], + source_repo: str = "Lightricks/LTX-2", +) -> Path: + """Convert and save audio encoder weights from original HF checkpoint. + + The audio VAE safetensors in the HF repo contains both encoder and decoder + weights. This extracts encoder weights, transposes Conv2d for MLX, and saves + them to a separate directory for AudioEncoder.from_pretrained(). + + Args: + model_path: Local model directory (output location). + source_repo: HF repo containing audio_vae/diffusion_pytorch_model.safetensors. + + Returns: + Path to the audio_vae_encoder directory. + """ + model_path = Path(model_path) + encoder_dir = model_path / "audio_vae_encoder" + + if (encoder_dir / "model.safetensors").exists(): + return encoder_dir + + # Download original audio VAE weights + from huggingface_hub import hf_hub_download + vae_path = hf_hub_download( + source_repo, + "audio_vae/diffusion_pytorch_model.safetensors", + ) + + raw_weights = mx.load(vae_path) + + # Extract encoder weights and per-channel statistics + from mlx_video.models.ltx.audio_vae import AudioEncoder + from mlx_video.models.ltx.config import AudioEncoderModelConfig + + # Build config from the decoder config (same audio VAE architecture) + decoder_config_path = model_path / "audio_vae" / "config.json" + if decoder_config_path.exists(): + with open(decoder_config_path) as f: + dec_cfg = json.load(f) + enc_config = { + "ch": dec_cfg.get("ch", 128), + "in_channels": dec_cfg.get("out_ch", 2), + "ch_mult": dec_cfg.get("ch_mult", [1, 2, 4]), + "num_res_blocks": dec_cfg.get("num_res_blocks", 2), + "attn_resolutions": dec_cfg.get("attn_resolutions", []), + "resolution": dec_cfg.get("resolution", 256), + "z_channels": dec_cfg.get("z_channels", 8), + "double_z": True, + "n_fft": 1024, + "norm_type": dec_cfg.get("norm_type", "pixel"), + "causality_axis": dec_cfg.get("causality_axis", "height"), + "dropout": dec_cfg.get("dropout", 0.0), + "mid_block_add_attention": dec_cfg.get("mid_block_add_attention", False), + "sample_rate": dec_cfg.get("sample_rate", 16000), + "mel_hop_length": dec_cfg.get("mel_hop_length", 160), + "is_causal": dec_cfg.get("is_causal", True), + "mel_bins": dec_cfg.get("mel_bins", 64) or 64, + "resamp_with_conv": dec_cfg.get("resamp_with_conv", True), + "attn_type": dec_cfg.get("attn_type", "vanilla"), + } + else: + enc_config = {"in_channels": 2, "double_z": True, "n_fft": 1024, "mel_bins": 64} + + # Sanitize weights + config = AudioEncoderModelConfig.from_dict(enc_config) + encoder = AudioEncoder(config) + sanitized = encoder.sanitize(raw_weights) + + # Save + encoder_dir.mkdir(parents=True, exist_ok=True) + mx.save_safetensors(str(encoder_dir / "model.safetensors"), sanitized) + with open(encoder_dir / "config.json", "w") as f: + json.dump(enc_config, f, indent=2) + + print(f"Audio encoder weights saved to {encoder_dir}") + return encoder_dir + + def load_model( path_or_hf_repo: str, lazy: bool = False, diff --git a/mlx_video/generate.py b/mlx_video/generate.py index d6f5517..d4b415c 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -454,6 +454,7 @@ def denoise_distilled( audio_latents: Optional[mx.array] = None, audio_positions: Optional[mx.array] = None, audio_embeddings: Optional[mx.array] = None, + audio_frozen: bool = False, ) -> tuple[mx.array, Optional[mx.array]]: """Run denoising loop for distilled pipeline (no CFG).""" dtype = latents.dtype @@ -513,14 +514,17 @@ def denoise_distilled( audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype) + # A2V: frozen audio uses timesteps=0 (tells model audio is clean) + a_ts = mx.zeros((ab, at), dtype=dtype) if audio_frozen else mx.full((ab, at), sigma, dtype=dtype) + a_sig = mx.zeros((ab,), dtype=dtype) if audio_frozen else mx.full((ab,), sigma, dtype=dtype) audio_modality = Modality( latent=audio_flat, - timesteps=mx.full((ab, at), sigma, dtype=dtype), + timesteps=a_ts, positions=audio_positions, context=audio_embeddings, context_mask=None, enabled=True, - sigma=mx.full((ab,), sigma, dtype=dtype), + sigma=a_sig, ) velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality) @@ -529,9 +533,6 @@ def denoise_distilled( mx.eval(audio_velocity) # Compute denoised (x0) using per-token timesteps in float32 - # x0 = latent - timestep * velocity - # For conditioned tokens (timestep=0): x0 = latent - # For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity sigma_f32 = mx.array(sigma, dtype=mx.float32) latents_flat_f32 = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1) @@ -539,7 +540,7 @@ def denoise_distilled( denoised = mx.reshape(mx.transpose(x0_f32, (0, 2, 1)), (b, c, f, h, w)) audio_denoised = None - if enable_audio and audio_velocity is not None: + if enable_audio and audio_velocity is not None and not audio_frozen: ab, ac, at, af = audio_latents.shape audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af)) audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) @@ -552,15 +553,15 @@ def denoise_distilled( if audio_denoised is not None: mx.eval(audio_denoised) - # Euler step in float32 (latents stay in float32) + # Euler step in float32 if sigma_next > 0: sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32) latents = denoised + sigma_next_f32 * (latents - denoised) / sigma_f32 - if enable_audio and audio_denoised is not None: + if enable_audio and audio_denoised is not None and not audio_frozen: audio_latents = audio_denoised + sigma_next_f32 * (audio_latents - audio_denoised) / sigma_f32 else: latents = denoised - if enable_audio and audio_denoised is not None: + if enable_audio and audio_denoised is not None and not audio_frozen: audio_latents = audio_denoised mx.eval(latents) @@ -785,6 +786,7 @@ def denoise_dev_av( stg_video_blocks: Optional[list] = None, stg_audio_blocks: Optional[list] = None, modality_scale: float = 1.0, + audio_frozen: bool = False, ) -> tuple[mx.array, mx.array]: """Run denoising loop for dev pipeline with CFG/APG, STG, modality guidance, and audio. @@ -879,11 +881,12 @@ def denoise_dev_av( else: video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype) - audio_timesteps = mx.full((ab, at), sigma, dtype=dtype) + # A2V: frozen audio uses timesteps=0 (tells model audio is clean) + audio_timesteps = mx.zeros((ab, at), dtype=dtype) if audio_frozen else mx.full((ab, at), sigma, dtype=dtype) # Positive conditioning pass sigma_array = mx.full((b,), sigma, dtype=dtype) - audio_sigma_array = mx.full((ab,), sigma, dtype=dtype) + audio_sigma_array = mx.zeros((ab,), dtype=dtype) if audio_frozen else mx.full((ab,), sigma, dtype=dtype) video_modality_pos = Modality( latent=video_flat, timesteps=video_timesteps, positions=video_positions, context=video_embeddings_pos, context_mask=None, enabled=True, @@ -1001,11 +1004,13 @@ def denoise_dev_av( video_velocity_f32 = (video_latents - video_denoised_f32) / sigma_f32 video_latents = video_latents + video_velocity_f32 * dt_f32 - audio_velocity_f32 = (audio_latents - audio_denoised_f32) / sigma_f32 - audio_latents = audio_latents + audio_velocity_f32 * dt_f32 + if not audio_frozen: + audio_velocity_f32 = (audio_latents - audio_denoised_f32) / sigma_f32 + audio_latents = audio_latents + audio_velocity_f32 * dt_f32 else: video_latents = video_denoised_f32 - audio_latents = audio_denoised_f32 + if not audio_frozen: + audio_latents = audio_denoised_f32 mx.eval(video_latents, audio_latents) progress.advance(task) @@ -1037,6 +1042,7 @@ def denoise_res2s_av( noise_seed: int = 42, bongmath: bool = True, bongmath_max_iter: int = 100, + audio_frozen: bool = False, ) -> tuple[mx.array, mx.array]: """Run res_2s second-order denoising loop with CFG/STG/modality guidance. @@ -1125,10 +1131,10 @@ def denoise_res2s_av( video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat else: video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype) - audio_timesteps = mx.full((ab, at), sigma, dtype=dtype) + audio_timesteps = mx.zeros((ab, at), dtype=dtype) if audio_frozen else mx.full((ab, at), sigma, dtype=dtype) sigma_array = mx.full((b,), sigma, dtype=dtype) - audio_sigma_array = mx.full((ab,), sigma, dtype=dtype) + audio_sigma_array = mx.zeros((ab,), dtype=dtype) if audio_frozen else mx.full((ab,), sigma, dtype=dtype) # Pass 1: Positive conditioning video_modality_pos = Modality( @@ -1270,18 +1276,23 @@ def denoise_res2s_av( # Compute midpoint eps_1_video = denoised_video_1 - x_anchor_video - eps_1_audio = denoised_audio_1 - x_anchor_audio - x_mid_video = x_anchor_video + h * a21 * eps_1_video - x_mid_audio = x_anchor_audio + h * a21 * eps_1_audio + + if not audio_frozen: + eps_1_audio = denoised_audio_1 - x_anchor_audio + x_mid_audio = x_anchor_audio + h * a21 * eps_1_audio + else: + eps_1_audio = None + x_mid_audio = audio_latents # frozen: pass through unchanged # SDE noise injection at substep substep_noise_key, key1, key2 = mx.random.split(substep_noise_key, 3) substep_noise_v = get_new_noise(video_latents.shape, key1) - substep_noise_a = get_new_noise(audio_latents.shape, key2) x_mid_video = sde_noise_step(x_anchor_video, x_mid_video, sigma, sub_sigma, substep_noise_v) - x_mid_audio = sde_noise_step(x_anchor_audio, x_mid_audio, sigma, sub_sigma, substep_noise_a) + if not audio_frozen: + substep_noise_a = get_new_noise(audio_latents.shape, key2) + x_mid_audio = sde_noise_step(x_anchor_audio, x_mid_audio, sigma, sub_sigma, substep_noise_a) mx.eval(x_mid_video, x_mid_audio) # ============================================================ @@ -1291,9 +1302,13 @@ def denoise_res2s_av( for _ in range(bongmath_max_iter): x_anchor_video = x_mid_video - h * a21 * eps_1_video eps_1_video = denoised_video_1 - x_anchor_video - x_anchor_audio = x_mid_audio - h * a21 * eps_1_audio - eps_1_audio = denoised_audio_1 - x_anchor_audio - mx.eval(x_anchor_video, x_anchor_audio, eps_1_video, eps_1_audio) + if not audio_frozen: + x_anchor_audio = x_mid_audio - h * a21 * eps_1_audio + eps_1_audio = denoised_audio_1 - x_anchor_audio + if audio_frozen: + mx.eval(x_anchor_video, eps_1_video) + else: + mx.eval(x_anchor_video, x_anchor_audio, eps_1_video, eps_1_audio) # ============================================================ # Stage 2: Evaluate denoiser at midpoint sigma @@ -1306,21 +1321,21 @@ def denoise_res2s_av( # Final combination with RK coefficients # ============================================================ eps_2_video = denoised_video_2 - x_anchor_video - eps_2_audio = denoised_audio_2 - x_anchor_audio - x_next_video = x_anchor_video + h * (b1 * eps_1_video + b2 * eps_2_video) - x_next_audio = x_anchor_audio + h * (b1 * eps_1_audio + b2 * eps_2_audio) # SDE noise injection at step level step_noise_key, key1, key2 = mx.random.split(step_noise_key, 3) step_noise_v = get_new_noise(video_latents.shape, key1) - step_noise_a = get_new_noise(audio_latents.shape, key2) - x_next_video = sde_noise_step(x_anchor_video, x_next_video, sigma, sigma_next, step_noise_v) - x_next_audio = sde_noise_step(x_anchor_audio, x_next_audio, sigma, sigma_next, step_noise_a) video_latents = x_next_video.astype(mx.float32) - audio_latents = x_next_audio.astype(mx.float32) + if not audio_frozen: + eps_2_audio = denoised_audio_2 - x_anchor_audio + x_next_audio = x_anchor_audio + h * (b1 * eps_1_audio + b2 * eps_2_audio) + step_noise_a = get_new_noise(audio_latents.shape, key2) + x_next_audio = sde_noise_step(x_anchor_audio, x_next_audio, sigma, sigma_next, step_noise_a) + audio_latents = x_next_audio.astype(mx.float32) + mx.eval(video_latents, audio_latents) progress.advance(task) @@ -1330,7 +1345,8 @@ def denoise_res2s_av( video_latents, audio_latents, sigmas_list[n_full_steps] ) video_latents = denoised_video - audio_latents = denoised_audio + if not audio_frozen: + audio_latents = denoised_audio mx.eval(video_latents, audio_latents) return video_latents, audio_latents @@ -1443,6 +1459,8 @@ def generate_video( lora_strength: float = 1.0, lora_strength_stage_1: Optional[float] = None, lora_strength_stage_2: Optional[float] = None, + audio_file: Optional[str] = None, + audio_start_time: float = 0.0, ): """Generate video using LTX-2 models. @@ -1496,8 +1514,16 @@ def generate_video( num_frames = adjusted_num_frames is_i2v = image is not None + is_a2v = audio_file is not None + if is_a2v and audio: + raise ValueError("Cannot use both --audio-file (A2V) and --audio (generate audio). Choose one.") + # A2V implicitly enables audio path through the transformer + if is_a2v: + audio = True mode_str = "I2V" if is_i2v else "T2V" - if audio: + if is_a2v: + mode_str = "A2V" + ("+I2V" if is_i2v else "") + elif audio: mode_str += "+Audio" pipeline_names = { @@ -1599,6 +1625,62 @@ def generate_video( stg_blocks = [29] console.print(f"[dim]Auto-detected STG blocks: {stg_blocks} (model={'2.3' if transformer.config.has_prompt_adaln else '2'})[/]") + # ========================================================================== + # A2V: Encode input audio to frozen latents + # ========================================================================== + a2v_audio_latents = None + a2v_waveform = None + a2v_sr = None + if is_a2v: + from mlx_video.models.ltx.audio_vae.audio_processor import load_audio, ensure_stereo, waveform_to_mel + from mlx_video.convert import convert_audio_encoder + from mlx_video.models.ltx.audio_vae import AudioEncoder + + with console.status("[blue]Loading and encoding input audio (A2V)...[/]", spinner="dots"): + video_duration = num_frames / fps + + # Load audio + waveform, sr = load_audio( + audio_file, + target_sr=AUDIO_LATENT_SAMPLE_RATE, + start_time=audio_start_time, + max_duration=video_duration, + ) + waveform = ensure_stereo(waveform) + a2v_waveform = waveform.copy() + a2v_sr = sr + + # Compute mel-spectrogram + mel = waveform_to_mel(waveform, sample_rate=sr, n_fft=1024, hop_length=AUDIO_HOP_LENGTH, n_mels=64) + + # Convert audio encoder weights if needed, then load + encoder_dir = convert_audio_encoder(model_path, source_repo="Lightricks/LTX-2") + audio_encoder = AudioEncoder.from_pretrained(encoder_dir) + mx.eval(audio_encoder.parameters()) + + # Encode: (1, 2, time, 64) -> normalized latents + encoded = audio_encoder(mel) + mx.eval(encoded) + + # encoded is in MLX format (B, T', mel_bins', z_channels) = (1, T', 16, 8) + # Convert to PyTorch-style format for consistency: (B, C, T, mel_bins) + a2v_audio_latents = mx.transpose(encoded, (0, 3, 1, 2)).astype(model_dtype) + + # Trim/pad to match expected audio_frames + t_encoded = a2v_audio_latents.shape[2] + if t_encoded > audio_frames: + a2v_audio_latents = a2v_audio_latents[:, :, :audio_frames, :] + elif t_encoded < audio_frames: + pad_size = audio_frames - t_encoded + padding = mx.zeros((1, AUDIO_LATENT_CHANNELS, pad_size, AUDIO_MEL_BINS), dtype=model_dtype) + a2v_audio_latents = mx.concatenate([a2v_audio_latents, padding], axis=2) + mx.eval(a2v_audio_latents) + + del audio_encoder + mx.clear_cache() + + console.print(f"[green]✓[/] Audio encoded ({a2v_audio_latents.shape[2]} frames from {audio_file})") + # ========================================================================== # Pipeline-specific generation logic # ========================================================================== @@ -1636,9 +1718,9 @@ def generate_video( positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) mx.eval(positions) - # Always init audio latents/positions - PyTorch unconditionally generates audio + # Init audio latents/positions: use encoded A2V latents or random audio_positions = create_audio_position_grid(1, audio_frames) - audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype) + audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype) mx.eval(audio_positions, audio_latents) # Apply I2V conditioning @@ -1671,6 +1753,7 @@ def generate_video( latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1, audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings, + audio_frozen=is_a2v, ) # Upsample latents @@ -1723,7 +1806,7 @@ def generate_video( mx.eval(latents) # Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch) - if audio_latents is not None: + if audio_latents is not None and not is_a2v: audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype) audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale) @@ -1735,6 +1818,7 @@ def generate_video( verbose=verbose, state=state2, audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings, + audio_frozen=is_a2v, ) elif pipeline == PipelineType.DEV: @@ -1770,7 +1854,7 @@ def generate_video( # Always init audio latents/positions - PyTorch unconditionally generates audio audio_positions = create_audio_position_grid(1, audio_frames) - audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype) + audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype) mx.eval(audio_positions, audio_latents) # Initialize latents with optional I2V conditioning @@ -1811,6 +1895,7 @@ def generate_video( use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold, stg_scale=stg_scale, stg_video_blocks=stg_blocks, stg_audio_blocks=stg_blocks, modality_scale=modality_scale, + audio_frozen=is_a2v, ) # Load VAE decoder (for dev pipeline, loaded here instead of during upsampling) @@ -1858,7 +1943,7 @@ def generate_video( # Always init audio latents/positions - PyTorch unconditionally generates audio audio_positions = create_audio_position_grid(1, audio_frames) - audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype) + audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype) mx.eval(audio_positions, audio_latents) # Apply I2V conditioning for stage 1 @@ -1899,6 +1984,7 @@ def generate_video( use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold, stg_scale=stg_scale, stg_video_blocks=stg_blocks, stg_audio_blocks=stg_blocks, modality_scale=modality_scale, + audio_frozen=is_a2v, ) mx.eval(audio_latents) @@ -1969,7 +2055,7 @@ def generate_video( mx.eval(latents) # Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch) - if audio_latents is not None: + if audio_latents is not None and not is_a2v: audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype) audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale) @@ -1981,6 +2067,7 @@ def generate_video( verbose=verbose, state=state2, audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings_pos, + audio_frozen=is_a2v, ) elif pipeline == PipelineType.DEV_TWO_STAGE_HQ: @@ -2045,7 +2132,7 @@ def generate_video( mx.eval(positions) audio_positions = create_audio_position_grid(1, audio_frames) - audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype) + audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype) mx.eval(audio_positions, audio_latents) # Apply I2V conditioning for stage 1 @@ -2087,6 +2174,7 @@ def generate_video( stg_scale=stg_scale, stg_video_blocks=stg_blocks, stg_audio_blocks=stg_blocks, modality_scale=modality_scale, noise_seed=seed, + audio_frozen=is_a2v, ) mx.eval(audio_latents) @@ -2148,7 +2236,7 @@ def generate_video( mx.eval(latents) # Re-noise audio at sigma=0.909375 for joint refinement - if audio_latents is not None: + if audio_latents is not None and not is_a2v: audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype) audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale) @@ -2165,6 +2253,7 @@ def generate_video( audio_cfg_scale=1.0, cfg_rescale=0.0, verbose=verbose, video_state=state2, noise_seed=seed + 1, + audio_frozen=is_a2v, ) del transformer @@ -2279,29 +2368,38 @@ def generate_video( # Decode and save audio if enabled audio_np = None + vocoder_sample_rate = AUDIO_SAMPLE_RATE if audio and audio_latents is not None: - with console.status("[blue]🔊 Decoding audio...[/]", spinner="dots"): - audio_decoder = load_audio_decoder(model_path, pipeline) - vocoder = load_vocoder_model(model_path, pipeline) - mx.eval(audio_decoder.parameters(), vocoder.parameters()) + if is_a2v and a2v_waveform is not None: + # A2V: use original input audio waveform (no VAE decoding needed) + audio_np = a2v_waveform + if audio_np.ndim == 1: + audio_np = audio_np[np.newaxis, :] + vocoder_sample_rate = a2v_sr or AUDIO_LATENT_SAMPLE_RATE + console.print("[green]✓[/] Using original input audio (A2V)") + else: + with console.status("[blue]Decoding audio...[/]", spinner="dots"): + audio_decoder = load_audio_decoder(model_path, pipeline) + vocoder = load_vocoder_model(model_path, pipeline) + mx.eval(audio_decoder.parameters(), vocoder.parameters()) - mel_spectrogram = audio_decoder(audio_latents) - mx.eval(mel_spectrogram) - console.print(f"[dim] Mel spectrogram: shape={mel_spectrogram.shape}, std={mel_spectrogram.std().item():.4f}, mean={mel_spectrogram.mean().item():.4f}[/]") + mel_spectrogram = audio_decoder(audio_latents) + mx.eval(mel_spectrogram) + console.print(f"[dim] Mel spectrogram: shape={mel_spectrogram.shape}, std={mel_spectrogram.std().item():.4f}, mean={mel_spectrogram.mean().item():.4f}[/]") - audio_waveform = vocoder(mel_spectrogram) - mx.eval(audio_waveform) + audio_waveform = vocoder(mel_spectrogram) + mx.eval(audio_waveform) - audio_np = np.array(audio_waveform.astype(mx.float32)) - if audio_np.ndim == 3: - audio_np = audio_np[0] + audio_np = np.array(audio_waveform.astype(mx.float32)) + if audio_np.ndim == 3: + audio_np = audio_np[0] - # Get sample rate from vocoder (dynamic: 24kHz for LTX-2, 48kHz for LTX-2.3 BWE) - vocoder_sample_rate = getattr(vocoder, 'output_sampling_rate', AUDIO_SAMPLE_RATE) + # Get sample rate from vocoder (dynamic: 24kHz for LTX-2, 48kHz for LTX-2.3 BWE) + vocoder_sample_rate = getattr(vocoder, 'output_sampling_rate', AUDIO_SAMPLE_RATE) - del audio_decoder, vocoder - mx.clear_cache() - console.print("[green]✓[/] Audio decoded") + del audio_decoder, vocoder + mx.clear_cache() + console.print("[green]✓[/] Audio decoded") audio_path = Path(output_audio_path) if output_audio_path else output_path.with_suffix('.wav') save_audio(audio_np, audio_path, vocoder_sample_rate) @@ -2398,6 +2496,8 @@ Examples: help="Tiling mode for VAE decoding") parser.add_argument("--stream", action="store_true", help="Stream frames to output as they're decoded") parser.add_argument("--audio", "-a", action="store_true", help="Enable synchronized audio generation") + parser.add_argument("--audio-file", type=str, default=None, help="Path to audio file for A2V (audio-to-video) conditioning") + parser.add_argument("--audio-start-time", type=float, default=0.0, help="Start time in seconds for audio file (default: 0.0)") parser.add_argument("--output-audio", type=str, default=None, help="Output audio path") parser.add_argument("--apg", action="store_true", help="Use Adaptive Projected Guidance instead of CFG (more stable for I2V)") parser.add_argument("--apg-eta", type=float, default=1.0, help="APG parallel component weight (1.0 = keep full parallel)") @@ -2457,6 +2557,8 @@ Examples: lora_strength=args.lora_strength, lora_strength_stage_1=args.lora_strength_stage_1, lora_strength_stage_2=args.lora_strength_stage_2, + audio_file=args.audio_file, + audio_start_time=args.audio_start_time, ) diff --git a/mlx_video/models/ltx/audio_vae/__init__.py b/mlx_video/models/ltx/audio_vae/__init__.py index 3a9e262..79a1679 100644 --- a/mlx_video/models/ltx/audio_vae/__init__.py +++ b/mlx_video/models/ltx/audio_vae/__init__.py @@ -1,7 +1,8 @@ """Audio VAE module for LTX-2 audio generation.""" from .attention import AttentionType, AttnBlock, make_attn -from .audio_vae import AudioDecoder, decode_audio +from .audio_vae import AudioDecoder, AudioEncoder, decode_audio +from .audio_processor import load_audio, ensure_stereo, waveform_to_mel from .causal_conv_2d import CausalConv2d, make_conv2d from ..config import CausalityAxis from .downsample import Downsample, build_downsampling_path @@ -13,10 +14,15 @@ from .vocoder import Vocoder, load_vocoder __all__ = [ # Main components + "AudioEncoder", "AudioDecoder", "Vocoder", "load_vocoder", "decode_audio", + # Audio processing + "load_audio", + "ensure_stereo", + "waveform_to_mel", # Ops "AudioLatentShape", "AudioPatchifier", diff --git a/mlx_video/models/ltx/audio_vae/audio_processor.py b/mlx_video/models/ltx/audio_vae/audio_processor.py new file mode 100644 index 0000000..ed5ff7a --- /dev/null +++ b/mlx_video/models/ltx/audio_vae/audio_processor.py @@ -0,0 +1,135 @@ +"""Audio processing utilities for loading audio files and computing mel-spectrograms. + +Matches the PyTorch AudioProcessor from LTX-2 (torchaudio.transforms.MelSpectrogram) +using librosa for macOS/MLX compatibility. +""" + +from pathlib import Path + +import numpy as np +import mlx.core as mx + + +def load_audio( + path: str, + target_sr: int = 16000, + start_time: float = 0.0, + max_duration: float | None = None, + mono: bool = False, +) -> tuple[np.ndarray, int]: + """Load audio file, resample to target sample rate. + + Args: + path: Path to audio file (WAV, FLAC, MP3, OGG, or video with audio track). + target_sr: Target sample rate (default 16000 Hz). + start_time: Start time in seconds. + max_duration: Maximum duration in seconds. None = read to end. + mono: If True, convert to mono. Default False (preserve channels). + + Returns: + (waveform, sample_rate) where waveform is (channels, samples) float32 numpy array. + """ + import librosa + + # librosa.load returns mono by default; we want to preserve stereo + y, sr = librosa.load( + path, + sr=target_sr, + mono=mono, + offset=start_time, + duration=max_duration, + ) + + # Ensure 2D: (channels, samples) + if y.ndim == 1: + y = y[np.newaxis, :] # (1, samples) + + return y.astype(np.float32), sr + + +def ensure_stereo(waveform: np.ndarray) -> np.ndarray: + """Ensure waveform is stereo (2, samples). Duplicates mono if needed.""" + if waveform.ndim == 1: + waveform = waveform[np.newaxis, :] + if waveform.shape[0] == 1: + waveform = np.concatenate([waveform, waveform], axis=0) + elif waveform.shape[0] > 2: + waveform = waveform[:2] + return waveform + + +def waveform_to_mel( + waveform: np.ndarray, + sample_rate: int = 16000, + n_fft: int = 1024, + hop_length: int = 160, + win_length: int = 1024, + n_mels: int = 64, + fmin: float = 0.0, + fmax: float = 8000.0, +) -> mx.array: + """Convert waveform to log-mel spectrogram matching PyTorch MelSpectrogram. + + PyTorch reference: + MelSpectrogram(sample_rate=16000, n_fft=1024, win_length=1024, hop_length=160, + f_min=0.0, f_max=8000.0, n_mels=64, power=1.0, + mel_scale="slaney", norm="slaney", center=True, pad_mode="reflect") + + Args: + waveform: (channels, samples) float32 numpy array. + sample_rate: Sample rate of the waveform. + n_fft: FFT size. + hop_length: Hop length. + win_length: Window length. + n_mels: Number of mel bins. + fmin: Minimum frequency for mel filterbank. + fmax: Maximum frequency for mel filterbank. + + Returns: + Log-mel spectrogram as mx.array of shape (1, channels, time, n_mels). + """ + import librosa + + # Ensure 2D + if waveform.ndim == 1: + waveform = waveform[np.newaxis, :] + + channels = waveform.shape[0] + mels = [] + + for ch in range(channels): + # Magnitude spectrogram (power=1.0) + S = np.abs(librosa.stft( + waveform[ch], + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + center=True, + pad_mode="reflect", + )) + + # Mel filterbank with slaney normalization + mel_basis = librosa.filters.mel( + sr=sample_rate, + n_fft=n_fft, + n_mels=n_mels, + fmin=fmin, + fmax=fmax, + norm="slaney", + ) + mel = mel_basis @ S + + # Log scale + mel = np.log(np.clip(mel, a_min=1e-5, a_max=None)) + + # Transpose: (n_mels, time) -> (time, n_mels) + mel = mel.T + mels.append(mel) + + # Stack channels: (channels, time, n_mels) + mel_spec = np.stack(mels, axis=0) + + # Add batch dim: (1, channels, time, n_mels) + mel_spec = mel_spec[np.newaxis, ...] + + return mx.array(mel_spec, dtype=mx.float32) diff --git a/mlx_video/models/ltx/audio_vae/audio_vae.py b/mlx_video/models/ltx/audio_vae/audio_vae.py index 4c6f97b..29eb7e3 100644 --- a/mlx_video/models/ltx/audio_vae/audio_vae.py +++ b/mlx_video/models/ltx/audio_vae/audio_vae.py @@ -6,10 +6,11 @@ from pathlib import Path import mlx.core as mx import mlx.nn as nn from mlx_vlm.models.base import check_array_shape -from ..config import AudioDecoderModelConfig +from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig from .attention import AttentionType, make_attn from .causal_conv_2d import make_conv2d from ..config import CausalityAxis +from .downsample import build_downsampling_path from .normalization import NormType, build_normalization_layer from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics from .resnet import ResnetBlock @@ -59,6 +60,179 @@ def run_mid_block(mid: dict, features: mx.array) -> mx.array: return mid["block_2"](features, temb=None) +class AudioEncoder(nn.Module): + """Encoder that compresses audio spectrograms into latent representations.""" + + def __init__(self, config: AudioEncoderModelConfig) -> None: + super().__init__() + + self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch) + self.sample_rate = config.sample_rate + self.mel_hop_length = config.mel_hop_length + self.is_causal = config.is_causal + self.mel_bins = config.mel_bins + + self.patchifier = AudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=config.sample_rate, + hop_length=config.mel_hop_length, + is_causal=config.is_causal, + ) + + self.ch = config.ch + self.temb_ch = 0 + self.num_resolutions = len(config.ch_mult) + self.num_res_blocks = config.num_res_blocks + self.resolution = config.resolution + self.in_channels = config.in_channels + self.z_channels = config.z_channels + self.double_z = config.double_z + self.norm_type = config.norm_type + self.causality_axis = config.causality_axis + self.attn_type = config.attn_type + + self.conv_in = make_conv2d( + config.in_channels, self.ch, kernel_size=3, stride=1, + causality_axis=self.causality_axis, + ) + + self.down, block_in = build_downsampling_path( + ch=config.ch, + ch_mult=config.ch_mult, + num_resolutions=self.num_resolutions, + num_res_blocks=config.num_res_blocks, + resolution=config.resolution, + temb_channels=self.temb_ch, + dropout=config.dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + attn_resolutions=config.attn_resolutions or set(), + resamp_with_conv=config.resamp_with_conv, + ) + + self.mid = build_mid_block( + channels=block_in, + temb_channels=self.temb_ch, + dropout=config.dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + add_attention=config.mid_block_add_attention, + ) + + self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type) + out_channels = 2 * config.z_channels if config.double_z else config.z_channels + self.conv_out = make_conv2d( + block_in, out_channels, kernel_size=3, stride=1, + causality_axis=self.causality_axis, + ) + + def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: + """Sanitize audio encoder weights from PyTorch format.""" + sanitized = {} + for key, value in weights.items(): + new_key = key + if key.startswith("audio_vae.encoder."): + new_key = key.replace("audio_vae.encoder.", "") + elif key.startswith("encoder."): + new_key = key.replace("encoder.", "") + elif key.startswith("audio_vae.per_channel_statistics."): + if "mean-of-means" in key: + new_key = "per_channel_statistics.mean_of_means" + elif "std-of-means" in key: + new_key = "per_channel_statistics.std_of_means" + else: + continue + elif "per_channel_statistics" in key: + if "mean-of-means" in key or "latents_mean" in key: + new_key = "per_channel_statistics.mean_of_means" + elif "std-of-means" in key or "latents_std" in key: + new_key = "per_channel_statistics.std_of_means" + else: + continue + elif key == "latents_mean": + new_key = "per_channel_statistics.mean_of_means" + elif key == "latents_std": + new_key = "per_channel_statistics.std_of_means" + else: + continue + + if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4: + value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1)) + + sanitized[new_key] = value + return sanitized + + @classmethod + def from_pretrained(cls, model_path: Path) -> "AudioEncoder": + """Load audio encoder from pretrained weights.""" + from mlx_video.models.ltx.config import AudioEncoderModelConfig + import json + + model_path = Path(model_path) + config = AudioEncoderModelConfig.from_dict(json.load(open(model_path / "config.json"))) + encoder = cls(config) + weights = mx.load(str(model_path / "model.safetensors")) + encoder.load_weights(list(weights.items()), strict=True) + return encoder + + def __call__(self, spectrogram: mx.array) -> mx.array: + """Encode audio spectrogram into normalized latent representation. + + Args: + spectrogram: (B, C, T, F) PyTorch format or (B, T, F, C) MLX format. + Returns: + Normalized latent (B, T', F', z_channels) in MLX channels-last format. + """ + if spectrogram.ndim == 4 and spectrogram.shape[1] == self.in_channels: + spectrogram = mx.transpose(spectrogram, (0, 2, 3, 1)) + + h = self.conv_in(spectrogram) + h = self._run_downsampling_path(h) + h = run_mid_block(self.mid, h) + h = self._finalize_output(h) + return self._normalize_latents(h) + + def _run_downsampling_path(self, h: mx.array) -> mx.array: + for level in range(self.num_resolutions): + stage = self.down[level] + for block_idx in range(self.num_res_blocks): + h = stage["block"][block_idx](h, temb=None) + if block_idx in stage["attn"]: + h = stage["attn"][block_idx](h) + if level != self.num_resolutions - 1 and "downsample" in stage: + h = stage["downsample"](h) + return h + + def _finalize_output(self, h: mx.array) -> mx.array: + h = self.norm_out(h) + h = nn.silu(h) + return self.conv_out(h) + + def _normalize_latents(self, h: mx.array) -> mx.array: + """Normalize encoder output using per-channel statistics. + + Takes first half of channels (mean) when double_z=True, + then patchifies, normalizes, and unpatchifies. + """ + # h shape: (B, T', F', 2*z_channels) in MLX format + z_channels = self.z_channels + means = h[..., :z_channels] + + latent_shape = AudioLatentShape( + batch=means.shape[0], + channels=means.shape[3], + frames=means.shape[1], + mel_bins=means.shape[2], + ) + + patched = self.patchifier.patchify(means) + normalized = self.per_channel_statistics.normalize(patched) + return self.patchifier.unpatchify(normalized, latent_shape) + + class AudioDecoder(nn.Module): """ Symmetric decoder that reconstructs audio spectrograms from latent features. diff --git a/mlx_video/models/ltx/config.py b/mlx_video/models/ltx/config.py index 1cfb0a6..57c7f46 100644 --- a/mlx_video/models/ltx/config.py +++ b/mlx_video/models/ltx/config.py @@ -2,7 +2,7 @@ import inspect from dataclasses import dataclass, field from enum import Enum -from typing import Any, List, Optional, Tuple, Set +from typing import Any, List, Optional, Tuple class LTXModelType(Enum): @@ -252,6 +252,47 @@ class AudioDecoderModelConfig(BaseModelConfig): if isinstance(self.attn_type, str): self.attn_type = AttentionType(self.attn_type) +@dataclass +class AudioEncoderModelConfig(BaseModelConfig): + ch: int = 128 + in_channels: int = 2 + ch_mult: Tuple[int, ...] = (1, 2, 4) + num_res_blocks: int = 2 + attn_resolutions: Optional[List[int]] = None + resolution: int = 256 + z_channels: int = 8 + double_z: bool = True + n_fft: int = 1024 + norm_type: Enum = None + causality_axis: Enum = None + dropout: float = 0.0 + mid_block_add_attention: bool = True + sample_rate: int = 16000 + mel_hop_length: int = 160 + is_causal: bool = True + mel_bins: int = 64 + resamp_with_conv: bool = True + attn_type: str = None + + def to_dict(self) -> dict[str, Any]: + result = super().to_dict() + if self.attn_resolutions is not None: + result["attn_resolutions"] = list(self.attn_resolutions) + return result + + def __post_init__(self): + """Convert string enum values to proper enum types.""" + from .audio_vae.normalization import NormType + from .audio_vae.attention import AttentionType + + if isinstance(self.causality_axis, str): + self.causality_axis = CausalityAxis(self.causality_axis) + if isinstance(self.norm_type, str): + self.norm_type = NormType(self.norm_type) + if isinstance(self.attn_type, str): + self.attn_type = AttentionType(self.attn_type) + + @dataclass class VocoderModelConfig(BaseModelConfig): resblock_kernel_sizes: Optional[List[int]] = None