Remove the audio-video generation pipeline from generate_av.py and integrate audio capabilities into generate.py. This includes adding audio position grid creation, audio frame computation, and updating the denoising function to handle audio latents. Enhance the command-line interface to support audio generation options and update the model configuration accordingly.

This commit is contained in:
Prince Canuma
2026-01-19 01:28:53 +01:00
parent 749762a0b9
commit cae11291a9
2 changed files with 377 additions and 856 deletions

View File

@@ -1,14 +1,15 @@
import argparse import argparse
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
# ANSI color codes # ANSI color codes
class Colors: class Colors:
CYAN = "\033[96m" CYAN = "\033[96m"
@@ -21,25 +22,33 @@ class Colors:
DIM = "\033[2m" DIM = "\033[2m"
RESET = "\033[0m" RESET = "\033[0m"
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
from mlx_video.models.ltx.ltx import LTXModel from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.transformer import Modality from mlx_video.models.ltx.transformer import Modality
from mlx_video.convert import sanitize_transformer_weights, sanitize_vae_encoder_weights from mlx_video.convert import sanitize_transformer_weights
from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding, get_model_path
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
from mlx_video.models.ltx.video_vae.tiling import TilingConfig from mlx_video.models.ltx.video_vae.tiling import TilingConfig
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning
from mlx_video.conditioning.latent import LatentState, create_initial_state, apply_denoise_mask, add_noise_with_state from mlx_video.conditioning.latent import LatentState, apply_denoise_mask
from mlx_video.utils import get_model_path
# Distilled sigma schedules # Distilled sigma schedules
STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0] STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
# Audio constants
AUDIO_SAMPLE_RATE = 24000 # Output audio sample rate
AUDIO_LATENT_SAMPLE_RATE = 16000 # VAE internal sample rate
AUDIO_HOP_LENGTH = 160
AUDIO_LATENT_DOWNSAMPLE_FACTOR = 4
AUDIO_LATENT_CHANNELS = 8 # Latent channels before patchifying
AUDIO_MEL_BINS = 16
AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25
def create_position_grid( def create_position_grid(
batch_size: int, batch_size: int,
@@ -115,6 +124,43 @@ def create_position_grid(
return mx.array(pixel_coords, dtype=mx.float32) return mx.array(pixel_coords, dtype=mx.float32)
def create_audio_position_grid(
batch_size: int,
audio_frames: int,
sample_rate: int = AUDIO_LATENT_SAMPLE_RATE,
hop_length: int = AUDIO_HOP_LENGTH,
downsample_factor: int = AUDIO_LATENT_DOWNSAMPLE_FACTOR,
is_causal: bool = True,
) -> mx.array:
"""Create temporal position grid for audio RoPE.
Audio positions are timestamps in seconds, shape (B, 1, T, 2).
Matches PyTorch's AudioPatchifier.get_patch_grid_bounds exactly.
"""
def get_audio_latent_time_in_sec(start_idx: int, end_idx: int) -> np.ndarray:
"""Convert latent indices to seconds."""
latent_frame = np.arange(start_idx, end_idx, dtype=np.float32)
mel_frame = latent_frame * downsample_factor
if is_causal:
mel_frame = np.clip(mel_frame + 1 - downsample_factor, 0, None)
return mel_frame * hop_length / sample_rate
start_times = get_audio_latent_time_in_sec(0, audio_frames)
end_times = get_audio_latent_time_in_sec(1, audio_frames + 1)
positions = np.stack([start_times, end_times], axis=-1)
positions = positions[np.newaxis, np.newaxis, :, :] # (1, 1, T, 2)
positions = np.tile(positions, (batch_size, 1, 1, 1))
return mx.array(positions, dtype=mx.float32)
def compute_audio_frames(num_video_frames: int, fps: float) -> int:
"""Compute number of audio latent frames given video duration."""
duration = num_video_frames / fps
return round(duration * AUDIO_LATENTS_PER_SECOND)
def denoise( def denoise(
latents: mx.array, latents: mx.array,
positions: mx.array, positions: mx.array,
@@ -123,27 +169,37 @@ def denoise(
sigmas: list, sigmas: list,
verbose: bool = True, verbose: bool = True,
state: Optional[LatentState] = None, state: Optional[LatentState] = None,
) -> mx.array: # Audio parameters (optional)
"""Run denoising loop with optional conditioning. audio_latents: Optional[mx.array] = None,
audio_positions: Optional[mx.array] = None,
audio_embeddings: Optional[mx.array] = None,
) -> tuple[mx.array, Optional[mx.array]]:
"""Run denoising loop with optional conditioning and optional audio.
Args: Args:
latents: Noisy latent tensor (B, C, F, H, W) latents: Noisy video latent tensor (B, C, F, H, W)
positions: Position embeddings positions: Video position embeddings
text_embeddings: Text conditioning embeddings text_embeddings: Video text conditioning embeddings
transformer: LTX model transformer: LTX model
sigmas: List of sigma values for denoising schedule sigmas: List of sigma values for denoising schedule
verbose: Whether to show progress bar verbose: Whether to show progress bar
state: Optional LatentState for I2V conditioning state: Optional LatentState for I2V conditioning
audio_latents: Optional audio latent tensor (B, C, T, F) for audio generation
audio_positions: Optional audio position embeddings
audio_embeddings: Optional audio text embeddings
Returns: Returns:
Denoised latent tensor Tuple of (video_latents, audio_latents) - audio_latents is None if audio disabled
""" """
# If state is provided, use its latent (which may have conditioning applied)
dtype = latents.dtype dtype = latents.dtype
enable_audio = audio_latents is not None
# If state is provided, use its latent (which may have conditioning applied)
if state is not None: if state is not None:
latents = state.latent latents = state.latent
for i in tqdm(range(len(sigmas) - 1), desc="Denoising", disable=not verbose): desc = "Denoising A/V" if enable_audio else "Denoising"
for i in tqdm(range(len(sigmas) - 1), desc=desc, disable=not verbose):
sigma, sigma_next = sigmas[i], sigmas[i + 1] sigma, sigma_next = sigmas[i], sigmas[i + 1]
b, c, f, h, w = latents.shape b, c, f, h, w = latents.shape
@@ -172,28 +228,163 @@ def denoise(
enabled=True, enabled=True,
) )
velocity, _ = transformer(video=video_modality, audio=None) # Prepare audio modality if enabled
audio_modality = None
if enable_audio:
ab, ac, at, af = audio_latents.shape
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F)
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af))
audio_modality = Modality(
latent=audio_flat,
timesteps=mx.full((ab, at), sigma, dtype=dtype),
positions=audio_positions,
context=audio_embeddings,
context_mask=None,
enabled=True,
)
velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality)
mx.eval(velocity) mx.eval(velocity)
if audio_velocity is not None:
mx.eval(audio_velocity)
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w)) velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
denoised = to_denoised(latents, velocity, sigma) denoised = to_denoised(latents, velocity, sigma)
# Handle audio velocity if enabled
audio_denoised = None
if enable_audio and audio_velocity is not None:
ab, ac, at, af = audio_latents.shape
audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af))
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) # (B, C, T, F)
audio_denoised = to_denoised(audio_latents, audio_velocity, sigma)
# Apply conditioning mask if state is provided # Apply conditioning mask if state is provided
if state is not None: if state is not None:
denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask) denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask)
mx.eval(denoised) mx.eval(denoised)
if audio_denoised is not None:
mx.eval(audio_denoised)
# Euler step (preserve dtype by converting Python floats to arrays) # Euler step (preserve dtype by converting Python floats to arrays)
if sigma_next > 0: if sigma_next > 0:
sigma_next_arr = mx.array(sigma_next, dtype=dtype) sigma_next_arr = mx.array(sigma_next, dtype=dtype)
sigma_arr = mx.array(sigma, dtype=dtype) sigma_arr = mx.array(sigma, dtype=dtype)
latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr
if enable_audio and audio_denoised is not None:
audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr
else: else:
latents = denoised latents = denoised
mx.eval(latents) if enable_audio and audio_denoised is not None:
audio_latents = audio_denoised
return latents mx.eval(latents)
if enable_audio:
mx.eval(audio_latents)
return latents, audio_latents if enable_audio else None
def load_audio_decoder(model_path: Path):
"""Load audio VAE decoder."""
from mlx_video.models.ltx.audio_vae import AudioDecoder, CausalityAxis, NormType
from mlx_video.convert import sanitize_audio_vae_weights
decoder = AudioDecoder(
ch=128,
out_ch=2, # stereo
ch_mult=(1, 2, 4),
num_res_blocks=2,
attn_resolutions=set(),
resolution=256,
z_channels=AUDIO_LATENT_CHANNELS,
norm_type=NormType.PIXEL,
causality_axis=CausalityAxis.HEIGHT,
mel_bins=64,
)
weight_file = model_path / "ltx-2-19b-distilled.safetensors"
if weight_file.exists():
raw_weights = mx.load(str(weight_file))
sanitized = sanitize_audio_vae_weights(raw_weights)
if sanitized:
decoder.load_weights(list(sanitized.items()), strict=False)
if "per_channel_statistics._mean_of_means" in sanitized:
decoder.per_channel_statistics._mean_of_means = sanitized["per_channel_statistics._mean_of_means"]
if "per_channel_statistics._std_of_means" in sanitized:
decoder.per_channel_statistics._std_of_means = sanitized["per_channel_statistics._std_of_means"]
return decoder
def load_vocoder(model_path: Path):
"""Load vocoder for mel to waveform conversion."""
from mlx_video.models.ltx.audio_vae import Vocoder
from mlx_video.convert import sanitize_vocoder_weights
vocoder = Vocoder(
resblock_kernel_sizes=[3, 7, 11],
upsample_rates=[6, 5, 2, 2, 2],
upsample_kernel_sizes=[16, 15, 8, 4, 4],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_initial_channel=1024,
stereo=True,
output_sample_rate=AUDIO_SAMPLE_RATE,
)
weight_file = model_path / "ltx-2-19b-distilled.safetensors"
if weight_file.exists():
raw_weights = mx.load(str(weight_file))
sanitized = sanitize_vocoder_weights(raw_weights)
if sanitized:
vocoder.load_weights(list(sanitized.items()), strict=False)
return vocoder
def save_audio(audio: np.ndarray, path: Path, sample_rate: int = AUDIO_SAMPLE_RATE):
"""Save audio to WAV file."""
import wave
if audio.ndim == 2:
audio = audio.T # (channels, samples) -> (samples, channels)
audio = np.clip(audio, -1.0, 1.0)
audio_int16 = (audio * 32767).astype(np.int16)
with wave.open(str(path), 'wb') as wf:
wf.setnchannels(2 if audio_int16.ndim == 2 else 1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
wf.writeframes(audio_int16.tobytes())
def mux_video_audio(video_path: Path, audio_path: Path, output_path: Path):
"""Combine video and audio into final output using ffmpeg."""
import subprocess
cmd = [
"ffmpeg", "-y",
"-i", str(video_path),
"-i", str(audio_path),
"-c:v", "copy",
"-c:a", "aac",
"-shortest",
str(output_path)
]
try:
subprocess.run(cmd, check=True, capture_output=True)
return True
except subprocess.CalledProcessError as e:
print(f"{Colors.RED}FFmpeg error: {e.stderr.decode()}{Colors.RESET}")
return False
except FileNotFoundError:
print(f"{Colors.RED}FFmpeg not found. Please install ffmpeg.{Colors.RESET}")
return False
def generate_video( def generate_video(
@@ -216,8 +407,11 @@ def generate_video(
image_frame_idx: int = 0, image_frame_idx: int = 0,
tiling: str = "auto", tiling: str = "auto",
stream: bool = False, stream: bool = False,
# Audio options
audio: bool = False,
output_audio_path: Optional[str] = None,
): ):
"""Generate video from text prompt, optionally conditioned on an image. """Generate video from text prompt, optionally conditioned on an image and with audio.
Args: Args:
model_repo: Model repository ID model_repo: Model repository ID
@@ -245,26 +439,37 @@ def generate_video(
- "conservative": 768px spatial, 96 frame temporal (faster) - "conservative": 768px spatial, 96 frame temporal (faster)
- "spatial": Spatial tiling only - "spatial": Spatial tiling only
- "temporal": Temporal tiling only - "temporal": Temporal tiling only
stream: Stream frames to output as they're decoded (requires tiling)
audio: Enable synchronized audio generation
output_audio_path: Path to save audio file (default: same as video with .wav)
""" """
start_time = time.time() start_time = time.time()
# Validate dimensions # Validate dimensions
assert height % 64 == 0, f"Height must be divisible by 64, got {height}" assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
assert width % 64 == 0, f"Width must be divisible by 64, got {width}" assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
if num_frames % 8 != 1: if num_frames % 8 != 1:
adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1 adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1
print(f"{Colors.YELLOW}⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}") print(f"{Colors.YELLOW}⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}")
num_frames = adjusted_num_frames num_frames = adjusted_num_frames
is_i2v = image is not None is_i2v = image is not None
mode_str = "I2V" if is_i2v else "T2V" mode_str = "I2V" if is_i2v else "T2V"
if audio:
mode_str += "+Audio"
print(f"{Colors.BOLD}{Colors.CYAN}🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames{Colors.RESET}") print(f"{Colors.BOLD}{Colors.CYAN}🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}") print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
if is_i2v: if is_i2v:
print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}") print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}")
# Calculate audio frames if enabled
audio_frames = None
if audio:
audio_frames = compute_audio_frames(num_frames, fps)
print(f"{Colors.DIM}Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz{Colors.RESET}")
# Get model path # Get model path
model_path = get_model_path(model_repo) model_path = get_model_path(model_repo)
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo) text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
@@ -289,22 +494,32 @@ def generate_video(
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose) prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}") print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False) # Get embeddings - with audio if enabled
if audio:
text_embeddings, audio_embeddings = text_encoder(prompt, return_audio_embeddings=True)
mx.eval(text_embeddings, audio_embeddings)
else:
text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False)
audio_embeddings = None
mx.eval(text_embeddings)
model_dtype = text_embeddings.dtype # bfloat16 from text encoder model_dtype = text_embeddings.dtype # bfloat16 from text encoder
mx.eval(text_embeddings)
del text_encoder del text_encoder
mx.clear_cache() mx.clear_cache()
# Load transformer # Load transformer
print(f"{Colors.BLUE}🤖 Loading transformer...{Colors.RESET}") print(f"{Colors.BLUE}🤖 Loading transformer{' (A/V mode)' if audio else ''}...{Colors.RESET}")
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors')) raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
sanitized = sanitize_transformer_weights(raw_weights) sanitized = sanitize_transformer_weights(raw_weights)
# Convert transformer weights to bfloat16 for memory efficiency # Convert transformer weights to bfloat16 for memory efficiency
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()} sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
config = LTXModelConfig( # Configure model type based on audio flag
model_type=LTXModelType.VideoOnly, model_type = LTXModelType.AudioVideo if audio else LTXModelType.VideoOnly
config_kwargs = dict(
model_type=model_type,
num_attention_heads=32, num_attention_heads=32,
attention_head_dim=128, attention_head_dim=128,
in_channels=128, in_channels=128,
@@ -320,7 +535,19 @@ def generate_video(
timestep_scale_multiplier=1000, timestep_scale_multiplier=1000,
) )
transformer = LTXModel(config) if audio:
config_kwargs.update(
audio_num_attention_heads=32,
audio_attention_head_dim=64,
audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, # 8 * 16 = 128
audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS,
audio_cross_attention_dim=2048,
audio_positional_embedding_max_pos=[20],
)
config = LTXModelConfig(**config_kwargs)
transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False) transformer.load_weights(list(sanitized.items()), strict=False)
mx.eval(transformer.parameters()) mx.eval(transformer.parameters())
@@ -357,6 +584,14 @@ def generate_video(
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions) mx.eval(positions)
# Create audio positions if enabled
audio_positions = None
audio_latents = None
if audio:
audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
mx.eval(audio_positions, audio_latents)
# Apply I2V conditioning if provided # Apply I2V conditioning if provided
state1 = None state1 = None
if is_i2v and stage1_image_latent is not None: if is_i2v and stage1_image_latent is not None:
@@ -394,7 +629,11 @@ def generate_video(
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w), dtype=model_dtype) latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w), dtype=model_dtype)
mx.eval(latents) mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1) latents, audio_latents = denoise(
latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS,
verbose=verbose, state=state1,
audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings,
)
# Upsample latents # Upsample latents
print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}") print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}")
@@ -447,6 +686,13 @@ def generate_video(
) )
latents = state2.latent latents = state2.latent
mx.eval(latents) mx.eval(latents)
# Audio also gets noise for stage 2 if enabled
if audio and audio_latents is not None:
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
mx.eval(audio_latents)
else: else:
# T2V: add noise to all frames for refinement # T2V: add noise to all frames for refinement
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
@@ -455,7 +701,17 @@ def generate_video(
latents = noise * noise_scale + latents * one_minus_scale latents = noise * noise_scale + latents * one_minus_scale
mx.eval(latents) mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2) # Audio also gets noise for stage 2 if enabled
if audio and audio_latents is not None:
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
mx.eval(audio_latents)
latents, audio_latents = denoise(
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
verbose=verbose, state=state2,
audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings,
)
del transformer del transformer
mx.clear_cache() mx.clear_cache()
@@ -496,7 +752,7 @@ def generate_video(
video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height)) video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
stream_pbar = tqdm(total=num_frames, desc="Streaming", unit="frame") stream_pbar = tqdm(total=num_frames, desc="Streaming", unit="frame")
def on_frames_ready(frames: mx.array, start_idx: int): def on_frames_ready(frames: mx.array, _start_idx: int):
"""Callback to write frames as they're finalized.""" """Callback to write frames as they're finalized."""
# frames: (B, 3, num_frames, H, W) # frames: (B, 3, num_frames, H, W)
frames = mx.squeeze(frames, axis=0) # (3, num_frames, H, W) frames = mx.squeeze(frames, axis=0) # (3, num_frames, H, W)
@@ -542,19 +798,66 @@ def generate_video(
video = (video * 255).astype(mx.uint8) video = (video * 255).astype(mx.uint8)
video_np = np.array(video) video_np = np.array(video)
# Save video normally # For audio mode, save to temp file first
if audio:
temp_video_path = output_path.with_suffix('.temp.mp4')
save_path = temp_video_path
else:
save_path = output_path
# Save video
try: try:
import cv2 import cv2
h, w = video_np.shape[1], video_np.shape[2] h, w = video_np.shape[1], video_np.shape[2]
fourcc = cv2.VideoWriter_fourcc(*'avc1') fourcc = cv2.VideoWriter_fourcc(*'avc1')
out = cv2.VideoWriter(str(output_path), fourcc, fps, (w, h)) out = cv2.VideoWriter(str(save_path), fourcc, fps, (w, h))
for frame in video_np: for frame in video_np:
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
out.release() out.release()
print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}") if not audio:
print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}")
except Exception as e: except Exception as e:
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}") print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}")
# Decode and save audio if enabled
audio_np = None
if audio and audio_latents is not None:
print(f"{Colors.BLUE}🔊 Decoding audio...{Colors.RESET}")
audio_decoder = load_audio_decoder(model_path)
vocoder = load_vocoder(model_path)
mx.eval(audio_decoder.parameters(), vocoder.parameters())
mel_spectrogram = audio_decoder(audio_latents)
mx.eval(mel_spectrogram)
audio_waveform = vocoder(mel_spectrogram)
mx.eval(audio_waveform)
audio_np = np.array(audio_waveform)
if audio_np.ndim == 3:
audio_np = audio_np[0]
del audio_decoder, vocoder
mx.clear_cache()
# Save audio
audio_path = Path(output_audio_path) if output_audio_path else output_path.with_suffix('.wav')
save_audio(audio_np, audio_path, AUDIO_SAMPLE_RATE)
print(f"{Colors.GREEN}✅ Saved audio to{Colors.RESET} {audio_path}")
# Mux video and audio
print(f"{Colors.BLUE}🎬 Combining video and audio...{Colors.RESET}")
temp_video_path = output_path.with_suffix('.temp.mp4')
if mux_video_audio(temp_video_path, audio_path, output_path):
print(f"{Colors.GREEN}✅ Saved video with audio to{Colors.RESET} {output_path}")
temp_video_path.unlink()
else:
temp_video_path.rename(output_path)
print(f"{Colors.YELLOW}⚠️ Saved video without audio to{Colors.RESET} {output_path}")
del vae_decoder
mx.clear_cache()
if save_frames: if save_frames:
frames_dir = output_path.parent / f"{output_path.stem}_frames" frames_dir = output_path.parent / f"{output_path.stem}_frames"
frames_dir.mkdir(exist_ok=True) frames_dir.mkdir(exist_ok=True)
@@ -566,12 +869,14 @@ def generate_video(
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}") print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}")
print(f"{Colors.BOLD}{Colors.GREEN}✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}") print(f"{Colors.BOLD}{Colors.GREEN}✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}")
if audio:
return video_np, audio_np
return video_np return video_np
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Generate videos with MLX LTX-2 (T2V and I2V)", description="Generate videos with MLX LTX-2 (T2V, I2V, and Audio)",
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=""" epilog="""
Examples: Examples:
@@ -583,6 +888,11 @@ Examples:
# Image-to-Video (I2V) # Image-to-Video (I2V)
python -m mlx_video.generate --prompt "A person dancing" --image photo.jpg python -m mlx_video.generate --prompt "A person dancing" --image photo.jpg
python -m mlx_video.generate --prompt "Waves crashing" --image beach.png --image-strength 0.8 python -m mlx_video.generate --prompt "Waves crashing" --image beach.png --image-strength 0.8
# With Audio (T2V+Audio or I2V+Audio)
python -m mlx_video.generate --prompt "Ocean waves crashing" --audio
python -m mlx_video.generate --prompt "A jazz band playing" --audio --enhance-prompt
python -m mlx_video.generate --prompt "Waves crashing" --image beach.png --audio
""" """
) )
@@ -623,7 +933,7 @@ Examples:
help="Frames per second for output video (default: 24)" help="Frames per second for output video (default: 24)"
) )
parser.add_argument( parser.add_argument(
"--output-path", "--output-path", "-o",
type=str, type=str,
default="output.mp4", default="output.mp4",
help="Output video path (default: output.mp4)" help="Output video path (default: output.mp4)"
@@ -699,10 +1009,42 @@ Examples:
action="store_true", action="store_true",
help="Stream frames to output file as they're decoded (requires tiling). Allows viewing partial results sooner." help="Stream frames to output file as they're decoded (requires tiling). Allows viewing partial results sooner."
) )
# Audio options
parser.add_argument(
"--audio", "-a",
action="store_true",
help="Enable synchronized audio generation"
)
parser.add_argument(
"--output-audio",
type=str,
default=None,
help="Output audio path (default: same as video with .wav)"
)
args = parser.parse_args() args = parser.parse_args()
generate_video( generate_video(
**vars(args) model_repo=args.model_repo,
text_encoder_repo=args.text_encoder_repo,
prompt=args.prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
seed=args.seed,
fps=args.fps,
output_path=args.output_path,
save_frames=args.save_frames,
verbose=args.verbose,
enhance_prompt=args.enhance_prompt,
max_tokens=args.max_tokens,
temperature=args.temperature,
image=args.image,
image_strength=args.image_strength,
image_frame_idx=args.image_frame_idx,
tiling=args.tiling,
stream=args.stream,
audio=args.audio,
output_audio_path=args.output_audio,
) )

View File

@@ -1,821 +0,0 @@
"""Audio-Video generation pipeline for LTX-2."""
import argparse
import time
from pathlib import Path
from typing import Optional
import mlx.core as mx
import numpy as np
from tqdm import tqdm
# ANSI color codes
class Colors:
CYAN = "\033[96m"
BLUE = "\033[94m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
RED = "\033[91m"
MAGENTA = "\033[95m"
BOLD = "\033[1m"
DIM = "\033[2m"
RESET = "\033[0m"
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.transformer import Modality
from mlx_video.convert import sanitize_transformer_weights, sanitize_audio_vae_weights, sanitize_vocoder_weights
from mlx_video.utils import to_denoised, get_model_path, load_image, prepare_image_for_encoding
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning
from mlx_video.conditioning.latent import LatentState, apply_denoise_mask
# Distilled sigma schedules
STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
# Audio constants
AUDIO_SAMPLE_RATE = 24000 # Output audio sample rate
AUDIO_LATENT_SAMPLE_RATE = 16000 # VAE internal sample rate
AUDIO_HOP_LENGTH = 160
AUDIO_LATENT_DOWNSAMPLE_FACTOR = 4
AUDIO_LATENT_CHANNELS = 8 # Latent channels before patchifying
AUDIO_MEL_BINS = 16
AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25
def create_video_position_grid(
batch_size: int,
num_frames: int,
height: int,
width: int,
temporal_scale: int = 8,
spatial_scale: int = 32,
fps: float = 24.0,
causal_fix: bool = True,
) -> mx.array:
"""Create position grid for video RoPE in pixel space."""
patch_size_t, patch_size_h, patch_size_w = 1, 1, 1
t_coords = np.arange(0, num_frames, patch_size_t)
h_coords = np.arange(0, height, patch_size_h)
w_coords = np.arange(0, width, patch_size_w)
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
patch_size_delta = np.array([patch_size_t, patch_size_h, patch_size_w]).reshape(3, 1, 1, 1)
patch_ends = patch_starts + patch_size_delta
latent_coords = np.stack([patch_starts, patch_ends], axis=-1)
num_patches = num_frames * height * width
latent_coords = latent_coords.reshape(3, num_patches, 2)
latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1))
scale_factors = np.array([temporal_scale, spatial_scale, spatial_scale]).reshape(1, 3, 1, 1)
pixel_coords = (latent_coords * scale_factors).astype(np.float32)
if causal_fix:
pixel_coords[:, 0, :, :] = np.clip(
pixel_coords[:, 0, :, :] + 1 - temporal_scale,
a_min=0,
a_max=None
)
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps
return mx.array(pixel_coords, dtype=mx.float32)
def create_audio_position_grid(
batch_size: int,
audio_frames: int,
sample_rate: int = AUDIO_LATENT_SAMPLE_RATE,
hop_length: int = AUDIO_HOP_LENGTH,
downsample_factor: int = AUDIO_LATENT_DOWNSAMPLE_FACTOR,
is_causal: bool = True,
) -> mx.array:
"""Create temporal position grid for audio RoPE.
Audio positions are timestamps in seconds, shape (B, 1, T, 2).
Matches PyTorch's AudioPatchifier.get_patch_grid_bounds exactly.
"""
def get_audio_latent_time_in_sec(start_idx: int, end_idx: int) -> np.ndarray:
"""Convert latent indices to seconds (matching PyTorch's _get_audio_latent_time_in_sec)."""
latent_frame = np.arange(start_idx, end_idx, dtype=np.float32)
mel_frame = latent_frame * downsample_factor
if is_causal:
# Frame offset for causal alignment (PyTorch uses +1 - downsample_factor)
mel_frame = np.clip(mel_frame + 1 - downsample_factor, 0, None)
return mel_frame * hop_length / sample_rate
# Start times: latent indices 0 to audio_frames
start_times = get_audio_latent_time_in_sec(0, audio_frames)
# End times: latent indices 1 to audio_frames+1 (shifted by 1)
end_times = get_audio_latent_time_in_sec(1, audio_frames + 1)
# Shape: (B, 1, T, 2)
positions = np.stack([start_times, end_times], axis=-1)
positions = positions[np.newaxis, np.newaxis, :, :] # (1, 1, T, 2)
positions = np.tile(positions, (batch_size, 1, 1, 1))
return mx.array(positions, dtype=mx.float32)
def compute_audio_frames(num_video_frames: int, fps: float) -> int:
"""Compute number of audio latent frames given video duration."""
duration = num_video_frames / fps
return round(duration * AUDIO_LATENTS_PER_SECOND)
def denoise_av(
video_latents: mx.array,
audio_latents: mx.array,
video_positions: mx.array,
audio_positions: mx.array,
video_embeddings: mx.array,
audio_embeddings: mx.array,
transformer: LTXModel,
sigmas: list,
verbose: bool = True,
video_state: Optional[LatentState] = None,
) -> tuple[mx.array, mx.array]:
"""Run denoising loop for audio-video generation with optional I2V conditioning.
Args:
video_latents: Video latent tensor (B, C, F, H, W)
audio_latents: Audio latent tensor (B, C, T, F)
video_positions: Video position embeddings
audio_positions: Audio position embeddings
video_embeddings: Video text embeddings
audio_embeddings: Audio text embeddings
transformer: LTX model
sigmas: List of sigma values
verbose: Whether to show progress bar
video_state: Optional LatentState for I2V conditioning
Returns:
Tuple of (video_latents, audio_latents)
"""
dtype = video_latents.dtype
# If video state is provided, use its latent
if video_state is not None:
video_latents = video_state.latent
for i in tqdm(range(len(sigmas) - 1), desc="Denoising A/V", disable=not verbose):
sigma, sigma_next = sigmas[i], sigmas[i + 1]
# Flatten video latents
b, c, f, h, w = video_latents.shape
num_video_tokens = f * h * w
video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1))
# Flatten audio latents: (B, C, T, F) -> (B, T, C*F)
ab, ac, at, af = audio_latents.shape
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F)
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af))
# Compute per-token timesteps for video
# For I2V: conditioned tokens get timestep=0 (mask=0), unconditioned get timestep=sigma (mask=1)
if video_state is not None:
# Reshape denoise_mask from (B, 1, F, 1, 1) to (B, num_tokens)
denoise_mask_flat = mx.reshape(video_state.denoise_mask, (b, 1, f, 1, 1))
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens))
# Per-token timesteps: sigma * mask
video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
else:
# All tokens get the same timestep
video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
video_modality = Modality(
latent=video_flat,
timesteps=video_timesteps,
positions=video_positions,
context=video_embeddings,
context_mask=None,
enabled=True,
)
audio_modality = Modality(
latent=audio_flat,
timesteps=mx.full((ab, at), sigma, dtype=dtype),
positions=audio_positions,
context=audio_embeddings,
context_mask=None,
enabled=True,
)
video_velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality)
mx.eval(video_velocity, audio_velocity)
# Reshape velocities back
video_velocity = mx.reshape(mx.transpose(video_velocity, (0, 2, 1)), (b, c, f, h, w))
audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af))
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) # (B, C, T, F)
# Compute denoised
video_denoised = to_denoised(video_latents, video_velocity, sigma)
audio_denoised = to_denoised(audio_latents, audio_velocity, sigma)
# Apply conditioning mask for video if state is provided
if video_state is not None:
video_denoised = apply_denoise_mask(video_denoised, video_state.clean_latent, video_state.denoise_mask)
mx.eval(video_denoised, audio_denoised)
# Euler step - use dtype-preserving arrays to avoid float32 promotion
if sigma_next > 0:
sigma_next_arr = mx.array(sigma_next, dtype=dtype)
sigma_arr = mx.array(sigma, dtype=dtype)
video_latents = video_denoised + sigma_next_arr * (video_latents - video_denoised) / sigma_arr
audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr
else:
video_latents = video_denoised
audio_latents = audio_denoised
mx.eval(video_latents, audio_latents)
return video_latents, audio_latents
def load_audio_decoder(model_path: Path):
"""Load audio VAE decoder."""
from mlx_video.models.ltx.audio_vae import AudioDecoder, CausalityAxis, NormType
decoder = AudioDecoder(
ch=128,
out_ch=2, # stereo
ch_mult=(1, 2, 4),
num_res_blocks=2,
attn_resolutions=set(), # PyTorch uses empty set (no attention in audio decoder)
resolution=256,
z_channels=AUDIO_LATENT_CHANNELS,
norm_type=NormType.PIXEL,
causality_axis=CausalityAxis.HEIGHT,
mel_bins=64, # Output mel bins
)
# Load weights from main model file
weight_file = model_path / "ltx-2-19b-distilled.safetensors"
if weight_file.exists():
raw_weights = mx.load(str(weight_file))
sanitized = sanitize_audio_vae_weights(raw_weights)
if sanitized:
decoder.load_weights(list(sanitized.items()), strict=False)
# Manually load per-channel statistics (they're plain mx.array, not tracked by load_weights)
if "per_channel_statistics._mean_of_means" in sanitized:
decoder.per_channel_statistics._mean_of_means = sanitized["per_channel_statistics._mean_of_means"]
if "per_channel_statistics._std_of_means" in sanitized:
decoder.per_channel_statistics._std_of_means = sanitized["per_channel_statistics._std_of_means"]
return decoder
def load_vocoder(model_path: Path):
"""Load vocoder for mel to waveform conversion."""
from mlx_video.models.ltx.audio_vae import Vocoder
vocoder = Vocoder(
resblock_kernel_sizes=[3, 7, 11],
upsample_rates=[6, 5, 2, 2, 2],
upsample_kernel_sizes=[16, 15, 8, 4, 4],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_initial_channel=1024,
stereo=True,
output_sample_rate=AUDIO_SAMPLE_RATE,
)
# Load weights
weight_file = model_path / "ltx-2-19b-distilled.safetensors"
if weight_file.exists():
raw_weights = mx.load(str(weight_file))
sanitized = sanitize_vocoder_weights(raw_weights)
if sanitized:
vocoder.load_weights(list(sanitized.items()), strict=False)
return vocoder
def save_audio(audio: np.ndarray, path: Path, sample_rate: int = AUDIO_SAMPLE_RATE):
"""Save audio to WAV file."""
import wave
# Ensure audio is in correct format (channels, samples) or (samples,)
if audio.ndim == 2:
# (channels, samples) -> (samples, channels)
audio = audio.T
# Normalize and convert to int16
audio = np.clip(audio, -1.0, 1.0)
audio_int16 = (audio * 32767).astype(np.int16)
with wave.open(str(path), 'wb') as wf:
wf.setnchannels(2 if audio_int16.ndim == 2 else 1)
wf.setsampwidth(2) # 16-bit
wf.setframerate(sample_rate)
wf.writeframes(audio_int16.tobytes())
def mux_video_audio(video_path: Path, audio_path: Path, output_path: Path):
"""Combine video and audio into final output using ffmpeg."""
import subprocess
cmd = [
"ffmpeg", "-y",
"-i", str(video_path),
"-i", str(audio_path),
"-c:v", "copy",
"-c:a", "aac",
"-shortest",
str(output_path)
]
try:
subprocess.run(cmd, check=True, capture_output=True)
return True
except subprocess.CalledProcessError as e:
print(f"{Colors.RED}FFmpeg error: {e.stderr.decode()}{Colors.RESET}")
return False
except FileNotFoundError:
print(f"{Colors.RED}FFmpeg not found. Please install ffmpeg.{Colors.RESET}")
return False
def generate_video_with_audio(
model_repo: str,
text_encoder_repo: Optional[str],
prompt: str,
height: int = 512,
width: int = 512,
num_frames: int = 33,
seed: int = 42,
fps: int = 24,
output_path: str = "output_av.mp4",
output_audio_path: Optional[str] = None,
verbose: bool = True,
enhance_prompt: bool = False,
max_tokens: int = 512,
temperature: float = 0.7,
image: Optional[str] = None,
image_strength: float = 1.0,
image_frame_idx: int = 0,
tiling: str = "auto",
):
"""Generate video with synchronized audio from text prompt, optionally conditioned on an image.
Args:
model_repo: Model repository ID
text_encoder_repo: Text encoder repository ID
prompt: Text description of the video to generate
height: Output video height (must be divisible by 64)
width: Output video width (must be divisible by 64)
num_frames: Number of frames
seed: Random seed
fps: Frames per second
output_path: Output video path
output_audio_path: Output audio path
verbose: Whether to print progress
enhance_prompt: Whether to enhance prompt using Gemma
max_tokens: Max tokens for prompt enhancement
temperature: Temperature for prompt enhancement
image: Path to conditioning image for I2V
image_strength: Conditioning strength (1.0 = full denoise)
image_frame_idx: Frame index to condition (0 = first frame)
tiling: Tiling mode for VAE decoding (auto/none/default/aggressive/conservative/spatial/temporal)
"""
start_time = time.time()
# Validate dimensions
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
if num_frames % 8 != 1:
adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1
print(f"{Colors.YELLOW}⚠️ Adjusted frames to {adjusted_num_frames}{Colors.RESET}")
num_frames = adjusted_num_frames
# Calculate audio frames
audio_frames = compute_audio_frames(num_frames, fps)
is_i2v = image is not None
mode_str = "I2V+Audio" if is_i2v else "T2V+Audio"
print(f"{Colors.BOLD}{Colors.CYAN}🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames + audio{Colors.RESET}")
print(f"{Colors.DIM}Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz{Colors.RESET}")
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
if is_i2v:
print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}")
model_path = get_model_path(model_repo)
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
# Calculate latent dimensions
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
stage2_h, stage2_w = height // 32, width // 32
latent_frames = 1 + (num_frames - 1) // 8
mx.random.seed(seed)
# Load text encoder with audio embeddings
print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}")
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
text_encoder = LTX2TextEncoder()
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
mx.eval(text_encoder.parameters())
# Optionally enhance prompt
if enhance_prompt:
print(f"{Colors.MAGENTA}✨ Enhancing prompt...{Colors.RESET}")
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
# Get both video and audio embeddings
video_embeddings, audio_embeddings = text_encoder(prompt)
model_dtype = video_embeddings.dtype # bfloat16 from text encoder
mx.eval(video_embeddings, audio_embeddings)
del text_encoder
mx.clear_cache()
# Load transformer with AudioVideo config
print(f"{Colors.BLUE}🤖 Loading transformer (A/V mode)...{Colors.RESET}")
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
sanitized = sanitize_transformer_weights(raw_weights)
# Convert transformer weights to bfloat16 for memory efficiency
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
config = LTXModelConfig(
model_type=LTXModelType.AudioVideo,
num_attention_heads=32,
attention_head_dim=128,
in_channels=128,
out_channels=128,
num_layers=48,
cross_attention_dim=4096,
caption_channels=3840,
# Audio config
audio_num_attention_heads=32,
audio_attention_head_dim=64,
audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, # 8 * 16 = 128
audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS,
audio_cross_attention_dim=2048,
rope_type=LTXRopeType.SPLIT,
double_precision_rope=True,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
audio_positional_embedding_max_pos=[20],
use_middle_indices_grid=True,
timestep_scale_multiplier=1000,
)
transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False)
mx.eval(transformer.parameters())
# Load VAE encoder and encode image for I2V conditioning
stage1_image_latent = None
stage2_image_latent = None
if is_i2v:
print(f"{Colors.BLUE}🖼️ Loading VAE encoder and encoding image...{Colors.RESET}")
vae_encoder = load_vae_encoder(str(model_path / 'ltx-2-19b-distilled.safetensors'))
mx.eval(vae_encoder.parameters())
# Load and prepare image for stage 1 (half resolution)
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
stage1_image_latent = vae_encoder(stage1_image_tensor)
mx.eval(stage1_image_latent)
# Load and prepare image for stage 2 (full resolution)
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
stage2_image_latent = vae_encoder(stage2_image_tensor)
mx.eval(stage2_image_latent)
del vae_encoder
mx.clear_cache()
# Initialize latents
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
mx.random.seed(seed)
# Create position grids - MUST stay float32 for RoPE precision
# bfloat16 positions cause quality degradation due to precision loss in sin/cos calculations
video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w) # float32
audio_positions = create_audio_position_grid(1, audio_frames) # float32
mx.eval(video_positions, audio_positions)
# Apply I2V conditioning for stage 1 if provided
video_state1 = None
video_latent_shape = (1, 128, latent_frames, stage1_h, stage1_w)
if is_i2v and stage1_image_latent is not None:
# PyTorch flow: create zeros -> apply conditioning -> apply noiser
video_state1 = LatentState(
latent=mx.zeros(video_latent_shape, dtype=model_dtype),
clean_latent=mx.zeros(video_latent_shape, dtype=model_dtype),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
)
conditioning = VideoConditionByLatentIndex(
latent=stage1_image_latent,
frame_idx=image_frame_idx,
strength=image_strength,
)
video_state1 = apply_conditioning(video_state1, [conditioning])
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
noise = mx.random.normal(video_latent_shape).astype(model_dtype)
noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype) # 1.0
scaled_mask = video_state1.denoise_mask * noise_scale
video_state1 = LatentState(
latent=noise * scaled_mask + video_state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=video_state1.clean_latent,
denoise_mask=video_state1.denoise_mask,
)
video_latents = video_state1.latent
mx.eval(video_latents)
else:
# T2V: just use random noise
video_latents = mx.random.normal(video_latent_shape).astype(model_dtype)
mx.eval(video_latents)
# Audio always uses pure noise (no I2V for audio)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
mx.eval(audio_latents)
# Stage 1 denoising
video_latents, audio_latents = denoise_av(
video_latents, audio_latents,
video_positions, audio_positions,
video_embeddings, audio_embeddings,
transformer, STAGE_1_SIGMAS, verbose=verbose,
video_state=video_state1
)
# Upsample video latents
print(f"{Colors.MAGENTA}🔍 Upsampling video latents 2x...{Colors.RESET}")
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
mx.eval(upsampler.parameters())
vae_decoder = load_vae_decoder(
str(model_path / 'ltx-2-19b-distilled.safetensors'),
timestep_conditioning=None # Auto-detect from model metadata
)
video_latents = upsample_latents(video_latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
mx.eval(video_latents)
del upsampler
mx.clear_cache()
# Stage 2: Refine at full resolution
print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}")
# Position grids stay float32 for RoPE precision
video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w) # float32
mx.eval(video_positions)
# Apply I2V conditioning for stage 2 if provided
video_state2 = None
if is_i2v and stage2_image_latent is not None:
# PyTorch flow: start with upscaled latent -> apply conditioning -> apply noiser
video_state2 = LatentState(
latent=video_latents, # Start with upscaled latent
clean_latent=mx.zeros_like(video_latents),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
)
conditioning = VideoConditionByLatentIndex(
latent=stage2_image_latent,
frame_idx=image_frame_idx,
strength=image_strength,
)
video_state2 = apply_conditioning(video_state2, [conditioning])
# Apply noiser: conditioned frames (mask=0) keep image latent, unconditioned get partial noise
video_noise = mx.random.normal(video_latents.shape).astype(model_dtype)
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
scaled_mask = video_state2.denoise_mask * noise_scale
video_state2 = LatentState(
latent=video_noise * scaled_mask + video_state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=video_state2.clean_latent,
denoise_mask=video_state2.denoise_mask,
)
video_latents = video_state2.latent
mx.eval(video_latents)
# Audio still gets noise (no I2V for audio)
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
mx.eval(audio_latents)
else:
# T2V: add noise to all frames for refinement
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale
video_noise = mx.random.normal(video_latents.shape).astype(model_dtype)
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
video_latents = video_noise * noise_scale + video_latents * one_minus_scale
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
mx.eval(video_latents, audio_latents)
video_latents, audio_latents = denoise_av(
video_latents, audio_latents,
video_positions, audio_positions,
video_embeddings, audio_embeddings,
transformer, STAGE_2_SIGMAS, verbose=verbose,
video_state=video_state2
)
del transformer
mx.clear_cache()
# Decode video with tiling
print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}")
# Select tiling configuration
if tiling == "none":
tiling_config = None
elif tiling == "auto":
tiling_config = TilingConfig.auto(height, width, num_frames)
elif tiling == "default":
tiling_config = TilingConfig.default()
elif tiling == "aggressive":
tiling_config = TilingConfig.aggressive()
elif tiling == "conservative":
tiling_config = TilingConfig.conservative()
elif tiling == "spatial":
tiling_config = TilingConfig.spatial_only()
elif tiling == "temporal":
tiling_config = TilingConfig.temporal_only()
else:
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
tiling_config = TilingConfig.auto(height, width, num_frames)
if tiling_config is not None:
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
video = vae_decoder.decode_tiled(video_latents, tiling_config=tiling_config, debug=verbose)
else:
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
video = vae_decoder(video_latents)
mx.eval(video)
# Convert video to uint8 frames
video = mx.squeeze(video, axis=0)
video = mx.transpose(video, (1, 2, 3, 0))
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
video = (video * 255).astype(mx.uint8)
video_np = np.array(video)
# Decode audio
print(f"{Colors.BLUE}🔊 Decoding audio...{Colors.RESET}")
audio_decoder = load_audio_decoder(model_path)
vocoder = load_vocoder(model_path)
mx.eval(audio_decoder.parameters(), vocoder.parameters())
mel_spectrogram = audio_decoder(audio_latents)
mx.eval(mel_spectrogram)
# Audio decoder output is already in vocoder format (B, C, T, F)
audio_waveform = vocoder(mel_spectrogram)
mx.eval(audio_waveform)
audio_np = np.array(audio_waveform)
if audio_np.ndim == 3:
audio_np = audio_np[0] # Remove batch dim
del audio_decoder, vocoder, vae_decoder
mx.clear_cache()
# Save outputs
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Save video (temporary without audio)
temp_video_path = output_path.with_suffix('.temp.mp4')
try:
import cv2
h, w = video_np.shape[1], video_np.shape[2]
fourcc = cv2.VideoWriter_fourcc(*'avc1')
out = cv2.VideoWriter(str(temp_video_path), fourcc, fps, (w, h))
for frame in video_np:
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
out.release()
print(f"{Colors.GREEN}✅ Video encoded{Colors.RESET}")
except Exception as e:
print(f"{Colors.RED}❌ Video encoding failed: {e}{Colors.RESET}")
return None, None
# Save audio
audio_path = output_path.with_suffix('.wav') if output_audio_path is None else Path(output_audio_path)
save_audio(audio_np, audio_path, AUDIO_SAMPLE_RATE)
print(f"{Colors.GREEN}✅ Saved audio to{Colors.RESET} {audio_path}")
# Mux video and audio
print(f"{Colors.BLUE}🎬 Combining video and audio...{Colors.RESET}")
if mux_video_audio(temp_video_path, audio_path, output_path):
print(f"{Colors.GREEN}✅ Saved video with audio to{Colors.RESET} {output_path}")
temp_video_path.unlink() # Remove temp file
else:
# Fallback: keep video without audio
temp_video_path.rename(output_path)
print(f"{Colors.YELLOW}⚠️ Saved video without audio to{Colors.RESET} {output_path}")
elapsed = time.time() - start_time
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s{Colors.RESET}")
print(f"{Colors.BOLD}{Colors.GREEN}✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}")
return video_np, audio_np
def main():
parser = argparse.ArgumentParser(
description="Generate videos with synchronized audio using MLX LTX-2 (T2V and I2V)",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Text-to-Video with Audio (T2V+Audio)
python -m mlx_video.generate_av --prompt "Ocean waves crashing on a beach"
python -m mlx_video.generate_av --prompt "A jazz band playing" --enhance-prompt
python -m mlx_video.generate_av --prompt "..." --output my_video.mp4 --output-audio my_audio.wav
# Image-to-Video with Audio (I2V+Audio)
python -m mlx_video.generate_av --prompt "A person dancing" --image photo.jpg
python -m mlx_video.generate_av --prompt "Waves crashing" --image beach.png --image-strength 0.8
"""
)
parser.add_argument("--prompt", "-p", type=str, required=True,
help="Text description of the video/audio to generate")
parser.add_argument("--height", "-H", type=int, default=512,
help="Output video height (default: 512)")
parser.add_argument("--width", "-W", type=int, default=512,
help="Output video width (default: 512)")
parser.add_argument("--num-frames", "-n", type=int, default=65,
help="Number of frames (default: 65)")
parser.add_argument("--seed", "-s", type=int, default=42,
help="Random seed (default: 42)")
parser.add_argument("--fps", type=int, default=24,
help="Frames per second (default: 24)")
parser.add_argument("--output-path", type=str, default="output_av.mp4",
help="Output video path (default: output_av.mp4)")
parser.add_argument("--output-audio", type=str, default=None,
help="Output audio path (default: same as video with .wav)")
parser.add_argument("--model-repo", type=str, default="Lightricks/LTX-2",
help="Model repository (default: Lightricks/LTX-2)")
parser.add_argument("--text-encoder-repo", type=str, default=None,
help="Text encoder repository")
parser.add_argument("--verbose", action="store_true",
help="Verbose output")
parser.add_argument("--enhance-prompt", action="store_true",
help="Enhance prompt using Gemma")
parser.add_argument("--max-tokens", type=int, default=512,
help="Max tokens for prompt enhancement")
parser.add_argument("--temperature", type=float, default=0.7,
help="Temperature for prompt enhancement")
parser.add_argument("--image", "-i", type=str, default=None,
help="Path to conditioning image for I2V (Image-to-Video) generation")
parser.add_argument("--image-strength", type=float, default=1.0,
help="Conditioning strength for I2V (1.0 = full denoise, 0.0 = keep original, default: 1.0)")
parser.add_argument("--image-frame-idx", type=int, default=0,
help="Frame index to condition for I2V (0 = first frame, default: 0)")
parser.add_argument("--tiling", type=str, default="auto",
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
help="Tiling mode for VAE decoding (default: auto). "
"auto=based on size, none=disabled, default=512px/64f, "
"aggressive=256px/32f (lowest memory), conservative=768px/96f")
args = parser.parse_args()
generate_video_with_audio(
model_repo=args.model_repo,
text_encoder_repo=args.text_encoder_repo,
prompt=args.prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
seed=args.seed,
fps=args.fps,
output_path=args.output_path,
output_audio_path=args.output_audio,
verbose=args.verbose,
enhance_prompt=args.enhance_prompt,
max_tokens=args.max_tokens,
temperature=args.temperature,
image=args.image,
image_strength=args.image_strength,
image_frame_idx=args.image_frame_idx,
tiling=args.tiling,
)
if __name__ == "__main__":
main()