Add audio support
Add audio support
This commit is contained in:
@@ -109,6 +109,84 @@ def load_vae_weights(model_path: Path) -> Dict[str, mx.array]:
|
|||||||
raise FileNotFoundError(f"VAE weights not found at {vae_path}")
|
raise FileNotFoundError(f"VAE weights not found at {vae_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio_vae_weights(model_path: Path) -> Dict[str, mx.array]:
|
||||||
|
"""Load audio VAE weights from LTX-2 model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to LTX-2 model directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of audio VAE weights
|
||||||
|
"""
|
||||||
|
# Try different possible paths for audio VAE weights
|
||||||
|
audio_vae_paths = [
|
||||||
|
model_path / "audio_vae" / "diffusion_pytorch_model.safetensors",
|
||||||
|
model_path / "audio_vae.safetensors",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Also check in main model weights
|
||||||
|
main_paths = [
|
||||||
|
model_path / "ltx-2-19b-distilled.safetensors",
|
||||||
|
model_path / "ltx-2-19b-dev.safetensors",
|
||||||
|
]
|
||||||
|
|
||||||
|
for audio_path in audio_vae_paths:
|
||||||
|
if audio_path.exists():
|
||||||
|
print(f"Loading audio VAE weights from {audio_path}...")
|
||||||
|
return mx.load(str(audio_path))
|
||||||
|
|
||||||
|
# Check main model weights for audio_vae keys
|
||||||
|
for main_path in main_paths:
|
||||||
|
if main_path.exists():
|
||||||
|
print(f"Loading audio VAE weights from {main_path.name}...")
|
||||||
|
all_weights = mx.load(str(main_path))
|
||||||
|
# Filter to only audio_vae keys
|
||||||
|
audio_weights = {k: v for k, v in all_weights.items() if "audio_vae" in k}
|
||||||
|
if audio_weights:
|
||||||
|
return audio_weights
|
||||||
|
|
||||||
|
raise FileNotFoundError(f"Audio VAE weights not found in {model_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_vocoder_weights(model_path: Path) -> Dict[str, mx.array]:
|
||||||
|
"""Load vocoder weights from LTX-2 model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to LTX-2 model directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of vocoder weights
|
||||||
|
"""
|
||||||
|
# Try different possible paths for vocoder weights
|
||||||
|
vocoder_paths = [
|
||||||
|
model_path / "vocoder" / "diffusion_pytorch_model.safetensors",
|
||||||
|
model_path / "vocoder.safetensors",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Also check in main model weights
|
||||||
|
main_paths = [
|
||||||
|
model_path / "ltx-2-19b-distilled.safetensors",
|
||||||
|
model_path / "ltx-2-19b-dev.safetensors",
|
||||||
|
]
|
||||||
|
|
||||||
|
for vocoder_path in vocoder_paths:
|
||||||
|
if vocoder_path.exists():
|
||||||
|
print(f"Loading vocoder weights from {vocoder_path}...")
|
||||||
|
return mx.load(str(vocoder_path))
|
||||||
|
|
||||||
|
# Check main model weights for vocoder keys
|
||||||
|
for main_path in main_paths:
|
||||||
|
if main_path.exists():
|
||||||
|
print(f"Loading vocoder weights from {main_path.name}...")
|
||||||
|
all_weights = mx.load(str(main_path))
|
||||||
|
# Filter to only vocoder keys
|
||||||
|
vocoder_weights = {k: v for k, v in all_weights.items() if "vocoder" in k}
|
||||||
|
if vocoder_weights:
|
||||||
|
return vocoder_weights
|
||||||
|
|
||||||
|
raise FileNotFoundError(f"Vocoder weights not found in {model_path}")
|
||||||
|
|
||||||
|
|
||||||
def sanitize_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
def sanitize_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
"""Sanitize transformer weight names from PyTorch LTX-2 format to MLX format.
|
"""Sanitize transformer weight names from PyTorch LTX-2 format to MLX format.
|
||||||
|
|
||||||
@@ -213,6 +291,83 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
|
"""Sanitize audio VAE weight names from PyTorch format to MLX format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights: Dictionary of weights with PyTorch naming
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with MLX-compatible naming for audio VAE decoder
|
||||||
|
"""
|
||||||
|
sanitized = {}
|
||||||
|
|
||||||
|
for key, value in weights.items():
|
||||||
|
new_key = key
|
||||||
|
|
||||||
|
# Handle audio_vae.decoder weights
|
||||||
|
if key.startswith("audio_vae.decoder."):
|
||||||
|
new_key = key.replace("audio_vae.decoder.", "")
|
||||||
|
elif key.startswith("audio_vae.per_channel_statistics."):
|
||||||
|
# Map per-channel statistics
|
||||||
|
if "mean-of-means" in key:
|
||||||
|
new_key = "per_channel_statistics._mean_of_means"
|
||||||
|
elif "std-of-means" in key:
|
||||||
|
new_key = "per_channel_statistics._std_of_means"
|
||||||
|
else:
|
||||||
|
continue # Skip other statistics keys
|
||||||
|
else:
|
||||||
|
continue # Skip non-decoder keys
|
||||||
|
|
||||||
|
# Handle Conv2d weight shape conversion
|
||||||
|
# PyTorch: (out_channels, in_channels, H, W)
|
||||||
|
# MLX: (out_channels, H, W, in_channels)
|
||||||
|
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||||
|
value = mx.transpose(value, (0, 2, 3, 1))
|
||||||
|
|
||||||
|
sanitized[new_key] = value
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_vocoder_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
|
"""Sanitize vocoder weight names from PyTorch format to MLX format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights: Dictionary of weights with PyTorch naming
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with MLX-compatible naming for vocoder
|
||||||
|
"""
|
||||||
|
sanitized = {}
|
||||||
|
|
||||||
|
for key, value in weights.items():
|
||||||
|
new_key = key
|
||||||
|
|
||||||
|
# Handle vocoder weights
|
||||||
|
if key.startswith("vocoder."):
|
||||||
|
new_key = key.replace("vocoder.", "")
|
||||||
|
|
||||||
|
# Handle ModuleList indices -> dict keys
|
||||||
|
# PyTorch: ups.0, ups.1, ... -> ups.0, ups.1, ...
|
||||||
|
# PyTorch: resblocks.0, resblocks.1, ... -> resblocks.0, resblocks.1, ...
|
||||||
|
|
||||||
|
# Handle Conv1d weight shape conversion
|
||||||
|
# PyTorch: (out_channels, in_channels, kernel)
|
||||||
|
# MLX: (out_channels, kernel, in_channels)
|
||||||
|
if "weight" in new_key and value.ndim == 3:
|
||||||
|
if "ups" in new_key:
|
||||||
|
# ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch)
|
||||||
|
value = mx.transpose(value, (1, 2, 0))
|
||||||
|
else:
|
||||||
|
# Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch)
|
||||||
|
value = mx.transpose(value, (0, 2, 1))
|
||||||
|
|
||||||
|
sanitized[new_key] = value
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
def sanitize_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
def sanitize_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
"""Sanitize weight names from PyTorch format to MLX format.
|
"""Sanitize weight names from PyTorch format to MLX format.
|
||||||
|
|
||||||
@@ -308,7 +463,7 @@ def create_model_from_config(config: Dict[str, Any]) -> LTXModel:
|
|||||||
positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]),
|
positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]),
|
||||||
audio_positional_embedding_max_pos=config.get("audio_positional_embedding_max_pos", [20]),
|
audio_positional_embedding_max_pos=config.get("audio_positional_embedding_max_pos", [20]),
|
||||||
timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000),
|
timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000),
|
||||||
av_ca_timestep_scale_multiplier=config.get("av_ca_timestep_scale_multiplier", 1),
|
av_ca_timestep_scale_multiplier=config.get("av_ca_timestep_scale_multiplier", 1000),
|
||||||
norm_eps=config.get("norm_eps", 1e-6),
|
norm_eps=config.get("norm_eps", 1e-6),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -215,7 +215,7 @@ def generate_video(
|
|||||||
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
|
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
|
||||||
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
|
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
|
||||||
|
|
||||||
text_embeddings, _ = text_encoder(prompt)
|
text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False)
|
||||||
mx.eval(text_embeddings)
|
mx.eval(text_embeddings)
|
||||||
|
|
||||||
del text_encoder
|
del text_encoder
|
||||||
@@ -265,7 +265,7 @@ def generate_video(
|
|||||||
|
|
||||||
vae_decoder = load_vae_decoder(
|
vae_decoder = load_vae_decoder(
|
||||||
str(model_path / 'ltx-2-19b-distilled.safetensors'),
|
str(model_path / 'ltx-2-19b-distilled.safetensors'),
|
||||||
timestep_conditioning=True
|
timestep_conditioning=None # Auto-detect from model metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
|
latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
|
||||||
|
|||||||
612
mlx_video/generate_av.py
Normal file
612
mlx_video/generate_av.py
Normal file
@@ -0,0 +1,612 @@
|
|||||||
|
"""Audio-Video generation pipeline for LTX-2."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
# ANSI color codes
|
||||||
|
class Colors:
|
||||||
|
CYAN = "\033[96m"
|
||||||
|
BLUE = "\033[94m"
|
||||||
|
GREEN = "\033[92m"
|
||||||
|
YELLOW = "\033[93m"
|
||||||
|
RED = "\033[91m"
|
||||||
|
MAGENTA = "\033[95m"
|
||||||
|
BOLD = "\033[1m"
|
||||||
|
DIM = "\033[2m"
|
||||||
|
RESET = "\033[0m"
|
||||||
|
|
||||||
|
|
||||||
|
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
|
||||||
|
from mlx_video.models.ltx.ltx import LTXModel
|
||||||
|
from mlx_video.models.ltx.transformer import Modality
|
||||||
|
from mlx_video.convert import sanitize_transformer_weights, sanitize_audio_vae_weights, sanitize_vocoder_weights
|
||||||
|
from mlx_video.utils import to_denoised, get_model_path
|
||||||
|
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||||
|
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
|
||||||
|
|
||||||
|
|
||||||
|
# Distilled sigma schedules
|
||||||
|
STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
|
||||||
|
STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
|
||||||
|
|
||||||
|
# Audio constants
|
||||||
|
AUDIO_SAMPLE_RATE = 24000 # Output audio sample rate
|
||||||
|
AUDIO_LATENT_SAMPLE_RATE = 16000 # VAE internal sample rate
|
||||||
|
AUDIO_HOP_LENGTH = 160
|
||||||
|
AUDIO_LATENT_DOWNSAMPLE_FACTOR = 4
|
||||||
|
AUDIO_LATENT_CHANNELS = 8 # Latent channels before patchifying
|
||||||
|
AUDIO_MEL_BINS = 16
|
||||||
|
AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25
|
||||||
|
|
||||||
|
|
||||||
|
def create_video_position_grid(
|
||||||
|
batch_size: int,
|
||||||
|
num_frames: int,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
temporal_scale: int = 8,
|
||||||
|
spatial_scale: int = 32,
|
||||||
|
fps: float = 24.0,
|
||||||
|
causal_fix: bool = True,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Create position grid for video RoPE in pixel space."""
|
||||||
|
patch_size_t, patch_size_h, patch_size_w = 1, 1, 1
|
||||||
|
|
||||||
|
t_coords = np.arange(0, num_frames, patch_size_t)
|
||||||
|
h_coords = np.arange(0, height, patch_size_h)
|
||||||
|
w_coords = np.arange(0, width, patch_size_w)
|
||||||
|
|
||||||
|
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
|
||||||
|
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
|
||||||
|
|
||||||
|
patch_size_delta = np.array([patch_size_t, patch_size_h, patch_size_w]).reshape(3, 1, 1, 1)
|
||||||
|
patch_ends = patch_starts + patch_size_delta
|
||||||
|
|
||||||
|
latent_coords = np.stack([patch_starts, patch_ends], axis=-1)
|
||||||
|
num_patches = num_frames * height * width
|
||||||
|
latent_coords = latent_coords.reshape(3, num_patches, 2)
|
||||||
|
latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1))
|
||||||
|
|
||||||
|
scale_factors = np.array([temporal_scale, spatial_scale, spatial_scale]).reshape(1, 3, 1, 1)
|
||||||
|
pixel_coords = (latent_coords * scale_factors).astype(np.float32)
|
||||||
|
|
||||||
|
if causal_fix:
|
||||||
|
pixel_coords[:, 0, :, :] = np.clip(
|
||||||
|
pixel_coords[:, 0, :, :] + 1 - temporal_scale,
|
||||||
|
a_min=0,
|
||||||
|
a_max=None
|
||||||
|
)
|
||||||
|
|
||||||
|
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps
|
||||||
|
|
||||||
|
return mx.array(pixel_coords, dtype=mx.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def create_audio_position_grid(
|
||||||
|
batch_size: int,
|
||||||
|
audio_frames: int,
|
||||||
|
sample_rate: int = AUDIO_LATENT_SAMPLE_RATE,
|
||||||
|
hop_length: int = AUDIO_HOP_LENGTH,
|
||||||
|
downsample_factor: int = AUDIO_LATENT_DOWNSAMPLE_FACTOR,
|
||||||
|
is_causal: bool = True,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Create temporal position grid for audio RoPE.
|
||||||
|
|
||||||
|
Audio positions are timestamps in seconds, shape (B, 1, T, 2).
|
||||||
|
Matches PyTorch's AudioPatchifier.get_patch_grid_bounds exactly.
|
||||||
|
"""
|
||||||
|
def get_audio_latent_time_in_sec(start_idx: int, end_idx: int) -> np.ndarray:
|
||||||
|
"""Convert latent indices to seconds (matching PyTorch's _get_audio_latent_time_in_sec)."""
|
||||||
|
latent_frame = np.arange(start_idx, end_idx, dtype=np.float32)
|
||||||
|
mel_frame = latent_frame * downsample_factor
|
||||||
|
if is_causal:
|
||||||
|
# Frame offset for causal alignment (PyTorch uses +1 - downsample_factor)
|
||||||
|
mel_frame = np.clip(mel_frame + 1 - downsample_factor, 0, None)
|
||||||
|
return mel_frame * hop_length / sample_rate
|
||||||
|
|
||||||
|
# Start times: latent indices 0 to audio_frames
|
||||||
|
start_times = get_audio_latent_time_in_sec(0, audio_frames)
|
||||||
|
|
||||||
|
# End times: latent indices 1 to audio_frames+1 (shifted by 1)
|
||||||
|
end_times = get_audio_latent_time_in_sec(1, audio_frames + 1)
|
||||||
|
|
||||||
|
# Shape: (B, 1, T, 2)
|
||||||
|
positions = np.stack([start_times, end_times], axis=-1)
|
||||||
|
positions = positions[np.newaxis, np.newaxis, :, :] # (1, 1, T, 2)
|
||||||
|
positions = np.tile(positions, (batch_size, 1, 1, 1))
|
||||||
|
|
||||||
|
return mx.array(positions, dtype=mx.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_audio_frames(num_video_frames: int, fps: float) -> int:
|
||||||
|
"""Compute number of audio latent frames given video duration."""
|
||||||
|
duration = num_video_frames / fps
|
||||||
|
return round(duration * AUDIO_LATENTS_PER_SECOND)
|
||||||
|
|
||||||
|
|
||||||
|
def denoise_av(
|
||||||
|
video_latents: mx.array,
|
||||||
|
audio_latents: mx.array,
|
||||||
|
video_positions: mx.array,
|
||||||
|
audio_positions: mx.array,
|
||||||
|
video_embeddings: mx.array,
|
||||||
|
audio_embeddings: mx.array,
|
||||||
|
transformer: LTXModel,
|
||||||
|
sigmas: list,
|
||||||
|
verbose: bool = True,
|
||||||
|
) -> tuple[mx.array, mx.array]:
|
||||||
|
"""Run denoising loop for audio-video generation."""
|
||||||
|
for i in tqdm(range(len(sigmas) - 1), desc="Denoising A/V", disable=not verbose):
|
||||||
|
sigma, sigma_next = sigmas[i], sigmas[i + 1]
|
||||||
|
|
||||||
|
# Flatten video latents
|
||||||
|
b, c, f, h, w = video_latents.shape
|
||||||
|
video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1))
|
||||||
|
|
||||||
|
# Flatten audio latents: (B, C, T, F) -> (B, T, C*F)
|
||||||
|
ab, ac, at, af = audio_latents.shape
|
||||||
|
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F)
|
||||||
|
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af))
|
||||||
|
|
||||||
|
video_modality = Modality(
|
||||||
|
latent=video_flat,
|
||||||
|
timesteps=mx.full((1,), sigma),
|
||||||
|
positions=video_positions,
|
||||||
|
context=video_embeddings,
|
||||||
|
context_mask=None,
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_modality = Modality(
|
||||||
|
latent=audio_flat,
|
||||||
|
timesteps=mx.full((1,), sigma),
|
||||||
|
positions=audio_positions,
|
||||||
|
context=audio_embeddings,
|
||||||
|
context_mask=None,
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
video_velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality)
|
||||||
|
mx.eval(video_velocity, audio_velocity)
|
||||||
|
|
||||||
|
# Reshape velocities back
|
||||||
|
video_velocity = mx.reshape(mx.transpose(video_velocity, (0, 2, 1)), (b, c, f, h, w))
|
||||||
|
audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af))
|
||||||
|
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) # (B, C, T, F)
|
||||||
|
|
||||||
|
# Compute denoised
|
||||||
|
video_denoised = to_denoised(video_latents, video_velocity, sigma)
|
||||||
|
audio_denoised = to_denoised(audio_latents, audio_velocity, sigma)
|
||||||
|
mx.eval(video_denoised, audio_denoised)
|
||||||
|
|
||||||
|
# Euler step
|
||||||
|
if sigma_next > 0:
|
||||||
|
video_latents = video_denoised + sigma_next * (video_latents - video_denoised) / sigma
|
||||||
|
audio_latents = audio_denoised + sigma_next * (audio_latents - audio_denoised) / sigma
|
||||||
|
else:
|
||||||
|
video_latents = video_denoised
|
||||||
|
audio_latents = audio_denoised
|
||||||
|
mx.eval(video_latents, audio_latents)
|
||||||
|
|
||||||
|
return video_latents, audio_latents
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio_decoder(model_path: Path):
|
||||||
|
"""Load audio VAE decoder."""
|
||||||
|
from mlx_video.models.ltx.audio_vae import AudioDecoder, CausalityAxis, NormType
|
||||||
|
|
||||||
|
decoder = AudioDecoder(
|
||||||
|
ch=128,
|
||||||
|
out_ch=2, # stereo
|
||||||
|
ch_mult=(1, 2, 4),
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_resolutions={8, 16, 32},
|
||||||
|
resolution=256,
|
||||||
|
z_channels=AUDIO_LATENT_CHANNELS,
|
||||||
|
norm_type=NormType.PIXEL,
|
||||||
|
causality_axis=CausalityAxis.HEIGHT,
|
||||||
|
mel_bins=64, # Output mel bins
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load weights from main model file
|
||||||
|
weight_file = model_path / "ltx-2-19b-distilled.safetensors"
|
||||||
|
if weight_file.exists():
|
||||||
|
raw_weights = mx.load(str(weight_file))
|
||||||
|
sanitized = sanitize_audio_vae_weights(raw_weights)
|
||||||
|
if sanitized:
|
||||||
|
decoder.load_weights(list(sanitized.items()), strict=False)
|
||||||
|
|
||||||
|
# Manually load per-channel statistics (they're plain mx.array, not tracked by load_weights)
|
||||||
|
if "per_channel_statistics._mean_of_means" in sanitized:
|
||||||
|
decoder.per_channel_statistics._mean_of_means = sanitized["per_channel_statistics._mean_of_means"]
|
||||||
|
if "per_channel_statistics._std_of_means" in sanitized:
|
||||||
|
decoder.per_channel_statistics._std_of_means = sanitized["per_channel_statistics._std_of_means"]
|
||||||
|
|
||||||
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
|
def load_vocoder(model_path: Path):
|
||||||
|
"""Load vocoder for mel to waveform conversion."""
|
||||||
|
from mlx_video.models.ltx.audio_vae import Vocoder
|
||||||
|
|
||||||
|
vocoder = Vocoder(
|
||||||
|
resblock_kernel_sizes=[3, 7, 11],
|
||||||
|
upsample_rates=[6, 5, 2, 2, 2],
|
||||||
|
upsample_kernel_sizes=[16, 15, 8, 4, 4],
|
||||||
|
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
|
upsample_initial_channel=1024,
|
||||||
|
stereo=True,
|
||||||
|
output_sample_rate=AUDIO_SAMPLE_RATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load weights
|
||||||
|
weight_file = model_path / "ltx-2-19b-distilled.safetensors"
|
||||||
|
if weight_file.exists():
|
||||||
|
raw_weights = mx.load(str(weight_file))
|
||||||
|
sanitized = sanitize_vocoder_weights(raw_weights)
|
||||||
|
if sanitized:
|
||||||
|
vocoder.load_weights(list(sanitized.items()), strict=False)
|
||||||
|
|
||||||
|
return vocoder
|
||||||
|
|
||||||
|
|
||||||
|
def save_audio(audio: np.ndarray, path: Path, sample_rate: int = AUDIO_SAMPLE_RATE):
|
||||||
|
"""Save audio to WAV file."""
|
||||||
|
import wave
|
||||||
|
|
||||||
|
# Ensure audio is in correct format (channels, samples) or (samples,)
|
||||||
|
if audio.ndim == 2:
|
||||||
|
# (channels, samples) -> (samples, channels)
|
||||||
|
audio = audio.T
|
||||||
|
|
||||||
|
# Normalize and convert to int16
|
||||||
|
audio = np.clip(audio, -1.0, 1.0)
|
||||||
|
audio_int16 = (audio * 32767).astype(np.int16)
|
||||||
|
|
||||||
|
with wave.open(str(path), 'wb') as wf:
|
||||||
|
wf.setnchannels(2 if audio_int16.ndim == 2 else 1)
|
||||||
|
wf.setsampwidth(2) # 16-bit
|
||||||
|
wf.setframerate(sample_rate)
|
||||||
|
wf.writeframes(audio_int16.tobytes())
|
||||||
|
|
||||||
|
|
||||||
|
def mux_video_audio(video_path: Path, audio_path: Path, output_path: Path):
|
||||||
|
"""Combine video and audio into final output using ffmpeg."""
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg", "-y",
|
||||||
|
"-i", str(video_path),
|
||||||
|
"-i", str(audio_path),
|
||||||
|
"-c:v", "copy",
|
||||||
|
"-c:a", "aac",
|
||||||
|
"-shortest",
|
||||||
|
str(output_path)
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.run(cmd, check=True, capture_output=True)
|
||||||
|
return True
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"{Colors.RED}FFmpeg error: {e.stderr.decode()}{Colors.RESET}")
|
||||||
|
return False
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"{Colors.RED}FFmpeg not found. Please install ffmpeg.{Colors.RESET}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def generate_video_with_audio(
|
||||||
|
model_repo: str,
|
||||||
|
text_encoder_repo: Optional[str],
|
||||||
|
prompt: str,
|
||||||
|
height: int = 512,
|
||||||
|
width: int = 512,
|
||||||
|
num_frames: int = 33,
|
||||||
|
seed: int = 42,
|
||||||
|
fps: int = 24,
|
||||||
|
output_path: str = "output_av.mp4",
|
||||||
|
output_audio_path: Optional[str] = None,
|
||||||
|
verbose: bool = True,
|
||||||
|
enhance_prompt: bool = False,
|
||||||
|
max_tokens: int = 512,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
):
|
||||||
|
"""Generate video with synchronized audio from text prompt."""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Validate dimensions
|
||||||
|
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
|
||||||
|
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
|
||||||
|
|
||||||
|
if num_frames % 8 != 1:
|
||||||
|
adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1
|
||||||
|
print(f"{Colors.YELLOW}⚠️ Adjusted frames to {adjusted_num_frames}{Colors.RESET}")
|
||||||
|
num_frames = adjusted_num_frames
|
||||||
|
|
||||||
|
# Calculate audio frames
|
||||||
|
audio_frames = compute_audio_frames(num_frames, fps)
|
||||||
|
|
||||||
|
print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames + audio{Colors.RESET}")
|
||||||
|
print(f"{Colors.DIM}Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz{Colors.RESET}")
|
||||||
|
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
|
||||||
|
|
||||||
|
model_path = get_model_path(model_repo)
|
||||||
|
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
|
||||||
|
|
||||||
|
# Calculate latent dimensions
|
||||||
|
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
|
||||||
|
stage2_h, stage2_w = height // 32, width // 32
|
||||||
|
latent_frames = 1 + (num_frames - 1) // 8
|
||||||
|
|
||||||
|
mx.random.seed(seed)
|
||||||
|
|
||||||
|
# Load text encoder with audio embeddings
|
||||||
|
print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}")
|
||||||
|
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
|
||||||
|
text_encoder = LTX2TextEncoder()
|
||||||
|
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
|
||||||
|
mx.eval(text_encoder.parameters())
|
||||||
|
|
||||||
|
# Optionally enhance prompt
|
||||||
|
if enhance_prompt:
|
||||||
|
print(f"{Colors.MAGENTA}✨ Enhancing prompt...{Colors.RESET}")
|
||||||
|
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
|
||||||
|
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
|
||||||
|
|
||||||
|
# Get both video and audio embeddings
|
||||||
|
video_embeddings, audio_embeddings = text_encoder(prompt)
|
||||||
|
mx.eval(video_embeddings, audio_embeddings)
|
||||||
|
|
||||||
|
del text_encoder
|
||||||
|
mx.clear_cache()
|
||||||
|
|
||||||
|
# Load transformer with AudioVideo config
|
||||||
|
print(f"{Colors.BLUE}🤖 Loading transformer (A/V mode)...{Colors.RESET}")
|
||||||
|
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
|
||||||
|
sanitized = sanitize_transformer_weights(raw_weights)
|
||||||
|
|
||||||
|
config = LTXModelConfig(
|
||||||
|
model_type=LTXModelType.AudioVideo,
|
||||||
|
num_attention_heads=32,
|
||||||
|
attention_head_dim=128,
|
||||||
|
in_channels=128,
|
||||||
|
out_channels=128,
|
||||||
|
num_layers=48,
|
||||||
|
cross_attention_dim=4096,
|
||||||
|
caption_channels=3840,
|
||||||
|
# Audio config
|
||||||
|
audio_num_attention_heads=32,
|
||||||
|
audio_attention_head_dim=64,
|
||||||
|
audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, # 8 * 16 = 128
|
||||||
|
audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS,
|
||||||
|
audio_cross_attention_dim=2048,
|
||||||
|
rope_type=LTXRopeType.SPLIT,
|
||||||
|
double_precision_rope=True,
|
||||||
|
positional_embedding_theta=10000.0,
|
||||||
|
positional_embedding_max_pos=[20, 2048, 2048],
|
||||||
|
audio_positional_embedding_max_pos=[20],
|
||||||
|
use_middle_indices_grid=True,
|
||||||
|
timestep_scale_multiplier=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
transformer = LTXModel(config)
|
||||||
|
transformer.load_weights(list(sanitized.items()), strict=False)
|
||||||
|
mx.eval(transformer.parameters())
|
||||||
|
|
||||||
|
# Initialize latents
|
||||||
|
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
|
||||||
|
mx.random.seed(seed)
|
||||||
|
video_latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
|
||||||
|
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS))
|
||||||
|
mx.eval(video_latents, audio_latents)
|
||||||
|
|
||||||
|
# Create position grids
|
||||||
|
video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||||
|
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||||
|
mx.eval(video_positions, audio_positions)
|
||||||
|
|
||||||
|
# Stage 1 denoising
|
||||||
|
video_latents, audio_latents = denoise_av(
|
||||||
|
video_latents, audio_latents,
|
||||||
|
video_positions, audio_positions,
|
||||||
|
video_embeddings, audio_embeddings,
|
||||||
|
transformer, STAGE_1_SIGMAS, verbose=verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
# Upsample video latents
|
||||||
|
print(f"{Colors.MAGENTA}🔍 Upsampling video latents 2x...{Colors.RESET}")
|
||||||
|
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
|
||||||
|
mx.eval(upsampler.parameters())
|
||||||
|
|
||||||
|
vae_decoder = load_vae_decoder(
|
||||||
|
str(model_path / 'ltx-2-19b-distilled.safetensors'),
|
||||||
|
timestep_conditioning=None # Auto-detect from model metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
video_latents = upsample_latents(video_latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
|
||||||
|
mx.eval(video_latents)
|
||||||
|
|
||||||
|
del upsampler
|
||||||
|
mx.clear_cache()
|
||||||
|
|
||||||
|
# Stage 2: Refine at full resolution
|
||||||
|
print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}")
|
||||||
|
video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||||
|
mx.eval(video_positions)
|
||||||
|
|
||||||
|
# Add noise for refinement
|
||||||
|
noise_scale = STAGE_2_SIGMAS[0]
|
||||||
|
video_noise = mx.random.normal(video_latents.shape)
|
||||||
|
audio_noise = mx.random.normal(audio_latents.shape)
|
||||||
|
video_latents = video_noise * noise_scale + video_latents * (1 - noise_scale)
|
||||||
|
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale)
|
||||||
|
mx.eval(video_latents, audio_latents)
|
||||||
|
|
||||||
|
video_latents, audio_latents = denoise_av(
|
||||||
|
video_latents, audio_latents,
|
||||||
|
video_positions, audio_positions,
|
||||||
|
video_embeddings, audio_embeddings,
|
||||||
|
transformer, STAGE_2_SIGMAS, verbose=verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
del transformer
|
||||||
|
mx.clear_cache()
|
||||||
|
|
||||||
|
# Decode video
|
||||||
|
print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}")
|
||||||
|
video = vae_decoder(video_latents)
|
||||||
|
mx.eval(video)
|
||||||
|
|
||||||
|
# Convert video to uint8 frames
|
||||||
|
video = mx.squeeze(video, axis=0)
|
||||||
|
video = mx.transpose(video, (1, 2, 3, 0))
|
||||||
|
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
|
||||||
|
video = (video * 255).astype(mx.uint8)
|
||||||
|
video_np = np.array(video)
|
||||||
|
|
||||||
|
# Decode audio
|
||||||
|
print(f"{Colors.BLUE}🔊 Decoding audio...{Colors.RESET}")
|
||||||
|
audio_decoder = load_audio_decoder(model_path)
|
||||||
|
vocoder = load_vocoder(model_path)
|
||||||
|
mx.eval(audio_decoder.parameters(), vocoder.parameters())
|
||||||
|
|
||||||
|
# Debug: check per-channel statistics are loaded
|
||||||
|
pcs = audio_decoder.per_channel_statistics
|
||||||
|
print(f"Per-channel stats: mean_of_means range=[{pcs._mean_of_means.min():.4f}, {pcs._mean_of_means.max():.4f}], std_of_means range=[{pcs._std_of_means.min():.4f}, {pcs._std_of_means.max():.4f}]")
|
||||||
|
|
||||||
|
# Debug: check audio latent statistics
|
||||||
|
print(f"Audio latents shape: {audio_latents.shape}")
|
||||||
|
print(f"Audio latents stats: min={audio_latents.min():.4f}, max={audio_latents.max():.4f}, mean={audio_latents.mean():.4f}, std={mx.std(audio_latents):.4f}")
|
||||||
|
|
||||||
|
mel_spectrogram = audio_decoder(audio_latents)
|
||||||
|
mx.eval(mel_spectrogram)
|
||||||
|
|
||||||
|
print(f"Mel spectrogram shape: {mel_spectrogram.shape}")
|
||||||
|
print(f"Mel spectrogram stats: min={mel_spectrogram.min():.4f}, max={mel_spectrogram.max():.4f}, mean={mel_spectrogram.mean():.4f}")
|
||||||
|
|
||||||
|
# Audio decoder output is already in vocoder format (B, C, T, F)
|
||||||
|
audio_waveform = vocoder(mel_spectrogram)
|
||||||
|
mx.eval(audio_waveform)
|
||||||
|
|
||||||
|
print(f"Audio waveform shape: {audio_waveform.shape}")
|
||||||
|
print(f"Audio waveform stats: min={audio_waveform.min():.4f}, max={audio_waveform.max():.4f}, mean={audio_waveform.mean():.4f}")
|
||||||
|
|
||||||
|
audio_np = np.array(audio_waveform)
|
||||||
|
if audio_np.ndim == 3:
|
||||||
|
audio_np = audio_np[0] # Remove batch dim
|
||||||
|
|
||||||
|
del audio_decoder, vocoder, vae_decoder
|
||||||
|
mx.clear_cache()
|
||||||
|
|
||||||
|
# Save outputs
|
||||||
|
output_path = Path(output_path)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save video (temporary without audio)
|
||||||
|
temp_video_path = output_path.with_suffix('.temp.mp4')
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
h, w = video_np.shape[1], video_np.shape[2]
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
||||||
|
out = cv2.VideoWriter(str(temp_video_path), fourcc, fps, (w, h))
|
||||||
|
for frame in video_np:
|
||||||
|
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||||
|
out.release()
|
||||||
|
print(f"{Colors.GREEN}✅ Video encoded{Colors.RESET}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"{Colors.RED}❌ Video encoding failed: {e}{Colors.RESET}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# Save audio
|
||||||
|
audio_path = output_path.with_suffix('.wav') if output_audio_path is None else Path(output_audio_path)
|
||||||
|
save_audio(audio_np, audio_path, AUDIO_SAMPLE_RATE)
|
||||||
|
print(f"{Colors.GREEN}✅ Saved audio to{Colors.RESET} {audio_path}")
|
||||||
|
|
||||||
|
# Mux video and audio
|
||||||
|
print(f"{Colors.BLUE}🎬 Combining video and audio...{Colors.RESET}")
|
||||||
|
if mux_video_audio(temp_video_path, audio_path, output_path):
|
||||||
|
print(f"{Colors.GREEN}✅ Saved video with audio to{Colors.RESET} {output_path}")
|
||||||
|
temp_video_path.unlink() # Remove temp file
|
||||||
|
else:
|
||||||
|
# Fallback: keep video without audio
|
||||||
|
temp_video_path.rename(output_path)
|
||||||
|
print(f"{Colors.YELLOW}⚠️ Saved video without audio to{Colors.RESET} {output_path}")
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s{Colors.RESET}")
|
||||||
|
print(f"{Colors.BOLD}{Colors.GREEN}✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}")
|
||||||
|
|
||||||
|
return video_np, audio_np
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Generate videos with synchronized audio using MLX LTX-2",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog="""
|
||||||
|
Examples:
|
||||||
|
python -m mlx_video.generate_av --prompt "Ocean waves crashing on a beach"
|
||||||
|
python -m mlx_video.generate_av --prompt "A jazz band playing" --enhance-prompt
|
||||||
|
python -m mlx_video.generate_av --prompt "..." --output my_video.mp4 --output-audio my_audio.wav
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--prompt", "-p", type=str, required=True,
|
||||||
|
help="Text description of the video/audio to generate")
|
||||||
|
parser.add_argument("--height", "-H", type=int, default=512,
|
||||||
|
help="Output video height (default: 512)")
|
||||||
|
parser.add_argument("--width", "-W", type=int, default=512,
|
||||||
|
help="Output video width (default: 512)")
|
||||||
|
parser.add_argument("--num-frames", "-n", type=int, default=65,
|
||||||
|
help="Number of frames (default: 65)")
|
||||||
|
parser.add_argument("--seed", "-s", type=int, default=42,
|
||||||
|
help="Random seed (default: 42)")
|
||||||
|
parser.add_argument("--fps", type=int, default=24,
|
||||||
|
help="Frames per second (default: 24)")
|
||||||
|
parser.add_argument("--output-path", type=str, default="output_av.mp4",
|
||||||
|
help="Output video path (default: output_av.mp4)")
|
||||||
|
parser.add_argument("--output-audio", type=str, default=None,
|
||||||
|
help="Output audio path (default: same as video with .wav)")
|
||||||
|
parser.add_argument("--model-repo", type=str, default="Lightricks/LTX-2",
|
||||||
|
help="Model repository (default: Lightricks/LTX-2)")
|
||||||
|
parser.add_argument("--text-encoder-repo", type=str, default=None,
|
||||||
|
help="Text encoder repository")
|
||||||
|
parser.add_argument("--verbose", action="store_true",
|
||||||
|
help="Verbose output")
|
||||||
|
parser.add_argument("--enhance-prompt", action="store_true",
|
||||||
|
help="Enhance prompt using Gemma")
|
||||||
|
parser.add_argument("--max-tokens", type=int, default=512,
|
||||||
|
help="Max tokens for prompt enhancement")
|
||||||
|
parser.add_argument("--temperature", type=float, default=0.7,
|
||||||
|
help="Temperature for prompt enhancement")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
generate_video_with_audio(
|
||||||
|
model_repo=args.model_repo,
|
||||||
|
text_encoder_repo=args.text_encoder_repo,
|
||||||
|
prompt=args.prompt,
|
||||||
|
height=args.height,
|
||||||
|
width=args.width,
|
||||||
|
num_frames=args.num_frames,
|
||||||
|
seed=args.seed,
|
||||||
|
fps=args.fps,
|
||||||
|
output_path=args.output_path,
|
||||||
|
output_audio_path=args.output_audio,
|
||||||
|
verbose=args.verbose,
|
||||||
|
enhance_prompt=args.enhance_prompt,
|
||||||
|
max_tokens=args.max_tokens,
|
||||||
|
temperature=args.temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -5,3 +5,4 @@ from mlx_video.models.ltx.config import (
|
|||||||
LTXModelType,
|
LTXModelType,
|
||||||
)
|
)
|
||||||
from mlx_video.models.ltx.ltx import LTXModel, X0Model
|
from mlx_video.models.ltx.ltx import LTXModel, X0Model
|
||||||
|
from mlx_video.models.ltx.audio_vae import AudioDecoder, Vocoder, decode_audio
|
||||||
|
|||||||
41
mlx_video/models/ltx/audio_vae/__init__.py
Normal file
41
mlx_video/models/ltx/audio_vae/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""Audio VAE module for LTX-2 audio generation."""
|
||||||
|
|
||||||
|
from .attention import AttentionType, AttnBlock, make_attn
|
||||||
|
from .audio_vae import AudioDecoder, decode_audio
|
||||||
|
from .causal_conv_2d import CausalConv2d, make_conv2d
|
||||||
|
from .causality_axis import CausalityAxis
|
||||||
|
from .downsample import Downsample, build_downsampling_path
|
||||||
|
from .normalization import NormType, PixelNorm, build_normalization_layer
|
||||||
|
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||||
|
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, ResnetBlock
|
||||||
|
from .upsample import Upsample, build_upsampling_path
|
||||||
|
from .vocoder import Vocoder
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Main components
|
||||||
|
"AudioDecoder",
|
||||||
|
"Vocoder",
|
||||||
|
"decode_audio",
|
||||||
|
# Ops
|
||||||
|
"AudioLatentShape",
|
||||||
|
"AudioPatchifier",
|
||||||
|
"PerChannelStatistics",
|
||||||
|
# Building blocks
|
||||||
|
"AttentionType",
|
||||||
|
"AttnBlock",
|
||||||
|
"make_attn",
|
||||||
|
"CausalConv2d",
|
||||||
|
"make_conv2d",
|
||||||
|
"CausalityAxis",
|
||||||
|
"Downsample",
|
||||||
|
"build_downsampling_path",
|
||||||
|
"NormType",
|
||||||
|
"PixelNorm",
|
||||||
|
"build_normalization_layer",
|
||||||
|
"ResBlock1",
|
||||||
|
"ResBlock2",
|
||||||
|
"ResnetBlock",
|
||||||
|
"LRELU_SLOPE",
|
||||||
|
"Upsample",
|
||||||
|
"build_upsampling_path",
|
||||||
|
]
|
||||||
108
mlx_video/models/ltx/audio_vae/attention.py
Normal file
108
mlx_video/models/ltx/audio_vae/attention.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""Attention blocks for audio VAE."""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .normalization import NormType, build_normalization_layer
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionType(Enum):
|
||||||
|
"""Enum for specifying the attention mechanism type."""
|
||||||
|
|
||||||
|
VANILLA = "vanilla"
|
||||||
|
LINEAR = "linear"
|
||||||
|
NONE = "none"
|
||||||
|
|
||||||
|
|
||||||
|
class AttnBlock(nn.Module):
|
||||||
|
"""Self-attention block for audio VAE."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
norm_type: NormType = NormType.GROUP,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.norm = build_normalization_layer(in_channels, normtype=norm_type)
|
||||||
|
# Using Conv2d with kernel_size=1 for Q, K, V projections
|
||||||
|
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
"""
|
||||||
|
Forward pass through attention block.
|
||||||
|
Args:
|
||||||
|
x: Input tensor of shape (B, H, W, C) in MLX channels-last format
|
||||||
|
Returns:
|
||||||
|
Output tensor with attention applied (residual connection)
|
||||||
|
"""
|
||||||
|
h_ = x
|
||||||
|
h_ = self.norm(h_)
|
||||||
|
|
||||||
|
q = self.q(h_)
|
||||||
|
k = self.k(h_)
|
||||||
|
v = self.v(h_)
|
||||||
|
|
||||||
|
# Compute attention
|
||||||
|
# x shape: (B, H, W, C)
|
||||||
|
b, h, w, c = q.shape
|
||||||
|
|
||||||
|
# Reshape for attention: (B, H*W, C)
|
||||||
|
q = q.reshape(b, h * w, c)
|
||||||
|
k = k.reshape(b, h * w, c)
|
||||||
|
v = v.reshape(b, h * w, c)
|
||||||
|
|
||||||
|
# Attention: Q @ K^T / sqrt(d)
|
||||||
|
# q: (B, HW, C), k: (B, HW, C) -> k^T: (B, C, HW)
|
||||||
|
# w_: (B, HW, HW)
|
||||||
|
scale = float(c) ** (-0.5)
|
||||||
|
w_ = mx.matmul(q, k.transpose(0, 2, 1)) * scale
|
||||||
|
w_ = mx.softmax(w_, axis=-1)
|
||||||
|
|
||||||
|
# Attend to values
|
||||||
|
# w_: (B, HW, HW), v: (B, HW, C) -> h_: (B, HW, C)
|
||||||
|
h_ = mx.matmul(w_, v)
|
||||||
|
|
||||||
|
# Reshape back to spatial dims
|
||||||
|
h_ = h_.reshape(b, h, w, c)
|
||||||
|
|
||||||
|
h_ = self.proj_out(h_)
|
||||||
|
|
||||||
|
return x + h_
|
||||||
|
|
||||||
|
|
||||||
|
class Identity(nn.Module):
|
||||||
|
"""Identity module that returns input unchanged."""
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def make_attn(
|
||||||
|
in_channels: int,
|
||||||
|
attn_type: AttentionType = AttentionType.VANILLA,
|
||||||
|
norm_type: NormType = NormType.GROUP,
|
||||||
|
) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Create an attention module based on type.
|
||||||
|
Args:
|
||||||
|
in_channels: Number of input channels
|
||||||
|
attn_type: Type of attention mechanism
|
||||||
|
norm_type: Type of normalization
|
||||||
|
Returns:
|
||||||
|
Attention module
|
||||||
|
"""
|
||||||
|
if attn_type == AttentionType.VANILLA:
|
||||||
|
return AttnBlock(in_channels, norm_type=norm_type)
|
||||||
|
elif attn_type == AttentionType.NONE:
|
||||||
|
return Identity()
|
||||||
|
elif attn_type == AttentionType.LINEAR:
|
||||||
|
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown attention type: {attn_type}")
|
||||||
326
mlx_video/models/ltx/audio_vae/audio_vae.py
Normal file
326
mlx_video/models/ltx/audio_vae/audio_vae.py
Normal file
@@ -0,0 +1,326 @@
|
|||||||
|
"""Audio VAE encoder and decoder for LTX-2."""
|
||||||
|
|
||||||
|
from typing import Set, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .attention import AttentionType, make_attn
|
||||||
|
from .causal_conv_2d import make_conv2d
|
||||||
|
from .causality_axis import CausalityAxis
|
||||||
|
from .downsample import build_downsampling_path
|
||||||
|
from .normalization import NormType, build_normalization_layer
|
||||||
|
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||||
|
from .resnet import ResnetBlock
|
||||||
|
from .upsample import build_upsampling_path
|
||||||
|
|
||||||
|
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||||
|
|
||||||
|
|
||||||
|
def build_mid_block(
|
||||||
|
channels: int,
|
||||||
|
temb_channels: int,
|
||||||
|
dropout: float,
|
||||||
|
norm_type: NormType,
|
||||||
|
causality_axis: CausalityAxis,
|
||||||
|
attn_type: AttentionType,
|
||||||
|
add_attention: bool,
|
||||||
|
) -> dict:
|
||||||
|
"""Build the middle block with two ResNet blocks and optional attention."""
|
||||||
|
mid = {}
|
||||||
|
mid["block_1"] = ResnetBlock(
|
||||||
|
in_channels=channels,
|
||||||
|
out_channels=channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
mid["attn_1"] = (
|
||||||
|
make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else None
|
||||||
|
)
|
||||||
|
mid["block_2"] = ResnetBlock(
|
||||||
|
in_channels=channels,
|
||||||
|
out_channels=channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
return mid
|
||||||
|
|
||||||
|
|
||||||
|
def run_mid_block(mid: dict, features: mx.array) -> mx.array:
|
||||||
|
"""Run features through the middle block."""
|
||||||
|
features = mid["block_1"](features, temb=None)
|
||||||
|
if mid["attn_1"] is not None:
|
||||||
|
features = mid["attn_1"](features)
|
||||||
|
return mid["block_2"](features, temb=None)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioDecoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Symmetric decoder that reconstructs audio spectrograms from latent features.
|
||||||
|
The decoder mirrors the encoder structure with configurable channel multipliers,
|
||||||
|
attention resolutions, and causal convolutions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ch: int = 128,
|
||||||
|
out_ch: int = 2,
|
||||||
|
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
||||||
|
num_res_blocks: int = 2,
|
||||||
|
attn_resolutions: Set[int] = None,
|
||||||
|
resolution: int = 256,
|
||||||
|
z_channels: int = 8,
|
||||||
|
norm_type: NormType = NormType.PIXEL,
|
||||||
|
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
mid_block_add_attention: bool = True,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
mel_hop_length: int = 160,
|
||||||
|
is_causal: bool = True,
|
||||||
|
mel_bins: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the AudioDecoder.
|
||||||
|
Args:
|
||||||
|
ch: Base number of feature channels
|
||||||
|
out_ch: Number of output channels (2 for stereo)
|
||||||
|
ch_mult: Multiplicative factors for channels at each resolution
|
||||||
|
num_res_blocks: Number of residual blocks per resolution
|
||||||
|
attn_resolutions: Resolutions at which to apply attention
|
||||||
|
resolution: Input spatial resolution
|
||||||
|
z_channels: Number of latent channels
|
||||||
|
norm_type: Normalization type
|
||||||
|
causality_axis: Axis for causal convolutions
|
||||||
|
dropout: Dropout probability
|
||||||
|
mid_block_add_attention: Whether to add attention in middle block
|
||||||
|
sample_rate: Audio sample rate
|
||||||
|
mel_hop_length: Hop length for mel spectrogram
|
||||||
|
is_causal: Whether to use causal convolutions
|
||||||
|
mel_bins: Number of mel frequency bins
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if attn_resolutions is None:
|
||||||
|
attn_resolutions = {8, 16, 32}
|
||||||
|
|
||||||
|
# Internal behavioral defaults
|
||||||
|
resamp_with_conv = True
|
||||||
|
attn_type = AttentionType.VANILLA
|
||||||
|
|
||||||
|
# Per-channel statistics for denormalizing latents
|
||||||
|
# Uses ch (base channel count) to match the patchified latent dimension
|
||||||
|
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
|
||||||
|
# After patchify: (B, T, z_channels * latent_mel_bins) = (B, T, 128)
|
||||||
|
# ch=128 matches this dimension, so use ch for per_channel_statistics
|
||||||
|
self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.mel_hop_length = mel_hop_length
|
||||||
|
self.is_causal = is_causal
|
||||||
|
self.mel_bins = mel_bins
|
||||||
|
|
||||||
|
self.patchifier = AudioPatchifier(
|
||||||
|
patch_size=1,
|
||||||
|
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
hop_length=mel_hop_length,
|
||||||
|
is_causal=is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ch = ch
|
||||||
|
self.temb_ch = 0
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.out_ch = out_ch
|
||||||
|
self.give_pre_end = False
|
||||||
|
self.tanh_out = False
|
||||||
|
self.norm_type = norm_type
|
||||||
|
self.z_channels = z_channels
|
||||||
|
self.channel_multipliers = ch_mult
|
||||||
|
self.attn_resolutions = attn_resolutions
|
||||||
|
self.causality_axis = causality_axis
|
||||||
|
self.attn_type = attn_type
|
||||||
|
|
||||||
|
base_block_channels = ch * self.channel_multipliers[-1]
|
||||||
|
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
|
||||||
|
self.z_shape = (1, z_channels, base_resolution, base_resolution)
|
||||||
|
|
||||||
|
self.conv_in = make_conv2d(
|
||||||
|
z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mid = build_mid_block(
|
||||||
|
channels=base_block_channels,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=self.norm_type,
|
||||||
|
causality_axis=self.causality_axis,
|
||||||
|
attn_type=self.attn_type,
|
||||||
|
add_attention=mid_block_add_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.up, final_block_channels = build_upsampling_path(
|
||||||
|
ch=ch,
|
||||||
|
ch_mult=ch_mult,
|
||||||
|
num_resolutions=self.num_resolutions,
|
||||||
|
num_res_blocks=num_res_blocks,
|
||||||
|
resolution=resolution,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=self.norm_type,
|
||||||
|
causality_axis=self.causality_axis,
|
||||||
|
attn_type=self.attn_type,
|
||||||
|
attn_resolutions=attn_resolutions,
|
||||||
|
resamp_with_conv=resamp_with_conv,
|
||||||
|
initial_block_channels=base_block_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
|
||||||
|
self.conv_out = make_conv2d(
|
||||||
|
final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, sample: mx.array) -> mx.array:
|
||||||
|
"""
|
||||||
|
Decode latent features back to audio spectrograms.
|
||||||
|
Args:
|
||||||
|
sample: Encoded latent representation of shape (B, H, W, C) in MLX format
|
||||||
|
or (B, C, H, W) in PyTorch format (will be transposed)
|
||||||
|
Returns:
|
||||||
|
Reconstructed audio spectrogram
|
||||||
|
"""
|
||||||
|
# Handle input format - if channels are in dim 1, transpose to channels-last
|
||||||
|
if sample.shape[1] == self.z_channels and sample.ndim == 4:
|
||||||
|
# PyTorch format (B, C, H, W) -> MLX format (B, H, W, C)
|
||||||
|
sample = mx.transpose(sample, (0, 2, 3, 1))
|
||||||
|
|
||||||
|
sample, target_shape = self._denormalize_latents(sample)
|
||||||
|
|
||||||
|
h = self.conv_in(sample)
|
||||||
|
h = run_mid_block(self.mid, h)
|
||||||
|
h = self._run_upsampling_path(h)
|
||||||
|
h = self._finalize_output(h)
|
||||||
|
|
||||||
|
return self._adjust_output_shape(h, target_shape)
|
||||||
|
|
||||||
|
def _denormalize_latents(self, sample: mx.array) -> tuple[mx.array, AudioLatentShape]:
|
||||||
|
"""Denormalize latents using per-channel statistics."""
|
||||||
|
# sample shape: (B, H, W, C) in MLX format
|
||||||
|
latent_shape = AudioLatentShape(
|
||||||
|
batch=sample.shape[0],
|
||||||
|
channels=sample.shape[3], # channels last
|
||||||
|
frames=sample.shape[1], # height = frames
|
||||||
|
mel_bins=sample.shape[2], # width = mel_bins
|
||||||
|
)
|
||||||
|
|
||||||
|
sample_patched = self.patchifier.patchify(sample)
|
||||||
|
sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched)
|
||||||
|
sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
|
||||||
|
|
||||||
|
target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
|
||||||
|
if self.causality_axis != CausalityAxis.NONE:
|
||||||
|
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
|
||||||
|
|
||||||
|
target_shape = AudioLatentShape(
|
||||||
|
batch=latent_shape.batch,
|
||||||
|
channels=self.out_ch,
|
||||||
|
frames=target_frames,
|
||||||
|
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
|
||||||
|
)
|
||||||
|
|
||||||
|
return sample, target_shape
|
||||||
|
|
||||||
|
def _adjust_output_shape(
|
||||||
|
self,
|
||||||
|
decoded_output: mx.array,
|
||||||
|
target_shape: AudioLatentShape,
|
||||||
|
) -> mx.array:
|
||||||
|
"""
|
||||||
|
Adjust output shape to match target dimensions for variable-length audio.
|
||||||
|
Args:
|
||||||
|
decoded_output: Tensor of shape (B, H, W, C) in MLX format
|
||||||
|
target_shape: AudioLatentShape describing target dimensions
|
||||||
|
Returns:
|
||||||
|
Tensor adjusted to match target_shape exactly
|
||||||
|
"""
|
||||||
|
# Current output shape: (batch, frames, mel_bins, channels) in MLX format
|
||||||
|
_, current_time, current_freq, _ = decoded_output.shape
|
||||||
|
target_channels = target_shape.channels
|
||||||
|
target_time = target_shape.frames
|
||||||
|
target_freq = target_shape.mel_bins
|
||||||
|
|
||||||
|
# Step 1: Crop first to avoid exceeding target dimensions
|
||||||
|
decoded_output = decoded_output[
|
||||||
|
:, : min(current_time, target_time), : min(current_freq, target_freq), :target_channels
|
||||||
|
]
|
||||||
|
|
||||||
|
# Step 2: Calculate padding needed for time and frequency dimensions
|
||||||
|
time_padding_needed = target_time - decoded_output.shape[1]
|
||||||
|
freq_padding_needed = target_freq - decoded_output.shape[2]
|
||||||
|
|
||||||
|
# Step 3: Apply padding if needed
|
||||||
|
if time_padding_needed > 0 or freq_padding_needed > 0:
|
||||||
|
# MLX pad: [(before_0, after_0), ...]
|
||||||
|
# For (B, H, W, C): H=time, W=freq
|
||||||
|
padding = [
|
||||||
|
(0, 0), # batch
|
||||||
|
(0, max(time_padding_needed, 0)), # time
|
||||||
|
(0, max(freq_padding_needed, 0)), # freq
|
||||||
|
(0, 0), # channels
|
||||||
|
]
|
||||||
|
decoded_output = mx.pad(decoded_output, padding)
|
||||||
|
|
||||||
|
# Step 4: Final safety crop to ensure exact target shape
|
||||||
|
decoded_output = decoded_output[:, :target_time, :target_freq, :target_channels]
|
||||||
|
|
||||||
|
# Transpose back to PyTorch format (B, C, H, W) for vocoder compatibility
|
||||||
|
decoded_output = mx.transpose(decoded_output, (0, 3, 1, 2))
|
||||||
|
|
||||||
|
return decoded_output
|
||||||
|
|
||||||
|
def _run_upsampling_path(self, h: mx.array) -> mx.array:
|
||||||
|
"""Run through upsampling path."""
|
||||||
|
for level in reversed(range(self.num_resolutions)):
|
||||||
|
stage = self.up[level]
|
||||||
|
for block_idx in range(len(stage["block"])):
|
||||||
|
h = stage["block"][block_idx](h, temb=None)
|
||||||
|
if block_idx in stage["attn"]:
|
||||||
|
h = stage["attn"][block_idx](h)
|
||||||
|
|
||||||
|
if level != 0 and "upsample" in stage:
|
||||||
|
h = stage["upsample"](h)
|
||||||
|
|
||||||
|
return h
|
||||||
|
|
||||||
|
def _finalize_output(self, h: mx.array) -> mx.array:
|
||||||
|
"""Apply final normalization and convolution."""
|
||||||
|
if self.give_pre_end:
|
||||||
|
return h
|
||||||
|
|
||||||
|
h = self.norm_out(h)
|
||||||
|
h = nn.silu(h)
|
||||||
|
h = self.conv_out(h)
|
||||||
|
return mx.tanh(h) if self.tanh_out else h
|
||||||
|
|
||||||
|
|
||||||
|
def decode_audio(latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder") -> mx.array:
|
||||||
|
"""
|
||||||
|
Decode an audio latent representation using the provided audio decoder and vocoder.
|
||||||
|
Args:
|
||||||
|
latent: Input audio latent tensor
|
||||||
|
audio_decoder: Model to decode the latent to spectrogram
|
||||||
|
vocoder: Model to convert spectrogram to audio waveform
|
||||||
|
Returns:
|
||||||
|
Decoded audio as a float tensor
|
||||||
|
"""
|
||||||
|
decoded_audio = audio_decoder(latent)
|
||||||
|
decoded_audio = vocoder(decoded_audio)
|
||||||
|
# Remove batch dimension if present
|
||||||
|
if decoded_audio.shape[0] == 1:
|
||||||
|
decoded_audio = decoded_audio[0]
|
||||||
|
return decoded_audio.astype(mx.float32)
|
||||||
146
mlx_video/models/ltx/audio_vae/causal_conv_2d.py
Normal file
146
mlx_video/models/ltx/audio_vae/causal_conv_2d.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""Causal 2D convolutions for audio VAE."""
|
||||||
|
|
||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .causality_axis import CausalityAxis
|
||||||
|
|
||||||
|
|
||||||
|
def _pair(x: Union[int, Tuple[int, int]]) -> Tuple[int, int]:
|
||||||
|
"""Convert int or tuple to tuple pair."""
|
||||||
|
if isinstance(x, int):
|
||||||
|
return (x, x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv2d(nn.Module):
|
||||||
|
"""
|
||||||
|
A causal 2D convolution.
|
||||||
|
This layer ensures that the output at time `t` only depends on inputs
|
||||||
|
at time `t` and earlier. It achieves this by applying asymmetric padding
|
||||||
|
to the time dimension before the convolution.
|
||||||
|
|
||||||
|
Note: MLX Conv2d expects input shape (N, H, W, C) - channels last.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Union[int, Tuple[int, int]],
|
||||||
|
stride: int = 1,
|
||||||
|
dilation: Union[int, Tuple[int, int]] = 1,
|
||||||
|
groups: int = 1,
|
||||||
|
bias: bool = True,
|
||||||
|
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.causality_axis = causality_axis
|
||||||
|
|
||||||
|
# Ensure kernel_size and dilation are tuples
|
||||||
|
kernel_size = _pair(kernel_size)
|
||||||
|
dilation = _pair(dilation)
|
||||||
|
|
||||||
|
# Calculate padding dimensions
|
||||||
|
pad_h = (kernel_size[0] - 1) * dilation[0]
|
||||||
|
pad_w = (kernel_size[1] - 1) * dilation[1]
|
||||||
|
|
||||||
|
# Store padding for manual application
|
||||||
|
# MLX pad order: [(before_axis0, after_axis0), (before_axis1, after_axis1), ...]
|
||||||
|
# For (N, H, W, C) format: axis 1 is H (height), axis 2 is W (width)
|
||||||
|
if self.causality_axis == CausalityAxis.NONE:
|
||||||
|
# Non-causal: symmetric padding
|
||||||
|
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w // 2, pad_w - pad_w // 2)
|
||||||
|
elif self.causality_axis in (CausalityAxis.WIDTH, CausalityAxis.WIDTH_COMPATIBILITY):
|
||||||
|
# Causal on width: pad left (before width axis)
|
||||||
|
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w, 0)
|
||||||
|
elif self.causality_axis == CausalityAxis.HEIGHT:
|
||||||
|
# Causal on height: pad top (before height axis)
|
||||||
|
self.padding = (pad_h, 0, pad_w // 2, pad_w - pad_w // 2)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid causality_axis: {causality_axis}")
|
||||||
|
|
||||||
|
# The internal convolution layer uses no padding, as we handle it manually
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=0,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
"""
|
||||||
|
Forward pass with causal padding.
|
||||||
|
Args:
|
||||||
|
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
|
||||||
|
Returns:
|
||||||
|
Output tensor after causal convolution
|
||||||
|
"""
|
||||||
|
# Apply causal padding before convolution
|
||||||
|
# padding format: (pad_h_top, pad_h_bottom, pad_w_left, pad_w_right)
|
||||||
|
pad_h_top, pad_h_bottom, pad_w_left, pad_w_right = self.padding
|
||||||
|
|
||||||
|
if any(p > 0 for p in self.padding):
|
||||||
|
# MLX pad expects: [(before_0, after_0), (before_1, after_1), ...]
|
||||||
|
# For (N, H, W, C): axis 0=N, axis 1=H, axis 2=W, axis 3=C
|
||||||
|
x = mx.pad(x, [(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)])
|
||||||
|
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
def make_conv2d(
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Union[int, Tuple[int, int]],
|
||||||
|
stride: int = 1,
|
||||||
|
padding: Union[int, Tuple[int, int], None] = None,
|
||||||
|
dilation: int = 1,
|
||||||
|
groups: int = 1,
|
||||||
|
bias: bool = True,
|
||||||
|
causality_axis: CausalityAxis | None = None,
|
||||||
|
) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Create a 2D convolution layer that can be either causal or non-causal.
|
||||||
|
Args:
|
||||||
|
in_channels: Number of input channels
|
||||||
|
out_channels: Number of output channels
|
||||||
|
kernel_size: Size of the convolution kernel
|
||||||
|
stride: Convolution stride
|
||||||
|
padding: Padding (if None, will be calculated based on causal flag)
|
||||||
|
dilation: Dilation rate
|
||||||
|
groups: Number of groups for grouped convolution
|
||||||
|
bias: Whether to use bias
|
||||||
|
causality_axis: Dimension along which to apply causality.
|
||||||
|
Returns:
|
||||||
|
Either a regular Conv2d or CausalConv2d layer
|
||||||
|
"""
|
||||||
|
if causality_axis is not None:
|
||||||
|
# For causal convolution, padding is handled internally by CausalConv2d
|
||||||
|
return CausalConv2d(
|
||||||
|
in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For non-causal convolution, use symmetric padding if not specified
|
||||||
|
if padding is None:
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
padding = kernel_size // 2
|
||||||
|
else:
|
||||||
|
padding = tuple(k // 2 for k in kernel_size)
|
||||||
|
|
||||||
|
return nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
groups,
|
||||||
|
bias,
|
||||||
|
)
|
||||||
12
mlx_video/models/ltx/audio_vae/causality_axis.py
Normal file
12
mlx_video/models/ltx/audio_vae/causality_axis.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""Causality axis enum for specifying causal convolution dimensions."""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class CausalityAxis(Enum):
|
||||||
|
"""Enum for specifying the causality axis in causal convolutions."""
|
||||||
|
|
||||||
|
NONE = None
|
||||||
|
WIDTH = "width"
|
||||||
|
HEIGHT = "height"
|
||||||
|
WIDTH_COMPATIBILITY = "width-compatibility"
|
||||||
127
mlx_video/models/ltx/audio_vae/downsample.py
Normal file
127
mlx_video/models/ltx/audio_vae/downsample.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
"""Downsampling layers for audio VAE."""
|
||||||
|
|
||||||
|
from typing import Set, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .attention import AttentionType, make_attn
|
||||||
|
from .causality_axis import CausalityAxis
|
||||||
|
from .normalization import NormType
|
||||||
|
from .resnet import ResnetBlock
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(nn.Module):
|
||||||
|
"""
|
||||||
|
A downsampling layer that can use either a strided convolution
|
||||||
|
or average pooling. Supports standard and causal padding for the
|
||||||
|
convolutional mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
with_conv: bool,
|
||||||
|
causality_axis: CausalityAxis = CausalityAxis.WIDTH,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.with_conv = with_conv
|
||||||
|
self.causality_axis = causality_axis
|
||||||
|
|
||||||
|
if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
|
||||||
|
raise ValueError("causality is only supported when `with_conv=True`.")
|
||||||
|
|
||||||
|
if self.with_conv:
|
||||||
|
# Do time downsampling here
|
||||||
|
# no asymmetric padding in MLX conv, must do it ourselves
|
||||||
|
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
"""
|
||||||
|
Forward pass with downsampling.
|
||||||
|
Args:
|
||||||
|
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
|
||||||
|
Returns:
|
||||||
|
Downsampled tensor
|
||||||
|
"""
|
||||||
|
if self.with_conv:
|
||||||
|
# Padding tuple is in the order: (left, right, top, bottom) for PyTorch
|
||||||
|
# For MLX pad: [(before_axis0, after_axis0), ...]
|
||||||
|
# x shape: (N, H, W, C) -> pad on H and W axes
|
||||||
|
if self.causality_axis == CausalityAxis.NONE:
|
||||||
|
# pad: (left=0, right=1, top=0, bottom=1)
|
||||||
|
pad = [(0, 0), (0, 1), (0, 1), (0, 0)]
|
||||||
|
elif self.causality_axis == CausalityAxis.WIDTH:
|
||||||
|
# pad: (left=2, right=0, top=0, bottom=1)
|
||||||
|
pad = [(0, 0), (0, 1), (2, 0), (0, 0)]
|
||||||
|
elif self.causality_axis == CausalityAxis.HEIGHT:
|
||||||
|
# pad: (left=0, right=1, top=2, bottom=0)
|
||||||
|
pad = [(0, 0), (2, 0), (0, 1), (0, 0)]
|
||||||
|
elif self.causality_axis == CausalityAxis.WIDTH_COMPATIBILITY:
|
||||||
|
# pad: (left=1, right=0, top=0, bottom=1)
|
||||||
|
pad = [(0, 0), (0, 1), (1, 0), (0, 0)]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||||
|
|
||||||
|
x = mx.pad(x, pad, constant_values=0)
|
||||||
|
x = self.conv(x)
|
||||||
|
else:
|
||||||
|
# Average pooling with 2x2 kernel and stride 2
|
||||||
|
# MLX doesn't have built-in avg_pool2d, implement manually
|
||||||
|
# x shape: (N, H, W, C)
|
||||||
|
n, h, w, c = x.shape
|
||||||
|
# Reshape to (N, H//2, 2, W//2, 2, C) and mean over pooling dims
|
||||||
|
x = x.reshape(n, h // 2, 2, w // 2, 2, c)
|
||||||
|
x = mx.mean(x, axis=(2, 4))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def build_downsampling_path(
|
||||||
|
*,
|
||||||
|
ch: int,
|
||||||
|
ch_mult: Tuple[int, ...],
|
||||||
|
num_resolutions: int,
|
||||||
|
num_res_blocks: int,
|
||||||
|
resolution: int,
|
||||||
|
temb_channels: int,
|
||||||
|
dropout: float,
|
||||||
|
norm_type: NormType,
|
||||||
|
causality_axis: CausalityAxis,
|
||||||
|
attn_type: AttentionType,
|
||||||
|
attn_resolutions: Set[int],
|
||||||
|
resamp_with_conv: bool,
|
||||||
|
) -> tuple[dict, int]:
|
||||||
|
"""Build the downsampling path with residual blocks, attention, and downsampling layers."""
|
||||||
|
down_modules = {}
|
||||||
|
curr_res = resolution
|
||||||
|
in_ch_mult = (1, *tuple(ch_mult))
|
||||||
|
block_in = ch
|
||||||
|
|
||||||
|
for i_level in range(num_resolutions):
|
||||||
|
stage = {}
|
||||||
|
stage["block"] = {}
|
||||||
|
stage["attn"] = {}
|
||||||
|
block_in = ch * in_ch_mult[i_level]
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
|
||||||
|
for i_block in range(num_res_blocks):
|
||||||
|
stage["block"][i_block] = ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
block_in = block_out
|
||||||
|
if curr_res in attn_resolutions:
|
||||||
|
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
|
||||||
|
|
||||||
|
if i_level != num_resolutions - 1:
|
||||||
|
stage["downsample"] = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||||
|
curr_res = curr_res // 2
|
||||||
|
|
||||||
|
down_modules[i_level] = stage
|
||||||
|
|
||||||
|
return down_modules, block_in
|
||||||
59
mlx_video/models/ltx/audio_vae/normalization.py
Normal file
59
mlx_video/models/ltx/audio_vae/normalization.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""Normalization layers for audio VAE."""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class NormType(Enum):
|
||||||
|
"""Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
|
||||||
|
|
||||||
|
GROUP = "group"
|
||||||
|
PIXEL = "pixel"
|
||||||
|
|
||||||
|
|
||||||
|
class PixelNorm(nn.Module):
|
||||||
|
"""
|
||||||
|
Per-pixel (per-location) RMS normalization layer.
|
||||||
|
For each element along the chosen dimension, this layer normalizes the tensor
|
||||||
|
by the root-mean-square of its values across that dimension:
|
||||||
|
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
dim: Dimension along which to compute the RMS (typically channels).
|
||||||
|
eps: Small constant added for numerical stability.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
"""Apply RMS normalization along the configured dimension."""
|
||||||
|
mean_sq = mx.mean(x**2, axis=self.dim, keepdims=True)
|
||||||
|
rms = mx.sqrt(mean_sq + self.eps)
|
||||||
|
return x / rms
|
||||||
|
|
||||||
|
|
||||||
|
def build_normalization_layer(
|
||||||
|
in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
|
||||||
|
) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Create a normalization layer based on the normalization type.
|
||||||
|
Args:
|
||||||
|
in_channels: Number of input channels
|
||||||
|
num_groups: Number of groups for group normalization
|
||||||
|
normtype: Type of normalization: "group" or "pixel"
|
||||||
|
Returns:
|
||||||
|
A normalization layer
|
||||||
|
"""
|
||||||
|
if normtype == NormType.GROUP:
|
||||||
|
return nn.GroupNorm(num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True)
|
||||||
|
if normtype == NormType.PIXEL:
|
||||||
|
# For MLX channels-last format (B, H, W, C), normalize along channels (dim=-1)
|
||||||
|
# PyTorch uses dim=1 for channels-first format (B, C, H, W)
|
||||||
|
return PixelNorm(dim=-1, eps=1e-6)
|
||||||
|
raise ValueError(f"Invalid normalization type: {normtype}")
|
||||||
98
mlx_video/models/ltx/audio_vae/ops.py
Normal file
98
mlx_video/models/ltx/audio_vae/ops.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""Audio processing utilities for audio VAE."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AudioLatentShape:
|
||||||
|
"""Shape descriptor for audio latent representations."""
|
||||||
|
|
||||||
|
batch: int
|
||||||
|
channels: int
|
||||||
|
frames: int
|
||||||
|
mel_bins: int
|
||||||
|
|
||||||
|
|
||||||
|
class PerChannelStatistics(nn.Module):
|
||||||
|
"""
|
||||||
|
Per-channel statistics for normalizing and denormalizing the latent representation.
|
||||||
|
This statistics is computed over the entire dataset and stored in model's checkpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, latent_channels: int = 128) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.latent_channels = latent_channels
|
||||||
|
# Initialize buffers - will be loaded from weights
|
||||||
|
# Using underscores for MLX compatibility with weight loading
|
||||||
|
self._std_of_means = mx.ones((latent_channels,))
|
||||||
|
self._mean_of_means = mx.zeros((latent_channels,))
|
||||||
|
|
||||||
|
def un_normalize(self, x: mx.array) -> mx.array:
|
||||||
|
"""Denormalize latent representation."""
|
||||||
|
# Broadcast statistics to match x shape
|
||||||
|
# x shape: (B, C, ...) or (B, ..., C)
|
||||||
|
std = self._std_of_means.astype(x.dtype)
|
||||||
|
mean = self._mean_of_means.astype(x.dtype)
|
||||||
|
return (x * std) + mean
|
||||||
|
|
||||||
|
def normalize(self, x: mx.array) -> mx.array:
|
||||||
|
"""Normalize latent representation."""
|
||||||
|
std = self._std_of_means.astype(x.dtype)
|
||||||
|
mean = self._mean_of_means.astype(x.dtype)
|
||||||
|
return (x - mean) / std
|
||||||
|
|
||||||
|
|
||||||
|
class AudioPatchifier:
|
||||||
|
"""
|
||||||
|
Audio patchifier for converting between audio latents and patches.
|
||||||
|
Combines channels and mel_bins dimensions for per-channel statistics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: int = 1,
|
||||||
|
audio_latent_downsample_factor: int = 4,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
hop_length: int = 160,
|
||||||
|
is_causal: bool = True,
|
||||||
|
):
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.audio_latent_downsample_factor = audio_latent_downsample_factor
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.is_causal = is_causal
|
||||||
|
|
||||||
|
def patchify(self, x: mx.array) -> mx.array:
|
||||||
|
"""Convert audio latents to patches.
|
||||||
|
|
||||||
|
Input shape: (B, T, F, C) in MLX format (channels last)
|
||||||
|
Output shape: (B, T, C*F) - flattened for per-channel statistics
|
||||||
|
|
||||||
|
The output order is (c f) to match PyTorch's "b c t f -> b t (c f)".
|
||||||
|
"""
|
||||||
|
# x shape: (B, T, F, C) e.g., (1, 68, 16, 8)
|
||||||
|
b, t, f, c = x.shape
|
||||||
|
# Transpose to (B, T, C, F) for correct (c f) ordering
|
||||||
|
x = mx.transpose(x, (0, 1, 3, 2))
|
||||||
|
# Reshape to (B, T, C*F) e.g., (1, 68, 128)
|
||||||
|
return x.reshape(b, t, c * f)
|
||||||
|
|
||||||
|
def unpatchify(self, x: mx.array, latent_shape: AudioLatentShape) -> mx.array:
|
||||||
|
"""Convert patches back to audio latents.
|
||||||
|
|
||||||
|
Input shape: (B, T, C*F)
|
||||||
|
Output shape: (B, T, F, C) in MLX format
|
||||||
|
|
||||||
|
Reverses patchify's "b t (c f) -> b c t f" then transposes to MLX format.
|
||||||
|
"""
|
||||||
|
# x shape: (B, T, C*F) e.g., (1, 68, 128)
|
||||||
|
b, t, cf = x.shape
|
||||||
|
c = latent_shape.channels
|
||||||
|
f = latent_shape.mel_bins
|
||||||
|
# Reshape to (B, T, C, F)
|
||||||
|
x = x.reshape(b, t, c, f)
|
||||||
|
# Transpose to MLX format (B, T, F, C)
|
||||||
|
return mx.transpose(x, (0, 1, 3, 2))
|
||||||
185
mlx_video/models/ltx/audio_vae/resnet.py
Normal file
185
mlx_video/models/ltx/audio_vae/resnet.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
"""ResNet blocks for audio VAE and vocoder."""
|
||||||
|
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .causal_conv_2d import make_conv2d
|
||||||
|
from .causality_axis import CausalityAxis
|
||||||
|
from .normalization import NormType, build_normalization_layer
|
||||||
|
|
||||||
|
LRELU_SLOPE = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
def leaky_relu(x: mx.array, negative_slope: float = LRELU_SLOPE) -> mx.array:
|
||||||
|
"""Leaky ReLU activation."""
|
||||||
|
return mx.maximum(x, x * negative_slope)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock1(nn.Module):
|
||||||
|
"""1D ResNet block for vocoder with dilated convolutions."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
dilation: Tuple[int, int, int] = (1, 3, 5),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# First set of convolutions with different dilations
|
||||||
|
self.convs1 = {
|
||||||
|
i: nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
dilation=d,
|
||||||
|
padding=(kernel_size - 1) * d // 2,
|
||||||
|
)
|
||||||
|
for i, d in enumerate(dilation)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Second set of convolutions with dilation=1
|
||||||
|
self.convs2 = {
|
||||||
|
i: nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
padding=(kernel_size - 1) // 2,
|
||||||
|
)
|
||||||
|
for i in range(len(dilation))
|
||||||
|
}
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
"""Forward pass through residual blocks."""
|
||||||
|
for i in range(len(self.convs1)):
|
||||||
|
xt = leaky_relu(x, LRELU_SLOPE)
|
||||||
|
xt = self.convs1[i](xt)
|
||||||
|
xt = leaky_relu(xt, LRELU_SLOPE)
|
||||||
|
xt = self.convs2[i](xt)
|
||||||
|
x = xt + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock2(nn.Module):
|
||||||
|
"""1D ResNet block for vocoder (alternative version)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
dilation: Tuple[int, int] = (1, 3),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.convs = {
|
||||||
|
i: nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
dilation=d,
|
||||||
|
padding=(kernel_size - 1) * d // 2,
|
||||||
|
)
|
||||||
|
for i, d in enumerate(dilation)
|
||||||
|
}
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
"""Forward pass through residual blocks."""
|
||||||
|
for i in range(len(self.convs)):
|
||||||
|
xt = leaky_relu(x, LRELU_SLOPE)
|
||||||
|
xt = self.convs[i](xt)
|
||||||
|
x = xt + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlock(nn.Module):
|
||||||
|
"""2D ResNet block for audio VAE encoder/decoder."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int | None = None,
|
||||||
|
conv_shortcut: bool = False,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
temb_channels: int = 512,
|
||||||
|
norm_type: NormType = NormType.GROUP,
|
||||||
|
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.causality_axis = causality_axis
|
||||||
|
|
||||||
|
if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP:
|
||||||
|
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.use_conv_shortcut = conv_shortcut
|
||||||
|
self.temb_channels = temb_channels
|
||||||
|
|
||||||
|
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
|
||||||
|
self.conv1 = make_conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||||
|
)
|
||||||
|
|
||||||
|
if temb_channels > 0:
|
||||||
|
self.temb_proj = nn.Linear(temb_channels, out_channels)
|
||||||
|
|
||||||
|
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
|
||||||
|
self.dropout_rate = dropout
|
||||||
|
self.conv2 = make_conv2d(
|
||||||
|
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
if self.use_conv_shortcut:
|
||||||
|
self.conv_shortcut = make_conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.nin_shortcut = make_conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
temb: mx.array | None = None,
|
||||||
|
) -> mx.array:
|
||||||
|
"""
|
||||||
|
Forward pass through ResNet block.
|
||||||
|
Args:
|
||||||
|
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
|
||||||
|
temb: Optional time embedding tensor
|
||||||
|
Returns:
|
||||||
|
Output tensor
|
||||||
|
"""
|
||||||
|
h = x
|
||||||
|
h = self.norm1(h)
|
||||||
|
h = nn.silu(h)
|
||||||
|
h = self.conv1(h)
|
||||||
|
|
||||||
|
if temb is not None and self.temb_channels > 0:
|
||||||
|
# temb: (B, temb_channels) -> (B, out_channels)
|
||||||
|
# Need to add spatial dims: (B, 1, 1, out_channels) for broadcasting
|
||||||
|
h = h + mx.expand_dims(mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1)
|
||||||
|
|
||||||
|
h = self.norm2(h)
|
||||||
|
h = nn.silu(h)
|
||||||
|
if self.dropout_rate > 0:
|
||||||
|
h = nn.Dropout(self.dropout_rate)(h)
|
||||||
|
h = self.conv2(h)
|
||||||
|
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
if self.use_conv_shortcut:
|
||||||
|
x = self.conv_shortcut(x)
|
||||||
|
else:
|
||||||
|
x = self.nin_shortcut(x)
|
||||||
|
|
||||||
|
return x + h
|
||||||
135
mlx_video/models/ltx/audio_vae/upsample.py
Normal file
135
mlx_video/models/ltx/audio_vae/upsample.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""Upsampling layers for audio VAE."""
|
||||||
|
|
||||||
|
from typing import Set, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .attention import AttentionType, make_attn
|
||||||
|
from .causal_conv_2d import make_conv2d
|
||||||
|
from .causality_axis import CausalityAxis
|
||||||
|
from .normalization import NormType
|
||||||
|
from .resnet import ResnetBlock
|
||||||
|
|
||||||
|
|
||||||
|
def nearest_neighbor_upsample(x: mx.array, scale_factor: int = 2) -> mx.array:
|
||||||
|
"""
|
||||||
|
Nearest neighbor upsampling for 4D tensors.
|
||||||
|
Args:
|
||||||
|
x: Input tensor of shape (N, H, W, C)
|
||||||
|
scale_factor: Upsampling factor
|
||||||
|
Returns:
|
||||||
|
Upsampled tensor of shape (N, H*scale_factor, W*scale_factor, C)
|
||||||
|
"""
|
||||||
|
n, h, w, c = x.shape
|
||||||
|
# Repeat along height and width
|
||||||
|
x = mx.repeat(x, scale_factor, axis=1)
|
||||||
|
x = mx.repeat(x, scale_factor, axis=2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample(nn.Module):
|
||||||
|
"""Upsampling layer with optional convolution."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
with_conv: bool,
|
||||||
|
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.with_conv = with_conv
|
||||||
|
self.causality_axis = causality_axis
|
||||||
|
if self.with_conv:
|
||||||
|
self.conv = make_conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
"""
|
||||||
|
Forward pass with upsampling.
|
||||||
|
Args:
|
||||||
|
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
|
||||||
|
Returns:
|
||||||
|
Upsampled tensor
|
||||||
|
"""
|
||||||
|
# Nearest neighbor 2x upsampling
|
||||||
|
x = nearest_neighbor_upsample(x, scale_factor=2)
|
||||||
|
|
||||||
|
if self.with_conv:
|
||||||
|
x = self.conv(x)
|
||||||
|
# Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
|
||||||
|
# For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
|
||||||
|
# The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
|
||||||
|
# So the output elements rely on the following windows:
|
||||||
|
# 0: [-,-,0]
|
||||||
|
# 1: [-,0,0]
|
||||||
|
# 2: [0,0,1]
|
||||||
|
# 3: [0,1,1]
|
||||||
|
# 4: [1,1,2]
|
||||||
|
# 5: [1,2,2]
|
||||||
|
# Notice that the first and second elements in the output rely only on the first element in the input,
|
||||||
|
# while all other elements rely on two elements in the input.
|
||||||
|
# So we can drop the first element to undo the padding (rather than the last element).
|
||||||
|
# This is a no-op for non-causal convolutions.
|
||||||
|
if self.causality_axis == CausalityAxis.NONE:
|
||||||
|
pass # x remains unchanged
|
||||||
|
elif self.causality_axis == CausalityAxis.HEIGHT:
|
||||||
|
x = x[:, 1:, :, :]
|
||||||
|
elif self.causality_axis == CausalityAxis.WIDTH:
|
||||||
|
x = x[:, :, 1:, :]
|
||||||
|
elif self.causality_axis == CausalityAxis.WIDTH_COMPATIBILITY:
|
||||||
|
pass # x remains unchanged
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def build_upsampling_path(
|
||||||
|
*,
|
||||||
|
ch: int,
|
||||||
|
ch_mult: Tuple[int, ...],
|
||||||
|
num_resolutions: int,
|
||||||
|
num_res_blocks: int,
|
||||||
|
resolution: int,
|
||||||
|
temb_channels: int,
|
||||||
|
dropout: float,
|
||||||
|
norm_type: NormType,
|
||||||
|
causality_axis: CausalityAxis,
|
||||||
|
attn_type: AttentionType,
|
||||||
|
attn_resolutions: Set[int],
|
||||||
|
resamp_with_conv: bool,
|
||||||
|
initial_block_channels: int,
|
||||||
|
) -> tuple[dict, int]:
|
||||||
|
"""Build the upsampling path with residual blocks, attention, and upsampling layers."""
|
||||||
|
up_modules = {}
|
||||||
|
block_in = initial_block_channels
|
||||||
|
curr_res = resolution // (2 ** (num_resolutions - 1))
|
||||||
|
|
||||||
|
for level in reversed(range(num_resolutions)):
|
||||||
|
stage = {}
|
||||||
|
stage["block"] = {}
|
||||||
|
stage["attn"] = {}
|
||||||
|
block_out = ch * ch_mult[level]
|
||||||
|
|
||||||
|
for i_block in range(num_res_blocks + 1):
|
||||||
|
stage["block"][i_block] = ResnetBlock(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
dropout=dropout,
|
||||||
|
norm_type=norm_type,
|
||||||
|
causality_axis=causality_axis,
|
||||||
|
)
|
||||||
|
block_in = block_out
|
||||||
|
if curr_res in attn_resolutions:
|
||||||
|
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
|
||||||
|
|
||||||
|
if level != 0:
|
||||||
|
stage["upsample"] = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||||
|
curr_res *= 2
|
||||||
|
|
||||||
|
up_modules[level] = stage
|
||||||
|
|
||||||
|
return up_modules, block_in
|
||||||
142
mlx_video/models/ltx/audio_vae/vocoder.py
Normal file
142
mlx_video/models/ltx/audio_vae/vocoder.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
"""Vocoder for converting mel spectrograms to audio waveforms."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, leaky_relu
|
||||||
|
|
||||||
|
|
||||||
|
class Vocoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Vocoder model for synthesizing audio from Mel spectrograms.
|
||||||
|
Based on HiFi-GAN architecture.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resblock_kernel_sizes: List of kernel sizes for the residual blocks
|
||||||
|
upsample_rates: List of upsampling rates
|
||||||
|
upsample_kernel_sizes: List of kernel sizes for the upsampling layers
|
||||||
|
resblock_dilation_sizes: List of dilation sizes for the residual blocks
|
||||||
|
upsample_initial_channel: Initial number of channels for upsampling
|
||||||
|
stereo: Whether to use stereo output
|
||||||
|
resblock: Type of residual block to use ("1" or "2")
|
||||||
|
output_sample_rate: Waveform sample rate
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
resblock_kernel_sizes: List[int] | None = None,
|
||||||
|
upsample_rates: List[int] | None = None,
|
||||||
|
upsample_kernel_sizes: List[int] | None = None,
|
||||||
|
resblock_dilation_sizes: List[List[int]] | None = None,
|
||||||
|
upsample_initial_channel: int = 1024,
|
||||||
|
stereo: bool = True,
|
||||||
|
resblock: str = "1",
|
||||||
|
output_sample_rate: int = 24000,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Initialize default values if not provided
|
||||||
|
if resblock_kernel_sizes is None:
|
||||||
|
resblock_kernel_sizes = [3, 7, 11]
|
||||||
|
if upsample_rates is None:
|
||||||
|
upsample_rates = [6, 5, 2, 2, 2]
|
||||||
|
if upsample_kernel_sizes is None:
|
||||||
|
upsample_kernel_sizes = [16, 15, 8, 4, 4]
|
||||||
|
if resblock_dilation_sizes is None:
|
||||||
|
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||||
|
|
||||||
|
self.output_sample_rate = output_sample_rate
|
||||||
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
|
self.num_upsamples = len(upsample_rates)
|
||||||
|
self.upsample_rates = upsample_rates
|
||||||
|
self.upsample_kernel_sizes = upsample_kernel_sizes
|
||||||
|
self.upsample_initial_channel = upsample_initial_channel
|
||||||
|
|
||||||
|
in_channels = 128 if stereo else 64
|
||||||
|
self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, kernel_size=7, stride=1, padding=3)
|
||||||
|
|
||||||
|
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
||||||
|
|
||||||
|
# Upsampling layers using ConvTranspose1d
|
||||||
|
self.ups = {}
|
||||||
|
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||||
|
in_ch = upsample_initial_channel // (2**i)
|
||||||
|
out_ch = upsample_initial_channel // (2 ** (i + 1))
|
||||||
|
self.ups[i] = nn.ConvTranspose1d(
|
||||||
|
in_ch,
|
||||||
|
out_ch,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=(kernel_size - stride) // 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Residual blocks
|
||||||
|
self.resblocks = {}
|
||||||
|
block_idx = 0
|
||||||
|
for i in range(len(self.ups)):
|
||||||
|
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||||
|
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
||||||
|
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations))
|
||||||
|
block_idx += 1
|
||||||
|
|
||||||
|
out_channels = 2 if stereo else 1
|
||||||
|
final_channels = upsample_initial_channel // (2**self.num_upsamples)
|
||||||
|
self.conv_post = nn.Conv1d(final_channels, out_channels, kernel_size=7, stride=1, padding=3)
|
||||||
|
|
||||||
|
self.upsample_factor = math.prod(upsample_rates)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
"""
|
||||||
|
Forward pass of the vocoder.
|
||||||
|
Args:
|
||||||
|
x: Input Mel spectrogram tensor. Can be either:
|
||||||
|
- 3D: (batch_size, time, mel_bins) for mono - MLX format (N, L, C)
|
||||||
|
- 4D: (batch_size, 2, time, mel_bins) for stereo - PyTorch format (N, C, H, W)
|
||||||
|
Returns:
|
||||||
|
Audio waveform tensor of shape (batch_size, out_channels, audio_length)
|
||||||
|
"""
|
||||||
|
# Input: (batch, channels, time, mel_bins) from audio decoder
|
||||||
|
# Transpose to (batch, channels, mel_bins, time)
|
||||||
|
x = mx.transpose(x, (0, 1, 3, 2))
|
||||||
|
|
||||||
|
if x.ndim == 4: # stereo
|
||||||
|
# x shape: (batch, 2, mel_bins, time)
|
||||||
|
# Rearrange to (batch, 2*mel_bins, time)
|
||||||
|
b, s, c, t = x.shape
|
||||||
|
x = x.reshape(b, s * c, t)
|
||||||
|
|
||||||
|
# MLX Conv1d expects (N, L, C), so transpose
|
||||||
|
# Current: (batch, channels, time) -> (batch, time, channels)
|
||||||
|
x = mx.transpose(x, (0, 2, 1))
|
||||||
|
|
||||||
|
x = self.conv_pre(x)
|
||||||
|
|
||||||
|
for i in range(self.num_upsamples):
|
||||||
|
x = leaky_relu(x, LRELU_SLOPE)
|
||||||
|
x = self.ups[i](x)
|
||||||
|
|
||||||
|
start = i * self.num_kernels
|
||||||
|
end = start + self.num_kernels
|
||||||
|
|
||||||
|
# Apply residual blocks and average their outputs
|
||||||
|
block_outputs = []
|
||||||
|
for idx in range(start, end):
|
||||||
|
block_outputs.append(self.resblocks[idx](x))
|
||||||
|
|
||||||
|
# Stack and mean
|
||||||
|
x = mx.stack(block_outputs, axis=0)
|
||||||
|
x = mx.mean(x, axis=0)
|
||||||
|
|
||||||
|
# IMPORTANT: Use default leaky_relu slope (0.01), NOT LRELU_SLOPE (0.1)
|
||||||
|
# PyTorch uses F.leaky_relu(x) which defaults to 0.01
|
||||||
|
x = nn.leaky_relu(x) # Default negative_slope=0.01
|
||||||
|
x = self.conv_post(x)
|
||||||
|
x = mx.tanh(x)
|
||||||
|
|
||||||
|
# Transpose back to (batch, channels, time)
|
||||||
|
x = mx.transpose(x, (0, 2, 1))
|
||||||
|
|
||||||
|
return x
|
||||||
@@ -111,6 +111,7 @@ class LTXModelConfig(BaseModelConfig):
|
|||||||
audio_in_channels: int = 128
|
audio_in_channels: int = 128
|
||||||
audio_out_channels: int = 128
|
audio_out_channels: int = 128
|
||||||
audio_cross_attention_dim: int = 2048
|
audio_cross_attention_dim: int = 2048
|
||||||
|
audio_caption_channels: int = 3840 # Input dim for audio text embeddings (same as video)
|
||||||
|
|
||||||
# Positional embedding config
|
# Positional embedding config
|
||||||
positional_embedding_theta: float = 10000.0
|
positional_embedding_theta: float = 10000.0
|
||||||
@@ -122,7 +123,7 @@ class LTXModelConfig(BaseModelConfig):
|
|||||||
|
|
||||||
# Timestep config
|
# Timestep config
|
||||||
timestep_scale_multiplier: int = 1000
|
timestep_scale_multiplier: int = 1000
|
||||||
av_ca_timestep_scale_multiplier: int = 1
|
av_ca_timestep_scale_multiplier: int = 1000
|
||||||
|
|
||||||
# Normalization
|
# Normalization
|
||||||
norm_eps: float = 1e-6
|
norm_eps: float = 1e-6
|
||||||
|
|||||||
@@ -70,6 +70,9 @@ class TransformerArgsPreprocessor:
|
|||||||
attention_mask: Optional[mx.array] = None,
|
attention_mask: Optional[mx.array] = None,
|
||||||
) -> Tuple[mx.array, Optional[mx.array]]:
|
) -> Tuple[mx.array, Optional[mx.array]]:
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
|
|
||||||
|
# Context is already processed through embeddings connector in text encoder
|
||||||
|
# Here we just apply the caption projection
|
||||||
context = self.caption_projection(context)
|
context = self.caption_projection(context)
|
||||||
context = mx.reshape(context, (batch_size, -1, x.shape[-1]))
|
context = mx.reshape(context, (batch_size, -1, x.shape[-1]))
|
||||||
return context, attention_mask
|
return context, attention_mask
|
||||||
@@ -282,8 +285,10 @@ class LTXModel(nn.Module):
|
|||||||
def _init_audio(self, config: LTXModelConfig) -> None:
|
def _init_audio(self, config: LTXModelConfig) -> None:
|
||||||
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
|
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
|
||||||
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim)
|
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim)
|
||||||
|
|
||||||
|
# Audio caption projection: receives pre-processed embeddings from text encoder's audio_embeddings_connector
|
||||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||||
in_features=config.caption_channels,
|
in_features=config.audio_caption_channels,
|
||||||
hidden_size=self.audio_inner_dim,
|
hidden_size=self.audio_inner_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -216,7 +216,8 @@ class ConnectorAttention(nn.Module):
|
|||||||
self.to_q = nn.Linear(dim, inner_dim, bias=True)
|
self.to_q = nn.Linear(dim, inner_dim, bias=True)
|
||||||
self.to_k = nn.Linear(dim, inner_dim, bias=True)
|
self.to_k = nn.Linear(dim, inner_dim, bias=True)
|
||||||
self.to_v = nn.Linear(dim, inner_dim, bias=True)
|
self.to_v = nn.Linear(dim, inner_dim, bias=True)
|
||||||
self.to_out = [nn.Linear(inner_dim, dim, bias=True)]
|
# Direct attribute for MLX parameter tracking (not a list)
|
||||||
|
self.to_out = nn.Linear(inner_dim, dim, bias=True)
|
||||||
|
|
||||||
# Standard RMSNorm (not Gemma-style) on full inner_dim
|
# Standard RMSNorm (not Gemma-style) on full inner_dim
|
||||||
self.q_norm = nn.RMSNorm(inner_dim, eps=1e-6)
|
self.q_norm = nn.RMSNorm(inner_dim, eps=1e-6)
|
||||||
@@ -239,21 +240,51 @@ class ConnectorAttention(nn.Module):
|
|||||||
q = self.q_norm(q)
|
q = self.q_norm(q)
|
||||||
k = self.k_norm(k)
|
k = self.k_norm(k)
|
||||||
|
|
||||||
|
# Reshape to (B, H, T, D) for SPLIT RoPE
|
||||||
if pe is not None:
|
|
||||||
# pe: tuple of (cos, sin) each with shape (1, seq_len, inner_dim)
|
|
||||||
q = apply_interleaved_rotary_emb(q, pe[0], pe[1])
|
|
||||||
k = apply_interleaved_rotary_emb(k, pe[0], pe[1])
|
|
||||||
|
|
||||||
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||||
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||||
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if pe is not None:
|
||||||
|
# pe: tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2)
|
||||||
|
# Apply SPLIT RoPE: operates on first half of head dimensions
|
||||||
|
q = self._apply_split_rope(q, pe[0], pe[1])
|
||||||
|
k = self._apply_split_rope(k, pe[0], pe[1])
|
||||||
|
|
||||||
# No mask needed for connector - after register replacement, all positions are valid
|
# No mask needed for connector - after register replacement, all positions are valid
|
||||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None)
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None)
|
||||||
out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
|
out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
|
||||||
|
|
||||||
return self.to_out[0](out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
def _apply_split_rope(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
cos_freq: mx.array,
|
||||||
|
sin_freq: mx.array,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Apply SPLIT RoPE to input tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor of shape (B, H, T, D)
|
||||||
|
cos_freq: Cosine frequencies of shape (1, H, T, D//2)
|
||||||
|
sin_freq: Sine frequencies of shape (1, H, T, D//2)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor with SPLIT rotary embeddings applied
|
||||||
|
"""
|
||||||
|
# Split x into two halves: (B, H, T, D) -> two tensors of (B, H, T, D//2)
|
||||||
|
half_dim = x.shape[-1] // 2
|
||||||
|
x1 = x[..., :half_dim]
|
||||||
|
x2 = x[..., half_dim:]
|
||||||
|
|
||||||
|
# Apply rotation: SPLIT pattern
|
||||||
|
# out1 = x1 * cos - x2 * sin
|
||||||
|
# out2 = x2 * cos + x1 * sin
|
||||||
|
out1 = x1 * cos_freq - x2 * sin_freq
|
||||||
|
out2 = x2 * cos_freq + x1 * sin_freq
|
||||||
|
|
||||||
|
return mx.concatenate([out1, out2], axis=-1)
|
||||||
|
|
||||||
|
|
||||||
class GEGLU(nn.Module):
|
class GEGLU(nn.Module):
|
||||||
@@ -272,15 +303,15 @@ class ConnectorFeedForward(nn.Module):
|
|||||||
def __init__(self, dim: int = 3840, mult: int = 4, dropout: float = 0.0):
|
def __init__(self, dim: int = 3840, mult: int = 4, dropout: float = 0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim * mult
|
inner_dim = dim * mult
|
||||||
self.net = [
|
# Use explicit named attributes to match weight key structure (proj_in, proj_out)
|
||||||
GEGLU(dim, inner_dim),
|
self.proj_in = nn.Linear(dim, inner_dim, bias=True)
|
||||||
nn.Dropout(dropout),
|
self.dropout = nn.Dropout(dropout)
|
||||||
nn.Linear(inner_dim, dim, bias=True),
|
self.proj_out = nn.Linear(inner_dim, dim, bias=True)
|
||||||
]
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
for layer in self.net:
|
x = nn.gelu(self.proj_in(x))
|
||||||
x = layer(x)
|
x = self.dropout(x)
|
||||||
|
x = self.proj_out(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -326,6 +357,7 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
num_layers: int = 2,
|
num_layers: int = 2,
|
||||||
num_learnable_registers: int = 128,
|
num_learnable_registers: int = 128,
|
||||||
positional_embedding_theta: float = 10000.0,
|
positional_embedding_theta: float = 10000.0,
|
||||||
|
positional_embedding_max_pos: list = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@@ -333,60 +365,69 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.num_learnable_registers = num_learnable_registers
|
self.num_learnable_registers = num_learnable_registers
|
||||||
self.positional_embedding_theta = positional_embedding_theta
|
self.positional_embedding_theta = positional_embedding_theta
|
||||||
|
self.positional_embedding_max_pos = positional_embedding_max_pos or [4096]
|
||||||
|
|
||||||
self.transformer_1d_blocks = [
|
# Use dict with int keys for MLX to track parameters (lists are not tracked)
|
||||||
ConnectorTransformerBlock(dim, num_heads, head_dim)
|
self.transformer_1d_blocks = {
|
||||||
for _ in range(num_layers)
|
i: ConnectorTransformerBlock(dim, num_heads, head_dim)
|
||||||
]
|
for i in range(num_layers)
|
||||||
|
}
|
||||||
|
|
||||||
if num_learnable_registers > 0:
|
if num_learnable_registers > 0:
|
||||||
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
|
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
|
||||||
|
|
||||||
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
|
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
|
||||||
"""Compute RoPE frequencies for connector (INTERLEAVED type).
|
"""Compute RoPE frequencies for connector (SPLIT type matching PyTorch).
|
||||||
|
|
||||||
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
|
Returns tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
dim = self.num_heads * self.head_dim # inner_dim = 3840
|
dim = self.num_heads * self.head_dim # inner_dim = 3840
|
||||||
theta = self.positional_embedding_theta
|
theta = self.positional_embedding_theta
|
||||||
max_pos = [1] # Default for connector
|
max_pos = self.positional_embedding_max_pos # [4096] from PyTorch
|
||||||
n_elem = 2 * len(max_pos) # = 2
|
n_elem = 2 * len(max_pos) # = 2
|
||||||
|
|
||||||
start = 1.0
|
start = 1.0
|
||||||
end = theta
|
end = theta
|
||||||
num_indices = dim // n_elem # 1920
|
num_indices = dim // n_elem # 1920
|
||||||
|
|
||||||
# Use numpy float64 for precision
|
# Use numpy float64 for precision (double_precision_rope=True in PyTorch)
|
||||||
log_start = np.log(start) / np.log(theta) # = 0
|
log_start = np.log(start) / np.log(theta) # = 0
|
||||||
log_end = np.log(end) / np.log(theta) # = 1
|
log_end = np.log(end) / np.log(theta) # = 1
|
||||||
lin_space = np.linspace(log_start, log_end, num_indices, dtype=np.float64)
|
lin_space = np.linspace(log_start, log_end, num_indices, dtype=np.float64)
|
||||||
indices = (np.power(theta, lin_space) * (np.pi / 2)).astype(np.float32)
|
indices = (np.power(theta, lin_space) * (np.pi / 2)).astype(np.float64)
|
||||||
|
|
||||||
# Generate positions and compute freqs (matches generate_freqs)
|
# Generate positions and compute freqs (matches generate_freqs)
|
||||||
positions = np.arange(seq_len, dtype=np.float64)
|
positions = np.arange(seq_len, dtype=np.float64)
|
||||||
# fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1)
|
# Scale positions by max_pos (PyTorch uses max_pos=[4096])
|
||||||
# scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1
|
fractional_positions = positions / max_pos[0]
|
||||||
scaled_positions = positions * 2 - 1 # Shape: (seq_len,)
|
scaled_positions = fractional_positions * 2 - 1 # Shape: (seq_len,)
|
||||||
|
|
||||||
# freqs = indices * scaled_positions (outer product)
|
# freqs = indices * scaled_positions (outer product)
|
||||||
# Shape: (seq_len, num_indices)
|
# Shape: (seq_len, num_indices)
|
||||||
freqs = scaled_positions[:, None] * indices[None, :]
|
freqs = scaled_positions[:, None] * indices[None, :]
|
||||||
|
|
||||||
# Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis)
|
# Compute cos/sin
|
||||||
cos_freq = np.cos(freqs)
|
cos_freq = np.cos(freqs) # (seq_len, 1920)
|
||||||
sin_freq = np.sin(freqs)
|
sin_freq = np.sin(freqs)
|
||||||
|
|
||||||
# repeat_interleave: (seq_len, num_indices) -> (seq_len, dim)
|
# For SPLIT RoPE: pad to head_dim//2 = 64 per head, then reshape to (1, H, T, D//2)
|
||||||
# Pattern: [c0, c0, c1, c1, c2, c2, ...]
|
# Current: (T, 1920) -> need (1, 30, T, 64)
|
||||||
cos_full = np.repeat(cos_freq, 2, axis=-1)
|
# 30 heads * 64 = 1920, so no padding needed
|
||||||
sin_full = np.repeat(sin_freq, 2, axis=-1)
|
|
||||||
|
|
||||||
# Add batch dimension and convert to MLX: (1, seq_len, dim)
|
# Reshape: (T, 1920) -> (T, 30, 64) -> (1, 30, T, 64)
|
||||||
cos_full = mx.array(cos_full[None, :, :].astype(np.float32))
|
cos_freq = cos_freq.reshape(seq_len, self.num_heads, self.head_dim // 2)
|
||||||
sin_full = mx.array(sin_full[None, :, :].astype(np.float32))
|
sin_freq = sin_freq.reshape(seq_len, self.num_heads, self.head_dim // 2)
|
||||||
|
|
||||||
|
# Transpose to (1, H, T, D//2)
|
||||||
|
cos_freq = np.transpose(cos_freq, (1, 0, 2))[np.newaxis, ...]
|
||||||
|
sin_freq = np.transpose(sin_freq, (1, 0, 2))[np.newaxis, ...]
|
||||||
|
|
||||||
|
# Convert to MLX
|
||||||
|
cos_full = mx.array(cos_freq.astype(np.float32))
|
||||||
|
sin_full = mx.array(sin_freq.astype(np.float32))
|
||||||
|
|
||||||
return cos_full.astype(dtype), sin_full.astype(dtype)
|
return cos_full.astype(dtype), sin_full.astype(dtype)
|
||||||
|
|
||||||
@@ -462,8 +503,8 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
freqs_cis = self._precompute_freqs_cis(seq_len, hidden_states.dtype)
|
freqs_cis = self._precompute_freqs_cis(seq_len, hidden_states.dtype)
|
||||||
|
|
||||||
# Process through transformer blocks
|
# Process through transformer blocks
|
||||||
for block in self.transformer_1d_blocks:
|
for i in range(len(self.transformer_1d_blocks)):
|
||||||
hidden_states = block(hidden_states, attention_mask, freqs_cis)
|
hidden_states = self.transformer_1d_blocks[i](hidden_states, attention_mask, freqs_cis)
|
||||||
|
|
||||||
# Final RMS norm
|
# Final RMS norm
|
||||||
hidden_states = rms_norm(hidden_states)
|
hidden_states = rms_norm(hidden_states)
|
||||||
@@ -535,15 +576,28 @@ class GemmaFeaturesExtractor(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class AudioEmbeddingsConnector(nn.Module):
|
||||||
|
"""Projects video embeddings to audio cross-attention dimension."""
|
||||||
|
|
||||||
|
def __init__(self, input_dim: int = 3840, output_dim: int = 2048):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(input_dim, output_dim, bias=True)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
return self.linear(x)
|
||||||
|
|
||||||
|
|
||||||
class LTX2TextEncoder(nn.Module):
|
class LTX2TextEncoder(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_dim: int = 3840,
|
hidden_dim: int = 3840,
|
||||||
|
audio_dim: int = 2048,
|
||||||
num_layers: int = 49, # 48 transformer layers + 1 embedding
|
num_layers: int = 49, # 48 transformer layers + 1 embedding
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_dim = hidden_dim
|
self.hidden_dim = hidden_dim
|
||||||
|
self.audio_dim = audio_dim
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.language_model = None
|
self.language_model = None
|
||||||
|
|
||||||
@@ -560,14 +614,26 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
head_dim=128,
|
head_dim=128,
|
||||||
num_layers=2,
|
num_layers=2,
|
||||||
num_learnable_registers=128,
|
num_learnable_registers=128,
|
||||||
|
positional_embedding_max_pos=[4096], # Match PyTorch
|
||||||
|
)
|
||||||
|
|
||||||
|
# Audio embeddings connector: separate 2-layer transformer (same architecture as video)
|
||||||
|
# Both connectors process the feature extractor output independently
|
||||||
|
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||||
|
dim=hidden_dim,
|
||||||
|
num_heads=30,
|
||||||
|
head_dim=128,
|
||||||
|
num_layers=2,
|
||||||
|
num_learnable_registers=128,
|
||||||
|
positional_embedding_max_pos=[4096], # Match PyTorch
|
||||||
)
|
)
|
||||||
|
|
||||||
self.processor = None
|
self.processor = None
|
||||||
|
|
||||||
def load(self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it"):
|
def load(self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it"):
|
||||||
|
|
||||||
if Path(text_encoder_path / "text_encoder").is_dir():
|
if Path(str(text_encoder_path)).joinpath("text_encoder").is_dir():
|
||||||
text_encoder_path = str(text_encoder_path / "text_encoder")
|
text_encoder_path = str(Path(text_encoder_path) / "text_encoder")
|
||||||
|
|
||||||
self.language_model = LanguageModel.from_pretrained(text_encoder_path)
|
self.language_model = LanguageModel.from_pretrained(text_encoder_path)
|
||||||
|
|
||||||
@@ -594,10 +660,14 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
# Map weight names to our structure
|
# Map weight names to our structure
|
||||||
mapped_weights = {}
|
mapped_weights = {}
|
||||||
for key, value in connector_weights.items():
|
for key, value in connector_weights.items():
|
||||||
# transformer_1d_blocks.X.attn1.* -> transformer_1d_blocks.X.attn1.*
|
new_key = key
|
||||||
# transformer_1d_blocks.X.ff.net.0.proj.* -> transformer_1d_blocks.X.ff.net.0.proj.*
|
# Map ff.net.0.proj -> ff.proj_in (GEGLU projection)
|
||||||
# transformer_1d_blocks.X.ff.net.2.* -> transformer_1d_blocks.X.ff.net.2.*
|
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||||
mapped_weights[key] = value
|
# Map ff.net.2 -> ff.proj_out (output Linear)
|
||||||
|
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||||
|
# Map to_out.0 -> to_out (Sequential -> direct)
|
||||||
|
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||||
|
mapped_weights[new_key] = value
|
||||||
|
|
||||||
self.video_embeddings_connector.load_weights(
|
self.video_embeddings_connector.load_weights(
|
||||||
list(mapped_weights.items()), strict=False
|
list(mapped_weights.items()), strict=False
|
||||||
@@ -607,6 +677,34 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
if "learnable_registers" in connector_weights:
|
if "learnable_registers" in connector_weights:
|
||||||
self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"]
|
self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"]
|
||||||
|
|
||||||
|
# Load audio_embeddings_connector weights (same structure as video connector)
|
||||||
|
audio_connector_weights = {}
|
||||||
|
for key, value in transformer_weights.items():
|
||||||
|
if key.startswith("model.diffusion_model.audio_embeddings_connector."):
|
||||||
|
new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "")
|
||||||
|
audio_connector_weights[new_key] = value
|
||||||
|
|
||||||
|
if audio_connector_weights:
|
||||||
|
# Map weight names to our structure (same as video connector)
|
||||||
|
mapped_audio_weights = {}
|
||||||
|
for key, value in audio_connector_weights.items():
|
||||||
|
new_key = key
|
||||||
|
# Map ff.net.0.proj -> ff.proj_in (GEGLU projection)
|
||||||
|
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||||
|
# Map ff.net.2 -> ff.proj_out (output Linear)
|
||||||
|
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||||
|
# Map to_out.0 -> to_out (Sequential -> direct)
|
||||||
|
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||||
|
mapped_audio_weights[new_key] = value
|
||||||
|
|
||||||
|
self.audio_embeddings_connector.load_weights(
|
||||||
|
list(mapped_audio_weights.items()), strict=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
|
||||||
|
if "learnable_registers" in audio_connector_weights:
|
||||||
|
self.audio_embeddings_connector.learnable_registers = audio_connector_weights["learnable_registers"]
|
||||||
|
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
tokenizer_path = model_path / "tokenizer"
|
tokenizer_path = model_path / "tokenizer"
|
||||||
@@ -623,8 +721,20 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
max_length: int = 1024,
|
max_length: int = 1024,
|
||||||
|
return_audio_embeddings: bool = True,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
"""Encode text prompt to video and audio embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Text prompt to encode
|
||||||
|
max_length: Maximum token length (default 1024 to match official PyTorch)
|
||||||
|
return_audio_embeddings: If True, returns (video_emb, audio_emb).
|
||||||
|
If False, returns (video_emb, attention_mask).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (video_embeddings, audio_embeddings) if return_audio_embeddings=True
|
||||||
|
Tuple of (video_embeddings, attention_mask) otherwise
|
||||||
|
"""
|
||||||
if self.processor is None:
|
if self.processor is None:
|
||||||
raise RuntimeError("Model not loaded. Call load() first.")
|
raise RuntimeError("Model not loaded. Call load() first.")
|
||||||
|
|
||||||
@@ -649,16 +759,33 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
additive_mask = (attention_mask - 1).astype(features.dtype)
|
additive_mask = (attention_mask - 1).astype(features.dtype)
|
||||||
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||||
|
|
||||||
embeddings, _ = self.video_embeddings_connector(features, additive_mask)
|
video_embeddings, _ = self.video_embeddings_connector(features, additive_mask)
|
||||||
|
|
||||||
return embeddings, attention_mask
|
if return_audio_embeddings:
|
||||||
|
# Process features through audio connector independently (same input as video)
|
||||||
|
audio_embeddings, _ = self.audio_embeddings_connector(features, additive_mask)
|
||||||
|
return video_embeddings, audio_embeddings
|
||||||
|
else:
|
||||||
|
return video_embeddings, attention_mask
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
max_length: int = 1024,
|
max_length: int = 1024,
|
||||||
|
return_audio_embeddings: bool = True,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
return self.encode(prompt, max_length)
|
"""Encode text prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Text prompt to encode
|
||||||
|
max_length: Maximum token length (default 1024 to match official PyTorch)
|
||||||
|
return_audio_embeddings: If True, returns (video_emb, audio_emb).
|
||||||
|
If False, returns (video_emb, attention_mask).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of embeddings based on return_audio_embeddings flag
|
||||||
|
"""
|
||||||
|
return self.encode(prompt, max_length, return_audio_embeddings)
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def default_t2v_system_prompt(self) -> str:
|
def default_t2v_system_prompt(self) -> str:
|
||||||
@@ -756,7 +883,7 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
input_ids = mx.array(inputs["input_ids"])
|
input_ids = mx.array(inputs["input_ids"])
|
||||||
|
|
||||||
sampler = make_sampler(kwargs.get("temperature", 0.7), kwargs.get("top_p", 0.95), top_k=kwargs.get("top_k", -1))
|
sampler = make_sampler(kwargs.get("temperature", 0.7), kwargs.get("top_p", 1.0), top_k=kwargs.get("top_k", -1))
|
||||||
logits_processors = make_logits_processors(
|
logits_processors = make_logits_processors(
|
||||||
kwargs.get("logit_bias", None),
|
kwargs.get("logit_bias", None),
|
||||||
kwargs.get("repetition_penalty", 1.3),
|
kwargs.get("repetition_penalty", 1.3),
|
||||||
@@ -833,7 +960,7 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
|
def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
|
||||||
encoder = LTX2TextEncoder(model_path=model_path)
|
encoder = LTX2TextEncoder()
|
||||||
encoder.load()
|
encoder.load(model_path=model_path)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|||||||
@@ -34,10 +34,11 @@ def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.a
|
|||||||
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
|
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
|
||||||
x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw))
|
x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw))
|
||||||
|
|
||||||
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, ph, pw, F', H', W')
|
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, pw, ph, F', H', W')
|
||||||
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
|
# PyTorch einops uses (c, p, r, q) = (c, temporal, width, height), so we need pw before ph
|
||||||
|
x = mx.transpose(x, (0, 1, 3, 7, 5, 2, 4, 6))
|
||||||
|
|
||||||
# Reshape: (B, C, pt, ph, pw, F', H', W') -> (B, C*pt*ph*pw, F', H', W')
|
# Reshape: (B, C, pt, pw, ph, F', H', W') -> (B, C*pt*pw*ph, F', H', W')
|
||||||
x = mx.reshape(x, (b, new_c, new_f, new_h, new_w))
|
x = mx.reshape(x, (b, new_c, new_f, new_h, new_w))
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|||||||
Reference in New Issue
Block a user