initial commit (LTX-2)

This commit is contained in:
Prince Canuma
2026-01-11 23:48:33 +01:00
parent 9f01d22750
commit d1ca36a315
29 changed files with 7124 additions and 0 deletions

13
mlx_video/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
from mlx_video.models.ltx import LTXModel, LTXModelConfig
from mlx_video.generate import LTXVideoPipeline, GenerationConfig
from mlx_video.convert import load_transformer_weights, load_vae_weights
__all__ = [
"LTXModel",
"LTXModelConfig",
"LTXVideoPipeline",
"GenerationConfig",
"load_transformer_weights",
"load_vae_weights",
]

457
mlx_video/convert.py Normal file
View File

@@ -0,0 +1,457 @@
import json
import shutil
from pathlib import Path
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
from mlx_video.models.ltx.ltx import LTXModel
def get_model_path(
path_or_hf_repo: str,
revision: Optional[str] = None,
) -> Path:
"""Get local path to model, downloading if necessary.
Args:
path_or_hf_repo: Local path or HuggingFace repo ID
revision: Git revision for HF repo
Returns:
Path to model directory
"""
model_path = Path(path_or_hf_repo)
if model_path.exists():
return model_path
# Download from HuggingFace
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
revision=revision,
allow_patterns=[
"*.safetensors",
"*.json",
"config.json",
],
)
)
return model_path
def load_safetensors(path: Path) -> Dict[str, mx.array]:
"""Load weights from safetensors file(s) using MLX.
Args:
path: Path to model directory or single safetensors file
Returns:
Dictionary of weights
"""
weights = {}
if path.is_file():
# Single file - use mx.load directly (handles bfloat16)
return mx.load(str(path))
else:
# Directory - load all safetensors files
safetensor_files = list(path.glob("*.safetensors"))
for sf_path in safetensor_files:
file_weights = mx.load(str(sf_path))
weights.update(file_weights)
return weights
def load_transformer_weights(model_path: Path) -> Dict[str, mx.array]:
"""Load transformer weights from LTX-2 model.
Args:
model_path: Path to LTX-2 model directory
Returns:
Dictionary of transformer weights
"""
# Try distilled model first, then dev
weight_files = [
model_path / "ltx-2-19b-distilled.safetensors",
model_path / "ltx-2-19b-dev.safetensors",
]
for weight_file in weight_files:
if weight_file.exists():
print(f"Loading transformer weights from {weight_file.name}...")
return mx.load(str(weight_file))
raise FileNotFoundError(f"No transformer weights found in {model_path}")
def load_vae_weights(model_path: Path) -> Dict[str, mx.array]:
"""Load VAE weights from LTX-2 model.
Args:
model_path: Path to LTX-2 model directory
Returns:
Dictionary of VAE weights
"""
vae_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
if vae_path.exists():
print(f"Loading VAE weights from {vae_path}...")
return mx.load(str(vae_path))
raise FileNotFoundError(f"VAE weights not found at {vae_path}")
def sanitize_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize transformer weight names from PyTorch LTX-2 format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for transformer
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Skip non-transformer weights (VAE, vocoder, audio_vae, connectors)
if not key.startswith("model.diffusion_model."):
continue
# Remove 'model.diffusion_model.' prefix
new_key = key.replace("model.diffusion_model.", "")
# Handle to_out.0 -> to_out (MLX doesn't use Sequential numbering)
new_key = new_key.replace(".to_out.0.", ".to_out.")
# Handle feed-forward net naming
# PyTorch: ff.net.0.proj -> ff.net_0_proj (or similar)
# MLX FeedForward: uses proj_in, proj_out
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
# Handle AdaLN naming - keep emb wrapper, just fix linear naming
# PyTorch: adaln_single.emb.timestep_embedder.linear_1 -> adaln_single.emb.timestep_embedder.linear1
new_key = new_key.replace(".linear_1.", ".linear1.")
new_key = new_key.replace(".linear_2.", ".linear2.")
# Handle caption projection (keep linear1/linear2 naming for compatibility)
# These are already mapped correctly in the sanitization
sanitized[new_key] = value
return sanitized
def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize VAE weight names from PyTorch format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for VAE
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Skip position_ids (not needed)
if "position_ids" in key:
continue
# Handle Conv3d weight shape conversion
# PyTorch: (out_channels, in_channels, D, H, W)
# MLX: (out_channels, D, H, W, in_channels)
if "conv" in key.lower() and "weight" in key and value.ndim == 5:
# Transpose from (O, I, D, H, W) to (O, D, H, W, I)
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Handle Conv2d weight shape conversion
# PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels)
if "conv" in key.lower() and "weight" in key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
def sanitize_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize weight names from PyTorch format to MLX format.
Generic function that handles both transformer and VAE weights.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Skip position_ids (not needed)
if "position_ids" in key:
continue
# Handle transformer weights
if key.startswith("model.diffusion_model."):
new_key = key.replace("model.diffusion_model.", "")
new_key = new_key.replace(".to_out.0.", ".to_out.")
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
new_key = new_key.replace(".linear_1.", ".linear1.")
new_key = new_key.replace(".linear_2.", ".linear2.")
# Handle Conv3d weight shape conversion
# PyTorch: (out_channels, in_channels, D, H, W)
# MLX: (out_channels, D, H, W, in_channels)
if "conv" in key.lower() and "weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Handle Conv2d weight shape conversion
# PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels)
if "conv" in key.lower() and "weight" in key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
def load_config(model_path: Path) -> Dict[str, Any]:
"""Load model configuration.
Args:
model_path: Path to model directory
Returns:
Configuration dictionary
"""
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path, "r") as f:
return json.load(f)
# Return default config
return {}
def create_model_from_config(config: Dict[str, Any]) -> LTXModel:
"""Create model instance from configuration.
Args:
config: Configuration dictionary
Returns:
LTXModel instance
"""
# Map config to LTXModelConfig
model_config = LTXModelConfig(
model_type=LTXModelType.AudioVideo,
num_attention_heads=config.get("num_attention_heads", 32),
attention_head_dim=config.get("attention_head_dim", 128),
in_channels=config.get("in_channels", 128),
out_channels=config.get("out_channels", 128),
num_layers=config.get("num_layers", 48),
cross_attention_dim=config.get("cross_attention_dim", 4096),
caption_channels=config.get("caption_channels", 3840),
audio_num_attention_heads=config.get("audio_num_attention_heads", 32),
audio_attention_head_dim=config.get("audio_attention_head_dim", 64),
audio_in_channels=config.get("audio_in_channels", 128),
audio_out_channels=config.get("audio_out_channels", 128),
audio_cross_attention_dim=config.get("audio_cross_attention_dim", 2048),
positional_embedding_theta=config.get("positional_embedding_theta", 10000.0),
positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]),
audio_positional_embedding_max_pos=config.get("audio_positional_embedding_max_pos", [20]),
timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000),
av_ca_timestep_scale_multiplier=config.get("av_ca_timestep_scale_multiplier", 1),
norm_eps=config.get("norm_eps", 1e-6),
)
return LTXModel(model_config)
def convert(
hf_path: str,
mlx_path: str = "mlx_model",
dtype: Optional[str] = None,
quantize: bool = False,
q_bits: int = 4,
q_group_size: int = 64,
) -> Path:
"""Convert HuggingFace model to MLX format.
Args:
hf_path: HuggingFace model path or repo ID
mlx_path: Output path for MLX model
dtype: Target dtype (float16, float32, bfloat16)
quantize: Whether to quantize the model
q_bits: Quantization bits
q_group_size: Quantization group size
Returns:
Path to converted model
"""
print(f"Loading model from {hf_path}...")
model_path = get_model_path(hf_path)
# Load config
config = load_config(model_path)
# Load weights
print("Loading weights...")
weights = load_safetensors(model_path)
# Sanitize weights
print("Sanitizing weights...")
weights = sanitize_weights(weights)
# Convert dtype if specified
if dtype is not None:
dtype_map = {
"float16": mx.float16,
"float32": mx.float32,
"bfloat16": mx.bfloat16,
}
target_dtype = dtype_map.get(dtype, mx.float16)
print(f"Converting to {dtype}...")
weights = {
k: v.astype(target_dtype) if v.dtype in [mx.float32, mx.float16, mx.bfloat16] else v
for k, v in weights.items()
}
# Create output directory
output_path = Path(mlx_path)
output_path.mkdir(parents=True, exist_ok=True)
# Save weights
print(f"Saving weights to {output_path}...")
save_weights(output_path, weights)
# Save config
config_out_path = output_path / "config.json"
with open(config_out_path, "w") as f:
json.dump(config, f, indent=2)
print(f"Model converted successfully to {output_path}")
return output_path
def save_weights(path: Path, weights: Dict[str, mx.array]) -> None:
"""Save weights in safetensors format.
Args:
path: Output directory
weights: Dictionary of weights
"""
from safetensors.numpy import save_file
import numpy as np
# Convert to numpy for safetensors
np_weights = {k: np.array(v) for k, v in weights.items()}
# Save to file
save_file(np_weights, path / "model.safetensors")
def load_model(
path_or_hf_repo: str,
lazy: bool = False,
) -> LTXModel:
"""Load LTX model from path or HuggingFace.
Args:
path_or_hf_repo: Path to model or HuggingFace repo ID
lazy: Whether to use lazy loading
Returns:
Loaded LTXModel
"""
model_path = get_model_path(path_or_hf_repo)
# Load config
config = load_config(model_path)
# Create model
model = create_model_from_config(config)
# Load weights
weights = load_safetensors(model_path)
# Sanitize if needed
weights = sanitize_weights(weights)
# Load weights into model
model.load_weights(list(weights.items()))
if not lazy:
mx.eval(model.parameters())
return model
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Convert LTX-2 model to MLX format")
parser.add_argument(
"--hf-path",
type=str,
default="Lightricks/LTX-2",
help="HuggingFace model path or repo ID",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="Output path for MLX model",
)
parser.add_argument(
"--dtype",
type=str,
choices=["float16", "float32", "bfloat16"],
default="float16",
help="Target dtype",
)
parser.add_argument(
"--quantize",
action="store_true",
help="Quantize the model",
)
parser.add_argument(
"--q-bits",
type=int,
default=4,
help="Quantization bits",
)
args = parser.parse_args()
convert(
hf_path=args.hf_path,
mlx_path=args.mlx_path,
dtype=args.dtype,
quantize=args.quantize,
q_bits=args.q_bits,
)

586
mlx_video/generate.py Normal file
View File

@@ -0,0 +1,586 @@
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple, Iterator, Union
import mlx.core as mx
import numpy as np
from mlx_video.models.ltx.ltx import LTXModel, X0Model
from mlx_video.models.ltx.transformer import Modality
from mlx_video.models.ltx.video_vae import VideoEncoder, VideoDecoder
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder, load_text_encoder
@dataclass
class GenerationConfig:
"""Configuration for video generation."""
# Video dimensions
height: int = 512
width: int = 512
num_frames: int = 33 # Must be 1 + 8*k
# Diffusion parameters
num_inference_steps: int = 8 # For distilled model (ignored if use_distilled=True)
guidance_scale: float = 3.0
use_distilled: bool = True # Use hardcoded sigma values for distilled model
# Latent dimensions (computed from video dimensions)
@property
def latent_height(self) -> int:
return self.height // 32
@property
def latent_width(self) -> int:
return self.width // 32
@property
def latent_frames(self) -> int:
return 1 + (self.num_frames - 1) // 8
# Hardcoded sigma values for distilled model (from LTX-2 pipeline)
# These were tuned to match the distillation process
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0]
# Scheduler constants for dynamic sigma computation (non-distilled models)
BASE_SHIFT_ANCHOR = 1024
MAX_SHIFT_ANCHOR = 4096
def get_sigmas(
num_steps: int,
num_tokens: int,
max_shift: float = 2.05,
base_shift: float = 0.95,
stretch: bool = True,
terminal: float = 0.1,
use_distilled: bool = True,
) -> mx.array:
"""Get sigma schedule for diffusion.
Args:
num_steps: Number of diffusion steps
num_tokens: Number of latent tokens (T * H * W)
max_shift: Maximum shift for sigma schedule
base_shift: Base shift for sigma schedule
stretch: Whether to stretch sigmas to terminal value
terminal: Terminal value for stretching
use_distilled: If True, use hardcoded distilled sigma values
Returns:
Array of sigma values
"""
import math
# For distilled model, use hardcoded sigma values
if use_distilled:
return mx.array(DISTILLED_SIGMA_VALUES, dtype=mx.float32)
# For non-distilled models, compute dynamically using LTX2Scheduler logic
# Linear base schedule
sigmas = mx.linspace(1.0, 0.0, num_steps + 1)
# Compute token-dependent sigma shift
x1 = BASE_SHIFT_ANCHOR
x2 = MAX_SHIFT_ANCHOR
mm = (max_shift - base_shift) / (x2 - x1)
b = base_shift - mm * x1
sigma_shift = num_tokens * mm + b
# Apply exponential transformation
# sigmas = exp(sigma_shift) / (exp(sigma_shift) + (1/sigmas - 1)^1)
power = 1
exp_shift = math.exp(sigma_shift)
# Convert to numpy for computation then back to mx
sigmas_np = np.array(sigmas)
result = np.zeros_like(sigmas_np)
non_zero = sigmas_np != 0
result[non_zero] = exp_shift / (exp_shift + (1.0 / sigmas_np[non_zero] - 1.0) ** power)
# Stretch sigmas so final value matches terminal
if stretch:
non_zero_mask = result != 0
non_zero_sigmas = result[non_zero_mask]
one_minus_z = 1.0 - non_zero_sigmas
scale_factor = one_minus_z[-1] / (1.0 - terminal)
stretched = 1.0 - (one_minus_z / scale_factor)
result[non_zero_mask] = stretched
return mx.array(result, dtype=mx.float32)
def create_position_grid(
batch_size: int,
num_frames: int,
height: int,
width: int,
temporal_scale: int = 8,
spatial_scale: int = 32,
fps: float = 24.0,
causal_fix: bool = True,
) -> mx.array:
"""Create position grid for RoPE in pixel space.
Args:
batch_size: Batch size
num_frames: Number of frames (latent)
height: Height (latent)
width: Width (latent)
temporal_scale: VAE temporal scale factor (default 8)
spatial_scale: VAE spatial scale factor (default 32)
fps: Frames per second (default 24.0)
causal_fix: Apply causal fix for first frame (default True)
Returns:
Position grid of shape (B, 3, num_patches, 2) in pixel space
where dim 2 is [start, end) bounds for each patch
"""
# Patch size is (1, 1, 1) for LTX-2 - no spatial patching
patch_size_t, patch_size_h, patch_size_w = 1, 1, 1
# Generate grid coordinates for each dimension (frame, height, width)
# These are the starting coordinates for each patch in latent space
t_coords = np.arange(0, num_frames, patch_size_t)
h_coords = np.arange(0, height, patch_size_h)
w_coords = np.arange(0, width, patch_size_w)
# Create meshgrid with indexing='ij' for (frame, height, width) order
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
# Stack to get shape (3, grid_t, grid_h, grid_w)
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
# Calculate end coordinates (start + patch_size)
patch_size_delta = np.array([patch_size_t, patch_size_h, patch_size_w]).reshape(3, 1, 1, 1)
patch_ends = patch_starts + patch_size_delta
# Stack start and end: shape (3, grid_t, grid_h, grid_w, 2)
latent_coords = np.stack([patch_starts, patch_ends], axis=-1)
# Flatten spatial/temporal dims: (3, num_patches, 2)
num_patches = num_frames * height * width
latent_coords = latent_coords.reshape(3, num_patches, 2)
# Broadcast to batch: (batch, 3, num_patches, 2)
latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1))
# Convert latent coords to pixel coords by scaling with VAE factors
scale_factors = np.array([temporal_scale, spatial_scale, spatial_scale]).reshape(1, 3, 1, 1)
pixel_coords = (latent_coords * scale_factors).astype(np.float32)
# Apply causal fix for first frame temporal axis
if causal_fix:
# VAE temporal stride for first frame is 1 instead of temporal_scale
# Shift and clamp to keep first-frame timestamps non-negative
pixel_coords[:, 0, :, :] = np.clip(
pixel_coords[:, 0, :, :] + 1 - temporal_scale,
a_min=0,
a_max=None
)
# Convert temporal to time in seconds by dividing by fps
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps
return mx.array(pixel_coords, dtype=mx.float32)
class LTXVideoPipeline:
def __init__(
self,
transformer: LTXModel,
text_encoder: Optional[LTX2TextEncoder] = None,
tokenizer: Optional[any] = None,
vae_encoder: Optional[VideoEncoder] = None,
vae_decoder: Optional[VideoDecoder] = None,
):
"""Initialize pipeline.
Args:
transformer: LTX transformer model
text_encoder: Optional LTX text encoder
tokenizer: Optional tokenizer for text encoding
vae_encoder: Optional VAE encoder
vae_decoder: Optional VAE decoder
"""
self.transformer = transformer
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.vae_encoder = vae_encoder
self.vae_decoder = vae_decoder
self.x0_model = X0Model(transformer)
def prepare_latents(
self,
batch_size: int,
num_frames: int,
height: int,
width: int,
dtype: mx.Dtype = mx.float16,
) -> mx.array:
"""Prepare initial noise latents.
Args:
batch_size: Batch size
num_frames: Number of latent frames
height: Latent height
width: Latent width
dtype: Data type
Returns:
Random latent noise
"""
# Use in_channels from transformer config
in_channels = self.transformer.config.in_channels
shape = (batch_size, in_channels, num_frames, height, width)
latents = mx.random.normal(shape).astype(dtype)
return latents
def prepare_text_embeddings(
self,
prompt: Union[str, List[str]],
batch_size: int,
max_length: int = 1024,
) -> Tuple[mx.array, Optional[mx.array]]:
"""Prepare text embeddings.
Args:
prompt: Text prompt or list of prompts
batch_size: Batch size
max_length: Maximum sequence length for tokenization
Returns:
Tuple of (text_embeddings, attention_mask)
"""
# If text encoder is available, use it
if self.text_encoder is not None and self.tokenizer is not None:
# Handle single or multiple prompts
if isinstance(prompt, str):
prompts = [prompt] * batch_size
else:
prompts = prompt
# Tokenize
tokens = self.tokenizer(
prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="np",
)
input_ids = mx.array(tokens["input_ids"])
attention_mask = mx.array(tokens["attention_mask"])
# Encode
embeddings = self.text_encoder(input_ids, attention_mask)
mx.eval(embeddings)
return embeddings, None # Connector handles masking internally
# Fallback: random embeddings (for testing without text encoder)
print("Warning: No text encoder provided, using random embeddings")
seq_len = max_length + 128 # Account for learnable registers
embed_dim = self.transformer.config.caption_channels
embeddings = mx.random.normal((batch_size, seq_len, embed_dim))
mask = mx.ones((batch_size, seq_len))
return embeddings, mask
def denoise_step(
self,
latents: mx.array,
sigma: float,
sigma_next: float,
text_embeddings: mx.array,
positions: mx.array,
text_mask: Optional[mx.array] = None,
) -> mx.array:
"""Perform one denoising step.
Args:
latents: Current noisy latents
sigma: Current noise level
sigma_next: Next noise level
text_embeddings: Text conditioning
positions: Position grid for RoPE
text_mask: Optional attention mask for text
Returns:
Denoised latents
"""
batch_size = latents.shape[0]
# Flatten latents for transformer: (B, C, F, H, W) -> (B, F*H*W, C)
b, c, f, h, w = latents.shape
latents_flat = mx.reshape(latents, (b, c, -1))
latents_flat = mx.transpose(latents_flat, (0, 2, 1))
# Create timestep tensor
timesteps = mx.full((batch_size,), sigma)
# Create video modality input
video_modality = Modality(
latent=latents_flat,
timesteps=timesteps,
positions=positions,
context=text_embeddings,
context_mask=text_mask,
enabled=True,
)
# Run denoising
denoised_video, _ = self.x0_model(video=video_modality, audio=None)
# Reshape back: (B, F*H*W, C) -> (B, C, F, H, W)
denoised_video = mx.transpose(denoised_video, (0, 2, 1))
denoised_video = mx.reshape(denoised_video, (b, c, f, h, w))
# Euler step
if sigma_next > 0:
# x_next = x0 + sigma_next * (x - x0) / sigma
noise = (latents - denoised_video) / sigma
latents = denoised_video + sigma_next * noise
else:
latents = denoised_video
return latents
def __call__(
self,
prompt: str,
config: Optional[GenerationConfig] = None,
seed: Optional[int] = None,
) -> mx.array:
"""Generate video from text prompt.
Args:
prompt: Text prompt
config: Generation configuration
seed: Random seed
Returns:
Generated video tensor of shape (B, C, F, H, W)
"""
if config is None:
config = GenerationConfig()
if seed is not None:
mx.random.seed(seed)
batch_size = 1
# Prepare text embeddings
text_embeddings, text_mask = self.prepare_text_embeddings(prompt, batch_size)
# Prepare initial latents
latents = self.prepare_latents(
batch_size=batch_size,
num_frames=config.latent_frames,
height=config.latent_height,
width=config.latent_width,
)
# Prepare position grid
positions = create_position_grid(
batch_size=batch_size,
num_frames=config.latent_frames,
height=config.latent_height,
width=config.latent_width,
)
# Get sigma schedule
num_tokens = config.latent_frames * config.latent_height * config.latent_width
sigmas = get_sigmas(
config.num_inference_steps,
num_tokens,
use_distilled=config.use_distilled,
)
# Denoising loop
for i in range(len(sigmas) - 1):
sigma = float(sigmas[i])
sigma_next = float(sigmas[i + 1])
latents = self.denoise_step(
latents=latents,
sigma=sigma,
sigma_next=sigma_next,
text_embeddings=text_embeddings,
positions=positions,
text_mask=text_mask,
)
mx.eval(latents)
# Decode latents to video
if self.vae_decoder is not None:
video = self.vae_decoder(latents)
else:
video = latents
return video
def generate_video(
prompt: str,
transformer: LTXModel,
text_encoder: Optional[LTX2TextEncoder] = None,
tokenizer: Optional[any] = None,
vae_decoder: Optional[VideoDecoder] = None,
config: Optional[GenerationConfig] = None,
seed: Optional[int] = None,
) -> mx.array:
"""Generate video from text prompt.
Args:
prompt: Text prompt
transformer: LTX transformer model
text_encoder: Optional text encoder
tokenizer: Optional tokenizer
vae_decoder: Optional VAE decoder
config: Generation configuration
seed: Random seed
Returns:
Generated video tensor
"""
pipeline = LTXVideoPipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae_decoder=vae_decoder,
)
return pipeline(prompt, config, seed)
def load_pipeline(
model_path: str,
text_encoder_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
load_text_encoder_weights: bool = True,
) -> LTXVideoPipeline:
"""Load complete LTX-2 video generation pipeline.
Args:
model_path: Path to LTX-2 model weights (safetensors)
text_encoder_path: Path to text encoder weights directory
tokenizer_path: Path to tokenizer directory
load_text_encoder_weights: Whether to load text encoder weights
Returns:
Configured LTXVideoPipeline
"""
from transformers import AutoTokenizer
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.convert import sanitize_transformer_weights
print("Loading LTX-2 pipeline...")
# Load transformer
print(" Loading transformer...")
raw_weights = mx.load(model_path)
sanitized = sanitize_transformer_weights(raw_weights)
config = LTXModelConfig(
model_type=LTXModelType.VideoOnly,
num_attention_heads=32,
attention_head_dim=128,
in_channels=128,
out_channels=128,
num_layers=48,
cross_attention_dim=4096,
caption_channels=3840,
)
transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False)
print(" Transformer loaded")
# Load VAE decoder
print(" Loading VAE decoder...")
vae_decoder = load_vae_decoder(model_path, timestep_conditioning=True)
print(" VAE decoder loaded")
# Load text encoder if paths provided
text_encoder = None
tokenizer = None
if load_text_encoder_weights and text_encoder_path is not None:
print(" Loading text encoder...")
text_encoder = load_text_encoder(model_path, text_encoder_path)
print(" Text encoder loaded")
if tokenizer_path is not None:
print(" Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
print(" Tokenizer loaded")
print("Pipeline ready!")
return LTXVideoPipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae_decoder=vae_decoder,
)
def video_to_numpy(video: mx.array) -> np.ndarray:
"""Convert video tensor to numpy array.
Args:
video: Video tensor of shape (B, C, F, H, W) in range [-1, 1]
Returns:
Numpy array of shape (B, F, H, W, C) in range [0, 255]
"""
# Clamp to [-1, 1]
video = mx.clip(video, -1.0, 1.0)
# Scale to [0, 255]
video = ((video + 1.0) / 2.0 * 255.0).astype(mx.uint8)
# Rearrange: (B, C, F, H, W) -> (B, F, H, W, C)
video = mx.transpose(video, (0, 2, 3, 4, 1))
return np.array(video)
if __name__ == "__main__":
# Example usage
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
# Create a small test config
config = LTXModelConfig(
model_type=LTXModelType.VideoOnly,
num_layers=2, # Reduced for testing
num_attention_heads=4,
attention_head_dim=32,
)
# Create model
model = LTXModel(config)
# Generate video
gen_config = GenerationConfig(
height=256,
width=256,
num_frames=9,
num_inference_steps=4,
)
print("Testing generation pipeline...")
pipeline = LTXVideoPipeline(transformer=model)
# This would require proper text embeddings in practice
# video = pipeline("A cat walking", gen_config, seed=42)
# print(f"Generated video shape: {video.shape}")
print("Pipeline initialized successfully!")

View File

@@ -0,0 +1,2 @@
from mlx_video.models.ltx import LTXModel, LTXModelConfig

View File

@@ -0,0 +1,7 @@
from mlx_video.models.ltx.config import (
LTXModelConfig,
TransformerConfig,
LTXModelType,
)
from mlx_video.models.ltx.ltx import LTXModel, X0Model

View File

@@ -0,0 +1,161 @@
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.utils import get_timestep_embedding
class AdaLayerNormSingle(nn.Module):
def __init__(
self,
embedding_dim: int,
embedding_coefficient: int = 6,
use_additional_conditions: bool = False,
):
super().__init__()
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim=embedding_dim,
size_emb_dim=0 if not use_additional_conditions else embedding_dim // 3,
use_additional_conditions=use_additional_conditions,
)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True)
def __call__(
self,
timestep: mx.array,
added_cond_kwargs: dict | None = None,
batch_size: int | None = None,
hidden_dtype: mx.Dtype | None = None,
) -> Tuple[mx.array, mx.array]:
added_cond_kwargs = added_cond_kwargs or {}
embedded_timestep = self.emb(
timestep,
batch_size=batch_size,
hidden_dtype=hidden_dtype,
**added_cond_kwargs,
)
scale_shift_params = self.linear(self.silu(embedded_timestep))
return scale_shift_params, embedded_timestep
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
def __init__(
self,
embedding_dim: int,
size_emb_dim: int = 0,
use_additional_conditions: bool = False,
timestep_proj_dim: int = 256,
):
super().__init__()
self.embedding_dim = embedding_dim
self.size_emb_dim = size_emb_dim
self.use_additional_conditions = use_additional_conditions
self.time_proj = Timesteps(timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(timestep_proj_dim, embedding_dim, out_dim=embedding_dim)
if use_additional_conditions and size_emb_dim > 0:
self.additional_embedder = ConditionEmbedding(size_emb_dim, embedding_dim)
def __call__(
self,
timestep: mx.array,
resolution: mx.array | None = None,
aspect_ratio: mx.array | None = None,
batch_size: int | None = None,
hidden_dtype: mx.Dtype | None = None,
) -> mx.array:
# Project timestep
timesteps_proj = self.time_proj(timestep)
if hidden_dtype is not None:
timesteps_proj = timesteps_proj.astype(hidden_dtype)
timesteps_emb = self.timestep_embedder(timesteps_proj)
# Add additional conditions if enabled
if self.use_additional_conditions and self.size_emb_dim > 0:
if resolution is not None and aspect_ratio is not None:
additional_embeds = self.additional_embedder(resolution, aspect_ratio, hidden_dtype)
timesteps_emb = timesteps_emb + additional_embeds
return timesteps_emb
class Timesteps(nn.Module):
def __init__(
self,
num_channels: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1.0,
):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def __call__(self, timesteps: mx.array) -> mx.array:
return get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int | None = None,
):
super().__init__()
out_dim = out_dim or time_embed_dim
self.linear1 = nn.Linear(in_channels, time_embed_dim)
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
self.linear2 = nn.Linear(time_embed_dim, out_dim)
def __call__(self, sample: mx.array) -> mx.array:
sample = self.linear1(sample)
sample = self.act(sample)
sample = self.linear2(sample)
return sample
class ConditionEmbedding(nn.Module):
def __init__(self, size_emb_dim: int, embedding_dim: int):
super().__init__()
self.resolution_embedder = TimestepEmbedding(size_emb_dim, embedding_dim)
self.aspect_ratio_embedder = TimestepEmbedding(size_emb_dim, embedding_dim)
def __call__(
self,
resolution: mx.array,
aspect_ratio: mx.array,
hidden_dtype: mx.Dtype | None = None,
) -> mx.array:
resolution_emb = self.resolution_embedder(resolution)
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio)
if hidden_dtype is not None:
resolution_emb = resolution_emb.astype(hidden_dtype)
aspect_ratio_emb = aspect_ratio_emb.astype(hidden_dtype)
return resolution_emb + aspect_ratio_emb

View File

@@ -0,0 +1,142 @@
"""Attention module for LTX-2."""
import math
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.config import LTXRopeType
from mlx_video.models.ltx.rope import apply_rotary_emb
def scaled_dot_product_attention(
q: mx.array,
k: mx.array,
v: mx.array,
heads: int,
mask: Optional[mx.array] = None,
) -> mx.array:
b, q_seq_len, dim = q.shape
_, kv_seq_len, _ = k.shape
dim_head = dim // heads
# Reshape to (B, seq_len, heads, dim_head)
q = mx.reshape(q, (b, q_seq_len, heads, dim_head))
k = mx.reshape(k, (b, kv_seq_len, heads, dim_head))
v = mx.reshape(v, (b, kv_seq_len, heads, dim_head))
# Transpose to (B, heads, seq_len, dim_head)
q = mx.swapaxes(q, 1, 2)
k = mx.swapaxes(k, 1, 2)
v = mx.swapaxes(v, 1, 2)
# Handle mask dimensions
if mask is not None:
# Add batch dimension if needed
if mask.ndim == 2:
mask = mx.expand_dims(mask, axis=0)
# Add heads dimension if needed
if mask.ndim == 3:
mask = mx.expand_dims(mask, axis=1)
# Compute scaled dot-product attention
scale = 1.0 / math.sqrt(dim_head)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
# Reshape back to (B, q_seq_len, heads * dim_head)
out = mx.swapaxes(out, 1, 2)
out = mx.reshape(out, (b, q_seq_len, heads * dim_head))
return out
class Attention(nn.Module):
"""Multi-head attention with rotary position embeddings.
Supports both self-attention and cross-attention.
"""
def __init__(
self,
query_dim: int,
context_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
norm_eps: float = 1e-6,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
):
"""Initialize attention module.
Args:
query_dim: Dimension of query input
context_dim: Dimension of context (key/value) input. If None, same as query_dim
heads: Number of attention heads
dim_head: Dimension per head
norm_eps: Epsilon for RMS normalization
rope_type: Type of rotary position embedding
"""
super().__init__()
self.rope_type = rope_type
self.heads = heads
self.dim_head = dim_head
inner_dim = dim_head * heads
context_dim = query_dim if context_dim is None else context_dim
# Q, K, V projections
self.to_q = nn.Linear(query_dim, inner_dim, bias=True)
self.to_k = nn.Linear(context_dim, inner_dim, bias=True)
self.to_v = nn.Linear(context_dim, inner_dim, bias=True)
# Q and K normalization
self.q_norm = nn.RMSNorm(inner_dim, eps=norm_eps)
self.k_norm = nn.RMSNorm(inner_dim, eps=norm_eps)
# Output projection
self.to_out = nn.Linear(inner_dim, query_dim, bias=True)
def __call__(
self,
x: mx.array,
context: Optional[mx.array] = None,
mask: Optional[mx.array] = None,
pe: Optional[Tuple[mx.array, mx.array]] = None,
k_pe: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
"""Forward pass.
Args:
x: Query input of shape (B, seq_len, query_dim)
context: Context for cross-attention. If None, uses x (self-attention)
mask: Attention mask
pe: Position embeddings for query (and key if k_pe is None)
k_pe: Position embeddings for key (optional, uses pe if None)
Returns:
Attention output of shape (B, seq_len, query_dim)
"""
# Compute Q, K, V
q = self.to_q(x)
context = x if context is None else context
k = self.to_k(context)
v = self.to_v(context)
# Apply normalization
q = self.q_norm(q)
k = self.k_norm(k)
# Apply rotary position embeddings
if pe is not None:
q = apply_rotary_emb(q, pe, self.rope_type)
k_pe_to_use = pe if k_pe is None else k_pe
k = apply_rotary_emb(k, k_pe_to_use, self.rope_type)
# Compute attention
out = scaled_dot_product_attention(q, k, v, self.heads, mask)
# Project output
return self.to_out(out)

View File

@@ -0,0 +1,181 @@
import inspect
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, List, Optional
class LTXModelType(Enum):
AudioVideo = "ltx av model"
VideoOnly = "ltx video only model"
AudioOnly = "ltx audio only model"
def is_video_enabled(self) -> bool:
return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly)
def is_audio_enabled(self) -> bool:
return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly)
class LTXRopeType(Enum):
INTERLEAVED = "interleaved"
SPLIT = "split"
TWO_D = "2d"
class AttentionType(Enum):
DEFAULT = "default"
@dataclass
class BaseModelConfig:
@classmethod
def from_dict(cls, params: dict[str, Any]) -> "BaseModelConfig":
"""Create config from dictionary, filtering only valid parameters."""
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
def to_dict(self) -> dict[str, Any]:
"""Export config to dictionary."""
result = {}
for k, v in self.__dict__.items():
if v is not None:
if isinstance(v, Enum):
result[k] = v.value
elif hasattr(v, 'to_dict'):
result[k] = v.to_dict()
else:
result[k] = v
return result
@dataclass
class TransformerConfig(BaseModelConfig):
dim: int
heads: int
d_head: int
context_dim: int
@dataclass
class VideoVAEConfig(BaseModelConfig):
convolution_dimensions: int = 3
in_channels: int = 3
out_channels: int = 128
latent_channels: int = 128
patch_size: int = 4
encoder_blocks: List[tuple] = field(default_factory=lambda: [
("res_x", {"num_layers": 4}),
("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_time_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
])
decoder_blocks: List[tuple] = field(default_factory=lambda: [
("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
])
@dataclass
class LTXModelConfig(BaseModelConfig):
# Model type
model_type: LTXModelType = LTXModelType.AudioVideo
# Video transformer config
num_attention_heads: int = 32
attention_head_dim: int = 128
in_channels: int = 128
out_channels: int = 128
num_layers: int = 48
cross_attention_dim: int = 4096
caption_channels: int = 3840
# Audio transformer config
audio_num_attention_heads: int = 32
audio_attention_head_dim: int = 64
audio_in_channels: int = 128
audio_out_channels: int = 128
audio_cross_attention_dim: int = 2048
# Positional embedding config
positional_embedding_theta: float = 10000.0
positional_embedding_max_pos: Optional[List[int]] = None
audio_positional_embedding_max_pos: Optional[List[int]] = None
use_middle_indices_grid: bool = True
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED
double_precision_rope: bool = False
# Timestep config
timestep_scale_multiplier: int = 1000
av_ca_timestep_scale_multiplier: int = 1
# Normalization
norm_eps: float = 1e-6
# Attention type
attention_type: AttentionType = AttentionType.DEFAULT
# VAE config
vae_config: Optional[VideoVAEConfig] = None
def __post_init__(self):
"""Set default values after initialization."""
if self.positional_embedding_max_pos is None:
self.positional_embedding_max_pos = [20, 2048, 2048]
if self.audio_positional_embedding_max_pos is None:
self.audio_positional_embedding_max_pos = [20]
# Convert string enum values if loading from dict
if isinstance(self.model_type, str):
self.model_type = LTXModelType(self.model_type)
if isinstance(self.rope_type, str):
self.rope_type = LTXRopeType(self.rope_type)
if isinstance(self.attention_type, str):
self.attention_type = AttentionType(self.attention_type)
@property
def inner_dim(self) -> int:
"""Video inner dimension."""
return self.num_attention_heads * self.attention_head_dim
@property
def audio_inner_dim(self) -> int:
"""Audio inner dimension."""
return self.audio_num_attention_heads * self.audio_attention_head_dim
def get_video_config(self) -> Optional[TransformerConfig]:
"""Get video transformer configuration."""
if not self.model_type.is_video_enabled():
return None
return TransformerConfig(
dim=self.inner_dim,
heads=self.num_attention_heads,
d_head=self.attention_head_dim,
context_dim=self.cross_attention_dim,
)
def get_audio_config(self) -> Optional[TransformerConfig]:
"""Get audio transformer configuration."""
if not self.model_type.is_audio_enabled():
return None
return TransformerConfig(
dim=self.audio_inner_dim,
heads=self.audio_num_attention_heads,
d_head=self.audio_attention_head_dim,
context_dim=self.audio_cross_attention_dim,
)

View File

@@ -0,0 +1,40 @@
import mlx.core as mx
import mlx.nn as nn
class GELU(nn.Module):
def __init__(self, approximate: str = "tanh"):
super().__init__()
self.approximate = approximate
def __call__(self, x: mx.array) -> mx.array:
if self.approximate == "tanh":
return nn.gelu_approx(x)
else:
return nn.gelu(x)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
dim_out: int | None = None,
mult: int = 4,
bias: bool = True,
):
super().__init__()
dim_out = dim_out or dim
inner_dim = int(dim * mult)
self.proj_in = nn.Linear(dim, inner_dim, bias=bias)
self.act = GELU(approximate="tanh")
self.proj_out = nn.Linear(inner_dim, dim_out, bias=bias)
def __call__(self, x: mx.array) -> mx.array:
x = self.proj_in(x)
x = self.act(x)
x = self.proj_out(x)
return x

518
mlx_video/models/ltx/ltx.py Normal file
View File

@@ -0,0 +1,518 @@
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.config import (
LTXModelConfig,
LTXModelType,
LTXRopeType,
TransformerConfig,
)
from mlx_video.models.ltx.adaln import AdaLayerNormSingle
from mlx_video.models.ltx.rope import precompute_freqs_cis
from mlx_video.models.ltx.text_projection import PixArtAlphaTextProjection
from mlx_video.models.ltx.transformer import (
BasicAVTransformerBlock,
Modality,
TransformerArgs,
)
from mlx_video.utils import to_denoised
class TransformerArgsPreprocessor:
def __init__(
self,
patchify_proj: nn.Linear,
adaln: AdaLayerNormSingle,
caption_projection: PixArtAlphaTextProjection,
inner_dim: int,
max_pos: List[int],
num_attention_heads: int,
use_middle_indices_grid: bool,
timestep_scale_multiplier: int,
positional_embedding_theta: float,
rope_type: LTXRopeType,
double_precision_rope: bool = False,
):
self.patchify_proj = patchify_proj
self.adaln = adaln
self.caption_projection = caption_projection
self.inner_dim = inner_dim
self.max_pos = max_pos
self.num_attention_heads = num_attention_heads
self.use_middle_indices_grid = use_middle_indices_grid
self.timestep_scale_multiplier = timestep_scale_multiplier
self.positional_embedding_theta = positional_embedding_theta
self.rope_type = rope_type
self.double_precision_rope = double_precision_rope
def _prepare_timestep(
self,
timestep: mx.array,
batch_size: int,
) -> Tuple[mx.array, mx.array]:
timestep = timestep * self.timestep_scale_multiplier
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1))
# Reshape to (batch, tokens, dim)
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
return timestep_emb, embedded_timestep
def _prepare_context(
self,
context: mx.array,
x: mx.array,
attention_mask: Optional[mx.array] = None,
) -> Tuple[mx.array, Optional[mx.array]]:
batch_size = x.shape[0]
context = self.caption_projection(context)
context = mx.reshape(context, (batch_size, -1, x.shape[-1]))
return context, attention_mask
def _prepare_attention_mask(
self,
attention_mask: Optional[mx.array],
x_dtype: mx.Dtype,
) -> Optional[mx.array]:
if attention_mask is None:
return None
# Check if already float
if attention_mask.dtype in [mx.float16, mx.float32, mx.bfloat16]:
return attention_mask
# Convert boolean/int mask to float mask
# 0 -> -inf (masked), 1 -> 0 (not masked)
mask = (attention_mask.astype(x_dtype) - 1) * 1e9
mask = mx.reshape(mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
return mask
def _prepare_positional_embeddings(
self,
positions: mx.array,
inner_dim: int,
max_pos: List[int],
use_middle_indices_grid: bool,
num_attention_heads: int,
) -> Tuple[mx.array, mx.array]:
pe = precompute_freqs_cis(
positions,
dim=inner_dim,
theta=self.positional_embedding_theta,
max_pos=max_pos,
use_middle_indices_grid=use_middle_indices_grid,
num_attention_heads=num_attention_heads,
rope_type=self.rope_type,
double_precision=self.double_precision_rope,
)
return pe
def prepare(self, modality: Modality) -> TransformerArgs:
x = self.patchify_proj(modality.latent)
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0])
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
pe = self._prepare_positional_embeddings(
positions=modality.positions,
inner_dim=self.inner_dim,
max_pos=self.max_pos,
use_middle_indices_grid=self.use_middle_indices_grid,
num_attention_heads=self.num_attention_heads,
)
return TransformerArgs(
x=x,
context=context,
context_mask=attention_mask,
timesteps=timestep,
embedded_timestep=embedded_timestep,
positional_embeddings=pe,
cross_positional_embeddings=None,
cross_scale_shift_timestep=None,
cross_gate_timestep=None,
enabled=modality.enabled,
)
class MultiModalTransformerArgsPreprocessor:
def __init__(
self,
patchify_proj: nn.Linear,
adaln: AdaLayerNormSingle,
caption_projection: PixArtAlphaTextProjection,
cross_scale_shift_adaln: AdaLayerNormSingle,
cross_gate_adaln: AdaLayerNormSingle,
inner_dim: int,
max_pos: List[int],
num_attention_heads: int,
cross_pe_max_pos: int,
use_middle_indices_grid: bool,
audio_cross_attention_dim: int,
timestep_scale_multiplier: int,
positional_embedding_theta: float,
rope_type: LTXRopeType,
av_ca_timestep_scale_multiplier: int,
double_precision_rope: bool = False,
):
self.simple_preprocessor = TransformerArgsPreprocessor(
patchify_proj=patchify_proj,
adaln=adaln,
caption_projection=caption_projection,
inner_dim=inner_dim,
max_pos=max_pos,
num_attention_heads=num_attention_heads,
use_middle_indices_grid=use_middle_indices_grid,
timestep_scale_multiplier=timestep_scale_multiplier,
positional_embedding_theta=positional_embedding_theta,
rope_type=rope_type,
double_precision_rope=double_precision_rope,
)
self.cross_scale_shift_adaln = cross_scale_shift_adaln
self.cross_gate_adaln = cross_gate_adaln
self.cross_pe_max_pos = cross_pe_max_pos
self.audio_cross_attention_dim = audio_cross_attention_dim
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
def prepare(self, modality: Modality) -> TransformerArgs:
from dataclasses import replace
transformer_args = self.simple_preprocessor.prepare(modality)
# Prepare cross-modal positional embeddings
cross_pe = self.simple_preprocessor._prepare_positional_embeddings(
positions=modality.positions[:, 0:1, :],
inner_dim=self.audio_cross_attention_dim,
max_pos=[self.cross_pe_max_pos],
use_middle_indices_grid=True,
num_attention_heads=self.simple_preprocessor.num_attention_heads,
)
# Prepare cross-attention timestep embeddings
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
timestep=modality.timesteps,
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
batch_size=transformer_args.x.shape[0],
)
return replace(
transformer_args,
cross_positional_embeddings=cross_pe,
cross_scale_shift_timestep=cross_scale_shift_timestep,
cross_gate_timestep=cross_gate_timestep,
)
def _prepare_cross_attention_timestep(
self,
timestep: mx.array,
timestep_scale_multiplier: int,
batch_size: int,
) -> Tuple[mx.array, mx.array]:
timestep = timestep * timestep_scale_multiplier
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1))
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor)
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
return scale_shift_timestep, gate_timestep
class LTXModel(nn.Module):
def __init__(self, config: LTXModelConfig):
super().__init__()
self.config = config
self.model_type = config.model_type
self.use_middle_indices_grid = config.use_middle_indices_grid
self.rope_type = config.rope_type
self.timestep_scale_multiplier = config.timestep_scale_multiplier
self.positional_embedding_theta = config.positional_embedding_theta
cross_pe_max_pos = None
if config.model_type.is_video_enabled():
self.positional_embedding_max_pos = config.positional_embedding_max_pos
self.num_attention_heads = config.num_attention_heads
self.inner_dim = config.inner_dim
self._init_video(config)
if config.model_type.is_audio_enabled():
self.audio_positional_embedding_max_pos = config.audio_positional_embedding_max_pos
self.audio_num_attention_heads = config.audio_num_attention_heads
self.audio_inner_dim = config.audio_inner_dim
self._init_audio(config)
# Initialize cross-modal components
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
cross_pe_max_pos = max(
config.positional_embedding_max_pos[0],
config.audio_positional_embedding_max_pos[0],
)
self.av_ca_timestep_scale_multiplier = config.av_ca_timestep_scale_multiplier
self.audio_cross_attention_dim = config.audio_cross_attention_dim
self._init_audio_video(config)
self._init_preprocessors(config, cross_pe_max_pos)
self._init_transformer_blocks(config)
def _init_video(self, config: LTXModelConfig) -> None:
self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True)
self.adaln_single = AdaLayerNormSingle(self.inner_dim)
self.caption_projection = PixArtAlphaTextProjection(
in_features=config.caption_channels,
hidden_size=self.inner_dim,
)
self.scale_shift_table = mx.zeros((2, self.inner_dim))
self.norm_out = nn.LayerNorm(self.inner_dim, eps=config.norm_eps, affine=False)
self.proj_out = nn.Linear(self.inner_dim, config.out_channels)
def _init_audio(self, config: LTXModelConfig) -> None:
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim)
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=config.caption_channels,
hidden_size=self.audio_inner_dim,
)
# Output components
self.audio_scale_shift_table = mx.zeros((2, self.audio_inner_dim))
self.audio_norm_out = nn.LayerNorm(self.audio_inner_dim, eps=config.norm_eps, affine=False)
self.audio_proj_out = nn.Linear(self.audio_inner_dim, config.audio_out_channels)
def _init_audio_video(self, config: LTXModelConfig) -> None:
num_scale_shift_values = 4
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
self.inner_dim,
embedding_coefficient=num_scale_shift_values,
)
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
embedding_coefficient=num_scale_shift_values,
)
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
self.inner_dim,
embedding_coefficient=1,
)
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
embedding_coefficient=1,
)
def _init_preprocessors(self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]) -> None:
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
# Multi-modal preprocessors
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
patchify_proj=self.patchify_proj,
adaln=self.adaln_single,
caption_projection=self.caption_projection,
cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
inner_dim=self.inner_dim,
max_pos=config.positional_embedding_max_pos,
num_attention_heads=self.num_attention_heads,
cross_pe_max_pos=cross_pe_max_pos,
use_middle_indices_grid=config.use_middle_indices_grid,
audio_cross_attention_dim=config.audio_cross_attention_dim,
timestep_scale_multiplier=config.timestep_scale_multiplier,
positional_embedding_theta=config.positional_embedding_theta,
rope_type=config.rope_type,
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
double_precision_rope=config.double_precision_rope,
)
self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
patchify_proj=self.audio_patchify_proj,
adaln=self.audio_adaln_single,
caption_projection=self.audio_caption_projection,
cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
inner_dim=self.audio_inner_dim,
max_pos=config.audio_positional_embedding_max_pos,
num_attention_heads=self.audio_num_attention_heads,
cross_pe_max_pos=cross_pe_max_pos,
use_middle_indices_grid=config.use_middle_indices_grid,
audio_cross_attention_dim=config.audio_cross_attention_dim,
timestep_scale_multiplier=config.timestep_scale_multiplier,
positional_embedding_theta=config.positional_embedding_theta,
rope_type=config.rope_type,
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
double_precision_rope=config.double_precision_rope,
)
elif config.model_type.is_video_enabled():
self.video_args_preprocessor = TransformerArgsPreprocessor(
patchify_proj=self.patchify_proj,
adaln=self.adaln_single,
caption_projection=self.caption_projection,
inner_dim=self.inner_dim,
max_pos=config.positional_embedding_max_pos,
num_attention_heads=self.num_attention_heads,
use_middle_indices_grid=config.use_middle_indices_grid,
timestep_scale_multiplier=config.timestep_scale_multiplier,
positional_embedding_theta=config.positional_embedding_theta,
rope_type=config.rope_type,
double_precision_rope=config.double_precision_rope,
)
elif config.model_type.is_audio_enabled():
self.audio_args_preprocessor = TransformerArgsPreprocessor(
patchify_proj=self.audio_patchify_proj,
adaln=self.audio_adaln_single,
caption_projection=self.audio_caption_projection,
inner_dim=self.audio_inner_dim,
max_pos=config.audio_positional_embedding_max_pos,
num_attention_heads=self.audio_num_attention_heads,
use_middle_indices_grid=config.use_middle_indices_grid,
timestep_scale_multiplier=config.timestep_scale_multiplier,
positional_embedding_theta=config.positional_embedding_theta,
rope_type=config.rope_type,
double_precision_rope=config.double_precision_rope,
)
def _init_transformer_blocks(self, config: LTXModelConfig) -> None:
video_config = config.get_video_config()
audio_config = config.get_audio_config()
self.transformer_blocks = [
BasicAVTransformerBlock(
idx=idx,
video=video_config,
audio=audio_config,
rope_type=config.rope_type,
norm_eps=config.norm_eps,
)
for idx in range(config.num_layers)
]
def _process_transformer_blocks(
self,
video: Optional[TransformerArgs],
audio: Optional[TransformerArgs],
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
"""Process through all transformer blocks."""
for block in self.transformer_blocks:
video, audio = block(video=video, audio=audio)
return video, audio
def _process_output(
self,
scale_shift_table: mx.array,
norm_out: nn.LayerNorm,
proj_out: nn.Linear,
x: mx.array,
embedded_timestep: mx.array,
) -> mx.array:
# scale_shift_table: (2, dim) -> expand to (1, 1, 2, dim)
# embedded_timestep: (B, 1, dim) -> expand to (B, 1, 1, dim)
table_expanded = scale_shift_table[None, None, :, :] # (1, 1, 2, dim)
timestep_expanded = embedded_timestep[:, :, None, :] # (B, 1, 1, dim)
# Combine: (1, 1, 2, dim) + (B, 1, 1, dim) broadcasts to (B, 1, 2, dim)
scale_shift_values = table_expanded + timestep_expanded
# Extract shift and scale (first index is shift, second is scale)
shift = scale_shift_values[:, :, 0, :] # (B, 1, dim)
scale = scale_shift_values[:, :, 1, :] # (B, 1, dim)
x = norm_out(x)
x = x * (1 + scale) + shift # Broadcasts (B, 1, dim) to (B, seq, dim)
x = proj_out(x)
return x
def __call__(
self,
video: Optional[Modality] = None,
audio: Optional[Modality] = None,
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
# Validate inputs
if not self.model_type.is_video_enabled() and video is not None:
raise ValueError("Video is not enabled for this model")
if not self.model_type.is_audio_enabled() and audio is not None:
raise ValueError("Audio is not enabled for this model")
# Preprocess arguments
video_args = self.video_args_preprocessor.prepare(video) if video is not None else None
audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None
# Process transformer blocks
video_out, audio_out = self._process_transformer_blocks(
video=video_args,
audio=audio_args,
)
# Process outputs
vx = (
self._process_output(
self.scale_shift_table,
self.norm_out,
self.proj_out,
video_out.x,
video_out.embedded_timestep,
)
if video_out is not None
else None
)
ax = (
self._process_output(
self.audio_scale_shift_table,
self.audio_norm_out,
self.audio_proj_out,
audio_out.x,
audio_out.embedded_timestep,
)
if audio_out is not None
else None
)
return vx, ax
def sanitize(self, weights: dict) -> dict:
sanitized = {}
for key, value in weights.items():
new_key = key
# Handle common remappings
# transformer_blocks.X -> transformer_blocks[X]
if "transformer_blocks." in new_key:
# Keep as-is for now, MLX handles this
pass
sanitized[new_key] = value
return sanitized
class X0Model(nn.Module):
def __init__(self, velocity_model: LTXModel):
super().__init__()
self.velocity_model = velocity_model
def __call__(
self,
video: Optional[Modality] = None,
audio: Optional[Modality] = None,
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
vx, ax = self.velocity_model(video, audio)
denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
return denoised_video, denoised_audio

View File

@@ -0,0 +1,508 @@
import math
from typing import Callable, List, Optional, Tuple
import mlx.core as mx
import numpy as np
from mlx_video.models.ltx.config import LTXRopeType
def apply_rotary_emb(
input_tensor: mx.array,
freqs_cis: Tuple[mx.array, mx.array],
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
) -> mx.array:
"""Apply rotary position embeddings to input tensor.
Args:
input_tensor: Input tensor to apply RoPE to
freqs_cis: Tuple of (cos_freqs, sin_freqs)
rope_type: Type of RoPE to apply (INTERLEAVED or SPLIT)
Returns:
Tensor with rotary embeddings applied
"""
if rope_type == LTXRopeType.INTERLEAVED:
return apply_interleaved_rotary_emb(input_tensor, freqs_cis[0], freqs_cis[1])
elif rope_type == LTXRopeType.SPLIT:
return apply_split_rotary_emb(input_tensor, freqs_cis[0], freqs_cis[1])
else:
raise ValueError(f"Invalid rope type: {rope_type}")
def apply_interleaved_rotary_emb(
input_tensor: mx.array,
cos_freqs: mx.array,
sin_freqs: mx.array,
) -> mx.array:
"""Apply interleaved rotary embeddings.
Pairs adjacent dimensions and applies rotation.
Pattern: [x0, x1, x2, x3, ...] -> rotate pairs (x0,x1), (x2,x3), ...
Args:
input_tensor: Input tensor of shape (..., dim)
cos_freqs: Cosine frequencies
sin_freqs: Sine frequencies
Returns:
Tensor with interleaved rotary embeddings applied
"""
# Reshape to pair adjacent dimensions: (..., dim) -> (..., dim/2, 2)
shape = input_tensor.shape
input_tensor = mx.reshape(input_tensor, shape[:-1] + (shape[-1] // 2, 2))
# Extract pairs
t1 = input_tensor[..., 0] # Even indices
t2 = input_tensor[..., 1] # Odd indices
# Apply rotation: (-t2, t1) pattern
t_rot = mx.stack([-t2, t1], axis=-1)
# Flatten back: (..., dim/2, 2) -> (..., dim)
input_tensor = mx.reshape(input_tensor, shape)
t_rot = mx.reshape(t_rot, shape)
# Apply rotary embeddings
out = input_tensor * cos_freqs + t_rot * sin_freqs
return out
def rotate_half_interleaved(x: mx.array) -> mx.array:
"""Rotate for interleaved RoPE: [x0, x1, x2, x3] -> [-x1, x0, -x3, x2].
PyTorch equivalent:
t_dup = rearrange(x, "... (d r) -> ... d r", r=2)
t1, t2 = t_dup.unbind(dim=-1)
t_dup = torch.stack((-t2, t1), dim=-1)
return rearrange(t_dup, "... d r -> ... (d r)")
"""
# x: (..., dim) where dim is even
x_even = x[..., 0::2] # [x0, x2, x4, ...]
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
# Stack: [[-x1, x0], [-x3, x2], ...] then flatten to [-x1, x0, -x3, x2, ...]
rotated = mx.stack([-x_odd, x_even], axis=-1)
return mx.reshape(rotated, x.shape)
def apply_rotary_emb_1d(
q: mx.array,
k: mx.array,
freqs_cis: mx.array,
) -> Tuple[mx.array, mx.array]:
"""Apply 1D rotary embeddings using precomputed frequencies (interleaved)."""
# freqs_cis: (1, seq_len, num_heads, head_dim, 2) where [..., 0] = cos, [..., 1] = sin
cos = freqs_cis[..., 0] # (1, seq_len, num_heads, head_dim)
sin = freqs_cis[..., 1]
# q, k: (batch, seq_len, num_heads, head_dim)
# Interleaved RoPE: pairs of adjacent dims rotate together
q_r = q * cos + rotate_half_interleaved(q) * sin
k_r = k * cos + rotate_half_interleaved(k) * sin
return q_r, k_r
def apply_split_rotary_emb(
input_tensor: mx.array,
cos_freqs: mx.array,
sin_freqs: mx.array,
) -> mx.array:
"""Apply split rotary embeddings.
Splits dimensions into two halves and applies rotation.
Pattern: split into first half and second half
Args:
input_tensor: Input tensor
cos_freqs: Cosine frequencies of shape (B, H, T, D//2)
sin_freqs: Sine frequencies of shape (B, H, T, D//2)
Returns:
Tensor with split rotary embeddings applied
"""
needs_reshape = False
original_shape = input_tensor.shape
# Handle dimension mismatch
if input_tensor.ndim != 4 and cos_freqs.ndim == 4:
b, h, t, _ = cos_freqs.shape
# Reshape from (B, T, H*D) to (B, H, T, D)
input_tensor = mx.reshape(input_tensor, (b, t, h, -1))
input_tensor = mx.swapaxes(input_tensor, 1, 2)
needs_reshape = True
# Split into two halves: (..., dim) -> (..., 2, dim//2)
dim = input_tensor.shape[-1]
split_input = mx.reshape(input_tensor, input_tensor.shape[:-1] + (2, dim // 2))
# Get first and second halves
first_half = split_input[..., 0, :] # (..., dim//2)
second_half = split_input[..., 1, :] # (..., dim//2)
# Apply cosine to both halves
output_first = first_half * cos_freqs
output_second = second_half * cos_freqs
# Apply sine cross-terms (addcmul pattern)
output_first = output_first - sin_freqs * second_half
output_second = output_second + sin_freqs * first_half
# Stack back together
output = mx.stack([output_first, output_second], axis=-2)
# Flatten: (..., 2, dim//2) -> (..., dim)
output = mx.reshape(output, input_tensor.shape)
if needs_reshape:
# Reshape back: (B, H, T, D) -> (B, T, H*D)
b, h, t, d = output.shape
output = mx.swapaxes(output, 1, 2)
output = mx.reshape(output, (b, t, h * d))
return output
def generate_freq_grid(
positional_embedding_theta: float,
positional_embedding_max_pos_count: int,
inner_dim: int,
) -> mx.array:
"""Generate frequency grid for RoPE.
Args:
positional_embedding_theta: Base theta value
positional_embedding_max_pos_count: Number of position dimensions
inner_dim: Inner dimension of the model
Returns:
Frequency indices tensor
"""
theta = positional_embedding_theta
start = 1.0
end = theta
n_elem = 2 * positional_embedding_max_pos_count
# Compute logarithmic spacing
log_start = math.log(start) / math.log(theta)
log_end = math.log(end) / math.log(theta)
num_indices = inner_dim // n_elem
if num_indices == 0:
num_indices = 1
# Create linearly spaced values in log space
lin_space = mx.linspace(log_start, log_end, num_indices)
# Compute power indices
pow_indices = mx.power(theta, lin_space)
# Scale by pi/2
return pow_indices * (math.pi / 2)
def get_fractional_positions(
indices_grid: mx.array,
max_pos: List[int],
) -> mx.array:
"""Convert indices to fractional positions.
Args:
indices_grid: Grid of position indices of shape (B, n_pos_dims, ...)
max_pos: Maximum position for each dimension
Returns:
Fractional positions in range [-1, 1] after scaling
"""
n_pos_dims = indices_grid.shape[1]
assert n_pos_dims == len(max_pos), (
f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
)
# Divide each dimension by its max position
fractional_positions = []
for i in range(n_pos_dims):
frac = indices_grid[:, i] / max_pos[i]
fractional_positions.append(frac)
return mx.stack(fractional_positions, axis=-1)
def generate_freqs(
indices: mx.array,
indices_grid: mx.array,
max_pos: List[int],
use_middle_indices_grid: bool,
) -> mx.array:
"""Generate frequencies from indices and position grid.
Args:
indices: Frequency indices
indices_grid: Position indices grid
max_pos: Maximum positions per dimension
use_middle_indices_grid: Whether to use middle of index ranges
Returns:
Frequency tensor
"""
# Handle middle indices grid
if use_middle_indices_grid:
# indices_grid shape: (B, n_dims, T, 2) where last dim is [start, end]
assert len(indices_grid.shape) == 4
assert indices_grid.shape[-1] == 2
indices_grid_start = indices_grid[..., 0]
indices_grid_end = indices_grid[..., 1]
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
elif len(indices_grid.shape) == 4:
indices_grid = indices_grid[..., 0]
# Get fractional positions
fractional_positions = get_fractional_positions(indices_grid, max_pos)
# Compute frequencies
# fractional_positions: (B, T, n_dims)
# indices: (inner_dim // n_elem,)
# Result: (B, T, inner_dim // n_elem * n_dims)
# Scale fractional positions to [-1, 1]
scaled_positions = fractional_positions * 2 - 1 # (B, T, n_dims)
# Outer product with indices
# (B, T, n_dims, 1) * (1, 1, 1, n_indices) -> (B, T, n_dims, n_indices)
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.expand_dims(
mx.expand_dims(mx.expand_dims(indices, axis=0), axis=0), axis=0
)
# Transpose and flatten: (B, T, n_dims, n_indices) -> (B, T, n_indices * n_dims)
freqs = mx.swapaxes(freqs, -1, -2) # (B, T, n_indices, n_dims)
freqs = mx.reshape(freqs, freqs.shape[:-2] + (-1,))
return freqs
def split_freqs_cis(
freqs: mx.array,
pad_size: int,
num_attention_heads: int,
) -> Tuple[mx.array, mx.array]:
"""Prepare cos/sin frequencies for split RoPE.
Args:
freqs: Frequency tensor
pad_size: Padding size for dimension alignment
num_attention_heads: Number of attention heads
Returns:
Tuple of (cos_freq, sin_freq) with shape (B, H, T, D//2)
"""
cos_freq = mx.cos(freqs)
sin_freq = mx.sin(freqs)
# Add padding if needed
if pad_size != 0:
cos_padding = mx.ones_like(cos_freq[:, :, :pad_size])
sin_padding = mx.zeros_like(sin_freq[:, :, :pad_size])
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
# Reshape for multi-head attention
b, t = cos_freq.shape[0], cos_freq.shape[1]
cos_freq = mx.reshape(cos_freq, (b, t, num_attention_heads, -1))
sin_freq = mx.reshape(sin_freq, (b, t, num_attention_heads, -1))
# Swap axes: (B, T, H, D//2) -> (B, H, T, D//2)
cos_freq = mx.swapaxes(cos_freq, 1, 2)
sin_freq = mx.swapaxes(sin_freq, 1, 2)
return cos_freq, sin_freq
def interleaved_freqs_cis(
freqs: mx.array,
pad_size: int,
) -> Tuple[mx.array, mx.array]:
"""Prepare cos/sin frequencies for interleaved RoPE.
Args:
freqs: Frequency tensor of shape (B, T, dim//2)
pad_size: Padding size for dimension alignment
Returns:
Tuple of (cos_freq, sin_freq) with shape (B, T, dim)
"""
# Compute cos and sin
cos_freq = mx.cos(freqs)
sin_freq = mx.sin(freqs)
# Repeat interleave: each element repeated twice
# (B, T, D) -> (B, T, 2*D) with pattern [c0, c0, c1, c1, ...]
cos_freq = mx.repeat(cos_freq, 2, axis=-1)
sin_freq = mx.repeat(sin_freq, 2, axis=-1)
# Add padding if needed
if pad_size != 0:
cos_padding = mx.ones_like(cos_freq[:, :, :pad_size])
sin_padding = mx.zeros_like(sin_freq[:, :, :pad_size])
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
return cos_freq, sin_freq
def precompute_freqs_cis(
indices_grid: mx.array,
dim: int,
theta: float = 10000.0,
max_pos: Optional[List[int]] = None,
use_middle_indices_grid: bool = False,
num_attention_heads: int = 32,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
double_precision: bool = False,
) -> Tuple[mx.array, mx.array]:
"""Precompute RoPE frequencies.
Args:
indices_grid: Position indices grid
dim: Dimension for RoPE
theta: Base theta value for frequency computation
max_pos: Maximum position per dimension
use_middle_indices_grid: Whether to use middle indices
num_attention_heads: Number of attention heads
rope_type: Type of RoPE (INTERLEAVED or SPLIT)
double_precision: If True, compute frequencies in float64 for higher precision
Returns:
Tuple of (cos_freq, sin_freq) tensors
"""
if max_pos is None:
max_pos = [20, 2048, 2048]
# For double precision, compute in numpy (float64) then convert back to MLX
# MLX GPU doesn't support float64, so we use numpy for high precision computation
if double_precision:
return _precompute_freqs_cis_double_precision(
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
num_attention_heads, rope_type
)
# Generate frequency indices
indices = generate_freq_grid(theta, indices_grid.shape[1], dim)
# Generate frequencies
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
# Prepare cos/sin based on rope type
if rope_type == LTXRopeType.SPLIT:
expected_freqs = dim // 2
current_freqs = freqs.shape[-1]
pad_size = expected_freqs - current_freqs
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
else:
# Interleaved
n_elem = 2 * indices_grid.shape[1]
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
return cos_freq, sin_freq
def _precompute_freqs_cis_double_precision(
indices_grid: mx.array,
dim: int,
theta: float,
max_pos: List[int],
use_middle_indices_grid: bool,
num_attention_heads: int,
rope_type: LTXRopeType,
) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies in double precision using numpy.
MLX GPU doesn't support float64, so we use numpy for computation then convert back.
"""
# Convert to numpy float64
indices_grid_np = np.array(indices_grid).astype(np.float64)
# Generate frequency indices in float64
n_pos_dims = indices_grid_np.shape[1]
n_elem = 2 * n_pos_dims
# Compute log-spaced frequencies
log_start = math.log(1.0) / math.log(theta)
log_end = math.log(theta) / math.log(theta)
num_indices = dim // n_elem
if num_indices == 0:
num_indices = 1
lin_space = np.linspace(log_start, log_end, num_indices)
indices_np = np.power(theta, lin_space) * (math.pi / 2)
# Handle middle indices grid
# Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise
if use_middle_indices_grid:
assert len(indices_grid_np.shape) == 4
assert indices_grid_np.shape[-1] == 2
indices_grid_start = indices_grid_np[..., 0]
indices_grid_end = indices_grid_np[..., 1]
indices_grid_np = (indices_grid_start + indices_grid_end) / 2.0
elif len(indices_grid_np.shape) == 4:
indices_grid_np = indices_grid_np[..., 0]
# After handling: indices_grid_np shape is (B, n_dims, T)
# Get fractional positions: (B, n_dims, T) -> (B, T, n_dims)
batch_size = indices_grid_np.shape[0]
seq_len = indices_grid_np.shape[2]
fractional_positions = np.zeros((batch_size, seq_len, n_pos_dims), dtype=np.float64)
for i in range(n_pos_dims):
# indices_grid_np[:, i, :] has shape (B, T)
fractional_positions[:, :, i] = indices_grid_np[:, i, :] / max_pos[i]
# Scale to [-1, 1]
scaled_positions = fractional_positions * 2 - 1
# Compute frequencies: outer product
freqs = np.expand_dims(scaled_positions, axis=-1) * indices_np.reshape(1, 1, 1, -1)
freqs = np.swapaxes(freqs, -1, -2)
freqs = freqs.reshape(freqs.shape[:-2] + (-1,))
# Compute cos/sin in float64
cos_freq = np.cos(freqs)
sin_freq = np.sin(freqs)
# Prepare based on rope type
if rope_type == LTXRopeType.SPLIT:
expected_freqs = dim // 2
current_freqs = cos_freq.shape[-1]
pad_size = expected_freqs - current_freqs
# Add padding
if pad_size > 0:
cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64)
sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64)
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1)
# Reshape for multi-head attention: (B, T, dim//2) -> (B, H, T, dim//2//H)
b, t = cos_freq.shape[0], cos_freq.shape[1]
cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1)
sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1)
cos_freq = np.swapaxes(cos_freq, 1, 2)
sin_freq = np.swapaxes(sin_freq, 1, 2)
else:
# Interleaved
cos_freq = np.repeat(cos_freq, 2, axis=-1)
sin_freq = np.repeat(sin_freq, 2, axis=-1)
pad_size = dim % n_elem
if pad_size > 0:
cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64)
sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64)
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1)
# Convert back to MLX (float32 for GPU compatibility)
cos_freq = mx.array(cos_freq.astype(np.float32))
sin_freq = mx.array(sin_freq.astype(np.float32))
return cos_freq, sin_freq

View File

@@ -0,0 +1,727 @@
"""Gemma 3 Text Encoder for LTX-2 - Full Pipeline."""
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.utils import rms_norm
from mlx_video.models.ltx.rope import apply_rotary_emb_1d
@dataclass
class Gemma3Config:
"""Configuration for Gemma 3 text model."""
hidden_size: int = 3840
num_attention_heads: int = 16
num_key_value_heads: int = 8
head_dim: int = 256
intermediate_size: int = 15360
num_hidden_layers: int = 48
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
vocab_size: int = 262208
max_position_embeddings: int = 131072
class RMSNorm(nn.Module):
"""RMS Normalization (Gemma style with 1+weight scaling)."""
def __init__(self, dims: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
# Gemma initializes to ones, but uses (1+weight) scaling
# After loading weights, weight will have the actual learned values
self.weight = mx.ones((dims,))
def __call__(self, x: mx.array) -> mx.array:
# Gemma-style RMSNorm uses (1 + weight) as the scale factor
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
def apply_rotary_emb(
q: mx.array,
k: mx.array,
positions: mx.array,
head_dim: int,
rope_theta: float = 1000000.0,
) -> Tuple[mx.array, mx.array]:
"""Apply rotary position embeddings to Q and K."""
inv_freq = 1.0 / (rope_theta ** (mx.arange(0, head_dim, 2).astype(mx.float32) / head_dim))
freqs = positions[:, :, None].astype(mx.float32) * inv_freq[None, None, :]
cos = mx.cos(freqs)
sin = mx.sin(freqs)
cos = cos[:, :, None, :]
sin = sin[:, :, None, :]
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return mx.concatenate([-x2, x1], axis=-1)
cos_full = mx.concatenate([cos, cos], axis=-1)
sin_full = mx.concatenate([sin, sin], axis=-1)
q_embed = q * cos_full + rotate_half(q) * sin_full
k_embed = k * cos_full + rotate_half(k) * sin_full
return q_embed, k_embed
class Gemma3MLP(nn.Module):
"""Gemma 3 MLP with gated activation."""
def __init__(self, config: Gemma3Config):
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def __call__(self, x: mx.array) -> mx.array:
gate = nn.gelu_approx(self.gate_proj(x))
up = self.up_proj(x)
return self.down_proj(gate * up)
class Gemma3Attention(nn.Module):
def __init__(self, config: Gemma3Config):
super().__init__()
self.config = config
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.scale = 1.0 / math.sqrt(config.head_dim)
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
def __call__(
self,
hidden_states: mx.array,
positions: mx.array,
attention_mask: Optional[mx.array] = None,
) -> mx.array:
batch_size, seq_len, _ = hidden_states.shape
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
k = mx.reshape(k, (batch_size, seq_len, self.num_kv_heads, self.head_dim))
v = mx.reshape(v, (batch_size, seq_len, self.num_kv_heads, self.head_dim))
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, positions, self.head_dim, self.config.rope_theta)
q = mx.transpose(q, (0, 2, 1, 3))
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
# Create causal mask (lower triangular)
causal_mask = mx.triu(mx.full((seq_len, seq_len), -1e9, dtype=k.dtype), k=1)
causal_mask = causal_mask[None, None, :, :] # (1, 1, seq, seq
if attention_mask is not None:
causal_mask = causal_mask + (1.0 - attention_mask[:, None, None, :].astype(k.dtype)) * -1e9
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=causal_mask)
out = mx.transpose(out, (0, 2, 1, 3))
out = mx.reshape(out, (batch_size, seq_len, -1))
return self.o_proj(out)
class Gemma3DecoderLayer(nn.Module):
def __init__(self, config: Gemma3Config):
super().__init__()
self.self_attn = Gemma3Attention(config)
self.mlp = Gemma3MLP(config)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
hidden_states: mx.array,
positions: mx.array,
attention_mask: Optional[mx.array] = None,
) -> mx.array:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(hidden_states, positions, attention_mask)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Gemma3TextModel(nn.Module):
def __init__(self, config: Gemma3Config):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [Gemma3DecoderLayer(config) for _ in range(config.num_hidden_layers)]
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Gemma scales embeddings by sqrt(hidden_size)
self.embed_scale = config.hidden_size ** 0.5
def __call__(
self,
input_ids: mx.array,
attention_mask: Optional[mx.array] = None,
output_hidden_states: bool = True,
) -> Tuple[mx.array, List[mx.array]]:
batch_size, seq_len = input_ids.shape
# Gemma scales embeddings by sqrt(hidden_size)
hidden_states = self.embed_tokens(input_ids) * self.embed_scale
all_hidden_states = [hidden_states] if output_hidden_states else []
positions = mx.arange(seq_len)[None, :].astype(mx.int32)
positions = mx.broadcast_to(positions, (batch_size, seq_len))
for layer in self.layers:
hidden_states = layer(hidden_states, positions, attention_mask)
if output_hidden_states:
all_hidden_states.append(hidden_states)
hidden_states = self.norm(hidden_states)
return hidden_states, all_hidden_states
class ConnectorAttention(nn.Module):
def __init__(
self,
dim: int = 3840,
num_heads: int = 30,
head_dim: int = 128,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
inner_dim = num_heads * head_dim
self.scale = 1.0 / math.sqrt(head_dim)
self.to_q = nn.Linear(dim, inner_dim, bias=True)
self.to_k = nn.Linear(dim, inner_dim, bias=True)
self.to_v = nn.Linear(dim, inner_dim, bias=True)
self.to_out = [nn.Linear(inner_dim, dim, bias=True)]
# Standard RMSNorm (not Gemma-style) on full inner_dim
self.q_norm = nn.RMSNorm(inner_dim, eps=1e-6)
self.k_norm = nn.RMSNorm(inner_dim, eps=1e-6)
def __call__(
self,
x: mx.array,
attention_mask: Optional[mx.array] = None,
pe: Optional[mx.array] = None,
) -> mx.array:
batch_size, seq_len, _ = x.shape
# Project to Q, K, V
q = self.to_q(x) # (B, seq, inner_dim)
k = self.to_k(x)
v = self.to_v(x)
# QK normalization on full inner_dim BEFORE reshape (matches PyTorch)
q = self.q_norm(q)
k = self.k_norm(k)
if pe is not None:
# pe: (1, seq_len, num_heads, head_dim, 2)
# q, k: (B, seq, inner_dim) - need to reshape for RoPE then reshape back
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim))
q, k = apply_rotary_emb_1d(q, k, pe)
# Reshape back for attention computation
q = mx.reshape(q, (batch_size, seq_len, -1))
k = mx.reshape(k, (batch_size, seq_len, -1))
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
mask = mx.full((batch_size, seq_len, seq_len), -1e9, dtype=q.dtype)
if attention_mask is not None:
mask = mask + (1.0 - attention_mask[:, None, None, :].astype(q.dtype)) * -1e9
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attention_mask)
out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
return self.to_out[0](out)
class GEGLU(nn.Module):
"""GELU-gated linear unit."""
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.proj = nn.Linear(in_dim, out_dim, bias=True)
def __call__(self, x: mx.array) -> mx.array:
return nn.gelu_approx(self.proj(x))
class ConnectorFeedForward(nn.Module):
def __init__(self, dim: int = 3840, mult: int = 4, dropout: float = 0.0):
super().__init__()
inner_dim = dim * mult
self.net = [
GEGLU(dim, inner_dim),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim, bias=True),
]
def __call__(self, x: mx.array) -> mx.array:
for layer in self.net:
x = layer(x)
return x
class ConnectorTransformerBlock(nn.Module):
def __init__(self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128):
super().__init__()
self.attn1 = ConnectorAttention(dim, num_heads, head_dim)
self.ff = ConnectorFeedForward(dim)
def __call__(
self,
x: mx.array,
attention_mask: Optional[mx.array] = None,
pe: Optional[mx.array] = None,
) -> mx.array:
# Pre-norm + attention + residual
norm_x = rms_norm(x)
if norm_x.ndim == 4:
norm_x = mx.squeeze(norm_x, axis=1)
attn_out = self.attn1(norm_x, attention_mask, pe)
x = x + attn_out
if x.ndim == 4:
x = mx.squeeze(x, axis=1)
# Pre-norm + FFN + residual
norm_x = rms_norm(x)
ff_out = self.ff(norm_x)
x = x + ff_out
if x.ndim == 4:
x = mx.squeeze(x, axis=1)
return x
class Embeddings1DConnector(nn.Module):
def __init__(
self,
dim: int = 3840,
num_heads: int = 30,
head_dim: int = 128,
num_layers: int = 2,
num_learnable_registers: int = 128,
positional_embedding_theta: float = 10000.0,
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = head_dim
self.num_learnable_registers = num_learnable_registers
self.positional_embedding_theta = positional_embedding_theta
self.transformer_1d_blocks = [
ConnectorTransformerBlock(dim, num_heads, head_dim)
for _ in range(num_layers)
]
if num_learnable_registers > 0:
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> mx.array:
import math
dim = self.num_heads * self.head_dim
theta = self.positional_embedding_theta
n_elem = 2
linspace_vals = mx.linspace(0.0, 1.0, dim // n_elem)
indices = (theta ** linspace_vals) * (math.pi / 2)
positions = mx.arange(seq_len).astype(mx.float32)
freqs = positions[:, None] * indices[None, :] # (seq_len, dim//2)
cos = mx.cos(freqs) # (seq_len, dim//2)
sin = mx.sin(freqs)
cos_full = mx.repeat(cos, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
sin_full = mx.repeat(sin, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
freqs_cis = mx.stack([cos_full, sin_full], axis=-1) # (1, seq_len, num_heads, head_dim, 2)
return freqs_cis.astype(dtype)
def _replace_padded_with_registers(
self,
hidden_states: mx.array,
attention_mask: mx.array,
) -> Tuple[mx.array, mx.array]:
batch_size, seq_len, dim = hidden_states.shape
# Binary mask: 1 for valid tokens, 0 for padded
# attention_mask is additive: 0 for valid, large negative for padded
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(mx.int32) # (batch, seq)
# Tile registers to match sequence length
num_tiles = seq_len // self.num_learnable_registers
registers = mx.tile(self.learnable_registers, (num_tiles, 1)) # (seq_len, dim)
# Process each batch item (PyTorch uses advanced indexing)
result_list = []
for b in range(batch_size):
mask_b = mask_binary[b] # (seq,)
hs_b = hidden_states[b] # (seq, dim)
# Count valid tokens
num_valid = int(mx.sum(mask_b))
# Extract valid tokens (where mask is 1)
# Since we have left-padded input, valid tokens are at the end
valid_tokens = hs_b[seq_len - num_valid:] # (num_valid, dim)
# Pad with zeros on the right to get back to seq_len
pad_length = seq_len - num_valid
if pad_length > 0:
padding = mx.zeros((pad_length, dim), dtype=hs_b.dtype)
adjusted = mx.concatenate([valid_tokens, padding], axis=0) # (seq_len, dim)
else:
adjusted = valid_tokens
# Create flipped mask: 1s at front (where valid tokens now are), 0s at back
flipped_mask = mx.concatenate([
mx.ones((num_valid,), dtype=mx.int32),
mx.zeros((pad_length,), dtype=mx.int32)
], axis=0) # (seq,)
# Combine: valid tokens at front, registers at back
flipped_mask_expanded = flipped_mask[:, None].astype(hs_b.dtype) # (seq, 1)
combined = flipped_mask_expanded * adjusted + (1 - flipped_mask_expanded) * registers
result_list.append(combined)
hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim)
# Reset attention mask to all zeros (no masking after register replacement)
attention_mask = mx.zeros_like(attention_mask)
return hidden_states, attention_mask
def __call__(
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
) -> Tuple[mx.array, mx.array]:
# Replace padded tokens with learnable registers
if self.num_learnable_registers > 0 and attention_mask is not None:
hidden_states, attention_mask = self._replace_padded_with_registers(
hidden_states, attention_mask
)
# Compute RoPE frequencies
seq_len = hidden_states.shape[1]
freqs_cis = self._precompute_freqs_cis(seq_len, hidden_states.dtype)
# Process through transformer blocks
for block in self.transformer_1d_blocks:
hidden_states = block(hidden_states, attention_mask, freqs_cis)
# Final RMS norm
hidden_states = rms_norm(hidden_states)
return hidden_states, attention_mask
def norm_and_concat_hidden_states(
hidden_states: List[mx.array],
attention_mask: mx.array,
padding_side: str = "left",
) -> mx.array:
# Stack hidden states: (batch, seq, dim, num_layers)
stacked = mx.stack(hidden_states, axis=-1)
b, t, d, num_layers = stacked.shape
# Compute sequence lengths from attention mask
sequence_lengths = mx.sum(attention_mask, axis=-1) # (batch,)
# Build mask based on padding side
token_indices = mx.arange(t)[None, :] # (1, T)
if padding_side == "right":
mask = token_indices < sequence_lengths[:, None] # (B, T)
else: # left padding
start_indices = t - sequence_lengths[:, None] # (B, 1)
mask = token_indices >= start_indices # (B, T)
mask = mask[:, :, None, None] # (B, T, 1, 1)
eps = 1e-6
# Compute masked mean per layer
masked = mx.where(mask, stacked, mx.zeros_like(stacked))
denom = (sequence_lengths * d).reshape(b, 1, 1, 1)
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
# Compute masked min/max per layer
large_val = 1e9
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, large_val, dtype=stacked.dtype))
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, -large_val, dtype=stacked.dtype))
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
range_val = x_max - x_min
# Normalize: 8 * (x - mean) / range
normed = 8 * (stacked - mean) / (range_val + eps)
# Flatten layers into feature dimension: (B, T, D*L)
normed = mx.reshape(normed, (b, t, -1))
# Zero out padded positions
mask_flat = mx.broadcast_to(mask[:, :, :, 0], (b, t, d * num_layers))
normed = mx.where(mask_flat, normed, mx.zeros_like(normed))
return normed
class GemmaFeaturesExtractor(nn.Module):
def __init__(self, input_dim: int = 188160, output_dim: int = 3840):
super().__init__()
self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=False)
def __call__(self, x: mx.array) -> mx.array:
return self.aggregate_embed(x)
def sanitize_gemma3_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
sanitized = {}
for key, value in weights.items():
new_key = None
if key.startswith("base_text_encoder.language_model."):
new_key = key.replace("base_text_encoder.language_model.", "")
elif key.startswith("language_model.model."):
new_key = key.replace("language_model.model.", "")
elif key.startswith("language_model."):
new_key = key.replace("language_model.", "")
else:
continue
if new_key is None:
continue
sanitized[new_key] = value
return sanitized
class LTX2TextEncoder(nn.Module):
def __init__(
self,
model_path: str = "Lightricks/LTX-2",
hidden_dim: int = 3840,
num_layers: int = 49, # 48 transformer layers + 1 embedding
):
super().__init__()
self._model_path = model_path
self.hidden_dim = hidden_dim
self.num_layers = num_layers
# Gemma 3 model
self.config = Gemma3Config()
self.model = Gemma3TextModel(self.config)
# Feature extractor: 3840*49 -> 3840
self.feature_extractor = GemmaFeaturesExtractor(
input_dim=hidden_dim * num_layers,
output_dim=hidden_dim,
)
# Video embeddings connector: 2-layer transformer
self.video_embeddings_connector = Embeddings1DConnector(
dim=hidden_dim,
num_heads=30,
head_dim=128,
num_layers=2,
num_learnable_registers=128,
)
self.processor = None
def load(self, model_path: Optional[str] = None):
path = model_path or self._model_path
# Load Gemma weights from text_encoder subdirectory
if Path(path).is_dir():
text_encoder_path = Path(path) / "text_encoder"
if text_encoder_path.exists():
gemma_path = str(text_encoder_path)
else:
gemma_path = path
else:
gemma_path = path
print(f"Loading Gemma 3 text encoder from {gemma_path}...")
weight_files = sorted(Path(gemma_path).glob("*.safetensors"))
all_weights = {}
for i, wf in enumerate(weight_files):
print(f" Loading weight file {i+1}/{len(weight_files)}...")
weights = mx.load(str(wf))
all_weights.update(weights)
# Sanitize and load Gemma weights
sanitized = sanitize_gemma3_weights(all_weights)
print(f" Sanitized Gemma weights: {len(sanitized)}")
self.model.load_weights(list(sanitized.items()), strict=False)
# Load transformer weights for feature extractor and connector
transformer_path = Path(model_path or self._model_path)
transformer_files = list(transformer_path.glob("ltx-2*.safetensors"))
if transformer_files:
print(f"Loading transformer weights for text pipeline...")
transformer_weights = mx.load(str(transformer_files[0]))
# Load feature extractor (aggregate_embed)
if "text_embedding_projection.aggregate_embed.weight" in transformer_weights:
self.feature_extractor.aggregate_embed.weight = transformer_weights[
"text_embedding_projection.aggregate_embed.weight"
]
print(" Loaded aggregate_embed weights")
# Load video_embeddings_connector weights
connector_weights = {}
for key, value in transformer_weights.items():
if key.startswith("model.diffusion_model.video_embeddings_connector."):
new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "")
connector_weights[new_key] = value
if connector_weights:
# Map weight names to our structure
mapped_weights = {}
for key, value in connector_weights.items():
# transformer_1d_blocks.X.attn1.* -> transformer_1d_blocks.X.attn1.*
# transformer_1d_blocks.X.ff.net.0.proj.* -> transformer_1d_blocks.X.ff.net.0.proj.*
# transformer_1d_blocks.X.ff.net.2.* -> transformer_1d_blocks.X.ff.net.2.*
mapped_weights[key] = value
self.video_embeddings_connector.load_weights(
list(mapped_weights.items()), strict=False
)
print(f" Loaded {len(connector_weights)} connector weights")
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
if "learnable_registers" in connector_weights:
self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"]
print(f" Loaded learnable_registers: {connector_weights['learnable_registers'].shape}")
# Load tokenizer
from transformers import AutoTokenizer
tokenizer_path = Path(model_path or self._model_path) / "tokenizer"
if tokenizer_path.exists():
self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True)
else:
self.processor = AutoTokenizer.from_pretrained(gemma_path, trust_remote_code=True)
# Set left padding to match official LTX-2 text encoder
self.processor.padding_side = "left"
print("Text encoder loaded successfully")
def encode(
self,
prompt: str,
max_length: int = 1024,
) -> Tuple[mx.array, mx.array]:
if self.processor is None:
raise RuntimeError("Model not loaded. Call load() first.")
# Tokenize with left padding (as in PyTorch version)
inputs = self.processor(
prompt,
return_tensors="np",
max_length=max_length,
truncation=True,
padding="max_length",
)
input_ids = mx.array(inputs["input_ids"])
attention_mask = mx.array(inputs["attention_mask"])
# Get all hidden states from Gemma
_, all_hidden_states = self.model(input_ids, attention_mask, output_hidden_states=True)
# Normalize and concatenate all hidden states
concat_hidden = norm_and_concat_hidden_states(
all_hidden_states, attention_mask, padding_side="left"
)
# Project through feature extractor
features = self.feature_extractor(concat_hidden)
# Convert attention mask to additive format for connector
additive_mask = (attention_mask - 1).astype(features.dtype)
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
# Process through connector
# Note: connector replaces padding with learnable registers and resets mask to zeros
# This means all positions now have valid embeddings (no need for final masking)
embeddings, _ = self.video_embeddings_connector(features, additive_mask)
# Return embeddings without zeroing - the connector's register replacement
# means all positions have meaningful values now
return embeddings, attention_mask
def __call__(
self,
prompt: str,
max_length: int = 1024,
) -> Tuple[mx.array, mx.array]:
return self.encode(prompt, max_length)
def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
encoder = LTX2TextEncoder(model_path=model_path)
encoder.load()
return encoder

View File

@@ -0,0 +1,26 @@
import mlx.core as mx
import mlx.nn as nn
class PixArtAlphaTextProjection(nn.Module):
def __init__(
self,
in_features: int,
hidden_size: int,
out_features: int | None = None,
bias: bool = True,
):
super().__init__()
out_features = out_features or hidden_size
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
self.act = nn.GELU(approx="precise")
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
def __call__(self, x: mx.array) -> mx.array:
x = self.linear1(x)
x = self.act(x)
x = self.linear2(x)
return x

View File

@@ -0,0 +1,359 @@
from dataclasses import dataclass, replace
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.config import LTXRopeType, TransformerConfig
from mlx_video.models.ltx.attention import Attention
from mlx_video.models.ltx.feed_forward import FeedForward
from mlx_video.utils import rms_norm
@dataclass(frozen=True)
class Modality:
latent: mx.array
timesteps: mx.array
positions: mx.array
context: mx.array
enabled: bool = True
context_mask: Optional[mx.array] = None
@dataclass(frozen=True)
class TransformerArgs:
x: mx.array
context: mx.array
context_mask: Optional[mx.array]
timesteps: mx.array
embedded_timestep: mx.array
positional_embeddings: Tuple[mx.array, mx.array]
cross_positional_embeddings: Optional[Tuple[mx.array, mx.array]]
cross_scale_shift_timestep: Optional[mx.array]
cross_gate_timestep: Optional[mx.array]
enabled: bool
class BasicAVTransformerBlock(nn.Module):
"""Audio-Video transformer block with cross-modal attention.
Supports video-only, audio-only, or combined audio-video processing
with bidirectional cross-attention between modalities.
"""
def __init__(
self,
idx: int,
video: Optional[TransformerConfig] = None,
audio: Optional[TransformerConfig] = None,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
norm_eps: float = 1e-6,
):
"""Initialize transformer block.
Args:
idx: Block index
video: Video modality configuration
audio: Audio modality configuration
rope_type: Type of rotary position embedding
norm_eps: Epsilon for normalization
"""
super().__init__()
self.idx = idx
self.norm_eps = norm_eps
# Video components
if video is not None:
self.attn1 = Attention(
query_dim=video.dim,
heads=video.heads,
dim_head=video.d_head,
context_dim=None, # Self-attention
rope_type=rope_type,
norm_eps=norm_eps,
)
self.attn2 = Attention(
query_dim=video.dim,
context_dim=video.context_dim,
heads=video.heads,
dim_head=video.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
)
self.ff = FeedForward(video.dim, dim_out=video.dim)
# 6 scale-shift parameters: 3 for attention, 3 for MLP
self.scale_shift_table = mx.zeros((6, video.dim))
# Audio components
if audio is not None:
self.audio_attn1 = Attention(
query_dim=audio.dim,
heads=audio.heads,
dim_head=audio.d_head,
context_dim=None,
rope_type=rope_type,
norm_eps=norm_eps,
)
self.audio_attn2 = Attention(
query_dim=audio.dim,
context_dim=audio.context_dim,
heads=audio.heads,
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
)
self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)
self.audio_scale_shift_table = mx.zeros((6, audio.dim))
# Cross-modal attention (when both video and audio are enabled)
if audio is not None and video is not None:
# Audio-to-Video: Q from video, K/V from audio
self.audio_to_video_attn = Attention(
query_dim=video.dim,
context_dim=audio.dim,
heads=audio.heads,
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
)
# Video-to-Audio: Q from audio, K/V from video
self.video_to_audio_attn = Attention(
query_dim=audio.dim,
context_dim=video.dim,
heads=audio.heads,
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
)
# Scale-shift tables for cross-attention
self.scale_shift_table_a2v_ca_audio = mx.zeros((5, audio.dim))
self.scale_shift_table_a2v_ca_video = mx.zeros((5, video.dim))
def get_ada_values(
self,
scale_shift_table: mx.array,
batch_size: int,
timestep: mx.array,
indices: slice,
) -> Tuple[mx.array, ...]:
"""Get adaptive normalization values from scale-shift table.
Args:
scale_shift_table: Table of shape (num_params, dim)
batch_size: Batch size
timestep: Timestep embeddings of shape (B, 1, num_params * dim) or similar
indices: Slice for which parameters to extract
Returns:
Tuple of scale-shift values
"""
num_ada_params = scale_shift_table.shape[0]
# scale_shift_table[indices]: (num_selected, dim)
# Add batch and sequence dimensions: (1, 1, num_selected, dim)
table_slice = scale_shift_table[indices]
table_expanded = mx.expand_dims(mx.expand_dims(table_slice, axis=0), axis=0)
# timestep: (B, seq, num_params * dim) -> reshape to (B, seq, num_params, dim)
timestep_reshaped = mx.reshape(
timestep,
(batch_size, timestep.shape[1], num_ada_params, -1)
)
# Extract the relevant indices
timestep_slice = timestep_reshaped[:, :, indices, :]
# Add table values to timestep
ada_values = table_expanded + timestep_slice
# Unbind along the parameter dimension
# Result: tuple of tensors, each of shape (B, seq, dim)
num_sliced = ada_values.shape[2]
result = tuple(ada_values[:, :, i, :] for i in range(num_sliced))
return result
def get_av_ca_ada_values(
self,
scale_shift_table: mx.array,
batch_size: int,
scale_shift_timestep: mx.array,
gate_timestep: mx.array,
num_scale_shift_values: int = 4,
) -> Tuple[mx.array, mx.array, mx.array, mx.array, mx.array]:
"""Get adaptive values for cross-modal attention.
Args:
scale_shift_table: Table with 5 parameters (4 scale-shift + 1 gate)
batch_size: Batch size
scale_shift_timestep: Timestep for scale-shift
gate_timestep: Timestep for gating
num_scale_shift_values: Number of scale-shift values (default 4)
Returns:
Tuple of 5 tensors: (scale1, shift1, scale2, shift2, gate)
"""
# Get scale-shift values
scale_shift_ada = self.get_ada_values(
scale_shift_table[:num_scale_shift_values, :],
batch_size,
scale_shift_timestep,
slice(None, None),
)
# Get gate values
gate_ada = self.get_ada_values(
scale_shift_table[num_scale_shift_values:, :],
batch_size,
gate_timestep,
slice(None, None),
)
# Squeeze the sequence dimension if it's 1
scale_shift_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada)
gate_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada)
return (*scale_shift_squeezed, *gate_squeezed)
def __call__(
self,
video: Optional[TransformerArgs] = None,
audio: Optional[TransformerArgs] = None,
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
"""Forward pass through transformer block.
Args:
video: Video modality arguments
audio: Audio modality arguments
Returns:
Tuple of (updated_video, updated_audio) TransformerArgs
"""
batch_size = video.x.shape[0] if video is not None else audio.x.shape[0]
vx = video.x if video is not None else None
ax = audio.x if audio is not None else None
# Check which modalities to run
run_vx = video is not None and video.enabled and vx.size > 0
run_ax = audio is not None and audio.enabled and ax.size > 0
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0)
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0)
# Process video self-attention and cross-attention with text
if run_vx:
vshift_msa, vscale_msa, vgate_msa = self.get_ada_values(
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
)
# Self-attention with RoPE
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa
# Cross-attention with text context
vx = vx + self.attn2(
rms_norm(vx, eps=self.norm_eps),
context=video.context,
mask=video.context_mask,
)
# Process audio self-attention and cross-attention with text
if run_ax:
ashift_msa, ascale_msa, agate_msa = self.get_ada_values(
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
)
# Self-attention with RoPE
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa
# Cross-attention with text context
ax = ax + self.audio_attn2(
rms_norm(ax, eps=self.norm_eps),
context=audio.context,
mask=audio.context_mask,
)
# Audio-Video cross-modal attention
if run_a2v or run_v2a:
vx_norm3 = rms_norm(vx, eps=self.norm_eps)
ax_norm3 = rms_norm(ax, eps=self.norm_eps)
# Get adaptive values for audio cross-attention
(
scale_ca_audio_a2v,
shift_ca_audio_a2v,
scale_ca_audio_v2a,
shift_ca_audio_v2a,
gate_out_v2a,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_audio,
ax.shape[0],
audio.cross_scale_shift_timestep,
audio.cross_gate_timestep,
)
# Get adaptive values for video cross-attention
(
scale_ca_video_a2v,
shift_ca_video_a2v,
scale_ca_video_v2a,
shift_ca_video_v2a,
gate_out_a2v,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_video,
vx.shape[0],
video.cross_scale_shift_timestep,
video.cross_gate_timestep,
)
# Audio-to-Video cross-attention
if run_a2v:
vx_scaled = vx_norm3 * (1 + scale_ca_video_a2v) + shift_ca_video_a2v
ax_scaled = ax_norm3 * (1 + scale_ca_audio_a2v) + shift_ca_audio_a2v
vx = vx + (
self.audio_to_video_attn(
vx_scaled,
context=ax_scaled,
pe=video.cross_positional_embeddings,
k_pe=audio.cross_positional_embeddings,
)
* gate_out_a2v
)
# Video-to-Audio cross-attention
if run_v2a:
ax_scaled = ax_norm3 * (1 + scale_ca_audio_v2a) + shift_ca_audio_v2a
vx_scaled = vx_norm3 * (1 + scale_ca_video_v2a) + shift_ca_video_v2a
ax = ax + (
self.video_to_audio_attn(
ax_scaled,
context=vx_scaled,
pe=audio.cross_positional_embeddings,
k_pe=video.cross_positional_embeddings,
)
* gate_out_v2a
)
# Process video feed-forward
if run_vx:
vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None)
)
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
vx = vx + self.ff(vx_scaled) * vgate_mlp
# Process audio feed-forward
if run_ax:
ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None)
)
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
ax = ax + self.audio_ff(ax_scaled) * agate_mlp
# Return updated TransformerArgs
video_out = replace(video, x=vx) if video is not None else None
audio_out = replace(audio, x=ax) if audio is not None else None
return video_out, audio_out

View File

@@ -0,0 +1,364 @@
from typing import Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
class Conv3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]] = 3,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
groups: int = 1,
bias: bool = True,
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
# Weight shape: (C_out, KD, KH, KW, C_in)
scale = 1.0 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels),
)
if bias:
self.bias = mx.zeros((out_channels,))
else:
self.bias = None
def __call__(self, x: mx.array) -> mx.array:
"""Forward pass.
Args:
x: Input tensor of shape (N, D, H, W, C_in)
Returns:
Output tensor of shape (N, D', H', W', C_out)
"""
y = mx.conv3d(
x,
self.weight,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
if self.bias is not None:
y = y + self.bias
return y
class GroupNorm3d(nn.Module):
def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5):
super().__init__()
self.num_groups = num_groups
self.num_channels = num_channels
self.eps = eps
self.weight = mx.ones((num_channels,))
self.bias = mx.zeros((num_channels,))
def __call__(self, x: mx.array) -> mx.array:
# x: (N, D, H, W, C)
n, d, h, w, c = x.shape
# Reshape to (N, D*H*W, num_groups, C//num_groups)
x = mx.reshape(x, (n, d * h * w, self.num_groups, c // self.num_groups))
# Compute mean and var over spatial and channel group dims
mean = mx.mean(x, axis=(1, 3), keepdims=True)
var = mx.var(x, axis=(1, 3), keepdims=True)
# Normalize
x = (x - mean) / mx.sqrt(var + self.eps)
# Reshape back
x = mx.reshape(x, (n, d, h, w, c))
# Apply weight and bias
x = x * self.weight + self.bias
return x
class PixelShuffle2D(nn.Module):
"""Pixel shuffle for 2D spatial upsampling."""
def __init__(self, upscale_factor: int = 2):
super().__init__()
self.upscale_factor = upscale_factor
def __call__(self, x: mx.array) -> mx.array:
# x: (N, H, W, C) where C = out_channels * upscale_factor^2
n, h, w, c = x.shape
r = self.upscale_factor
out_c = c // (r * r)
# Reshape: (N, H, W, out_c, r, r)
x = mx.reshape(x, (n, h, w, out_c, r, r))
# Permute: (N, H, r, W, r, out_c)
x = mx.transpose(x, (0, 1, 4, 2, 5, 3))
# Reshape: (N, H*r, W*r, out_c)
x = mx.reshape(x, (n, h * r, w * r, out_c))
return x
class SpatialRationalResampler(nn.Module):
def __init__(self, mid_channels: int = 1024, scale: float = 2.0):
super().__init__()
self.scale = scale
# 2D conv: mid_channels -> 4*mid_channels for pixel shuffle
self.conv = nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1)
# Blur kernel for antialiasing
self.blur_down_kernel = mx.ones((1, 1, 5, 5)) / 25.0
self.pixel_shuffle = PixelShuffle2D(2)
def __call__(self, x: mx.array) -> mx.array:
# x: (N, D, H, W, C) - channels last 3D format
n, d, h, w, c = x.shape
# Process frame by frame
# Reshape to (N*D, H, W, C) for 2D operations
x = mx.reshape(x, (n * d, h, w, c))
# Apply 2D conv
x = self.conv(x)
# Pixel shuffle for 2x upscaling
x = self.pixel_shuffle(x)
# Reshape back to (N, D, H*2, W*2, C)
x = mx.reshape(x, (n, d, h * 2, w * 2, c))
return x
class ResBlock3D(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.conv1 = Conv3d(channels, channels, kernel_size=3, padding=1)
self.norm1 = GroupNorm3d(32, channels)
self.conv2 = Conv3d(channels, channels, kernel_size=3, padding=1)
self.norm2 = GroupNorm3d(32, channels)
def __call__(self, x: mx.array) -> mx.array:
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = nn.silu(x)
x = self.conv2(x)
x = self.norm2(x)
# Activation AFTER residual addition
x = nn.silu(x + residual)
return x
class LatentUpsampler(nn.Module):
def __init__(
self,
in_channels: int = 128,
mid_channels: int = 1024,
num_blocks_per_stage: int = 4,
):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
# Initial projection
self.initial_conv = Conv3d(in_channels, mid_channels, kernel_size=3, padding=1)
self.initial_norm = GroupNorm3d(32, mid_channels)
# Pre-upsample ResBlocks - use dict with int keys for MLX parameter tracking
self.res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
# Upsampler: 2D spatial upsampling (frame-by-frame)
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=2.0)
# Post-upsample ResBlocks - use dict with int keys for MLX parameter tracking
self.post_upsample_res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
# Final projection
self.final_conv = Conv3d(mid_channels, in_channels, kernel_size=3, padding=1)
def __call__(self, latent: mx.array, debug: bool = False) -> mx.array:
"""Upsample latents by 2x spatially.
Args:
latent: Input tensor of shape (B, C, F, H, W) - channels first
debug: If True, print intermediate values for debugging
Returns:
Upsampled tensor of shape (B, C, F, H*2, W*2) - channels first
"""
def debug_stats(name, t):
if debug:
mx.eval(t)
print(f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
if debug:
print(" [DEBUG] LatentUpsampler forward pass:")
debug_stats("Input (channels first)", latent)
# Convert from channels first (B, C, F, H, W) to channels last (B, F, H, W, C)
x = mx.transpose(latent, (0, 2, 3, 4, 1))
if debug:
debug_stats("After transpose to channels-last", x)
# Initial conv
x = self.initial_conv(x)
if debug:
debug_stats("After initial_conv", x)
x = self.initial_norm(x)
if debug:
debug_stats("After initial_norm", x)
x = nn.silu(x)
if debug:
debug_stats("After silu", x)
# Pre-upsample blocks
for i in sorted(self.res_blocks.keys()):
x = self.res_blocks[i](x)
if debug:
debug_stats(f"After res_blocks[{i}]", x)
# Upsample (2D spatial, frame-by-frame)
x = self.upsampler(x)
if debug:
debug_stats("After upsampler (spatial 2x)", x)
# Post-upsample blocks
for i in sorted(self.post_upsample_res_blocks.keys()):
x = self.post_upsample_res_blocks[i](x)
if debug:
debug_stats(f"After post_upsample_res_blocks[{i}]", x)
# Final conv
x = self.final_conv(x)
if debug:
debug_stats("After final_conv", x)
# Convert back to channels first (B, C, F, H, W)
x = mx.transpose(x, (0, 4, 1, 2, 3))
if debug:
debug_stats("Output (channels first)", x)
return x
def upsample_latents(
latent: mx.array,
upsampler: LatentUpsampler,
latent_mean: mx.array,
latent_std: mx.array,
debug: bool = False,
) -> mx.array:
# Un-normalize: latent * std + mean
latent_mean = latent_mean.reshape(1, -1, 1, 1, 1)
latent_std = latent_std.reshape(1, -1, 1, 1, 1)
latent = latent * latent_std + latent_mean
# Upsample
latent = upsampler(latent, debug=debug)
# Re-normalize: (latent - mean) / std
latent = (latent - latent_mean) / latent_std
return latent
def load_upsampler(weights_path: str) -> LatentUpsampler:
"""Load upsampler from safetensors weights.
Args:
weights_path: Path to upsampler weights file
Returns:
Loaded LatentUpsampler model
"""
print(f"Loading spatial upsampler from {weights_path}...")
raw_weights = mx.load(weights_path)
# Check weight shapes to determine mid_channels
# res_blocks.0.conv1.weight should be (mid_channels, mid_channels, 3, 3, 3)
sample_key = "res_blocks.0.conv1.weight"
if sample_key in raw_weights:
mid_channels = raw_weights[sample_key].shape[0]
else:
mid_channels = 1024 # default
print(f" Detected mid_channels: {mid_channels}")
# Create model
upsampler = LatentUpsampler(
in_channels=128,
mid_channels=mid_channels,
num_blocks_per_stage=4,
)
# Sanitize weights - convert from PyTorch to MLX format
sanitized = {}
for key, value in raw_weights.items():
new_key = key
# Conv3d weights: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
if "conv" in key and "weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Conv2d weights: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
if "conv" in key and "weight" in key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
# Map upsampler.conv to upsampler.conv (SpatialRationalResampler)
# Keys: upsampler.conv.weight, upsampler.conv.bias, upsampler.blur_down.kernel
if key.startswith("upsampler."):
new_key = key # Keep as is for SpatialRationalResampler
sanitized[new_key] = value
# Load weights
upsampler.load_weights(list(sanitized.items()), strict=False)
print(f" Loaded {len(sanitized)} weights")
return upsampler

View File

@@ -0,0 +1 @@
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder

View File

@@ -0,0 +1,294 @@
from enum import Enum
from typing import List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
class PaddingModeType(Enum):
ZEROS = "zeros"
REFLECT = "reflect"
def reflect_pad_2d(x: mx.array, pad_h: int, pad_w: int) -> mx.array:
"""Apply reflect padding to spatial dimensions of a 5D tensor.
Args:
x: Input tensor of shape (B, D, H, W, C) - channels last
pad_h: Padding for height dimension
pad_w: Padding for width dimension
Returns:
Padded tensor
"""
if pad_h == 0 and pad_w == 0:
return x
# Height padding (axis 2)
if pad_h > 0:
# Get reflection indices - exclude boundary
top_pad = x[:, :, 1:pad_h+1, :, :][:, :, ::-1, :, :] # Flip top portion
bottom_pad = x[:, :, -pad_h-1:-1, :, :][:, :, ::-1, :, :] # Flip bottom portion
x = mx.concatenate([top_pad, x, bottom_pad], axis=2)
# Width padding (axis 3)
if pad_w > 0:
left_pad = x[:, :, :, 1:pad_w+1, :][:, :, :, ::-1, :] # Flip left portion
right_pad = x[:, :, :, -pad_w-1:-1, :][:, :, :, ::-1, :] # Flip right portion
x = mx.concatenate([left_pad, x, right_pad], axis=3)
return x
def make_conv_nd(
dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[int, Tuple[int, ...], str] = 0,
causal: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
) -> nn.Module:
if dims == 2:
return CausalConv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
causal=causal,
spatial_padding_mode=spatial_padding_mode,
)
elif dims == 3:
return CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
causal=causal,
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"Unsupported number of dimensions: {dims}")
class CausalConv3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int], str] = 0,
causal: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
self.causal = causal
self.spatial_padding_mode = spatial_padding_mode
# Normalize kernel_size and stride to tuples
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride, stride)
self.kernel_size = kernel_size
self.stride = stride
self.time_kernel_size = kernel_size[0]
# Calculate spatial padding (temporal is handled separately via frame replication)
height_pad = kernel_size[1] // 2
width_pad = kernel_size[2] // 2
self.spatial_padding = (height_pad, width_pad)
# Create the base convolution (without padding, we'll handle it manually)
self.conv = nn.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0, # We handle padding manually
bias=True,
)
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
use_causal = causal if causal is not None else self.causal
# Apply temporal padding via frame replication
# Only apply if kernel_size > 1
if self.time_kernel_size > 1:
if use_causal:
# Causal: replicate first frame kernel_size-1 times at the beginning
first_frame_pad = mx.repeat(x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2)
x = mx.concatenate([first_frame_pad, x], axis=2)
else:
# Non-causal: replicate first frame at start, last frame at end
pad_size = (self.time_kernel_size - 1) // 2
if pad_size > 0:
first_frame_pad = mx.repeat(x[:, :, :1, :, :], pad_size, axis=2)
last_frame_pad = mx.repeat(x[:, :, -1:, :, :], pad_size, axis=2)
x = mx.concatenate([first_frame_pad, x, last_frame_pad], axis=2)
# Transpose to channels last: (B, C, D, H, W) -> (B, D, H, W, C)
x = mx.transpose(x, (0, 2, 3, 4, 1))
# Apply spatial padding
pad_h, pad_w = self.spatial_padding
if pad_h > 0 or pad_w > 0:
if self.spatial_padding_mode == PaddingModeType.REFLECT:
# Use reflect padding for spatial dimensions
x = reflect_pad_2d(x, pad_h, pad_w)
else:
# Use zero padding for spatial dimensions
pad_width = [
(0, 0), # Batch
(0, 0), # D (temporal - already padded)
(pad_h, pad_h), # H
(pad_w, pad_w), # W
(0, 0), # C
]
x = mx.pad(x, pad_width)
# Apply convolution with chunking for large tensors
# Note: We choose to use chunking because MLX conv3d fails around 33 frames with 192x192 spatial
x = self._chunked_conv3d(x)
# Transpose back to channels first: (B, D, H, W, C) -> (B, C, D, H, W)
x = mx.transpose(x, (0, 4, 1, 2, 3))
return x
def _chunked_conv3d(self, x: mx.array) -> mx.array:
"""Apply conv3d in temporal chunks to work around MLX bug with large tensors.
Args:
x: Input tensor of shape (B, D, H, W, C) in channels-last format
Returns:
Output tensor after conv3d
"""
b, d, h, w, c = x.shape
total_elements = d * h * w * c
max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk
if total_elements <= max_safe_elements:
return self.conv(x)
elements_per_frame = h * w * c
max_frames_per_chunk = max(1, max_safe_elements // elements_per_frame)
chunk_size = min(max_frames_per_chunk, 24) # Cap at 24 frames per chunk
kernel_t = self.time_kernel_size
overlap = kernel_t - 1
expected_output_frames = d - overlap
outputs = []
out_idx = 0
# Process chunks
in_start = 0
while out_idx < expected_output_frames:
remaining = expected_output_frames - out_idx
out_frames_this_chunk = min(chunk_size, remaining)
in_frames_needed = out_frames_this_chunk + overlap
in_end = min(in_start + in_frames_needed, d)
chunk = x[:, in_start:in_end, :, :, :]
chunk_out = self.conv(chunk)
mx.eval(chunk_out)
outputs.append(chunk_out)
out_idx += chunk_out.shape[1]
in_start += chunk_out.shape[1]
# Concatenate all chunks
if len(outputs) == 1:
return outputs[0]
return mx.concatenate(outputs, axis=1)
class CausalConv2d(nn.Module):
"""2D convolution with optional causal padding."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int], str] = 0,
causal: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
"""Initialize CausalConv2d."""
super().__init__()
self.causal = causal
self.spatial_padding_mode = spatial_padding_mode
# Normalize kernel_size and stride
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride)
self.kernel_size = kernel_size
self.stride = stride
# Calculate padding
if isinstance(padding, str) and padding == "same":
self.padding = (
(kernel_size[0] - 1) // 2,
(kernel_size[1] - 1) // 2,
)
elif isinstance(padding, int):
self.padding = (padding, padding)
else:
self.padding = padding
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0,
bias=True,
)
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
"""Forward pass."""
# Transpose to channels last: (B, C, H, W) -> (B, H, W, C)
x = mx.transpose(x, (0, 2, 3, 1))
# Apply padding
pad_h, pad_w = self.padding
if pad_h != 0 or pad_w != 0:
pad_width = [
(0, 0), # Batch
(pad_h, pad_h), # H
(pad_w, pad_w), # W
(0, 0), # C
]
x = mx.pad(x, pad_width)
x = self.conv(x)
# Transpose back: (B, H, W, C) -> (B, C, H, W)
x = mx.transpose(x, (0, 3, 1, 2))
return x

View File

@@ -0,0 +1,524 @@
"""Video VAE Decoder for LTX-2 with timestep conditioning.
Architecture (from PyTorch weights):
- conv_in: 128 -> 1024
- up_blocks.0: 5 ResBlocks at 1024 (with timestep)
- up_blocks.1: Conv 1024 -> 4096, depth2space -> 512, upscale 2x
- up_blocks.2: 5 ResBlocks at 512 (with timestep)
- up_blocks.3: Conv 512 -> 2048, depth2space -> 256, upscale 2x
- up_blocks.4: 5 ResBlocks at 256 (with timestep)
- up_blocks.5: Conv 256 -> 1024, depth2space -> 128, upscale 2x
- up_blocks.6: 5 ResBlocks at 128 (with timestep)
- pixel_norm + timestep modulation (last_scale_shift_table)
- conv_out: 128 -> 48
- unpatchify: 48 -> 3 with patch_size=4
"""
import math
from typing import List, Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx.video_vae.ops import unpatchify
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
def get_timestep_embedding(
timesteps: mx.array,
embedding_dim: int,
flip_sin_to_cos: bool = True,
downscale_freq_shift: float = 0,
scale: float = 1,
max_period: int = 10000,
) -> mx.array:
"""Create sinusoidal timestep embeddings."""
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * mx.arange(0, half_dim, dtype=mx.float32)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = mx.exp(exponent)
emb = timesteps[:, None].astype(mx.float32) * emb[None, :]
emb = scale * emb
emb = mx.concatenate([mx.sin(emb), mx.cos(emb)], axis=-1)
if flip_sin_to_cos:
emb = mx.concatenate([emb[:, half_dim:], emb[:, :half_dim]], axis=-1)
if embedding_dim % 2 == 1:
emb = mx.pad(emb, [(0, 0), (0, 1)])
return emb
class TimestepEmbedding(nn.Module):
"""MLP for timestep embedding."""
def __init__(self, in_channels: int, time_embed_dim: int):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
self.act = nn.SiLU()
def __call__(self, sample: mx.array) -> mx.array:
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class PixArtAlphaTimestepEmbedder(nn.Module):
"""Combined timestep embedding (sinusoidal + MLP)."""
def __init__(self, embedding_dim: int):
super().__init__()
self.timestep_embedder = TimestepEmbedding(
in_channels=256,
time_embed_dim=embedding_dim
)
def __call__(self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32) -> mx.array:
timesteps_proj = get_timestep_embedding(
timestep,
embedding_dim=256,
flip_sin_to_cos=True,
downscale_freq_shift=0
)
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
return timesteps_emb
class ResnetBlock3DSimple(nn.Module):
"""ResNet block with optional timestep conditioning.
Weight keys: conv1.conv, conv2.conv, scale_shift_table
"""
def __init__(
self,
channels: int,
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
timestep_conditioning: bool = False,
):
super().__init__()
self.timestep_conditioning = timestep_conditioning
# Nested conv structure to match PyTorch naming: conv1.conv.weight
self.conv1 = self._make_conv_wrapper(channels, channels, spatial_padding_mode)
self.conv2 = self._make_conv_wrapper(channels, channels, spatial_padding_mode)
self.act = nn.SiLU()
# Scale-shift table for timestep conditioning: [shift1, scale1, shift2, scale2]
if timestep_conditioning:
self.scale_shift_table = mx.zeros((4, channels))
def _make_conv_wrapper(self, in_ch, out_ch, padding_mode):
"""Create a wrapper object with a 'conv' attribute to match PyTorch naming."""
class ConvWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
self_inner.conv = CausalConv3d(
in_channels=in_ch,
out_channels=out_ch,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
return ConvWrapper()
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization."""
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
def __call__(
self,
x: mx.array,
causal: bool = False,
timestep_embed: Optional[mx.array] = None,
) -> mx.array:
residual = x
batch_size = x.shape[0]
# Block 1 with optional timestep conditioning
x = self.pixel_norm(x)
if self.timestep_conditioning and timestep_embed is not None:
# scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1)
# Combine table with timestep embedding
ada_values = self.scale_shift_table[None, :, :, None, None, None] # (1, 4, C, 1, 1, 1)
# Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1)
channels = self.scale_shift_table.shape[1]
ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1)
ada_values = ada_values + ts_reshaped
shift1 = ada_values[:, 0] # (B, C, 1, 1, 1)
scale1 = ada_values[:, 1]
shift2 = ada_values[:, 2]
scale2 = ada_values[:, 3]
x = x * (1 + scale1) + shift1
x = self.act(x)
x = self.conv1(x, causal=causal)
# Block 2 with optional timestep conditioning
x = self.pixel_norm(x)
if self.timestep_conditioning and timestep_embed is not None:
x = x * (1 + scale2) + shift2
x = self.act(x)
x = self.conv2(x, causal=causal)
return x + residual
class ResBlockGroup(nn.Module):
"""Group of ResNet blocks with shared timestep embedding.
PyTorch naming: res_blocks.0, res_blocks.1, etc.
"""
def __init__(
self,
channels: int,
num_layers: int = 5,
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
timestep_conditioning: bool = False,
):
super().__init__()
self.timestep_conditioning = timestep_conditioning
# Time embedder for this block group: embed_dim = 4 * channels
if timestep_conditioning:
self.time_embedder = PixArtAlphaTimestepEmbedder(
embedding_dim=channels * 4
)
self.res_blocks = [
ResnetBlock3DSimple(
channels,
spatial_padding_mode,
timestep_conditioning=timestep_conditioning
)
for _ in range(num_layers)
]
def __call__(
self,
x: mx.array,
causal: bool = False,
timestep: Optional[mx.array] = None,
) -> mx.array:
timestep_embed = None
if self.timestep_conditioning and timestep is not None:
batch_size = x.shape[0]
timestep_embed = self.time_embedder(
timestep.flatten(),
hidden_dtype=x.dtype
)
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
for res_block in self.res_blocks:
x = res_block(x, causal=causal, timestep_embed=timestep_embed)
return x
class LTX2VideoDecoder(nn.Module):
"""LTX-2 Video VAE Decoder with timestep conditioning.
Architecture:
- conv_in: 128 -> 1024
- up_blocks.0: 5 ResBlocks at 1024 (with timestep)
- up_blocks.1: Upsampler 1024 -> 512
- up_blocks.2: 5 ResBlocks at 512 (with timestep)
- up_blocks.3: Upsampler 512 -> 256
- up_blocks.4: 5 ResBlocks at 256 (with timestep)
- up_blocks.5: Upsampler 256 -> 128
- up_blocks.6: 5 ResBlocks at 128 (with timestep)
- conv_out: 128 -> 48 (3 * 4^2 for patch_size=4)
"""
def __init__(
self,
in_channels: int = 128,
out_channels: int = 3,
patch_size: int = 4,
num_layers_per_block: int = 5,
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
timestep_conditioning: bool = True,
):
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.timestep_conditioning = timestep_conditioning
# Decode parameters (configurable via constructor)
self.decode_noise_scale = 0.025 # Set to 0.0 to disable noise
self.decode_timestep = 0.05
# Per-channel statistics for denormalization (loaded from weights)
self.latents_mean = mx.zeros((in_channels,))
self.latents_std = mx.ones((in_channels,))
# Initial conv: 128 -> 1024
class ConvInWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
self_inner.conv = CausalConv3d(
in_channels=in_channels,
out_channels=1024,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
self.conv_in = ConvInWrapper()
# Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample
self.up_blocks = [
ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
DepthToSpaceUpsample(
dims=3,
in_channels=1024,
stride=(2, 2, 2),
residual=True, # CRITICAL: Must match PyTorch config!
out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode,
),
ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
DepthToSpaceUpsample(
dims=3,
in_channels=512,
stride=(2, 2, 2),
residual=True, # CRITICAL: Must match PyTorch config!
out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode,
),
ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
DepthToSpaceUpsample(
dims=3,
in_channels=256,
stride=(2, 2, 2),
residual=True, # CRITICAL: Must match PyTorch config!
out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode,
),
ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
]
final_out_channels = out_channels * patch_size * patch_size
class ConvOutWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
self_inner.conv = CausalConv3d(
in_channels=128,
out_channels=final_out_channels,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
self.conv_out = ConvOutWrapper()
self.act = nn.SiLU()
if timestep_conditioning:
self.timestep_scale_multiplier = mx.array(1000.0)
self.last_time_embedder = PixArtAlphaTimestepEmbedder(
embedding_dim=128 * 2 # 256, matches (2, 128) table
)
self.last_scale_shift_table = mx.zeros((2, 128))
def denormalize(self, x: mx.array) -> mx.array:
"""Denormalize latents using per-channel statistics."""
mean = self.latents_mean.reshape(1, -1, 1, 1, 1)
std = self.latents_std.reshape(1, -1, 1, 1, 1)
return x * std + mean
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization."""
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
def __call__(
self,
sample: mx.array,
causal: bool = False,
timestep: Optional[mx.array] = None,
debug: bool = False,
) -> mx.array:
def debug_stats(name, t):
if debug:
mx.eval(t)
print(f" [VAE] {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
batch_size = sample.shape[0]
if debug:
debug_stats("Input", sample)
# Add noise if timestep conditioning is enabled
if self.timestep_conditioning:
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
sample = noise + (1.0 - self.decode_noise_scale) * sample
if debug:
debug_stats("After noise", sample)
if debug:
print(f" [VAE] Denorm stats - mean: [{self.latents_mean.min().item():.4f}, {self.latents_mean.max().item():.4f}], std: [{self.latents_std.min().item():.4f}, {self.latents_std.max().item():.4f}]")
sample = self.denormalize(sample)
if debug:
debug_stats("After denormalize", sample)
if timestep is None and self.timestep_conditioning:
timestep = mx.full((batch_size,), self.decode_timestep)
scaled_timestep = None
if self.timestep_conditioning and timestep is not None:
scaled_timestep = timestep * self.timestep_scale_multiplier
x = self.conv_in(sample, causal=causal)
if debug:
debug_stats("After conv_in", x)
for i, block in enumerate(self.up_blocks):
if isinstance(block, ResBlockGroup):
x = block(x, causal=causal, timestep=scaled_timestep)
else:
x = block(x, causal=causal)
if debug:
block_type = type(block).__name__
debug_stats(f"After up_blocks[{i}] ({block_type})", x)
x = self.pixel_norm(x)
if debug:
debug_stats("After pixel_norm", x)
if self.timestep_conditioning and scaled_timestep is not None:
embedded_timestep = self.last_time_embedder(
scaled_timestep.flatten(),
hidden_dtype=x.dtype
)
embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1)
ada_values = self.last_scale_shift_table[None, :, :, None, None, None] # (1, 2, 128, 1, 1, 1)
ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1)
ada_values = ada_values + ts_reshaped
shift = ada_values[:, 0] # (B, 128, 1, 1, 1)
scale = ada_values[:, 1]
x = x * (1 + scale) + shift
if debug:
debug_stats("After timestep modulation", x)
x = self.act(x)
if debug:
debug_stats("After activation", x)
x = self.conv_out(x, causal=causal)
if debug:
debug_stats("After conv_out", x)
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
if debug:
debug_stats("After unpatchify", x)
return x
def load_vae_decoder(model_path: str, timestep_conditioning: bool = True) -> LTX2VideoDecoder:
from pathlib import Path
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
model_path = Path(model_path)
# Try to find the weights file
if model_path.is_file() and model_path.suffix == ".safetensors":
weights_path = model_path
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
else:
raise FileNotFoundError(f"VAE weights not found at {model_path}")
print(f"Loading VAE decoder from {weights_path}...")
weights = mx.load(str(weights_path))
# Determine prefix based on weight keys
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
has_decoder_prefix = any(k.startswith("decoder.") for k in weights.keys())
if has_vae_prefix:
prefix = "vae.decoder."
stats_prefix = "vae.per_channel_statistics."
elif has_decoder_prefix:
prefix = "decoder."
stats_prefix = ""
else:
prefix = ""
stats_prefix = ""
# Load per-channel statistics for denormalization
# Note: use std-of-means (not mean-of-stds) for proper denormalization
mean_key = f"{stats_prefix}mean-of-means" if stats_prefix else "latents_mean"
std_key = f"{stats_prefix}std-of-means" if stats_prefix else "latents_std"
if mean_key in weights:
decoder.latents_mean = weights[mean_key]
print(f" Loaded latent mean: shape {decoder.latents_mean.shape}")
if std_key in weights:
decoder.latents_std = weights[std_key]
print(f" Loaded latent std: shape {decoder.latents_std.shape}")
# Build decoder weights dict with key remapping
decoder_weights = {}
for key, value in weights.items():
if not key.startswith(prefix):
continue
# Remove prefix
new_key = key[len(prefix):]
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".conv.weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
if ".conv.bias" in key:
pass # bias doesn't need transpose
if ".conv.weight" in new_key or ".conv.bias" in new_key:
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
decoder_weights[new_key] = value
print(f" Found {len(decoder_weights)} decoder weights")
ts_keys = [k for k in decoder_weights.keys() if "scale_shift" in k or "time_embedder" in k or "timestep_scale" in k]
print(f" Found {len(ts_keys)} timestep conditioning weights")
# Load weights
decoder.load_weights(list(decoder_weights.items()), strict=False)
print("VAE decoder loaded successfully")
return decoder

View File

@@ -0,0 +1,120 @@
"""Operations for Video VAE."""
from typing import List, Tuple
import mlx.core as mx
import mlx.nn as nn
def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
"""Convert video to patches.
Moves spatial pixels from H, W dimensions to channel dimension.
Args:
x: Input tensor of shape (B, C, F, H, W)
patch_size_hw: Spatial patch size
patch_size_t: Temporal patch size
Returns:
Patched tensor of shape (B, C * patch_size_hw^2, F, H/patch_size_hw, W/patch_size_hw)
"""
b, c, f, h, w = x.shape
# Check dimensions are divisible
assert h % patch_size_hw == 0 and w % patch_size_hw == 0
assert f % patch_size_t == 0
# New dimensions
new_h = h // patch_size_hw
new_w = w // patch_size_hw
new_f = f // patch_size_t
new_c = c * patch_size_hw * patch_size_hw * patch_size_t
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw))
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, ph, pw, F', H', W')
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
# Reshape: (B, C, pt, ph, pw, F', H', W') -> (B, C*pt*ph*pw, F', H', W')
x = mx.reshape(x, (b, new_c, new_f, new_h, new_w))
return x
def unpatchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
"""Convert patches back to video.
Inverse of patchify - moves pixels from channel dimension back to spatial.
Matches PyTorch einops: "b (c p r q) f h w -> b c (f p) (h q) (w r)"
where p=patch_size_t, r=patch_size_hw (width), q=patch_size_hw (height)
Args:
x: Patched tensor of shape (B, C * patch_size_hw^2, F, H, W)
patch_size_hw: Spatial patch size
patch_size_t: Temporal patch size
Returns:
Video tensor of shape (B, C, F * patch_size_t, H * patch_size_hw, W * patch_size_hw)
"""
b, c_packed, f, h, w = x.shape
# Calculate original channel count
c = c_packed // (patch_size_hw * patch_size_hw * patch_size_t)
# Reshape: (B, C*pt*pr*pq, F, H, W) -> (B, C, pt, pr, pq, F, H, W)
# where pt=temporal, pr=width_patch (r), pq=height_patch (q)
# Channel layout from PyTorch is (c, p, r, q) = (c, temporal, width, height)
x = mx.reshape(x, (b, c, patch_size_t, patch_size_hw, patch_size_hw, f, h, w))
# Permute to interleave patches with spatial dims:
# (B, C, pt, pr, pq, F, H, W) -> (B, C, F, pt, H, pq, W, pr)
x = mx.transpose(x, (0, 1, 5, 2, 6, 4, 7, 3))
# Reshape: (B, C, F, pt, H, pq, W, pr) -> (B, C, F*pt, H*pq, W*pr)
x = mx.reshape(x, (b, c, f * patch_size_t, h * patch_size_hw, w * patch_size_hw))
return x
class PerChannelStatistics(nn.Module):
def __init__(self, latent_channels: int = 128):
super().__init__()
self.latent_channels = latent_channels
# Learnable per-channel mean and std
self.mean = mx.zeros((latent_channels,))
self.std = mx.ones((latent_channels,))
def normalize(self, x: mx.array) -> mx.array:
"""Normalize latents using per-channel statistics.
Args:
x: Input tensor of shape (B, C, ...)
Returns:
Normalized tensor
"""
# Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1)
mean = self.mean.reshape(1, -1, 1, 1, 1)
std = self.std.reshape(1, -1, 1, 1, 1)
return (x - mean) / std
def un_normalize(self, x: mx.array) -> mx.array:
"""Denormalize latents using per-channel statistics.
Args:
x: Normalized tensor of shape (B, C, ...)
Returns:
Denormalized tensor
"""
mean = self.mean.reshape(1, -1, 1, 1, 1)
std = self.std.reshape(1, -1, 1, 1, 1)
return x * std + mean

View File

@@ -0,0 +1,171 @@
"""ResNet blocks for Video VAE."""
from enum import Enum
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.utils import PixelNorm
class NormLayerType(Enum):
GROUP_NORM = "group_norm"
PIXEL_NORM = "pixel_norm"
def get_norm_layer(
norm_type: NormLayerType,
num_channels: int,
num_groups: int = 32,
eps: float = 1e-6,
) -> nn.Module:
if norm_type == NormLayerType.GROUP_NORM:
return nn.GroupNorm(num_groups=num_groups, dims=num_channels, eps=eps)
elif norm_type == NormLayerType.PIXEL_NORM:
return PixelNorm(eps=eps)
else:
raise ValueError(f"Unknown norm type: {norm_type}")
class ResnetBlock3D(nn.Module):
def __init__(
self,
dims: int,
in_channels: int,
out_channels: Optional[int] = None,
eps: float = 1e-6,
groups: int = 32,
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
inject_noise: bool = False,
timestep_conditioning: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
out_channels = out_channels or in_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.inject_noise = inject_noise
# First normalization and convolution
self.norm1 = get_norm_layer(norm_layer, in_channels, groups, eps)
self.conv1 = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
# Second normalization and convolution
self.norm2 = get_norm_layer(norm_layer, out_channels, groups, eps)
self.conv2 = CausalConv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
# Shortcut connection if channels change
if in_channels != out_channels:
self.shortcut = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
spatial_padding_mode=spatial_padding_mode,
)
else:
self.shortcut = None
# Activation
self.act = nn.SiLU()
def __call__(
self,
x: mx.array,
causal: bool = True,
generator: Optional[int] = None,
) -> mx.array:
residual = x
# First block
x = self.norm1(x)
x = self.act(x)
x = self.conv1(x, causal=causal)
# Inject noise if enabled
if self.inject_noise and generator is not None:
noise = mx.random.normal(x.shape)
x = x + noise * 0.01
# Second block
x = self.norm2(x)
x = self.act(x)
x = self.conv2(x, causal=causal)
# Shortcut
if self.shortcut is not None:
residual = self.shortcut(residual, causal=causal)
return x + residual
class UNetMidBlock3D(nn.Module):
def __init__(
self,
dims: int,
in_channels: int,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_groups: int = 32,
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
inject_noise: bool = False,
timestep_conditioning: bool = False,
attention_head_dim: Optional[int] = None,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
self.num_layers = num_layers
# Create ResNet blocks
self.resnets = [
ResnetBlock3D(
dims=dims,
in_channels=in_channels,
out_channels=in_channels,
eps=resnet_eps,
groups=resnet_groups,
norm_layer=norm_layer,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
for _ in range(num_layers)
]
def __call__(
self,
x: mx.array,
causal: bool = True,
timestep: Optional[mx.array] = None,
generator: Optional[int] = None,
) -> mx.array:
for resnet in self.resnets:
x = resnet(x, causal=causal, generator=generator)
return x

View File

@@ -0,0 +1,173 @@
"""Sampling operations for Video VAE (upsampling/downsampling)."""
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
class SpaceToDepthDownsample(nn.Module):
def __init__(
self,
dims: int,
in_channels: int,
out_channels: int,
stride: Union[int, Tuple[int, int, int]],
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
if isinstance(stride, int):
stride = (stride, stride, stride)
self.stride = stride
self.dims = dims
# Calculate the multiplier for channels
multiplier = stride[0] * stride[1] * stride[2]
intermediate_channels = in_channels * multiplier
# 1x1x1 convolution to adjust channels
self.conv = CausalConv3d(
in_channels=intermediate_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
spatial_padding_mode=spatial_padding_mode,
)
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
b, c, d, h, w = x.shape
st, sh, sw = self.stride
# Pad if necessary to make dimensions divisible by stride
pad_d = (st - d % st) % st
pad_h = (sh - h % sh) % sh
pad_w = (sw - w % sw) % sw
if pad_d > 0 or pad_h > 0 or pad_w > 0:
# For causal, pad at the end of temporal dimension
if causal:
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)])
else:
x = mx.pad(x, [(0, 0), (0, 0), (pad_d // 2, pad_d - pad_d // 2),
(pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2)])
b, c, d, h, w = x.shape
# Reshape to group spatial elements
# (B, C, D, H, W) -> (B, C, D/st, st, H/sh, sh, W/sw, sw)
x = mx.reshape(x, (b, c, d // st, st, h // sh, sh, w // sw, sw))
# Permute to move stride elements to channel dim
# (B, C, D', st, H', sh, W', sw) -> (B, C, st, sh, sw, D', H', W')
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
# Reshape to combine channels
# (B, C, st, sh, sw, D', H', W') -> (B, C*st*sh*sw, D', H', W')
new_c = c * st * sh * sw
new_d = d // st
new_h = h // sh
new_w = w // sw
x = mx.reshape(x, (b, new_c, new_d, new_h, new_w))
# Apply 1x1 conv to adjust channels
x = self.conv(x, causal=causal)
return x
class DepthToSpaceUpsample(nn.Module):
def __init__(
self,
dims: int,
in_channels: int,
stride: Union[int, Tuple[int, int, int]],
residual: bool = False,
out_channels_reduction_factor: int = 1,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
if isinstance(stride, int):
stride = (stride, stride, stride)
self.stride = stride
self.dims = dims
self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor
# Calculate output channels
multiplier = stride[0] * stride[1] * stride[2]
out_channels = in_channels // out_channels_reduction_factor
self.out_channels = out_channels
# 3x3x3 convolution to prepare channels for unpacking (matches PyTorch)
self.conv = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels * multiplier,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def _depth_to_space(self, x: mx.array) -> mx.array:
b, c_packed, d, h, w = x.shape
st, sh, sw = self.stride
c = c_packed // (st * sh * sw)
# (B, C*st*sh*sw, D, H, W) -> (B, C, st, sh, sw, D, H, W)
x = mx.reshape(x, (b, c, st, sh, sw, d, h, w))
# (B, C, st, sh, sw, D, H, W) -> (B, C, D, st, H, sh, W, sw)
x = mx.transpose(x, (0, 1, 5, 2, 6, 3, 7, 4))
# (B, C, D, st, H, sh, W, sw) -> (B, C, D*st, H*sh, W*sw)
x = mx.reshape(x, (b, c, d * st, h * sh, w * sw))
return x
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
b, c, d, h, w = x.shape
st, sh, sw = self.stride
# Compute residual path if enabled
x_residual = None
if self.residual:
# Reshape input: treat channels as spatial factors
# "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)"
x_residual = self._depth_to_space(x)
# Tile channels to match output (PyTorch .repeat() tiles, not element-repeat!)
# num_repeat = prod(stride) / out_channels_reduction_factor
num_repeat = (st * sh * sw) // self.out_channels_reduction_factor
x_residual = mx.tile(x_residual, (1, num_repeat, 1, 1, 1))
# Remove first temporal frame if temporal upsampling
if st > 1:
x_residual = x_residual[:, :, 1:, :, :]
# Apply conv
x = self.conv(x, causal=causal)
# Depth to space rearrangement
x = self._depth_to_space(x)
# Remove first frame for causal temporal upsampling
if st > 1:
x = x[:, :, 1:, :, :]
# Add residual
if self.residual and x_residual is not None:
x = x + x_residual
return x

View File

@@ -0,0 +1,528 @@
"""Video VAE Encoder and Decoder for LTX-2."""
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx.video_vae.ops import PerChannelStatistics, patchify, unpatchify
from mlx_video.models.ltx.video_vae.resnet import (
NormLayerType,
ResnetBlock3D,
UNetMidBlock3D,
get_norm_layer,
)
from mlx_video.models.ltx.video_vae.sampling import (
DepthToSpaceUpsample,
SpaceToDepthDownsample,
)
from mlx_video.utils import PixelNorm
class LogVarianceType(Enum):
"""Log variance mode for VAE."""
PER_CHANNEL = "per_channel"
UNIFORM = "uniform"
CONSTANT = "constant"
NONE = "none"
def _make_encoder_block(
block_name: str,
block_config: Dict[str, Any],
in_channels: int,
convolution_dimensions: int,
norm_layer: NormLayerType,
norm_num_groups: int,
spatial_padding_mode: PaddingModeType,
) -> Tuple[nn.Module, int]:
"""Create an encoder block.
Args:
block_name: Type of block
block_config: Block configuration
in_channels: Input channels
convolution_dimensions: Number of dimensions
norm_layer: Normalization layer type
norm_num_groups: Number of groups for group norm
spatial_padding_mode: Padding mode
Returns:
Tuple of (block, output_channels)
"""
out_channels = in_channels
if block_name == "res_x":
block = UNetMidBlock3D(
dims=convolution_dimensions,
in_channels=in_channels,
num_layers=block_config["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
out_channels = in_channels * block_config.get("multiplier", 2)
block = ResnetBlock3D(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(2, 1, 1),
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(1, 2, 2),
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
block = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(2, 2, 2),
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_x_y":
out_channels = in_channels * block_config.get("multiplier", 2)
block = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(2, 2, 2),
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_res":
out_channels = in_channels * block_config.get("multiplier", 2)
block = SpaceToDepthDownsample(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space_res":
out_channels = in_channels * block_config.get("multiplier", 2)
block = SpaceToDepthDownsample(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time_res":
out_channels = in_channels * block_config.get("multiplier", 2)
block = SpaceToDepthDownsample(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"Unknown encoder block: {block_name}")
return block, out_channels
def _make_decoder_block(
block_name: str,
block_config: Dict[str, Any],
in_channels: int,
convolution_dimensions: int,
norm_layer: NormLayerType,
timestep_conditioning: bool,
norm_num_groups: int,
spatial_padding_mode: PaddingModeType,
) -> Tuple[nn.Module, int]:
"""Create a decoder block."""
out_channels = in_channels
if block_name == "res_x":
block = UNetMidBlock3D(
dims=convolution_dimensions,
in_channels=in_channels,
num_layers=block_config["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_config.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
out_channels = in_channels // block_config.get("multiplier", 2)
block = ResnetBlock3D(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_config.get("inject_noise", False),
timestep_conditioning=False,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = DepthToSpaceUpsample(
dims=convolution_dimensions,
in_channels=in_channels,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = DepthToSpaceUpsample(
dims=convolution_dimensions,
in_channels=in_channels,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
out_channels = in_channels // block_config.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=convolution_dimensions,
in_channels=in_channels,
stride=(2, 2, 2),
residual=block_config.get("residual", False),
out_channels_reduction_factor=block_config.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"Unknown decoder block: {block_name}")
return block, out_channels
class VideoEncoder(nn.Module):
_DEFAULT_NORM_NUM_GROUPS = 32
def __init__(
self,
convolution_dimensions: int = 3,
in_channels: int = 3,
out_channels: int = 128,
encoder_blocks: List[Tuple[str, Any]] = None,
patch_size: int = 4,
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
latent_log_var: LogVarianceType = LogVarianceType.UNIFORM,
encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
"""Initialize VideoEncoder.
Args:
convolution_dimensions: Number of dimensions (3 for video)
in_channels: Input channels (3 for RGB)
out_channels: Output latent channels
encoder_blocks: List of (block_name, config) tuples
patch_size: Spatial patch size
norm_layer: Normalization layer type
latent_log_var: Log variance mode
encoder_spatial_padding_mode: Padding mode
"""
super().__init__()
if encoder_blocks is None:
encoder_blocks = []
self.patch_size = patch_size
self.norm_layer = norm_layer
self.latent_channels = out_channels
self.latent_log_var = latent_log_var
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
# Per-channel statistics for normalizing latents
self.per_channel_statistics = PerChannelStatistics(latent_channels=out_channels)
# After patchify, channels increase by patch_size^2
in_channels = in_channels * patch_size ** 2
feature_channels = out_channels
# Initial convolution
self.conv_in = CausalConv3d(
in_channels=in_channels,
out_channels=feature_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=encoder_spatial_padding_mode,
)
# Build encoder blocks
self.down_blocks = []
for block_name, block_params in encoder_blocks:
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block, feature_channels = _make_encoder_block(
block_name=block_name,
block_config=block_config,
in_channels=feature_channels,
convolution_dimensions=convolution_dimensions,
norm_layer=norm_layer,
norm_num_groups=self._norm_num_groups,
spatial_padding_mode=encoder_spatial_padding_mode,
)
self.down_blocks.append(block)
# Output normalization and convolution
if norm_layer == NormLayerType.GROUP_NORM:
self.conv_norm_out = nn.GroupNorm(
num_groups=self._norm_num_groups,
dims=feature_channels,
eps=1e-6,
)
elif norm_layer == NormLayerType.PIXEL_NORM:
self.conv_norm_out = PixelNorm()
self.conv_act = nn.SiLU()
# Calculate output convolution channels
conv_out_channels = out_channels
if latent_log_var == LogVarianceType.PER_CHANNEL:
conv_out_channels *= 2
elif latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
conv_out_channels += 1
self.conv_out = CausalConv3d(
in_channels=feature_channels,
out_channels=conv_out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=encoder_spatial_padding_mode,
)
def __call__(self, sample: mx.array) -> mx.array:
"""Encode video to latent representation.
Args:
sample: Input video of shape (B, C, F, H, W).
F must be 1 + 8*k (e.g., 1, 9, 17, 25, 33...)
Returns:
Normalized latent means of shape (B, 128, F', H', W')
"""
# Validate frame count
frames_count = sample.shape[2]
if ((frames_count - 1) % 8) != 0:
raise ValueError(
"Invalid number of frames: Encode input must have 1 + 8 * x frames "
f"(e.g., 1, 9, 17, ...). Got {frames_count} frames."
)
# Initial patchify
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
sample = self.conv_in(sample, causal=True)
# Process through encoder blocks
for down_block in self.down_blocks:
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
sample = down_block(sample, causal=True)
else:
sample = down_block(sample, causal=True)
# Output processing
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=True)
# Handle log variance modes
if self.latent_log_var == LogVarianceType.UNIFORM:
means = sample[:, :-1, ...]
logvar = sample[:, -1:, ...]
num_channels = means.shape[1]
repeated_logvar = mx.tile(logvar, (1, num_channels, 1, 1, 1))
sample = mx.concatenate([means, repeated_logvar], axis=1)
elif self.latent_log_var == LogVarianceType.CONSTANT:
sample = sample[:, :-1, ...]
approx_ln_0 = -30
sample = mx.concatenate([
sample,
mx.full_like(sample, approx_ln_0),
], axis=1)
# Split into means and logvar, normalize means
means = sample[:, :self.latent_channels, ...]
return self.per_channel_statistics.normalize(means)
class VideoDecoder(nn.Module):
_DEFAULT_NORM_NUM_GROUPS = 32
def __init__(
self,
convolution_dimensions: int = 3,
in_channels: int = 128,
out_channels: int = 3,
decoder_blocks: List[Tuple[str, Any]] = None,
patch_size: int = 4,
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
causal: bool = False,
timestep_conditioning: bool = False,
decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
):
"""Initialize VideoDecoder.
Args:
convolution_dimensions: Number of dimensions
in_channels: Input latent channels
out_channels: Output channels (3 for RGB)
decoder_blocks: List of (block_name, config) tuples
patch_size: Spatial patch size
norm_layer: Normalization layer type
causal: Whether to use causal convolutions
timestep_conditioning: Whether to use timestep conditioning
decoder_spatial_padding_mode: Padding mode
"""
super().__init__()
if decoder_blocks is None:
decoder_blocks = []
self.patch_size = patch_size
out_channels = out_channels * patch_size ** 2
self.causal = causal
self.timestep_conditioning = timestep_conditioning
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
# Per-channel statistics for denormalizing latents
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
# Noise and timestep parameters
self.decode_noise_scale = 0.025
self.decode_timestep = 0.05
# Compute initial feature channels
feature_channels = in_channels
for block_name, block_params in list(reversed(decoder_blocks)):
block_config = block_params if isinstance(block_params, dict) else {}
if block_name == "res_x_y":
feature_channels = feature_channels * block_config.get("multiplier", 2)
if block_name == "compress_all":
feature_channels = feature_channels * block_config.get("multiplier", 1)
# Initial convolution
self.conv_in = CausalConv3d(
in_channels=in_channels,
out_channels=feature_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=decoder_spatial_padding_mode,
)
# Build decoder blocks (reversed order)
self.up_blocks = []
for block_name, block_params in list(reversed(decoder_blocks)):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block, feature_channels = _make_decoder_block(
block_name=block_name,
block_config=block_config,
in_channels=feature_channels,
convolution_dimensions=convolution_dimensions,
norm_layer=norm_layer,
timestep_conditioning=timestep_conditioning,
norm_num_groups=self._norm_num_groups,
spatial_padding_mode=decoder_spatial_padding_mode,
)
self.up_blocks.append(block)
# Output normalization
if norm_layer == NormLayerType.GROUP_NORM:
self.conv_norm_out = nn.GroupNorm(
num_groups=self._norm_num_groups,
dims=feature_channels,
eps=1e-6,
)
elif norm_layer == NormLayerType.PIXEL_NORM:
self.conv_norm_out = PixelNorm()
self.conv_act = nn.SiLU()
self.conv_out = CausalConv3d(
in_channels=feature_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=decoder_spatial_padding_mode,
)
def __call__(
self,
sample: mx.array,
timestep: Optional[mx.array] = None,
) -> mx.array:
"""Decode latent to video.
Args:
sample: Latent tensor of shape (B, 128, F', H', W')
timestep: Optional timestep for conditioning
Returns:
Decoded video of shape (B, 3, F, H, W)
"""
batch_size = sample.shape[0]
# Add noise if timestep conditioning is enabled
if self.timestep_conditioning:
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
sample = noise + (1.0 - self.decode_noise_scale) * sample
# Denormalize latents
sample = self.per_channel_statistics.un_normalize(sample)
# Use default timestep if not provided
if timestep is None and self.timestep_conditioning:
timestep = mx.full((batch_size,), self.decode_timestep)
# Initial convolution
sample = self.conv_in(sample, causal=self.causal)
# Process through decoder blocks
for up_block in self.up_blocks:
if isinstance(up_block, UNetMidBlock3D):
sample = up_block(sample, causal=self.causal)
elif isinstance(up_block, ResnetBlock3D):
sample = up_block(sample, causal=self.causal)
else:
sample = up_block(sample, causal=self.causal)
# Output processing
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=self.causal)
# Unpatchify to restore spatial resolution
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample

165
mlx_video/postprocess.py Normal file
View File

@@ -0,0 +1,165 @@
import numpy as np
from typing import Optional
def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75) -> np.ndarray:
"""Apply bilateral filter to reduce grid artifacts while preserving edges.
Args:
image: Input image as uint8 numpy array (H, W, C)
d: Diameter of each pixel neighborhood
sigma_color: Filter sigma in the color space
sigma_space: Filter sigma in the coordinate space
Returns:
Filtered image
"""
try:
import cv2
return cv2.bilateralFilter(image, d, sigma_color, sigma_space)
except ImportError:
# Fallback to simple Gaussian blur if cv2 not available
return gaussian_blur(image, kernel_size=3)
def gaussian_blur(image: np.ndarray, kernel_size: int = 3) -> np.ndarray:
"""Apply Gaussian blur.
Args:
image: Input image as uint8 numpy array (H, W, C)
kernel_size: Size of the Gaussian kernel (must be odd)
Returns:
Blurred image
"""
try:
import cv2
return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
except ImportError:
# Simple box blur fallback
from scipy.ndimage import uniform_filter
return uniform_filter(image, size=(kernel_size, kernel_size, 1)).astype(np.uint8)
def unsharp_mask(image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, amount: float = 1.0) -> np.ndarray:
"""Apply unsharp masking to enhance edges after blur.
Args:
image: Input image as uint8 numpy array
kernel_size: Size of the Gaussian kernel
sigma: Gaussian sigma
amount: Strength of sharpening
Returns:
Sharpened image
"""
try:
import cv2
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
sharpened = cv2.addWeighted(image, 1 + amount, blurred, -amount, 0)
return np.clip(sharpened, 0, 255).astype(np.uint8)
except ImportError:
return image
def reduce_grid_artifacts(
video: np.ndarray,
method: str = "bilateral",
strength: float = 1.0,
) -> np.ndarray:
"""Reduce grid artifacts in video frames.
Args:
video: Video as numpy array (F, H, W, C) uint8
method: "bilateral", "gaussian", or "frequency"
strength: How strong to apply the filter (0-1)
Returns:
Processed video
"""
if method == "bilateral":
d = max(3, int(5 * strength))
sigma = 50 + 50 * strength
processed = np.stack([
bilateral_filter(frame, d=d, sigma_color=sigma, sigma_space=sigma)
for frame in video
])
elif method == "gaussian":
kernel_size = max(3, int(3 + 4 * strength))
if kernel_size % 2 == 0:
kernel_size += 1
processed = np.stack([
gaussian_blur(frame, kernel_size=kernel_size)
for frame in video
])
elif method == "frequency":
processed = np.stack([
remove_grid_frequency(frame, grid_size=8)
for frame in video
])
else:
raise ValueError(f"Unknown method: {method}")
# Optionally sharpen to recover some detail
if strength < 1.0:
# Blend with original based on strength
alpha = strength
processed = (alpha * processed + (1 - alpha) * video).astype(np.uint8)
return processed
def remove_grid_frequency(frame: np.ndarray, grid_size: int = 8) -> np.ndarray:
"""Remove grid-frequency components using FFT.
Args:
frame: Input frame (H, W, C) uint8
grid_size: Expected grid periodicity in pixels
Returns:
Filtered frame
"""
result = np.zeros_like(frame)
for c in range(frame.shape[2]):
channel = frame[:, :, c].astype(np.float32)
h, w = channel.shape
# FFT
fft = np.fft.fft2(channel)
fft_shifted = np.fft.fftshift(fft)
# Create notch filter at grid frequencies
cy, cx = h // 2, w // 2
mask = np.ones((h, w), dtype=np.float32)
# Attenuate frequencies at grid periodicity
freq_y = h // grid_size
freq_x = w // grid_size
for fy in range(-2, 3):
for fx in range(-2, 3):
if fy == 0 and fx == 0:
continue
y_pos = cy + fy * freq_y
x_pos = cx + fx * freq_x
if 0 <= y_pos < h and 0 <= x_pos < w:
# Gaussian attenuation around the frequency
for dy in range(-2, 3):
for dx in range(-2, 3):
yy, xx = y_pos + dy, x_pos + dx
if 0 <= yy < h and 0 <= xx < w:
dist = np.sqrt(dy**2 + dx**2)
mask[yy, xx] *= min(1.0, dist / 3.0)
# Apply mask and inverse FFT
fft_filtered = fft_shifted * mask
channel_filtered = np.fft.ifft2(np.fft.ifftshift(fft_filtered)).real
result[:, :, c] = np.clip(channel_filtered, 0, 255).astype(np.uint8)
return result

View File

@@ -0,0 +1,26 @@
import mlx.core as mx
import mlx.nn as nn
class PixArtAlphaTextProjection(nn.Module):
def __init__(
self,
in_features: int,
hidden_size: int,
out_features: int | None = None,
bias: bool = True,
):
super().__init__()
out_features = out_features or hidden_size
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
self.act = nn.GELU(approx="precise")
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
def __call__(self, x: mx.array) -> mx.array:
x = self.linear1(x)
x = self.act(x)
x = self.linear2(x)
return x

127
mlx_video/utils.py Normal file
View File

@@ -0,0 +1,127 @@
"""Utility functions for MLX Video."""
import math
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from functools import partial
@partial(mx.compile, shapeless=True)
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],)), eps)
@partial(mx.compile, shapeless=True)
def to_denoised(
noisy: mx.array,
velocity: mx.array,
sigma: mx.array | float
) -> mx.array:
"""Convert velocity prediction to denoised output.
Given noisy input x_t and velocity prediction v, compute denoised x_0:
x_0 = x_t - sigma * v
Args:
noisy: Noisy input tensor x_t
velocity: Velocity prediction v
sigma: Noise level (scalar or per-sample)
Returns:
Denoised tensor x_0
"""
if isinstance(sigma, (int, float)):
return noisy - sigma * velocity
else:
# sigma is per-sample
while sigma.ndim < velocity.ndim:
sigma = mx.expand_dims(sigma, axis=-1)
return noisy - sigma * velocity
def repeat_interleave(x: mx.array, repeats: int, axis: int = -1) -> mx.array:
"""Repeat elements of tensor along an axis, similar to torch.repeat_interleave.
Args:
x: Input tensor
repeats: Number of repetitions for each element
axis: The axis along which to repeat values
Returns:
Tensor with repeated values
"""
# Handle negative axis
if axis < 0:
axis = x.ndim + axis
# Get shape
shape = list(x.shape)
# Expand dims, repeat, then reshape
x = mx.expand_dims(x, axis=axis + 1)
# Create tile pattern
tile_pattern = [1] * x.ndim
tile_pattern[axis + 1] = repeats
x = mx.tile(x, tile_pattern)
# Reshape to merge the repeated dimension
new_shape = shape.copy()
new_shape[axis] *= repeats
return mx.reshape(x, new_shape)
class PixelNorm(nn.Module):
def __init__(self, eps: float = 1e-6):
super().__init__()
self.eps = eps
def __call__(self, x: mx.array) -> mx.array:
return x / mx.sqrt(mx.mean(x * x, axis=1, keepdims=True) + self.eps)
def get_timestep_embedding(
timesteps: mx.array,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1.0,
scale: float = 1.0,
max_period: int = 10000,
) -> mx.array:
"""Create sinusoidal timestep embeddings.
Args:
timesteps: 1D tensor of timesteps
embedding_dim: Dimension of the embeddings to create
flip_sin_to_cos: If True, flip sin and cos ordering
downscale_freq_shift: Frequency shift factor
scale: Scale factor for timesteps
max_period: Maximum period for the sinusoids
Returns:
Tensor of shape (len(timesteps), embedding_dim)
"""
assert timesteps.ndim == 1, "Timesteps should be 1D"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * mx.arange(0, half_dim, dtype=mx.float32)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = mx.exp(exponent)
emb = (timesteps[:, None].astype(mx.float32) * scale) * emb[None, :]
# Compute sin and cos embeddings
if flip_sin_to_cos:
emb = mx.concatenate([mx.cos(emb), mx.sin(emb)], axis=-1)
else:
emb = mx.concatenate([mx.sin(emb), mx.cos(emb)], axis=-1)
# Zero pad if odd embedding dimension
if embedding_dim % 2 == 1:
emb = mx.pad(emb, [(0, 0), (0, 1)])
return emb