From a658911f985e435cc66846a5b853368e723a52ae Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 16 Jan 2026 01:15:22 +0100 Subject: [PATCH 1/4] add audio --- mlx_video/convert.py | 155 +++++ mlx_video/generate.py | 2 +- mlx_video/generate_av.py | 612 ++++++++++++++++++ mlx_video/models/ltx/__init__.py | 1 + mlx_video/models/ltx/audio_vae/__init__.py | 41 ++ mlx_video/models/ltx/audio_vae/attention.py | 108 ++++ mlx_video/models/ltx/audio_vae/audio_vae.py | 326 ++++++++++ .../models/ltx/audio_vae/causal_conv_2d.py | 146 +++++ .../models/ltx/audio_vae/causality_axis.py | 12 + mlx_video/models/ltx/audio_vae/downsample.py | 127 ++++ .../models/ltx/audio_vae/normalization.py | 59 ++ mlx_video/models/ltx/audio_vae/ops.py | 98 +++ mlx_video/models/ltx/audio_vae/resnet.py | 185 ++++++ mlx_video/models/ltx/audio_vae/upsample.py | 135 ++++ mlx_video/models/ltx/audio_vae/vocoder.py | 142 ++++ mlx_video/models/ltx/config.py | 1 + mlx_video/models/ltx/ltx.py | 7 +- mlx_video/models/ltx/text_encoder.py | 225 +++++-- mlx_video/models/ltx/video_vae/ops.py | 7 +- 19 files changed, 2335 insertions(+), 54 deletions(-) create mode 100644 mlx_video/generate_av.py create mode 100644 mlx_video/models/ltx/audio_vae/__init__.py create mode 100644 mlx_video/models/ltx/audio_vae/attention.py create mode 100644 mlx_video/models/ltx/audio_vae/audio_vae.py create mode 100644 mlx_video/models/ltx/audio_vae/causal_conv_2d.py create mode 100644 mlx_video/models/ltx/audio_vae/causality_axis.py create mode 100644 mlx_video/models/ltx/audio_vae/downsample.py create mode 100644 mlx_video/models/ltx/audio_vae/normalization.py create mode 100644 mlx_video/models/ltx/audio_vae/ops.py create mode 100644 mlx_video/models/ltx/audio_vae/resnet.py create mode 100644 mlx_video/models/ltx/audio_vae/upsample.py create mode 100644 mlx_video/models/ltx/audio_vae/vocoder.py diff --git a/mlx_video/convert.py b/mlx_video/convert.py index 0204d76..1e08e5c 100644 --- a/mlx_video/convert.py +++ b/mlx_video/convert.py @@ -109,6 +109,84 @@ def load_vae_weights(model_path: Path) -> Dict[str, mx.array]: raise FileNotFoundError(f"VAE weights not found at {vae_path}") +def load_audio_vae_weights(model_path: Path) -> Dict[str, mx.array]: + """Load audio VAE weights from LTX-2 model. + + Args: + model_path: Path to LTX-2 model directory + + Returns: + Dictionary of audio VAE weights + """ + # Try different possible paths for audio VAE weights + audio_vae_paths = [ + model_path / "audio_vae" / "diffusion_pytorch_model.safetensors", + model_path / "audio_vae.safetensors", + ] + + # Also check in main model weights + main_paths = [ + model_path / "ltx-2-19b-distilled.safetensors", + model_path / "ltx-2-19b-dev.safetensors", + ] + + for audio_path in audio_vae_paths: + if audio_path.exists(): + print(f"Loading audio VAE weights from {audio_path}...") + return mx.load(str(audio_path)) + + # Check main model weights for audio_vae keys + for main_path in main_paths: + if main_path.exists(): + print(f"Loading audio VAE weights from {main_path.name}...") + all_weights = mx.load(str(main_path)) + # Filter to only audio_vae keys + audio_weights = {k: v for k, v in all_weights.items() if "audio_vae" in k} + if audio_weights: + return audio_weights + + raise FileNotFoundError(f"Audio VAE weights not found in {model_path}") + + +def load_vocoder_weights(model_path: Path) -> Dict[str, mx.array]: + """Load vocoder weights from LTX-2 model. + + Args: + model_path: Path to LTX-2 model directory + + Returns: + Dictionary of vocoder weights + """ + # Try different possible paths for vocoder weights + vocoder_paths = [ + model_path / "vocoder" / "diffusion_pytorch_model.safetensors", + model_path / "vocoder.safetensors", + ] + + # Also check in main model weights + main_paths = [ + model_path / "ltx-2-19b-distilled.safetensors", + model_path / "ltx-2-19b-dev.safetensors", + ] + + for vocoder_path in vocoder_paths: + if vocoder_path.exists(): + print(f"Loading vocoder weights from {vocoder_path}...") + return mx.load(str(vocoder_path)) + + # Check main model weights for vocoder keys + for main_path in main_paths: + if main_path.exists(): + print(f"Loading vocoder weights from {main_path.name}...") + all_weights = mx.load(str(main_path)) + # Filter to only vocoder keys + vocoder_weights = {k: v for k, v in all_weights.items() if "vocoder" in k} + if vocoder_weights: + return vocoder_weights + + raise FileNotFoundError(f"Vocoder weights not found in {model_path}") + + def sanitize_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: """Sanitize transformer weight names from PyTorch LTX-2 format to MLX format. @@ -213,6 +291,83 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: return sanitized +def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: + """Sanitize audio VAE weight names from PyTorch format to MLX format. + + Args: + weights: Dictionary of weights with PyTorch naming + + Returns: + Dictionary with MLX-compatible naming for audio VAE decoder + """ + sanitized = {} + + for key, value in weights.items(): + new_key = key + + # Handle audio_vae.decoder weights + if key.startswith("audio_vae.decoder."): + new_key = key.replace("audio_vae.decoder.", "") + elif key.startswith("audio_vae.per_channel_statistics."): + # Map per-channel statistics + if "mean-of-means" in key: + new_key = "per_channel_statistics._mean_of_means" + elif "std-of-means" in key: + new_key = "per_channel_statistics._std_of_means" + else: + continue # Skip other statistics keys + else: + continue # Skip non-decoder keys + + # Handle Conv2d weight shape conversion + # PyTorch: (out_channels, in_channels, H, W) + # MLX: (out_channels, H, W, in_channels) + if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4: + value = mx.transpose(value, (0, 2, 3, 1)) + + sanitized[new_key] = value + + return sanitized + + +def sanitize_vocoder_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: + """Sanitize vocoder weight names from PyTorch format to MLX format. + + Args: + weights: Dictionary of weights with PyTorch naming + + Returns: + Dictionary with MLX-compatible naming for vocoder + """ + sanitized = {} + + for key, value in weights.items(): + new_key = key + + # Handle vocoder weights + if key.startswith("vocoder."): + new_key = key.replace("vocoder.", "") + + # Handle ModuleList indices -> dict keys + # PyTorch: ups.0, ups.1, ... -> ups.0, ups.1, ... + # PyTorch: resblocks.0, resblocks.1, ... -> resblocks.0, resblocks.1, ... + + # Handle Conv1d weight shape conversion + # PyTorch: (out_channels, in_channels, kernel) + # MLX: (out_channels, kernel, in_channels) + if "weight" in new_key and value.ndim == 3: + if "ups" in new_key: + # ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch) + value = mx.transpose(value, (1, 2, 0)) + else: + # Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch) + value = mx.transpose(value, (0, 2, 1)) + + sanitized[new_key] = value + + return sanitized + + def sanitize_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: """Sanitize weight names from PyTorch format to MLX format. diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 4c78bb3..e39dfeb 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -215,7 +215,7 @@ def generate_video( prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose) print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}") - text_embeddings, _ = text_encoder(prompt) + text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False) mx.eval(text_embeddings) del text_encoder diff --git a/mlx_video/generate_av.py b/mlx_video/generate_av.py new file mode 100644 index 0000000..b4e488c --- /dev/null +++ b/mlx_video/generate_av.py @@ -0,0 +1,612 @@ +"""Audio-Video generation pipeline for LTX-2.""" + +import argparse +import time +from pathlib import Path +from typing import Optional + +import mlx.core as mx +import numpy as np +from tqdm import tqdm + + +# ANSI color codes +class Colors: + CYAN = "\033[96m" + BLUE = "\033[94m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + RED = "\033[91m" + MAGENTA = "\033[95m" + BOLD = "\033[1m" + DIM = "\033[2m" + RESET = "\033[0m" + + +from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType +from mlx_video.models.ltx.ltx import LTXModel +from mlx_video.models.ltx.transformer import Modality +from mlx_video.convert import sanitize_transformer_weights, sanitize_audio_vae_weights, sanitize_vocoder_weights +from mlx_video.utils import to_denoised, get_model_path +from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder +from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents + + +# Distilled sigma schedules +STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] +STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0] + +# Audio constants +AUDIO_SAMPLE_RATE = 24000 # Output audio sample rate +AUDIO_LATENT_SAMPLE_RATE = 16000 # VAE internal sample rate +AUDIO_HOP_LENGTH = 160 +AUDIO_LATENT_DOWNSAMPLE_FACTOR = 4 +AUDIO_LATENT_CHANNELS = 8 # Latent channels before patchifying +AUDIO_MEL_BINS = 16 +AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25 + + +def create_video_position_grid( + batch_size: int, + num_frames: int, + height: int, + width: int, + temporal_scale: int = 8, + spatial_scale: int = 32, + fps: float = 24.0, + causal_fix: bool = True, +) -> mx.array: + """Create position grid for video RoPE in pixel space.""" + patch_size_t, patch_size_h, patch_size_w = 1, 1, 1 + + t_coords = np.arange(0, num_frames, patch_size_t) + h_coords = np.arange(0, height, patch_size_h) + w_coords = np.arange(0, width, patch_size_w) + + t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij') + patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0) + + patch_size_delta = np.array([patch_size_t, patch_size_h, patch_size_w]).reshape(3, 1, 1, 1) + patch_ends = patch_starts + patch_size_delta + + latent_coords = np.stack([patch_starts, patch_ends], axis=-1) + num_patches = num_frames * height * width + latent_coords = latent_coords.reshape(3, num_patches, 2) + latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1)) + + scale_factors = np.array([temporal_scale, spatial_scale, spatial_scale]).reshape(1, 3, 1, 1) + pixel_coords = (latent_coords * scale_factors).astype(np.float32) + + if causal_fix: + pixel_coords[:, 0, :, :] = np.clip( + pixel_coords[:, 0, :, :] + 1 - temporal_scale, + a_min=0, + a_max=None + ) + + pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps + + return mx.array(pixel_coords, dtype=mx.float32) + + +def create_audio_position_grid( + batch_size: int, + audio_frames: int, + sample_rate: int = AUDIO_LATENT_SAMPLE_RATE, + hop_length: int = AUDIO_HOP_LENGTH, + downsample_factor: int = AUDIO_LATENT_DOWNSAMPLE_FACTOR, + is_causal: bool = True, +) -> mx.array: + """Create temporal position grid for audio RoPE. + + Audio positions are timestamps in seconds, shape (B, 1, T, 2). + Matches PyTorch's AudioPatchifier.get_patch_grid_bounds exactly. + """ + def get_audio_latent_time_in_sec(start_idx: int, end_idx: int) -> np.ndarray: + """Convert latent indices to seconds (matching PyTorch's _get_audio_latent_time_in_sec).""" + latent_frame = np.arange(start_idx, end_idx, dtype=np.float32) + mel_frame = latent_frame * downsample_factor + if is_causal: + # Frame offset for causal alignment (PyTorch uses +1 - downsample_factor) + mel_frame = np.clip(mel_frame + 1 - downsample_factor, 0, None) + return mel_frame * hop_length / sample_rate + + # Start times: latent indices 0 to audio_frames + start_times = get_audio_latent_time_in_sec(0, audio_frames) + + # End times: latent indices 1 to audio_frames+1 (shifted by 1) + end_times = get_audio_latent_time_in_sec(1, audio_frames + 1) + + # Shape: (B, 1, T, 2) + positions = np.stack([start_times, end_times], axis=-1) + positions = positions[np.newaxis, np.newaxis, :, :] # (1, 1, T, 2) + positions = np.tile(positions, (batch_size, 1, 1, 1)) + + return mx.array(positions, dtype=mx.float32) + + +def compute_audio_frames(num_video_frames: int, fps: float) -> int: + """Compute number of audio latent frames given video duration.""" + duration = num_video_frames / fps + return round(duration * AUDIO_LATENTS_PER_SECOND) + + +def denoise_av( + video_latents: mx.array, + audio_latents: mx.array, + video_positions: mx.array, + audio_positions: mx.array, + video_embeddings: mx.array, + audio_embeddings: mx.array, + transformer: LTXModel, + sigmas: list, + verbose: bool = True, +) -> tuple[mx.array, mx.array]: + """Run denoising loop for audio-video generation.""" + for i in tqdm(range(len(sigmas) - 1), desc="Denoising A/V", disable=not verbose): + sigma, sigma_next = sigmas[i], sigmas[i + 1] + + # Flatten video latents + b, c, f, h, w = video_latents.shape + video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1)) + + # Flatten audio latents: (B, C, T, F) -> (B, T, C*F) + ab, ac, at, af = audio_latents.shape + audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F) + audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)) + + video_modality = Modality( + latent=video_flat, + timesteps=mx.full((1,), sigma), + positions=video_positions, + context=video_embeddings, + context_mask=None, + enabled=True, + ) + + audio_modality = Modality( + latent=audio_flat, + timesteps=mx.full((1,), sigma), + positions=audio_positions, + context=audio_embeddings, + context_mask=None, + enabled=True, + ) + + video_velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality) + mx.eval(video_velocity, audio_velocity) + + # Reshape velocities back + video_velocity = mx.reshape(mx.transpose(video_velocity, (0, 2, 1)), (b, c, f, h, w)) + audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af)) + audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) # (B, C, T, F) + + # Compute denoised + video_denoised = to_denoised(video_latents, video_velocity, sigma) + audio_denoised = to_denoised(audio_latents, audio_velocity, sigma) + mx.eval(video_denoised, audio_denoised) + + # Euler step + if sigma_next > 0: + video_latents = video_denoised + sigma_next * (video_latents - video_denoised) / sigma + audio_latents = audio_denoised + sigma_next * (audio_latents - audio_denoised) / sigma + else: + video_latents = video_denoised + audio_latents = audio_denoised + mx.eval(video_latents, audio_latents) + + return video_latents, audio_latents + + +def load_audio_decoder(model_path: Path): + """Load audio VAE decoder.""" + from mlx_video.models.ltx.audio_vae import AudioDecoder, CausalityAxis, NormType + + decoder = AudioDecoder( + ch=128, + out_ch=2, # stereo + ch_mult=(1, 2, 4), + num_res_blocks=2, + attn_resolutions={8, 16, 32}, + resolution=256, + z_channels=AUDIO_LATENT_CHANNELS, + norm_type=NormType.PIXEL, + causality_axis=CausalityAxis.HEIGHT, + mel_bins=64, # Output mel bins + ) + + # Load weights from main model file + weight_file = model_path / "ltx-2-19b-distilled.safetensors" + if weight_file.exists(): + raw_weights = mx.load(str(weight_file)) + sanitized = sanitize_audio_vae_weights(raw_weights) + if sanitized: + decoder.load_weights(list(sanitized.items()), strict=False) + + # Manually load per-channel statistics (they're plain mx.array, not tracked by load_weights) + if "per_channel_statistics._mean_of_means" in sanitized: + decoder.per_channel_statistics._mean_of_means = sanitized["per_channel_statistics._mean_of_means"] + if "per_channel_statistics._std_of_means" in sanitized: + decoder.per_channel_statistics._std_of_means = sanitized["per_channel_statistics._std_of_means"] + + return decoder + + +def load_vocoder(model_path: Path): + """Load vocoder for mel to waveform conversion.""" + from mlx_video.models.ltx.audio_vae import Vocoder + + vocoder = Vocoder( + resblock_kernel_sizes=[3, 7, 11], + upsample_rates=[6, 5, 2, 2, 2], + upsample_kernel_sizes=[16, 15, 8, 4, 4], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_initial_channel=1024, + stereo=True, + output_sample_rate=AUDIO_SAMPLE_RATE, + ) + + # Load weights + weight_file = model_path / "ltx-2-19b-distilled.safetensors" + if weight_file.exists(): + raw_weights = mx.load(str(weight_file)) + sanitized = sanitize_vocoder_weights(raw_weights) + if sanitized: + vocoder.load_weights(list(sanitized.items()), strict=False) + + return vocoder + + +def save_audio(audio: np.ndarray, path: Path, sample_rate: int = AUDIO_SAMPLE_RATE): + """Save audio to WAV file.""" + import wave + + # Ensure audio is in correct format (channels, samples) or (samples,) + if audio.ndim == 2: + # (channels, samples) -> (samples, channels) + audio = audio.T + + # Normalize and convert to int16 + audio = np.clip(audio, -1.0, 1.0) + audio_int16 = (audio * 32767).astype(np.int16) + + with wave.open(str(path), 'wb') as wf: + wf.setnchannels(2 if audio_int16.ndim == 2 else 1) + wf.setsampwidth(2) # 16-bit + wf.setframerate(sample_rate) + wf.writeframes(audio_int16.tobytes()) + + +def mux_video_audio(video_path: Path, audio_path: Path, output_path: Path): + """Combine video and audio into final output using ffmpeg.""" + import subprocess + + cmd = [ + "ffmpeg", "-y", + "-i", str(video_path), + "-i", str(audio_path), + "-c:v", "copy", + "-c:a", "aac", + "-shortest", + str(output_path) + ] + + try: + subprocess.run(cmd, check=True, capture_output=True) + return True + except subprocess.CalledProcessError as e: + print(f"{Colors.RED}FFmpeg error: {e.stderr.decode()}{Colors.RESET}") + return False + except FileNotFoundError: + print(f"{Colors.RED}FFmpeg not found. Please install ffmpeg.{Colors.RESET}") + return False + + +def generate_video_with_audio( + model_repo: str, + text_encoder_repo: Optional[str], + prompt: str, + height: int = 512, + width: int = 512, + num_frames: int = 33, + seed: int = 42, + fps: int = 24, + output_path: str = "output_av.mp4", + output_audio_path: Optional[str] = None, + verbose: bool = True, + enhance_prompt: bool = False, + max_tokens: int = 512, + temperature: float = 0.7, +): + """Generate video with synchronized audio from text prompt.""" + start_time = time.time() + + # Validate dimensions + assert height % 64 == 0, f"Height must be divisible by 64, got {height}" + assert width % 64 == 0, f"Width must be divisible by 64, got {width}" + + if num_frames % 8 != 1: + adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1 + print(f"{Colors.YELLOW}⚠️ Adjusted frames to {adjusted_num_frames}{Colors.RESET}") + num_frames = adjusted_num_frames + + # Calculate audio frames + audio_frames = compute_audio_frames(num_frames, fps) + + print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames + audio{Colors.RESET}") + print(f"{Colors.DIM}Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz{Colors.RESET}") + print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}") + + model_path = get_model_path(model_repo) + text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo) + + # Calculate latent dimensions + stage1_h, stage1_w = height // 2 // 32, width // 2 // 32 + stage2_h, stage2_w = height // 32, width // 32 + latent_frames = 1 + (num_frames - 1) // 8 + + mx.random.seed(seed) + + # Load text encoder with audio embeddings + print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}") + from mlx_video.models.ltx.text_encoder import LTX2TextEncoder + text_encoder = LTX2TextEncoder() + text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path) + mx.eval(text_encoder.parameters()) + + # Optionally enhance prompt + if enhance_prompt: + print(f"{Colors.MAGENTA}✨ Enhancing prompt...{Colors.RESET}") + prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose) + print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}") + + # Get both video and audio embeddings + video_embeddings, audio_embeddings = text_encoder(prompt) + mx.eval(video_embeddings, audio_embeddings) + + del text_encoder + mx.clear_cache() + + # Load transformer with AudioVideo config + print(f"{Colors.BLUE}🤖 Loading transformer (A/V mode)...{Colors.RESET}") + raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors')) + sanitized = sanitize_transformer_weights(raw_weights) + + config = LTXModelConfig( + model_type=LTXModelType.AudioVideo, + num_attention_heads=32, + attention_head_dim=128, + in_channels=128, + out_channels=128, + num_layers=48, + cross_attention_dim=4096, + caption_channels=3840, + # Audio config + audio_num_attention_heads=32, + audio_attention_head_dim=64, + audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, # 8 * 16 = 128 + audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, + audio_cross_attention_dim=2048, + rope_type=LTXRopeType.SPLIT, + double_precision_rope=True, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + audio_positional_embedding_max_pos=[20], + use_middle_indices_grid=True, + timestep_scale_multiplier=1000, + ) + + transformer = LTXModel(config) + transformer.load_weights(list(sanitized.items()), strict=False) + mx.eval(transformer.parameters()) + + # Initialize latents + print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}") + mx.random.seed(seed) + video_latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w)) + audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)) + mx.eval(video_latents, audio_latents) + + # Create position grids + video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w) + audio_positions = create_audio_position_grid(1, audio_frames) + mx.eval(video_positions, audio_positions) + + # Stage 1 denoising + video_latents, audio_latents = denoise_av( + video_latents, audio_latents, + video_positions, audio_positions, + video_embeddings, audio_embeddings, + transformer, STAGE_1_SIGMAS, verbose=verbose + ) + + # Upsample video latents + print(f"{Colors.MAGENTA}🔍 Upsampling video latents 2x...{Colors.RESET}") + upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')) + mx.eval(upsampler.parameters()) + + vae_decoder = load_vae_decoder( + str(model_path / 'ltx-2-19b-distilled.safetensors'), + timestep_conditioning=True + ) + + video_latents = upsample_latents(video_latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std) + mx.eval(video_latents) + + del upsampler + mx.clear_cache() + + # Stage 2: Refine at full resolution + print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}") + video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w) + mx.eval(video_positions) + + # Add noise for refinement + noise_scale = STAGE_2_SIGMAS[0] + video_noise = mx.random.normal(video_latents.shape) + audio_noise = mx.random.normal(audio_latents.shape) + video_latents = video_noise * noise_scale + video_latents * (1 - noise_scale) + audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale) + mx.eval(video_latents, audio_latents) + + video_latents, audio_latents = denoise_av( + video_latents, audio_latents, + video_positions, audio_positions, + video_embeddings, audio_embeddings, + transformer, STAGE_2_SIGMAS, verbose=verbose + ) + + del transformer + mx.clear_cache() + + # Decode video + print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}") + video = vae_decoder(video_latents) + mx.eval(video) + + # Convert video to uint8 frames + video = mx.squeeze(video, axis=0) + video = mx.transpose(video, (1, 2, 3, 0)) + video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0) + video = (video * 255).astype(mx.uint8) + video_np = np.array(video) + + # Decode audio + print(f"{Colors.BLUE}🔊 Decoding audio...{Colors.RESET}") + audio_decoder = load_audio_decoder(model_path) + vocoder = load_vocoder(model_path) + mx.eval(audio_decoder.parameters(), vocoder.parameters()) + + # Debug: check per-channel statistics are loaded + pcs = audio_decoder.per_channel_statistics + print(f"Per-channel stats: mean_of_means range=[{pcs._mean_of_means.min():.4f}, {pcs._mean_of_means.max():.4f}], std_of_means range=[{pcs._std_of_means.min():.4f}, {pcs._std_of_means.max():.4f}]") + + # Debug: check audio latent statistics + print(f"Audio latents shape: {audio_latents.shape}") + print(f"Audio latents stats: min={audio_latents.min():.4f}, max={audio_latents.max():.4f}, mean={audio_latents.mean():.4f}, std={mx.std(audio_latents):.4f}") + + mel_spectrogram = audio_decoder(audio_latents) + mx.eval(mel_spectrogram) + + print(f"Mel spectrogram shape: {mel_spectrogram.shape}") + print(f"Mel spectrogram stats: min={mel_spectrogram.min():.4f}, max={mel_spectrogram.max():.4f}, mean={mel_spectrogram.mean():.4f}") + + # Audio decoder output is already in vocoder format (B, C, T, F) + audio_waveform = vocoder(mel_spectrogram) + mx.eval(audio_waveform) + + print(f"Audio waveform shape: {audio_waveform.shape}") + print(f"Audio waveform stats: min={audio_waveform.min():.4f}, max={audio_waveform.max():.4f}, mean={audio_waveform.mean():.4f}") + + audio_np = np.array(audio_waveform) + if audio_np.ndim == 3: + audio_np = audio_np[0] # Remove batch dim + + del audio_decoder, vocoder, vae_decoder + mx.clear_cache() + + # Save outputs + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Save video (temporary without audio) + temp_video_path = output_path.with_suffix('.temp.mp4') + + try: + import cv2 + h, w = video_np.shape[1], video_np.shape[2] + fourcc = cv2.VideoWriter_fourcc(*'avc1') + out = cv2.VideoWriter(str(temp_video_path), fourcc, fps, (w, h)) + for frame in video_np: + out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) + out.release() + print(f"{Colors.GREEN}✅ Video encoded{Colors.RESET}") + except Exception as e: + print(f"{Colors.RED}❌ Video encoding failed: {e}{Colors.RESET}") + return None, None + + # Save audio + audio_path = output_path.with_suffix('.wav') if output_audio_path is None else Path(output_audio_path) + save_audio(audio_np, audio_path, AUDIO_SAMPLE_RATE) + print(f"{Colors.GREEN}✅ Saved audio to{Colors.RESET} {audio_path}") + + # Mux video and audio + print(f"{Colors.BLUE}🎬 Combining video and audio...{Colors.RESET}") + if mux_video_audio(temp_video_path, audio_path, output_path): + print(f"{Colors.GREEN}✅ Saved video with audio to{Colors.RESET} {output_path}") + temp_video_path.unlink() # Remove temp file + else: + # Fallback: keep video without audio + temp_video_path.rename(output_path) + print(f"{Colors.YELLOW}⚠️ Saved video without audio to{Colors.RESET} {output_path}") + + elapsed = time.time() - start_time + print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s{Colors.RESET}") + print(f"{Colors.BOLD}{Colors.GREEN}✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}") + + return video_np, audio_np + + +def main(): + parser = argparse.ArgumentParser( + description="Generate videos with synchronized audio using MLX LTX-2", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python -m mlx_video.generate_av --prompt "Ocean waves crashing on a beach" + python -m mlx_video.generate_av --prompt "A jazz band playing" --enhance-prompt + python -m mlx_video.generate_av --prompt "..." --output my_video.mp4 --output-audio my_audio.wav + """ + ) + + parser.add_argument("--prompt", "-p", type=str, required=True, + help="Text description of the video/audio to generate") + parser.add_argument("--height", "-H", type=int, default=512, + help="Output video height (default: 512)") + parser.add_argument("--width", "-W", type=int, default=512, + help="Output video width (default: 512)") + parser.add_argument("--num-frames", "-n", type=int, default=65, + help="Number of frames (default: 65)") + parser.add_argument("--seed", "-s", type=int, default=42, + help="Random seed (default: 42)") + parser.add_argument("--fps", type=int, default=24, + help="Frames per second (default: 24)") + parser.add_argument("--output-path", type=str, default="output_av.mp4", + help="Output video path (default: output_av.mp4)") + parser.add_argument("--output-audio", type=str, default=None, + help="Output audio path (default: same as video with .wav)") + parser.add_argument("--model-repo", type=str, default="Lightricks/LTX-2", + help="Model repository (default: Lightricks/LTX-2)") + parser.add_argument("--text-encoder-repo", type=str, default=None, + help="Text encoder repository") + parser.add_argument("--verbose", action="store_true", + help="Verbose output") + parser.add_argument("--enhance-prompt", action="store_true", + help="Enhance prompt using Gemma") + parser.add_argument("--max-tokens", type=int, default=512, + help="Max tokens for prompt enhancement") + parser.add_argument("--temperature", type=float, default=0.7, + help="Temperature for prompt enhancement") + + args = parser.parse_args() + + generate_video_with_audio( + model_repo=args.model_repo, + text_encoder_repo=args.text_encoder_repo, + prompt=args.prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, + seed=args.seed, + fps=args.fps, + output_path=args.output_path, + output_audio_path=args.output_audio, + verbose=args.verbose, + enhance_prompt=args.enhance_prompt, + max_tokens=args.max_tokens, + temperature=args.temperature, + ) + + +if __name__ == "__main__": + main() diff --git a/mlx_video/models/ltx/__init__.py b/mlx_video/models/ltx/__init__.py index 13c9c65..6a817e3 100644 --- a/mlx_video/models/ltx/__init__.py +++ b/mlx_video/models/ltx/__init__.py @@ -5,3 +5,4 @@ from mlx_video.models.ltx.config import ( LTXModelType, ) from mlx_video.models.ltx.ltx import LTXModel, X0Model +from mlx_video.models.ltx.audio_vae import AudioDecoder, Vocoder, decode_audio diff --git a/mlx_video/models/ltx/audio_vae/__init__.py b/mlx_video/models/ltx/audio_vae/__init__.py new file mode 100644 index 0000000..5907e2d --- /dev/null +++ b/mlx_video/models/ltx/audio_vae/__init__.py @@ -0,0 +1,41 @@ +"""Audio VAE module for LTX-2 audio generation.""" + +from .attention import AttentionType, AttnBlock, make_attn +from .audio_vae import AudioDecoder, decode_audio +from .causal_conv_2d import CausalConv2d, make_conv2d +from .causality_axis import CausalityAxis +from .downsample import Downsample, build_downsampling_path +from .normalization import NormType, PixelNorm, build_normalization_layer +from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics +from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, ResnetBlock +from .upsample import Upsample, build_upsampling_path +from .vocoder import Vocoder + +__all__ = [ + # Main components + "AudioDecoder", + "Vocoder", + "decode_audio", + # Ops + "AudioLatentShape", + "AudioPatchifier", + "PerChannelStatistics", + # Building blocks + "AttentionType", + "AttnBlock", + "make_attn", + "CausalConv2d", + "make_conv2d", + "CausalityAxis", + "Downsample", + "build_downsampling_path", + "NormType", + "PixelNorm", + "build_normalization_layer", + "ResBlock1", + "ResBlock2", + "ResnetBlock", + "LRELU_SLOPE", + "Upsample", + "build_upsampling_path", +] diff --git a/mlx_video/models/ltx/audio_vae/attention.py b/mlx_video/models/ltx/audio_vae/attention.py new file mode 100644 index 0000000..38c5744 --- /dev/null +++ b/mlx_video/models/ltx/audio_vae/attention.py @@ -0,0 +1,108 @@ +"""Attention blocks for audio VAE.""" + +from enum import Enum + +import mlx.core as mx +import mlx.nn as nn + +from .normalization import NormType, build_normalization_layer + + +class AttentionType(Enum): + """Enum for specifying the attention mechanism type.""" + + VANILLA = "vanilla" + LINEAR = "linear" + NONE = "none" + + +class AttnBlock(nn.Module): + """Self-attention block for audio VAE.""" + + def __init__( + self, + in_channels: int, + norm_type: NormType = NormType.GROUP, + ) -> None: + super().__init__() + self.in_channels = in_channels + + self.norm = build_normalization_layer(in_channels, normtype=norm_type) + # Using Conv2d with kernel_size=1 for Q, K, V projections + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def __call__(self, x: mx.array) -> mx.array: + """ + Forward pass through attention block. + Args: + x: Input tensor of shape (B, H, W, C) in MLX channels-last format + Returns: + Output tensor with attention applied (residual connection) + """ + h_ = x + h_ = self.norm(h_) + + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # Compute attention + # x shape: (B, H, W, C) + b, h, w, c = q.shape + + # Reshape for attention: (B, H*W, C) + q = q.reshape(b, h * w, c) + k = k.reshape(b, h * w, c) + v = v.reshape(b, h * w, c) + + # Attention: Q @ K^T / sqrt(d) + # q: (B, HW, C), k: (B, HW, C) -> k^T: (B, C, HW) + # w_: (B, HW, HW) + scale = float(c) ** (-0.5) + w_ = mx.matmul(q, k.transpose(0, 2, 1)) * scale + w_ = mx.softmax(w_, axis=-1) + + # Attend to values + # w_: (B, HW, HW), v: (B, HW, C) -> h_: (B, HW, C) + h_ = mx.matmul(w_, v) + + # Reshape back to spatial dims + h_ = h_.reshape(b, h, w, c) + + h_ = self.proj_out(h_) + + return x + h_ + + +class Identity(nn.Module): + """Identity module that returns input unchanged.""" + + def __call__(self, x: mx.array) -> mx.array: + return x + + +def make_attn( + in_channels: int, + attn_type: AttentionType = AttentionType.VANILLA, + norm_type: NormType = NormType.GROUP, +) -> nn.Module: + """ + Create an attention module based on type. + Args: + in_channels: Number of input channels + attn_type: Type of attention mechanism + norm_type: Type of normalization + Returns: + Attention module + """ + if attn_type == AttentionType.VANILLA: + return AttnBlock(in_channels, norm_type=norm_type) + elif attn_type == AttentionType.NONE: + return Identity() + elif attn_type == AttentionType.LINEAR: + raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.") + else: + raise ValueError(f"Unknown attention type: {attn_type}") diff --git a/mlx_video/models/ltx/audio_vae/audio_vae.py b/mlx_video/models/ltx/audio_vae/audio_vae.py new file mode 100644 index 0000000..08caec5 --- /dev/null +++ b/mlx_video/models/ltx/audio_vae/audio_vae.py @@ -0,0 +1,326 @@ +"""Audio VAE encoder and decoder for LTX-2.""" + +from typing import Set, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from .attention import AttentionType, make_attn +from .causal_conv_2d import make_conv2d +from .causality_axis import CausalityAxis +from .downsample import build_downsampling_path +from .normalization import NormType, build_normalization_layer +from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics +from .resnet import ResnetBlock +from .upsample import build_upsampling_path + +LATENT_DOWNSAMPLE_FACTOR = 4 + + +def build_mid_block( + channels: int, + temb_channels: int, + dropout: float, + norm_type: NormType, + causality_axis: CausalityAxis, + attn_type: AttentionType, + add_attention: bool, +) -> dict: + """Build the middle block with two ResNet blocks and optional attention.""" + mid = {} + mid["block_1"] = ResnetBlock( + in_channels=channels, + out_channels=channels, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + mid["attn_1"] = ( + make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else None + ) + mid["block_2"] = ResnetBlock( + in_channels=channels, + out_channels=channels, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + return mid + + +def run_mid_block(mid: dict, features: mx.array) -> mx.array: + """Run features through the middle block.""" + features = mid["block_1"](features, temb=None) + if mid["attn_1"] is not None: + features = mid["attn_1"](features) + return mid["block_2"](features, temb=None) + + +class AudioDecoder(nn.Module): + """ + Symmetric decoder that reconstructs audio spectrograms from latent features. + The decoder mirrors the encoder structure with configurable channel multipliers, + attention resolutions, and causal convolutions. + """ + + def __init__( + self, + *, + ch: int = 128, + out_ch: int = 2, + ch_mult: Tuple[int, ...] = (1, 2, 4), + num_res_blocks: int = 2, + attn_resolutions: Set[int] = None, + resolution: int = 256, + z_channels: int = 8, + norm_type: NormType = NormType.PIXEL, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + dropout: float = 0.0, + mid_block_add_attention: bool = True, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: int | None = None, + ) -> None: + """ + Initialize the AudioDecoder. + Args: + ch: Base number of feature channels + out_ch: Number of output channels (2 for stereo) + ch_mult: Multiplicative factors for channels at each resolution + num_res_blocks: Number of residual blocks per resolution + attn_resolutions: Resolutions at which to apply attention + resolution: Input spatial resolution + z_channels: Number of latent channels + norm_type: Normalization type + causality_axis: Axis for causal convolutions + dropout: Dropout probability + mid_block_add_attention: Whether to add attention in middle block + sample_rate: Audio sample rate + mel_hop_length: Hop length for mel spectrogram + is_causal: Whether to use causal convolutions + mel_bins: Number of mel frequency bins + """ + super().__init__() + + if attn_resolutions is None: + attn_resolutions = {8, 16, 32} + + # Internal behavioral defaults + resamp_with_conv = True + attn_type = AttentionType.VANILLA + + # Per-channel statistics for denormalizing latents + # Uses ch (base channel count) to match the patchified latent dimension + # Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16) + # After patchify: (B, T, z_channels * latent_mel_bins) = (B, T, 128) + # ch=128 matches this dimension, so use ch for per_channel_statistics + self.per_channel_statistics = PerChannelStatistics(latent_channels=ch) + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.is_causal = is_causal + self.mel_bins = mel_bins + + self.patchifier = AudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=sample_rate, + hop_length=mel_hop_length, + is_causal=is_causal, + ) + + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.out_ch = out_ch + self.give_pre_end = False + self.tanh_out = False + self.norm_type = norm_type + self.z_channels = z_channels + self.channel_multipliers = ch_mult + self.attn_resolutions = attn_resolutions + self.causality_axis = causality_axis + self.attn_type = attn_type + + base_block_channels = ch * self.channel_multipliers[-1] + base_resolution = resolution // (2 ** (self.num_resolutions - 1)) + self.z_shape = (1, z_channels, base_resolution, base_resolution) + + self.conv_in = make_conv2d( + z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + + self.mid = build_mid_block( + channels=base_block_channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + add_attention=mid_block_add_attention, + ) + + self.up, final_block_channels = build_upsampling_path( + ch=ch, + ch_mult=ch_mult, + num_resolutions=self.num_resolutions, + num_res_blocks=num_res_blocks, + resolution=resolution, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + attn_resolutions=attn_resolutions, + resamp_with_conv=resamp_with_conv, + initial_block_channels=base_block_channels, + ) + + self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type) + self.conv_out = make_conv2d( + final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + + def __call__(self, sample: mx.array) -> mx.array: + """ + Decode latent features back to audio spectrograms. + Args: + sample: Encoded latent representation of shape (B, H, W, C) in MLX format + or (B, C, H, W) in PyTorch format (will be transposed) + Returns: + Reconstructed audio spectrogram + """ + # Handle input format - if channels are in dim 1, transpose to channels-last + if sample.shape[1] == self.z_channels and sample.ndim == 4: + # PyTorch format (B, C, H, W) -> MLX format (B, H, W, C) + sample = mx.transpose(sample, (0, 2, 3, 1)) + + sample, target_shape = self._denormalize_latents(sample) + + h = self.conv_in(sample) + h = run_mid_block(self.mid, h) + h = self._run_upsampling_path(h) + h = self._finalize_output(h) + + return self._adjust_output_shape(h, target_shape) + + def _denormalize_latents(self, sample: mx.array) -> tuple[mx.array, AudioLatentShape]: + """Denormalize latents using per-channel statistics.""" + # sample shape: (B, H, W, C) in MLX format + latent_shape = AudioLatentShape( + batch=sample.shape[0], + channels=sample.shape[3], # channels last + frames=sample.shape[1], # height = frames + mel_bins=sample.shape[2], # width = mel_bins + ) + + sample_patched = self.patchifier.patchify(sample) + sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched) + sample = self.patchifier.unpatchify(sample_denormalized, latent_shape) + + target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR + if self.causality_axis != CausalityAxis.NONE: + target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1) + + target_shape = AudioLatentShape( + batch=latent_shape.batch, + channels=self.out_ch, + frames=target_frames, + mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins, + ) + + return sample, target_shape + + def _adjust_output_shape( + self, + decoded_output: mx.array, + target_shape: AudioLatentShape, + ) -> mx.array: + """ + Adjust output shape to match target dimensions for variable-length audio. + Args: + decoded_output: Tensor of shape (B, H, W, C) in MLX format + target_shape: AudioLatentShape describing target dimensions + Returns: + Tensor adjusted to match target_shape exactly + """ + # Current output shape: (batch, frames, mel_bins, channels) in MLX format + _, current_time, current_freq, _ = decoded_output.shape + target_channels = target_shape.channels + target_time = target_shape.frames + target_freq = target_shape.mel_bins + + # Step 1: Crop first to avoid exceeding target dimensions + decoded_output = decoded_output[ + :, : min(current_time, target_time), : min(current_freq, target_freq), :target_channels + ] + + # Step 2: Calculate padding needed for time and frequency dimensions + time_padding_needed = target_time - decoded_output.shape[1] + freq_padding_needed = target_freq - decoded_output.shape[2] + + # Step 3: Apply padding if needed + if time_padding_needed > 0 or freq_padding_needed > 0: + # MLX pad: [(before_0, after_0), ...] + # For (B, H, W, C): H=time, W=freq + padding = [ + (0, 0), # batch + (0, max(time_padding_needed, 0)), # time + (0, max(freq_padding_needed, 0)), # freq + (0, 0), # channels + ] + decoded_output = mx.pad(decoded_output, padding) + + # Step 4: Final safety crop to ensure exact target shape + decoded_output = decoded_output[:, :target_time, :target_freq, :target_channels] + + # Transpose back to PyTorch format (B, C, H, W) for vocoder compatibility + decoded_output = mx.transpose(decoded_output, (0, 3, 1, 2)) + + return decoded_output + + def _run_upsampling_path(self, h: mx.array) -> mx.array: + """Run through upsampling path.""" + for level in reversed(range(self.num_resolutions)): + stage = self.up[level] + for block_idx in range(len(stage["block"])): + h = stage["block"][block_idx](h, temb=None) + if block_idx in stage["attn"]: + h = stage["attn"][block_idx](h) + + if level != 0 and "upsample" in stage: + h = stage["upsample"](h) + + return h + + def _finalize_output(self, h: mx.array) -> mx.array: + """Apply final normalization and convolution.""" + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nn.silu(h) + h = self.conv_out(h) + return mx.tanh(h) if self.tanh_out else h + + +def decode_audio(latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder") -> mx.array: + """ + Decode an audio latent representation using the provided audio decoder and vocoder. + Args: + latent: Input audio latent tensor + audio_decoder: Model to decode the latent to spectrogram + vocoder: Model to convert spectrogram to audio waveform + Returns: + Decoded audio as a float tensor + """ + decoded_audio = audio_decoder(latent) + decoded_audio = vocoder(decoded_audio) + # Remove batch dimension if present + if decoded_audio.shape[0] == 1: + decoded_audio = decoded_audio[0] + return decoded_audio.astype(mx.float32) diff --git a/mlx_video/models/ltx/audio_vae/causal_conv_2d.py b/mlx_video/models/ltx/audio_vae/causal_conv_2d.py new file mode 100644 index 0000000..2a38448 --- /dev/null +++ b/mlx_video/models/ltx/audio_vae/causal_conv_2d.py @@ -0,0 +1,146 @@ +"""Causal 2D convolutions for audio VAE.""" + +from typing import Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from .causality_axis import CausalityAxis + + +def _pair(x: Union[int, Tuple[int, int]]) -> Tuple[int, int]: + """Convert int or tuple to tuple pair.""" + if isinstance(x, int): + return (x, x) + return x + + +class CausalConv2d(nn.Module): + """ + A causal 2D convolution. + This layer ensures that the output at time `t` only depends on inputs + at time `t` and earlier. It achieves this by applying asymmetric padding + to the time dimension before the convolution. + + Note: MLX Conv2d expects input shape (N, H, W, C) - channels last. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: int = 1, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ) -> None: + super().__init__() + + self.causality_axis = causality_axis + + # Ensure kernel_size and dilation are tuples + kernel_size = _pair(kernel_size) + dilation = _pair(dilation) + + # Calculate padding dimensions + pad_h = (kernel_size[0] - 1) * dilation[0] + pad_w = (kernel_size[1] - 1) * dilation[1] + + # Store padding for manual application + # MLX pad order: [(before_axis0, after_axis0), (before_axis1, after_axis1), ...] + # For (N, H, W, C) format: axis 1 is H (height), axis 2 is W (width) + if self.causality_axis == CausalityAxis.NONE: + # Non-causal: symmetric padding + self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w // 2, pad_w - pad_w // 2) + elif self.causality_axis in (CausalityAxis.WIDTH, CausalityAxis.WIDTH_COMPATIBILITY): + # Causal on width: pad left (before width axis) + self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w, 0) + elif self.causality_axis == CausalityAxis.HEIGHT: + # Causal on height: pad top (before height axis) + self.padding = (pad_h, 0, pad_w // 2, pad_w - pad_w // 2) + else: + raise ValueError(f"Invalid causality_axis: {causality_axis}") + + # The internal convolution layer uses no padding, as we handle it manually + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def __call__(self, x: mx.array) -> mx.array: + """ + Forward pass with causal padding. + Args: + x: Input tensor of shape (N, H, W, C) in MLX channels-last format + Returns: + Output tensor after causal convolution + """ + # Apply causal padding before convolution + # padding format: (pad_h_top, pad_h_bottom, pad_w_left, pad_w_right) + pad_h_top, pad_h_bottom, pad_w_left, pad_w_right = self.padding + + if any(p > 0 for p in self.padding): + # MLX pad expects: [(before_0, after_0), (before_1, after_1), ...] + # For (N, H, W, C): axis 0=N, axis 1=H, axis 2=W, axis 3=C + x = mx.pad(x, [(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)]) + + return self.conv(x) + + +def make_conv2d( + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: int = 1, + padding: Union[int, Tuple[int, int], None] = None, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causality_axis: CausalityAxis | None = None, +) -> nn.Module: + """ + Create a 2D convolution layer that can be either causal or non-causal. + Args: + in_channels: Number of input channels + out_channels: Number of output channels + kernel_size: Size of the convolution kernel + stride: Convolution stride + padding: Padding (if None, will be calculated based on causal flag) + dilation: Dilation rate + groups: Number of groups for grouped convolution + bias: Whether to use bias + causality_axis: Dimension along which to apply causality. + Returns: + Either a regular Conv2d or CausalConv2d layer + """ + if causality_axis is not None: + # For causal convolution, padding is handled internally by CausalConv2d + return CausalConv2d( + in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis + ) + else: + # For non-causal convolution, use symmetric padding if not specified + if padding is None: + if isinstance(kernel_size, int): + padding = kernel_size // 2 + else: + padding = tuple(k // 2 for k in kernel_size) + + return nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) diff --git a/mlx_video/models/ltx/audio_vae/causality_axis.py b/mlx_video/models/ltx/audio_vae/causality_axis.py new file mode 100644 index 0000000..15545b3 --- /dev/null +++ b/mlx_video/models/ltx/audio_vae/causality_axis.py @@ -0,0 +1,12 @@ +"""Causality axis enum for specifying causal convolution dimensions.""" + +from enum import Enum + + +class CausalityAxis(Enum): + """Enum for specifying the causality axis in causal convolutions.""" + + NONE = None + WIDTH = "width" + HEIGHT = "height" + WIDTH_COMPATIBILITY = "width-compatibility" diff --git a/mlx_video/models/ltx/audio_vae/downsample.py b/mlx_video/models/ltx/audio_vae/downsample.py new file mode 100644 index 0000000..2f553c8 --- /dev/null +++ b/mlx_video/models/ltx/audio_vae/downsample.py @@ -0,0 +1,127 @@ +"""Downsampling layers for audio VAE.""" + +from typing import Set, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from .attention import AttentionType, make_attn +from .causality_axis import CausalityAxis +from .normalization import NormType +from .resnet import ResnetBlock + + +class Downsample(nn.Module): + """ + A downsampling layer that can use either a strided convolution + or average pooling. Supports standard and causal padding for the + convolutional mode. + """ + + def __init__( + self, + in_channels: int, + with_conv: bool, + causality_axis: CausalityAxis = CausalityAxis.WIDTH, + ) -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and not self.with_conv: + raise ValueError("causality is only supported when `with_conv=True`.") + + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in MLX conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def __call__(self, x: mx.array) -> mx.array: + """ + Forward pass with downsampling. + Args: + x: Input tensor of shape (N, H, W, C) in MLX channels-last format + Returns: + Downsampled tensor + """ + if self.with_conv: + # Padding tuple is in the order: (left, right, top, bottom) for PyTorch + # For MLX pad: [(before_axis0, after_axis0), ...] + # x shape: (N, H, W, C) -> pad on H and W axes + if self.causality_axis == CausalityAxis.NONE: + # pad: (left=0, right=1, top=0, bottom=1) + pad = [(0, 0), (0, 1), (0, 1), (0, 0)] + elif self.causality_axis == CausalityAxis.WIDTH: + # pad: (left=2, right=0, top=0, bottom=1) + pad = [(0, 0), (0, 1), (2, 0), (0, 0)] + elif self.causality_axis == CausalityAxis.HEIGHT: + # pad: (left=0, right=1, top=2, bottom=0) + pad = [(0, 0), (2, 0), (0, 1), (0, 0)] + elif self.causality_axis == CausalityAxis.WIDTH_COMPATIBILITY: + # pad: (left=1, right=0, top=0, bottom=1) + pad = [(0, 0), (0, 1), (1, 0), (0, 0)] + else: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + x = mx.pad(x, pad, constant_values=0) + x = self.conv(x) + else: + # Average pooling with 2x2 kernel and stride 2 + # MLX doesn't have built-in avg_pool2d, implement manually + # x shape: (N, H, W, C) + n, h, w, c = x.shape + # Reshape to (N, H//2, 2, W//2, 2, C) and mean over pooling dims + x = x.reshape(n, h // 2, 2, w // 2, 2, c) + x = mx.mean(x, axis=(2, 4)) + + return x + + +def build_downsampling_path( + *, + ch: int, + ch_mult: Tuple[int, ...], + num_resolutions: int, + num_res_blocks: int, + resolution: int, + temb_channels: int, + dropout: float, + norm_type: NormType, + causality_axis: CausalityAxis, + attn_type: AttentionType, + attn_resolutions: Set[int], + resamp_with_conv: bool, +) -> tuple[dict, int]: + """Build the downsampling path with residual blocks, attention, and downsampling layers.""" + down_modules = {} + curr_res = resolution + in_ch_mult = (1, *tuple(ch_mult)) + block_in = ch + + for i_level in range(num_resolutions): + stage = {} + stage["block"] = {} + stage["attn"] = {} + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + + for i_block in range(num_res_blocks): + stage["block"][i_block] = ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + block_in = block_out + if curr_res in attn_resolutions: + stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type) + + if i_level != num_resolutions - 1: + stage["downsample"] = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res = curr_res // 2 + + down_modules[i_level] = stage + + return down_modules, block_in diff --git a/mlx_video/models/ltx/audio_vae/normalization.py b/mlx_video/models/ltx/audio_vae/normalization.py new file mode 100644 index 0000000..361c6b4 --- /dev/null +++ b/mlx_video/models/ltx/audio_vae/normalization.py @@ -0,0 +1,59 @@ +"""Normalization layers for audio VAE.""" + +from enum import Enum + +import mlx.core as mx +import mlx.nn as nn + + +class NormType(Enum): + """Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm).""" + + GROUP = "group" + PIXEL = "pixel" + + +class PixelNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + For each element along the chosen dimension, this layer normalizes the tensor + by the root-mean-square of its values across that dimension: + y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps) + """ + + def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: + """ + Args: + dim: Dimension along which to compute the RMS (typically channels). + eps: Small constant added for numerical stability. + """ + super().__init__() + self.dim = dim + self.eps = eps + + def __call__(self, x: mx.array) -> mx.array: + """Apply RMS normalization along the configured dimension.""" + mean_sq = mx.mean(x**2, axis=self.dim, keepdims=True) + rms = mx.sqrt(mean_sq + self.eps) + return x / rms + + +def build_normalization_layer( + in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP +) -> nn.Module: + """ + Create a normalization layer based on the normalization type. + Args: + in_channels: Number of input channels + num_groups: Number of groups for group normalization + normtype: Type of normalization: "group" or "pixel" + Returns: + A normalization layer + """ + if normtype == NormType.GROUP: + return nn.GroupNorm(num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True) + if normtype == NormType.PIXEL: + # For MLX channels-last format (B, H, W, C), normalize along channels (dim=-1) + # PyTorch uses dim=1 for channels-first format (B, C, H, W) + return PixelNorm(dim=-1, eps=1e-6) + raise ValueError(f"Invalid normalization type: {normtype}") diff --git a/mlx_video/models/ltx/audio_vae/ops.py b/mlx_video/models/ltx/audio_vae/ops.py new file mode 100644 index 0000000..bf2d111 --- /dev/null +++ b/mlx_video/models/ltx/audio_vae/ops.py @@ -0,0 +1,98 @@ +"""Audio processing utilities for audio VAE.""" + +from dataclasses import dataclass + +import mlx.core as mx +import mlx.nn as nn + + +@dataclass +class AudioLatentShape: + """Shape descriptor for audio latent representations.""" + + batch: int + channels: int + frames: int + mel_bins: int + + +class PerChannelStatistics(nn.Module): + """ + Per-channel statistics for normalizing and denormalizing the latent representation. + This statistics is computed over the entire dataset and stored in model's checkpoint. + """ + + def __init__(self, latent_channels: int = 128) -> None: + super().__init__() + self.latent_channels = latent_channels + # Initialize buffers - will be loaded from weights + # Using underscores for MLX compatibility with weight loading + self._std_of_means = mx.ones((latent_channels,)) + self._mean_of_means = mx.zeros((latent_channels,)) + + def un_normalize(self, x: mx.array) -> mx.array: + """Denormalize latent representation.""" + # Broadcast statistics to match x shape + # x shape: (B, C, ...) or (B, ..., C) + std = self._std_of_means.astype(x.dtype) + mean = self._mean_of_means.astype(x.dtype) + return (x * std) + mean + + def normalize(self, x: mx.array) -> mx.array: + """Normalize latent representation.""" + std = self._std_of_means.astype(x.dtype) + mean = self._mean_of_means.astype(x.dtype) + return (x - mean) / std + + +class AudioPatchifier: + """ + Audio patchifier for converting between audio latents and patches. + Combines channels and mel_bins dimensions for per-channel statistics. + """ + + def __init__( + self, + patch_size: int = 1, + audio_latent_downsample_factor: int = 4, + sample_rate: int = 16000, + hop_length: int = 160, + is_causal: bool = True, + ): + self.patch_size = patch_size + self.audio_latent_downsample_factor = audio_latent_downsample_factor + self.sample_rate = sample_rate + self.hop_length = hop_length + self.is_causal = is_causal + + def patchify(self, x: mx.array) -> mx.array: + """Convert audio latents to patches. + + Input shape: (B, T, F, C) in MLX format (channels last) + Output shape: (B, T, C*F) - flattened for per-channel statistics + + The output order is (c f) to match PyTorch's "b c t f -> b t (c f)". + """ + # x shape: (B, T, F, C) e.g., (1, 68, 16, 8) + b, t, f, c = x.shape + # Transpose to (B, T, C, F) for correct (c f) ordering + x = mx.transpose(x, (0, 1, 3, 2)) + # Reshape to (B, T, C*F) e.g., (1, 68, 128) + return x.reshape(b, t, c * f) + + def unpatchify(self, x: mx.array, latent_shape: AudioLatentShape) -> mx.array: + """Convert patches back to audio latents. + + Input shape: (B, T, C*F) + Output shape: (B, T, F, C) in MLX format + + Reverses patchify's "b t (c f) -> b c t f" then transposes to MLX format. + """ + # x shape: (B, T, C*F) e.g., (1, 68, 128) + b, t, cf = x.shape + c = latent_shape.channels + f = latent_shape.mel_bins + # Reshape to (B, T, C, F) + x = x.reshape(b, t, c, f) + # Transpose to MLX format (B, T, F, C) + return mx.transpose(x, (0, 1, 3, 2)) diff --git a/mlx_video/models/ltx/audio_vae/resnet.py b/mlx_video/models/ltx/audio_vae/resnet.py new file mode 100644 index 0000000..c80d938 --- /dev/null +++ b/mlx_video/models/ltx/audio_vae/resnet.py @@ -0,0 +1,185 @@ +"""ResNet blocks for audio VAE and vocoder.""" + +from typing import List, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from .causal_conv_2d import make_conv2d +from .causality_axis import CausalityAxis +from .normalization import NormType, build_normalization_layer + +LRELU_SLOPE = 0.1 + + +def leaky_relu(x: mx.array, negative_slope: float = LRELU_SLOPE) -> mx.array: + """Leaky ReLU activation.""" + return mx.maximum(x, x * negative_slope) + + +class ResBlock1(nn.Module): + """1D ResNet block for vocoder with dilated convolutions.""" + + def __init__( + self, + channels: int, + kernel_size: int = 3, + dilation: Tuple[int, int, int] = (1, 3, 5), + ): + super().__init__() + + # First set of convolutions with different dilations + self.convs1 = { + i: nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=d, + padding=(kernel_size - 1) * d // 2, + ) + for i, d in enumerate(dilation) + } + + # Second set of convolutions with dilation=1 + self.convs2 = { + i: nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=1, + padding=(kernel_size - 1) // 2, + ) + for i in range(len(dilation)) + } + + def __call__(self, x: mx.array) -> mx.array: + """Forward pass through residual blocks.""" + for i in range(len(self.convs1)): + xt = leaky_relu(x, LRELU_SLOPE) + xt = self.convs1[i](xt) + xt = leaky_relu(xt, LRELU_SLOPE) + xt = self.convs2[i](xt) + x = xt + x + return x + + +class ResBlock2(nn.Module): + """1D ResNet block for vocoder (alternative version).""" + + def __init__( + self, + channels: int, + kernel_size: int = 3, + dilation: Tuple[int, int] = (1, 3), + ): + super().__init__() + + self.convs = { + i: nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=d, + padding=(kernel_size - 1) * d // 2, + ) + for i, d in enumerate(dilation) + } + + def __call__(self, x: mx.array) -> mx.array: + """Forward pass through residual blocks.""" + for i in range(len(self.convs)): + xt = leaky_relu(x, LRELU_SLOPE) + xt = self.convs[i](xt) + x = xt + x + return x + + +class ResnetBlock(nn.Module): + """2D ResNet block for audio VAE encoder/decoder.""" + + def __init__( + self, + *, + in_channels: int, + out_channels: int | None = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + norm_type: NormType = NormType.GROUP, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ) -> None: + super().__init__() + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP: + raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") + + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.temb_channels = temb_channels + + self.norm1 = build_normalization_layer(in_channels, normtype=norm_type) + self.conv1 = make_conv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + + if temb_channels > 0: + self.temb_proj = nn.Linear(temb_channels, out_channels) + + self.norm2 = build_normalization_layer(out_channels, normtype=norm_type) + self.dropout_rate = dropout + self.conv2 = make_conv2d( + out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.nin_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + ) + + def __call__( + self, + x: mx.array, + temb: mx.array | None = None, + ) -> mx.array: + """ + Forward pass through ResNet block. + Args: + x: Input tensor of shape (N, H, W, C) in MLX channels-last format + temb: Optional time embedding tensor + Returns: + Output tensor + """ + h = x + h = self.norm1(h) + h = nn.silu(h) + h = self.conv1(h) + + if temb is not None and self.temb_channels > 0: + # temb: (B, temb_channels) -> (B, out_channels) + # Need to add spatial dims: (B, 1, 1, out_channels) for broadcasting + h = h + mx.expand_dims(mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1) + + h = self.norm2(h) + h = nn.silu(h) + if self.dropout_rate > 0: + h = nn.Dropout(self.dropout_rate)(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h diff --git a/mlx_video/models/ltx/audio_vae/upsample.py b/mlx_video/models/ltx/audio_vae/upsample.py new file mode 100644 index 0000000..731ac85 --- /dev/null +++ b/mlx_video/models/ltx/audio_vae/upsample.py @@ -0,0 +1,135 @@ +"""Upsampling layers for audio VAE.""" + +from typing import Set, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from .attention import AttentionType, make_attn +from .causal_conv_2d import make_conv2d +from .causality_axis import CausalityAxis +from .normalization import NormType +from .resnet import ResnetBlock + + +def nearest_neighbor_upsample(x: mx.array, scale_factor: int = 2) -> mx.array: + """ + Nearest neighbor upsampling for 4D tensors. + Args: + x: Input tensor of shape (N, H, W, C) + scale_factor: Upsampling factor + Returns: + Upsampled tensor of shape (N, H*scale_factor, W*scale_factor, C) + """ + n, h, w, c = x.shape + # Repeat along height and width + x = mx.repeat(x, scale_factor, axis=1) + x = mx.repeat(x, scale_factor, axis=2) + return x + + +class Upsample(nn.Module): + """Upsampling layer with optional convolution.""" + + def __init__( + self, + in_channels: int, + with_conv: bool, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ) -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + if self.with_conv: + self.conv = make_conv2d( + in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + + def __call__(self, x: mx.array) -> mx.array: + """ + Forward pass with upsampling. + Args: + x: Input tensor of shape (N, H, W, C) in MLX channels-last format + Returns: + Upsampled tensor + """ + # Nearest neighbor 2x upsampling + x = nearest_neighbor_upsample(x, scale_factor=2) + + if self.with_conv: + x = self.conv(x) + # Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n. + # For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2]. + # The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2], + # So the output elements rely on the following windows: + # 0: [-,-,0] + # 1: [-,0,0] + # 2: [0,0,1] + # 3: [0,1,1] + # 4: [1,1,2] + # 5: [1,2,2] + # Notice that the first and second elements in the output rely only on the first element in the input, + # while all other elements rely on two elements in the input. + # So we can drop the first element to undo the padding (rather than the last element). + # This is a no-op for non-causal convolutions. + if self.causality_axis == CausalityAxis.NONE: + pass # x remains unchanged + elif self.causality_axis == CausalityAxis.HEIGHT: + x = x[:, 1:, :, :] + elif self.causality_axis == CausalityAxis.WIDTH: + x = x[:, :, 1:, :] + elif self.causality_axis == CausalityAxis.WIDTH_COMPATIBILITY: + pass # x remains unchanged + else: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + return x + + +def build_upsampling_path( + *, + ch: int, + ch_mult: Tuple[int, ...], + num_resolutions: int, + num_res_blocks: int, + resolution: int, + temb_channels: int, + dropout: float, + norm_type: NormType, + causality_axis: CausalityAxis, + attn_type: AttentionType, + attn_resolutions: Set[int], + resamp_with_conv: bool, + initial_block_channels: int, +) -> tuple[dict, int]: + """Build the upsampling path with residual blocks, attention, and upsampling layers.""" + up_modules = {} + block_in = initial_block_channels + curr_res = resolution // (2 ** (num_resolutions - 1)) + + for level in reversed(range(num_resolutions)): + stage = {} + stage["block"] = {} + stage["attn"] = {} + block_out = ch * ch_mult[level] + + for i_block in range(num_res_blocks + 1): + stage["block"][i_block] = ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + block_in = block_out + if curr_res in attn_resolutions: + stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type) + + if level != 0: + stage["upsample"] = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res *= 2 + + up_modules[level] = stage + + return up_modules, block_in diff --git a/mlx_video/models/ltx/audio_vae/vocoder.py b/mlx_video/models/ltx/audio_vae/vocoder.py new file mode 100644 index 0000000..02b5393 --- /dev/null +++ b/mlx_video/models/ltx/audio_vae/vocoder.py @@ -0,0 +1,142 @@ +"""Vocoder for converting mel spectrograms to audio waveforms.""" + +import math +from typing import List + +import mlx.core as mx +import mlx.nn as nn + +from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, leaky_relu + + +class Vocoder(nn.Module): + """ + Vocoder model for synthesizing audio from Mel spectrograms. + Based on HiFi-GAN architecture. + + Args: + resblock_kernel_sizes: List of kernel sizes for the residual blocks + upsample_rates: List of upsampling rates + upsample_kernel_sizes: List of kernel sizes for the upsampling layers + resblock_dilation_sizes: List of dilation sizes for the residual blocks + upsample_initial_channel: Initial number of channels for upsampling + stereo: Whether to use stereo output + resblock: Type of residual block to use ("1" or "2") + output_sample_rate: Waveform sample rate + """ + + def __init__( + self, + resblock_kernel_sizes: List[int] | None = None, + upsample_rates: List[int] | None = None, + upsample_kernel_sizes: List[int] | None = None, + resblock_dilation_sizes: List[List[int]] | None = None, + upsample_initial_channel: int = 1024, + stereo: bool = True, + resblock: str = "1", + output_sample_rate: int = 24000, + ): + super().__init__() + + # Initialize default values if not provided + if resblock_kernel_sizes is None: + resblock_kernel_sizes = [3, 7, 11] + if upsample_rates is None: + upsample_rates = [6, 5, 2, 2, 2] + if upsample_kernel_sizes is None: + upsample_kernel_sizes = [16, 15, 8, 4, 4] + if resblock_dilation_sizes is None: + resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + + self.output_sample_rate = output_sample_rate + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.upsample_rates = upsample_rates + self.upsample_kernel_sizes = upsample_kernel_sizes + self.upsample_initial_channel = upsample_initial_channel + + in_channels = 128 if stereo else 64 + self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, kernel_size=7, stride=1, padding=3) + + resblock_class = ResBlock1 if resblock == "1" else ResBlock2 + + # Upsampling layers using ConvTranspose1d + self.ups = {} + for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + in_ch = upsample_initial_channel // (2**i) + out_ch = upsample_initial_channel // (2 ** (i + 1)) + self.ups[i] = nn.ConvTranspose1d( + in_ch, + out_ch, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - stride) // 2, + ) + + # Residual blocks + self.resblocks = {} + block_idx = 0 + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes): + self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations)) + block_idx += 1 + + out_channels = 2 if stereo else 1 + final_channels = upsample_initial_channel // (2**self.num_upsamples) + self.conv_post = nn.Conv1d(final_channels, out_channels, kernel_size=7, stride=1, padding=3) + + self.upsample_factor = math.prod(upsample_rates) + + def __call__(self, x: mx.array) -> mx.array: + """ + Forward pass of the vocoder. + Args: + x: Input Mel spectrogram tensor. Can be either: + - 3D: (batch_size, time, mel_bins) for mono - MLX format (N, L, C) + - 4D: (batch_size, 2, time, mel_bins) for stereo - PyTorch format (N, C, H, W) + Returns: + Audio waveform tensor of shape (batch_size, out_channels, audio_length) + """ + # Input: (batch, channels, time, mel_bins) from audio decoder + # Transpose to (batch, channels, mel_bins, time) + x = mx.transpose(x, (0, 1, 3, 2)) + + if x.ndim == 4: # stereo + # x shape: (batch, 2, mel_bins, time) + # Rearrange to (batch, 2*mel_bins, time) + b, s, c, t = x.shape + x = x.reshape(b, s * c, t) + + # MLX Conv1d expects (N, L, C), so transpose + # Current: (batch, channels, time) -> (batch, time, channels) + x = mx.transpose(x, (0, 2, 1)) + + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + x = leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + + start = i * self.num_kernels + end = start + self.num_kernels + + # Apply residual blocks and average their outputs + block_outputs = [] + for idx in range(start, end): + block_outputs.append(self.resblocks[idx](x)) + + # Stack and mean + x = mx.stack(block_outputs, axis=0) + x = mx.mean(x, axis=0) + + # IMPORTANT: Use default leaky_relu slope (0.01), NOT LRELU_SLOPE (0.1) + # PyTorch uses F.leaky_relu(x) which defaults to 0.01 + x = nn.leaky_relu(x) # Default negative_slope=0.01 + x = self.conv_post(x) + x = mx.tanh(x) + + # Transpose back to (batch, channels, time) + x = mx.transpose(x, (0, 2, 1)) + + return x diff --git a/mlx_video/models/ltx/config.py b/mlx_video/models/ltx/config.py index 6b8fc57..1f45fb1 100644 --- a/mlx_video/models/ltx/config.py +++ b/mlx_video/models/ltx/config.py @@ -111,6 +111,7 @@ class LTXModelConfig(BaseModelConfig): audio_in_channels: int = 128 audio_out_channels: int = 128 audio_cross_attention_dim: int = 2048 + audio_caption_channels: int = 3840 # Input dim for audio text embeddings (same as video) # Positional embedding config positional_embedding_theta: float = 10000.0 diff --git a/mlx_video/models/ltx/ltx.py b/mlx_video/models/ltx/ltx.py index dea4089..75987bc 100644 --- a/mlx_video/models/ltx/ltx.py +++ b/mlx_video/models/ltx/ltx.py @@ -70,6 +70,9 @@ class TransformerArgsPreprocessor: attention_mask: Optional[mx.array] = None, ) -> Tuple[mx.array, Optional[mx.array]]: batch_size = x.shape[0] + + # Context is already processed through embeddings connector in text encoder + # Here we just apply the caption projection context = self.caption_projection(context) context = mx.reshape(context, (batch_size, -1, x.shape[-1])) return context, attention_mask @@ -282,8 +285,10 @@ class LTXModel(nn.Module): def _init_audio(self, config: LTXModelConfig) -> None: self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True) self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim) + + # Audio caption projection: receives pre-processed embeddings from text encoder's audio_embeddings_connector self.audio_caption_projection = PixArtAlphaTextProjection( - in_features=config.caption_channels, + in_features=config.audio_caption_channels, hidden_size=self.audio_inner_dim, ) diff --git a/mlx_video/models/ltx/text_encoder.py b/mlx_video/models/ltx/text_encoder.py index 8d8da41..b38d076 100644 --- a/mlx_video/models/ltx/text_encoder.py +++ b/mlx_video/models/ltx/text_encoder.py @@ -216,7 +216,8 @@ class ConnectorAttention(nn.Module): self.to_q = nn.Linear(dim, inner_dim, bias=True) self.to_k = nn.Linear(dim, inner_dim, bias=True) self.to_v = nn.Linear(dim, inner_dim, bias=True) - self.to_out = [nn.Linear(inner_dim, dim, bias=True)] + # Direct attribute for MLX parameter tracking (not a list) + self.to_out = nn.Linear(inner_dim, dim, bias=True) # Standard RMSNorm (not Gemma-style) on full inner_dim self.q_norm = nn.RMSNorm(inner_dim, eps=1e-6) @@ -239,21 +240,51 @@ class ConnectorAttention(nn.Module): q = self.q_norm(q) k = self.k_norm(k) - - if pe is not None: - # pe: tuple of (cos, sin) each with shape (1, seq_len, inner_dim) - q = apply_interleaved_rotary_emb(q, pe[0], pe[1]) - k = apply_interleaved_rotary_emb(k, pe[0], pe[1]) - + # Reshape to (B, H, T, D) for SPLIT RoPE q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3) k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3) v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3) + if pe is not None: + # pe: tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2) + # Apply SPLIT RoPE: operates on first half of head dimensions + q = self._apply_split_rope(q, pe[0], pe[1]) + k = self._apply_split_rope(k, pe[0], pe[1]) + # No mask needed for connector - after register replacement, all positions are valid out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None) out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1) - return self.to_out[0](out) + return self.to_out(out) + + def _apply_split_rope( + self, + x: mx.array, + cos_freq: mx.array, + sin_freq: mx.array, + ) -> mx.array: + """Apply SPLIT RoPE to input tensor. + + Args: + x: Input tensor of shape (B, H, T, D) + cos_freq: Cosine frequencies of shape (1, H, T, D//2) + sin_freq: Sine frequencies of shape (1, H, T, D//2) + + Returns: + Tensor with SPLIT rotary embeddings applied + """ + # Split x into two halves: (B, H, T, D) -> two tensors of (B, H, T, D//2) + half_dim = x.shape[-1] // 2 + x1 = x[..., :half_dim] + x2 = x[..., half_dim:] + + # Apply rotation: SPLIT pattern + # out1 = x1 * cos - x2 * sin + # out2 = x2 * cos + x1 * sin + out1 = x1 * cos_freq - x2 * sin_freq + out2 = x2 * cos_freq + x1 * sin_freq + + return mx.concatenate([out1, out2], axis=-1) class GEGLU(nn.Module): @@ -272,15 +303,15 @@ class ConnectorFeedForward(nn.Module): def __init__(self, dim: int = 3840, mult: int = 4, dropout: float = 0.0): super().__init__() inner_dim = dim * mult - self.net = [ - GEGLU(dim, inner_dim), - nn.Dropout(dropout), - nn.Linear(inner_dim, dim, bias=True), - ] + # Use explicit named attributes to match weight key structure (proj_in, proj_out) + self.proj_in = nn.Linear(dim, inner_dim, bias=True) + self.dropout = nn.Dropout(dropout) + self.proj_out = nn.Linear(inner_dim, dim, bias=True) def __call__(self, x: mx.array) -> mx.array: - for layer in self.net: - x = layer(x) + x = nn.gelu(self.proj_in(x)) + x = self.dropout(x) + x = self.proj_out(x) return x @@ -326,6 +357,7 @@ class Embeddings1DConnector(nn.Module): num_layers: int = 2, num_learnable_registers: int = 128, positional_embedding_theta: float = 10000.0, + positional_embedding_max_pos: list = None, ): super().__init__() self.dim = dim @@ -333,60 +365,69 @@ class Embeddings1DConnector(nn.Module): self.head_dim = head_dim self.num_learnable_registers = num_learnable_registers self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = positional_embedding_max_pos or [4096] - self.transformer_1d_blocks = [ - ConnectorTransformerBlock(dim, num_heads, head_dim) - for _ in range(num_layers) - ] + # Use dict with int keys for MLX to track parameters (lists are not tracked) + self.transformer_1d_blocks = { + i: ConnectorTransformerBlock(dim, num_heads, head_dim) + for i in range(num_layers) + } if num_learnable_registers > 0: self.learnable_registers = mx.zeros((num_learnable_registers, dim)) def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]: - """Compute RoPE frequencies for connector (INTERLEAVED type). + """Compute RoPE frequencies for connector (SPLIT type matching PyTorch). - Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim). + Returns tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2). """ import numpy as np dim = self.num_heads * self.head_dim # inner_dim = 3840 theta = self.positional_embedding_theta - max_pos = [1] # Default for connector + max_pos = self.positional_embedding_max_pos # [4096] from PyTorch n_elem = 2 * len(max_pos) # = 2 start = 1.0 end = theta num_indices = dim // n_elem # 1920 - # Use numpy float64 for precision + # Use numpy float64 for precision (double_precision_rope=True in PyTorch) log_start = np.log(start) / np.log(theta) # = 0 log_end = np.log(end) / np.log(theta) # = 1 lin_space = np.linspace(log_start, log_end, num_indices, dtype=np.float64) - indices = (np.power(theta, lin_space) * (np.pi / 2)).astype(np.float32) + indices = (np.power(theta, lin_space) * (np.pi / 2)).astype(np.float64) # Generate positions and compute freqs (matches generate_freqs) positions = np.arange(seq_len, dtype=np.float64) - # fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1) - # scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1 - scaled_positions = positions * 2 - 1 # Shape: (seq_len,) + # Scale positions by max_pos (PyTorch uses max_pos=[4096]) + fractional_positions = positions / max_pos[0] + scaled_positions = fractional_positions * 2 - 1 # Shape: (seq_len,) # freqs = indices * scaled_positions (outer product) # Shape: (seq_len, num_indices) freqs = scaled_positions[:, None] * indices[None, :] - # Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis) - cos_freq = np.cos(freqs) + # Compute cos/sin + cos_freq = np.cos(freqs) # (seq_len, 1920) sin_freq = np.sin(freqs) - # repeat_interleave: (seq_len, num_indices) -> (seq_len, dim) - # Pattern: [c0, c0, c1, c1, c2, c2, ...] - cos_full = np.repeat(cos_freq, 2, axis=-1) - sin_full = np.repeat(sin_freq, 2, axis=-1) + # For SPLIT RoPE: pad to head_dim//2 = 64 per head, then reshape to (1, H, T, D//2) + # Current: (T, 1920) -> need (1, 30, T, 64) + # 30 heads * 64 = 1920, so no padding needed - # Add batch dimension and convert to MLX: (1, seq_len, dim) - cos_full = mx.array(cos_full[None, :, :].astype(np.float32)) - sin_full = mx.array(sin_full[None, :, :].astype(np.float32)) + # Reshape: (T, 1920) -> (T, 30, 64) -> (1, 30, T, 64) + cos_freq = cos_freq.reshape(seq_len, self.num_heads, self.head_dim // 2) + sin_freq = sin_freq.reshape(seq_len, self.num_heads, self.head_dim // 2) + + # Transpose to (1, H, T, D//2) + cos_freq = np.transpose(cos_freq, (1, 0, 2))[np.newaxis, ...] + sin_freq = np.transpose(sin_freq, (1, 0, 2))[np.newaxis, ...] + + # Convert to MLX + cos_full = mx.array(cos_freq.astype(np.float32)) + sin_full = mx.array(sin_freq.astype(np.float32)) return cos_full.astype(dtype), sin_full.astype(dtype) @@ -462,8 +503,8 @@ class Embeddings1DConnector(nn.Module): freqs_cis = self._precompute_freqs_cis(seq_len, hidden_states.dtype) # Process through transformer blocks - for block in self.transformer_1d_blocks: - hidden_states = block(hidden_states, attention_mask, freqs_cis) + for i in range(len(self.transformer_1d_blocks)): + hidden_states = self.transformer_1d_blocks[i](hidden_states, attention_mask, freqs_cis) # Final RMS norm hidden_states = rms_norm(hidden_states) @@ -535,15 +576,28 @@ class GemmaFeaturesExtractor(nn.Module): +class AudioEmbeddingsConnector(nn.Module): + """Projects video embeddings to audio cross-attention dimension.""" + + def __init__(self, input_dim: int = 3840, output_dim: int = 2048): + super().__init__() + self.linear = nn.Linear(input_dim, output_dim, bias=True) + + def __call__(self, x: mx.array) -> mx.array: + return self.linear(x) + + class LTX2TextEncoder(nn.Module): def __init__( self, hidden_dim: int = 3840, + audio_dim: int = 2048, num_layers: int = 49, # 48 transformer layers + 1 embedding ): super().__init__() self.hidden_dim = hidden_dim + self.audio_dim = audio_dim self.num_layers = num_layers self.language_model = None @@ -560,14 +614,26 @@ class LTX2TextEncoder(nn.Module): head_dim=128, num_layers=2, num_learnable_registers=128, + positional_embedding_max_pos=[4096], # Match PyTorch + ) + + # Audio embeddings connector: separate 2-layer transformer (same architecture as video) + # Both connectors process the feature extractor output independently + self.audio_embeddings_connector = Embeddings1DConnector( + dim=hidden_dim, + num_heads=30, + head_dim=128, + num_layers=2, + num_learnable_registers=128, + positional_embedding_max_pos=[4096], # Match PyTorch ) self.processor = None def load(self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it"): - if Path(text_encoder_path / "text_encoder").is_dir(): - text_encoder_path = str(text_encoder_path / "text_encoder") + if Path(str(text_encoder_path)).joinpath("text_encoder").is_dir(): + text_encoder_path = str(Path(text_encoder_path) / "text_encoder") self.language_model = LanguageModel.from_pretrained(text_encoder_path) @@ -594,10 +660,14 @@ class LTX2TextEncoder(nn.Module): # Map weight names to our structure mapped_weights = {} for key, value in connector_weights.items(): - # transformer_1d_blocks.X.attn1.* -> transformer_1d_blocks.X.attn1.* - # transformer_1d_blocks.X.ff.net.0.proj.* -> transformer_1d_blocks.X.ff.net.0.proj.* - # transformer_1d_blocks.X.ff.net.2.* -> transformer_1d_blocks.X.ff.net.2.* - mapped_weights[key] = value + new_key = key + # Map ff.net.0.proj -> ff.proj_in (GEGLU projection) + new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") + # Map ff.net.2 -> ff.proj_out (output Linear) + new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") + # Map to_out.0 -> to_out (Sequential -> direct) + new_key = new_key.replace(".to_out.0.", ".to_out.") + mapped_weights[new_key] = value self.video_embeddings_connector.load_weights( list(mapped_weights.items()), strict=False @@ -607,6 +677,34 @@ class LTX2TextEncoder(nn.Module): if "learnable_registers" in connector_weights: self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"] + # Load audio_embeddings_connector weights (same structure as video connector) + audio_connector_weights = {} + for key, value in transformer_weights.items(): + if key.startswith("model.diffusion_model.audio_embeddings_connector."): + new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "") + audio_connector_weights[new_key] = value + + if audio_connector_weights: + # Map weight names to our structure (same as video connector) + mapped_audio_weights = {} + for key, value in audio_connector_weights.items(): + new_key = key + # Map ff.net.0.proj -> ff.proj_in (GEGLU projection) + new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") + # Map ff.net.2 -> ff.proj_out (output Linear) + new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") + # Map to_out.0 -> to_out (Sequential -> direct) + new_key = new_key.replace(".to_out.0.", ".to_out.") + mapped_audio_weights[new_key] = value + + self.audio_embeddings_connector.load_weights( + list(mapped_audio_weights.items()), strict=False + ) + + # Manually load learnable_registers (it's a plain mx.array, not a parameter) + if "learnable_registers" in audio_connector_weights: + self.audio_embeddings_connector.learnable_registers = audio_connector_weights["learnable_registers"] + # Load tokenizer from transformers import AutoTokenizer tokenizer_path = model_path / "tokenizer" @@ -623,8 +721,20 @@ class LTX2TextEncoder(nn.Module): self, prompt: str, max_length: int = 1024, + return_audio_embeddings: bool = True, ) -> Tuple[mx.array, mx.array]: + """Encode text prompt to video and audio embeddings. + Args: + prompt: Text prompt to encode + max_length: Maximum token length (default 1024 to match official PyTorch) + return_audio_embeddings: If True, returns (video_emb, audio_emb). + If False, returns (video_emb, attention_mask). + + Returns: + Tuple of (video_embeddings, audio_embeddings) if return_audio_embeddings=True + Tuple of (video_embeddings, attention_mask) otherwise + """ if self.processor is None: raise RuntimeError("Model not loaded. Call load() first.") @@ -649,16 +759,33 @@ class LTX2TextEncoder(nn.Module): additive_mask = (attention_mask - 1).astype(features.dtype) additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 - embeddings, _ = self.video_embeddings_connector(features, additive_mask) + video_embeddings, _ = self.video_embeddings_connector(features, additive_mask) - return embeddings, attention_mask + if return_audio_embeddings: + # Process features through audio connector independently (same input as video) + audio_embeddings, _ = self.audio_embeddings_connector(features, additive_mask) + return video_embeddings, audio_embeddings + else: + return video_embeddings, attention_mask def __call__( self, prompt: str, max_length: int = 1024, + return_audio_embeddings: bool = True, ) -> Tuple[mx.array, mx.array]: - return self.encode(prompt, max_length) + """Encode text prompt. + + Args: + prompt: Text prompt to encode + max_length: Maximum token length (default 1024 to match official PyTorch) + return_audio_embeddings: If True, returns (video_emb, audio_emb). + If False, returns (video_emb, attention_mask). + + Returns: + Tuple of embeddings based on return_audio_embeddings flag + """ + return self.encode(prompt, max_length, return_audio_embeddings) @functools.cached_property def default_t2v_system_prompt(self) -> str: @@ -833,7 +960,7 @@ class LTX2TextEncoder(nn.Module): def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder: - encoder = LTX2TextEncoder(model_path=model_path) - encoder.load() + encoder = LTX2TextEncoder() + encoder.load(model_path=model_path) return encoder diff --git a/mlx_video/models/ltx/video_vae/ops.py b/mlx_video/models/ltx/video_vae/ops.py index 4521ba5..ca0457e 100644 --- a/mlx_video/models/ltx/video_vae/ops.py +++ b/mlx_video/models/ltx/video_vae/ops.py @@ -34,10 +34,11 @@ def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.a # Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw) x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw)) - # Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, ph, pw, F', H', W') - x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6)) + # Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, pw, ph, F', H', W') + # PyTorch einops uses (c, p, r, q) = (c, temporal, width, height), so we need pw before ph + x = mx.transpose(x, (0, 1, 3, 7, 5, 2, 4, 6)) - # Reshape: (B, C, pt, ph, pw, F', H', W') -> (B, C*pt*ph*pw, F', H', W') + # Reshape: (B, C, pt, pw, ph, F', H', W') -> (B, C*pt*pw*ph, F', H', W') x = mx.reshape(x, (b, new_c, new_f, new_h, new_w)) return x From e1bff927dfd296c937f64e49959c47eb7103e3c4 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 16 Jan 2026 14:55:50 +0100 Subject: [PATCH 2/4] Auto-detect timestep_cond from model metadata () --- mlx_video/generate.py | 2 +- mlx_video/generate_av.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx_video/generate.py b/mlx_video/generate.py index e39dfeb..0c9e562 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -265,7 +265,7 @@ def generate_video( vae_decoder = load_vae_decoder( str(model_path / 'ltx-2-19b-distilled.safetensors'), - timestep_conditioning=True + timestep_conditioning=None # Auto-detect from model metadata ) latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std) diff --git a/mlx_video/generate_av.py b/mlx_video/generate_av.py index b4e488c..7551c51 100644 --- a/mlx_video/generate_av.py +++ b/mlx_video/generate_av.py @@ -427,7 +427,7 @@ def generate_video_with_audio( vae_decoder = load_vae_decoder( str(model_path / 'ltx-2-19b-distilled.safetensors'), - timestep_conditioning=True + timestep_conditioning=None # Auto-detect from model metadata ) video_latents = upsample_latents(video_latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std) From f6e0e5d5a4312e5d913f3a91507605435efd0fbd Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 16 Jan 2026 15:59:22 +0100 Subject: [PATCH 3/4] Update av_ca_timestep_scale_multiplier to 1000 in model configuration for consistency across modules --- mlx_video/convert.py | 2 +- mlx_video/models/ltx/config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx_video/convert.py b/mlx_video/convert.py index 1e08e5c..e251fc5 100644 --- a/mlx_video/convert.py +++ b/mlx_video/convert.py @@ -463,7 +463,7 @@ def create_model_from_config(config: Dict[str, Any]) -> LTXModel: positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]), audio_positional_embedding_max_pos=config.get("audio_positional_embedding_max_pos", [20]), timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000), - av_ca_timestep_scale_multiplier=config.get("av_ca_timestep_scale_multiplier", 1), + av_ca_timestep_scale_multiplier=config.get("av_ca_timestep_scale_multiplier", 1000), norm_eps=config.get("norm_eps", 1e-6), ) diff --git a/mlx_video/models/ltx/config.py b/mlx_video/models/ltx/config.py index 1f45fb1..6ac9de2 100644 --- a/mlx_video/models/ltx/config.py +++ b/mlx_video/models/ltx/config.py @@ -123,7 +123,7 @@ class LTXModelConfig(BaseModelConfig): # Timestep config timestep_scale_multiplier: int = 1000 - av_ca_timestep_scale_multiplier: int = 1 + av_ca_timestep_scale_multiplier: int = 1000 # Normalization norm_eps: float = 1e-6 From 5f86e881d7928f3553c4b97a9915b4990c38663d Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 16 Jan 2026 21:08:14 +0100 Subject: [PATCH 4/4] Update top_p parameter in sampler function to 1.0 for enhanced sampling control in LTX2TextEncoder --- mlx_video/models/ltx/text_encoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx_video/models/ltx/text_encoder.py b/mlx_video/models/ltx/text_encoder.py index b38d076..bcd7cf4 100644 --- a/mlx_video/models/ltx/text_encoder.py +++ b/mlx_video/models/ltx/text_encoder.py @@ -883,7 +883,7 @@ class LTX2TextEncoder(nn.Module): ) input_ids = mx.array(inputs["input_ids"]) - sampler = make_sampler(kwargs.get("temperature", 0.7), kwargs.get("top_p", 0.95), top_k=kwargs.get("top_k", -1)) + sampler = make_sampler(kwargs.get("temperature", 0.7), kwargs.get("top_p", 1.0), top_k=kwargs.get("top_k", -1)) logits_processors = make_logits_processors( kwargs.get("logit_bias", None), kwargs.get("repetition_penalty", 1.3),