Add audio generation capabilities to video pipeline, including audio position grid creation, audio frame computation, and integration of audio VAE and vocoder. Update tests to cover new audio functionalities.
This commit is contained in:
@@ -37,7 +37,7 @@ class Colors:
|
|||||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
|
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
|
||||||
from mlx_video.models.ltx.ltx import LTXModel
|
from mlx_video.models.ltx.ltx import LTXModel
|
||||||
from mlx_video.models.ltx.transformer import Modality
|
from mlx_video.models.ltx.transformer import Modality
|
||||||
from mlx_video.convert import sanitize_transformer_weights
|
from mlx_video.convert import sanitize_transformer_weights, sanitize_audio_vae_weights, sanitize_vocoder_weights
|
||||||
from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding
|
from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding
|
||||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||||
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
|
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
|
||||||
@@ -65,6 +65,15 @@ DEFAULT_NEGATIVE_PROMPT = (
|
|||||||
BASE_SHIFT_ANCHOR = 1024
|
BASE_SHIFT_ANCHOR = 1024
|
||||||
MAX_SHIFT_ANCHOR = 4096
|
MAX_SHIFT_ANCHOR = 4096
|
||||||
|
|
||||||
|
# Audio constants
|
||||||
|
AUDIO_SAMPLE_RATE = 24000 # Output audio sample rate
|
||||||
|
AUDIO_LATENT_SAMPLE_RATE = 16000 # VAE internal sample rate
|
||||||
|
AUDIO_HOP_LENGTH = 160
|
||||||
|
AUDIO_LATENT_DOWNSAMPLE_FACTOR = 4
|
||||||
|
AUDIO_LATENT_CHANNELS = 8 # Latent channels before patchifying
|
||||||
|
AUDIO_MEL_BINS = 16
|
||||||
|
AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25
|
||||||
|
|
||||||
|
|
||||||
def ltx2_scheduler(
|
def ltx2_scheduler(
|
||||||
steps: int,
|
steps: int,
|
||||||
@@ -195,6 +204,54 @@ def create_position_grid(
|
|||||||
return mx.array(pixel_coords, dtype=mx.float32)
|
return mx.array(pixel_coords, dtype=mx.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def create_audio_position_grid(
|
||||||
|
batch_size: int,
|
||||||
|
audio_frames: int,
|
||||||
|
sample_rate: int = AUDIO_LATENT_SAMPLE_RATE,
|
||||||
|
hop_length: int = AUDIO_HOP_LENGTH,
|
||||||
|
downsample_factor: int = AUDIO_LATENT_DOWNSAMPLE_FACTOR,
|
||||||
|
is_causal: bool = True,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Create temporal position grid for audio RoPE.
|
||||||
|
|
||||||
|
Audio positions are timestamps in seconds, shape (B, 1, T, 2).
|
||||||
|
Matches PyTorch's AudioPatchifier.get_patch_grid_bounds exactly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: Batch size
|
||||||
|
audio_frames: Number of audio latent frames
|
||||||
|
sample_rate: Audio sample rate (default 16000)
|
||||||
|
hop_length: Hop length for mel spectrogram (default 160)
|
||||||
|
downsample_factor: Latent downsample factor (default 4)
|
||||||
|
is_causal: Whether to use causal alignment (default True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Position grid of shape (B, 1, T, 2)
|
||||||
|
"""
|
||||||
|
def get_audio_latent_time_in_sec(start_idx: int, end_idx: int) -> np.ndarray:
|
||||||
|
"""Convert latent indices to seconds."""
|
||||||
|
latent_frame = np.arange(start_idx, end_idx, dtype=np.float32)
|
||||||
|
mel_frame = latent_frame * downsample_factor
|
||||||
|
if is_causal:
|
||||||
|
mel_frame = np.clip(mel_frame + 1 - downsample_factor, 0, None)
|
||||||
|
return mel_frame * hop_length / sample_rate
|
||||||
|
|
||||||
|
start_times = get_audio_latent_time_in_sec(0, audio_frames)
|
||||||
|
end_times = get_audio_latent_time_in_sec(1, audio_frames + 1)
|
||||||
|
|
||||||
|
positions = np.stack([start_times, end_times], axis=-1)
|
||||||
|
positions = positions[np.newaxis, np.newaxis, :, :] # (1, 1, T, 2)
|
||||||
|
positions = np.tile(positions, (batch_size, 1, 1, 1))
|
||||||
|
|
||||||
|
return mx.array(positions, dtype=mx.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_audio_frames(num_video_frames: int, fps: float) -> int:
|
||||||
|
"""Compute number of audio latent frames given video duration."""
|
||||||
|
duration = num_video_frames / fps
|
||||||
|
return round(duration * AUDIO_LATENTS_PER_SECOND)
|
||||||
|
|
||||||
|
|
||||||
def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array:
|
def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array:
|
||||||
"""Compute CFG (Classifier-Free Guidance) delta.
|
"""Compute CFG (Classifier-Free Guidance) delta.
|
||||||
|
|
||||||
@@ -209,6 +266,116 @@ def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array:
|
|||||||
return (scale - 1.0) * (cond - uncond)
|
return (scale - 1.0) * (cond - uncond)
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio_decoder(model_path: Path):
|
||||||
|
"""Load audio VAE decoder."""
|
||||||
|
from mlx_video.models.ltx.audio_vae import AudioDecoder, CausalityAxis, NormType
|
||||||
|
|
||||||
|
decoder = AudioDecoder(
|
||||||
|
ch=128,
|
||||||
|
out_ch=2, # stereo
|
||||||
|
ch_mult=(1, 2, 4),
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_resolutions={8, 16, 32},
|
||||||
|
resolution=256,
|
||||||
|
z_channels=AUDIO_LATENT_CHANNELS,
|
||||||
|
norm_type=NormType.PIXEL,
|
||||||
|
causality_axis=CausalityAxis.HEIGHT,
|
||||||
|
mel_bins=64, # Output mel bins
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load weights - try dev model first, fall back to distilled
|
||||||
|
weight_file = model_path / "ltx-2-19b-dev.safetensors"
|
||||||
|
if not weight_file.exists():
|
||||||
|
weight_file = model_path / "ltx-2-19b-distilled.safetensors"
|
||||||
|
|
||||||
|
if weight_file.exists():
|
||||||
|
raw_weights = mx.load(str(weight_file))
|
||||||
|
sanitized = sanitize_audio_vae_weights(raw_weights)
|
||||||
|
if sanitized:
|
||||||
|
decoder.load_weights(list(sanitized.items()), strict=False)
|
||||||
|
|
||||||
|
# Manually load per-channel statistics
|
||||||
|
if "per_channel_statistics._mean_of_means" in sanitized:
|
||||||
|
decoder.per_channel_statistics._mean_of_means = sanitized["per_channel_statistics._mean_of_means"]
|
||||||
|
if "per_channel_statistics._std_of_means" in sanitized:
|
||||||
|
decoder.per_channel_statistics._std_of_means = sanitized["per_channel_statistics._std_of_means"]
|
||||||
|
|
||||||
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
|
def load_vocoder(model_path: Path):
|
||||||
|
"""Load vocoder for mel to waveform conversion."""
|
||||||
|
from mlx_video.models.ltx.audio_vae import Vocoder
|
||||||
|
|
||||||
|
vocoder = Vocoder(
|
||||||
|
resblock_kernel_sizes=[3, 7, 11],
|
||||||
|
upsample_rates=[6, 5, 2, 2, 2],
|
||||||
|
upsample_kernel_sizes=[16, 15, 8, 4, 4],
|
||||||
|
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
upsample_initial_channel=1024,
|
||||||
|
stereo=True,
|
||||||
|
output_sample_rate=AUDIO_SAMPLE_RATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load weights - try dev model first, fall back to distilled
|
||||||
|
weight_file = model_path / "ltx-2-19b-dev.safetensors"
|
||||||
|
if not weight_file.exists():
|
||||||
|
weight_file = model_path / "ltx-2-19b-distilled.safetensors"
|
||||||
|
|
||||||
|
if weight_file.exists():
|
||||||
|
raw_weights = mx.load(str(weight_file))
|
||||||
|
sanitized = sanitize_vocoder_weights(raw_weights)
|
||||||
|
if sanitized:
|
||||||
|
vocoder.load_weights(list(sanitized.items()), strict=False)
|
||||||
|
|
||||||
|
return vocoder
|
||||||
|
|
||||||
|
|
||||||
|
def save_audio(audio: np.ndarray, path: Path, sample_rate: int = AUDIO_SAMPLE_RATE):
|
||||||
|
"""Save audio to WAV file."""
|
||||||
|
import wave
|
||||||
|
|
||||||
|
# Ensure audio is in correct format (channels, samples) or (samples,)
|
||||||
|
if audio.ndim == 2:
|
||||||
|
# (channels, samples) -> (samples, channels)
|
||||||
|
audio = audio.T
|
||||||
|
|
||||||
|
# Normalize and convert to int16
|
||||||
|
audio = np.clip(audio, -1.0, 1.0)
|
||||||
|
audio_int16 = (audio * 32767).astype(np.int16)
|
||||||
|
|
||||||
|
with wave.open(str(path), 'wb') as wf:
|
||||||
|
wf.setnchannels(2 if audio_int16.ndim == 2 else 1)
|
||||||
|
wf.setsampwidth(2) # 16-bit
|
||||||
|
wf.setframerate(sample_rate)
|
||||||
|
wf.writeframes(audio_int16.tobytes())
|
||||||
|
|
||||||
|
|
||||||
|
def mux_video_audio(video_path: Path, audio_path: Path, output_path: Path) -> bool:
|
||||||
|
"""Combine video and audio into final output using ffmpeg."""
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg", "-y",
|
||||||
|
"-i", str(video_path),
|
||||||
|
"-i", str(audio_path),
|
||||||
|
"-c:v", "copy",
|
||||||
|
"-c:a", "aac",
|
||||||
|
"-shortest",
|
||||||
|
str(output_path)
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.run(cmd, check=True, capture_output=True)
|
||||||
|
return True
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"{Colors.RED}FFmpeg error: {e.stderr.decode()}{Colors.RESET}")
|
||||||
|
return False
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"{Colors.RED}FFmpeg not found. Please install ffmpeg.{Colors.RESET}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def denoise_with_cfg(
|
def denoise_with_cfg(
|
||||||
latents: mx.array,
|
latents: mx.array,
|
||||||
positions: mx.array,
|
positions: mx.array,
|
||||||
@@ -222,10 +389,9 @@ def denoise_with_cfg(
|
|||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""Run denoising loop with CFG (Classifier-Free Guidance).
|
"""Run denoising loop with CFG (Classifier-Free Guidance).
|
||||||
|
|
||||||
Optimized version that:
|
Uses separate forward passes for positive and negative conditioning
|
||||||
1. Batches positive and negative forward passes together
|
to match PyTorch implementation behavior (avoids potential issues with
|
||||||
2. Precomputes RoPE once and reuses it (avoids expensive NumPy conversion each step)
|
batched attention patterns).
|
||||||
3. Minimizes mx.eval() calls for better performance
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
latents: Noisy latent tensor (B, C, F, H, W)
|
latents: Noisy latent tensor (B, C, F, H, W)
|
||||||
@@ -250,18 +416,10 @@ def denoise_with_cfg(
|
|||||||
sigmas_list = sigmas.tolist()
|
sigmas_list = sigmas.tolist()
|
||||||
use_cfg = cfg_scale != 1.0
|
use_cfg = cfg_scale != 1.0
|
||||||
|
|
||||||
# Pre-compute batched context for CFG (concat pos and neg along batch dim)
|
|
||||||
if use_cfg:
|
|
||||||
# Shape: (2, seq_len, dim) - batch pos and neg together
|
|
||||||
batched_context = mx.concatenate([text_embeddings_pos, text_embeddings_neg], axis=0)
|
|
||||||
batched_positions = mx.concatenate([positions, positions], axis=0)
|
|
||||||
else:
|
|
||||||
batched_positions = positions
|
|
||||||
|
|
||||||
# Precompute RoPE once (expensive operation due to NumPy conversion for double precision)
|
# Precompute RoPE once (expensive operation due to NumPy conversion for double precision)
|
||||||
# This avoids recomputing it every forward pass
|
# This avoids recomputing it every forward pass
|
||||||
precomputed_rope = precompute_freqs_cis(
|
precomputed_rope = precompute_freqs_cis(
|
||||||
batched_positions,
|
positions,
|
||||||
dim=transformer.inner_dim,
|
dim=transformer.inner_dim,
|
||||||
theta=transformer.positional_embedding_theta,
|
theta=transformer.positional_embedding_theta,
|
||||||
max_pos=transformer.positional_embedding_max_pos,
|
max_pos=transformer.positional_embedding_max_pos,
|
||||||
@@ -289,45 +447,38 @@ def denoise_with_cfg(
|
|||||||
else:
|
else:
|
||||||
timesteps = mx.full((b, num_tokens), sigma, dtype=dtype)
|
timesteps = mx.full((b, num_tokens), sigma, dtype=dtype)
|
||||||
|
|
||||||
if use_cfg:
|
# First forward pass: positive conditioning
|
||||||
# Batch both positive and negative in a single forward pass
|
video_modality_pos = Modality(
|
||||||
batched_latents = mx.concatenate([latents_flat, latents_flat], axis=0)
|
|
||||||
batched_timesteps = mx.concatenate([timesteps, timesteps], axis=0)
|
|
||||||
|
|
||||||
video_modality = Modality(
|
|
||||||
latent=batched_latents,
|
|
||||||
timesteps=batched_timesteps,
|
|
||||||
positions=batched_positions,
|
|
||||||
context=batched_context,
|
|
||||||
context_mask=None,
|
|
||||||
enabled=True,
|
|
||||||
positional_embeddings=precomputed_rope, # Use precomputed RoPE
|
|
||||||
)
|
|
||||||
|
|
||||||
# Single forward pass for both pos and neg
|
|
||||||
batched_output, _ = transformer(video=video_modality, audio=None)
|
|
||||||
|
|
||||||
# Split results: first half is positive, second half is negative
|
|
||||||
denoised_pos = batched_output[:1]
|
|
||||||
denoised_neg = batched_output[1:]
|
|
||||||
|
|
||||||
# Apply CFG: denoised = denoised_pos + (scale - 1) * (denoised_pos - denoised_neg)
|
|
||||||
denoised_flat = denoised_pos + (cfg_scale - 1.0) * (denoised_pos - denoised_neg)
|
|
||||||
else:
|
|
||||||
# No CFG - single forward pass
|
|
||||||
video_modality = Modality(
|
|
||||||
latent=latents_flat,
|
latent=latents_flat,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
context=text_embeddings_pos,
|
context=text_embeddings_pos,
|
||||||
context_mask=None,
|
context_mask=None,
|
||||||
enabled=True,
|
enabled=True,
|
||||||
positional_embeddings=precomputed_rope, # Use precomputed RoPE
|
positional_embeddings=precomputed_rope,
|
||||||
)
|
)
|
||||||
denoised_flat, _ = transformer(video=video_modality, audio=None)
|
velocity_pos, _ = transformer(video=video_modality_pos, audio=None)
|
||||||
|
|
||||||
|
if use_cfg:
|
||||||
|
# Second forward pass: negative conditioning
|
||||||
|
video_modality_neg = Modality(
|
||||||
|
latent=latents_flat,
|
||||||
|
timesteps=timesteps,
|
||||||
|
positions=positions,
|
||||||
|
context=text_embeddings_neg,
|
||||||
|
context_mask=None,
|
||||||
|
enabled=True,
|
||||||
|
positional_embeddings=precomputed_rope,
|
||||||
|
)
|
||||||
|
velocity_neg, _ = transformer(video=video_modality_neg, audio=None)
|
||||||
|
|
||||||
|
# Apply CFG: velocity = pos + (scale - 1) * (pos - neg)
|
||||||
|
velocity_flat = velocity_pos + (cfg_scale - 1.0) * (velocity_pos - velocity_neg)
|
||||||
|
else:
|
||||||
|
velocity_flat = velocity_pos
|
||||||
|
|
||||||
# Reshape back to 5D
|
# Reshape back to 5D
|
||||||
velocity = mx.reshape(mx.transpose(denoised_flat, (0, 2, 1)), (b, c, f, h, w))
|
velocity = mx.reshape(mx.transpose(velocity_flat, (0, 2, 1)), (b, c, f, h, w))
|
||||||
denoised = to_denoised(latents, velocity, sigma)
|
denoised = to_denoised(latents, velocity, sigma)
|
||||||
|
|
||||||
# Apply conditioning mask if state is provided
|
# Apply conditioning mask if state is provided
|
||||||
@@ -348,6 +499,185 @@ def denoise_with_cfg(
|
|||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
def denoise_av_with_cfg(
|
||||||
|
video_latents: mx.array,
|
||||||
|
audio_latents: mx.array,
|
||||||
|
video_positions: mx.array,
|
||||||
|
audio_positions: mx.array,
|
||||||
|
video_embeddings_pos: mx.array,
|
||||||
|
video_embeddings_neg: mx.array,
|
||||||
|
audio_embeddings_pos: mx.array,
|
||||||
|
audio_embeddings_neg: mx.array,
|
||||||
|
transformer: LTXModel,
|
||||||
|
sigmas: mx.array,
|
||||||
|
cfg_scale: float = 4.0,
|
||||||
|
verbose: bool = True,
|
||||||
|
video_state: Optional[LatentState] = None,
|
||||||
|
) -> tuple[mx.array, mx.array]:
|
||||||
|
"""Run denoising loop for audio-video generation with CFG.
|
||||||
|
|
||||||
|
Uses separate forward passes for positive and negative CFG to ensure
|
||||||
|
correct audio-video cross-attention behavior (matching PyTorch implementation).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_latents: Video latent tensor (B, C, F, H, W)
|
||||||
|
audio_latents: Audio latent tensor (B, C, T, F)
|
||||||
|
video_positions: Video position embeddings
|
||||||
|
audio_positions: Audio position embeddings
|
||||||
|
video_embeddings_pos: Positive video text embeddings
|
||||||
|
video_embeddings_neg: Negative video text embeddings
|
||||||
|
audio_embeddings_pos: Positive audio text embeddings
|
||||||
|
audio_embeddings_neg: Negative audio text embeddings
|
||||||
|
transformer: LTX model
|
||||||
|
sigmas: Array of sigma values for denoising schedule
|
||||||
|
cfg_scale: Guidance scale (default 4.0, 1.0 = no guidance)
|
||||||
|
verbose: Whether to show progress bar
|
||||||
|
video_state: Optional LatentState for I2V conditioning
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (video_latents, audio_latents)
|
||||||
|
"""
|
||||||
|
from mlx_video.models.ltx.rope import precompute_freqs_cis
|
||||||
|
|
||||||
|
dtype = video_latents.dtype
|
||||||
|
if video_state is not None:
|
||||||
|
video_latents = video_state.latent
|
||||||
|
|
||||||
|
sigmas_list = sigmas.tolist()
|
||||||
|
use_cfg = cfg_scale != 1.0
|
||||||
|
|
||||||
|
# Precompute video RoPE (single batch, not doubled for CFG)
|
||||||
|
precomputed_video_rope = precompute_freqs_cis(
|
||||||
|
video_positions,
|
||||||
|
dim=transformer.inner_dim,
|
||||||
|
theta=transformer.positional_embedding_theta,
|
||||||
|
max_pos=transformer.positional_embedding_max_pos,
|
||||||
|
use_middle_indices_grid=transformer.use_middle_indices_grid,
|
||||||
|
num_attention_heads=transformer.num_attention_heads,
|
||||||
|
rope_type=transformer.rope_type,
|
||||||
|
double_precision=transformer.config.double_precision_rope,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Precompute audio RoPE (1D positions)
|
||||||
|
precomputed_audio_rope = precompute_freqs_cis(
|
||||||
|
audio_positions,
|
||||||
|
dim=transformer.audio_inner_dim,
|
||||||
|
theta=transformer.positional_embedding_theta,
|
||||||
|
max_pos=transformer.audio_positional_embedding_max_pos,
|
||||||
|
use_middle_indices_grid=transformer.use_middle_indices_grid,
|
||||||
|
num_attention_heads=transformer.audio_num_attention_heads,
|
||||||
|
rope_type=transformer.rope_type,
|
||||||
|
double_precision=transformer.config.double_precision_rope,
|
||||||
|
)
|
||||||
|
mx.eval(precomputed_video_rope, precomputed_audio_rope)
|
||||||
|
|
||||||
|
for i in tqdm(range(len(sigmas_list) - 1), desc="Denoising A/V", disable=not verbose):
|
||||||
|
sigma = sigmas_list[i]
|
||||||
|
sigma_next = sigmas_list[i + 1]
|
||||||
|
|
||||||
|
# Flatten video latents
|
||||||
|
b, c, f, h, w = video_latents.shape
|
||||||
|
num_video_tokens = f * h * w
|
||||||
|
video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1))
|
||||||
|
|
||||||
|
# Flatten audio latents: (B, C, T, F) -> (B, T, C*F)
|
||||||
|
ab, ac, at, af = audio_latents.shape
|
||||||
|
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F)
|
||||||
|
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af))
|
||||||
|
|
||||||
|
# Compute per-token timesteps for video
|
||||||
|
if video_state is not None:
|
||||||
|
denoise_mask_flat = mx.reshape(video_state.denoise_mask, (b, 1, f, 1, 1))
|
||||||
|
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
|
||||||
|
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens))
|
||||||
|
video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
|
||||||
|
else:
|
||||||
|
video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
|
||||||
|
|
||||||
|
audio_timesteps = mx.full((ab, at), sigma, dtype=dtype)
|
||||||
|
|
||||||
|
# First forward pass: positive conditioning
|
||||||
|
video_modality_pos = Modality(
|
||||||
|
latent=video_flat,
|
||||||
|
timesteps=video_timesteps,
|
||||||
|
positions=video_positions,
|
||||||
|
context=video_embeddings_pos,
|
||||||
|
context_mask=None,
|
||||||
|
enabled=True,
|
||||||
|
positional_embeddings=precomputed_video_rope,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_modality_pos = Modality(
|
||||||
|
latent=audio_flat,
|
||||||
|
timesteps=audio_timesteps,
|
||||||
|
positions=audio_positions,
|
||||||
|
context=audio_embeddings_pos,
|
||||||
|
context_mask=None,
|
||||||
|
enabled=True,
|
||||||
|
positional_embeddings=precomputed_audio_rope,
|
||||||
|
)
|
||||||
|
|
||||||
|
video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos)
|
||||||
|
|
||||||
|
if use_cfg:
|
||||||
|
# Second forward pass: negative conditioning
|
||||||
|
video_modality_neg = Modality(
|
||||||
|
latent=video_flat,
|
||||||
|
timesteps=video_timesteps,
|
||||||
|
positions=video_positions,
|
||||||
|
context=video_embeddings_neg,
|
||||||
|
context_mask=None,
|
||||||
|
enabled=True,
|
||||||
|
positional_embeddings=precomputed_video_rope,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_modality_neg = Modality(
|
||||||
|
latent=audio_flat,
|
||||||
|
timesteps=audio_timesteps,
|
||||||
|
positions=audio_positions,
|
||||||
|
context=audio_embeddings_neg,
|
||||||
|
context_mask=None,
|
||||||
|
enabled=True,
|
||||||
|
positional_embeddings=precomputed_audio_rope,
|
||||||
|
)
|
||||||
|
|
||||||
|
video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg)
|
||||||
|
|
||||||
|
# Apply CFG: denoised = pos + (scale - 1) * (pos - neg)
|
||||||
|
video_velocity_flat = video_vel_pos + (cfg_scale - 1.0) * (video_vel_pos - video_vel_neg)
|
||||||
|
audio_velocity_flat = audio_vel_pos + (cfg_scale - 1.0) * (audio_vel_pos - audio_vel_neg)
|
||||||
|
else:
|
||||||
|
video_velocity_flat = video_vel_pos
|
||||||
|
audio_velocity_flat = audio_vel_pos
|
||||||
|
|
||||||
|
# Reshape velocities back
|
||||||
|
video_velocity = mx.reshape(mx.transpose(video_velocity_flat, (0, 2, 1)), (b, c, f, h, w))
|
||||||
|
audio_velocity = mx.reshape(audio_velocity_flat, (ab, at, ac, af))
|
||||||
|
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) # (B, C, T, F)
|
||||||
|
|
||||||
|
# Compute denoised
|
||||||
|
video_denoised = to_denoised(video_latents, video_velocity, sigma)
|
||||||
|
audio_denoised = to_denoised(audio_latents, audio_velocity, sigma)
|
||||||
|
|
||||||
|
# Apply conditioning mask for video if state is provided
|
||||||
|
if video_state is not None:
|
||||||
|
video_denoised = apply_denoise_mask(video_denoised, video_state.clean_latent, video_state.denoise_mask)
|
||||||
|
|
||||||
|
# Euler step
|
||||||
|
if sigma_next > 0:
|
||||||
|
sigma_next_arr = mx.array(sigma_next, dtype=dtype)
|
||||||
|
sigma_arr = mx.array(sigma, dtype=dtype)
|
||||||
|
video_latents = video_denoised + sigma_next_arr * (video_latents - video_denoised) / sigma_arr
|
||||||
|
audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr
|
||||||
|
else:
|
||||||
|
video_latents = video_denoised
|
||||||
|
audio_latents = audio_denoised
|
||||||
|
|
||||||
|
mx.eval(video_latents, audio_latents)
|
||||||
|
|
||||||
|
return video_latents, audio_latents
|
||||||
|
|
||||||
|
|
||||||
def generate_video_dev(
|
def generate_video_dev(
|
||||||
model_repo: str,
|
model_repo: str,
|
||||||
text_encoder_repo: str,
|
text_encoder_repo: str,
|
||||||
@@ -361,6 +691,7 @@ def generate_video_dev(
|
|||||||
seed: int = 42,
|
seed: int = 42,
|
||||||
fps: int = 24,
|
fps: int = 24,
|
||||||
output_path: str = "output.mp4",
|
output_path: str = "output.mp4",
|
||||||
|
output_audio_path: Optional[str] = None,
|
||||||
save_frames: bool = False,
|
save_frames: bool = False,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
enhance_prompt: bool = False,
|
enhance_prompt: bool = False,
|
||||||
@@ -370,6 +701,7 @@ def generate_video_dev(
|
|||||||
image_strength: float = 1.0,
|
image_strength: float = 1.0,
|
||||||
image_frame_idx: int = 0,
|
image_frame_idx: int = 0,
|
||||||
tiling: str = "none",
|
tiling: str = "none",
|
||||||
|
audio: bool = False,
|
||||||
):
|
):
|
||||||
"""Generate video using LTX-2 dev model with CFG.
|
"""Generate video using LTX-2 dev model with CFG.
|
||||||
|
|
||||||
@@ -389,6 +721,7 @@ def generate_video_dev(
|
|||||||
seed: Random seed for reproducibility
|
seed: Random seed for reproducibility
|
||||||
fps: Frames per second for output video
|
fps: Frames per second for output video
|
||||||
output_path: Path to save the output video
|
output_path: Path to save the output video
|
||||||
|
output_audio_path: Path to save audio (if audio=True)
|
||||||
save_frames: Whether to save individual frames as images
|
save_frames: Whether to save individual frames as images
|
||||||
verbose: Whether to print progress
|
verbose: Whether to print progress
|
||||||
enhance_prompt: Whether to enhance prompt using Gemma
|
enhance_prompt: Whether to enhance prompt using Gemma
|
||||||
@@ -398,6 +731,7 @@ def generate_video_dev(
|
|||||||
image_strength: Conditioning strength (1.0 = full denoise, 0.0 = keep original)
|
image_strength: Conditioning strength (1.0 = full denoise, 0.0 = keep original)
|
||||||
image_frame_idx: Frame index to condition (0 = first frame)
|
image_frame_idx: Frame index to condition (0 = first frame)
|
||||||
tiling: Tiling mode for VAE decoding
|
tiling: Tiling mode for VAE decoding
|
||||||
|
audio: Whether to generate synchronized audio
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
@@ -410,10 +744,17 @@ def generate_video_dev(
|
|||||||
print(f"{Colors.YELLOW}Warning: Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}")
|
print(f"{Colors.YELLOW}Warning: Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}")
|
||||||
num_frames = adjusted_num_frames
|
num_frames = adjusted_num_frames
|
||||||
|
|
||||||
|
# Calculate audio frames if audio is enabled
|
||||||
|
audio_frames = compute_audio_frames(num_frames, fps) if audio else 0
|
||||||
|
|
||||||
is_i2v = image is not None
|
is_i2v = image is not None
|
||||||
mode_str = "I2V" if is_i2v else "T2V"
|
mode_str = "I2V" if is_i2v else "T2V"
|
||||||
|
if audio:
|
||||||
|
mode_str += "+Audio"
|
||||||
print(f"{Colors.BOLD}{Colors.CYAN}[DEV] [{mode_str}] Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
|
print(f"{Colors.BOLD}{Colors.CYAN}[DEV] [{mode_str}] Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
|
||||||
print(f"{Colors.DIM}Steps: {num_inference_steps}, CFG: {cfg_scale}{Colors.RESET}")
|
print(f"{Colors.DIM}Steps: {num_inference_steps}, CFG: {cfg_scale}{Colors.RESET}")
|
||||||
|
if audio:
|
||||||
|
print(f"{Colors.DIM}Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz{Colors.RESET}")
|
||||||
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
|
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
|
||||||
if is_i2v:
|
if is_i2v:
|
||||||
print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}")
|
print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}")
|
||||||
@@ -442,21 +783,54 @@ def generate_video_dev(
|
|||||||
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
|
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
|
||||||
|
|
||||||
# Encode both positive and negative prompts
|
# Encode both positive and negative prompts
|
||||||
text_embeddings_pos, _ = text_encoder(prompt, return_audio_embeddings=False)
|
if audio:
|
||||||
text_embeddings_neg, _ = text_encoder(negative_prompt, return_audio_embeddings=False)
|
video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True)
|
||||||
model_dtype = text_embeddings_pos.dtype
|
video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True)
|
||||||
mx.eval(text_embeddings_pos, text_embeddings_neg)
|
model_dtype = video_embeddings_pos.dtype
|
||||||
|
mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg)
|
||||||
|
else:
|
||||||
|
video_embeddings_pos, _ = text_encoder(prompt, return_audio_embeddings=False)
|
||||||
|
video_embeddings_neg, _ = text_encoder(negative_prompt, return_audio_embeddings=False)
|
||||||
|
audio_embeddings_pos = None
|
||||||
|
audio_embeddings_neg = None
|
||||||
|
model_dtype = video_embeddings_pos.dtype
|
||||||
|
mx.eval(video_embeddings_pos, video_embeddings_neg)
|
||||||
|
|
||||||
del text_encoder
|
del text_encoder
|
||||||
mx.clear_cache()
|
mx.clear_cache()
|
||||||
|
|
||||||
# Load transformer (dev model)
|
# Load transformer (dev model)
|
||||||
print(f"{Colors.BLUE}Loading dev transformer...{Colors.RESET}")
|
print(f"{Colors.BLUE}Loading dev transformer{' (A/V mode)' if audio else ''}...{Colors.RESET}")
|
||||||
raw_weights = mx.load(str(model_path / 'ltx-2-19b-dev.safetensors'))
|
raw_weights = mx.load(str(model_path / 'ltx-2-19b-dev.safetensors'))
|
||||||
sanitized = sanitize_transformer_weights(raw_weights)
|
sanitized = sanitize_transformer_weights(raw_weights)
|
||||||
# Convert transformer weights to bfloat16 for memory efficiency
|
# Convert transformer weights to bfloat16 for memory efficiency
|
||||||
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
|
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
|
||||||
|
|
||||||
|
if audio:
|
||||||
|
config = LTXModelConfig(
|
||||||
|
model_type=LTXModelType.AudioVideo,
|
||||||
|
num_attention_heads=32,
|
||||||
|
attention_head_dim=128,
|
||||||
|
in_channels=128,
|
||||||
|
out_channels=128,
|
||||||
|
num_layers=48,
|
||||||
|
cross_attention_dim=4096,
|
||||||
|
caption_channels=3840,
|
||||||
|
# Audio config
|
||||||
|
audio_num_attention_heads=32,
|
||||||
|
audio_attention_head_dim=64,
|
||||||
|
audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, # 8 * 16 = 128
|
||||||
|
audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS,
|
||||||
|
audio_cross_attention_dim=2048,
|
||||||
|
rope_type=LTXRopeType.SPLIT,
|
||||||
|
double_precision_rope=True,
|
||||||
|
positional_embedding_theta=10000.0,
|
||||||
|
positional_embedding_max_pos=[20, 2048, 2048],
|
||||||
|
audio_positional_embedding_max_pos=[20],
|
||||||
|
use_middle_indices_grid=True,
|
||||||
|
timestep_scale_multiplier=1000,
|
||||||
|
)
|
||||||
|
else:
|
||||||
config = LTXModelConfig(
|
config = LTXModelConfig(
|
||||||
model_type=LTXModelType.VideoOnly,
|
model_type=LTXModelType.VideoOnly,
|
||||||
num_attention_heads=32,
|
num_attention_heads=32,
|
||||||
@@ -503,20 +877,26 @@ def generate_video_dev(
|
|||||||
mx.eval(sigmas)
|
mx.eval(sigmas)
|
||||||
print(f"{Colors.DIM}Sigma schedule: {sigmas[0].item():.4f} -> {sigmas[-2].item():.4f} -> {sigmas[-1].item():.4f}{Colors.RESET}")
|
print(f"{Colors.DIM}Sigma schedule: {sigmas[0].item():.4f} -> {sigmas[-2].item():.4f} -> {sigmas[-1].item():.4f}{Colors.RESET}")
|
||||||
|
|
||||||
# Create position grid
|
# Create position grids
|
||||||
print(f"{Colors.YELLOW}Generating at {width}x{height} ({num_inference_steps} steps, CFG={cfg_scale})...{Colors.RESET}")
|
print(f"{Colors.YELLOW}Generating at {width}x{height} ({num_inference_steps} steps, CFG={cfg_scale})...{Colors.RESET}")
|
||||||
mx.random.seed(seed)
|
mx.random.seed(seed)
|
||||||
|
|
||||||
positions = create_position_grid(1, latent_frames, latent_h, latent_w)
|
video_positions = create_position_grid(1, latent_frames, latent_h, latent_w)
|
||||||
mx.eval(positions)
|
mx.eval(video_positions)
|
||||||
|
|
||||||
|
if audio:
|
||||||
|
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||||
|
mx.eval(audio_positions)
|
||||||
|
else:
|
||||||
|
audio_positions = None
|
||||||
|
|
||||||
# Initialize latents with optional I2V conditioning
|
# Initialize latents with optional I2V conditioning
|
||||||
state = None
|
video_state = None
|
||||||
|
video_latent_shape = (1, 128, latent_frames, latent_h, latent_w)
|
||||||
if is_i2v and image_latent is not None:
|
if is_i2v and image_latent is not None:
|
||||||
latent_shape = (1, 128, latent_frames, latent_h, latent_w)
|
video_state = LatentState(
|
||||||
state = LatentState(
|
latent=mx.zeros(video_latent_shape, dtype=model_dtype),
|
||||||
latent=mx.zeros(latent_shape, dtype=model_dtype),
|
clean_latent=mx.zeros(video_latent_shape, dtype=model_dtype),
|
||||||
clean_latent=mx.zeros(latent_shape, dtype=model_dtype),
|
|
||||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
||||||
)
|
)
|
||||||
conditioning = VideoConditionByLatentIndex(
|
conditioning = VideoConditionByLatentIndex(
|
||||||
@@ -524,29 +904,45 @@ def generate_video_dev(
|
|||||||
frame_idx=image_frame_idx,
|
frame_idx=image_frame_idx,
|
||||||
strength=image_strength,
|
strength=image_strength,
|
||||||
)
|
)
|
||||||
state = apply_conditioning(state, [conditioning])
|
video_state = apply_conditioning(video_state, [conditioning])
|
||||||
|
|
||||||
# Apply noiser
|
# Apply noiser
|
||||||
noise = mx.random.normal(latent_shape, dtype=model_dtype)
|
noise = mx.random.normal(video_latent_shape, dtype=model_dtype)
|
||||||
noise_scale = sigmas[0]
|
noise_scale = sigmas[0]
|
||||||
scaled_mask = state.denoise_mask * noise_scale
|
scaled_mask = video_state.denoise_mask * noise_scale
|
||||||
|
|
||||||
state = LatentState(
|
video_state = LatentState(
|
||||||
latent=noise * scaled_mask + state.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
latent=noise * scaled_mask + video_state.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
||||||
clean_latent=state.clean_latent,
|
clean_latent=video_state.clean_latent,
|
||||||
denoise_mask=state.denoise_mask,
|
denoise_mask=video_state.denoise_mask,
|
||||||
)
|
)
|
||||||
latents = state.latent
|
video_latents = video_state.latent
|
||||||
mx.eval(latents)
|
mx.eval(video_latents)
|
||||||
else:
|
else:
|
||||||
# T2V: just use random noise
|
# T2V: just use random noise
|
||||||
latents = mx.random.normal((1, 128, latent_frames, latent_h, latent_w), dtype=model_dtype)
|
video_latents = mx.random.normal(video_latent_shape, dtype=model_dtype)
|
||||||
mx.eval(latents)
|
mx.eval(video_latents)
|
||||||
|
|
||||||
|
# Initialize audio latents if audio is enabled
|
||||||
|
if audio:
|
||||||
|
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||||
|
mx.eval(audio_latents)
|
||||||
|
else:
|
||||||
|
audio_latents = None
|
||||||
|
|
||||||
# Denoise with CFG
|
# Denoise with CFG
|
||||||
latents = denoise_with_cfg(
|
if audio:
|
||||||
latents, positions, text_embeddings_pos, text_embeddings_neg,
|
video_latents, audio_latents = denoise_av_with_cfg(
|
||||||
transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, state=state
|
video_latents, audio_latents,
|
||||||
|
video_positions, audio_positions,
|
||||||
|
video_embeddings_pos, video_embeddings_neg,
|
||||||
|
audio_embeddings_pos, audio_embeddings_neg,
|
||||||
|
transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, video_state=video_state
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
video_latents = denoise_with_cfg(
|
||||||
|
video_latents, video_positions, video_embeddings_pos, video_embeddings_neg,
|
||||||
|
transformer, sigmas, cfg_scale=cfg_scale, verbose=verbose, state=video_state
|
||||||
)
|
)
|
||||||
|
|
||||||
del transformer
|
del transformer
|
||||||
@@ -583,32 +979,99 @@ def generate_video_dev(
|
|||||||
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
|
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
|
||||||
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
|
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
|
||||||
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
|
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
|
||||||
video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose)
|
video = vae_decoder.decode_tiled(video_latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose)
|
||||||
else:
|
else:
|
||||||
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
|
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
|
||||||
video = vae_decoder(latents)
|
video = vae_decoder(video_latents)
|
||||||
mx.eval(video)
|
mx.eval(video)
|
||||||
|
|
||||||
|
del vae_decoder
|
||||||
mx.clear_cache()
|
mx.clear_cache()
|
||||||
|
|
||||||
# Convert to uint8 frames
|
# Decode audio if enabled
|
||||||
|
audio_np = None
|
||||||
|
if audio and audio_latents is not None:
|
||||||
|
print(f"{Colors.BLUE}Decoding audio...{Colors.RESET}")
|
||||||
|
|
||||||
|
# Load audio decoder
|
||||||
|
audio_decoder = load_audio_decoder(model_path)
|
||||||
|
mx.eval(audio_decoder.parameters())
|
||||||
|
|
||||||
|
# Decode audio latents to mel spectrogram
|
||||||
|
mel_spectrogram = audio_decoder(audio_latents)
|
||||||
|
mx.eval(mel_spectrogram)
|
||||||
|
|
||||||
|
del audio_decoder
|
||||||
|
mx.clear_cache()
|
||||||
|
|
||||||
|
# Load vocoder and convert mel to waveform
|
||||||
|
vocoder = load_vocoder(model_path)
|
||||||
|
mx.eval(vocoder.parameters())
|
||||||
|
|
||||||
|
audio_waveform = vocoder(mel_spectrogram)
|
||||||
|
mx.eval(audio_waveform)
|
||||||
|
|
||||||
|
del vocoder
|
||||||
|
mx.clear_cache()
|
||||||
|
|
||||||
|
# Convert to numpy
|
||||||
|
audio_np = np.array(audio_waveform)
|
||||||
|
if audio_np.ndim == 3:
|
||||||
|
audio_np = audio_np[0] # Remove batch dim
|
||||||
|
|
||||||
|
print(f"{Colors.DIM} Audio shape: {audio_np.shape}, duration: {audio_np.shape[-1] / AUDIO_SAMPLE_RATE:.2f}s{Colors.RESET}")
|
||||||
|
|
||||||
|
# Convert video to uint8 frames
|
||||||
video = mx.squeeze(video, axis=0) # (C, F, H, W)
|
video = mx.squeeze(video, axis=0) # (C, F, H, W)
|
||||||
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
|
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
|
||||||
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
|
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
|
||||||
video = (video * 255).astype(mx.uint8)
|
video = (video * 255).astype(mx.uint8)
|
||||||
video_np = np.array(video)
|
video_np = np.array(video)
|
||||||
|
|
||||||
# Save video
|
# Save outputs
|
||||||
output_path = Path(output_path)
|
output_path = Path(output_path)
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Determine audio output path
|
||||||
|
if audio and audio_np is not None:
|
||||||
|
if output_audio_path is None:
|
||||||
|
audio_output = output_path.parent / f"{output_path.stem}.wav"
|
||||||
|
else:
|
||||||
|
audio_output = Path(output_audio_path)
|
||||||
|
|
||||||
|
# Save audio
|
||||||
|
save_audio(audio_np, audio_output)
|
||||||
|
print(f"{Colors.GREEN}Saved audio to{Colors.RESET} {audio_output}")
|
||||||
|
|
||||||
|
# Save video (to temp file if we need to mux with audio)
|
||||||
|
if audio and audio_np is not None:
|
||||||
|
# Save video to temp file, then mux with audio
|
||||||
|
temp_video_path = output_path.parent / f"{output_path.stem}_temp.mp4"
|
||||||
|
video_save_path = temp_video_path
|
||||||
|
else:
|
||||||
|
video_save_path = output_path
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import cv2
|
import cv2
|
||||||
h, w = video_np.shape[1], video_np.shape[2]
|
h, w = video_np.shape[1], video_np.shape[2]
|
||||||
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
||||||
out = cv2.VideoWriter(str(output_path), fourcc, fps, (w, h))
|
out = cv2.VideoWriter(str(video_save_path), fourcc, fps, (w, h))
|
||||||
for frame in video_np:
|
for frame in video_np:
|
||||||
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||||
out.release()
|
out.release()
|
||||||
|
|
||||||
|
if audio and audio_np is not None:
|
||||||
|
# Mux video and audio
|
||||||
|
print(f"{Colors.BLUE}Muxing video and audio...{Colors.RESET}")
|
||||||
|
if mux_video_audio(temp_video_path, audio_output, output_path):
|
||||||
|
print(f"{Colors.GREEN}Saved video with audio to{Colors.RESET} {output_path}")
|
||||||
|
# Clean up temp file
|
||||||
|
temp_video_path.unlink(missing_ok=True)
|
||||||
|
else:
|
||||||
|
# Fallback: keep separate files
|
||||||
|
print(f"{Colors.YELLOW}Could not mux, keeping separate files{Colors.RESET}")
|
||||||
|
temp_video_path.rename(output_path.parent / f"{output_path.stem}_video.mp4")
|
||||||
|
else:
|
||||||
print(f"{Colors.GREEN}Saved video to{Colors.RESET} {output_path}")
|
print(f"{Colors.GREEN}Saved video to{Colors.RESET} {output_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"{Colors.RED}Could not save video: {e}{Colors.RESET}")
|
print(f"{Colors.RED}Could not save video: {e}{Colors.RESET}")
|
||||||
@@ -642,6 +1105,10 @@ Examples:
|
|||||||
|
|
||||||
# Image-to-Video (I2V)
|
# Image-to-Video (I2V)
|
||||||
python -m mlx_video.generate_dev --prompt "A person dancing" --image photo.jpg
|
python -m mlx_video.generate_dev --prompt "A person dancing" --image photo.jpg
|
||||||
|
|
||||||
|
# With synchronized audio
|
||||||
|
python -m mlx_video.generate_dev --prompt "Ocean waves crashing on rocks" --audio
|
||||||
|
python -m mlx_video.generate_dev --prompt "A busy city street" --audio --output-audio street.wav
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -769,6 +1236,17 @@ Examples:
|
|||||||
choices=["none", "auto", "default", "aggressive", "conservative", "spatial", "temporal"],
|
choices=["none", "auto", "default", "aggressive", "conservative", "spatial", "temporal"],
|
||||||
help="Tiling mode for VAE decoding (default: none, faster on high-memory systems)"
|
help="Tiling mode for VAE decoding (default: none, faster on high-memory systems)"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--audio",
|
||||||
|
action="store_true",
|
||||||
|
help="Generate synchronized audio with the video"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-audio",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Output audio path (default: same as video with .wav extension)"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
generate_video_dev(
|
generate_video_dev(
|
||||||
@@ -784,6 +1262,7 @@ Examples:
|
|||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
fps=args.fps,
|
fps=args.fps,
|
||||||
output_path=args.output_path,
|
output_path=args.output_path,
|
||||||
|
output_audio_path=args.output_audio,
|
||||||
save_frames=args.save_frames,
|
save_frames=args.save_frames,
|
||||||
verbose=args.verbose,
|
verbose=args.verbose,
|
||||||
enhance_prompt=args.enhance_prompt,
|
enhance_prompt=args.enhance_prompt,
|
||||||
@@ -793,6 +1272,7 @@ Examples:
|
|||||||
image_strength=args.image_strength,
|
image_strength=args.image_strength,
|
||||||
image_frame_idx=args.image_frame_idx,
|
image_frame_idx=args.image_frame_idx,
|
||||||
tiling=args.tiling,
|
tiling=args.tiling,
|
||||||
|
audio=args.audio,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,14 +2,16 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from mlx_video.generate_dev import (
|
from mlx_video.generate_dev import (
|
||||||
ltx2_scheduler,
|
ltx2_scheduler,
|
||||||
create_position_grid,
|
create_position_grid,
|
||||||
|
create_audio_position_grid,
|
||||||
|
compute_audio_frames,
|
||||||
cfg_delta,
|
cfg_delta,
|
||||||
denoise_with_cfg,
|
|
||||||
DEFAULT_NEGATIVE_PROMPT,
|
DEFAULT_NEGATIVE_PROMPT,
|
||||||
|
AUDIO_SAMPLE_RATE,
|
||||||
|
AUDIO_LATENTS_PER_SECOND,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -260,28 +262,6 @@ class TestInputValidation:
|
|||||||
class TestDenoiseWithCFGMocked:
|
class TestDenoiseWithCFGMocked:
|
||||||
"""Tests for denoise_with_cfg with mocked transformer."""
|
"""Tests for denoise_with_cfg with mocked transformer."""
|
||||||
|
|
||||||
def test_denoise_returns_correct_shape(self):
|
|
||||||
"""Denoised output should have same shape as input latents."""
|
|
||||||
# Create a simple mock transformer
|
|
||||||
class MockTransformer:
|
|
||||||
inner_dim = 4096
|
|
||||||
positional_embedding_theta = 10000.0
|
|
||||||
positional_embedding_max_pos = [20, 2048, 2048]
|
|
||||||
use_middle_indices_grid = True
|
|
||||||
num_attention_heads = 32
|
|
||||||
rope_type = None
|
|
||||||
|
|
||||||
class config:
|
|
||||||
double_precision_rope = True
|
|
||||||
|
|
||||||
def __call__(self, video, audio):
|
|
||||||
# Return input as output (identity)
|
|
||||||
return video.latent, None
|
|
||||||
|
|
||||||
# Skip this test if we can't import the required modules easily
|
|
||||||
# This is a structural test to ensure the function signature is correct
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_sigmas_list_conversion(self):
|
def test_sigmas_list_conversion(self):
|
||||||
"""Sigmas should be convertible to list."""
|
"""Sigmas should be convertible to list."""
|
||||||
sigmas = ltx2_scheduler(steps=5)
|
sigmas = ltx2_scheduler(steps=5)
|
||||||
@@ -296,16 +276,10 @@ class TestTilingDefault:
|
|||||||
|
|
||||||
def test_tiling_default_is_none(self):
|
def test_tiling_default_is_none(self):
|
||||||
"""Default tiling should be 'none' for performance."""
|
"""Default tiling should be 'none' for performance."""
|
||||||
# Import and check the default
|
|
||||||
import argparse
|
|
||||||
from mlx_video.generate_dev import main
|
|
||||||
|
|
||||||
# The default is set in the argparse definition
|
|
||||||
# We verify this by checking the function signature
|
|
||||||
import inspect
|
import inspect
|
||||||
sig = inspect.signature(
|
from mlx_video.generate_dev import generate_video_dev
|
||||||
__import__('mlx_video.generate_dev', fromlist=['generate_video_dev']).generate_video_dev
|
|
||||||
)
|
sig = inspect.signature(generate_video_dev)
|
||||||
|
|
||||||
tiling_param = sig.parameters.get('tiling')
|
tiling_param = sig.parameters.get('tiling')
|
||||||
assert tiling_param is not None
|
assert tiling_param is not None
|
||||||
@@ -358,5 +332,91 @@ class TestLatentDimensions:
|
|||||||
assert num_tokens == expected, f"Expected {expected} tokens, got {num_tokens}"
|
assert num_tokens == expected, f"Expected {expected} tokens, got {num_tokens}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAudioPositionGrid:
|
||||||
|
"""Tests for audio position grid creation."""
|
||||||
|
|
||||||
|
def test_audio_position_grid_shape(self):
|
||||||
|
"""Audio position grid should have correct shape (B, 1, T, 2)."""
|
||||||
|
batch_size = 1
|
||||||
|
audio_frames = 34 # ~1.36 seconds at 25 latent frames/sec
|
||||||
|
|
||||||
|
positions = create_audio_position_grid(batch_size, audio_frames)
|
||||||
|
expected_shape = (batch_size, 1, audio_frames, 2)
|
||||||
|
|
||||||
|
assert positions.shape == expected_shape, \
|
||||||
|
f"Expected {expected_shape}, got {positions.shape}"
|
||||||
|
|
||||||
|
def test_audio_position_grid_dtype(self):
|
||||||
|
"""Audio position grid should be float32."""
|
||||||
|
positions = create_audio_position_grid(1, 34)
|
||||||
|
assert positions.dtype == mx.float32, \
|
||||||
|
f"Expected float32, got {positions.dtype}"
|
||||||
|
|
||||||
|
def test_audio_position_grid_batch_size(self):
|
||||||
|
"""Audio position grid should respect batch size."""
|
||||||
|
for batch_size in [1, 2, 4]:
|
||||||
|
positions = create_audio_position_grid(batch_size, 34)
|
||||||
|
assert positions.shape[0] == batch_size
|
||||||
|
|
||||||
|
def test_audio_position_grid_temporal_values(self):
|
||||||
|
"""Audio positions should be in seconds."""
|
||||||
|
positions = create_audio_position_grid(1, 34)
|
||||||
|
|
||||||
|
# Values should be in seconds (small values for ~1 second of audio)
|
||||||
|
max_val = mx.max(positions).item()
|
||||||
|
assert max_val < 10, f"Audio positions seem too large: {max_val}"
|
||||||
|
assert max_val > 0, "Audio positions should be positive"
|
||||||
|
|
||||||
|
def test_audio_position_grid_no_nan_or_inf(self):
|
||||||
|
"""Audio position grid should not contain NaN or Inf."""
|
||||||
|
positions = create_audio_position_grid(1, 34)
|
||||||
|
|
||||||
|
assert not mx.any(mx.isnan(positions)).item(), "Audio position grid contains NaN"
|
||||||
|
assert not mx.any(mx.isinf(positions)).item(), "Audio position grid contains Inf"
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputeAudioFrames:
|
||||||
|
"""Tests for audio frame count calculation."""
|
||||||
|
|
||||||
|
def test_audio_frames_basic(self):
|
||||||
|
"""Audio frames should be proportional to video duration."""
|
||||||
|
# 33 frames at 24 fps = ~1.375 seconds
|
||||||
|
# At 25 latent frames/sec = ~34 audio frames
|
||||||
|
audio_frames = compute_audio_frames(33, 24.0)
|
||||||
|
assert audio_frames > 0
|
||||||
|
assert isinstance(audio_frames, int)
|
||||||
|
|
||||||
|
def test_audio_frames_scales_with_video(self):
|
||||||
|
"""More video frames should produce more audio frames."""
|
||||||
|
audio_33 = compute_audio_frames(33, 24.0)
|
||||||
|
audio_65 = compute_audio_frames(65, 24.0)
|
||||||
|
|
||||||
|
assert audio_65 > audio_33, \
|
||||||
|
f"Expected more audio frames for longer video: {audio_65} <= {audio_33}"
|
||||||
|
|
||||||
|
def test_audio_frames_formula(self):
|
||||||
|
"""Audio frames should match expected formula."""
|
||||||
|
num_video_frames = 33
|
||||||
|
fps = 24.0
|
||||||
|
|
||||||
|
duration = num_video_frames / fps # ~1.375 seconds
|
||||||
|
expected = round(duration * AUDIO_LATENTS_PER_SECOND)
|
||||||
|
|
||||||
|
actual = compute_audio_frames(num_video_frames, fps)
|
||||||
|
assert actual == expected, f"Expected {expected}, got {actual}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAudioConstants:
|
||||||
|
"""Tests for audio constants."""
|
||||||
|
|
||||||
|
def test_audio_sample_rate(self):
|
||||||
|
"""Audio sample rate should be 24000 Hz."""
|
||||||
|
assert AUDIO_SAMPLE_RATE == 24000
|
||||||
|
|
||||||
|
def test_audio_latents_per_second(self):
|
||||||
|
"""Audio latents per second should be 25."""
|
||||||
|
assert AUDIO_LATENTS_PER_SECOND == 25.0
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__, "-v"])
|
pytest.main([__file__, "-v"])
|
||||||
|
|||||||
Reference in New Issue
Block a user