Add audio generation capabilities to video pipeline, including audio position grid creation, audio frame computation, and integration of audio VAE and vocoder. Update tests to cover new audio functionalities.
This commit is contained in:
@@ -37,7 +37,7 @@ class Colors:
|
||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
|
||||
from mlx_video.models.ltx.ltx import LTXModel
|
||||
from mlx_video.models.ltx.transformer import Modality
|
||||
from mlx_video.convert import sanitize_transformer_weights
|
||||
from mlx_video.convert import sanitize_transformer_weights, sanitize_audio_vae_weights, sanitize_vocoder_weights
|
||||
from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding
|
||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
|
||||
@@ -65,6 +65,15 @@ DEFAULT_NEGATIVE_PROMPT = (
|
||||
BASE_SHIFT_ANCHOR = 1024
|
||||
MAX_SHIFT_ANCHOR = 4096
|
||||
|
||||
# Audio constants
|
||||
AUDIO_SAMPLE_RATE = 24000 # Output audio sample rate
|
||||
AUDIO_LATENT_SAMPLE_RATE = 16000 # VAE internal sample rate
|
||||
AUDIO_HOP_LENGTH = 160
|
||||
AUDIO_LATENT_DOWNSAMPLE_FACTOR = 4
|
||||
AUDIO_LATENT_CHANNELS = 8 # Latent channels before patchifying
|
||||
AUDIO_MEL_BINS = 16
|
||||
AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25
|
||||
|
||||
|
||||
def ltx2_scheduler(
|
||||
steps: int,
|
||||
@@ -195,6 +204,54 @@ def create_position_grid(
|
||||
return mx.array(pixel_coords, dtype=mx.float32)
|
||||
|
||||
|
||||
def create_audio_position_grid(
|
||||
batch_size: int,
|
||||
audio_frames: int,
|
||||
sample_rate: int = AUDIO_LATENT_SAMPLE_RATE,
|
||||
hop_length: int = AUDIO_HOP_LENGTH,
|
||||
downsample_factor: int = AUDIO_LATENT_DOWNSAMPLE_FACTOR,
|
||||
is_causal: bool = True,
|
||||
) -> mx.array:
|
||||
"""Create temporal position grid for audio RoPE.
|
||||
|
||||
Audio positions are timestamps in seconds, shape (B, 1, T, 2).
|
||||
Matches PyTorch's AudioPatchifier.get_patch_grid_bounds exactly.
|
||||
|
||||
Args:
|
||||
batch_size: Batch size
|
||||
audio_frames: Number of audio latent frames
|
||||
sample_rate: Audio sample rate (default 16000)
|
||||
hop_length: Hop length for mel spectrogram (default 160)
|
||||
downsample_factor: Latent downsample factor (default 4)
|
||||
is_causal: Whether to use causal alignment (default True)
|
||||
|
||||
Returns:
|
||||
Position grid of shape (B, 1, T, 2)
|
||||
"""
|
||||
def get_audio_latent_time_in_sec(start_idx: int, end_idx: int) -> np.ndarray:
|
||||
"""Convert latent indices to seconds."""
|
||||
latent_frame = np.arange(start_idx, end_idx, dtype=np.float32)
|
||||
mel_frame = latent_frame * downsample_factor
|
||||
if is_causal:
|
||||
mel_frame = np.clip(mel_frame + 1 - downsample_factor, 0, None)
|
||||
return mel_frame * hop_length / sample_rate
|
||||
|
||||
start_times = get_audio_latent_time_in_sec(0, audio_frames)
|
||||
end_times = get_audio_latent_time_in_sec(1, audio_frames + 1)
|
||||
|
||||
positions = np.stack([start_times, end_times], axis=-1)
|
||||
positions = positions[np.newaxis, np.newaxis, :, :] # (1, 1, T, 2)
|
||||
positions = np.tile(positions, (batch_size, 1, 1, 1))
|
||||
|
||||
return mx.array(positions, dtype=mx.float32)
|
||||
|
||||
|
||||
def compute_audio_frames(num_video_frames: int, fps: float) -> int:
|
||||
"""Compute number of audio latent frames given video duration."""
|
||||
duration = num_video_frames / fps
|
||||
return round(duration * AUDIO_LATENTS_PER_SECOND)
|
||||
|
||||
|
||||
def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array:
|
||||
"""Compute CFG (Classifier-Free Guidance) delta.
|
||||
|
||||
@@ -209,6 +266,116 @@ def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array:
|
||||
return (scale - 1.0) * (cond - uncond)
|
||||
|
||||
|
||||
def load_audio_decoder(model_path: Path):
|
||||
"""Load audio VAE decoder."""
|
||||
from mlx_video.models.ltx.audio_vae import AudioDecoder, CausalityAxis, NormType
|
||||
|
||||
decoder = AudioDecoder(
|
||||
ch=128,
|
||||
out_ch=2, # stereo
|
||||
ch_mult=(1, 2, 4),
|
||||
num_res_blocks=2,
|
||||
attn_resolutions={8, 16, 32},
|
||||
resolution=256,
|
||||
z_channels=AUDIO_LATENT_CHANNELS,
|
||||
norm_type=NormType.PIXEL,
|
||||
causality_axis=CausalityAxis.HEIGHT,
|
||||
mel_bins=64, # Output mel bins
|
||||
)
|
||||
|
||||
# Load weights - try dev model first, fall back to distilled
|
||||
weight_file = model_path / "ltx-2-19b-dev.safetensors"
|
||||
if not weight_file.exists():
|
||||
weight_file = model_path / "ltx-2-19b-distilled.safetensors"
|
||||
|
||||
if weight_file.exists():
|
||||
raw_weights = mx.load(str(weight_file))
|
||||
sanitized = sanitize_audio_vae_weights(raw_weights)
|
||||
if sanitized:
|
||||
decoder.load_weights(list(sanitized.items()), strict=False)
|
||||
|
||||
# Manually load per-channel statistics
|
||||
if "per_channel_statistics._mean_of_means" in sanitized:
|
||||
decoder.per_channel_statistics._mean_of_means = sanitized["per_channel_statistics._mean_of_means"]
|
||||
if "per_channel_statistics._std_of_means" in sanitized:
|
||||
decoder.per_channel_statistics._std_of_means = sanitized["per_channel_statistics._std_of_means"]
|
||||
|
||||
return decoder
|
||||
|
||||
|
||||
def load_vocoder(model_path: Path):
|
||||
"""Load vocoder for mel to waveform conversion."""
|
||||
from mlx_video.models.ltx.audio_vae import Vocoder
|
||||
|
||||
vocoder = Vocoder(
|
||||
resblock_kernel_sizes=[3, 7, 11],
|
||||
upsample_rates=[6, 5, 2, 2, 2],
|
||||
upsample_kernel_sizes=[16, 15, 8, 4, 4],
|
||||
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
upsample_initial_channel=1024,
|
||||
stereo=True,
|
||||
output_sample_rate=AUDIO_SAMPLE_RATE,
|
||||
)
|
||||
|
||||
# Load weights - try dev model first, fall back to distilled
|
||||
weight_file = model_path / "ltx-2-19b-dev.safetensors"
|
||||
if not weight_file.exists():
|
||||
weight_file = model_path / "ltx-2-19b-distilled.safetensors"
|
||||
|
||||
if weight_file.exists():
|
||||
raw_weights = mx.load(str(weight_file))
|
||||
sanitized = sanitize_vocoder_weights(raw_weights)
|
||||
if sanitized:
|
||||
vocoder.load_weights(list(sanitized.items()), strict=False)
|
||||
|
||||
return vocoder
|
||||
|
||||
|
||||
def save_audio(audio: np.ndarray, path: Path, sample_rate: int = AUDIO_SAMPLE_RATE):
|
||||
"""Save audio to WAV file."""
|
||||
import wave
|
||||
|
||||
# Ensure audio is in correct format (channels, samples) or (samples,)
|
||||
if audio.ndim == 2:
|
||||
# (channels, samples) -> (samples, channels)
|
||||
audio = audio.T
|
||||
|
||||
# Normalize and convert to int16
|
||||
audio = np.clip(audio, -1.0, 1.0)
|
||||
audio_int16 = (audio * 32767).astype(np.int16)
|
||||
|
||||
with wave.open(str(path), 'wb') as wf:
|
||||
wf.setnchannels(2 if audio_int16.ndim == 2 else 1)
|
||||
wf.setsampwidth(2) # 16-bit
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(audio_int16.tobytes())
|
||||
|
||||
|
||||
def mux_video_audio(video_path: Path, audio_path: Path, output_path: Path) -> bool:
|
||||
"""Combine video and audio into final output using ffmpeg."""
|
||||
import subprocess
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-i", str(video_path),
|
||||
"-i", str(audio_path),
|
||||
"-c:v", "copy",
|
||||
"-c:a", "aac",
|
||||
"-shortest",
|
||||
str(output_path)
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True, capture_output=True)
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"{Colors.RED}FFmpeg error: {e.stderr.decode()}{Colors.RESET}")
|
||||
return False
|
||||
except FileNotFoundError:
|
||||
print(f"{Colors.RED}FFmpeg not found. Please install ffmpeg.{Colors.RESET}")
|
||||
return False
|
||||
|
||||
|
||||
def denoise_with_cfg(
|
||||
latents: mx.array,
|
||||
positions: mx.array,
|
||||
@@ -222,10 +389,9 @@ def denoise_with_cfg(
|
||||
) -> mx.array:
|
||||
"""Run denoising loop with CFG (Classifier-Free Guidance).
|
||||
|
||||
Optimized version that:
|
||||
1. Batches positive and negative forward passes together
|
||||
2. Precomputes RoPE once and reuses it (avoids expensive NumPy conversion each step)
|
||||
3. Minimizes mx.eval() calls for better performance
|
||||
Uses separate forward passes for positive and negative conditioning
|
||||
to match PyTorch implementation behavior (avoids potential issues with
|
||||
batched attention patterns).
|
||||
|
||||
Args:
|
||||
latents: Noisy latent tensor (B, C, F, H, W)
|
||||
@@ -250,18 +416,10 @@ def denoise_with_cfg(
|
||||
sigmas_list = sigmas.tolist()
|
||||
use_cfg = cfg_scale != 1.0
|
||||
|
||||
# Pre-compute batched context for CFG (concat pos and neg along batch dim)
|
||||
if use_cfg:
|
||||
# Shape: (2, seq_len, dim) - batch pos and neg together
|
||||
batched_context = mx.concatenate([text_embeddings_pos, text_embeddings_neg], axis=0)
|
||||
batched_positions = mx.concatenate([positions, positions], axis=0)
|
||||
else:
|
||||
batched_positions = positions
|
||||
|
||||
# Precompute RoPE once (expensive operation due to NumPy conversion for double precision)
|
||||
# This avoids recomputing it every forward pass
|
||||
precomputed_rope = precompute_freqs_cis(
|
||||
batched_positions,
|
||||
positions,
|
||||
dim=transformer.inner_dim,
|
||||
theta=transformer.positional_embedding_theta,
|
||||
max_pos=transformer.positional_embedding_max_pos,
|
||||
@@ -289,45 +447,38 @@ def denoise_with_cfg(
|
||||
else:
|
||||
timesteps = mx.full((b, num_tokens), sigma, dtype=dtype)
|
||||
|
||||
# First forward pass: positive conditioning
|
||||
video_modality_pos = Modality(
|
||||
latent=latents_flat,
|
||||
timesteps=timesteps,
|
||||
positions=positions,
|
||||
context=text_embeddings_pos,
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
positional_embeddings=precomputed_rope,
|
||||
)
|
||||
velocity_pos, _ = transformer(video=video_modality_pos, audio=None)
|
||||
|
||||
if use_cfg:
|
||||
# Batch both positive and negative in a single forward pass
|
||||
batched_latents = mx.concatenate([latents_flat, latents_flat], axis=0)
|
||||
batched_timesteps = mx.concatenate([timesteps, timesteps], axis=0)
|
||||
|
||||
video_modality = Modality(
|
||||
latent=batched_latents,
|
||||
timesteps=batched_timesteps,
|
||||
positions=batched_positions,
|
||||
context=batched_context,
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
positional_embeddings=precomputed_rope, # Use precomputed RoPE
|
||||
)
|
||||
|
||||
# Single forward pass for both pos and neg
|
||||
batched_output, _ = transformer(video=video_modality, audio=None)
|
||||
|
||||
# Split results: first half is positive, second half is negative
|
||||
denoised_pos = batched_output[:1]
|
||||
denoised_neg = batched_output[1:]
|
||||
|
||||
# Apply CFG: denoised = denoised_pos + (scale - 1) * (denoised_pos - denoised_neg)
|
||||
denoised_flat = denoised_pos + (cfg_scale - 1.0) * (denoised_pos - denoised_neg)
|
||||
else:
|
||||
# No CFG - single forward pass
|
||||
video_modality = Modality(
|
||||
# Second forward pass: negative conditioning
|
||||
video_modality_neg = Modality(
|
||||
latent=latents_flat,
|
||||
timesteps=timesteps,
|
||||
positions=positions,
|
||||
context=text_embeddings_pos,
|
||||
context=text_embeddings_neg,
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
positional_embeddings=precomputed_rope, # Use precomputed RoPE
|
||||
positional_embeddings=precomputed_rope,
|
||||
)
|
||||
denoised_flat, _ = transformer(video=video_modality, audio=None)
|
||||
velocity_neg, _ = transformer(video=video_modality_neg, audio=None)
|
||||
|
||||
# Apply CFG: velocity = pos + (scale - 1) * (pos - neg)
|
||||
velocity_flat = velocity_pos + (cfg_scale - 1.0) * (velocity_pos - velocity_neg)
|
||||
else:
|
||||
velocity_flat = velocity_pos
|
||||
|
||||
# Reshape back to 5D
|
||||
velocity = mx.reshape(mx.transpose(denoised_flat, (0, 2, 1)), (b, c, f, h, w))
|
||||
velocity = mx.reshape(mx.transpose(velocity_flat, (0, 2, 1)), (b, c, f, h, w))
|
||||
denoised = to_denoised(latents, velocity, sigma)
|
||||
|
||||
# Apply conditioning mask if state is provided
|
||||
@@ -348,6 +499,185 @@ def denoise_with_cfg(
|
||||
return latents
|
||||
|
||||
|
||||
def denoise_av_with_cfg(
|
||||
video_latents: mx.array,
|
||||
audio_latents: mx.array,
|
||||
video_positions: mx.array,
|
||||
audio_positions: mx.array,
|
||||
video_embeddings_pos: mx.array,
|
||||
video_embeddings_neg: mx.array,
|
||||
audio_embeddings_pos: mx.array,
|
||||
audio_embeddings_neg: mx.array,
|
||||
transformer: LTXModel,
|
||||
sigmas: mx.array,
|
||||
cfg_scale: float = 4.0,
|
||||
verbose: bool = True,
|
||||
video_state: Optional[LatentState] = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Run denoising loop for audio-video generation with CFG.
|
||||
|
||||
Uses separate forward passes for positive and negative CFG to ensure
|
||||
correct audio-video cross-attention behavior (matching PyTorch implementation).
|
||||
|
||||
Args:
|
||||
video_latents: Video latent tensor (B, C, F, H, W)
|
||||
audio_latents: Audio latent tensor (B, C, T, F)
|
||||
video_positions: Video position embeddings
|
||||
audio_positions: Audio position embeddings
|
||||
video_embeddings_pos: Positive video text embeddings
|
||||
video_embeddings_neg: Negative video text embeddings
|
||||
audio_embeddings_pos: Positive audio text embeddings
|
||||
audio_embeddings_neg: Negative audio text embeddings
|
||||
transformer: LTX model
|
||||
sigmas: Array of sigma values for denoising schedule
|
||||
cfg_scale: Guidance scale (default 4.0, 1.0 = no guidance)
|
||||
verbose: Whether to show progress bar
|
||||
video_state: Optional LatentState for I2V conditioning
|
||||
|
||||
Returns:
|
||||
Tuple of (video_latents, audio_latents)
|
||||
"""
|
||||
from mlx_video.models.ltx.rope import precompute_freqs_cis
|
||||
|
||||
dtype = video_latents.dtype
|
||||
if video_state is not None:
|
||||
video_latents = video_state.latent
|
||||
|
||||
sigmas_list = sigmas.tolist()
|
||||
use_cfg = cfg_scale != 1.0
|
||||
|
||||
# Precompute video RoPE (single batch, not doubled for CFG)
|
||||
precomputed_video_rope = precompute_freqs_cis(
|
||||
video_positions,
|
||||
dim=transformer.inner_dim,
|
||||
theta=transformer.positional_embedding_theta,
|
||||
max_pos=transformer.positional_embedding_max_pos,
|
||||
use_middle_indices_grid=transformer.use_middle_indices_grid,
|
||||
num_attention_heads=transformer.num_attention_heads,
|
||||
rope_type=transformer.rope_type,
|
||||
double_precision=transformer.config.double_precision_rope,
|
||||
)
|
||||
|
||||
# Precompute audio RoPE (1D positions)
|
||||
precomputed_audio_rope = precompute_freqs_cis(
|
||||
audio_positions,
|
||||
dim=transformer.audio_inner_dim,
|
||||
theta=transformer.positional_embedding_theta,
|
||||
max_pos=transformer.audio_positional_embedding_max_pos,
|
||||
use_middle_indices_grid=transformer.use_middle_indices_grid,
|
||||
num_attention_heads=transformer.audio_num_attention_heads,
|
||||
rope_type=transformer.rope_type,
|
||||
double_precision=transformer.config.double_precision_rope,
|
||||
)
|
||||
mx.eval(precomputed_video_rope, precomputed_audio_rope)
|
||||
|
||||
for i in tqdm(range(len(sigmas_list) - 1), desc="Denoising A/V", disable=not verbose):
|
||||
sigma = sigmas_list[i]
|
||||
sigma_next = sigmas_list[i + 1]
|
||||
|
||||
# Flatten video latents
|
||||
b, c, f, h, w = video_latents.shape
|
||||
num_video_tokens = f * h * w
|
||||
video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1))
|
||||
|
||||
# Flatten audio latents: (B, C, T, F) -> (B, T, C*F)
|
||||
ab, ac, at, af = audio_latents.shape
|
||||
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F)
|
||||
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af))
|
||||
|
||||
# Compute per-token timesteps for video
|
||||
if video_state is not None:
|
||||
denoise_mask_flat = mx.reshape(video_state.denoise_mask, (b, 1, f, 1, 1))
|
||||
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
|
||||
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens))
|
||||
video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
|
||||
else:
|
||||
video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
|
||||
|
||||
audio_timesteps = mx.full((ab, at), sigma, dtype=dtype)
|
||||
|
||||
# First forward pass: positive conditioning
|
||||
video_modality_pos = Modality(
|
||||
latent=video_flat,
|
||||
timesteps=video_timesteps,
|
||||
positions=video_positions,
|
||||
context=video_embeddings_pos,
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
positional_embeddings=precomputed_video_rope,
|
||||
)
|
||||
|
||||
audio_modality_pos = Modality(
|
||||
latent=audio_flat,
|
||||
timesteps=audio_timesteps,
|
||||
positions=audio_positions,
|
||||
context=audio_embeddings_pos,
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
positional_embeddings=precomputed_audio_rope,
|
||||
)
|
||||
|
||||
video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos)
|
||||
|
||||
if use_cfg:
|
||||
# Second forward pass: negative conditioning
|
||||
video_modality_neg = Modality(
|
||||
latent=video_flat,
|
||||
timesteps=video_timesteps,
|
||||
positions=video_positions,
|
||||
context=video_embeddings_neg,
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
positional_embeddings=precomputed_video_rope,
|
||||
)
|
||||
|
||||
audio_modality_neg = Modality(
|
||||
latent=audio_flat,
|
||||
timesteps=audio_timesteps,
|
||||
positions=audio_positions,
|
||||
context=audio_embeddings_neg,
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
positional_embeddings=precomputed_audio_rope,
|
||||
)
|
||||
|
||||
video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg)
|
||||
|
||||
# Apply CFG: denoised = pos + (scale - 1) * (pos - neg)
|
||||
video_velocity_flat = video_vel_pos + (cfg_scale - 1.0) * (video_vel_pos - video_vel_neg)
|
||||
audio_velocity_flat = audio_vel_pos + (cfg_scale - 1.0) * (audio_vel_pos - audio_vel_neg)
|
||||
else:
|
||||
video_velocity_flat = video_vel_pos
|
||||
audio_velocity_flat = audio_vel_pos
|
||||
|
||||
# Reshape velocities back
|
||||
video_velocity = mx.reshape(mx.transpose(video_velocity_flat, (0, 2, 1)), (b, c, f, h, w))
|
||||
audio_velocity = mx.reshape(audio_velocity_flat, (ab, at, ac, af))
|
||||
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) # (B, C, T, F)
|
||||
|
||||
# Compute denoised
|
||||
video_denoised = to_denoised(video_latents, video_velocity, sigma)
|
||||
audio_denoised = to_denoised(audio_latents, audio_velocity, sigma)
|
||||
|
||||
# Apply conditioning mask for video if state is provided
|
||||
if video_state is not None:
|
||||
video_denoised = apply_denoise_mask(video_denoised, video_state.clean_latent, video_state.denoise_mask)
|
||||
|
||||
# Euler step
|
||||
if sigma_next > 0:
|
||||
sigma_next_arr = mx.array(sigma_next, dtype=dtype)
|
||||
sigma_arr = mx.array(sigma, dtype=dtype)
|
||||
video_latents = video_denoised + sigma_next_arr * (video_latents - video_denoised) / sigma_arr
|
||||
audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr
|
||||
else:
|
||||
video_latents = video_denoised
|
||||
audio_latents = audio_denoised
|
||||
|
||||
mx.eval(video_latents, audio_latents)
|
||||
|
||||
return video_latents, audio_latents
|
||||
|
||||
|
||||
def generate_video_dev(
|
||||
model_repo: str,
|
||||
text_encoder_repo: str,
|
||||
@@ -361,6 +691,7 @@ def generate_video_dev(
|
||||
seed: int = 42,
|
||||
fps: int = 24,
|
||||
output_path: str = "output.mp4",
|
||||
output_audio_path: Optional[str] = None,
|
||||
save_frames: bool = False,
|
||||
verbose: bool = True,
|
||||
enhance_prompt: bool = False,
|
||||
@@ -370,6 +701,7 @@ def generate_video_dev(
|
||||
image_strength: float = 1.0,
|
||||
image_frame_idx: int = 0,
|
||||
tiling: str = "none",
|
||||
audio: bool = False,
|
||||
):
|
||||
"""Generate video using LTX-2 dev model with CFG.
|
||||
|
||||
@@ -389,6 +721,7 @@ def generate_video_dev(
|
||||
seed: Random seed for reproducibility
|
||||
fps: Frames per second for output video
|
||||
output_path: Path to save the output video
|
||||
output_audio_path: Path to save audio (if audio=True)
|
||||
save_frames: Whether to save individual frames as images
|
||||
verbose: Whether to print progress
|
||||
enhance_prompt: Whether to enhance prompt using Gemma
|
||||
@@ -398,6 +731,7 @@ def generate_video_dev(
|
||||
image_strength: Conditioning strength (1.0 = full denoise, 0.0 = keep original)
|
||||
image_frame_idx: Frame index to condition (0 = first frame)
|
||||
tiling: Tiling mode for VAE decoding
|
||||
audio: Whether to generate synchronized audio
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -410,10 +744,17 @@ def generate_video_dev(
|
||||
print(f"{Colors.YELLOW}Warning: Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}")
|
||||
num_frames = adjusted_num_frames
|
||||
|
||||
# Calculate audio frames if audio is enabled
|
||||
audio_frames = compute_audio_frames(num_frames, fps) if audio else 0
|
||||
|
||||
is_i2v = image is not None
|
||||
mode_str = "I2V" if is_i2v else "T2V"
|
||||
if audio:
|
||||
mode_str += "+Audio"
|
||||
print(f"{Colors.BOLD}{Colors.CYAN}[DEV] [{mode_str}] Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
|
||||
print(f"{Colors.DIM}Steps: {num_inference_steps}, CFG: {cfg_scale}{Colors.RESET}")
|
||||
if audio:
|
||||
print(f"{Colors.DIM}Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz{Colors.RESET}")
|
||||
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
|
||||
if is_i2v:
|
||||
print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}")
|
||||
@@ -442,37 +783,70 @@ def generate_video_dev(
|
||||
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
|
||||
|
||||
# Encode both positive and negative prompts
|
||||
text_embeddings_pos, _ = text_encoder(prompt, return_audio_embeddings=False)
|
||||
text_embeddings_neg, _ = text_encoder(negative_prompt, return_audio_embeddings=False)
|
||||
model_dtype = text_embeddings_pos.dtype
|
||||
mx.eval(text_embeddings_pos, text_embeddings_neg)
|
||||
if audio:
|
||||
video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True)
|
||||
video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True)
|
||||
model_dtype = video_embeddings_pos.dtype
|
||||
mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg)
|
||||
else:
|
||||
video_embeddings_pos, _ = text_encoder(prompt, return_audio_embeddings=False)
|
||||
video_embeddings_neg, _ = text_encoder(negative_prompt, return_audio_embeddings=False)
|
||||
audio_embeddings_pos = None
|
||||
audio_embeddings_neg = None
|
||||
model_dtype = video_embeddings_pos.dtype
|
||||
mx.eval(video_embeddings_pos, video_embeddings_neg)
|
||||
|
||||
del text_encoder
|
||||
mx.clear_cache()
|
||||
|
||||
# Load transformer (dev model)
|
||||
print(f"{Colors.BLUE}Loading dev transformer...{Colors.RESET}")
|
||||
print(f"{Colors.BLUE}Loading dev transformer{' (A/V mode)' if audio else ''}...{Colors.RESET}")
|
||||
raw_weights = mx.load(str(model_path / 'ltx-2-19b-dev.safetensors'))
|
||||
sanitized = sanitize_transformer_weights(raw_weights)
|
||||
# Convert transformer weights to bfloat16 for memory efficiency
|
||||
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
|
||||
|
||||
config = LTXModelConfig(
|
||||
model_type=LTXModelType.VideoOnly,
|
||||
num_attention_heads=32,
|
||||
attention_head_dim=128,
|
||||
in_channels=128,
|
||||
out_channels=128,
|
||||
num_layers=48,
|
||||
cross_attention_dim=4096,
|
||||
caption_channels=3840,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision_rope=True,
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
timestep_scale_multiplier=1000,
|
||||
)
|
||||
if audio:
|
||||
config = LTXModelConfig(
|
||||
model_type=LTXModelType.AudioVideo,
|
||||
num_attention_heads=32,
|
||||
attention_head_dim=128,
|
||||
in_channels=128,
|
||||
out_channels=128,
|
||||
num_layers=48,
|
||||
cross_attention_dim=4096,
|
||||
caption_channels=3840,
|
||||
# Audio config
|
||||
audio_num_attention_heads=32,
|
||||
audio_attention_head_dim=64,
|
||||
audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, # 8 * 16 = 128
|
||||
audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS,
|
||||
audio_cross_attention_dim=2048,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision_rope=True,
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
audio_positional_embedding_max_pos=[20],
|
||||
use_middle_indices_grid=True,
|
||||
timestep_scale_multiplier=1000,
|
||||
)
|
||||
else:
|
||||
config = LTXModelConfig(
|
||||
model_type=LTXModelType.VideoOnly,
|
||||
num_attention_heads=32,
|
||||
attention_head_dim=128,
|
||||
in_channels=128,
|
||||
out_channels=128,
|
||||
num_layers=48,
|
||||
cross_attention_dim=4096,
|
||||
caption_channels=3840,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision_rope=True,
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
timestep_scale_multiplier=1000,
|
||||
)
|
||||
|
||||
transformer = LTXModel(config)
|
||||
transformer.load_weights(list(sanitized.items()), strict=False)
|
||||
@@ -503,20 +877,26 @@ def generate_video_dev(
|
||||
mx.eval(sigmas)
|
||||
print(f"{Colors.DIM}Sigma schedule: {sigmas[0].item():.4f} -> {sigmas[-2].item():.4f} -> {sigmas[-1].item():.4f}{Colors.RESET}")
|
||||
|
||||
# Create position grid
|
||||
# Create position grids
|
||||
print(f"{Colors.YELLOW}Generating at {width}x{height} ({num_inference_steps} steps, CFG={cfg_scale})...{Colors.RESET}")
|
||||
mx.random.seed(seed)
|
||||
|
||||
positions = create_position_grid(1, latent_frames, latent_h, latent_w)
|
||||
mx.eval(positions)
|
||||
video_positions = create_position_grid(1, latent_frames, latent_h, latent_w)
|
||||
mx.eval(video_positions)
|
||||
|
||||
if audio:
|
||||
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||
mx.eval(audio_positions)
|
||||
else:
|
||||
audio_positions = None
|
||||
|
||||
# Initialize latents with optional I2V conditioning
|
||||
state = None
|
||||
video_state = None
|
||||
video_latent_shape = (1, 128, latent_frames, latent_h, latent_w)
|
||||
if is_i2v and image_latent is not None:
|
||||
latent_shape = (1, 128, latent_frames, latent_h, latent_w)
|
||||
state = LatentState(
|
||||
latent=mx.zeros(latent_shape, dtype=model_dtype),
|
||||
clean_latent=mx.zeros(latent_shape, dtype=model_dtype),
|
||||
video_state = LatentState(
|
||||
latent=mx.zeros(video_latent_shape, dtype=model_dtype),
|
||||
clean_latent=mx.zeros(video_latent_shape, dtype=model_dtype),
|
||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
||||
)
|
||||
conditioning = VideoConditionByLatentIndex(
|
||||
@@ -524,30 +904,46 @@ def generate_video_dev(
|
||||
frame_idx=image_frame_idx,
|
||||
strength=image_strength,
|
||||
)
|
||||
state = apply_conditioning(state, [conditioning])
|
||||
video_state = apply_conditioning(video_state, [conditioning])
|
||||
|
||||
# Apply noiser
|
||||
noise = mx.random.normal(latent_shape, dtype=model_dtype)
|
||||
noise = mx.random.normal(video_latent_shape, dtype=model_dtype)
|
||||
noise_scale = sigmas[0]
|
||||
scaled_mask = state.denoise_mask * noise_scale
|
||||
scaled_mask = video_state.denoise_mask * noise_scale
|
||||
|
||||
state = LatentState(
|
||||
latent=noise * scaled_mask + state.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
||||
clean_latent=state.clean_latent,
|
||||
denoise_mask=state.denoise_mask,
|
||||
video_state = LatentState(
|
||||
latent=noise * scaled_mask + video_state.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
||||
clean_latent=video_state.clean_latent,
|
||||
denoise_mask=video_state.denoise_mask,
|
||||
)
|
||||
latents = state.latent
|
||||
mx.eval(latents)
|
||||
video_latents = video_state.latent
|
||||
mx.eval(video_latents)
|
||||
else:
|
||||
# T2V: just use random noise
|
||||
latents = mx.random.normal((1, 128, latent_frames, latent_h, latent_w), dtype=model_dtype)
|
||||
mx.eval(latents)
|
||||
video_latents = mx.random.normal(video_latent_shape, dtype=model_dtype)
|
||||
mx.eval(video_latents)
|
||||
|
||||
# Initialize audio latents if audio is enabled
|
||||
if audio:
|
||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||
mx.eval(audio_latents)
|
||||
else:
|
||||
audio_latents = None
|
||||
|
||||
# Denoise with CFG
|
||||
latents = denoise_with_cfg(
|
||||
latents, positions, text_embeddings_pos, text_embeddings_neg,
|
||||
transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, state=state
|
||||
)
|
||||
if audio:
|
||||
video_latents, audio_latents = denoise_av_with_cfg(
|
||||
video_latents, audio_latents,
|
||||
video_positions, audio_positions,
|
||||
video_embeddings_pos, video_embeddings_neg,
|
||||
audio_embeddings_pos, audio_embeddings_neg,
|
||||
transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, video_state=video_state
|
||||
)
|
||||
else:
|
||||
video_latents = denoise_with_cfg(
|
||||
video_latents, video_positions, video_embeddings_pos, video_embeddings_neg,
|
||||
transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, state=video_state
|
||||
)
|
||||
|
||||
del transformer
|
||||
mx.clear_cache()
|
||||
@@ -583,33 +979,100 @@ def generate_video_dev(
|
||||
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
|
||||
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
|
||||
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
|
||||
video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose)
|
||||
video = vae_decoder.decode_tiled(video_latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose)
|
||||
else:
|
||||
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
|
||||
video = vae_decoder(latents)
|
||||
video = vae_decoder(video_latents)
|
||||
mx.eval(video)
|
||||
|
||||
del vae_decoder
|
||||
mx.clear_cache()
|
||||
|
||||
# Convert to uint8 frames
|
||||
# Decode audio if enabled
|
||||
audio_np = None
|
||||
if audio and audio_latents is not None:
|
||||
print(f"{Colors.BLUE}Decoding audio...{Colors.RESET}")
|
||||
|
||||
# Load audio decoder
|
||||
audio_decoder = load_audio_decoder(model_path)
|
||||
mx.eval(audio_decoder.parameters())
|
||||
|
||||
# Decode audio latents to mel spectrogram
|
||||
mel_spectrogram = audio_decoder(audio_latents)
|
||||
mx.eval(mel_spectrogram)
|
||||
|
||||
del audio_decoder
|
||||
mx.clear_cache()
|
||||
|
||||
# Load vocoder and convert mel to waveform
|
||||
vocoder = load_vocoder(model_path)
|
||||
mx.eval(vocoder.parameters())
|
||||
|
||||
audio_waveform = vocoder(mel_spectrogram)
|
||||
mx.eval(audio_waveform)
|
||||
|
||||
del vocoder
|
||||
mx.clear_cache()
|
||||
|
||||
# Convert to numpy
|
||||
audio_np = np.array(audio_waveform)
|
||||
if audio_np.ndim == 3:
|
||||
audio_np = audio_np[0] # Remove batch dim
|
||||
|
||||
print(f"{Colors.DIM} Audio shape: {audio_np.shape}, duration: {audio_np.shape[-1] / AUDIO_SAMPLE_RATE:.2f}s{Colors.RESET}")
|
||||
|
||||
# Convert video to uint8 frames
|
||||
video = mx.squeeze(video, axis=0) # (C, F, H, W)
|
||||
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
|
||||
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
|
||||
video = (video * 255).astype(mx.uint8)
|
||||
video_np = np.array(video)
|
||||
|
||||
# Save video
|
||||
# Save outputs
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Determine audio output path
|
||||
if audio and audio_np is not None:
|
||||
if output_audio_path is None:
|
||||
audio_output = output_path.parent / f"{output_path.stem}.wav"
|
||||
else:
|
||||
audio_output = Path(output_audio_path)
|
||||
|
||||
# Save audio
|
||||
save_audio(audio_np, audio_output)
|
||||
print(f"{Colors.GREEN}Saved audio to{Colors.RESET} {audio_output}")
|
||||
|
||||
# Save video (to temp file if we need to mux with audio)
|
||||
if audio and audio_np is not None:
|
||||
# Save video to temp file, then mux with audio
|
||||
temp_video_path = output_path.parent / f"{output_path.stem}_temp.mp4"
|
||||
video_save_path = temp_video_path
|
||||
else:
|
||||
video_save_path = output_path
|
||||
|
||||
try:
|
||||
import cv2
|
||||
h, w = video_np.shape[1], video_np.shape[2]
|
||||
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
||||
out = cv2.VideoWriter(str(output_path), fourcc, fps, (w, h))
|
||||
out = cv2.VideoWriter(str(video_save_path), fourcc, fps, (w, h))
|
||||
for frame in video_np:
|
||||
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||
out.release()
|
||||
print(f"{Colors.GREEN}Saved video to{Colors.RESET} {output_path}")
|
||||
|
||||
if audio and audio_np is not None:
|
||||
# Mux video and audio
|
||||
print(f"{Colors.BLUE}Muxing video and audio...{Colors.RESET}")
|
||||
if mux_video_audio(temp_video_path, audio_output, output_path):
|
||||
print(f"{Colors.GREEN}Saved video with audio to{Colors.RESET} {output_path}")
|
||||
# Clean up temp file
|
||||
temp_video_path.unlink(missing_ok=True)
|
||||
else:
|
||||
# Fallback: keep separate files
|
||||
print(f"{Colors.YELLOW}Could not mux, keeping separate files{Colors.RESET}")
|
||||
temp_video_path.rename(output_path.parent / f"{output_path.stem}_video.mp4")
|
||||
else:
|
||||
print(f"{Colors.GREEN}Saved video to{Colors.RESET} {output_path}")
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}Could not save video: {e}{Colors.RESET}")
|
||||
|
||||
@@ -642,6 +1105,10 @@ Examples:
|
||||
|
||||
# Image-to-Video (I2V)
|
||||
python -m mlx_video.generate_dev --prompt "A person dancing" --image photo.jpg
|
||||
|
||||
# With synchronized audio
|
||||
python -m mlx_video.generate_dev --prompt "Ocean waves crashing on rocks" --audio
|
||||
python -m mlx_video.generate_dev --prompt "A busy city street" --audio --output-audio street.wav
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -769,6 +1236,17 @@ Examples:
|
||||
choices=["none", "auto", "default", "aggressive", "conservative", "spatial", "temporal"],
|
||||
help="Tiling mode for VAE decoding (default: none, faster on high-memory systems)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audio",
|
||||
action="store_true",
|
||||
help="Generate synchronized audio with the video"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-audio",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output audio path (default: same as video with .wav extension)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
generate_video_dev(
|
||||
@@ -784,6 +1262,7 @@ Examples:
|
||||
seed=args.seed,
|
||||
fps=args.fps,
|
||||
output_path=args.output_path,
|
||||
output_audio_path=args.output_audio,
|
||||
save_frames=args.save_frames,
|
||||
verbose=args.verbose,
|
||||
enhance_prompt=args.enhance_prompt,
|
||||
@@ -793,6 +1272,7 @@ Examples:
|
||||
image_strength=args.image_strength,
|
||||
image_frame_idx=args.image_frame_idx,
|
||||
tiling=args.tiling,
|
||||
audio=args.audio,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user