Add audio to video conditioning
This commit is contained in:
@@ -454,6 +454,7 @@ def denoise_distilled(
|
||||
audio_latents: Optional[mx.array] = None,
|
||||
audio_positions: Optional[mx.array] = None,
|
||||
audio_embeddings: Optional[mx.array] = None,
|
||||
audio_frozen: bool = False,
|
||||
) -> tuple[mx.array, Optional[mx.array]]:
|
||||
"""Run denoising loop for distilled pipeline (no CFG)."""
|
||||
dtype = latents.dtype
|
||||
@@ -513,14 +514,17 @@ def denoise_distilled(
|
||||
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3))
|
||||
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype)
|
||||
|
||||
# A2V: frozen audio uses timesteps=0 (tells model audio is clean)
|
||||
a_ts = mx.zeros((ab, at), dtype=dtype) if audio_frozen else mx.full((ab, at), sigma, dtype=dtype)
|
||||
a_sig = mx.zeros((ab,), dtype=dtype) if audio_frozen else mx.full((ab,), sigma, dtype=dtype)
|
||||
audio_modality = Modality(
|
||||
latent=audio_flat,
|
||||
timesteps=mx.full((ab, at), sigma, dtype=dtype),
|
||||
timesteps=a_ts,
|
||||
positions=audio_positions,
|
||||
context=audio_embeddings,
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
sigma=mx.full((ab,), sigma, dtype=dtype),
|
||||
sigma=a_sig,
|
||||
)
|
||||
|
||||
velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality)
|
||||
@@ -529,9 +533,6 @@ def denoise_distilled(
|
||||
mx.eval(audio_velocity)
|
||||
|
||||
# Compute denoised (x0) using per-token timesteps in float32
|
||||
# x0 = latent - timestep * velocity
|
||||
# For conditioned tokens (timestep=0): x0 = latent
|
||||
# For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity
|
||||
sigma_f32 = mx.array(sigma, dtype=mx.float32)
|
||||
latents_flat_f32 = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
|
||||
timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1)
|
||||
@@ -539,7 +540,7 @@ def denoise_distilled(
|
||||
denoised = mx.reshape(mx.transpose(x0_f32, (0, 2, 1)), (b, c, f, h, w))
|
||||
|
||||
audio_denoised = None
|
||||
if enable_audio and audio_velocity is not None:
|
||||
if enable_audio and audio_velocity is not None and not audio_frozen:
|
||||
ab, ac, at, af = audio_latents.shape
|
||||
audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af))
|
||||
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3))
|
||||
@@ -552,15 +553,15 @@ def denoise_distilled(
|
||||
if audio_denoised is not None:
|
||||
mx.eval(audio_denoised)
|
||||
|
||||
# Euler step in float32 (latents stay in float32)
|
||||
# Euler step in float32
|
||||
if sigma_next > 0:
|
||||
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
|
||||
latents = denoised + sigma_next_f32 * (latents - denoised) / sigma_f32
|
||||
if enable_audio and audio_denoised is not None:
|
||||
if enable_audio and audio_denoised is not None and not audio_frozen:
|
||||
audio_latents = audio_denoised + sigma_next_f32 * (audio_latents - audio_denoised) / sigma_f32
|
||||
else:
|
||||
latents = denoised
|
||||
if enable_audio and audio_denoised is not None:
|
||||
if enable_audio and audio_denoised is not None and not audio_frozen:
|
||||
audio_latents = audio_denoised
|
||||
|
||||
mx.eval(latents)
|
||||
@@ -785,6 +786,7 @@ def denoise_dev_av(
|
||||
stg_video_blocks: Optional[list] = None,
|
||||
stg_audio_blocks: Optional[list] = None,
|
||||
modality_scale: float = 1.0,
|
||||
audio_frozen: bool = False,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Run denoising loop for dev pipeline with CFG/APG, STG, modality guidance, and audio.
|
||||
|
||||
@@ -879,11 +881,12 @@ def denoise_dev_av(
|
||||
else:
|
||||
video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
|
||||
|
||||
audio_timesteps = mx.full((ab, at), sigma, dtype=dtype)
|
||||
# A2V: frozen audio uses timesteps=0 (tells model audio is clean)
|
||||
audio_timesteps = mx.zeros((ab, at), dtype=dtype) if audio_frozen else mx.full((ab, at), sigma, dtype=dtype)
|
||||
|
||||
# Positive conditioning pass
|
||||
sigma_array = mx.full((b,), sigma, dtype=dtype)
|
||||
audio_sigma_array = mx.full((ab,), sigma, dtype=dtype)
|
||||
audio_sigma_array = mx.zeros((ab,), dtype=dtype) if audio_frozen else mx.full((ab,), sigma, dtype=dtype)
|
||||
video_modality_pos = Modality(
|
||||
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
|
||||
context=video_embeddings_pos, context_mask=None, enabled=True,
|
||||
@@ -1001,11 +1004,13 @@ def denoise_dev_av(
|
||||
video_velocity_f32 = (video_latents - video_denoised_f32) / sigma_f32
|
||||
video_latents = video_latents + video_velocity_f32 * dt_f32
|
||||
|
||||
audio_velocity_f32 = (audio_latents - audio_denoised_f32) / sigma_f32
|
||||
audio_latents = audio_latents + audio_velocity_f32 * dt_f32
|
||||
if not audio_frozen:
|
||||
audio_velocity_f32 = (audio_latents - audio_denoised_f32) / sigma_f32
|
||||
audio_latents = audio_latents + audio_velocity_f32 * dt_f32
|
||||
else:
|
||||
video_latents = video_denoised_f32
|
||||
audio_latents = audio_denoised_f32
|
||||
if not audio_frozen:
|
||||
audio_latents = audio_denoised_f32
|
||||
|
||||
mx.eval(video_latents, audio_latents)
|
||||
progress.advance(task)
|
||||
@@ -1037,6 +1042,7 @@ def denoise_res2s_av(
|
||||
noise_seed: int = 42,
|
||||
bongmath: bool = True,
|
||||
bongmath_max_iter: int = 100,
|
||||
audio_frozen: bool = False,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Run res_2s second-order denoising loop with CFG/STG/modality guidance.
|
||||
|
||||
@@ -1125,10 +1131,10 @@ def denoise_res2s_av(
|
||||
video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
|
||||
else:
|
||||
video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
|
||||
audio_timesteps = mx.full((ab, at), sigma, dtype=dtype)
|
||||
audio_timesteps = mx.zeros((ab, at), dtype=dtype) if audio_frozen else mx.full((ab, at), sigma, dtype=dtype)
|
||||
|
||||
sigma_array = mx.full((b,), sigma, dtype=dtype)
|
||||
audio_sigma_array = mx.full((ab,), sigma, dtype=dtype)
|
||||
audio_sigma_array = mx.zeros((ab,), dtype=dtype) if audio_frozen else mx.full((ab,), sigma, dtype=dtype)
|
||||
|
||||
# Pass 1: Positive conditioning
|
||||
video_modality_pos = Modality(
|
||||
@@ -1270,18 +1276,23 @@ def denoise_res2s_av(
|
||||
|
||||
# Compute midpoint
|
||||
eps_1_video = denoised_video_1 - x_anchor_video
|
||||
eps_1_audio = denoised_audio_1 - x_anchor_audio
|
||||
|
||||
x_mid_video = x_anchor_video + h * a21 * eps_1_video
|
||||
x_mid_audio = x_anchor_audio + h * a21 * eps_1_audio
|
||||
|
||||
if not audio_frozen:
|
||||
eps_1_audio = denoised_audio_1 - x_anchor_audio
|
||||
x_mid_audio = x_anchor_audio + h * a21 * eps_1_audio
|
||||
else:
|
||||
eps_1_audio = None
|
||||
x_mid_audio = audio_latents # frozen: pass through unchanged
|
||||
|
||||
# SDE noise injection at substep
|
||||
substep_noise_key, key1, key2 = mx.random.split(substep_noise_key, 3)
|
||||
substep_noise_v = get_new_noise(video_latents.shape, key1)
|
||||
substep_noise_a = get_new_noise(audio_latents.shape, key2)
|
||||
|
||||
x_mid_video = sde_noise_step(x_anchor_video, x_mid_video, sigma, sub_sigma, substep_noise_v)
|
||||
x_mid_audio = sde_noise_step(x_anchor_audio, x_mid_audio, sigma, sub_sigma, substep_noise_a)
|
||||
if not audio_frozen:
|
||||
substep_noise_a = get_new_noise(audio_latents.shape, key2)
|
||||
x_mid_audio = sde_noise_step(x_anchor_audio, x_mid_audio, sigma, sub_sigma, substep_noise_a)
|
||||
mx.eval(x_mid_video, x_mid_audio)
|
||||
|
||||
# ============================================================
|
||||
@@ -1291,9 +1302,13 @@ def denoise_res2s_av(
|
||||
for _ in range(bongmath_max_iter):
|
||||
x_anchor_video = x_mid_video - h * a21 * eps_1_video
|
||||
eps_1_video = denoised_video_1 - x_anchor_video
|
||||
x_anchor_audio = x_mid_audio - h * a21 * eps_1_audio
|
||||
eps_1_audio = denoised_audio_1 - x_anchor_audio
|
||||
mx.eval(x_anchor_video, x_anchor_audio, eps_1_video, eps_1_audio)
|
||||
if not audio_frozen:
|
||||
x_anchor_audio = x_mid_audio - h * a21 * eps_1_audio
|
||||
eps_1_audio = denoised_audio_1 - x_anchor_audio
|
||||
if audio_frozen:
|
||||
mx.eval(x_anchor_video, eps_1_video)
|
||||
else:
|
||||
mx.eval(x_anchor_video, x_anchor_audio, eps_1_video, eps_1_audio)
|
||||
|
||||
# ============================================================
|
||||
# Stage 2: Evaluate denoiser at midpoint sigma
|
||||
@@ -1306,21 +1321,21 @@ def denoise_res2s_av(
|
||||
# Final combination with RK coefficients
|
||||
# ============================================================
|
||||
eps_2_video = denoised_video_2 - x_anchor_video
|
||||
eps_2_audio = denoised_audio_2 - x_anchor_audio
|
||||
|
||||
x_next_video = x_anchor_video + h * (b1 * eps_1_video + b2 * eps_2_video)
|
||||
x_next_audio = x_anchor_audio + h * (b1 * eps_1_audio + b2 * eps_2_audio)
|
||||
|
||||
# SDE noise injection at step level
|
||||
step_noise_key, key1, key2 = mx.random.split(step_noise_key, 3)
|
||||
step_noise_v = get_new_noise(video_latents.shape, key1)
|
||||
step_noise_a = get_new_noise(audio_latents.shape, key2)
|
||||
|
||||
x_next_video = sde_noise_step(x_anchor_video, x_next_video, sigma, sigma_next, step_noise_v)
|
||||
x_next_audio = sde_noise_step(x_anchor_audio, x_next_audio, sigma, sigma_next, step_noise_a)
|
||||
|
||||
video_latents = x_next_video.astype(mx.float32)
|
||||
audio_latents = x_next_audio.astype(mx.float32)
|
||||
if not audio_frozen:
|
||||
eps_2_audio = denoised_audio_2 - x_anchor_audio
|
||||
x_next_audio = x_anchor_audio + h * (b1 * eps_1_audio + b2 * eps_2_audio)
|
||||
step_noise_a = get_new_noise(audio_latents.shape, key2)
|
||||
x_next_audio = sde_noise_step(x_anchor_audio, x_next_audio, sigma, sigma_next, step_noise_a)
|
||||
audio_latents = x_next_audio.astype(mx.float32)
|
||||
|
||||
mx.eval(video_latents, audio_latents)
|
||||
progress.advance(task)
|
||||
|
||||
@@ -1330,7 +1345,8 @@ def denoise_res2s_av(
|
||||
video_latents, audio_latents, sigmas_list[n_full_steps]
|
||||
)
|
||||
video_latents = denoised_video
|
||||
audio_latents = denoised_audio
|
||||
if not audio_frozen:
|
||||
audio_latents = denoised_audio
|
||||
mx.eval(video_latents, audio_latents)
|
||||
|
||||
return video_latents, audio_latents
|
||||
@@ -1443,6 +1459,8 @@ def generate_video(
|
||||
lora_strength: float = 1.0,
|
||||
lora_strength_stage_1: Optional[float] = None,
|
||||
lora_strength_stage_2: Optional[float] = None,
|
||||
audio_file: Optional[str] = None,
|
||||
audio_start_time: float = 0.0,
|
||||
):
|
||||
"""Generate video using LTX-2 models.
|
||||
|
||||
@@ -1496,8 +1514,16 @@ def generate_video(
|
||||
num_frames = adjusted_num_frames
|
||||
|
||||
is_i2v = image is not None
|
||||
is_a2v = audio_file is not None
|
||||
if is_a2v and audio:
|
||||
raise ValueError("Cannot use both --audio-file (A2V) and --audio (generate audio). Choose one.")
|
||||
# A2V implicitly enables audio path through the transformer
|
||||
if is_a2v:
|
||||
audio = True
|
||||
mode_str = "I2V" if is_i2v else "T2V"
|
||||
if audio:
|
||||
if is_a2v:
|
||||
mode_str = "A2V" + ("+I2V" if is_i2v else "")
|
||||
elif audio:
|
||||
mode_str += "+Audio"
|
||||
|
||||
pipeline_names = {
|
||||
@@ -1599,6 +1625,62 @@ def generate_video(
|
||||
stg_blocks = [29]
|
||||
console.print(f"[dim]Auto-detected STG blocks: {stg_blocks} (model={'2.3' if transformer.config.has_prompt_adaln else '2'})[/]")
|
||||
|
||||
# ==========================================================================
|
||||
# A2V: Encode input audio to frozen latents
|
||||
# ==========================================================================
|
||||
a2v_audio_latents = None
|
||||
a2v_waveform = None
|
||||
a2v_sr = None
|
||||
if is_a2v:
|
||||
from mlx_video.models.ltx.audio_vae.audio_processor import load_audio, ensure_stereo, waveform_to_mel
|
||||
from mlx_video.convert import convert_audio_encoder
|
||||
from mlx_video.models.ltx.audio_vae import AudioEncoder
|
||||
|
||||
with console.status("[blue]Loading and encoding input audio (A2V)...[/]", spinner="dots"):
|
||||
video_duration = num_frames / fps
|
||||
|
||||
# Load audio
|
||||
waveform, sr = load_audio(
|
||||
audio_file,
|
||||
target_sr=AUDIO_LATENT_SAMPLE_RATE,
|
||||
start_time=audio_start_time,
|
||||
max_duration=video_duration,
|
||||
)
|
||||
waveform = ensure_stereo(waveform)
|
||||
a2v_waveform = waveform.copy()
|
||||
a2v_sr = sr
|
||||
|
||||
# Compute mel-spectrogram
|
||||
mel = waveform_to_mel(waveform, sample_rate=sr, n_fft=1024, hop_length=AUDIO_HOP_LENGTH, n_mels=64)
|
||||
|
||||
# Convert audio encoder weights if needed, then load
|
||||
encoder_dir = convert_audio_encoder(model_path, source_repo="Lightricks/LTX-2")
|
||||
audio_encoder = AudioEncoder.from_pretrained(encoder_dir)
|
||||
mx.eval(audio_encoder.parameters())
|
||||
|
||||
# Encode: (1, 2, time, 64) -> normalized latents
|
||||
encoded = audio_encoder(mel)
|
||||
mx.eval(encoded)
|
||||
|
||||
# encoded is in MLX format (B, T', mel_bins', z_channels) = (1, T', 16, 8)
|
||||
# Convert to PyTorch-style format for consistency: (B, C, T, mel_bins)
|
||||
a2v_audio_latents = mx.transpose(encoded, (0, 3, 1, 2)).astype(model_dtype)
|
||||
|
||||
# Trim/pad to match expected audio_frames
|
||||
t_encoded = a2v_audio_latents.shape[2]
|
||||
if t_encoded > audio_frames:
|
||||
a2v_audio_latents = a2v_audio_latents[:, :, :audio_frames, :]
|
||||
elif t_encoded < audio_frames:
|
||||
pad_size = audio_frames - t_encoded
|
||||
padding = mx.zeros((1, AUDIO_LATENT_CHANNELS, pad_size, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||
a2v_audio_latents = mx.concatenate([a2v_audio_latents, padding], axis=2)
|
||||
mx.eval(a2v_audio_latents)
|
||||
|
||||
del audio_encoder
|
||||
mx.clear_cache()
|
||||
|
||||
console.print(f"[green]✓[/] Audio encoded ({a2v_audio_latents.shape[2]} frames from {audio_file})")
|
||||
|
||||
# ==========================================================================
|
||||
# Pipeline-specific generation logic
|
||||
# ==========================================================================
|
||||
@@ -1636,9 +1718,9 @@ def generate_video(
|
||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||
mx.eval(positions)
|
||||
|
||||
# Always init audio latents/positions - PyTorch unconditionally generates audio
|
||||
# Init audio latents/positions: use encoded A2V latents or random
|
||||
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
|
||||
audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
|
||||
mx.eval(audio_positions, audio_latents)
|
||||
|
||||
# Apply I2V conditioning
|
||||
@@ -1671,6 +1753,7 @@ def generate_video(
|
||||
latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS,
|
||||
verbose=verbose, state=state1,
|
||||
audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings,
|
||||
audio_frozen=is_a2v,
|
||||
)
|
||||
|
||||
# Upsample latents
|
||||
@@ -1723,7 +1806,7 @@ def generate_video(
|
||||
mx.eval(latents)
|
||||
|
||||
# Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch)
|
||||
if audio_latents is not None:
|
||||
if audio_latents is not None and not is_a2v:
|
||||
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
|
||||
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
|
||||
@@ -1735,6 +1818,7 @@ def generate_video(
|
||||
verbose=verbose, state=state2,
|
||||
audio_latents=audio_latents, audio_positions=audio_positions,
|
||||
audio_embeddings=audio_embeddings,
|
||||
audio_frozen=is_a2v,
|
||||
)
|
||||
|
||||
elif pipeline == PipelineType.DEV:
|
||||
@@ -1770,7 +1854,7 @@ def generate_video(
|
||||
|
||||
# Always init audio latents/positions - PyTorch unconditionally generates audio
|
||||
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||
audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||
mx.eval(audio_positions, audio_latents)
|
||||
|
||||
# Initialize latents with optional I2V conditioning
|
||||
@@ -1811,6 +1895,7 @@ def generate_video(
|
||||
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
|
||||
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
|
||||
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
|
||||
audio_frozen=is_a2v,
|
||||
)
|
||||
|
||||
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
|
||||
@@ -1858,7 +1943,7 @@ def generate_video(
|
||||
|
||||
# Always init audio latents/positions - PyTorch unconditionally generates audio
|
||||
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||
audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||
mx.eval(audio_positions, audio_latents)
|
||||
|
||||
# Apply I2V conditioning for stage 1
|
||||
@@ -1899,6 +1984,7 @@ def generate_video(
|
||||
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
|
||||
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
|
||||
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
|
||||
audio_frozen=is_a2v,
|
||||
)
|
||||
|
||||
mx.eval(audio_latents)
|
||||
@@ -1969,7 +2055,7 @@ def generate_video(
|
||||
mx.eval(latents)
|
||||
|
||||
# Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch)
|
||||
if audio_latents is not None:
|
||||
if audio_latents is not None and not is_a2v:
|
||||
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
|
||||
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
|
||||
@@ -1981,6 +2067,7 @@ def generate_video(
|
||||
verbose=verbose, state=state2,
|
||||
audio_latents=audio_latents, audio_positions=audio_positions,
|
||||
audio_embeddings=audio_embeddings_pos,
|
||||
audio_frozen=is_a2v,
|
||||
)
|
||||
|
||||
elif pipeline == PipelineType.DEV_TWO_STAGE_HQ:
|
||||
@@ -2045,7 +2132,7 @@ def generate_video(
|
||||
mx.eval(positions)
|
||||
|
||||
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||
audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||
mx.eval(audio_positions, audio_latents)
|
||||
|
||||
# Apply I2V conditioning for stage 1
|
||||
@@ -2087,6 +2174,7 @@ def generate_video(
|
||||
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
|
||||
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
|
||||
noise_seed=seed,
|
||||
audio_frozen=is_a2v,
|
||||
)
|
||||
|
||||
mx.eval(audio_latents)
|
||||
@@ -2148,7 +2236,7 @@ def generate_video(
|
||||
mx.eval(latents)
|
||||
|
||||
# Re-noise audio at sigma=0.909375 for joint refinement
|
||||
if audio_latents is not None:
|
||||
if audio_latents is not None and not is_a2v:
|
||||
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
|
||||
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
|
||||
@@ -2165,6 +2253,7 @@ def generate_video(
|
||||
audio_cfg_scale=1.0,
|
||||
cfg_rescale=0.0, verbose=verbose, video_state=state2,
|
||||
noise_seed=seed + 1,
|
||||
audio_frozen=is_a2v,
|
||||
)
|
||||
|
||||
del transformer
|
||||
@@ -2279,29 +2368,38 @@ def generate_video(
|
||||
|
||||
# Decode and save audio if enabled
|
||||
audio_np = None
|
||||
vocoder_sample_rate = AUDIO_SAMPLE_RATE
|
||||
if audio and audio_latents is not None:
|
||||
with console.status("[blue]🔊 Decoding audio...[/]", spinner="dots"):
|
||||
audio_decoder = load_audio_decoder(model_path, pipeline)
|
||||
vocoder = load_vocoder_model(model_path, pipeline)
|
||||
mx.eval(audio_decoder.parameters(), vocoder.parameters())
|
||||
if is_a2v and a2v_waveform is not None:
|
||||
# A2V: use original input audio waveform (no VAE decoding needed)
|
||||
audio_np = a2v_waveform
|
||||
if audio_np.ndim == 1:
|
||||
audio_np = audio_np[np.newaxis, :]
|
||||
vocoder_sample_rate = a2v_sr or AUDIO_LATENT_SAMPLE_RATE
|
||||
console.print("[green]✓[/] Using original input audio (A2V)")
|
||||
else:
|
||||
with console.status("[blue]Decoding audio...[/]", spinner="dots"):
|
||||
audio_decoder = load_audio_decoder(model_path, pipeline)
|
||||
vocoder = load_vocoder_model(model_path, pipeline)
|
||||
mx.eval(audio_decoder.parameters(), vocoder.parameters())
|
||||
|
||||
mel_spectrogram = audio_decoder(audio_latents)
|
||||
mx.eval(mel_spectrogram)
|
||||
console.print(f"[dim] Mel spectrogram: shape={mel_spectrogram.shape}, std={mel_spectrogram.std().item():.4f}, mean={mel_spectrogram.mean().item():.4f}[/]")
|
||||
mel_spectrogram = audio_decoder(audio_latents)
|
||||
mx.eval(mel_spectrogram)
|
||||
console.print(f"[dim] Mel spectrogram: shape={mel_spectrogram.shape}, std={mel_spectrogram.std().item():.4f}, mean={mel_spectrogram.mean().item():.4f}[/]")
|
||||
|
||||
audio_waveform = vocoder(mel_spectrogram)
|
||||
mx.eval(audio_waveform)
|
||||
audio_waveform = vocoder(mel_spectrogram)
|
||||
mx.eval(audio_waveform)
|
||||
|
||||
audio_np = np.array(audio_waveform.astype(mx.float32))
|
||||
if audio_np.ndim == 3:
|
||||
audio_np = audio_np[0]
|
||||
audio_np = np.array(audio_waveform.astype(mx.float32))
|
||||
if audio_np.ndim == 3:
|
||||
audio_np = audio_np[0]
|
||||
|
||||
# Get sample rate from vocoder (dynamic: 24kHz for LTX-2, 48kHz for LTX-2.3 BWE)
|
||||
vocoder_sample_rate = getattr(vocoder, 'output_sampling_rate', AUDIO_SAMPLE_RATE)
|
||||
# Get sample rate from vocoder (dynamic: 24kHz for LTX-2, 48kHz for LTX-2.3 BWE)
|
||||
vocoder_sample_rate = getattr(vocoder, 'output_sampling_rate', AUDIO_SAMPLE_RATE)
|
||||
|
||||
del audio_decoder, vocoder
|
||||
mx.clear_cache()
|
||||
console.print("[green]✓[/] Audio decoded")
|
||||
del audio_decoder, vocoder
|
||||
mx.clear_cache()
|
||||
console.print("[green]✓[/] Audio decoded")
|
||||
|
||||
audio_path = Path(output_audio_path) if output_audio_path else output_path.with_suffix('.wav')
|
||||
save_audio(audio_np, audio_path, vocoder_sample_rate)
|
||||
@@ -2398,6 +2496,8 @@ Examples:
|
||||
help="Tiling mode for VAE decoding")
|
||||
parser.add_argument("--stream", action="store_true", help="Stream frames to output as they're decoded")
|
||||
parser.add_argument("--audio", "-a", action="store_true", help="Enable synchronized audio generation")
|
||||
parser.add_argument("--audio-file", type=str, default=None, help="Path to audio file for A2V (audio-to-video) conditioning")
|
||||
parser.add_argument("--audio-start-time", type=float, default=0.0, help="Start time in seconds for audio file (default: 0.0)")
|
||||
parser.add_argument("--output-audio", type=str, default=None, help="Output audio path")
|
||||
parser.add_argument("--apg", action="store_true", help="Use Adaptive Projected Guidance instead of CFG (more stable for I2V)")
|
||||
parser.add_argument("--apg-eta", type=float, default=1.0, help="APG parallel component weight (1.0 = keep full parallel)")
|
||||
@@ -2457,6 +2557,8 @@ Examples:
|
||||
lora_strength=args.lora_strength,
|
||||
lora_strength_stage_1=args.lora_strength_stage_1,
|
||||
lora_strength_stage_2=args.lora_strength_stage_2,
|
||||
audio_file=args.audio_file,
|
||||
audio_start_time=args.audio_start_time,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user