Add audio to video conditioning

This commit is contained in:
Prince Canuma
2026-03-16 01:42:11 +01:00
parent f53b9e0807
commit 6f6105b715
7 changed files with 623 additions and 62 deletions

View File

@@ -454,6 +454,7 @@ def denoise_distilled(
audio_latents: Optional[mx.array] = None,
audio_positions: Optional[mx.array] = None,
audio_embeddings: Optional[mx.array] = None,
audio_frozen: bool = False,
) -> tuple[mx.array, Optional[mx.array]]:
"""Run denoising loop for distilled pipeline (no CFG)."""
dtype = latents.dtype
@@ -513,14 +514,17 @@ def denoise_distilled(
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3))
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype)
# A2V: frozen audio uses timesteps=0 (tells model audio is clean)
a_ts = mx.zeros((ab, at), dtype=dtype) if audio_frozen else mx.full((ab, at), sigma, dtype=dtype)
a_sig = mx.zeros((ab,), dtype=dtype) if audio_frozen else mx.full((ab,), sigma, dtype=dtype)
audio_modality = Modality(
latent=audio_flat,
timesteps=mx.full((ab, at), sigma, dtype=dtype),
timesteps=a_ts,
positions=audio_positions,
context=audio_embeddings,
context_mask=None,
enabled=True,
sigma=mx.full((ab,), sigma, dtype=dtype),
sigma=a_sig,
)
velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality)
@@ -529,9 +533,6 @@ def denoise_distilled(
mx.eval(audio_velocity)
# Compute denoised (x0) using per-token timesteps in float32
# x0 = latent - timestep * velocity
# For conditioned tokens (timestep=0): x0 = latent
# For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity
sigma_f32 = mx.array(sigma, dtype=mx.float32)
latents_flat_f32 = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1)
@@ -539,7 +540,7 @@ def denoise_distilled(
denoised = mx.reshape(mx.transpose(x0_f32, (0, 2, 1)), (b, c, f, h, w))
audio_denoised = None
if enable_audio and audio_velocity is not None:
if enable_audio and audio_velocity is not None and not audio_frozen:
ab, ac, at, af = audio_latents.shape
audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af))
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3))
@@ -552,15 +553,15 @@ def denoise_distilled(
if audio_denoised is not None:
mx.eval(audio_denoised)
# Euler step in float32 (latents stay in float32)
# Euler step in float32
if sigma_next > 0:
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
latents = denoised + sigma_next_f32 * (latents - denoised) / sigma_f32
if enable_audio and audio_denoised is not None:
if enable_audio and audio_denoised is not None and not audio_frozen:
audio_latents = audio_denoised + sigma_next_f32 * (audio_latents - audio_denoised) / sigma_f32
else:
latents = denoised
if enable_audio and audio_denoised is not None:
if enable_audio and audio_denoised is not None and not audio_frozen:
audio_latents = audio_denoised
mx.eval(latents)
@@ -785,6 +786,7 @@ def denoise_dev_av(
stg_video_blocks: Optional[list] = None,
stg_audio_blocks: Optional[list] = None,
modality_scale: float = 1.0,
audio_frozen: bool = False,
) -> tuple[mx.array, mx.array]:
"""Run denoising loop for dev pipeline with CFG/APG, STG, modality guidance, and audio.
@@ -879,11 +881,12 @@ def denoise_dev_av(
else:
video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
audio_timesteps = mx.full((ab, at), sigma, dtype=dtype)
# A2V: frozen audio uses timesteps=0 (tells model audio is clean)
audio_timesteps = mx.zeros((ab, at), dtype=dtype) if audio_frozen else mx.full((ab, at), sigma, dtype=dtype)
# Positive conditioning pass
sigma_array = mx.full((b,), sigma, dtype=dtype)
audio_sigma_array = mx.full((ab,), sigma, dtype=dtype)
audio_sigma_array = mx.zeros((ab,), dtype=dtype) if audio_frozen else mx.full((ab,), sigma, dtype=dtype)
video_modality_pos = Modality(
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
context=video_embeddings_pos, context_mask=None, enabled=True,
@@ -1001,11 +1004,13 @@ def denoise_dev_av(
video_velocity_f32 = (video_latents - video_denoised_f32) / sigma_f32
video_latents = video_latents + video_velocity_f32 * dt_f32
audio_velocity_f32 = (audio_latents - audio_denoised_f32) / sigma_f32
audio_latents = audio_latents + audio_velocity_f32 * dt_f32
if not audio_frozen:
audio_velocity_f32 = (audio_latents - audio_denoised_f32) / sigma_f32
audio_latents = audio_latents + audio_velocity_f32 * dt_f32
else:
video_latents = video_denoised_f32
audio_latents = audio_denoised_f32
if not audio_frozen:
audio_latents = audio_denoised_f32
mx.eval(video_latents, audio_latents)
progress.advance(task)
@@ -1037,6 +1042,7 @@ def denoise_res2s_av(
noise_seed: int = 42,
bongmath: bool = True,
bongmath_max_iter: int = 100,
audio_frozen: bool = False,
) -> tuple[mx.array, mx.array]:
"""Run res_2s second-order denoising loop with CFG/STG/modality guidance.
@@ -1125,10 +1131,10 @@ def denoise_res2s_av(
video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
else:
video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
audio_timesteps = mx.full((ab, at), sigma, dtype=dtype)
audio_timesteps = mx.zeros((ab, at), dtype=dtype) if audio_frozen else mx.full((ab, at), sigma, dtype=dtype)
sigma_array = mx.full((b,), sigma, dtype=dtype)
audio_sigma_array = mx.full((ab,), sigma, dtype=dtype)
audio_sigma_array = mx.zeros((ab,), dtype=dtype) if audio_frozen else mx.full((ab,), sigma, dtype=dtype)
# Pass 1: Positive conditioning
video_modality_pos = Modality(
@@ -1270,18 +1276,23 @@ def denoise_res2s_av(
# Compute midpoint
eps_1_video = denoised_video_1 - x_anchor_video
eps_1_audio = denoised_audio_1 - x_anchor_audio
x_mid_video = x_anchor_video + h * a21 * eps_1_video
x_mid_audio = x_anchor_audio + h * a21 * eps_1_audio
if not audio_frozen:
eps_1_audio = denoised_audio_1 - x_anchor_audio
x_mid_audio = x_anchor_audio + h * a21 * eps_1_audio
else:
eps_1_audio = None
x_mid_audio = audio_latents # frozen: pass through unchanged
# SDE noise injection at substep
substep_noise_key, key1, key2 = mx.random.split(substep_noise_key, 3)
substep_noise_v = get_new_noise(video_latents.shape, key1)
substep_noise_a = get_new_noise(audio_latents.shape, key2)
x_mid_video = sde_noise_step(x_anchor_video, x_mid_video, sigma, sub_sigma, substep_noise_v)
x_mid_audio = sde_noise_step(x_anchor_audio, x_mid_audio, sigma, sub_sigma, substep_noise_a)
if not audio_frozen:
substep_noise_a = get_new_noise(audio_latents.shape, key2)
x_mid_audio = sde_noise_step(x_anchor_audio, x_mid_audio, sigma, sub_sigma, substep_noise_a)
mx.eval(x_mid_video, x_mid_audio)
# ============================================================
@@ -1291,9 +1302,13 @@ def denoise_res2s_av(
for _ in range(bongmath_max_iter):
x_anchor_video = x_mid_video - h * a21 * eps_1_video
eps_1_video = denoised_video_1 - x_anchor_video
x_anchor_audio = x_mid_audio - h * a21 * eps_1_audio
eps_1_audio = denoised_audio_1 - x_anchor_audio
mx.eval(x_anchor_video, x_anchor_audio, eps_1_video, eps_1_audio)
if not audio_frozen:
x_anchor_audio = x_mid_audio - h * a21 * eps_1_audio
eps_1_audio = denoised_audio_1 - x_anchor_audio
if audio_frozen:
mx.eval(x_anchor_video, eps_1_video)
else:
mx.eval(x_anchor_video, x_anchor_audio, eps_1_video, eps_1_audio)
# ============================================================
# Stage 2: Evaluate denoiser at midpoint sigma
@@ -1306,21 +1321,21 @@ def denoise_res2s_av(
# Final combination with RK coefficients
# ============================================================
eps_2_video = denoised_video_2 - x_anchor_video
eps_2_audio = denoised_audio_2 - x_anchor_audio
x_next_video = x_anchor_video + h * (b1 * eps_1_video + b2 * eps_2_video)
x_next_audio = x_anchor_audio + h * (b1 * eps_1_audio + b2 * eps_2_audio)
# SDE noise injection at step level
step_noise_key, key1, key2 = mx.random.split(step_noise_key, 3)
step_noise_v = get_new_noise(video_latents.shape, key1)
step_noise_a = get_new_noise(audio_latents.shape, key2)
x_next_video = sde_noise_step(x_anchor_video, x_next_video, sigma, sigma_next, step_noise_v)
x_next_audio = sde_noise_step(x_anchor_audio, x_next_audio, sigma, sigma_next, step_noise_a)
video_latents = x_next_video.astype(mx.float32)
audio_latents = x_next_audio.astype(mx.float32)
if not audio_frozen:
eps_2_audio = denoised_audio_2 - x_anchor_audio
x_next_audio = x_anchor_audio + h * (b1 * eps_1_audio + b2 * eps_2_audio)
step_noise_a = get_new_noise(audio_latents.shape, key2)
x_next_audio = sde_noise_step(x_anchor_audio, x_next_audio, sigma, sigma_next, step_noise_a)
audio_latents = x_next_audio.astype(mx.float32)
mx.eval(video_latents, audio_latents)
progress.advance(task)
@@ -1330,7 +1345,8 @@ def denoise_res2s_av(
video_latents, audio_latents, sigmas_list[n_full_steps]
)
video_latents = denoised_video
audio_latents = denoised_audio
if not audio_frozen:
audio_latents = denoised_audio
mx.eval(video_latents, audio_latents)
return video_latents, audio_latents
@@ -1443,6 +1459,8 @@ def generate_video(
lora_strength: float = 1.0,
lora_strength_stage_1: Optional[float] = None,
lora_strength_stage_2: Optional[float] = None,
audio_file: Optional[str] = None,
audio_start_time: float = 0.0,
):
"""Generate video using LTX-2 models.
@@ -1496,8 +1514,16 @@ def generate_video(
num_frames = adjusted_num_frames
is_i2v = image is not None
is_a2v = audio_file is not None
if is_a2v and audio:
raise ValueError("Cannot use both --audio-file (A2V) and --audio (generate audio). Choose one.")
# A2V implicitly enables audio path through the transformer
if is_a2v:
audio = True
mode_str = "I2V" if is_i2v else "T2V"
if audio:
if is_a2v:
mode_str = "A2V" + ("+I2V" if is_i2v else "")
elif audio:
mode_str += "+Audio"
pipeline_names = {
@@ -1599,6 +1625,62 @@ def generate_video(
stg_blocks = [29]
console.print(f"[dim]Auto-detected STG blocks: {stg_blocks} (model={'2.3' if transformer.config.has_prompt_adaln else '2'})[/]")
# ==========================================================================
# A2V: Encode input audio to frozen latents
# ==========================================================================
a2v_audio_latents = None
a2v_waveform = None
a2v_sr = None
if is_a2v:
from mlx_video.models.ltx.audio_vae.audio_processor import load_audio, ensure_stereo, waveform_to_mel
from mlx_video.convert import convert_audio_encoder
from mlx_video.models.ltx.audio_vae import AudioEncoder
with console.status("[blue]Loading and encoding input audio (A2V)...[/]", spinner="dots"):
video_duration = num_frames / fps
# Load audio
waveform, sr = load_audio(
audio_file,
target_sr=AUDIO_LATENT_SAMPLE_RATE,
start_time=audio_start_time,
max_duration=video_duration,
)
waveform = ensure_stereo(waveform)
a2v_waveform = waveform.copy()
a2v_sr = sr
# Compute mel-spectrogram
mel = waveform_to_mel(waveform, sample_rate=sr, n_fft=1024, hop_length=AUDIO_HOP_LENGTH, n_mels=64)
# Convert audio encoder weights if needed, then load
encoder_dir = convert_audio_encoder(model_path, source_repo="Lightricks/LTX-2")
audio_encoder = AudioEncoder.from_pretrained(encoder_dir)
mx.eval(audio_encoder.parameters())
# Encode: (1, 2, time, 64) -> normalized latents
encoded = audio_encoder(mel)
mx.eval(encoded)
# encoded is in MLX format (B, T', mel_bins', z_channels) = (1, T', 16, 8)
# Convert to PyTorch-style format for consistency: (B, C, T, mel_bins)
a2v_audio_latents = mx.transpose(encoded, (0, 3, 1, 2)).astype(model_dtype)
# Trim/pad to match expected audio_frames
t_encoded = a2v_audio_latents.shape[2]
if t_encoded > audio_frames:
a2v_audio_latents = a2v_audio_latents[:, :, :audio_frames, :]
elif t_encoded < audio_frames:
pad_size = audio_frames - t_encoded
padding = mx.zeros((1, AUDIO_LATENT_CHANNELS, pad_size, AUDIO_MEL_BINS), dtype=model_dtype)
a2v_audio_latents = mx.concatenate([a2v_audio_latents, padding], axis=2)
mx.eval(a2v_audio_latents)
del audio_encoder
mx.clear_cache()
console.print(f"[green]✓[/] Audio encoded ({a2v_audio_latents.shape[2]} frames from {audio_file})")
# ==========================================================================
# Pipeline-specific generation logic
# ==========================================================================
@@ -1636,9 +1718,9 @@ def generate_video(
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions)
# Always init audio latents/positions - PyTorch unconditionally generates audio
# Init audio latents/positions: use encoded A2V latents or random
audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
mx.eval(audio_positions, audio_latents)
# Apply I2V conditioning
@@ -1671,6 +1753,7 @@ def generate_video(
latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS,
verbose=verbose, state=state1,
audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings,
audio_frozen=is_a2v,
)
# Upsample latents
@@ -1723,7 +1806,7 @@ def generate_video(
mx.eval(latents)
# Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch)
if audio_latents is not None:
if audio_latents is not None and not is_a2v:
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
@@ -1735,6 +1818,7 @@ def generate_video(
verbose=verbose, state=state2,
audio_latents=audio_latents, audio_positions=audio_positions,
audio_embeddings=audio_embeddings,
audio_frozen=is_a2v,
)
elif pipeline == PipelineType.DEV:
@@ -1770,7 +1854,7 @@ def generate_video(
# Always init audio latents/positions - PyTorch unconditionally generates audio
audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
mx.eval(audio_positions, audio_latents)
# Initialize latents with optional I2V conditioning
@@ -1811,6 +1895,7 @@ def generate_video(
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
audio_frozen=is_a2v,
)
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
@@ -1858,7 +1943,7 @@ def generate_video(
# Always init audio latents/positions - PyTorch unconditionally generates audio
audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
mx.eval(audio_positions, audio_latents)
# Apply I2V conditioning for stage 1
@@ -1899,6 +1984,7 @@ def generate_video(
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
audio_frozen=is_a2v,
)
mx.eval(audio_latents)
@@ -1969,7 +2055,7 @@ def generate_video(
mx.eval(latents)
# Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch)
if audio_latents is not None:
if audio_latents is not None and not is_a2v:
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
@@ -1981,6 +2067,7 @@ def generate_video(
verbose=verbose, state=state2,
audio_latents=audio_latents, audio_positions=audio_positions,
audio_embeddings=audio_embeddings_pos,
audio_frozen=is_a2v,
)
elif pipeline == PipelineType.DEV_TWO_STAGE_HQ:
@@ -2045,7 +2132,7 @@ def generate_video(
mx.eval(positions)
audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
mx.eval(audio_positions, audio_latents)
# Apply I2V conditioning for stage 1
@@ -2087,6 +2174,7 @@ def generate_video(
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
noise_seed=seed,
audio_frozen=is_a2v,
)
mx.eval(audio_latents)
@@ -2148,7 +2236,7 @@ def generate_video(
mx.eval(latents)
# Re-noise audio at sigma=0.909375 for joint refinement
if audio_latents is not None:
if audio_latents is not None and not is_a2v:
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
@@ -2165,6 +2253,7 @@ def generate_video(
audio_cfg_scale=1.0,
cfg_rescale=0.0, verbose=verbose, video_state=state2,
noise_seed=seed + 1,
audio_frozen=is_a2v,
)
del transformer
@@ -2279,29 +2368,38 @@ def generate_video(
# Decode and save audio if enabled
audio_np = None
vocoder_sample_rate = AUDIO_SAMPLE_RATE
if audio and audio_latents is not None:
with console.status("[blue]🔊 Decoding audio...[/]", spinner="dots"):
audio_decoder = load_audio_decoder(model_path, pipeline)
vocoder = load_vocoder_model(model_path, pipeline)
mx.eval(audio_decoder.parameters(), vocoder.parameters())
if is_a2v and a2v_waveform is not None:
# A2V: use original input audio waveform (no VAE decoding needed)
audio_np = a2v_waveform
if audio_np.ndim == 1:
audio_np = audio_np[np.newaxis, :]
vocoder_sample_rate = a2v_sr or AUDIO_LATENT_SAMPLE_RATE
console.print("[green]✓[/] Using original input audio (A2V)")
else:
with console.status("[blue]Decoding audio...[/]", spinner="dots"):
audio_decoder = load_audio_decoder(model_path, pipeline)
vocoder = load_vocoder_model(model_path, pipeline)
mx.eval(audio_decoder.parameters(), vocoder.parameters())
mel_spectrogram = audio_decoder(audio_latents)
mx.eval(mel_spectrogram)
console.print(f"[dim] Mel spectrogram: shape={mel_spectrogram.shape}, std={mel_spectrogram.std().item():.4f}, mean={mel_spectrogram.mean().item():.4f}[/]")
mel_spectrogram = audio_decoder(audio_latents)
mx.eval(mel_spectrogram)
console.print(f"[dim] Mel spectrogram: shape={mel_spectrogram.shape}, std={mel_spectrogram.std().item():.4f}, mean={mel_spectrogram.mean().item():.4f}[/]")
audio_waveform = vocoder(mel_spectrogram)
mx.eval(audio_waveform)
audio_waveform = vocoder(mel_spectrogram)
mx.eval(audio_waveform)
audio_np = np.array(audio_waveform.astype(mx.float32))
if audio_np.ndim == 3:
audio_np = audio_np[0]
audio_np = np.array(audio_waveform.astype(mx.float32))
if audio_np.ndim == 3:
audio_np = audio_np[0]
# Get sample rate from vocoder (dynamic: 24kHz for LTX-2, 48kHz for LTX-2.3 BWE)
vocoder_sample_rate = getattr(vocoder, 'output_sampling_rate', AUDIO_SAMPLE_RATE)
# Get sample rate from vocoder (dynamic: 24kHz for LTX-2, 48kHz for LTX-2.3 BWE)
vocoder_sample_rate = getattr(vocoder, 'output_sampling_rate', AUDIO_SAMPLE_RATE)
del audio_decoder, vocoder
mx.clear_cache()
console.print("[green]✓[/] Audio decoded")
del audio_decoder, vocoder
mx.clear_cache()
console.print("[green]✓[/] Audio decoded")
audio_path = Path(output_audio_path) if output_audio_path else output_path.with_suffix('.wav')
save_audio(audio_np, audio_path, vocoder_sample_rate)
@@ -2398,6 +2496,8 @@ Examples:
help="Tiling mode for VAE decoding")
parser.add_argument("--stream", action="store_true", help="Stream frames to output as they're decoded")
parser.add_argument("--audio", "-a", action="store_true", help="Enable synchronized audio generation")
parser.add_argument("--audio-file", type=str, default=None, help="Path to audio file for A2V (audio-to-video) conditioning")
parser.add_argument("--audio-start-time", type=float, default=0.0, help="Start time in seconds for audio file (default: 0.0)")
parser.add_argument("--output-audio", type=str, default=None, help="Output audio path")
parser.add_argument("--apg", action="store_true", help="Use Adaptive Projected Guidance instead of CFG (more stable for I2V)")
parser.add_argument("--apg-eta", type=float, default=1.0, help="APG parallel component weight (1.0 = keep full parallel)")
@@ -2457,6 +2557,8 @@ Examples:
lora_strength=args.lora_strength,
lora_strength_stage_1=args.lora_strength_stage_1,
lora_strength_stage_2=args.lora_strength_stage_2,
audio_file=args.audio_file,
audio_start_time=args.audio_start_time,
)