Add audio to video conditioning

This commit is contained in:
Prince Canuma
2026-03-16 01:42:11 +01:00
parent f53b9e0807
commit 6f6105b715
7 changed files with 623 additions and 62 deletions

View File

@@ -24,6 +24,7 @@ Supported models:
## Features
- Text-to-video (T2V) and Image-to-video (I2V) generation
- Audio-to-video (A2V) conditioning — generate video from input audio
- Four pipeline modes: Distilled, Dev, Dev Two-Stage, and Dev Two-Stage HQ
- Synchronized audio-video generation (experimental)
- LoRA support (including HuggingFace repos)
@@ -85,7 +86,27 @@ uv run mlx_video.generate --prompt "A person dancing" --image photo.jpg
uv run mlx_video.generate --pipeline dev --prompt "Waves crashing" --image beach.png --cfg-scale 3.5
```
### Audio-Video (experimental)
### Audio-to-Video (A2V)
Generate video conditioned on an input audio file. The audio is encoded to latent space and frozen during denoising — the transformer's cross-attention reads the audio signal to guide video generation.
```bash
# A2V - generate video from audio
uv run mlx_video.generate --audio-file music.wav --prompt "A band playing music"
# A2V with dev pipeline
uv run mlx_video.generate --pipeline dev --audio-file ocean.wav --prompt "Ocean waves"
# A2V + I2V (audio + image conditioning)
uv run mlx_video.generate --audio-file rain.wav --image forest.jpg --prompt "Rain in forest"
# A2V with custom start time
uv run mlx_video.generate --audio-file song.mp3 --audio-start-time 30.0 --prompt "Concert"
```
### Audio-Video Generation (experimental)
Generate synchronized audio alongside video from scratch:
```bash
uv run mlx_video.generate --prompt "Ocean waves crashing" --audio
@@ -150,6 +171,8 @@ uv run mlx_video.upscale --input video.mp4 --output upscaled.mp4 --refine --prom
| `--image`, `-i` | None | Conditioning image for I2V |
| `--image-strength` | 1.0 | Conditioning strength for I2V |
| `--audio`, `-a` | false | Enable synchronized audio generation |
| `--audio-file` | None | Path to audio file for A2V conditioning |
| `--audio-start-time` | 0.0 | Start time in seconds for audio file |
| `--tiling` | `auto` | VAE tiling mode: `auto`, `none`, `aggressive`, `conservative` |
| `--stream` | false | Stream frames as they decode |

View File

@@ -606,6 +606,86 @@ def save_weights(path: Path, weights: Dict[str, mx.array]) -> None:
mx.save_safetensors(str(path / "model.safetensors"), weights)
def convert_audio_encoder(
model_path: Union[str, Path],
source_repo: str = "Lightricks/LTX-2",
) -> Path:
"""Convert and save audio encoder weights from original HF checkpoint.
The audio VAE safetensors in the HF repo contains both encoder and decoder
weights. This extracts encoder weights, transposes Conv2d for MLX, and saves
them to a separate directory for AudioEncoder.from_pretrained().
Args:
model_path: Local model directory (output location).
source_repo: HF repo containing audio_vae/diffusion_pytorch_model.safetensors.
Returns:
Path to the audio_vae_encoder directory.
"""
model_path = Path(model_path)
encoder_dir = model_path / "audio_vae_encoder"
if (encoder_dir / "model.safetensors").exists():
return encoder_dir
# Download original audio VAE weights
from huggingface_hub import hf_hub_download
vae_path = hf_hub_download(
source_repo,
"audio_vae/diffusion_pytorch_model.safetensors",
)
raw_weights = mx.load(vae_path)
# Extract encoder weights and per-channel statistics
from mlx_video.models.ltx.audio_vae import AudioEncoder
from mlx_video.models.ltx.config import AudioEncoderModelConfig
# Build config from the decoder config (same audio VAE architecture)
decoder_config_path = model_path / "audio_vae" / "config.json"
if decoder_config_path.exists():
with open(decoder_config_path) as f:
dec_cfg = json.load(f)
enc_config = {
"ch": dec_cfg.get("ch", 128),
"in_channels": dec_cfg.get("out_ch", 2),
"ch_mult": dec_cfg.get("ch_mult", [1, 2, 4]),
"num_res_blocks": dec_cfg.get("num_res_blocks", 2),
"attn_resolutions": dec_cfg.get("attn_resolutions", []),
"resolution": dec_cfg.get("resolution", 256),
"z_channels": dec_cfg.get("z_channels", 8),
"double_z": True,
"n_fft": 1024,
"norm_type": dec_cfg.get("norm_type", "pixel"),
"causality_axis": dec_cfg.get("causality_axis", "height"),
"dropout": dec_cfg.get("dropout", 0.0),
"mid_block_add_attention": dec_cfg.get("mid_block_add_attention", False),
"sample_rate": dec_cfg.get("sample_rate", 16000),
"mel_hop_length": dec_cfg.get("mel_hop_length", 160),
"is_causal": dec_cfg.get("is_causal", True),
"mel_bins": dec_cfg.get("mel_bins", 64) or 64,
"resamp_with_conv": dec_cfg.get("resamp_with_conv", True),
"attn_type": dec_cfg.get("attn_type", "vanilla"),
}
else:
enc_config = {"in_channels": 2, "double_z": True, "n_fft": 1024, "mel_bins": 64}
# Sanitize weights
config = AudioEncoderModelConfig.from_dict(enc_config)
encoder = AudioEncoder(config)
sanitized = encoder.sanitize(raw_weights)
# Save
encoder_dir.mkdir(parents=True, exist_ok=True)
mx.save_safetensors(str(encoder_dir / "model.safetensors"), sanitized)
with open(encoder_dir / "config.json", "w") as f:
json.dump(enc_config, f, indent=2)
print(f"Audio encoder weights saved to {encoder_dir}")
return encoder_dir
def load_model(
path_or_hf_repo: str,
lazy: bool = False,

View File

@@ -454,6 +454,7 @@ def denoise_distilled(
audio_latents: Optional[mx.array] = None,
audio_positions: Optional[mx.array] = None,
audio_embeddings: Optional[mx.array] = None,
audio_frozen: bool = False,
) -> tuple[mx.array, Optional[mx.array]]:
"""Run denoising loop for distilled pipeline (no CFG)."""
dtype = latents.dtype
@@ -513,14 +514,17 @@ def denoise_distilled(
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3))
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype)
# A2V: frozen audio uses timesteps=0 (tells model audio is clean)
a_ts = mx.zeros((ab, at), dtype=dtype) if audio_frozen else mx.full((ab, at), sigma, dtype=dtype)
a_sig = mx.zeros((ab,), dtype=dtype) if audio_frozen else mx.full((ab,), sigma, dtype=dtype)
audio_modality = Modality(
latent=audio_flat,
timesteps=mx.full((ab, at), sigma, dtype=dtype),
timesteps=a_ts,
positions=audio_positions,
context=audio_embeddings,
context_mask=None,
enabled=True,
sigma=mx.full((ab,), sigma, dtype=dtype),
sigma=a_sig,
)
velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality)
@@ -529,9 +533,6 @@ def denoise_distilled(
mx.eval(audio_velocity)
# Compute denoised (x0) using per-token timesteps in float32
# x0 = latent - timestep * velocity
# For conditioned tokens (timestep=0): x0 = latent
# For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity
sigma_f32 = mx.array(sigma, dtype=mx.float32)
latents_flat_f32 = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1)
@@ -539,7 +540,7 @@ def denoise_distilled(
denoised = mx.reshape(mx.transpose(x0_f32, (0, 2, 1)), (b, c, f, h, w))
audio_denoised = None
if enable_audio and audio_velocity is not None:
if enable_audio and audio_velocity is not None and not audio_frozen:
ab, ac, at, af = audio_latents.shape
audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af))
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3))
@@ -552,15 +553,15 @@ def denoise_distilled(
if audio_denoised is not None:
mx.eval(audio_denoised)
# Euler step in float32 (latents stay in float32)
# Euler step in float32
if sigma_next > 0:
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
latents = denoised + sigma_next_f32 * (latents - denoised) / sigma_f32
if enable_audio and audio_denoised is not None:
if enable_audio and audio_denoised is not None and not audio_frozen:
audio_latents = audio_denoised + sigma_next_f32 * (audio_latents - audio_denoised) / sigma_f32
else:
latents = denoised
if enable_audio and audio_denoised is not None:
if enable_audio and audio_denoised is not None and not audio_frozen:
audio_latents = audio_denoised
mx.eval(latents)
@@ -785,6 +786,7 @@ def denoise_dev_av(
stg_video_blocks: Optional[list] = None,
stg_audio_blocks: Optional[list] = None,
modality_scale: float = 1.0,
audio_frozen: bool = False,
) -> tuple[mx.array, mx.array]:
"""Run denoising loop for dev pipeline with CFG/APG, STG, modality guidance, and audio.
@@ -879,11 +881,12 @@ def denoise_dev_av(
else:
video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
audio_timesteps = mx.full((ab, at), sigma, dtype=dtype)
# A2V: frozen audio uses timesteps=0 (tells model audio is clean)
audio_timesteps = mx.zeros((ab, at), dtype=dtype) if audio_frozen else mx.full((ab, at), sigma, dtype=dtype)
# Positive conditioning pass
sigma_array = mx.full((b,), sigma, dtype=dtype)
audio_sigma_array = mx.full((ab,), sigma, dtype=dtype)
audio_sigma_array = mx.zeros((ab,), dtype=dtype) if audio_frozen else mx.full((ab,), sigma, dtype=dtype)
video_modality_pos = Modality(
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
context=video_embeddings_pos, context_mask=None, enabled=True,
@@ -1001,11 +1004,13 @@ def denoise_dev_av(
video_velocity_f32 = (video_latents - video_denoised_f32) / sigma_f32
video_latents = video_latents + video_velocity_f32 * dt_f32
audio_velocity_f32 = (audio_latents - audio_denoised_f32) / sigma_f32
audio_latents = audio_latents + audio_velocity_f32 * dt_f32
if not audio_frozen:
audio_velocity_f32 = (audio_latents - audio_denoised_f32) / sigma_f32
audio_latents = audio_latents + audio_velocity_f32 * dt_f32
else:
video_latents = video_denoised_f32
audio_latents = audio_denoised_f32
if not audio_frozen:
audio_latents = audio_denoised_f32
mx.eval(video_latents, audio_latents)
progress.advance(task)
@@ -1037,6 +1042,7 @@ def denoise_res2s_av(
noise_seed: int = 42,
bongmath: bool = True,
bongmath_max_iter: int = 100,
audio_frozen: bool = False,
) -> tuple[mx.array, mx.array]:
"""Run res_2s second-order denoising loop with CFG/STG/modality guidance.
@@ -1125,10 +1131,10 @@ def denoise_res2s_av(
video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
else:
video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
audio_timesteps = mx.full((ab, at), sigma, dtype=dtype)
audio_timesteps = mx.zeros((ab, at), dtype=dtype) if audio_frozen else mx.full((ab, at), sigma, dtype=dtype)
sigma_array = mx.full((b,), sigma, dtype=dtype)
audio_sigma_array = mx.full((ab,), sigma, dtype=dtype)
audio_sigma_array = mx.zeros((ab,), dtype=dtype) if audio_frozen else mx.full((ab,), sigma, dtype=dtype)
# Pass 1: Positive conditioning
video_modality_pos = Modality(
@@ -1270,18 +1276,23 @@ def denoise_res2s_av(
# Compute midpoint
eps_1_video = denoised_video_1 - x_anchor_video
eps_1_audio = denoised_audio_1 - x_anchor_audio
x_mid_video = x_anchor_video + h * a21 * eps_1_video
x_mid_audio = x_anchor_audio + h * a21 * eps_1_audio
if not audio_frozen:
eps_1_audio = denoised_audio_1 - x_anchor_audio
x_mid_audio = x_anchor_audio + h * a21 * eps_1_audio
else:
eps_1_audio = None
x_mid_audio = audio_latents # frozen: pass through unchanged
# SDE noise injection at substep
substep_noise_key, key1, key2 = mx.random.split(substep_noise_key, 3)
substep_noise_v = get_new_noise(video_latents.shape, key1)
substep_noise_a = get_new_noise(audio_latents.shape, key2)
x_mid_video = sde_noise_step(x_anchor_video, x_mid_video, sigma, sub_sigma, substep_noise_v)
x_mid_audio = sde_noise_step(x_anchor_audio, x_mid_audio, sigma, sub_sigma, substep_noise_a)
if not audio_frozen:
substep_noise_a = get_new_noise(audio_latents.shape, key2)
x_mid_audio = sde_noise_step(x_anchor_audio, x_mid_audio, sigma, sub_sigma, substep_noise_a)
mx.eval(x_mid_video, x_mid_audio)
# ============================================================
@@ -1291,9 +1302,13 @@ def denoise_res2s_av(
for _ in range(bongmath_max_iter):
x_anchor_video = x_mid_video - h * a21 * eps_1_video
eps_1_video = denoised_video_1 - x_anchor_video
x_anchor_audio = x_mid_audio - h * a21 * eps_1_audio
eps_1_audio = denoised_audio_1 - x_anchor_audio
mx.eval(x_anchor_video, x_anchor_audio, eps_1_video, eps_1_audio)
if not audio_frozen:
x_anchor_audio = x_mid_audio - h * a21 * eps_1_audio
eps_1_audio = denoised_audio_1 - x_anchor_audio
if audio_frozen:
mx.eval(x_anchor_video, eps_1_video)
else:
mx.eval(x_anchor_video, x_anchor_audio, eps_1_video, eps_1_audio)
# ============================================================
# Stage 2: Evaluate denoiser at midpoint sigma
@@ -1306,21 +1321,21 @@ def denoise_res2s_av(
# Final combination with RK coefficients
# ============================================================
eps_2_video = denoised_video_2 - x_anchor_video
eps_2_audio = denoised_audio_2 - x_anchor_audio
x_next_video = x_anchor_video + h * (b1 * eps_1_video + b2 * eps_2_video)
x_next_audio = x_anchor_audio + h * (b1 * eps_1_audio + b2 * eps_2_audio)
# SDE noise injection at step level
step_noise_key, key1, key2 = mx.random.split(step_noise_key, 3)
step_noise_v = get_new_noise(video_latents.shape, key1)
step_noise_a = get_new_noise(audio_latents.shape, key2)
x_next_video = sde_noise_step(x_anchor_video, x_next_video, sigma, sigma_next, step_noise_v)
x_next_audio = sde_noise_step(x_anchor_audio, x_next_audio, sigma, sigma_next, step_noise_a)
video_latents = x_next_video.astype(mx.float32)
audio_latents = x_next_audio.astype(mx.float32)
if not audio_frozen:
eps_2_audio = denoised_audio_2 - x_anchor_audio
x_next_audio = x_anchor_audio + h * (b1 * eps_1_audio + b2 * eps_2_audio)
step_noise_a = get_new_noise(audio_latents.shape, key2)
x_next_audio = sde_noise_step(x_anchor_audio, x_next_audio, sigma, sigma_next, step_noise_a)
audio_latents = x_next_audio.astype(mx.float32)
mx.eval(video_latents, audio_latents)
progress.advance(task)
@@ -1330,7 +1345,8 @@ def denoise_res2s_av(
video_latents, audio_latents, sigmas_list[n_full_steps]
)
video_latents = denoised_video
audio_latents = denoised_audio
if not audio_frozen:
audio_latents = denoised_audio
mx.eval(video_latents, audio_latents)
return video_latents, audio_latents
@@ -1443,6 +1459,8 @@ def generate_video(
lora_strength: float = 1.0,
lora_strength_stage_1: Optional[float] = None,
lora_strength_stage_2: Optional[float] = None,
audio_file: Optional[str] = None,
audio_start_time: float = 0.0,
):
"""Generate video using LTX-2 models.
@@ -1496,8 +1514,16 @@ def generate_video(
num_frames = adjusted_num_frames
is_i2v = image is not None
is_a2v = audio_file is not None
if is_a2v and audio:
raise ValueError("Cannot use both --audio-file (A2V) and --audio (generate audio). Choose one.")
# A2V implicitly enables audio path through the transformer
if is_a2v:
audio = True
mode_str = "I2V" if is_i2v else "T2V"
if audio:
if is_a2v:
mode_str = "A2V" + ("+I2V" if is_i2v else "")
elif audio:
mode_str += "+Audio"
pipeline_names = {
@@ -1599,6 +1625,62 @@ def generate_video(
stg_blocks = [29]
console.print(f"[dim]Auto-detected STG blocks: {stg_blocks} (model={'2.3' if transformer.config.has_prompt_adaln else '2'})[/]")
# ==========================================================================
# A2V: Encode input audio to frozen latents
# ==========================================================================
a2v_audio_latents = None
a2v_waveform = None
a2v_sr = None
if is_a2v:
from mlx_video.models.ltx.audio_vae.audio_processor import load_audio, ensure_stereo, waveform_to_mel
from mlx_video.convert import convert_audio_encoder
from mlx_video.models.ltx.audio_vae import AudioEncoder
with console.status("[blue]Loading and encoding input audio (A2V)...[/]", spinner="dots"):
video_duration = num_frames / fps
# Load audio
waveform, sr = load_audio(
audio_file,
target_sr=AUDIO_LATENT_SAMPLE_RATE,
start_time=audio_start_time,
max_duration=video_duration,
)
waveform = ensure_stereo(waveform)
a2v_waveform = waveform.copy()
a2v_sr = sr
# Compute mel-spectrogram
mel = waveform_to_mel(waveform, sample_rate=sr, n_fft=1024, hop_length=AUDIO_HOP_LENGTH, n_mels=64)
# Convert audio encoder weights if needed, then load
encoder_dir = convert_audio_encoder(model_path, source_repo="Lightricks/LTX-2")
audio_encoder = AudioEncoder.from_pretrained(encoder_dir)
mx.eval(audio_encoder.parameters())
# Encode: (1, 2, time, 64) -> normalized latents
encoded = audio_encoder(mel)
mx.eval(encoded)
# encoded is in MLX format (B, T', mel_bins', z_channels) = (1, T', 16, 8)
# Convert to PyTorch-style format for consistency: (B, C, T, mel_bins)
a2v_audio_latents = mx.transpose(encoded, (0, 3, 1, 2)).astype(model_dtype)
# Trim/pad to match expected audio_frames
t_encoded = a2v_audio_latents.shape[2]
if t_encoded > audio_frames:
a2v_audio_latents = a2v_audio_latents[:, :, :audio_frames, :]
elif t_encoded < audio_frames:
pad_size = audio_frames - t_encoded
padding = mx.zeros((1, AUDIO_LATENT_CHANNELS, pad_size, AUDIO_MEL_BINS), dtype=model_dtype)
a2v_audio_latents = mx.concatenate([a2v_audio_latents, padding], axis=2)
mx.eval(a2v_audio_latents)
del audio_encoder
mx.clear_cache()
console.print(f"[green]✓[/] Audio encoded ({a2v_audio_latents.shape[2]} frames from {audio_file})")
# ==========================================================================
# Pipeline-specific generation logic
# ==========================================================================
@@ -1636,9 +1718,9 @@ def generate_video(
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions)
# Always init audio latents/positions - PyTorch unconditionally generates audio
# Init audio latents/positions: use encoded A2V latents or random
audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
mx.eval(audio_positions, audio_latents)
# Apply I2V conditioning
@@ -1671,6 +1753,7 @@ def generate_video(
latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS,
verbose=verbose, state=state1,
audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings,
audio_frozen=is_a2v,
)
# Upsample latents
@@ -1723,7 +1806,7 @@ def generate_video(
mx.eval(latents)
# Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch)
if audio_latents is not None:
if audio_latents is not None and not is_a2v:
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
@@ -1735,6 +1818,7 @@ def generate_video(
verbose=verbose, state=state2,
audio_latents=audio_latents, audio_positions=audio_positions,
audio_embeddings=audio_embeddings,
audio_frozen=is_a2v,
)
elif pipeline == PipelineType.DEV:
@@ -1770,7 +1854,7 @@ def generate_video(
# Always init audio latents/positions - PyTorch unconditionally generates audio
audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
mx.eval(audio_positions, audio_latents)
# Initialize latents with optional I2V conditioning
@@ -1811,6 +1895,7 @@ def generate_video(
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
audio_frozen=is_a2v,
)
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
@@ -1858,7 +1943,7 @@ def generate_video(
# Always init audio latents/positions - PyTorch unconditionally generates audio
audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
mx.eval(audio_positions, audio_latents)
# Apply I2V conditioning for stage 1
@@ -1899,6 +1984,7 @@ def generate_video(
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
audio_frozen=is_a2v,
)
mx.eval(audio_latents)
@@ -1969,7 +2055,7 @@ def generate_video(
mx.eval(latents)
# Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch)
if audio_latents is not None:
if audio_latents is not None and not is_a2v:
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
@@ -1981,6 +2067,7 @@ def generate_video(
verbose=verbose, state=state2,
audio_latents=audio_latents, audio_positions=audio_positions,
audio_embeddings=audio_embeddings_pos,
audio_frozen=is_a2v,
)
elif pipeline == PipelineType.DEV_TWO_STAGE_HQ:
@@ -2045,7 +2132,7 @@ def generate_video(
mx.eval(positions)
audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
mx.eval(audio_positions, audio_latents)
# Apply I2V conditioning for stage 1
@@ -2087,6 +2174,7 @@ def generate_video(
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
noise_seed=seed,
audio_frozen=is_a2v,
)
mx.eval(audio_latents)
@@ -2148,7 +2236,7 @@ def generate_video(
mx.eval(latents)
# Re-noise audio at sigma=0.909375 for joint refinement
if audio_latents is not None:
if audio_latents is not None and not is_a2v:
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
@@ -2165,6 +2253,7 @@ def generate_video(
audio_cfg_scale=1.0,
cfg_rescale=0.0, verbose=verbose, video_state=state2,
noise_seed=seed + 1,
audio_frozen=is_a2v,
)
del transformer
@@ -2279,29 +2368,38 @@ def generate_video(
# Decode and save audio if enabled
audio_np = None
vocoder_sample_rate = AUDIO_SAMPLE_RATE
if audio and audio_latents is not None:
with console.status("[blue]🔊 Decoding audio...[/]", spinner="dots"):
audio_decoder = load_audio_decoder(model_path, pipeline)
vocoder = load_vocoder_model(model_path, pipeline)
mx.eval(audio_decoder.parameters(), vocoder.parameters())
if is_a2v and a2v_waveform is not None:
# A2V: use original input audio waveform (no VAE decoding needed)
audio_np = a2v_waveform
if audio_np.ndim == 1:
audio_np = audio_np[np.newaxis, :]
vocoder_sample_rate = a2v_sr or AUDIO_LATENT_SAMPLE_RATE
console.print("[green]✓[/] Using original input audio (A2V)")
else:
with console.status("[blue]Decoding audio...[/]", spinner="dots"):
audio_decoder = load_audio_decoder(model_path, pipeline)
vocoder = load_vocoder_model(model_path, pipeline)
mx.eval(audio_decoder.parameters(), vocoder.parameters())
mel_spectrogram = audio_decoder(audio_latents)
mx.eval(mel_spectrogram)
console.print(f"[dim] Mel spectrogram: shape={mel_spectrogram.shape}, std={mel_spectrogram.std().item():.4f}, mean={mel_spectrogram.mean().item():.4f}[/]")
mel_spectrogram = audio_decoder(audio_latents)
mx.eval(mel_spectrogram)
console.print(f"[dim] Mel spectrogram: shape={mel_spectrogram.shape}, std={mel_spectrogram.std().item():.4f}, mean={mel_spectrogram.mean().item():.4f}[/]")
audio_waveform = vocoder(mel_spectrogram)
mx.eval(audio_waveform)
audio_waveform = vocoder(mel_spectrogram)
mx.eval(audio_waveform)
audio_np = np.array(audio_waveform.astype(mx.float32))
if audio_np.ndim == 3:
audio_np = audio_np[0]
audio_np = np.array(audio_waveform.astype(mx.float32))
if audio_np.ndim == 3:
audio_np = audio_np[0]
# Get sample rate from vocoder (dynamic: 24kHz for LTX-2, 48kHz for LTX-2.3 BWE)
vocoder_sample_rate = getattr(vocoder, 'output_sampling_rate', AUDIO_SAMPLE_RATE)
# Get sample rate from vocoder (dynamic: 24kHz for LTX-2, 48kHz for LTX-2.3 BWE)
vocoder_sample_rate = getattr(vocoder, 'output_sampling_rate', AUDIO_SAMPLE_RATE)
del audio_decoder, vocoder
mx.clear_cache()
console.print("[green]✓[/] Audio decoded")
del audio_decoder, vocoder
mx.clear_cache()
console.print("[green]✓[/] Audio decoded")
audio_path = Path(output_audio_path) if output_audio_path else output_path.with_suffix('.wav')
save_audio(audio_np, audio_path, vocoder_sample_rate)
@@ -2398,6 +2496,8 @@ Examples:
help="Tiling mode for VAE decoding")
parser.add_argument("--stream", action="store_true", help="Stream frames to output as they're decoded")
parser.add_argument("--audio", "-a", action="store_true", help="Enable synchronized audio generation")
parser.add_argument("--audio-file", type=str, default=None, help="Path to audio file for A2V (audio-to-video) conditioning")
parser.add_argument("--audio-start-time", type=float, default=0.0, help="Start time in seconds for audio file (default: 0.0)")
parser.add_argument("--output-audio", type=str, default=None, help="Output audio path")
parser.add_argument("--apg", action="store_true", help="Use Adaptive Projected Guidance instead of CFG (more stable for I2V)")
parser.add_argument("--apg-eta", type=float, default=1.0, help="APG parallel component weight (1.0 = keep full parallel)")
@@ -2457,6 +2557,8 @@ Examples:
lora_strength=args.lora_strength,
lora_strength_stage_1=args.lora_strength_stage_1,
lora_strength_stage_2=args.lora_strength_stage_2,
audio_file=args.audio_file,
audio_start_time=args.audio_start_time,
)

View File

@@ -1,7 +1,8 @@
"""Audio VAE module for LTX-2 audio generation."""
from .attention import AttentionType, AttnBlock, make_attn
from .audio_vae import AudioDecoder, decode_audio
from .audio_vae import AudioDecoder, AudioEncoder, decode_audio
from .audio_processor import load_audio, ensure_stereo, waveform_to_mel
from .causal_conv_2d import CausalConv2d, make_conv2d
from ..config import CausalityAxis
from .downsample import Downsample, build_downsampling_path
@@ -13,10 +14,15 @@ from .vocoder import Vocoder, load_vocoder
__all__ = [
# Main components
"AudioEncoder",
"AudioDecoder",
"Vocoder",
"load_vocoder",
"decode_audio",
# Audio processing
"load_audio",
"ensure_stereo",
"waveform_to_mel",
# Ops
"AudioLatentShape",
"AudioPatchifier",

View File

@@ -0,0 +1,135 @@
"""Audio processing utilities for loading audio files and computing mel-spectrograms.
Matches the PyTorch AudioProcessor from LTX-2 (torchaudio.transforms.MelSpectrogram)
using librosa for macOS/MLX compatibility.
"""
from pathlib import Path
import numpy as np
import mlx.core as mx
def load_audio(
path: str,
target_sr: int = 16000,
start_time: float = 0.0,
max_duration: float | None = None,
mono: bool = False,
) -> tuple[np.ndarray, int]:
"""Load audio file, resample to target sample rate.
Args:
path: Path to audio file (WAV, FLAC, MP3, OGG, or video with audio track).
target_sr: Target sample rate (default 16000 Hz).
start_time: Start time in seconds.
max_duration: Maximum duration in seconds. None = read to end.
mono: If True, convert to mono. Default False (preserve channels).
Returns:
(waveform, sample_rate) where waveform is (channels, samples) float32 numpy array.
"""
import librosa
# librosa.load returns mono by default; we want to preserve stereo
y, sr = librosa.load(
path,
sr=target_sr,
mono=mono,
offset=start_time,
duration=max_duration,
)
# Ensure 2D: (channels, samples)
if y.ndim == 1:
y = y[np.newaxis, :] # (1, samples)
return y.astype(np.float32), sr
def ensure_stereo(waveform: np.ndarray) -> np.ndarray:
"""Ensure waveform is stereo (2, samples). Duplicates mono if needed."""
if waveform.ndim == 1:
waveform = waveform[np.newaxis, :]
if waveform.shape[0] == 1:
waveform = np.concatenate([waveform, waveform], axis=0)
elif waveform.shape[0] > 2:
waveform = waveform[:2]
return waveform
def waveform_to_mel(
waveform: np.ndarray,
sample_rate: int = 16000,
n_fft: int = 1024,
hop_length: int = 160,
win_length: int = 1024,
n_mels: int = 64,
fmin: float = 0.0,
fmax: float = 8000.0,
) -> mx.array:
"""Convert waveform to log-mel spectrogram matching PyTorch MelSpectrogram.
PyTorch reference:
MelSpectrogram(sample_rate=16000, n_fft=1024, win_length=1024, hop_length=160,
f_min=0.0, f_max=8000.0, n_mels=64, power=1.0,
mel_scale="slaney", norm="slaney", center=True, pad_mode="reflect")
Args:
waveform: (channels, samples) float32 numpy array.
sample_rate: Sample rate of the waveform.
n_fft: FFT size.
hop_length: Hop length.
win_length: Window length.
n_mels: Number of mel bins.
fmin: Minimum frequency for mel filterbank.
fmax: Maximum frequency for mel filterbank.
Returns:
Log-mel spectrogram as mx.array of shape (1, channels, time, n_mels).
"""
import librosa
# Ensure 2D
if waveform.ndim == 1:
waveform = waveform[np.newaxis, :]
channels = waveform.shape[0]
mels = []
for ch in range(channels):
# Magnitude spectrogram (power=1.0)
S = np.abs(librosa.stft(
waveform[ch],
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
center=True,
pad_mode="reflect",
))
# Mel filterbank with slaney normalization
mel_basis = librosa.filters.mel(
sr=sample_rate,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
norm="slaney",
)
mel = mel_basis @ S
# Log scale
mel = np.log(np.clip(mel, a_min=1e-5, a_max=None))
# Transpose: (n_mels, time) -> (time, n_mels)
mel = mel.T
mels.append(mel)
# Stack channels: (channels, time, n_mels)
mel_spec = np.stack(mels, axis=0)
# Add batch dim: (1, channels, time, n_mels)
mel_spec = mel_spec[np.newaxis, ...]
return mx.array(mel_spec, dtype=mx.float32)

View File

@@ -6,10 +6,11 @@ from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from mlx_vlm.models.base import check_array_shape
from ..config import AudioDecoderModelConfig
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig
from .attention import AttentionType, make_attn
from .causal_conv_2d import make_conv2d
from ..config import CausalityAxis
from .downsample import build_downsampling_path
from .normalization import NormType, build_normalization_layer
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
from .resnet import ResnetBlock
@@ -59,6 +60,179 @@ def run_mid_block(mid: dict, features: mx.array) -> mx.array:
return mid["block_2"](features, temb=None)
class AudioEncoder(nn.Module):
"""Encoder that compresses audio spectrograms into latent representations."""
def __init__(self, config: AudioEncoderModelConfig) -> None:
super().__init__()
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch)
self.sample_rate = config.sample_rate
self.mel_hop_length = config.mel_hop_length
self.is_causal = config.is_causal
self.mel_bins = config.mel_bins
self.patchifier = AudioPatchifier(
patch_size=1,
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
sample_rate=config.sample_rate,
hop_length=config.mel_hop_length,
is_causal=config.is_causal,
)
self.ch = config.ch
self.temb_ch = 0
self.num_resolutions = len(config.ch_mult)
self.num_res_blocks = config.num_res_blocks
self.resolution = config.resolution
self.in_channels = config.in_channels
self.z_channels = config.z_channels
self.double_z = config.double_z
self.norm_type = config.norm_type
self.causality_axis = config.causality_axis
self.attn_type = config.attn_type
self.conv_in = make_conv2d(
config.in_channels, self.ch, kernel_size=3, stride=1,
causality_axis=self.causality_axis,
)
self.down, block_in = build_downsampling_path(
ch=config.ch,
ch_mult=config.ch_mult,
num_resolutions=self.num_resolutions,
num_res_blocks=config.num_res_blocks,
resolution=config.resolution,
temb_channels=self.temb_ch,
dropout=config.dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
attn_resolutions=config.attn_resolutions or set(),
resamp_with_conv=config.resamp_with_conv,
)
self.mid = build_mid_block(
channels=block_in,
temb_channels=self.temb_ch,
dropout=config.dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
add_attention=config.mid_block_add_attention,
)
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
out_channels = 2 * config.z_channels if config.double_z else config.z_channels
self.conv_out = make_conv2d(
block_in, out_channels, kernel_size=3, stride=1,
causality_axis=self.causality_axis,
)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize audio encoder weights from PyTorch format."""
sanitized = {}
for key, value in weights.items():
new_key = key
if key.startswith("audio_vae.encoder."):
new_key = key.replace("audio_vae.encoder.", "")
elif key.startswith("encoder."):
new_key = key.replace("encoder.", "")
elif key.startswith("audio_vae.per_channel_statistics."):
if "mean-of-means" in key:
new_key = "per_channel_statistics.mean_of_means"
elif "std-of-means" in key:
new_key = "per_channel_statistics.std_of_means"
else:
continue
elif "per_channel_statistics" in key:
if "mean-of-means" in key or "latents_mean" in key:
new_key = "per_channel_statistics.mean_of_means"
elif "std-of-means" in key or "latents_std" in key:
new_key = "per_channel_statistics.std_of_means"
else:
continue
elif key == "latents_mean":
new_key = "per_channel_statistics.mean_of_means"
elif key == "latents_std":
new_key = "per_channel_statistics.std_of_means"
else:
continue
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
"""Load audio encoder from pretrained weights."""
from mlx_video.models.ltx.config import AudioEncoderModelConfig
import json
model_path = Path(model_path)
config = AudioEncoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
encoder = cls(config)
weights = mx.load(str(model_path / "model.safetensors"))
encoder.load_weights(list(weights.items()), strict=True)
return encoder
def __call__(self, spectrogram: mx.array) -> mx.array:
"""Encode audio spectrogram into normalized latent representation.
Args:
spectrogram: (B, C, T, F) PyTorch format or (B, T, F, C) MLX format.
Returns:
Normalized latent (B, T', F', z_channels) in MLX channels-last format.
"""
if spectrogram.ndim == 4 and spectrogram.shape[1] == self.in_channels:
spectrogram = mx.transpose(spectrogram, (0, 2, 3, 1))
h = self.conv_in(spectrogram)
h = self._run_downsampling_path(h)
h = run_mid_block(self.mid, h)
h = self._finalize_output(h)
return self._normalize_latents(h)
def _run_downsampling_path(self, h: mx.array) -> mx.array:
for level in range(self.num_resolutions):
stage = self.down[level]
for block_idx in range(self.num_res_blocks):
h = stage["block"][block_idx](h, temb=None)
if block_idx in stage["attn"]:
h = stage["attn"][block_idx](h)
if level != self.num_resolutions - 1 and "downsample" in stage:
h = stage["downsample"](h)
return h
def _finalize_output(self, h: mx.array) -> mx.array:
h = self.norm_out(h)
h = nn.silu(h)
return self.conv_out(h)
def _normalize_latents(self, h: mx.array) -> mx.array:
"""Normalize encoder output using per-channel statistics.
Takes first half of channels (mean) when double_z=True,
then patchifies, normalizes, and unpatchifies.
"""
# h shape: (B, T', F', 2*z_channels) in MLX format
z_channels = self.z_channels
means = h[..., :z_channels]
latent_shape = AudioLatentShape(
batch=means.shape[0],
channels=means.shape[3],
frames=means.shape[1],
mel_bins=means.shape[2],
)
patched = self.patchifier.patchify(means)
normalized = self.per_channel_statistics.normalize(patched)
return self.patchifier.unpatchify(normalized, latent_shape)
class AudioDecoder(nn.Module):
"""
Symmetric decoder that reconstructs audio spectrograms from latent features.

View File

@@ -2,7 +2,7 @@
import inspect
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, List, Optional, Tuple, Set
from typing import Any, List, Optional, Tuple
class LTXModelType(Enum):
@@ -252,6 +252,47 @@ class AudioDecoderModelConfig(BaseModelConfig):
if isinstance(self.attn_type, str):
self.attn_type = AttentionType(self.attn_type)
@dataclass
class AudioEncoderModelConfig(BaseModelConfig):
ch: int = 128
in_channels: int = 2
ch_mult: Tuple[int, ...] = (1, 2, 4)
num_res_blocks: int = 2
attn_resolutions: Optional[List[int]] = None
resolution: int = 256
z_channels: int = 8
double_z: bool = True
n_fft: int = 1024
norm_type: Enum = None
causality_axis: Enum = None
dropout: float = 0.0
mid_block_add_attention: bool = True
sample_rate: int = 16000
mel_hop_length: int = 160
is_causal: bool = True
mel_bins: int = 64
resamp_with_conv: bool = True
attn_type: str = None
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if self.attn_resolutions is not None:
result["attn_resolutions"] = list(self.attn_resolutions)
return result
def __post_init__(self):
"""Convert string enum values to proper enum types."""
from .audio_vae.normalization import NormType
from .audio_vae.attention import AttentionType
if isinstance(self.causality_axis, str):
self.causality_axis = CausalityAxis(self.causality_axis)
if isinstance(self.norm_type, str):
self.norm_type = NormType(self.norm_type)
if isinstance(self.attn_type, str):
self.attn_type = AttentionType(self.attn_type)
@dataclass
class VocoderModelConfig(BaseModelConfig):
resblock_kernel_sizes: Optional[List[int]] = None