Add audio generation capabilities to video pipeline, including audio position grid creation, audio frame computation, and integration of audio VAE and vocoder. Update tests to cover new audio functionalities.

This commit is contained in:
Prince Canuma
2026-01-18 21:28:56 +01:00
parent b36ad1e22d
commit 7069cc39c9
2 changed files with 667 additions and 127 deletions

View File

@@ -37,7 +37,7 @@ class Colors:
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
from mlx_video.models.ltx.ltx import LTXModel from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.transformer import Modality from mlx_video.models.ltx.transformer import Modality
from mlx_video.convert import sanitize_transformer_weights from mlx_video.convert import sanitize_transformer_weights, sanitize_audio_vae_weights, sanitize_vocoder_weights
from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
@@ -65,6 +65,15 @@ DEFAULT_NEGATIVE_PROMPT = (
BASE_SHIFT_ANCHOR = 1024 BASE_SHIFT_ANCHOR = 1024
MAX_SHIFT_ANCHOR = 4096 MAX_SHIFT_ANCHOR = 4096
# Audio constants
AUDIO_SAMPLE_RATE = 24000 # Output audio sample rate
AUDIO_LATENT_SAMPLE_RATE = 16000 # VAE internal sample rate
AUDIO_HOP_LENGTH = 160
AUDIO_LATENT_DOWNSAMPLE_FACTOR = 4
AUDIO_LATENT_CHANNELS = 8 # Latent channels before patchifying
AUDIO_MEL_BINS = 16
AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25
def ltx2_scheduler( def ltx2_scheduler(
steps: int, steps: int,
@@ -195,6 +204,54 @@ def create_position_grid(
return mx.array(pixel_coords, dtype=mx.float32) return mx.array(pixel_coords, dtype=mx.float32)
def create_audio_position_grid(
batch_size: int,
audio_frames: int,
sample_rate: int = AUDIO_LATENT_SAMPLE_RATE,
hop_length: int = AUDIO_HOP_LENGTH,
downsample_factor: int = AUDIO_LATENT_DOWNSAMPLE_FACTOR,
is_causal: bool = True,
) -> mx.array:
"""Create temporal position grid for audio RoPE.
Audio positions are timestamps in seconds, shape (B, 1, T, 2).
Matches PyTorch's AudioPatchifier.get_patch_grid_bounds exactly.
Args:
batch_size: Batch size
audio_frames: Number of audio latent frames
sample_rate: Audio sample rate (default 16000)
hop_length: Hop length for mel spectrogram (default 160)
downsample_factor: Latent downsample factor (default 4)
is_causal: Whether to use causal alignment (default True)
Returns:
Position grid of shape (B, 1, T, 2)
"""
def get_audio_latent_time_in_sec(start_idx: int, end_idx: int) -> np.ndarray:
"""Convert latent indices to seconds."""
latent_frame = np.arange(start_idx, end_idx, dtype=np.float32)
mel_frame = latent_frame * downsample_factor
if is_causal:
mel_frame = np.clip(mel_frame + 1 - downsample_factor, 0, None)
return mel_frame * hop_length / sample_rate
start_times = get_audio_latent_time_in_sec(0, audio_frames)
end_times = get_audio_latent_time_in_sec(1, audio_frames + 1)
positions = np.stack([start_times, end_times], axis=-1)
positions = positions[np.newaxis, np.newaxis, :, :] # (1, 1, T, 2)
positions = np.tile(positions, (batch_size, 1, 1, 1))
return mx.array(positions, dtype=mx.float32)
def compute_audio_frames(num_video_frames: int, fps: float) -> int:
"""Compute number of audio latent frames given video duration."""
duration = num_video_frames / fps
return round(duration * AUDIO_LATENTS_PER_SECOND)
def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array: def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array:
"""Compute CFG (Classifier-Free Guidance) delta. """Compute CFG (Classifier-Free Guidance) delta.
@@ -209,6 +266,116 @@ def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array:
return (scale - 1.0) * (cond - uncond) return (scale - 1.0) * (cond - uncond)
def load_audio_decoder(model_path: Path):
"""Load audio VAE decoder."""
from mlx_video.models.ltx.audio_vae import AudioDecoder, CausalityAxis, NormType
decoder = AudioDecoder(
ch=128,
out_ch=2, # stereo
ch_mult=(1, 2, 4),
num_res_blocks=2,
attn_resolutions={8, 16, 32},
resolution=256,
z_channels=AUDIO_LATENT_CHANNELS,
norm_type=NormType.PIXEL,
causality_axis=CausalityAxis.HEIGHT,
mel_bins=64, # Output mel bins
)
# Load weights - try dev model first, fall back to distilled
weight_file = model_path / "ltx-2-19b-dev.safetensors"
if not weight_file.exists():
weight_file = model_path / "ltx-2-19b-distilled.safetensors"
if weight_file.exists():
raw_weights = mx.load(str(weight_file))
sanitized = sanitize_audio_vae_weights(raw_weights)
if sanitized:
decoder.load_weights(list(sanitized.items()), strict=False)
# Manually load per-channel statistics
if "per_channel_statistics._mean_of_means" in sanitized:
decoder.per_channel_statistics._mean_of_means = sanitized["per_channel_statistics._mean_of_means"]
if "per_channel_statistics._std_of_means" in sanitized:
decoder.per_channel_statistics._std_of_means = sanitized["per_channel_statistics._std_of_means"]
return decoder
def load_vocoder(model_path: Path):
"""Load vocoder for mel to waveform conversion."""
from mlx_video.models.ltx.audio_vae import Vocoder
vocoder = Vocoder(
resblock_kernel_sizes=[3, 7, 11],
upsample_rates=[6, 5, 2, 2, 2],
upsample_kernel_sizes=[16, 15, 8, 4, 4],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_initial_channel=1024,
stereo=True,
output_sample_rate=AUDIO_SAMPLE_RATE,
)
# Load weights - try dev model first, fall back to distilled
weight_file = model_path / "ltx-2-19b-dev.safetensors"
if not weight_file.exists():
weight_file = model_path / "ltx-2-19b-distilled.safetensors"
if weight_file.exists():
raw_weights = mx.load(str(weight_file))
sanitized = sanitize_vocoder_weights(raw_weights)
if sanitized:
vocoder.load_weights(list(sanitized.items()), strict=False)
return vocoder
def save_audio(audio: np.ndarray, path: Path, sample_rate: int = AUDIO_SAMPLE_RATE):
"""Save audio to WAV file."""
import wave
# Ensure audio is in correct format (channels, samples) or (samples,)
if audio.ndim == 2:
# (channels, samples) -> (samples, channels)
audio = audio.T
# Normalize and convert to int16
audio = np.clip(audio, -1.0, 1.0)
audio_int16 = (audio * 32767).astype(np.int16)
with wave.open(str(path), 'wb') as wf:
wf.setnchannels(2 if audio_int16.ndim == 2 else 1)
wf.setsampwidth(2) # 16-bit
wf.setframerate(sample_rate)
wf.writeframes(audio_int16.tobytes())
def mux_video_audio(video_path: Path, audio_path: Path, output_path: Path) -> bool:
"""Combine video and audio into final output using ffmpeg."""
import subprocess
cmd = [
"ffmpeg", "-y",
"-i", str(video_path),
"-i", str(audio_path),
"-c:v", "copy",
"-c:a", "aac",
"-shortest",
str(output_path)
]
try:
subprocess.run(cmd, check=True, capture_output=True)
return True
except subprocess.CalledProcessError as e:
print(f"{Colors.RED}FFmpeg error: {e.stderr.decode()}{Colors.RESET}")
return False
except FileNotFoundError:
print(f"{Colors.RED}FFmpeg not found. Please install ffmpeg.{Colors.RESET}")
return False
def denoise_with_cfg( def denoise_with_cfg(
latents: mx.array, latents: mx.array,
positions: mx.array, positions: mx.array,
@@ -222,10 +389,9 @@ def denoise_with_cfg(
) -> mx.array: ) -> mx.array:
"""Run denoising loop with CFG (Classifier-Free Guidance). """Run denoising loop with CFG (Classifier-Free Guidance).
Optimized version that: Uses separate forward passes for positive and negative conditioning
1. Batches positive and negative forward passes together to match PyTorch implementation behavior (avoids potential issues with
2. Precomputes RoPE once and reuses it (avoids expensive NumPy conversion each step) batched attention patterns).
3. Minimizes mx.eval() calls for better performance
Args: Args:
latents: Noisy latent tensor (B, C, F, H, W) latents: Noisy latent tensor (B, C, F, H, W)
@@ -250,18 +416,10 @@ def denoise_with_cfg(
sigmas_list = sigmas.tolist() sigmas_list = sigmas.tolist()
use_cfg = cfg_scale != 1.0 use_cfg = cfg_scale != 1.0
# Pre-compute batched context for CFG (concat pos and neg along batch dim)
if use_cfg:
# Shape: (2, seq_len, dim) - batch pos and neg together
batched_context = mx.concatenate([text_embeddings_pos, text_embeddings_neg], axis=0)
batched_positions = mx.concatenate([positions, positions], axis=0)
else:
batched_positions = positions
# Precompute RoPE once (expensive operation due to NumPy conversion for double precision) # Precompute RoPE once (expensive operation due to NumPy conversion for double precision)
# This avoids recomputing it every forward pass # This avoids recomputing it every forward pass
precomputed_rope = precompute_freqs_cis( precomputed_rope = precompute_freqs_cis(
batched_positions, positions,
dim=transformer.inner_dim, dim=transformer.inner_dim,
theta=transformer.positional_embedding_theta, theta=transformer.positional_embedding_theta,
max_pos=transformer.positional_embedding_max_pos, max_pos=transformer.positional_embedding_max_pos,
@@ -289,45 +447,38 @@ def denoise_with_cfg(
else: else:
timesteps = mx.full((b, num_tokens), sigma, dtype=dtype) timesteps = mx.full((b, num_tokens), sigma, dtype=dtype)
if use_cfg: # First forward pass: positive conditioning
# Batch both positive and negative in a single forward pass video_modality_pos = Modality(
batched_latents = mx.concatenate([latents_flat, latents_flat], axis=0)
batched_timesteps = mx.concatenate([timesteps, timesteps], axis=0)
video_modality = Modality(
latent=batched_latents,
timesteps=batched_timesteps,
positions=batched_positions,
context=batched_context,
context_mask=None,
enabled=True,
positional_embeddings=precomputed_rope, # Use precomputed RoPE
)
# Single forward pass for both pos and neg
batched_output, _ = transformer(video=video_modality, audio=None)
# Split results: first half is positive, second half is negative
denoised_pos = batched_output[:1]
denoised_neg = batched_output[1:]
# Apply CFG: denoised = denoised_pos + (scale - 1) * (denoised_pos - denoised_neg)
denoised_flat = denoised_pos + (cfg_scale - 1.0) * (denoised_pos - denoised_neg)
else:
# No CFG - single forward pass
video_modality = Modality(
latent=latents_flat, latent=latents_flat,
timesteps=timesteps, timesteps=timesteps,
positions=positions, positions=positions,
context=text_embeddings_pos, context=text_embeddings_pos,
context_mask=None, context_mask=None,
enabled=True, enabled=True,
positional_embeddings=precomputed_rope, # Use precomputed RoPE positional_embeddings=precomputed_rope,
) )
denoised_flat, _ = transformer(video=video_modality, audio=None) velocity_pos, _ = transformer(video=video_modality_pos, audio=None)
if use_cfg:
# Second forward pass: negative conditioning
video_modality_neg = Modality(
latent=latents_flat,
timesteps=timesteps,
positions=positions,
context=text_embeddings_neg,
context_mask=None,
enabled=True,
positional_embeddings=precomputed_rope,
)
velocity_neg, _ = transformer(video=video_modality_neg, audio=None)
# Apply CFG: velocity = pos + (scale - 1) * (pos - neg)
velocity_flat = velocity_pos + (cfg_scale - 1.0) * (velocity_pos - velocity_neg)
else:
velocity_flat = velocity_pos
# Reshape back to 5D # Reshape back to 5D
velocity = mx.reshape(mx.transpose(denoised_flat, (0, 2, 1)), (b, c, f, h, w)) velocity = mx.reshape(mx.transpose(velocity_flat, (0, 2, 1)), (b, c, f, h, w))
denoised = to_denoised(latents, velocity, sigma) denoised = to_denoised(latents, velocity, sigma)
# Apply conditioning mask if state is provided # Apply conditioning mask if state is provided
@@ -348,6 +499,185 @@ def denoise_with_cfg(
return latents return latents
def denoise_av_with_cfg(
video_latents: mx.array,
audio_latents: mx.array,
video_positions: mx.array,
audio_positions: mx.array,
video_embeddings_pos: mx.array,
video_embeddings_neg: mx.array,
audio_embeddings_pos: mx.array,
audio_embeddings_neg: mx.array,
transformer: LTXModel,
sigmas: mx.array,
cfg_scale: float = 4.0,
verbose: bool = True,
video_state: Optional[LatentState] = None,
) -> tuple[mx.array, mx.array]:
"""Run denoising loop for audio-video generation with CFG.
Uses separate forward passes for positive and negative CFG to ensure
correct audio-video cross-attention behavior (matching PyTorch implementation).
Args:
video_latents: Video latent tensor (B, C, F, H, W)
audio_latents: Audio latent tensor (B, C, T, F)
video_positions: Video position embeddings
audio_positions: Audio position embeddings
video_embeddings_pos: Positive video text embeddings
video_embeddings_neg: Negative video text embeddings
audio_embeddings_pos: Positive audio text embeddings
audio_embeddings_neg: Negative audio text embeddings
transformer: LTX model
sigmas: Array of sigma values for denoising schedule
cfg_scale: Guidance scale (default 4.0, 1.0 = no guidance)
verbose: Whether to show progress bar
video_state: Optional LatentState for I2V conditioning
Returns:
Tuple of (video_latents, audio_latents)
"""
from mlx_video.models.ltx.rope import precompute_freqs_cis
dtype = video_latents.dtype
if video_state is not None:
video_latents = video_state.latent
sigmas_list = sigmas.tolist()
use_cfg = cfg_scale != 1.0
# Precompute video RoPE (single batch, not doubled for CFG)
precomputed_video_rope = precompute_freqs_cis(
video_positions,
dim=transformer.inner_dim,
theta=transformer.positional_embedding_theta,
max_pos=transformer.positional_embedding_max_pos,
use_middle_indices_grid=transformer.use_middle_indices_grid,
num_attention_heads=transformer.num_attention_heads,
rope_type=transformer.rope_type,
double_precision=transformer.config.double_precision_rope,
)
# Precompute audio RoPE (1D positions)
precomputed_audio_rope = precompute_freqs_cis(
audio_positions,
dim=transformer.audio_inner_dim,
theta=transformer.positional_embedding_theta,
max_pos=transformer.audio_positional_embedding_max_pos,
use_middle_indices_grid=transformer.use_middle_indices_grid,
num_attention_heads=transformer.audio_num_attention_heads,
rope_type=transformer.rope_type,
double_precision=transformer.config.double_precision_rope,
)
mx.eval(precomputed_video_rope, precomputed_audio_rope)
for i in tqdm(range(len(sigmas_list) - 1), desc="Denoising A/V", disable=not verbose):
sigma = sigmas_list[i]
sigma_next = sigmas_list[i + 1]
# Flatten video latents
b, c, f, h, w = video_latents.shape
num_video_tokens = f * h * w
video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1))
# Flatten audio latents: (B, C, T, F) -> (B, T, C*F)
ab, ac, at, af = audio_latents.shape
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F)
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af))
# Compute per-token timesteps for video
if video_state is not None:
denoise_mask_flat = mx.reshape(video_state.denoise_mask, (b, 1, f, 1, 1))
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens))
video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
else:
video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
audio_timesteps = mx.full((ab, at), sigma, dtype=dtype)
# First forward pass: positive conditioning
video_modality_pos = Modality(
latent=video_flat,
timesteps=video_timesteps,
positions=video_positions,
context=video_embeddings_pos,
context_mask=None,
enabled=True,
positional_embeddings=precomputed_video_rope,
)
audio_modality_pos = Modality(
latent=audio_flat,
timesteps=audio_timesteps,
positions=audio_positions,
context=audio_embeddings_pos,
context_mask=None,
enabled=True,
positional_embeddings=precomputed_audio_rope,
)
video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos)
if use_cfg:
# Second forward pass: negative conditioning
video_modality_neg = Modality(
latent=video_flat,
timesteps=video_timesteps,
positions=video_positions,
context=video_embeddings_neg,
context_mask=None,
enabled=True,
positional_embeddings=precomputed_video_rope,
)
audio_modality_neg = Modality(
latent=audio_flat,
timesteps=audio_timesteps,
positions=audio_positions,
context=audio_embeddings_neg,
context_mask=None,
enabled=True,
positional_embeddings=precomputed_audio_rope,
)
video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg)
# Apply CFG: denoised = pos + (scale - 1) * (pos - neg)
video_velocity_flat = video_vel_pos + (cfg_scale - 1.0) * (video_vel_pos - video_vel_neg)
audio_velocity_flat = audio_vel_pos + (cfg_scale - 1.0) * (audio_vel_pos - audio_vel_neg)
else:
video_velocity_flat = video_vel_pos
audio_velocity_flat = audio_vel_pos
# Reshape velocities back
video_velocity = mx.reshape(mx.transpose(video_velocity_flat, (0, 2, 1)), (b, c, f, h, w))
audio_velocity = mx.reshape(audio_velocity_flat, (ab, at, ac, af))
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) # (B, C, T, F)
# Compute denoised
video_denoised = to_denoised(video_latents, video_velocity, sigma)
audio_denoised = to_denoised(audio_latents, audio_velocity, sigma)
# Apply conditioning mask for video if state is provided
if video_state is not None:
video_denoised = apply_denoise_mask(video_denoised, video_state.clean_latent, video_state.denoise_mask)
# Euler step
if sigma_next > 0:
sigma_next_arr = mx.array(sigma_next, dtype=dtype)
sigma_arr = mx.array(sigma, dtype=dtype)
video_latents = video_denoised + sigma_next_arr * (video_latents - video_denoised) / sigma_arr
audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr
else:
video_latents = video_denoised
audio_latents = audio_denoised
mx.eval(video_latents, audio_latents)
return video_latents, audio_latents
def generate_video_dev( def generate_video_dev(
model_repo: str, model_repo: str,
text_encoder_repo: str, text_encoder_repo: str,
@@ -361,6 +691,7 @@ def generate_video_dev(
seed: int = 42, seed: int = 42,
fps: int = 24, fps: int = 24,
output_path: str = "output.mp4", output_path: str = "output.mp4",
output_audio_path: Optional[str] = None,
save_frames: bool = False, save_frames: bool = False,
verbose: bool = True, verbose: bool = True,
enhance_prompt: bool = False, enhance_prompt: bool = False,
@@ -370,6 +701,7 @@ def generate_video_dev(
image_strength: float = 1.0, image_strength: float = 1.0,
image_frame_idx: int = 0, image_frame_idx: int = 0,
tiling: str = "none", tiling: str = "none",
audio: bool = False,
): ):
"""Generate video using LTX-2 dev model with CFG. """Generate video using LTX-2 dev model with CFG.
@@ -389,6 +721,7 @@ def generate_video_dev(
seed: Random seed for reproducibility seed: Random seed for reproducibility
fps: Frames per second for output video fps: Frames per second for output video
output_path: Path to save the output video output_path: Path to save the output video
output_audio_path: Path to save audio (if audio=True)
save_frames: Whether to save individual frames as images save_frames: Whether to save individual frames as images
verbose: Whether to print progress verbose: Whether to print progress
enhance_prompt: Whether to enhance prompt using Gemma enhance_prompt: Whether to enhance prompt using Gemma
@@ -398,6 +731,7 @@ def generate_video_dev(
image_strength: Conditioning strength (1.0 = full denoise, 0.0 = keep original) image_strength: Conditioning strength (1.0 = full denoise, 0.0 = keep original)
image_frame_idx: Frame index to condition (0 = first frame) image_frame_idx: Frame index to condition (0 = first frame)
tiling: Tiling mode for VAE decoding tiling: Tiling mode for VAE decoding
audio: Whether to generate synchronized audio
""" """
start_time = time.time() start_time = time.time()
@@ -410,10 +744,17 @@ def generate_video_dev(
print(f"{Colors.YELLOW}Warning: Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}") print(f"{Colors.YELLOW}Warning: Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}")
num_frames = adjusted_num_frames num_frames = adjusted_num_frames
# Calculate audio frames if audio is enabled
audio_frames = compute_audio_frames(num_frames, fps) if audio else 0
is_i2v = image is not None is_i2v = image is not None
mode_str = "I2V" if is_i2v else "T2V" mode_str = "I2V" if is_i2v else "T2V"
if audio:
mode_str += "+Audio"
print(f"{Colors.BOLD}{Colors.CYAN}[DEV] [{mode_str}] Generating {width}x{height} video with {num_frames} frames{Colors.RESET}") print(f"{Colors.BOLD}{Colors.CYAN}[DEV] [{mode_str}] Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
print(f"{Colors.DIM}Steps: {num_inference_steps}, CFG: {cfg_scale}{Colors.RESET}") print(f"{Colors.DIM}Steps: {num_inference_steps}, CFG: {cfg_scale}{Colors.RESET}")
if audio:
print(f"{Colors.DIM}Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz{Colors.RESET}")
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}") print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
if is_i2v: if is_i2v:
print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}") print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}")
@@ -442,21 +783,54 @@ def generate_video_dev(
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}") print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
# Encode both positive and negative prompts # Encode both positive and negative prompts
text_embeddings_pos, _ = text_encoder(prompt, return_audio_embeddings=False) if audio:
text_embeddings_neg, _ = text_encoder(negative_prompt, return_audio_embeddings=False) video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True)
model_dtype = text_embeddings_pos.dtype video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True)
mx.eval(text_embeddings_pos, text_embeddings_neg) model_dtype = video_embeddings_pos.dtype
mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg)
else:
video_embeddings_pos, _ = text_encoder(prompt, return_audio_embeddings=False)
video_embeddings_neg, _ = text_encoder(negative_prompt, return_audio_embeddings=False)
audio_embeddings_pos = None
audio_embeddings_neg = None
model_dtype = video_embeddings_pos.dtype
mx.eval(video_embeddings_pos, video_embeddings_neg)
del text_encoder del text_encoder
mx.clear_cache() mx.clear_cache()
# Load transformer (dev model) # Load transformer (dev model)
print(f"{Colors.BLUE}Loading dev transformer...{Colors.RESET}") print(f"{Colors.BLUE}Loading dev transformer{' (A/V mode)' if audio else ''}...{Colors.RESET}")
raw_weights = mx.load(str(model_path / 'ltx-2-19b-dev.safetensors')) raw_weights = mx.load(str(model_path / 'ltx-2-19b-dev.safetensors'))
sanitized = sanitize_transformer_weights(raw_weights) sanitized = sanitize_transformer_weights(raw_weights)
# Convert transformer weights to bfloat16 for memory efficiency # Convert transformer weights to bfloat16 for memory efficiency
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()} sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
if audio:
config = LTXModelConfig(
model_type=LTXModelType.AudioVideo,
num_attention_heads=32,
attention_head_dim=128,
in_channels=128,
out_channels=128,
num_layers=48,
cross_attention_dim=4096,
caption_channels=3840,
# Audio config
audio_num_attention_heads=32,
audio_attention_head_dim=64,
audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, # 8 * 16 = 128
audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS,
audio_cross_attention_dim=2048,
rope_type=LTXRopeType.SPLIT,
double_precision_rope=True,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
audio_positional_embedding_max_pos=[20],
use_middle_indices_grid=True,
timestep_scale_multiplier=1000,
)
else:
config = LTXModelConfig( config = LTXModelConfig(
model_type=LTXModelType.VideoOnly, model_type=LTXModelType.VideoOnly,
num_attention_heads=32, num_attention_heads=32,
@@ -503,20 +877,26 @@ def generate_video_dev(
mx.eval(sigmas) mx.eval(sigmas)
print(f"{Colors.DIM}Sigma schedule: {sigmas[0].item():.4f} -> {sigmas[-2].item():.4f} -> {sigmas[-1].item():.4f}{Colors.RESET}") print(f"{Colors.DIM}Sigma schedule: {sigmas[0].item():.4f} -> {sigmas[-2].item():.4f} -> {sigmas[-1].item():.4f}{Colors.RESET}")
# Create position grid # Create position grids
print(f"{Colors.YELLOW}Generating at {width}x{height} ({num_inference_steps} steps, CFG={cfg_scale})...{Colors.RESET}") print(f"{Colors.YELLOW}Generating at {width}x{height} ({num_inference_steps} steps, CFG={cfg_scale})...{Colors.RESET}")
mx.random.seed(seed) mx.random.seed(seed)
positions = create_position_grid(1, latent_frames, latent_h, latent_w) video_positions = create_position_grid(1, latent_frames, latent_h, latent_w)
mx.eval(positions) mx.eval(video_positions)
if audio:
audio_positions = create_audio_position_grid(1, audio_frames)
mx.eval(audio_positions)
else:
audio_positions = None
# Initialize latents with optional I2V conditioning # Initialize latents with optional I2V conditioning
state = None video_state = None
video_latent_shape = (1, 128, latent_frames, latent_h, latent_w)
if is_i2v and image_latent is not None: if is_i2v and image_latent is not None:
latent_shape = (1, 128, latent_frames, latent_h, latent_w) video_state = LatentState(
state = LatentState( latent=mx.zeros(video_latent_shape, dtype=model_dtype),
latent=mx.zeros(latent_shape, dtype=model_dtype), clean_latent=mx.zeros(video_latent_shape, dtype=model_dtype),
clean_latent=mx.zeros(latent_shape, dtype=model_dtype),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
) )
conditioning = VideoConditionByLatentIndex( conditioning = VideoConditionByLatentIndex(
@@ -524,29 +904,45 @@ def generate_video_dev(
frame_idx=image_frame_idx, frame_idx=image_frame_idx,
strength=image_strength, strength=image_strength,
) )
state = apply_conditioning(state, [conditioning]) video_state = apply_conditioning(video_state, [conditioning])
# Apply noiser # Apply noiser
noise = mx.random.normal(latent_shape, dtype=model_dtype) noise = mx.random.normal(video_latent_shape, dtype=model_dtype)
noise_scale = sigmas[0] noise_scale = sigmas[0]
scaled_mask = state.denoise_mask * noise_scale scaled_mask = video_state.denoise_mask * noise_scale
state = LatentState( video_state = LatentState(
latent=noise * scaled_mask + state.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), latent=noise * scaled_mask + video_state.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=state.clean_latent, clean_latent=video_state.clean_latent,
denoise_mask=state.denoise_mask, denoise_mask=video_state.denoise_mask,
) )
latents = state.latent video_latents = video_state.latent
mx.eval(latents) mx.eval(video_latents)
else: else:
# T2V: just use random noise # T2V: just use random noise
latents = mx.random.normal((1, 128, latent_frames, latent_h, latent_w), dtype=model_dtype) video_latents = mx.random.normal(video_latent_shape, dtype=model_dtype)
mx.eval(latents) mx.eval(video_latents)
# Initialize audio latents if audio is enabled
if audio:
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
mx.eval(audio_latents)
else:
audio_latents = None
# Denoise with CFG # Denoise with CFG
latents = denoise_with_cfg( if audio:
latents, positions, text_embeddings_pos, text_embeddings_neg, video_latents, audio_latents = denoise_av_with_cfg(
transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, state=state video_latents, audio_latents,
video_positions, audio_positions,
video_embeddings_pos, video_embeddings_neg,
audio_embeddings_pos, audio_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, video_state=video_state
)
else:
video_latents = denoise_with_cfg(
video_latents, video_positions, video_embeddings_pos, video_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, state=video_state
) )
del transformer del transformer
@@ -583,32 +979,99 @@ def generate_video_dev(
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none" spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none" temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}") print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose) video = vae_decoder.decode_tiled(video_latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose)
else: else:
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}") print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
video = vae_decoder(latents) video = vae_decoder(video_latents)
mx.eval(video) mx.eval(video)
del vae_decoder
mx.clear_cache() mx.clear_cache()
# Convert to uint8 frames # Decode audio if enabled
audio_np = None
if audio and audio_latents is not None:
print(f"{Colors.BLUE}Decoding audio...{Colors.RESET}")
# Load audio decoder
audio_decoder = load_audio_decoder(model_path)
mx.eval(audio_decoder.parameters())
# Decode audio latents to mel spectrogram
mel_spectrogram = audio_decoder(audio_latents)
mx.eval(mel_spectrogram)
del audio_decoder
mx.clear_cache()
# Load vocoder and convert mel to waveform
vocoder = load_vocoder(model_path)
mx.eval(vocoder.parameters())
audio_waveform = vocoder(mel_spectrogram)
mx.eval(audio_waveform)
del vocoder
mx.clear_cache()
# Convert to numpy
audio_np = np.array(audio_waveform)
if audio_np.ndim == 3:
audio_np = audio_np[0] # Remove batch dim
print(f"{Colors.DIM} Audio shape: {audio_np.shape}, duration: {audio_np.shape[-1] / AUDIO_SAMPLE_RATE:.2f}s{Colors.RESET}")
# Convert video to uint8 frames
video = mx.squeeze(video, axis=0) # (C, F, H, W) video = mx.squeeze(video, axis=0) # (C, F, H, W)
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C) video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0) video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
video = (video * 255).astype(mx.uint8) video = (video * 255).astype(mx.uint8)
video_np = np.array(video) video_np = np.array(video)
# Save video # Save outputs
output_path = Path(output_path) output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True) output_path.parent.mkdir(parents=True, exist_ok=True)
# Determine audio output path
if audio and audio_np is not None:
if output_audio_path is None:
audio_output = output_path.parent / f"{output_path.stem}.wav"
else:
audio_output = Path(output_audio_path)
# Save audio
save_audio(audio_np, audio_output)
print(f"{Colors.GREEN}Saved audio to{Colors.RESET} {audio_output}")
# Save video (to temp file if we need to mux with audio)
if audio and audio_np is not None:
# Save video to temp file, then mux with audio
temp_video_path = output_path.parent / f"{output_path.stem}_temp.mp4"
video_save_path = temp_video_path
else:
video_save_path = output_path
try: try:
import cv2 import cv2
h, w = video_np.shape[1], video_np.shape[2] h, w = video_np.shape[1], video_np.shape[2]
fourcc = cv2.VideoWriter_fourcc(*'avc1') fourcc = cv2.VideoWriter_fourcc(*'avc1')
out = cv2.VideoWriter(str(output_path), fourcc, fps, (w, h)) out = cv2.VideoWriter(str(video_save_path), fourcc, fps, (w, h))
for frame in video_np: for frame in video_np:
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
out.release() out.release()
if audio and audio_np is not None:
# Mux video and audio
print(f"{Colors.BLUE}Muxing video and audio...{Colors.RESET}")
if mux_video_audio(temp_video_path, audio_output, output_path):
print(f"{Colors.GREEN}Saved video with audio to{Colors.RESET} {output_path}")
# Clean up temp file
temp_video_path.unlink(missing_ok=True)
else:
# Fallback: keep separate files
print(f"{Colors.YELLOW}Could not mux, keeping separate files{Colors.RESET}")
temp_video_path.rename(output_path.parent / f"{output_path.stem}_video.mp4")
else:
print(f"{Colors.GREEN}Saved video to{Colors.RESET} {output_path}") print(f"{Colors.GREEN}Saved video to{Colors.RESET} {output_path}")
except Exception as e: except Exception as e:
print(f"{Colors.RED}Could not save video: {e}{Colors.RESET}") print(f"{Colors.RED}Could not save video: {e}{Colors.RESET}")
@@ -642,6 +1105,10 @@ Examples:
# Image-to-Video (I2V) # Image-to-Video (I2V)
python -m mlx_video.generate_dev --prompt "A person dancing" --image photo.jpg python -m mlx_video.generate_dev --prompt "A person dancing" --image photo.jpg
# With synchronized audio
python -m mlx_video.generate_dev --prompt "Ocean waves crashing on rocks" --audio
python -m mlx_video.generate_dev --prompt "A busy city street" --audio --output-audio street.wav
""" """
) )
@@ -769,6 +1236,17 @@ Examples:
choices=["none", "auto", "default", "aggressive", "conservative", "spatial", "temporal"], choices=["none", "auto", "default", "aggressive", "conservative", "spatial", "temporal"],
help="Tiling mode for VAE decoding (default: none, faster on high-memory systems)" help="Tiling mode for VAE decoding (default: none, faster on high-memory systems)"
) )
parser.add_argument(
"--audio",
action="store_true",
help="Generate synchronized audio with the video"
)
parser.add_argument(
"--output-audio",
type=str,
default=None,
help="Output audio path (default: same as video with .wav extension)"
)
args = parser.parse_args() args = parser.parse_args()
generate_video_dev( generate_video_dev(
@@ -784,6 +1262,7 @@ Examples:
seed=args.seed, seed=args.seed,
fps=args.fps, fps=args.fps,
output_path=args.output_path, output_path=args.output_path,
output_audio_path=args.output_audio,
save_frames=args.save_frames, save_frames=args.save_frames,
verbose=args.verbose, verbose=args.verbose,
enhance_prompt=args.enhance_prompt, enhance_prompt=args.enhance_prompt,
@@ -793,6 +1272,7 @@ Examples:
image_strength=args.image_strength, image_strength=args.image_strength,
image_frame_idx=args.image_frame_idx, image_frame_idx=args.image_frame_idx,
tiling=args.tiling, tiling=args.tiling,
audio=args.audio,
) )

View File

@@ -2,14 +2,16 @@
import pytest import pytest
import mlx.core as mx import mlx.core as mx
import numpy as np
from mlx_video.generate_dev import ( from mlx_video.generate_dev import (
ltx2_scheduler, ltx2_scheduler,
create_position_grid, create_position_grid,
create_audio_position_grid,
compute_audio_frames,
cfg_delta, cfg_delta,
denoise_with_cfg,
DEFAULT_NEGATIVE_PROMPT, DEFAULT_NEGATIVE_PROMPT,
AUDIO_SAMPLE_RATE,
AUDIO_LATENTS_PER_SECOND,
) )
@@ -260,28 +262,6 @@ class TestInputValidation:
class TestDenoiseWithCFGMocked: class TestDenoiseWithCFGMocked:
"""Tests for denoise_with_cfg with mocked transformer.""" """Tests for denoise_with_cfg with mocked transformer."""
def test_denoise_returns_correct_shape(self):
"""Denoised output should have same shape as input latents."""
# Create a simple mock transformer
class MockTransformer:
inner_dim = 4096
positional_embedding_theta = 10000.0
positional_embedding_max_pos = [20, 2048, 2048]
use_middle_indices_grid = True
num_attention_heads = 32
rope_type = None
class config:
double_precision_rope = True
def __call__(self, video, audio):
# Return input as output (identity)
return video.latent, None
# Skip this test if we can't import the required modules easily
# This is a structural test to ensure the function signature is correct
pass
def test_sigmas_list_conversion(self): def test_sigmas_list_conversion(self):
"""Sigmas should be convertible to list.""" """Sigmas should be convertible to list."""
sigmas = ltx2_scheduler(steps=5) sigmas = ltx2_scheduler(steps=5)
@@ -296,16 +276,10 @@ class TestTilingDefault:
def test_tiling_default_is_none(self): def test_tiling_default_is_none(self):
"""Default tiling should be 'none' for performance.""" """Default tiling should be 'none' for performance."""
# Import and check the default
import argparse
from mlx_video.generate_dev import main
# The default is set in the argparse definition
# We verify this by checking the function signature
import inspect import inspect
sig = inspect.signature( from mlx_video.generate_dev import generate_video_dev
__import__('mlx_video.generate_dev', fromlist=['generate_video_dev']).generate_video_dev
) sig = inspect.signature(generate_video_dev)
tiling_param = sig.parameters.get('tiling') tiling_param = sig.parameters.get('tiling')
assert tiling_param is not None assert tiling_param is not None
@@ -358,5 +332,91 @@ class TestLatentDimensions:
assert num_tokens == expected, f"Expected {expected} tokens, got {num_tokens}" assert num_tokens == expected, f"Expected {expected} tokens, got {num_tokens}"
class TestAudioPositionGrid:
"""Tests for audio position grid creation."""
def test_audio_position_grid_shape(self):
"""Audio position grid should have correct shape (B, 1, T, 2)."""
batch_size = 1
audio_frames = 34 # ~1.36 seconds at 25 latent frames/sec
positions = create_audio_position_grid(batch_size, audio_frames)
expected_shape = (batch_size, 1, audio_frames, 2)
assert positions.shape == expected_shape, \
f"Expected {expected_shape}, got {positions.shape}"
def test_audio_position_grid_dtype(self):
"""Audio position grid should be float32."""
positions = create_audio_position_grid(1, 34)
assert positions.dtype == mx.float32, \
f"Expected float32, got {positions.dtype}"
def test_audio_position_grid_batch_size(self):
"""Audio position grid should respect batch size."""
for batch_size in [1, 2, 4]:
positions = create_audio_position_grid(batch_size, 34)
assert positions.shape[0] == batch_size
def test_audio_position_grid_temporal_values(self):
"""Audio positions should be in seconds."""
positions = create_audio_position_grid(1, 34)
# Values should be in seconds (small values for ~1 second of audio)
max_val = mx.max(positions).item()
assert max_val < 10, f"Audio positions seem too large: {max_val}"
assert max_val > 0, "Audio positions should be positive"
def test_audio_position_grid_no_nan_or_inf(self):
"""Audio position grid should not contain NaN or Inf."""
positions = create_audio_position_grid(1, 34)
assert not mx.any(mx.isnan(positions)).item(), "Audio position grid contains NaN"
assert not mx.any(mx.isinf(positions)).item(), "Audio position grid contains Inf"
class TestComputeAudioFrames:
"""Tests for audio frame count calculation."""
def test_audio_frames_basic(self):
"""Audio frames should be proportional to video duration."""
# 33 frames at 24 fps = ~1.375 seconds
# At 25 latent frames/sec = ~34 audio frames
audio_frames = compute_audio_frames(33, 24.0)
assert audio_frames > 0
assert isinstance(audio_frames, int)
def test_audio_frames_scales_with_video(self):
"""More video frames should produce more audio frames."""
audio_33 = compute_audio_frames(33, 24.0)
audio_65 = compute_audio_frames(65, 24.0)
assert audio_65 > audio_33, \
f"Expected more audio frames for longer video: {audio_65} <= {audio_33}"
def test_audio_frames_formula(self):
"""Audio frames should match expected formula."""
num_video_frames = 33
fps = 24.0
duration = num_video_frames / fps # ~1.375 seconds
expected = round(duration * AUDIO_LATENTS_PER_SECOND)
actual = compute_audio_frames(num_video_frames, fps)
assert actual == expected, f"Expected {expected}, got {actual}"
class TestAudioConstants:
"""Tests for audio constants."""
def test_audio_sample_rate(self):
"""Audio sample rate should be 24000 Hz."""
assert AUDIO_SAMPLE_RATE == 24000
def test_audio_latents_per_second(self):
"""Audio latents per second should be 25."""
assert AUDIO_LATENTS_PER_SECOND == 25.0
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v"]) pytest.main([__file__, "-v"])