Add audio to video conditioning
This commit is contained in:
25
README.md
25
README.md
@@ -24,6 +24,7 @@ Supported models:
|
||||
## Features
|
||||
|
||||
- Text-to-video (T2V) and Image-to-video (I2V) generation
|
||||
- Audio-to-video (A2V) conditioning — generate video from input audio
|
||||
- Four pipeline modes: Distilled, Dev, Dev Two-Stage, and Dev Two-Stage HQ
|
||||
- Synchronized audio-video generation (experimental)
|
||||
- LoRA support (including HuggingFace repos)
|
||||
@@ -85,7 +86,27 @@ uv run mlx_video.generate --prompt "A person dancing" --image photo.jpg
|
||||
uv run mlx_video.generate --pipeline dev --prompt "Waves crashing" --image beach.png --cfg-scale 3.5
|
||||
```
|
||||
|
||||
### Audio-Video (experimental)
|
||||
### Audio-to-Video (A2V)
|
||||
|
||||
Generate video conditioned on an input audio file. The audio is encoded to latent space and frozen during denoising — the transformer's cross-attention reads the audio signal to guide video generation.
|
||||
|
||||
```bash
|
||||
# A2V - generate video from audio
|
||||
uv run mlx_video.generate --audio-file music.wav --prompt "A band playing music"
|
||||
|
||||
# A2V with dev pipeline
|
||||
uv run mlx_video.generate --pipeline dev --audio-file ocean.wav --prompt "Ocean waves"
|
||||
|
||||
# A2V + I2V (audio + image conditioning)
|
||||
uv run mlx_video.generate --audio-file rain.wav --image forest.jpg --prompt "Rain in forest"
|
||||
|
||||
# A2V with custom start time
|
||||
uv run mlx_video.generate --audio-file song.mp3 --audio-start-time 30.0 --prompt "Concert"
|
||||
```
|
||||
|
||||
### Audio-Video Generation (experimental)
|
||||
|
||||
Generate synchronized audio alongside video from scratch:
|
||||
|
||||
```bash
|
||||
uv run mlx_video.generate --prompt "Ocean waves crashing" --audio
|
||||
@@ -150,6 +171,8 @@ uv run mlx_video.upscale --input video.mp4 --output upscaled.mp4 --refine --prom
|
||||
| `--image`, `-i` | None | Conditioning image for I2V |
|
||||
| `--image-strength` | 1.0 | Conditioning strength for I2V |
|
||||
| `--audio`, `-a` | false | Enable synchronized audio generation |
|
||||
| `--audio-file` | None | Path to audio file for A2V conditioning |
|
||||
| `--audio-start-time` | 0.0 | Start time in seconds for audio file |
|
||||
| `--tiling` | `auto` | VAE tiling mode: `auto`, `none`, `aggressive`, `conservative` |
|
||||
| `--stream` | false | Stream frames as they decode |
|
||||
|
||||
|
||||
@@ -606,6 +606,86 @@ def save_weights(path: Path, weights: Dict[str, mx.array]) -> None:
|
||||
mx.save_safetensors(str(path / "model.safetensors"), weights)
|
||||
|
||||
|
||||
def convert_audio_encoder(
|
||||
model_path: Union[str, Path],
|
||||
source_repo: str = "Lightricks/LTX-2",
|
||||
) -> Path:
|
||||
"""Convert and save audio encoder weights from original HF checkpoint.
|
||||
|
||||
The audio VAE safetensors in the HF repo contains both encoder and decoder
|
||||
weights. This extracts encoder weights, transposes Conv2d for MLX, and saves
|
||||
them to a separate directory for AudioEncoder.from_pretrained().
|
||||
|
||||
Args:
|
||||
model_path: Local model directory (output location).
|
||||
source_repo: HF repo containing audio_vae/diffusion_pytorch_model.safetensors.
|
||||
|
||||
Returns:
|
||||
Path to the audio_vae_encoder directory.
|
||||
"""
|
||||
model_path = Path(model_path)
|
||||
encoder_dir = model_path / "audio_vae_encoder"
|
||||
|
||||
if (encoder_dir / "model.safetensors").exists():
|
||||
return encoder_dir
|
||||
|
||||
# Download original audio VAE weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
vae_path = hf_hub_download(
|
||||
source_repo,
|
||||
"audio_vae/diffusion_pytorch_model.safetensors",
|
||||
)
|
||||
|
||||
raw_weights = mx.load(vae_path)
|
||||
|
||||
# Extract encoder weights and per-channel statistics
|
||||
from mlx_video.models.ltx.audio_vae import AudioEncoder
|
||||
from mlx_video.models.ltx.config import AudioEncoderModelConfig
|
||||
|
||||
# Build config from the decoder config (same audio VAE architecture)
|
||||
decoder_config_path = model_path / "audio_vae" / "config.json"
|
||||
if decoder_config_path.exists():
|
||||
with open(decoder_config_path) as f:
|
||||
dec_cfg = json.load(f)
|
||||
enc_config = {
|
||||
"ch": dec_cfg.get("ch", 128),
|
||||
"in_channels": dec_cfg.get("out_ch", 2),
|
||||
"ch_mult": dec_cfg.get("ch_mult", [1, 2, 4]),
|
||||
"num_res_blocks": dec_cfg.get("num_res_blocks", 2),
|
||||
"attn_resolutions": dec_cfg.get("attn_resolutions", []),
|
||||
"resolution": dec_cfg.get("resolution", 256),
|
||||
"z_channels": dec_cfg.get("z_channels", 8),
|
||||
"double_z": True,
|
||||
"n_fft": 1024,
|
||||
"norm_type": dec_cfg.get("norm_type", "pixel"),
|
||||
"causality_axis": dec_cfg.get("causality_axis", "height"),
|
||||
"dropout": dec_cfg.get("dropout", 0.0),
|
||||
"mid_block_add_attention": dec_cfg.get("mid_block_add_attention", False),
|
||||
"sample_rate": dec_cfg.get("sample_rate", 16000),
|
||||
"mel_hop_length": dec_cfg.get("mel_hop_length", 160),
|
||||
"is_causal": dec_cfg.get("is_causal", True),
|
||||
"mel_bins": dec_cfg.get("mel_bins", 64) or 64,
|
||||
"resamp_with_conv": dec_cfg.get("resamp_with_conv", True),
|
||||
"attn_type": dec_cfg.get("attn_type", "vanilla"),
|
||||
}
|
||||
else:
|
||||
enc_config = {"in_channels": 2, "double_z": True, "n_fft": 1024, "mel_bins": 64}
|
||||
|
||||
# Sanitize weights
|
||||
config = AudioEncoderModelConfig.from_dict(enc_config)
|
||||
encoder = AudioEncoder(config)
|
||||
sanitized = encoder.sanitize(raw_weights)
|
||||
|
||||
# Save
|
||||
encoder_dir.mkdir(parents=True, exist_ok=True)
|
||||
mx.save_safetensors(str(encoder_dir / "model.safetensors"), sanitized)
|
||||
with open(encoder_dir / "config.json", "w") as f:
|
||||
json.dump(enc_config, f, indent=2)
|
||||
|
||||
print(f"Audio encoder weights saved to {encoder_dir}")
|
||||
return encoder_dir
|
||||
|
||||
|
||||
def load_model(
|
||||
path_or_hf_repo: str,
|
||||
lazy: bool = False,
|
||||
|
||||
@@ -454,6 +454,7 @@ def denoise_distilled(
|
||||
audio_latents: Optional[mx.array] = None,
|
||||
audio_positions: Optional[mx.array] = None,
|
||||
audio_embeddings: Optional[mx.array] = None,
|
||||
audio_frozen: bool = False,
|
||||
) -> tuple[mx.array, Optional[mx.array]]:
|
||||
"""Run denoising loop for distilled pipeline (no CFG)."""
|
||||
dtype = latents.dtype
|
||||
@@ -513,14 +514,17 @@ def denoise_distilled(
|
||||
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3))
|
||||
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype)
|
||||
|
||||
# A2V: frozen audio uses timesteps=0 (tells model audio is clean)
|
||||
a_ts = mx.zeros((ab, at), dtype=dtype) if audio_frozen else mx.full((ab, at), sigma, dtype=dtype)
|
||||
a_sig = mx.zeros((ab,), dtype=dtype) if audio_frozen else mx.full((ab,), sigma, dtype=dtype)
|
||||
audio_modality = Modality(
|
||||
latent=audio_flat,
|
||||
timesteps=mx.full((ab, at), sigma, dtype=dtype),
|
||||
timesteps=a_ts,
|
||||
positions=audio_positions,
|
||||
context=audio_embeddings,
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
sigma=mx.full((ab,), sigma, dtype=dtype),
|
||||
sigma=a_sig,
|
||||
)
|
||||
|
||||
velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality)
|
||||
@@ -529,9 +533,6 @@ def denoise_distilled(
|
||||
mx.eval(audio_velocity)
|
||||
|
||||
# Compute denoised (x0) using per-token timesteps in float32
|
||||
# x0 = latent - timestep * velocity
|
||||
# For conditioned tokens (timestep=0): x0 = latent
|
||||
# For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity
|
||||
sigma_f32 = mx.array(sigma, dtype=mx.float32)
|
||||
latents_flat_f32 = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
|
||||
timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1)
|
||||
@@ -539,7 +540,7 @@ def denoise_distilled(
|
||||
denoised = mx.reshape(mx.transpose(x0_f32, (0, 2, 1)), (b, c, f, h, w))
|
||||
|
||||
audio_denoised = None
|
||||
if enable_audio and audio_velocity is not None:
|
||||
if enable_audio and audio_velocity is not None and not audio_frozen:
|
||||
ab, ac, at, af = audio_latents.shape
|
||||
audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af))
|
||||
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3))
|
||||
@@ -552,15 +553,15 @@ def denoise_distilled(
|
||||
if audio_denoised is not None:
|
||||
mx.eval(audio_denoised)
|
||||
|
||||
# Euler step in float32 (latents stay in float32)
|
||||
# Euler step in float32
|
||||
if sigma_next > 0:
|
||||
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
|
||||
latents = denoised + sigma_next_f32 * (latents - denoised) / sigma_f32
|
||||
if enable_audio and audio_denoised is not None:
|
||||
if enable_audio and audio_denoised is not None and not audio_frozen:
|
||||
audio_latents = audio_denoised + sigma_next_f32 * (audio_latents - audio_denoised) / sigma_f32
|
||||
else:
|
||||
latents = denoised
|
||||
if enable_audio and audio_denoised is not None:
|
||||
if enable_audio and audio_denoised is not None and not audio_frozen:
|
||||
audio_latents = audio_denoised
|
||||
|
||||
mx.eval(latents)
|
||||
@@ -785,6 +786,7 @@ def denoise_dev_av(
|
||||
stg_video_blocks: Optional[list] = None,
|
||||
stg_audio_blocks: Optional[list] = None,
|
||||
modality_scale: float = 1.0,
|
||||
audio_frozen: bool = False,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Run denoising loop for dev pipeline with CFG/APG, STG, modality guidance, and audio.
|
||||
|
||||
@@ -879,11 +881,12 @@ def denoise_dev_av(
|
||||
else:
|
||||
video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
|
||||
|
||||
audio_timesteps = mx.full((ab, at), sigma, dtype=dtype)
|
||||
# A2V: frozen audio uses timesteps=0 (tells model audio is clean)
|
||||
audio_timesteps = mx.zeros((ab, at), dtype=dtype) if audio_frozen else mx.full((ab, at), sigma, dtype=dtype)
|
||||
|
||||
# Positive conditioning pass
|
||||
sigma_array = mx.full((b,), sigma, dtype=dtype)
|
||||
audio_sigma_array = mx.full((ab,), sigma, dtype=dtype)
|
||||
audio_sigma_array = mx.zeros((ab,), dtype=dtype) if audio_frozen else mx.full((ab,), sigma, dtype=dtype)
|
||||
video_modality_pos = Modality(
|
||||
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
|
||||
context=video_embeddings_pos, context_mask=None, enabled=True,
|
||||
@@ -1001,10 +1004,12 @@ def denoise_dev_av(
|
||||
video_velocity_f32 = (video_latents - video_denoised_f32) / sigma_f32
|
||||
video_latents = video_latents + video_velocity_f32 * dt_f32
|
||||
|
||||
if not audio_frozen:
|
||||
audio_velocity_f32 = (audio_latents - audio_denoised_f32) / sigma_f32
|
||||
audio_latents = audio_latents + audio_velocity_f32 * dt_f32
|
||||
else:
|
||||
video_latents = video_denoised_f32
|
||||
if not audio_frozen:
|
||||
audio_latents = audio_denoised_f32
|
||||
|
||||
mx.eval(video_latents, audio_latents)
|
||||
@@ -1037,6 +1042,7 @@ def denoise_res2s_av(
|
||||
noise_seed: int = 42,
|
||||
bongmath: bool = True,
|
||||
bongmath_max_iter: int = 100,
|
||||
audio_frozen: bool = False,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Run res_2s second-order denoising loop with CFG/STG/modality guidance.
|
||||
|
||||
@@ -1125,10 +1131,10 @@ def denoise_res2s_av(
|
||||
video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
|
||||
else:
|
||||
video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
|
||||
audio_timesteps = mx.full((ab, at), sigma, dtype=dtype)
|
||||
audio_timesteps = mx.zeros((ab, at), dtype=dtype) if audio_frozen else mx.full((ab, at), sigma, dtype=dtype)
|
||||
|
||||
sigma_array = mx.full((b,), sigma, dtype=dtype)
|
||||
audio_sigma_array = mx.full((ab,), sigma, dtype=dtype)
|
||||
audio_sigma_array = mx.zeros((ab,), dtype=dtype) if audio_frozen else mx.full((ab,), sigma, dtype=dtype)
|
||||
|
||||
# Pass 1: Positive conditioning
|
||||
video_modality_pos = Modality(
|
||||
@@ -1270,17 +1276,22 @@ def denoise_res2s_av(
|
||||
|
||||
# Compute midpoint
|
||||
eps_1_video = denoised_video_1 - x_anchor_video
|
||||
eps_1_audio = denoised_audio_1 - x_anchor_audio
|
||||
|
||||
x_mid_video = x_anchor_video + h * a21 * eps_1_video
|
||||
|
||||
if not audio_frozen:
|
||||
eps_1_audio = denoised_audio_1 - x_anchor_audio
|
||||
x_mid_audio = x_anchor_audio + h * a21 * eps_1_audio
|
||||
else:
|
||||
eps_1_audio = None
|
||||
x_mid_audio = audio_latents # frozen: pass through unchanged
|
||||
|
||||
# SDE noise injection at substep
|
||||
substep_noise_key, key1, key2 = mx.random.split(substep_noise_key, 3)
|
||||
substep_noise_v = get_new_noise(video_latents.shape, key1)
|
||||
substep_noise_a = get_new_noise(audio_latents.shape, key2)
|
||||
|
||||
x_mid_video = sde_noise_step(x_anchor_video, x_mid_video, sigma, sub_sigma, substep_noise_v)
|
||||
if not audio_frozen:
|
||||
substep_noise_a = get_new_noise(audio_latents.shape, key2)
|
||||
x_mid_audio = sde_noise_step(x_anchor_audio, x_mid_audio, sigma, sub_sigma, substep_noise_a)
|
||||
mx.eval(x_mid_video, x_mid_audio)
|
||||
|
||||
@@ -1291,8 +1302,12 @@ def denoise_res2s_av(
|
||||
for _ in range(bongmath_max_iter):
|
||||
x_anchor_video = x_mid_video - h * a21 * eps_1_video
|
||||
eps_1_video = denoised_video_1 - x_anchor_video
|
||||
if not audio_frozen:
|
||||
x_anchor_audio = x_mid_audio - h * a21 * eps_1_audio
|
||||
eps_1_audio = denoised_audio_1 - x_anchor_audio
|
||||
if audio_frozen:
|
||||
mx.eval(x_anchor_video, eps_1_video)
|
||||
else:
|
||||
mx.eval(x_anchor_video, x_anchor_audio, eps_1_video, eps_1_audio)
|
||||
|
||||
# ============================================================
|
||||
@@ -1306,21 +1321,21 @@ def denoise_res2s_av(
|
||||
# Final combination with RK coefficients
|
||||
# ============================================================
|
||||
eps_2_video = denoised_video_2 - x_anchor_video
|
||||
eps_2_audio = denoised_audio_2 - x_anchor_audio
|
||||
|
||||
x_next_video = x_anchor_video + h * (b1 * eps_1_video + b2 * eps_2_video)
|
||||
x_next_audio = x_anchor_audio + h * (b1 * eps_1_audio + b2 * eps_2_audio)
|
||||
|
||||
# SDE noise injection at step level
|
||||
step_noise_key, key1, key2 = mx.random.split(step_noise_key, 3)
|
||||
step_noise_v = get_new_noise(video_latents.shape, key1)
|
||||
step_noise_a = get_new_noise(audio_latents.shape, key2)
|
||||
|
||||
x_next_video = sde_noise_step(x_anchor_video, x_next_video, sigma, sigma_next, step_noise_v)
|
||||
x_next_audio = sde_noise_step(x_anchor_audio, x_next_audio, sigma, sigma_next, step_noise_a)
|
||||
|
||||
video_latents = x_next_video.astype(mx.float32)
|
||||
if not audio_frozen:
|
||||
eps_2_audio = denoised_audio_2 - x_anchor_audio
|
||||
x_next_audio = x_anchor_audio + h * (b1 * eps_1_audio + b2 * eps_2_audio)
|
||||
step_noise_a = get_new_noise(audio_latents.shape, key2)
|
||||
x_next_audio = sde_noise_step(x_anchor_audio, x_next_audio, sigma, sigma_next, step_noise_a)
|
||||
audio_latents = x_next_audio.astype(mx.float32)
|
||||
|
||||
mx.eval(video_latents, audio_latents)
|
||||
progress.advance(task)
|
||||
|
||||
@@ -1330,6 +1345,7 @@ def denoise_res2s_av(
|
||||
video_latents, audio_latents, sigmas_list[n_full_steps]
|
||||
)
|
||||
video_latents = denoised_video
|
||||
if not audio_frozen:
|
||||
audio_latents = denoised_audio
|
||||
mx.eval(video_latents, audio_latents)
|
||||
|
||||
@@ -1443,6 +1459,8 @@ def generate_video(
|
||||
lora_strength: float = 1.0,
|
||||
lora_strength_stage_1: Optional[float] = None,
|
||||
lora_strength_stage_2: Optional[float] = None,
|
||||
audio_file: Optional[str] = None,
|
||||
audio_start_time: float = 0.0,
|
||||
):
|
||||
"""Generate video using LTX-2 models.
|
||||
|
||||
@@ -1496,8 +1514,16 @@ def generate_video(
|
||||
num_frames = adjusted_num_frames
|
||||
|
||||
is_i2v = image is not None
|
||||
is_a2v = audio_file is not None
|
||||
if is_a2v and audio:
|
||||
raise ValueError("Cannot use both --audio-file (A2V) and --audio (generate audio). Choose one.")
|
||||
# A2V implicitly enables audio path through the transformer
|
||||
if is_a2v:
|
||||
audio = True
|
||||
mode_str = "I2V" if is_i2v else "T2V"
|
||||
if audio:
|
||||
if is_a2v:
|
||||
mode_str = "A2V" + ("+I2V" if is_i2v else "")
|
||||
elif audio:
|
||||
mode_str += "+Audio"
|
||||
|
||||
pipeline_names = {
|
||||
@@ -1599,6 +1625,62 @@ def generate_video(
|
||||
stg_blocks = [29]
|
||||
console.print(f"[dim]Auto-detected STG blocks: {stg_blocks} (model={'2.3' if transformer.config.has_prompt_adaln else '2'})[/]")
|
||||
|
||||
# ==========================================================================
|
||||
# A2V: Encode input audio to frozen latents
|
||||
# ==========================================================================
|
||||
a2v_audio_latents = None
|
||||
a2v_waveform = None
|
||||
a2v_sr = None
|
||||
if is_a2v:
|
||||
from mlx_video.models.ltx.audio_vae.audio_processor import load_audio, ensure_stereo, waveform_to_mel
|
||||
from mlx_video.convert import convert_audio_encoder
|
||||
from mlx_video.models.ltx.audio_vae import AudioEncoder
|
||||
|
||||
with console.status("[blue]Loading and encoding input audio (A2V)...[/]", spinner="dots"):
|
||||
video_duration = num_frames / fps
|
||||
|
||||
# Load audio
|
||||
waveform, sr = load_audio(
|
||||
audio_file,
|
||||
target_sr=AUDIO_LATENT_SAMPLE_RATE,
|
||||
start_time=audio_start_time,
|
||||
max_duration=video_duration,
|
||||
)
|
||||
waveform = ensure_stereo(waveform)
|
||||
a2v_waveform = waveform.copy()
|
||||
a2v_sr = sr
|
||||
|
||||
# Compute mel-spectrogram
|
||||
mel = waveform_to_mel(waveform, sample_rate=sr, n_fft=1024, hop_length=AUDIO_HOP_LENGTH, n_mels=64)
|
||||
|
||||
# Convert audio encoder weights if needed, then load
|
||||
encoder_dir = convert_audio_encoder(model_path, source_repo="Lightricks/LTX-2")
|
||||
audio_encoder = AudioEncoder.from_pretrained(encoder_dir)
|
||||
mx.eval(audio_encoder.parameters())
|
||||
|
||||
# Encode: (1, 2, time, 64) -> normalized latents
|
||||
encoded = audio_encoder(mel)
|
||||
mx.eval(encoded)
|
||||
|
||||
# encoded is in MLX format (B, T', mel_bins', z_channels) = (1, T', 16, 8)
|
||||
# Convert to PyTorch-style format for consistency: (B, C, T, mel_bins)
|
||||
a2v_audio_latents = mx.transpose(encoded, (0, 3, 1, 2)).astype(model_dtype)
|
||||
|
||||
# Trim/pad to match expected audio_frames
|
||||
t_encoded = a2v_audio_latents.shape[2]
|
||||
if t_encoded > audio_frames:
|
||||
a2v_audio_latents = a2v_audio_latents[:, :, :audio_frames, :]
|
||||
elif t_encoded < audio_frames:
|
||||
pad_size = audio_frames - t_encoded
|
||||
padding = mx.zeros((1, AUDIO_LATENT_CHANNELS, pad_size, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||
a2v_audio_latents = mx.concatenate([a2v_audio_latents, padding], axis=2)
|
||||
mx.eval(a2v_audio_latents)
|
||||
|
||||
del audio_encoder
|
||||
mx.clear_cache()
|
||||
|
||||
console.print(f"[green]✓[/] Audio encoded ({a2v_audio_latents.shape[2]} frames from {audio_file})")
|
||||
|
||||
# ==========================================================================
|
||||
# Pipeline-specific generation logic
|
||||
# ==========================================================================
|
||||
@@ -1636,9 +1718,9 @@ def generate_video(
|
||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||
mx.eval(positions)
|
||||
|
||||
# Always init audio latents/positions - PyTorch unconditionally generates audio
|
||||
# Init audio latents/positions: use encoded A2V latents or random
|
||||
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
|
||||
audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
|
||||
mx.eval(audio_positions, audio_latents)
|
||||
|
||||
# Apply I2V conditioning
|
||||
@@ -1671,6 +1753,7 @@ def generate_video(
|
||||
latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS,
|
||||
verbose=verbose, state=state1,
|
||||
audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings,
|
||||
audio_frozen=is_a2v,
|
||||
)
|
||||
|
||||
# Upsample latents
|
||||
@@ -1723,7 +1806,7 @@ def generate_video(
|
||||
mx.eval(latents)
|
||||
|
||||
# Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch)
|
||||
if audio_latents is not None:
|
||||
if audio_latents is not None and not is_a2v:
|
||||
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
|
||||
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
|
||||
@@ -1735,6 +1818,7 @@ def generate_video(
|
||||
verbose=verbose, state=state2,
|
||||
audio_latents=audio_latents, audio_positions=audio_positions,
|
||||
audio_embeddings=audio_embeddings,
|
||||
audio_frozen=is_a2v,
|
||||
)
|
||||
|
||||
elif pipeline == PipelineType.DEV:
|
||||
@@ -1770,7 +1854,7 @@ def generate_video(
|
||||
|
||||
# Always init audio latents/positions - PyTorch unconditionally generates audio
|
||||
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||
audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||
mx.eval(audio_positions, audio_latents)
|
||||
|
||||
# Initialize latents with optional I2V conditioning
|
||||
@@ -1811,6 +1895,7 @@ def generate_video(
|
||||
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
|
||||
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
|
||||
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
|
||||
audio_frozen=is_a2v,
|
||||
)
|
||||
|
||||
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
|
||||
@@ -1858,7 +1943,7 @@ def generate_video(
|
||||
|
||||
# Always init audio latents/positions - PyTorch unconditionally generates audio
|
||||
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||
audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||
mx.eval(audio_positions, audio_latents)
|
||||
|
||||
# Apply I2V conditioning for stage 1
|
||||
@@ -1899,6 +1984,7 @@ def generate_video(
|
||||
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
|
||||
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
|
||||
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
|
||||
audio_frozen=is_a2v,
|
||||
)
|
||||
|
||||
mx.eval(audio_latents)
|
||||
@@ -1969,7 +2055,7 @@ def generate_video(
|
||||
mx.eval(latents)
|
||||
|
||||
# Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch)
|
||||
if audio_latents is not None:
|
||||
if audio_latents is not None and not is_a2v:
|
||||
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
|
||||
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
|
||||
@@ -1981,6 +2067,7 @@ def generate_video(
|
||||
verbose=verbose, state=state2,
|
||||
audio_latents=audio_latents, audio_positions=audio_positions,
|
||||
audio_embeddings=audio_embeddings_pos,
|
||||
audio_frozen=is_a2v,
|
||||
)
|
||||
|
||||
elif pipeline == PipelineType.DEV_TWO_STAGE_HQ:
|
||||
@@ -2045,7 +2132,7 @@ def generate_video(
|
||||
mx.eval(positions)
|
||||
|
||||
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||
audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||
mx.eval(audio_positions, audio_latents)
|
||||
|
||||
# Apply I2V conditioning for stage 1
|
||||
@@ -2087,6 +2174,7 @@ def generate_video(
|
||||
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
|
||||
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
|
||||
noise_seed=seed,
|
||||
audio_frozen=is_a2v,
|
||||
)
|
||||
|
||||
mx.eval(audio_latents)
|
||||
@@ -2148,7 +2236,7 @@ def generate_video(
|
||||
mx.eval(latents)
|
||||
|
||||
# Re-noise audio at sigma=0.909375 for joint refinement
|
||||
if audio_latents is not None:
|
||||
if audio_latents is not None and not is_a2v:
|
||||
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
|
||||
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
|
||||
@@ -2165,6 +2253,7 @@ def generate_video(
|
||||
audio_cfg_scale=1.0,
|
||||
cfg_rescale=0.0, verbose=verbose, video_state=state2,
|
||||
noise_seed=seed + 1,
|
||||
audio_frozen=is_a2v,
|
||||
)
|
||||
|
||||
del transformer
|
||||
@@ -2279,8 +2368,17 @@ def generate_video(
|
||||
|
||||
# Decode and save audio if enabled
|
||||
audio_np = None
|
||||
vocoder_sample_rate = AUDIO_SAMPLE_RATE
|
||||
if audio and audio_latents is not None:
|
||||
with console.status("[blue]🔊 Decoding audio...[/]", spinner="dots"):
|
||||
if is_a2v and a2v_waveform is not None:
|
||||
# A2V: use original input audio waveform (no VAE decoding needed)
|
||||
audio_np = a2v_waveform
|
||||
if audio_np.ndim == 1:
|
||||
audio_np = audio_np[np.newaxis, :]
|
||||
vocoder_sample_rate = a2v_sr or AUDIO_LATENT_SAMPLE_RATE
|
||||
console.print("[green]✓[/] Using original input audio (A2V)")
|
||||
else:
|
||||
with console.status("[blue]Decoding audio...[/]", spinner="dots"):
|
||||
audio_decoder = load_audio_decoder(model_path, pipeline)
|
||||
vocoder = load_vocoder_model(model_path, pipeline)
|
||||
mx.eval(audio_decoder.parameters(), vocoder.parameters())
|
||||
@@ -2398,6 +2496,8 @@ Examples:
|
||||
help="Tiling mode for VAE decoding")
|
||||
parser.add_argument("--stream", action="store_true", help="Stream frames to output as they're decoded")
|
||||
parser.add_argument("--audio", "-a", action="store_true", help="Enable synchronized audio generation")
|
||||
parser.add_argument("--audio-file", type=str, default=None, help="Path to audio file for A2V (audio-to-video) conditioning")
|
||||
parser.add_argument("--audio-start-time", type=float, default=0.0, help="Start time in seconds for audio file (default: 0.0)")
|
||||
parser.add_argument("--output-audio", type=str, default=None, help="Output audio path")
|
||||
parser.add_argument("--apg", action="store_true", help="Use Adaptive Projected Guidance instead of CFG (more stable for I2V)")
|
||||
parser.add_argument("--apg-eta", type=float, default=1.0, help="APG parallel component weight (1.0 = keep full parallel)")
|
||||
@@ -2457,6 +2557,8 @@ Examples:
|
||||
lora_strength=args.lora_strength,
|
||||
lora_strength_stage_1=args.lora_strength_stage_1,
|
||||
lora_strength_stage_2=args.lora_strength_stage_2,
|
||||
audio_file=args.audio_file,
|
||||
audio_start_time=args.audio_start_time,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Audio VAE module for LTX-2 audio generation."""
|
||||
|
||||
from .attention import AttentionType, AttnBlock, make_attn
|
||||
from .audio_vae import AudioDecoder, decode_audio
|
||||
from .audio_vae import AudioDecoder, AudioEncoder, decode_audio
|
||||
from .audio_processor import load_audio, ensure_stereo, waveform_to_mel
|
||||
from .causal_conv_2d import CausalConv2d, make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .downsample import Downsample, build_downsampling_path
|
||||
@@ -13,10 +14,15 @@ from .vocoder import Vocoder, load_vocoder
|
||||
|
||||
__all__ = [
|
||||
# Main components
|
||||
"AudioEncoder",
|
||||
"AudioDecoder",
|
||||
"Vocoder",
|
||||
"load_vocoder",
|
||||
"decode_audio",
|
||||
# Audio processing
|
||||
"load_audio",
|
||||
"ensure_stereo",
|
||||
"waveform_to_mel",
|
||||
# Ops
|
||||
"AudioLatentShape",
|
||||
"AudioPatchifier",
|
||||
|
||||
135
mlx_video/models/ltx/audio_vae/audio_processor.py
Normal file
135
mlx_video/models/ltx/audio_vae/audio_processor.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Audio processing utilities for loading audio files and computing mel-spectrograms.
|
||||
|
||||
Matches the PyTorch AudioProcessor from LTX-2 (torchaudio.transforms.MelSpectrogram)
|
||||
using librosa for macOS/MLX compatibility.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def load_audio(
|
||||
path: str,
|
||||
target_sr: int = 16000,
|
||||
start_time: float = 0.0,
|
||||
max_duration: float | None = None,
|
||||
mono: bool = False,
|
||||
) -> tuple[np.ndarray, int]:
|
||||
"""Load audio file, resample to target sample rate.
|
||||
|
||||
Args:
|
||||
path: Path to audio file (WAV, FLAC, MP3, OGG, or video with audio track).
|
||||
target_sr: Target sample rate (default 16000 Hz).
|
||||
start_time: Start time in seconds.
|
||||
max_duration: Maximum duration in seconds. None = read to end.
|
||||
mono: If True, convert to mono. Default False (preserve channels).
|
||||
|
||||
Returns:
|
||||
(waveform, sample_rate) where waveform is (channels, samples) float32 numpy array.
|
||||
"""
|
||||
import librosa
|
||||
|
||||
# librosa.load returns mono by default; we want to preserve stereo
|
||||
y, sr = librosa.load(
|
||||
path,
|
||||
sr=target_sr,
|
||||
mono=mono,
|
||||
offset=start_time,
|
||||
duration=max_duration,
|
||||
)
|
||||
|
||||
# Ensure 2D: (channels, samples)
|
||||
if y.ndim == 1:
|
||||
y = y[np.newaxis, :] # (1, samples)
|
||||
|
||||
return y.astype(np.float32), sr
|
||||
|
||||
|
||||
def ensure_stereo(waveform: np.ndarray) -> np.ndarray:
|
||||
"""Ensure waveform is stereo (2, samples). Duplicates mono if needed."""
|
||||
if waveform.ndim == 1:
|
||||
waveform = waveform[np.newaxis, :]
|
||||
if waveform.shape[0] == 1:
|
||||
waveform = np.concatenate([waveform, waveform], axis=0)
|
||||
elif waveform.shape[0] > 2:
|
||||
waveform = waveform[:2]
|
||||
return waveform
|
||||
|
||||
|
||||
def waveform_to_mel(
|
||||
waveform: np.ndarray,
|
||||
sample_rate: int = 16000,
|
||||
n_fft: int = 1024,
|
||||
hop_length: int = 160,
|
||||
win_length: int = 1024,
|
||||
n_mels: int = 64,
|
||||
fmin: float = 0.0,
|
||||
fmax: float = 8000.0,
|
||||
) -> mx.array:
|
||||
"""Convert waveform to log-mel spectrogram matching PyTorch MelSpectrogram.
|
||||
|
||||
PyTorch reference:
|
||||
MelSpectrogram(sample_rate=16000, n_fft=1024, win_length=1024, hop_length=160,
|
||||
f_min=0.0, f_max=8000.0, n_mels=64, power=1.0,
|
||||
mel_scale="slaney", norm="slaney", center=True, pad_mode="reflect")
|
||||
|
||||
Args:
|
||||
waveform: (channels, samples) float32 numpy array.
|
||||
sample_rate: Sample rate of the waveform.
|
||||
n_fft: FFT size.
|
||||
hop_length: Hop length.
|
||||
win_length: Window length.
|
||||
n_mels: Number of mel bins.
|
||||
fmin: Minimum frequency for mel filterbank.
|
||||
fmax: Maximum frequency for mel filterbank.
|
||||
|
||||
Returns:
|
||||
Log-mel spectrogram as mx.array of shape (1, channels, time, n_mels).
|
||||
"""
|
||||
import librosa
|
||||
|
||||
# Ensure 2D
|
||||
if waveform.ndim == 1:
|
||||
waveform = waveform[np.newaxis, :]
|
||||
|
||||
channels = waveform.shape[0]
|
||||
mels = []
|
||||
|
||||
for ch in range(channels):
|
||||
# Magnitude spectrogram (power=1.0)
|
||||
S = np.abs(librosa.stft(
|
||||
waveform[ch],
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
))
|
||||
|
||||
# Mel filterbank with slaney normalization
|
||||
mel_basis = librosa.filters.mel(
|
||||
sr=sample_rate,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mels,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
norm="slaney",
|
||||
)
|
||||
mel = mel_basis @ S
|
||||
|
||||
# Log scale
|
||||
mel = np.log(np.clip(mel, a_min=1e-5, a_max=None))
|
||||
|
||||
# Transpose: (n_mels, time) -> (time, n_mels)
|
||||
mel = mel.T
|
||||
mels.append(mel)
|
||||
|
||||
# Stack channels: (channels, time, n_mels)
|
||||
mel_spec = np.stack(mels, axis=0)
|
||||
|
||||
# Add batch dim: (1, channels, time, n_mels)
|
||||
mel_spec = mel_spec[np.newaxis, ...]
|
||||
|
||||
return mx.array(mel_spec, dtype=mx.float32)
|
||||
@@ -6,10 +6,11 @@ from pathlib import Path
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_vlm.models.base import check_array_shape
|
||||
from ..config import AudioDecoderModelConfig
|
||||
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig
|
||||
from .attention import AttentionType, make_attn
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .downsample import build_downsampling_path
|
||||
from .normalization import NormType, build_normalization_layer
|
||||
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||
from .resnet import ResnetBlock
|
||||
@@ -59,6 +60,179 @@ def run_mid_block(mid: dict, features: mx.array) -> mx.array:
|
||||
return mid["block_2"](features, temb=None)
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
"""Encoder that compresses audio spectrograms into latent representations."""
|
||||
|
||||
def __init__(self, config: AudioEncoderModelConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch)
|
||||
self.sample_rate = config.sample_rate
|
||||
self.mel_hop_length = config.mel_hop_length
|
||||
self.is_causal = config.is_causal
|
||||
self.mel_bins = config.mel_bins
|
||||
|
||||
self.patchifier = AudioPatchifier(
|
||||
patch_size=1,
|
||||
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||
sample_rate=config.sample_rate,
|
||||
hop_length=config.mel_hop_length,
|
||||
is_causal=config.is_causal,
|
||||
)
|
||||
|
||||
self.ch = config.ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(config.ch_mult)
|
||||
self.num_res_blocks = config.num_res_blocks
|
||||
self.resolution = config.resolution
|
||||
self.in_channels = config.in_channels
|
||||
self.z_channels = config.z_channels
|
||||
self.double_z = config.double_z
|
||||
self.norm_type = config.norm_type
|
||||
self.causality_axis = config.causality_axis
|
||||
self.attn_type = config.attn_type
|
||||
|
||||
self.conv_in = make_conv2d(
|
||||
config.in_channels, self.ch, kernel_size=3, stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
self.down, block_in = build_downsampling_path(
|
||||
ch=config.ch,
|
||||
ch_mult=config.ch_mult,
|
||||
num_resolutions=self.num_resolutions,
|
||||
num_res_blocks=config.num_res_blocks,
|
||||
resolution=config.resolution,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=config.dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
attn_type=self.attn_type,
|
||||
attn_resolutions=config.attn_resolutions or set(),
|
||||
resamp_with_conv=config.resamp_with_conv,
|
||||
)
|
||||
|
||||
self.mid = build_mid_block(
|
||||
channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=config.dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
attn_type=self.attn_type,
|
||||
add_attention=config.mid_block_add_attention,
|
||||
)
|
||||
|
||||
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
|
||||
out_channels = 2 * config.z_channels if config.double_z else config.z_channels
|
||||
self.conv_out = make_conv2d(
|
||||
block_in, out_channels, kernel_size=3, stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize audio encoder weights from PyTorch format."""
|
||||
sanitized = {}
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
if key.startswith("audio_vae.encoder."):
|
||||
new_key = key.replace("audio_vae.encoder.", "")
|
||||
elif key.startswith("encoder."):
|
||||
new_key = key.replace("encoder.", "")
|
||||
elif key.startswith("audio_vae.per_channel_statistics."):
|
||||
if "mean-of-means" in key:
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
elif "std-of-means" in key:
|
||||
new_key = "per_channel_statistics.std_of_means"
|
||||
else:
|
||||
continue
|
||||
elif "per_channel_statistics" in key:
|
||||
if "mean-of-means" in key or "latents_mean" in key:
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
elif "std-of-means" in key or "latents_std" in key:
|
||||
new_key = "per_channel_statistics.std_of_means"
|
||||
else:
|
||||
continue
|
||||
elif key == "latents_mean":
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
elif key == "latents_std":
|
||||
new_key = "per_channel_statistics.std_of_means"
|
||||
else:
|
||||
continue
|
||||
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
|
||||
"""Load audio encoder from pretrained weights."""
|
||||
from mlx_video.models.ltx.config import AudioEncoderModelConfig
|
||||
import json
|
||||
|
||||
model_path = Path(model_path)
|
||||
config = AudioEncoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
||||
encoder = cls(config)
|
||||
weights = mx.load(str(model_path / "model.safetensors"))
|
||||
encoder.load_weights(list(weights.items()), strict=True)
|
||||
return encoder
|
||||
|
||||
def __call__(self, spectrogram: mx.array) -> mx.array:
|
||||
"""Encode audio spectrogram into normalized latent representation.
|
||||
|
||||
Args:
|
||||
spectrogram: (B, C, T, F) PyTorch format or (B, T, F, C) MLX format.
|
||||
Returns:
|
||||
Normalized latent (B, T', F', z_channels) in MLX channels-last format.
|
||||
"""
|
||||
if spectrogram.ndim == 4 and spectrogram.shape[1] == self.in_channels:
|
||||
spectrogram = mx.transpose(spectrogram, (0, 2, 3, 1))
|
||||
|
||||
h = self.conv_in(spectrogram)
|
||||
h = self._run_downsampling_path(h)
|
||||
h = run_mid_block(self.mid, h)
|
||||
h = self._finalize_output(h)
|
||||
return self._normalize_latents(h)
|
||||
|
||||
def _run_downsampling_path(self, h: mx.array) -> mx.array:
|
||||
for level in range(self.num_resolutions):
|
||||
stage = self.down[level]
|
||||
for block_idx in range(self.num_res_blocks):
|
||||
h = stage["block"][block_idx](h, temb=None)
|
||||
if block_idx in stage["attn"]:
|
||||
h = stage["attn"][block_idx](h)
|
||||
if level != self.num_resolutions - 1 and "downsample" in stage:
|
||||
h = stage["downsample"](h)
|
||||
return h
|
||||
|
||||
def _finalize_output(self, h: mx.array) -> mx.array:
|
||||
h = self.norm_out(h)
|
||||
h = nn.silu(h)
|
||||
return self.conv_out(h)
|
||||
|
||||
def _normalize_latents(self, h: mx.array) -> mx.array:
|
||||
"""Normalize encoder output using per-channel statistics.
|
||||
|
||||
Takes first half of channels (mean) when double_z=True,
|
||||
then patchifies, normalizes, and unpatchifies.
|
||||
"""
|
||||
# h shape: (B, T', F', 2*z_channels) in MLX format
|
||||
z_channels = self.z_channels
|
||||
means = h[..., :z_channels]
|
||||
|
||||
latent_shape = AudioLatentShape(
|
||||
batch=means.shape[0],
|
||||
channels=means.shape[3],
|
||||
frames=means.shape[1],
|
||||
mel_bins=means.shape[2],
|
||||
)
|
||||
|
||||
patched = self.patchifier.patchify(means)
|
||||
normalized = self.per_channel_statistics.normalize(patched)
|
||||
return self.patchifier.unpatchify(normalized, latent_shape)
|
||||
|
||||
|
||||
class AudioDecoder(nn.Module):
|
||||
"""
|
||||
Symmetric decoder that reconstructs audio spectrograms from latent features.
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Tuple, Set
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
|
||||
class LTXModelType(Enum):
|
||||
@@ -252,6 +252,47 @@ class AudioDecoderModelConfig(BaseModelConfig):
|
||||
if isinstance(self.attn_type, str):
|
||||
self.attn_type = AttentionType(self.attn_type)
|
||||
|
||||
@dataclass
|
||||
class AudioEncoderModelConfig(BaseModelConfig):
|
||||
ch: int = 128
|
||||
in_channels: int = 2
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4)
|
||||
num_res_blocks: int = 2
|
||||
attn_resolutions: Optional[List[int]] = None
|
||||
resolution: int = 256
|
||||
z_channels: int = 8
|
||||
double_z: bool = True
|
||||
n_fft: int = 1024
|
||||
norm_type: Enum = None
|
||||
causality_axis: Enum = None
|
||||
dropout: float = 0.0
|
||||
mid_block_add_attention: bool = True
|
||||
sample_rate: int = 16000
|
||||
mel_hop_length: int = 160
|
||||
is_causal: bool = True
|
||||
mel_bins: int = 64
|
||||
resamp_with_conv: bool = True
|
||||
attn_type: str = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if self.attn_resolutions is not None:
|
||||
result["attn_resolutions"] = list(self.attn_resolutions)
|
||||
return result
|
||||
|
||||
def __post_init__(self):
|
||||
"""Convert string enum values to proper enum types."""
|
||||
from .audio_vae.normalization import NormType
|
||||
from .audio_vae.attention import AttentionType
|
||||
|
||||
if isinstance(self.causality_axis, str):
|
||||
self.causality_axis = CausalityAxis(self.causality_axis)
|
||||
if isinstance(self.norm_type, str):
|
||||
self.norm_type = NormType(self.norm_type)
|
||||
if isinstance(self.attn_type, str):
|
||||
self.attn_type = AttentionType(self.attn_type)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocoderModelConfig(BaseModelConfig):
|
||||
resblock_kernel_sizes: Optional[List[int]] = None
|
||||
|
||||
Reference in New Issue
Block a user