initial commit (LTX-2)

This commit is contained in:
Prince Canuma
2026-01-11 23:48:33 +01:00
parent 9f01d22750
commit d1ca36a315
29 changed files with 7124 additions and 0 deletions

5
.gitignore vendored Normal file
View File

@@ -0,0 +1,5 @@
.env
claude.md
.DS_Store
**.pyc
__pycache__/*

394
main.py Normal file
View File

@@ -0,0 +1,394 @@
import mlx.core as mx
import numpy as np
from pathlib import Path
from PIL import Image
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.transformer import Modality
from mlx_video.convert import sanitize_transformer_weights
from mlx_video.generate import create_position_grid
from mlx_video.utils import to_denoised
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.upsampler import LatentUpsampler, load_upsampler, upsample_latents
# Paths
from huggingface_hub import snapshot_download
from pathlib import Path
import os
LTX2_REPO = "Lightricks/LTX-2"
def get_ltx2_cache_dir():
# Try to get local cache (local_only), will not download files
try:
ref_path = snapshot_download(
repo_id=LTX2_REPO,
local_files_only=True,
allow_patterns=["*"],
ignore_patterns=[],
# leave as default revision and cache_dir, only local
)
return ref_path
except Exception:
# If not present locally, download from hub
return snapshot_download(
repo_id=LTX2_REPO,
local_files_only=False,
resume_download=True,
allow_patterns=["*.safetensors", "*.json"],
ignore_patterns=[]
)
LTX2_PATH = Path(get_ltx2_cache_dir())
MODEL_PATH = str(LTX2_PATH / 'ltx-2-19b-distilled.safetensors')
UPSAMPLER_PATH = str(LTX2_PATH / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')
TEXT_ENCODER_PATH = str(LTX2_PATH / 'text_encoder')
TOKENIZER_PATH = str(LTX2_PATH / 'tokenizer')
# Distilled sigma schedules (from PyTorch)
STAGE_1_SIGMA_SCHEDULE = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
STAGE_2_SIGMA_SCHEDULE = [0.909375, 0.725, 0.421875, 0.0] # Refinement steps
def denoise_loop(
latents: mx.array,
positions: mx.array,
text_embeddings: mx.array,
transformer: LTXModel,
sigma_schedule: list,
stage_name: str = "Stage",
negative_embeddings: mx.array = None,
cfg_scale: float = 1.0,
) -> mx.array:
"""Run denoising loop for given sigma schedule.
Args:
latents: Noisy latent tensor
positions: Position embeddings
text_embeddings: Positive prompt embeddings
transformer: The transformer model
sigma_schedule: List of sigma values for each step
stage_name: Name for logging
negative_embeddings: Negative prompt embeddings for CFG (optional)
cfg_scale: Classifier-free guidance scale (1.0 = no guidance)
"""
use_cfg = negative_embeddings is not None and cfg_scale > 1.0
for i in range(len(sigma_schedule) - 1):
sigma = sigma_schedule[i]
sigma_next = sigma_schedule[i + 1]
print(f" {stage_name} step {i+1}/{len(sigma_schedule)-1}: sigma={sigma:.4f} -> {sigma_next:.4f}")
b, c, f, h, w = latents.shape
latents_flat = mx.reshape(latents, (b, c, -1))
latents_flat = mx.transpose(latents_flat, (0, 2, 1))
timesteps = mx.full((1,), sigma)
# Positive (conditioned) prediction
video_modality = Modality(
latent=latents_flat,
timesteps=timesteps,
positions=positions,
context=text_embeddings,
context_mask=None,
enabled=True,
)
vx_cond, _ = transformer(video=video_modality, audio=None)
mx.eval(vx_cond)
if use_cfg:
# Negative (unconditioned) prediction
video_modality_neg = Modality(
latent=latents_flat,
timesteps=timesteps,
positions=positions,
context=negative_embeddings,
context_mask=None,
enabled=True,
)
vx_uncond, _ = transformer(video=video_modality_neg, audio=None)
mx.eval(vx_uncond)
# CFG: output = uncond + cfg_scale * (cond - uncond)
vx = vx_uncond + cfg_scale * (vx_cond - vx_uncond)
else:
vx = vx_cond
vx_reshaped = mx.transpose(vx, (0, 2, 1))
vx_reshaped = mx.reshape(vx_reshaped, (b, c, f, h, w))
# Debug: Print velocity stats
vx_np = np.array(vx_reshaped)
print(f" Velocity: min={vx_np.min():.4f}, max={vx_np.max():.4f}, mean={vx_np.mean():.4f}")
# Get denoised prediction: x_0 = x_t - sigma * velocity
denoised = to_denoised(latents, vx_reshaped, sigma)
mx.eval(denoised)
# Debug: Print denoised stats
denoised_np = np.array(denoised)
print(f" Denoised: min={denoised_np.min():.4f}, max={denoised_np.max():.4f}, mean={denoised_np.mean():.4f}")
# Euler step: x_next = x_0 + sigma_next * (x_t - x_0) / sigma
if sigma_next > 0:
velocity = (latents - denoised) / sigma
latents = denoised + sigma_next * velocity
else:
latents = denoised
mx.eval(latents)
# Debug: Print latents after step
latents_np = np.array(latents)
print(f" Latents after step: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
return latents
def main():
print("="*60)
print("MLX LTX-2 Video Generation (Two-Stage)")
print("="*60)
# Config - same as PyTorch reference
prompt = "A beautiful woman with flowing dark hair stands on a tropical beach at golden hour, gentle waves lapping at her feet, she turns and smiles at the camera, warm sunlight illuminating her face, palm trees swaying in the background, cinematic lighting, photorealistic"
negative_prompt = "" # PyTorch script doesn't use negative prompt
cfg_scale = 1.0 # No CFG in the distilled pipeline
height, width, num_frames = 512, 512, 500 # Must be divisible by 64 for two-stage
seed = 123
# Stage 1: Half resolution
stage1_height = height // 2
stage1_width = width // 2
stage1_latent_height = stage1_height // 32
stage1_latent_width = stage1_width // 32
latent_frames = 1 + (num_frames - 1) // 8
# Stage 2: Full resolution
latent_height = height // 32
latent_width = width // 32
print(f"\nConfig:")
print(f" Prompt: {prompt}")
print(f" Negative prompt: '{negative_prompt}'")
print(f" CFG scale: {cfg_scale}")
print(f" Final resolution: {width}x{height}, {num_frames} frames")
print(f" Stage 1: {stage1_width}x{stage1_height} -> latent {stage1_latent_width}x{stage1_latent_height}")
print(f" Stage 2: {width}x{height} -> latent {latent_width}x{latent_height}")
print(f" Seed: {seed}")
mx.random.seed(seed)
# Load text encoder
print("\nLoading text encoder...")
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
text_encoder = LTX2TextEncoder(model_path=str(LTX2_PATH))
text_encoder.load(str(LTX2_PATH))
mx.eval(text_encoder.parameters())
# Encode positive prompt
print("Encoding text...")
text_embeddings, attention_mask = text_encoder(prompt)
mx.eval(text_embeddings)
print(f" Positive embeddings: {text_embeddings.shape}")
# Encode negative prompt for CFG
negative_embeddings, _ = text_encoder(negative_prompt)
mx.eval(negative_embeddings)
print(f" Negative embeddings: {negative_embeddings.shape}")
# Free text encoder memory
del text_encoder
mx.clear_cache()
# Load transformer
print("\nLoading transformer...")
raw_weights = mx.load(MODEL_PATH)
sanitized = sanitize_transformer_weights(raw_weights)
config = LTXModelConfig(
model_type=LTXModelType.VideoOnly,
num_attention_heads=32,
attention_head_dim=128,
in_channels=128,
out_channels=128,
num_layers=48,
cross_attention_dim=4096,
caption_channels=3840,
rope_type=LTXRopeType.SPLIT,
double_precision_rope=True,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
timestep_scale_multiplier=1000,
)
transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False)
mx.eval(transformer.parameters())
print(" Transformer loaded!")
# ========================================
# Stage 1: Generate at half resolution
# ========================================
print("\n" + "="*60)
print("Stage 1: Generating at half resolution")
print("="*60)
mx.random.seed(seed)
latents = mx.random.normal((1, 128, latent_frames, stage1_latent_height, stage1_latent_width))
mx.eval(latents)
print(f" Initial latents: {latents.shape}")
positions = create_position_grid(1, latent_frames, stage1_latent_height, stage1_latent_width)
mx.eval(positions)
latents = denoise_loop(
latents=latents,
positions=positions,
text_embeddings=text_embeddings,
transformer=transformer,
sigma_schedule=STAGE_1_SIGMA_SCHEDULE,
stage_name="Stage 1",
negative_embeddings=negative_embeddings,
cfg_scale=cfg_scale,
)
print(f"\nStage 1 latents: {latents.shape}")
latents_np = np.array(latents)
print(f" Stats: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
# ========================================
# Upsample latents 2x
# ========================================
print("\n" + "="*60)
print("Upsampling latents 2x")
print("="*60)
# Load upsampler
print(" Loading spatial upsampler...")
upsampler = load_upsampler(UPSAMPLER_PATH)
mx.eval(upsampler.parameters())
# Load latent statistics for normalization
vae_decoder = load_vae_decoder(MODEL_PATH, timestep_conditioning=True)
# EXPERIMENT: Disable VAE decode noise for sharper output
# vae_decoder.decode_noise_scale = 0.0
# print(f" VAE decode_noise_scale set to {vae_decoder.decode_noise_scale}")
latent_mean = vae_decoder.latents_mean
latent_std = vae_decoder.latents_std
# Upsample
print(" Upsampling...")
latents = upsample_latents(latents, upsampler, latent_mean, latent_std, debug=False)
mx.eval(latents)
print(f" Upsampled latents: {latents.shape}")
# Free upsampler memory
del upsampler
mx.clear_cache()
# ========================================
# Stage 2: Refine at full resolution
# ========================================
print("\n" + "="*60)
print("Stage 2: Refining at full resolution")
print("="*60)
# Debug: Print upsampled latent stats before adding noise
latents_np = np.array(latents)
print(f" Upsampled latents (before noise): min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
# Create new position grid for full resolution
positions = create_position_grid(1, latent_frames, latent_height, latent_width)
mx.eval(positions)
# Add noise at initial sigma for stage 2
# PyTorch uses interpolation: noisy = noise * scale + clean * (1 - scale)
# NOT addition: noisy = clean + scale * noise
noise_scale = STAGE_2_SIGMA_SCHEDULE[0]
noise = mx.random.normal(latents.shape)
latents = noise * noise_scale + latents * (1 - noise_scale)
mx.eval(latents)
# Debug: Print latents after adding noise
latents_np = np.array(latents)
print(f" After adding noise (sigma={noise_scale}): min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
latents = denoise_loop(
latents=latents,
positions=positions,
text_embeddings=text_embeddings,
transformer=transformer,
sigma_schedule=STAGE_2_SIGMA_SCHEDULE,
stage_name="Stage 2",
)
print(f"\nFinal latents: {latents.shape}")
latents_np = np.array(latents)
print(f" Stats: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
# Save latents for PyTorch comparison
np.save("mlx_final_latents.npy", latents_np)
print(" Saved latents to mlx_final_latents.npy")
# Free transformer memory
del transformer
mx.clear_cache()
# ========================================
# Decode to video
# ========================================
print("\n" + "="*60)
print("Decoding with VAE")
print("="*60)
# Decode latents to video
video = vae_decoder(latents, debug=True)
mx.eval(video)
print(f" Video shape: {video.shape}")
# Convert to frames
video = mx.squeeze(video, axis=0) # (C, F, H, W)
# Debug: check raw RGB values before conversion
video_raw = np.array(video)
print(f" Raw video per-channel means: R={video_raw[0].mean():.4f}, G={video_raw[1].mean():.4f}, B={video_raw[2].mean():.4f}")
print(f" Raw video range: [{video_raw.min():.4f}, {video_raw.max():.4f}]")
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
video = (video + 1.0) / 2.0 # [-1, 1] -> [0, 1]
video = mx.clip(video, 0.0, 1.0)
video = (video * 255).astype(mx.uint8)
video_np = np.array(video)
print(f" Converted video RGB means: R={video_np[:,:,:,0].mean():.1f}, G={video_np[:,:,:,1].mean():.1f}, B={video_np[:,:,:,2].mean():.1f}")
# Save first frame
output_path = Path("mlx_output_frame0_2.png")
Image.fromarray(video_np[0]).save(output_path)
print(f"\nSaved first frame to {output_path}")
# Save video
try:
import imageio
video_path = "mlx_output_video_2.mp4"
imageio.mimwrite(video_path, video_np, fps=24, codec='libx264')
print(f"Saved video to {video_path}")
except Exception as e:
print(f"Could not save video: {e}")
print("\nDone!")
if __name__ == "__main__":
import time
start_time = time.time()
main()
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")

13
mlx_video/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
from mlx_video.models.ltx import LTXModel, LTXModelConfig
from mlx_video.generate import LTXVideoPipeline, GenerationConfig
from mlx_video.convert import load_transformer_weights, load_vae_weights
__all__ = [
"LTXModel",
"LTXModelConfig",
"LTXVideoPipeline",
"GenerationConfig",
"load_transformer_weights",
"load_vae_weights",
]

457
mlx_video/convert.py Normal file
View File

@@ -0,0 +1,457 @@
import json
import shutil
from pathlib import Path
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
from mlx_video.models.ltx.ltx import LTXModel
def get_model_path(
path_or_hf_repo: str,
revision: Optional[str] = None,
) -> Path:
"""Get local path to model, downloading if necessary.
Args:
path_or_hf_repo: Local path or HuggingFace repo ID
revision: Git revision for HF repo
Returns:
Path to model directory
"""
model_path = Path(path_or_hf_repo)
if model_path.exists():
return model_path
# Download from HuggingFace
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
revision=revision,
allow_patterns=[
"*.safetensors",
"*.json",
"config.json",
],
)
)
return model_path
def load_safetensors(path: Path) -> Dict[str, mx.array]:
"""Load weights from safetensors file(s) using MLX.
Args:
path: Path to model directory or single safetensors file
Returns:
Dictionary of weights
"""
weights = {}
if path.is_file():
# Single file - use mx.load directly (handles bfloat16)
return mx.load(str(path))
else:
# Directory - load all safetensors files
safetensor_files = list(path.glob("*.safetensors"))
for sf_path in safetensor_files:
file_weights = mx.load(str(sf_path))
weights.update(file_weights)
return weights
def load_transformer_weights(model_path: Path) -> Dict[str, mx.array]:
"""Load transformer weights from LTX-2 model.
Args:
model_path: Path to LTX-2 model directory
Returns:
Dictionary of transformer weights
"""
# Try distilled model first, then dev
weight_files = [
model_path / "ltx-2-19b-distilled.safetensors",
model_path / "ltx-2-19b-dev.safetensors",
]
for weight_file in weight_files:
if weight_file.exists():
print(f"Loading transformer weights from {weight_file.name}...")
return mx.load(str(weight_file))
raise FileNotFoundError(f"No transformer weights found in {model_path}")
def load_vae_weights(model_path: Path) -> Dict[str, mx.array]:
"""Load VAE weights from LTX-2 model.
Args:
model_path: Path to LTX-2 model directory
Returns:
Dictionary of VAE weights
"""
vae_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
if vae_path.exists():
print(f"Loading VAE weights from {vae_path}...")
return mx.load(str(vae_path))
raise FileNotFoundError(f"VAE weights not found at {vae_path}")
def sanitize_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize transformer weight names from PyTorch LTX-2 format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for transformer
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Skip non-transformer weights (VAE, vocoder, audio_vae, connectors)
if not key.startswith("model.diffusion_model."):
continue
# Remove 'model.diffusion_model.' prefix
new_key = key.replace("model.diffusion_model.", "")
# Handle to_out.0 -> to_out (MLX doesn't use Sequential numbering)
new_key = new_key.replace(".to_out.0.", ".to_out.")
# Handle feed-forward net naming
# PyTorch: ff.net.0.proj -> ff.net_0_proj (or similar)
# MLX FeedForward: uses proj_in, proj_out
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
# Handle AdaLN naming - keep emb wrapper, just fix linear naming
# PyTorch: adaln_single.emb.timestep_embedder.linear_1 -> adaln_single.emb.timestep_embedder.linear1
new_key = new_key.replace(".linear_1.", ".linear1.")
new_key = new_key.replace(".linear_2.", ".linear2.")
# Handle caption projection (keep linear1/linear2 naming for compatibility)
# These are already mapped correctly in the sanitization
sanitized[new_key] = value
return sanitized
def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize VAE weight names from PyTorch format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for VAE
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Skip position_ids (not needed)
if "position_ids" in key:
continue
# Handle Conv3d weight shape conversion
# PyTorch: (out_channels, in_channels, D, H, W)
# MLX: (out_channels, D, H, W, in_channels)
if "conv" in key.lower() and "weight" in key and value.ndim == 5:
# Transpose from (O, I, D, H, W) to (O, D, H, W, I)
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Handle Conv2d weight shape conversion
# PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels)
if "conv" in key.lower() and "weight" in key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
def sanitize_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize weight names from PyTorch format to MLX format.
Generic function that handles both transformer and VAE weights.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Skip position_ids (not needed)
if "position_ids" in key:
continue
# Handle transformer weights
if key.startswith("model.diffusion_model."):
new_key = key.replace("model.diffusion_model.", "")
new_key = new_key.replace(".to_out.0.", ".to_out.")
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
new_key = new_key.replace(".linear_1.", ".linear1.")
new_key = new_key.replace(".linear_2.", ".linear2.")
# Handle Conv3d weight shape conversion
# PyTorch: (out_channels, in_channels, D, H, W)
# MLX: (out_channels, D, H, W, in_channels)
if "conv" in key.lower() and "weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Handle Conv2d weight shape conversion
# PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels)
if "conv" in key.lower() and "weight" in key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
def load_config(model_path: Path) -> Dict[str, Any]:
"""Load model configuration.
Args:
model_path: Path to model directory
Returns:
Configuration dictionary
"""
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path, "r") as f:
return json.load(f)
# Return default config
return {}
def create_model_from_config(config: Dict[str, Any]) -> LTXModel:
"""Create model instance from configuration.
Args:
config: Configuration dictionary
Returns:
LTXModel instance
"""
# Map config to LTXModelConfig
model_config = LTXModelConfig(
model_type=LTXModelType.AudioVideo,
num_attention_heads=config.get("num_attention_heads", 32),
attention_head_dim=config.get("attention_head_dim", 128),
in_channels=config.get("in_channels", 128),
out_channels=config.get("out_channels", 128),
num_layers=config.get("num_layers", 48),
cross_attention_dim=config.get("cross_attention_dim", 4096),
caption_channels=config.get("caption_channels", 3840),
audio_num_attention_heads=config.get("audio_num_attention_heads", 32),
audio_attention_head_dim=config.get("audio_attention_head_dim", 64),
audio_in_channels=config.get("audio_in_channels", 128),
audio_out_channels=config.get("audio_out_channels", 128),
audio_cross_attention_dim=config.get("audio_cross_attention_dim", 2048),
positional_embedding_theta=config.get("positional_embedding_theta", 10000.0),
positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]),
audio_positional_embedding_max_pos=config.get("audio_positional_embedding_max_pos", [20]),
timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000),
av_ca_timestep_scale_multiplier=config.get("av_ca_timestep_scale_multiplier", 1),
norm_eps=config.get("norm_eps", 1e-6),
)
return LTXModel(model_config)
def convert(
hf_path: str,
mlx_path: str = "mlx_model",
dtype: Optional[str] = None,
quantize: bool = False,
q_bits: int = 4,
q_group_size: int = 64,
) -> Path:
"""Convert HuggingFace model to MLX format.
Args:
hf_path: HuggingFace model path or repo ID
mlx_path: Output path for MLX model
dtype: Target dtype (float16, float32, bfloat16)
quantize: Whether to quantize the model
q_bits: Quantization bits
q_group_size: Quantization group size
Returns:
Path to converted model
"""
print(f"Loading model from {hf_path}...")
model_path = get_model_path(hf_path)
# Load config
config = load_config(model_path)
# Load weights
print("Loading weights...")
weights = load_safetensors(model_path)
# Sanitize weights
print("Sanitizing weights...")
weights = sanitize_weights(weights)
# Convert dtype if specified
if dtype is not None:
dtype_map = {
"float16": mx.float16,
"float32": mx.float32,
"bfloat16": mx.bfloat16,
}
target_dtype = dtype_map.get(dtype, mx.float16)
print(f"Converting to {dtype}...")
weights = {
k: v.astype(target_dtype) if v.dtype in [mx.float32, mx.float16, mx.bfloat16] else v
for k, v in weights.items()
}
# Create output directory
output_path = Path(mlx_path)
output_path.mkdir(parents=True, exist_ok=True)
# Save weights
print(f"Saving weights to {output_path}...")
save_weights(output_path, weights)
# Save config
config_out_path = output_path / "config.json"
with open(config_out_path, "w") as f:
json.dump(config, f, indent=2)
print(f"Model converted successfully to {output_path}")
return output_path
def save_weights(path: Path, weights: Dict[str, mx.array]) -> None:
"""Save weights in safetensors format.
Args:
path: Output directory
weights: Dictionary of weights
"""
from safetensors.numpy import save_file
import numpy as np
# Convert to numpy for safetensors
np_weights = {k: np.array(v) for k, v in weights.items()}
# Save to file
save_file(np_weights, path / "model.safetensors")
def load_model(
path_or_hf_repo: str,
lazy: bool = False,
) -> LTXModel:
"""Load LTX model from path or HuggingFace.
Args:
path_or_hf_repo: Path to model or HuggingFace repo ID
lazy: Whether to use lazy loading
Returns:
Loaded LTXModel
"""
model_path = get_model_path(path_or_hf_repo)
# Load config
config = load_config(model_path)
# Create model
model = create_model_from_config(config)
# Load weights
weights = load_safetensors(model_path)
# Sanitize if needed
weights = sanitize_weights(weights)
# Load weights into model
model.load_weights(list(weights.items()))
if not lazy:
mx.eval(model.parameters())
return model
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Convert LTX-2 model to MLX format")
parser.add_argument(
"--hf-path",
type=str,
default="Lightricks/LTX-2",
help="HuggingFace model path or repo ID",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="Output path for MLX model",
)
parser.add_argument(
"--dtype",
type=str,
choices=["float16", "float32", "bfloat16"],
default="float16",
help="Target dtype",
)
parser.add_argument(
"--quantize",
action="store_true",
help="Quantize the model",
)
parser.add_argument(
"--q-bits",
type=int,
default=4,
help="Quantization bits",
)
args = parser.parse_args()
convert(
hf_path=args.hf_path,
mlx_path=args.mlx_path,
dtype=args.dtype,
quantize=args.quantize,
q_bits=args.q_bits,
)

586
mlx_video/generate.py Normal file
View File

@@ -0,0 +1,586 @@
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple, Iterator, Union
import mlx.core as mx
import numpy as np
from mlx_video.models.ltx.ltx import LTXModel, X0Model
from mlx_video.models.ltx.transformer import Modality
from mlx_video.models.ltx.video_vae import VideoEncoder, VideoDecoder
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder, load_text_encoder
@dataclass
class GenerationConfig:
"""Configuration for video generation."""
# Video dimensions
height: int = 512
width: int = 512
num_frames: int = 33 # Must be 1 + 8*k
# Diffusion parameters
num_inference_steps: int = 8 # For distilled model (ignored if use_distilled=True)
guidance_scale: float = 3.0
use_distilled: bool = True # Use hardcoded sigma values for distilled model
# Latent dimensions (computed from video dimensions)
@property
def latent_height(self) -> int:
return self.height // 32
@property
def latent_width(self) -> int:
return self.width // 32
@property
def latent_frames(self) -> int:
return 1 + (self.num_frames - 1) // 8
# Hardcoded sigma values for distilled model (from LTX-2 pipeline)
# These were tuned to match the distillation process
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0]
# Scheduler constants for dynamic sigma computation (non-distilled models)
BASE_SHIFT_ANCHOR = 1024
MAX_SHIFT_ANCHOR = 4096
def get_sigmas(
num_steps: int,
num_tokens: int,
max_shift: float = 2.05,
base_shift: float = 0.95,
stretch: bool = True,
terminal: float = 0.1,
use_distilled: bool = True,
) -> mx.array:
"""Get sigma schedule for diffusion.
Args:
num_steps: Number of diffusion steps
num_tokens: Number of latent tokens (T * H * W)
max_shift: Maximum shift for sigma schedule
base_shift: Base shift for sigma schedule
stretch: Whether to stretch sigmas to terminal value
terminal: Terminal value for stretching
use_distilled: If True, use hardcoded distilled sigma values
Returns:
Array of sigma values
"""
import math
# For distilled model, use hardcoded sigma values
if use_distilled:
return mx.array(DISTILLED_SIGMA_VALUES, dtype=mx.float32)
# For non-distilled models, compute dynamically using LTX2Scheduler logic
# Linear base schedule
sigmas = mx.linspace(1.0, 0.0, num_steps + 1)
# Compute token-dependent sigma shift
x1 = BASE_SHIFT_ANCHOR
x2 = MAX_SHIFT_ANCHOR
mm = (max_shift - base_shift) / (x2 - x1)
b = base_shift - mm * x1
sigma_shift = num_tokens * mm + b
# Apply exponential transformation
# sigmas = exp(sigma_shift) / (exp(sigma_shift) + (1/sigmas - 1)^1)
power = 1
exp_shift = math.exp(sigma_shift)
# Convert to numpy for computation then back to mx
sigmas_np = np.array(sigmas)
result = np.zeros_like(sigmas_np)
non_zero = sigmas_np != 0
result[non_zero] = exp_shift / (exp_shift + (1.0 / sigmas_np[non_zero] - 1.0) ** power)
# Stretch sigmas so final value matches terminal
if stretch:
non_zero_mask = result != 0
non_zero_sigmas = result[non_zero_mask]
one_minus_z = 1.0 - non_zero_sigmas
scale_factor = one_minus_z[-1] / (1.0 - terminal)
stretched = 1.0 - (one_minus_z / scale_factor)
result[non_zero_mask] = stretched
return mx.array(result, dtype=mx.float32)
def create_position_grid(
batch_size: int,
num_frames: int,
height: int,
width: int,
temporal_scale: int = 8,
spatial_scale: int = 32,
fps: float = 24.0,
causal_fix: bool = True,
) -> mx.array:
"""Create position grid for RoPE in pixel space.
Args:
batch_size: Batch size
num_frames: Number of frames (latent)
height: Height (latent)
width: Width (latent)
temporal_scale: VAE temporal scale factor (default 8)
spatial_scale: VAE spatial scale factor (default 32)
fps: Frames per second (default 24.0)
causal_fix: Apply causal fix for first frame (default True)
Returns:
Position grid of shape (B, 3, num_patches, 2) in pixel space
where dim 2 is [start, end) bounds for each patch
"""
# Patch size is (1, 1, 1) for LTX-2 - no spatial patching
patch_size_t, patch_size_h, patch_size_w = 1, 1, 1
# Generate grid coordinates for each dimension (frame, height, width)
# These are the starting coordinates for each patch in latent space
t_coords = np.arange(0, num_frames, patch_size_t)
h_coords = np.arange(0, height, patch_size_h)
w_coords = np.arange(0, width, patch_size_w)
# Create meshgrid with indexing='ij' for (frame, height, width) order
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
# Stack to get shape (3, grid_t, grid_h, grid_w)
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
# Calculate end coordinates (start + patch_size)
patch_size_delta = np.array([patch_size_t, patch_size_h, patch_size_w]).reshape(3, 1, 1, 1)
patch_ends = patch_starts + patch_size_delta
# Stack start and end: shape (3, grid_t, grid_h, grid_w, 2)
latent_coords = np.stack([patch_starts, patch_ends], axis=-1)
# Flatten spatial/temporal dims: (3, num_patches, 2)
num_patches = num_frames * height * width
latent_coords = latent_coords.reshape(3, num_patches, 2)
# Broadcast to batch: (batch, 3, num_patches, 2)
latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1))
# Convert latent coords to pixel coords by scaling with VAE factors
scale_factors = np.array([temporal_scale, spatial_scale, spatial_scale]).reshape(1, 3, 1, 1)
pixel_coords = (latent_coords * scale_factors).astype(np.float32)
# Apply causal fix for first frame temporal axis
if causal_fix:
# VAE temporal stride for first frame is 1 instead of temporal_scale
# Shift and clamp to keep first-frame timestamps non-negative
pixel_coords[:, 0, :, :] = np.clip(
pixel_coords[:, 0, :, :] + 1 - temporal_scale,
a_min=0,
a_max=None
)
# Convert temporal to time in seconds by dividing by fps
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps
return mx.array(pixel_coords, dtype=mx.float32)
class LTXVideoPipeline:
def __init__(
self,
transformer: LTXModel,
text_encoder: Optional[LTX2TextEncoder] = None,
tokenizer: Optional[any] = None,
vae_encoder: Optional[VideoEncoder] = None,
vae_decoder: Optional[VideoDecoder] = None,
):
"""Initialize pipeline.
Args:
transformer: LTX transformer model
text_encoder: Optional LTX text encoder
tokenizer: Optional tokenizer for text encoding
vae_encoder: Optional VAE encoder
vae_decoder: Optional VAE decoder
"""
self.transformer = transformer
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.vae_encoder = vae_encoder
self.vae_decoder = vae_decoder
self.x0_model = X0Model(transformer)
def prepare_latents(
self,
batch_size: int,
num_frames: int,
height: int,
width: int,
dtype: mx.Dtype = mx.float16,
) -> mx.array:
"""Prepare initial noise latents.
Args:
batch_size: Batch size
num_frames: Number of latent frames
height: Latent height
width: Latent width
dtype: Data type
Returns:
Random latent noise
"""
# Use in_channels from transformer config
in_channels = self.transformer.config.in_channels
shape = (batch_size, in_channels, num_frames, height, width)
latents = mx.random.normal(shape).astype(dtype)
return latents
def prepare_text_embeddings(
self,
prompt: Union[str, List[str]],
batch_size: int,
max_length: int = 1024,
) -> Tuple[mx.array, Optional[mx.array]]:
"""Prepare text embeddings.
Args:
prompt: Text prompt or list of prompts
batch_size: Batch size
max_length: Maximum sequence length for tokenization
Returns:
Tuple of (text_embeddings, attention_mask)
"""
# If text encoder is available, use it
if self.text_encoder is not None and self.tokenizer is not None:
# Handle single or multiple prompts
if isinstance(prompt, str):
prompts = [prompt] * batch_size
else:
prompts = prompt
# Tokenize
tokens = self.tokenizer(
prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="np",
)
input_ids = mx.array(tokens["input_ids"])
attention_mask = mx.array(tokens["attention_mask"])
# Encode
embeddings = self.text_encoder(input_ids, attention_mask)
mx.eval(embeddings)
return embeddings, None # Connector handles masking internally
# Fallback: random embeddings (for testing without text encoder)
print("Warning: No text encoder provided, using random embeddings")
seq_len = max_length + 128 # Account for learnable registers
embed_dim = self.transformer.config.caption_channels
embeddings = mx.random.normal((batch_size, seq_len, embed_dim))
mask = mx.ones((batch_size, seq_len))
return embeddings, mask
def denoise_step(
self,
latents: mx.array,
sigma: float,
sigma_next: float,
text_embeddings: mx.array,
positions: mx.array,
text_mask: Optional[mx.array] = None,
) -> mx.array:
"""Perform one denoising step.
Args:
latents: Current noisy latents
sigma: Current noise level
sigma_next: Next noise level
text_embeddings: Text conditioning
positions: Position grid for RoPE
text_mask: Optional attention mask for text
Returns:
Denoised latents
"""
batch_size = latents.shape[0]
# Flatten latents for transformer: (B, C, F, H, W) -> (B, F*H*W, C)
b, c, f, h, w = latents.shape
latents_flat = mx.reshape(latents, (b, c, -1))
latents_flat = mx.transpose(latents_flat, (0, 2, 1))
# Create timestep tensor
timesteps = mx.full((batch_size,), sigma)
# Create video modality input
video_modality = Modality(
latent=latents_flat,
timesteps=timesteps,
positions=positions,
context=text_embeddings,
context_mask=text_mask,
enabled=True,
)
# Run denoising
denoised_video, _ = self.x0_model(video=video_modality, audio=None)
# Reshape back: (B, F*H*W, C) -> (B, C, F, H, W)
denoised_video = mx.transpose(denoised_video, (0, 2, 1))
denoised_video = mx.reshape(denoised_video, (b, c, f, h, w))
# Euler step
if sigma_next > 0:
# x_next = x0 + sigma_next * (x - x0) / sigma
noise = (latents - denoised_video) / sigma
latents = denoised_video + sigma_next * noise
else:
latents = denoised_video
return latents
def __call__(
self,
prompt: str,
config: Optional[GenerationConfig] = None,
seed: Optional[int] = None,
) -> mx.array:
"""Generate video from text prompt.
Args:
prompt: Text prompt
config: Generation configuration
seed: Random seed
Returns:
Generated video tensor of shape (B, C, F, H, W)
"""
if config is None:
config = GenerationConfig()
if seed is not None:
mx.random.seed(seed)
batch_size = 1
# Prepare text embeddings
text_embeddings, text_mask = self.prepare_text_embeddings(prompt, batch_size)
# Prepare initial latents
latents = self.prepare_latents(
batch_size=batch_size,
num_frames=config.latent_frames,
height=config.latent_height,
width=config.latent_width,
)
# Prepare position grid
positions = create_position_grid(
batch_size=batch_size,
num_frames=config.latent_frames,
height=config.latent_height,
width=config.latent_width,
)
# Get sigma schedule
num_tokens = config.latent_frames * config.latent_height * config.latent_width
sigmas = get_sigmas(
config.num_inference_steps,
num_tokens,
use_distilled=config.use_distilled,
)
# Denoising loop
for i in range(len(sigmas) - 1):
sigma = float(sigmas[i])
sigma_next = float(sigmas[i + 1])
latents = self.denoise_step(
latents=latents,
sigma=sigma,
sigma_next=sigma_next,
text_embeddings=text_embeddings,
positions=positions,
text_mask=text_mask,
)
mx.eval(latents)
# Decode latents to video
if self.vae_decoder is not None:
video = self.vae_decoder(latents)
else:
video = latents
return video
def generate_video(
prompt: str,
transformer: LTXModel,
text_encoder: Optional[LTX2TextEncoder] = None,
tokenizer: Optional[any] = None,
vae_decoder: Optional[VideoDecoder] = None,
config: Optional[GenerationConfig] = None,
seed: Optional[int] = None,
) -> mx.array:
"""Generate video from text prompt.
Args:
prompt: Text prompt
transformer: LTX transformer model
text_encoder: Optional text encoder
tokenizer: Optional tokenizer
vae_decoder: Optional VAE decoder
config: Generation configuration
seed: Random seed
Returns:
Generated video tensor
"""
pipeline = LTXVideoPipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae_decoder=vae_decoder,
)
return pipeline(prompt, config, seed)
def load_pipeline(
model_path: str,
text_encoder_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
load_text_encoder_weights: bool = True,
) -> LTXVideoPipeline:
"""Load complete LTX-2 video generation pipeline.
Args:
model_path: Path to LTX-2 model weights (safetensors)
text_encoder_path: Path to text encoder weights directory
tokenizer_path: Path to tokenizer directory
load_text_encoder_weights: Whether to load text encoder weights
Returns:
Configured LTXVideoPipeline
"""
from transformers import AutoTokenizer
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.convert import sanitize_transformer_weights
print("Loading LTX-2 pipeline...")
# Load transformer
print(" Loading transformer...")
raw_weights = mx.load(model_path)
sanitized = sanitize_transformer_weights(raw_weights)
config = LTXModelConfig(
model_type=LTXModelType.VideoOnly,
num_attention_heads=32,
attention_head_dim=128,
in_channels=128,
out_channels=128,
num_layers=48,
cross_attention_dim=4096,
caption_channels=3840,
)
transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False)
print(" Transformer loaded")
# Load VAE decoder
print(" Loading VAE decoder...")
vae_decoder = load_vae_decoder(model_path, timestep_conditioning=True)
print(" VAE decoder loaded")
# Load text encoder if paths provided
text_encoder = None
tokenizer = None
if load_text_encoder_weights and text_encoder_path is not None:
print(" Loading text encoder...")
text_encoder = load_text_encoder(model_path, text_encoder_path)
print(" Text encoder loaded")
if tokenizer_path is not None:
print(" Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
print(" Tokenizer loaded")
print("Pipeline ready!")
return LTXVideoPipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae_decoder=vae_decoder,
)
def video_to_numpy(video: mx.array) -> np.ndarray:
"""Convert video tensor to numpy array.
Args:
video: Video tensor of shape (B, C, F, H, W) in range [-1, 1]
Returns:
Numpy array of shape (B, F, H, W, C) in range [0, 255]
"""
# Clamp to [-1, 1]
video = mx.clip(video, -1.0, 1.0)
# Scale to [0, 255]
video = ((video + 1.0) / 2.0 * 255.0).astype(mx.uint8)
# Rearrange: (B, C, F, H, W) -> (B, F, H, W, C)
video = mx.transpose(video, (0, 2, 3, 4, 1))
return np.array(video)
if __name__ == "__main__":
# Example usage
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
# Create a small test config
config = LTXModelConfig(
model_type=LTXModelType.VideoOnly,
num_layers=2, # Reduced for testing
num_attention_heads=4,
attention_head_dim=32,
)
# Create model
model = LTXModel(config)
# Generate video
gen_config = GenerationConfig(
height=256,
width=256,
num_frames=9,
num_inference_steps=4,
)
print("Testing generation pipeline...")
pipeline = LTXVideoPipeline(transformer=model)
# This would require proper text embeddings in practice
# video = pipeline("A cat walking", gen_config, seed=42)
# print(f"Generated video shape: {video.shape}")
print("Pipeline initialized successfully!")

View File

@@ -0,0 +1,2 @@
from mlx_video.models.ltx import LTXModel, LTXModelConfig

View File

@@ -0,0 +1,7 @@
from mlx_video.models.ltx.config import (
LTXModelConfig,
TransformerConfig,
LTXModelType,
)
from mlx_video.models.ltx.ltx import LTXModel, X0Model

View File

@@ -0,0 +1,161 @@
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.utils import get_timestep_embedding
class AdaLayerNormSingle(nn.Module):
def __init__(
self,
embedding_dim: int,
embedding_coefficient: int = 6,
use_additional_conditions: bool = False,
):
super().__init__()
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim=embedding_dim,
size_emb_dim=0 if not use_additional_conditions else embedding_dim // 3,
use_additional_conditions=use_additional_conditions,
)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True)
def __call__(
self,
timestep: mx.array,
added_cond_kwargs: dict | None = None,
batch_size: int | None = None,
hidden_dtype: mx.Dtype | None = None,
) -> Tuple[mx.array, mx.array]:
added_cond_kwargs = added_cond_kwargs or {}
embedded_timestep = self.emb(
timestep,
batch_size=batch_size,
hidden_dtype=hidden_dtype,
**added_cond_kwargs,
)
scale_shift_params = self.linear(self.silu(embedded_timestep))
return scale_shift_params, embedded_timestep
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
def __init__(
self,
embedding_dim: int,
size_emb_dim: int = 0,
use_additional_conditions: bool = False,
timestep_proj_dim: int = 256,
):
super().__init__()
self.embedding_dim = embedding_dim
self.size_emb_dim = size_emb_dim
self.use_additional_conditions = use_additional_conditions
self.time_proj = Timesteps(timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(timestep_proj_dim, embedding_dim, out_dim=embedding_dim)
if use_additional_conditions and size_emb_dim > 0:
self.additional_embedder = ConditionEmbedding(size_emb_dim, embedding_dim)
def __call__(
self,
timestep: mx.array,
resolution: mx.array | None = None,
aspect_ratio: mx.array | None = None,
batch_size: int | None = None,
hidden_dtype: mx.Dtype | None = None,
) -> mx.array:
# Project timestep
timesteps_proj = self.time_proj(timestep)
if hidden_dtype is not None:
timesteps_proj = timesteps_proj.astype(hidden_dtype)
timesteps_emb = self.timestep_embedder(timesteps_proj)
# Add additional conditions if enabled
if self.use_additional_conditions and self.size_emb_dim > 0:
if resolution is not None and aspect_ratio is not None:
additional_embeds = self.additional_embedder(resolution, aspect_ratio, hidden_dtype)
timesteps_emb = timesteps_emb + additional_embeds
return timesteps_emb
class Timesteps(nn.Module):
def __init__(
self,
num_channels: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1.0,
):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def __call__(self, timesteps: mx.array) -> mx.array:
return get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int | None = None,
):
super().__init__()
out_dim = out_dim or time_embed_dim
self.linear1 = nn.Linear(in_channels, time_embed_dim)
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
self.linear2 = nn.Linear(time_embed_dim, out_dim)
def __call__(self, sample: mx.array) -> mx.array:
sample = self.linear1(sample)
sample = self.act(sample)
sample = self.linear2(sample)
return sample
class ConditionEmbedding(nn.Module):
def __init__(self, size_emb_dim: int, embedding_dim: int):
super().__init__()
self.resolution_embedder = TimestepEmbedding(size_emb_dim, embedding_dim)
self.aspect_ratio_embedder = TimestepEmbedding(size_emb_dim, embedding_dim)
def __call__(
self,
resolution: mx.array,
aspect_ratio: mx.array,
hidden_dtype: mx.Dtype | None = None,
) -> mx.array:
resolution_emb = self.resolution_embedder(resolution)
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio)
if hidden_dtype is not None:
resolution_emb = resolution_emb.astype(hidden_dtype)
aspect_ratio_emb = aspect_ratio_emb.astype(hidden_dtype)
return resolution_emb + aspect_ratio_emb

View File

@@ -0,0 +1,142 @@
"""Attention module for LTX-2."""
import math
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.config import LTXRopeType
from mlx_video.models.ltx.rope import apply_rotary_emb
def scaled_dot_product_attention(
q: mx.array,
k: mx.array,
v: mx.array,
heads: int,
mask: Optional[mx.array] = None,
) -> mx.array:
b, q_seq_len, dim = q.shape
_, kv_seq_len, _ = k.shape
dim_head = dim // heads
# Reshape to (B, seq_len, heads, dim_head)
q = mx.reshape(q, (b, q_seq_len, heads, dim_head))
k = mx.reshape(k, (b, kv_seq_len, heads, dim_head))
v = mx.reshape(v, (b, kv_seq_len, heads, dim_head))
# Transpose to (B, heads, seq_len, dim_head)
q = mx.swapaxes(q, 1, 2)
k = mx.swapaxes(k, 1, 2)
v = mx.swapaxes(v, 1, 2)
# Handle mask dimensions
if mask is not None:
# Add batch dimension if needed
if mask.ndim == 2:
mask = mx.expand_dims(mask, axis=0)
# Add heads dimension if needed
if mask.ndim == 3:
mask = mx.expand_dims(mask, axis=1)
# Compute scaled dot-product attention
scale = 1.0 / math.sqrt(dim_head)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
# Reshape back to (B, q_seq_len, heads * dim_head)
out = mx.swapaxes(out, 1, 2)
out = mx.reshape(out, (b, q_seq_len, heads * dim_head))
return out
class Attention(nn.Module):
"""Multi-head attention with rotary position embeddings.
Supports both self-attention and cross-attention.
"""
def __init__(
self,
query_dim: int,
context_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
norm_eps: float = 1e-6,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
):
"""Initialize attention module.
Args:
query_dim: Dimension of query input
context_dim: Dimension of context (key/value) input. If None, same as query_dim
heads: Number of attention heads
dim_head: Dimension per head
norm_eps: Epsilon for RMS normalization
rope_type: Type of rotary position embedding
"""
super().__init__()
self.rope_type = rope_type
self.heads = heads
self.dim_head = dim_head
inner_dim = dim_head * heads
context_dim = query_dim if context_dim is None else context_dim
# Q, K, V projections
self.to_q = nn.Linear(query_dim, inner_dim, bias=True)
self.to_k = nn.Linear(context_dim, inner_dim, bias=True)
self.to_v = nn.Linear(context_dim, inner_dim, bias=True)
# Q and K normalization
self.q_norm = nn.RMSNorm(inner_dim, eps=norm_eps)
self.k_norm = nn.RMSNorm(inner_dim, eps=norm_eps)
# Output projection
self.to_out = nn.Linear(inner_dim, query_dim, bias=True)
def __call__(
self,
x: mx.array,
context: Optional[mx.array] = None,
mask: Optional[mx.array] = None,
pe: Optional[Tuple[mx.array, mx.array]] = None,
k_pe: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
"""Forward pass.
Args:
x: Query input of shape (B, seq_len, query_dim)
context: Context for cross-attention. If None, uses x (self-attention)
mask: Attention mask
pe: Position embeddings for query (and key if k_pe is None)
k_pe: Position embeddings for key (optional, uses pe if None)
Returns:
Attention output of shape (B, seq_len, query_dim)
"""
# Compute Q, K, V
q = self.to_q(x)
context = x if context is None else context
k = self.to_k(context)
v = self.to_v(context)
# Apply normalization
q = self.q_norm(q)
k = self.k_norm(k)
# Apply rotary position embeddings
if pe is not None:
q = apply_rotary_emb(q, pe, self.rope_type)
k_pe_to_use = pe if k_pe is None else k_pe
k = apply_rotary_emb(k, k_pe_to_use, self.rope_type)
# Compute attention
out = scaled_dot_product_attention(q, k, v, self.heads, mask)
# Project output
return self.to_out(out)

View File

@@ -0,0 +1,181 @@
import inspect
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, List, Optional
class LTXModelType(Enum):
AudioVideo = "ltx av model"
VideoOnly = "ltx video only model"
AudioOnly = "ltx audio only model"
def is_video_enabled(self) -> bool:
return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly)
def is_audio_enabled(self) -> bool:
return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly)
class LTXRopeType(Enum):
INTERLEAVED = "interleaved"
SPLIT = "split"
TWO_D = "2d"
class AttentionType(Enum):
DEFAULT = "default"
@dataclass
class BaseModelConfig:
@classmethod
def from_dict(cls, params: dict[str, Any]) -> "BaseModelConfig":
"""Create config from dictionary, filtering only valid parameters."""
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
def to_dict(self) -> dict[str, Any]:
"""Export config to dictionary."""
result = {}
for k, v in self.__dict__.items():
if v is not None:
if isinstance(v, Enum):
result[k] = v.value
elif hasattr(v, 'to_dict'):
result[k] = v.to_dict()
else:
result[k] = v
return result
@dataclass
class TransformerConfig(BaseModelConfig):
dim: int
heads: int
d_head: int
context_dim: int
@dataclass
class VideoVAEConfig(BaseModelConfig):
convolution_dimensions: int = 3
in_channels: int = 3
out_channels: int = 128
latent_channels: int = 128
patch_size: int = 4
encoder_blocks: List[tuple] = field(default_factory=lambda: [
("res_x", {"num_layers": 4}),
("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_time_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
])
decoder_blocks: List[tuple] = field(default_factory=lambda: [
("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
])
@dataclass
class LTXModelConfig(BaseModelConfig):
# Model type
model_type: LTXModelType = LTXModelType.AudioVideo
# Video transformer config
num_attention_heads: int = 32
attention_head_dim: int = 128
in_channels: int = 128
out_channels: int = 128
num_layers: int = 48
cross_attention_dim: int = 4096
caption_channels: int = 3840
# Audio transformer config
audio_num_attention_heads: int = 32
audio_attention_head_dim: int = 64
audio_in_channels: int = 128
audio_out_channels: int = 128
audio_cross_attention_dim: int = 2048
# Positional embedding config
positional_embedding_theta: float = 10000.0
positional_embedding_max_pos: Optional[List[int]] = None
audio_positional_embedding_max_pos: Optional[List[int]] = None
use_middle_indices_grid: bool = True
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED
double_precision_rope: bool = False
# Timestep config
timestep_scale_multiplier: int = 1000
av_ca_timestep_scale_multiplier: int = 1
# Normalization
norm_eps: float = 1e-6
# Attention type
attention_type: AttentionType = AttentionType.DEFAULT
# VAE config
vae_config: Optional[VideoVAEConfig] = None
def __post_init__(self):
"""Set default values after initialization."""
if self.positional_embedding_max_pos is None:
self.positional_embedding_max_pos = [20, 2048, 2048]
if self.audio_positional_embedding_max_pos is None:
self.audio_positional_embedding_max_pos = [20]
# Convert string enum values if loading from dict
if isinstance(self.model_type, str):
self.model_type = LTXModelType(self.model_type)
if isinstance(self.rope_type, str):
self.rope_type = LTXRopeType(self.rope_type)
if isinstance(self.attention_type, str):
self.attention_type = AttentionType(self.attention_type)
@property
def inner_dim(self) -> int:
"""Video inner dimension."""
return self.num_attention_heads * self.attention_head_dim
@property
def audio_inner_dim(self) -> int:
"""Audio inner dimension."""
return self.audio_num_attention_heads * self.audio_attention_head_dim
def get_video_config(self) -> Optional[TransformerConfig]:
"""Get video transformer configuration."""
if not self.model_type.is_video_enabled():
return None
return TransformerConfig(
dim=self.inner_dim,
heads=self.num_attention_heads,
d_head=self.attention_head_dim,
context_dim=self.cross_attention_dim,
)
def get_audio_config(self) -> Optional[TransformerConfig]:
"""Get audio transformer configuration."""
if not self.model_type.is_audio_enabled():
return None
return TransformerConfig(
dim=self.audio_inner_dim,
heads=self.audio_num_attention_heads,
d_head=self.audio_attention_head_dim,
context_dim=self.audio_cross_attention_dim,
)

View File

@@ -0,0 +1,40 @@
import mlx.core as mx
import mlx.nn as nn
class GELU(nn.Module):
def __init__(self, approximate: str = "tanh"):
super().__init__()
self.approximate = approximate
def __call__(self, x: mx.array) -> mx.array:
if self.approximate == "tanh":
return nn.gelu_approx(x)
else:
return nn.gelu(x)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
dim_out: int | None = None,
mult: int = 4,
bias: bool = True,
):
super().__init__()
dim_out = dim_out or dim
inner_dim = int(dim * mult)
self.proj_in = nn.Linear(dim, inner_dim, bias=bias)
self.act = GELU(approximate="tanh")
self.proj_out = nn.Linear(inner_dim, dim_out, bias=bias)
def __call__(self, x: mx.array) -> mx.array:
x = self.proj_in(x)
x = self.act(x)
x = self.proj_out(x)
return x

518
mlx_video/models/ltx/ltx.py Normal file
View File

@@ -0,0 +1,518 @@
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.config import (
LTXModelConfig,
LTXModelType,
LTXRopeType,
TransformerConfig,
)
from mlx_video.models.ltx.adaln import AdaLayerNormSingle
from mlx_video.models.ltx.rope import precompute_freqs_cis
from mlx_video.models.ltx.text_projection import PixArtAlphaTextProjection
from mlx_video.models.ltx.transformer import (
BasicAVTransformerBlock,
Modality,
TransformerArgs,
)
from mlx_video.utils import to_denoised
class TransformerArgsPreprocessor:
def __init__(
self,
patchify_proj: nn.Linear,
adaln: AdaLayerNormSingle,
caption_projection: PixArtAlphaTextProjection,
inner_dim: int,
max_pos: List[int],
num_attention_heads: int,
use_middle_indices_grid: bool,
timestep_scale_multiplier: int,
positional_embedding_theta: float,
rope_type: LTXRopeType,
double_precision_rope: bool = False,
):
self.patchify_proj = patchify_proj
self.adaln = adaln
self.caption_projection = caption_projection
self.inner_dim = inner_dim
self.max_pos = max_pos
self.num_attention_heads = num_attention_heads
self.use_middle_indices_grid = use_middle_indices_grid
self.timestep_scale_multiplier = timestep_scale_multiplier
self.positional_embedding_theta = positional_embedding_theta
self.rope_type = rope_type
self.double_precision_rope = double_precision_rope
def _prepare_timestep(
self,
timestep: mx.array,
batch_size: int,
) -> Tuple[mx.array, mx.array]:
timestep = timestep * self.timestep_scale_multiplier
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1))
# Reshape to (batch, tokens, dim)
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
return timestep_emb, embedded_timestep
def _prepare_context(
self,
context: mx.array,
x: mx.array,
attention_mask: Optional[mx.array] = None,
) -> Tuple[mx.array, Optional[mx.array]]:
batch_size = x.shape[0]
context = self.caption_projection(context)
context = mx.reshape(context, (batch_size, -1, x.shape[-1]))
return context, attention_mask
def _prepare_attention_mask(
self,
attention_mask: Optional[mx.array],
x_dtype: mx.Dtype,
) -> Optional[mx.array]:
if attention_mask is None:
return None
# Check if already float
if attention_mask.dtype in [mx.float16, mx.float32, mx.bfloat16]:
return attention_mask
# Convert boolean/int mask to float mask
# 0 -> -inf (masked), 1 -> 0 (not masked)
mask = (attention_mask.astype(x_dtype) - 1) * 1e9
mask = mx.reshape(mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
return mask
def _prepare_positional_embeddings(
self,
positions: mx.array,
inner_dim: int,
max_pos: List[int],
use_middle_indices_grid: bool,
num_attention_heads: int,
) -> Tuple[mx.array, mx.array]:
pe = precompute_freqs_cis(
positions,
dim=inner_dim,
theta=self.positional_embedding_theta,
max_pos=max_pos,
use_middle_indices_grid=use_middle_indices_grid,
num_attention_heads=num_attention_heads,
rope_type=self.rope_type,
double_precision=self.double_precision_rope,
)
return pe
def prepare(self, modality: Modality) -> TransformerArgs:
x = self.patchify_proj(modality.latent)
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0])
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
pe = self._prepare_positional_embeddings(
positions=modality.positions,
inner_dim=self.inner_dim,
max_pos=self.max_pos,
use_middle_indices_grid=self.use_middle_indices_grid,
num_attention_heads=self.num_attention_heads,
)
return TransformerArgs(
x=x,
context=context,
context_mask=attention_mask,
timesteps=timestep,
embedded_timestep=embedded_timestep,
positional_embeddings=pe,
cross_positional_embeddings=None,
cross_scale_shift_timestep=None,
cross_gate_timestep=None,
enabled=modality.enabled,
)
class MultiModalTransformerArgsPreprocessor:
def __init__(
self,
patchify_proj: nn.Linear,
adaln: AdaLayerNormSingle,
caption_projection: PixArtAlphaTextProjection,
cross_scale_shift_adaln: AdaLayerNormSingle,
cross_gate_adaln: AdaLayerNormSingle,
inner_dim: int,
max_pos: List[int],
num_attention_heads: int,
cross_pe_max_pos: int,
use_middle_indices_grid: bool,
audio_cross_attention_dim: int,
timestep_scale_multiplier: int,
positional_embedding_theta: float,
rope_type: LTXRopeType,
av_ca_timestep_scale_multiplier: int,
double_precision_rope: bool = False,
):
self.simple_preprocessor = TransformerArgsPreprocessor(
patchify_proj=patchify_proj,
adaln=adaln,
caption_projection=caption_projection,
inner_dim=inner_dim,
max_pos=max_pos,
num_attention_heads=num_attention_heads,
use_middle_indices_grid=use_middle_indices_grid,
timestep_scale_multiplier=timestep_scale_multiplier,
positional_embedding_theta=positional_embedding_theta,
rope_type=rope_type,
double_precision_rope=double_precision_rope,
)
self.cross_scale_shift_adaln = cross_scale_shift_adaln
self.cross_gate_adaln = cross_gate_adaln
self.cross_pe_max_pos = cross_pe_max_pos
self.audio_cross_attention_dim = audio_cross_attention_dim
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
def prepare(self, modality: Modality) -> TransformerArgs:
from dataclasses import replace
transformer_args = self.simple_preprocessor.prepare(modality)
# Prepare cross-modal positional embeddings
cross_pe = self.simple_preprocessor._prepare_positional_embeddings(
positions=modality.positions[:, 0:1, :],
inner_dim=self.audio_cross_attention_dim,
max_pos=[self.cross_pe_max_pos],
use_middle_indices_grid=True,
num_attention_heads=self.simple_preprocessor.num_attention_heads,
)
# Prepare cross-attention timestep embeddings
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
timestep=modality.timesteps,
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
batch_size=transformer_args.x.shape[0],
)
return replace(
transformer_args,
cross_positional_embeddings=cross_pe,
cross_scale_shift_timestep=cross_scale_shift_timestep,
cross_gate_timestep=cross_gate_timestep,
)
def _prepare_cross_attention_timestep(
self,
timestep: mx.array,
timestep_scale_multiplier: int,
batch_size: int,
) -> Tuple[mx.array, mx.array]:
timestep = timestep * timestep_scale_multiplier
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1))
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor)
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
return scale_shift_timestep, gate_timestep
class LTXModel(nn.Module):
def __init__(self, config: LTXModelConfig):
super().__init__()
self.config = config
self.model_type = config.model_type
self.use_middle_indices_grid = config.use_middle_indices_grid
self.rope_type = config.rope_type
self.timestep_scale_multiplier = config.timestep_scale_multiplier
self.positional_embedding_theta = config.positional_embedding_theta
cross_pe_max_pos = None
if config.model_type.is_video_enabled():
self.positional_embedding_max_pos = config.positional_embedding_max_pos
self.num_attention_heads = config.num_attention_heads
self.inner_dim = config.inner_dim
self._init_video(config)
if config.model_type.is_audio_enabled():
self.audio_positional_embedding_max_pos = config.audio_positional_embedding_max_pos
self.audio_num_attention_heads = config.audio_num_attention_heads
self.audio_inner_dim = config.audio_inner_dim
self._init_audio(config)
# Initialize cross-modal components
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
cross_pe_max_pos = max(
config.positional_embedding_max_pos[0],
config.audio_positional_embedding_max_pos[0],
)
self.av_ca_timestep_scale_multiplier = config.av_ca_timestep_scale_multiplier
self.audio_cross_attention_dim = config.audio_cross_attention_dim
self._init_audio_video(config)
self._init_preprocessors(config, cross_pe_max_pos)
self._init_transformer_blocks(config)
def _init_video(self, config: LTXModelConfig) -> None:
self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True)
self.adaln_single = AdaLayerNormSingle(self.inner_dim)
self.caption_projection = PixArtAlphaTextProjection(
in_features=config.caption_channels,
hidden_size=self.inner_dim,
)
self.scale_shift_table = mx.zeros((2, self.inner_dim))
self.norm_out = nn.LayerNorm(self.inner_dim, eps=config.norm_eps, affine=False)
self.proj_out = nn.Linear(self.inner_dim, config.out_channels)
def _init_audio(self, config: LTXModelConfig) -> None:
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim)
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=config.caption_channels,
hidden_size=self.audio_inner_dim,
)
# Output components
self.audio_scale_shift_table = mx.zeros((2, self.audio_inner_dim))
self.audio_norm_out = nn.LayerNorm(self.audio_inner_dim, eps=config.norm_eps, affine=False)
self.audio_proj_out = nn.Linear(self.audio_inner_dim, config.audio_out_channels)
def _init_audio_video(self, config: LTXModelConfig) -> None:
num_scale_shift_values = 4
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
self.inner_dim,
embedding_coefficient=num_scale_shift_values,
)
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
embedding_coefficient=num_scale_shift_values,
)
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
self.inner_dim,
embedding_coefficient=1,
)
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
embedding_coefficient=1,
)
def _init_preprocessors(self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]) -> None:
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
# Multi-modal preprocessors
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
patchify_proj=self.patchify_proj,
adaln=self.adaln_single,
caption_projection=self.caption_projection,
cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
inner_dim=self.inner_dim,
max_pos=config.positional_embedding_max_pos,
num_attention_heads=self.num_attention_heads,
cross_pe_max_pos=cross_pe_max_pos,
use_middle_indices_grid=config.use_middle_indices_grid,
audio_cross_attention_dim=config.audio_cross_attention_dim,
timestep_scale_multiplier=config.timestep_scale_multiplier,
positional_embedding_theta=config.positional_embedding_theta,
rope_type=config.rope_type,
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
double_precision_rope=config.double_precision_rope,
)
self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
patchify_proj=self.audio_patchify_proj,
adaln=self.audio_adaln_single,
caption_projection=self.audio_caption_projection,
cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
inner_dim=self.audio_inner_dim,
max_pos=config.audio_positional_embedding_max_pos,
num_attention_heads=self.audio_num_attention_heads,
cross_pe_max_pos=cross_pe_max_pos,
use_middle_indices_grid=config.use_middle_indices_grid,
audio_cross_attention_dim=config.audio_cross_attention_dim,
timestep_scale_multiplier=config.timestep_scale_multiplier,
positional_embedding_theta=config.positional_embedding_theta,
rope_type=config.rope_type,
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
double_precision_rope=config.double_precision_rope,
)
elif config.model_type.is_video_enabled():
self.video_args_preprocessor = TransformerArgsPreprocessor(
patchify_proj=self.patchify_proj,
adaln=self.adaln_single,
caption_projection=self.caption_projection,
inner_dim=self.inner_dim,
max_pos=config.positional_embedding_max_pos,
num_attention_heads=self.num_attention_heads,
use_middle_indices_grid=config.use_middle_indices_grid,
timestep_scale_multiplier=config.timestep_scale_multiplier,
positional_embedding_theta=config.positional_embedding_theta,
rope_type=config.rope_type,
double_precision_rope=config.double_precision_rope,
)
elif config.model_type.is_audio_enabled():
self.audio_args_preprocessor = TransformerArgsPreprocessor(
patchify_proj=self.audio_patchify_proj,
adaln=self.audio_adaln_single,
caption_projection=self.audio_caption_projection,
inner_dim=self.audio_inner_dim,
max_pos=config.audio_positional_embedding_max_pos,
num_attention_heads=self.audio_num_attention_heads,
use_middle_indices_grid=config.use_middle_indices_grid,
timestep_scale_multiplier=config.timestep_scale_multiplier,
positional_embedding_theta=config.positional_embedding_theta,
rope_type=config.rope_type,
double_precision_rope=config.double_precision_rope,
)
def _init_transformer_blocks(self, config: LTXModelConfig) -> None:
video_config = config.get_video_config()
audio_config = config.get_audio_config()
self.transformer_blocks = [
BasicAVTransformerBlock(
idx=idx,
video=video_config,
audio=audio_config,
rope_type=config.rope_type,
norm_eps=config.norm_eps,
)
for idx in range(config.num_layers)
]
def _process_transformer_blocks(
self,
video: Optional[TransformerArgs],
audio: Optional[TransformerArgs],
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
"""Process through all transformer blocks."""
for block in self.transformer_blocks:
video, audio = block(video=video, audio=audio)
return video, audio
def _process_output(
self,
scale_shift_table: mx.array,
norm_out: nn.LayerNorm,
proj_out: nn.Linear,
x: mx.array,
embedded_timestep: mx.array,
) -> mx.array:
# scale_shift_table: (2, dim) -> expand to (1, 1, 2, dim)
# embedded_timestep: (B, 1, dim) -> expand to (B, 1, 1, dim)
table_expanded = scale_shift_table[None, None, :, :] # (1, 1, 2, dim)
timestep_expanded = embedded_timestep[:, :, None, :] # (B, 1, 1, dim)
# Combine: (1, 1, 2, dim) + (B, 1, 1, dim) broadcasts to (B, 1, 2, dim)
scale_shift_values = table_expanded + timestep_expanded
# Extract shift and scale (first index is shift, second is scale)
shift = scale_shift_values[:, :, 0, :] # (B, 1, dim)
scale = scale_shift_values[:, :, 1, :] # (B, 1, dim)
x = norm_out(x)
x = x * (1 + scale) + shift # Broadcasts (B, 1, dim) to (B, seq, dim)
x = proj_out(x)
return x
def __call__(
self,
video: Optional[Modality] = None,
audio: Optional[Modality] = None,
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
# Validate inputs
if not self.model_type.is_video_enabled() and video is not None:
raise ValueError("Video is not enabled for this model")
if not self.model_type.is_audio_enabled() and audio is not None:
raise ValueError("Audio is not enabled for this model")
# Preprocess arguments
video_args = self.video_args_preprocessor.prepare(video) if video is not None else None
audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None
# Process transformer blocks
video_out, audio_out = self._process_transformer_blocks(
video=video_args,
audio=audio_args,
)
# Process outputs
vx = (
self._process_output(
self.scale_shift_table,
self.norm_out,
self.proj_out,
video_out.x,
video_out.embedded_timestep,
)
if video_out is not None
else None
)
ax = (
self._process_output(
self.audio_scale_shift_table,
self.audio_norm_out,
self.audio_proj_out,
audio_out.x,
audio_out.embedded_timestep,
)
if audio_out is not None
else None
)
return vx, ax
def sanitize(self, weights: dict) -> dict:
sanitized = {}
for key, value in weights.items():
new_key = key
# Handle common remappings
# transformer_blocks.X -> transformer_blocks[X]
if "transformer_blocks." in new_key:
# Keep as-is for now, MLX handles this
pass
sanitized[new_key] = value
return sanitized
class X0Model(nn.Module):
def __init__(self, velocity_model: LTXModel):
super().__init__()
self.velocity_model = velocity_model
def __call__(
self,
video: Optional[Modality] = None,
audio: Optional[Modality] = None,
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
vx, ax = self.velocity_model(video, audio)
denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
return denoised_video, denoised_audio

View File

@@ -0,0 +1,508 @@
import math
from typing import Callable, List, Optional, Tuple
import mlx.core as mx
import numpy as np
from mlx_video.models.ltx.config import LTXRopeType
def apply_rotary_emb(
input_tensor: mx.array,
freqs_cis: Tuple[mx.array, mx.array],
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
) -> mx.array:
"""Apply rotary position embeddings to input tensor.
Args:
input_tensor: Input tensor to apply RoPE to
freqs_cis: Tuple of (cos_freqs, sin_freqs)
rope_type: Type of RoPE to apply (INTERLEAVED or SPLIT)
Returns:
Tensor with rotary embeddings applied
"""
if rope_type == LTXRopeType.INTERLEAVED:
return apply_interleaved_rotary_emb(input_tensor, freqs_cis[0], freqs_cis[1])
elif rope_type == LTXRopeType.SPLIT:
return apply_split_rotary_emb(input_tensor, freqs_cis[0], freqs_cis[1])
else:
raise ValueError(f"Invalid rope type: {rope_type}")
def apply_interleaved_rotary_emb(
input_tensor: mx.array,
cos_freqs: mx.array,
sin_freqs: mx.array,
) -> mx.array:
"""Apply interleaved rotary embeddings.
Pairs adjacent dimensions and applies rotation.
Pattern: [x0, x1, x2, x3, ...] -> rotate pairs (x0,x1), (x2,x3), ...
Args:
input_tensor: Input tensor of shape (..., dim)
cos_freqs: Cosine frequencies
sin_freqs: Sine frequencies
Returns:
Tensor with interleaved rotary embeddings applied
"""
# Reshape to pair adjacent dimensions: (..., dim) -> (..., dim/2, 2)
shape = input_tensor.shape
input_tensor = mx.reshape(input_tensor, shape[:-1] + (shape[-1] // 2, 2))
# Extract pairs
t1 = input_tensor[..., 0] # Even indices
t2 = input_tensor[..., 1] # Odd indices
# Apply rotation: (-t2, t1) pattern
t_rot = mx.stack([-t2, t1], axis=-1)
# Flatten back: (..., dim/2, 2) -> (..., dim)
input_tensor = mx.reshape(input_tensor, shape)
t_rot = mx.reshape(t_rot, shape)
# Apply rotary embeddings
out = input_tensor * cos_freqs + t_rot * sin_freqs
return out
def rotate_half_interleaved(x: mx.array) -> mx.array:
"""Rotate for interleaved RoPE: [x0, x1, x2, x3] -> [-x1, x0, -x3, x2].
PyTorch equivalent:
t_dup = rearrange(x, "... (d r) -> ... d r", r=2)
t1, t2 = t_dup.unbind(dim=-1)
t_dup = torch.stack((-t2, t1), dim=-1)
return rearrange(t_dup, "... d r -> ... (d r)")
"""
# x: (..., dim) where dim is even
x_even = x[..., 0::2] # [x0, x2, x4, ...]
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
# Stack: [[-x1, x0], [-x3, x2], ...] then flatten to [-x1, x0, -x3, x2, ...]
rotated = mx.stack([-x_odd, x_even], axis=-1)
return mx.reshape(rotated, x.shape)
def apply_rotary_emb_1d(
q: mx.array,
k: mx.array,
freqs_cis: mx.array,
) -> Tuple[mx.array, mx.array]:
"""Apply 1D rotary embeddings using precomputed frequencies (interleaved)."""
# freqs_cis: (1, seq_len, num_heads, head_dim, 2) where [..., 0] = cos, [..., 1] = sin
cos = freqs_cis[..., 0] # (1, seq_len, num_heads, head_dim)
sin = freqs_cis[..., 1]
# q, k: (batch, seq_len, num_heads, head_dim)
# Interleaved RoPE: pairs of adjacent dims rotate together
q_r = q * cos + rotate_half_interleaved(q) * sin
k_r = k * cos + rotate_half_interleaved(k) * sin
return q_r, k_r
def apply_split_rotary_emb(
input_tensor: mx.array,
cos_freqs: mx.array,
sin_freqs: mx.array,
) -> mx.array:
"""Apply split rotary embeddings.
Splits dimensions into two halves and applies rotation.
Pattern: split into first half and second half
Args:
input_tensor: Input tensor
cos_freqs: Cosine frequencies of shape (B, H, T, D//2)
sin_freqs: Sine frequencies of shape (B, H, T, D//2)
Returns:
Tensor with split rotary embeddings applied
"""
needs_reshape = False
original_shape = input_tensor.shape
# Handle dimension mismatch
if input_tensor.ndim != 4 and cos_freqs.ndim == 4:
b, h, t, _ = cos_freqs.shape
# Reshape from (B, T, H*D) to (B, H, T, D)
input_tensor = mx.reshape(input_tensor, (b, t, h, -1))
input_tensor = mx.swapaxes(input_tensor, 1, 2)
needs_reshape = True
# Split into two halves: (..., dim) -> (..., 2, dim//2)
dim = input_tensor.shape[-1]
split_input = mx.reshape(input_tensor, input_tensor.shape[:-1] + (2, dim // 2))
# Get first and second halves
first_half = split_input[..., 0, :] # (..., dim//2)
second_half = split_input[..., 1, :] # (..., dim//2)
# Apply cosine to both halves
output_first = first_half * cos_freqs
output_second = second_half * cos_freqs
# Apply sine cross-terms (addcmul pattern)
output_first = output_first - sin_freqs * second_half
output_second = output_second + sin_freqs * first_half
# Stack back together
output = mx.stack([output_first, output_second], axis=-2)
# Flatten: (..., 2, dim//2) -> (..., dim)
output = mx.reshape(output, input_tensor.shape)
if needs_reshape:
# Reshape back: (B, H, T, D) -> (B, T, H*D)
b, h, t, d = output.shape
output = mx.swapaxes(output, 1, 2)
output = mx.reshape(output, (b, t, h * d))
return output
def generate_freq_grid(
positional_embedding_theta: float,
positional_embedding_max_pos_count: int,
inner_dim: int,
) -> mx.array:
"""Generate frequency grid for RoPE.
Args:
positional_embedding_theta: Base theta value
positional_embedding_max_pos_count: Number of position dimensions
inner_dim: Inner dimension of the model
Returns:
Frequency indices tensor
"""
theta = positional_embedding_theta
start = 1.0
end = theta
n_elem = 2 * positional_embedding_max_pos_count
# Compute logarithmic spacing
log_start = math.log(start) / math.log(theta)
log_end = math.log(end) / math.log(theta)
num_indices = inner_dim // n_elem
if num_indices == 0:
num_indices = 1
# Create linearly spaced values in log space
lin_space = mx.linspace(log_start, log_end, num_indices)
# Compute power indices
pow_indices = mx.power(theta, lin_space)
# Scale by pi/2
return pow_indices * (math.pi / 2)
def get_fractional_positions(
indices_grid: mx.array,
max_pos: List[int],
) -> mx.array:
"""Convert indices to fractional positions.
Args:
indices_grid: Grid of position indices of shape (B, n_pos_dims, ...)
max_pos: Maximum position for each dimension
Returns:
Fractional positions in range [-1, 1] after scaling
"""
n_pos_dims = indices_grid.shape[1]
assert n_pos_dims == len(max_pos), (
f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
)
# Divide each dimension by its max position
fractional_positions = []
for i in range(n_pos_dims):
frac = indices_grid[:, i] / max_pos[i]
fractional_positions.append(frac)
return mx.stack(fractional_positions, axis=-1)
def generate_freqs(
indices: mx.array,
indices_grid: mx.array,
max_pos: List[int],
use_middle_indices_grid: bool,
) -> mx.array:
"""Generate frequencies from indices and position grid.
Args:
indices: Frequency indices
indices_grid: Position indices grid
max_pos: Maximum positions per dimension
use_middle_indices_grid: Whether to use middle of index ranges
Returns:
Frequency tensor
"""
# Handle middle indices grid
if use_middle_indices_grid:
# indices_grid shape: (B, n_dims, T, 2) where last dim is [start, end]
assert len(indices_grid.shape) == 4
assert indices_grid.shape[-1] == 2
indices_grid_start = indices_grid[..., 0]
indices_grid_end = indices_grid[..., 1]
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
elif len(indices_grid.shape) == 4:
indices_grid = indices_grid[..., 0]
# Get fractional positions
fractional_positions = get_fractional_positions(indices_grid, max_pos)
# Compute frequencies
# fractional_positions: (B, T, n_dims)
# indices: (inner_dim // n_elem,)
# Result: (B, T, inner_dim // n_elem * n_dims)
# Scale fractional positions to [-1, 1]
scaled_positions = fractional_positions * 2 - 1 # (B, T, n_dims)
# Outer product with indices
# (B, T, n_dims, 1) * (1, 1, 1, n_indices) -> (B, T, n_dims, n_indices)
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.expand_dims(
mx.expand_dims(mx.expand_dims(indices, axis=0), axis=0), axis=0
)
# Transpose and flatten: (B, T, n_dims, n_indices) -> (B, T, n_indices * n_dims)
freqs = mx.swapaxes(freqs, -1, -2) # (B, T, n_indices, n_dims)
freqs = mx.reshape(freqs, freqs.shape[:-2] + (-1,))
return freqs
def split_freqs_cis(
freqs: mx.array,
pad_size: int,
num_attention_heads: int,
) -> Tuple[mx.array, mx.array]:
"""Prepare cos/sin frequencies for split RoPE.
Args:
freqs: Frequency tensor
pad_size: Padding size for dimension alignment
num_attention_heads: Number of attention heads
Returns:
Tuple of (cos_freq, sin_freq) with shape (B, H, T, D//2)
"""
cos_freq = mx.cos(freqs)
sin_freq = mx.sin(freqs)
# Add padding if needed
if pad_size != 0:
cos_padding = mx.ones_like(cos_freq[:, :, :pad_size])
sin_padding = mx.zeros_like(sin_freq[:, :, :pad_size])
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
# Reshape for multi-head attention
b, t = cos_freq.shape[0], cos_freq.shape[1]
cos_freq = mx.reshape(cos_freq, (b, t, num_attention_heads, -1))
sin_freq = mx.reshape(sin_freq, (b, t, num_attention_heads, -1))
# Swap axes: (B, T, H, D//2) -> (B, H, T, D//2)
cos_freq = mx.swapaxes(cos_freq, 1, 2)
sin_freq = mx.swapaxes(sin_freq, 1, 2)
return cos_freq, sin_freq
def interleaved_freqs_cis(
freqs: mx.array,
pad_size: int,
) -> Tuple[mx.array, mx.array]:
"""Prepare cos/sin frequencies for interleaved RoPE.
Args:
freqs: Frequency tensor of shape (B, T, dim//2)
pad_size: Padding size for dimension alignment
Returns:
Tuple of (cos_freq, sin_freq) with shape (B, T, dim)
"""
# Compute cos and sin
cos_freq = mx.cos(freqs)
sin_freq = mx.sin(freqs)
# Repeat interleave: each element repeated twice
# (B, T, D) -> (B, T, 2*D) with pattern [c0, c0, c1, c1, ...]
cos_freq = mx.repeat(cos_freq, 2, axis=-1)
sin_freq = mx.repeat(sin_freq, 2, axis=-1)
# Add padding if needed
if pad_size != 0:
cos_padding = mx.ones_like(cos_freq[:, :, :pad_size])
sin_padding = mx.zeros_like(sin_freq[:, :, :pad_size])
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
return cos_freq, sin_freq
def precompute_freqs_cis(
indices_grid: mx.array,
dim: int,
theta: float = 10000.0,
max_pos: Optional[List[int]] = None,
use_middle_indices_grid: bool = False,
num_attention_heads: int = 32,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
double_precision: bool = False,
) -> Tuple[mx.array, mx.array]:
"""Precompute RoPE frequencies.
Args:
indices_grid: Position indices grid
dim: Dimension for RoPE
theta: Base theta value for frequency computation
max_pos: Maximum position per dimension
use_middle_indices_grid: Whether to use middle indices
num_attention_heads: Number of attention heads
rope_type: Type of RoPE (INTERLEAVED or SPLIT)
double_precision: If True, compute frequencies in float64 for higher precision
Returns:
Tuple of (cos_freq, sin_freq) tensors
"""
if max_pos is None:
max_pos = [20, 2048, 2048]
# For double precision, compute in numpy (float64) then convert back to MLX
# MLX GPU doesn't support float64, so we use numpy for high precision computation
if double_precision:
return _precompute_freqs_cis_double_precision(
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
num_attention_heads, rope_type
)
# Generate frequency indices
indices = generate_freq_grid(theta, indices_grid.shape[1], dim)
# Generate frequencies
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
# Prepare cos/sin based on rope type
if rope_type == LTXRopeType.SPLIT:
expected_freqs = dim // 2
current_freqs = freqs.shape[-1]
pad_size = expected_freqs - current_freqs
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
else:
# Interleaved
n_elem = 2 * indices_grid.shape[1]
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
return cos_freq, sin_freq
def _precompute_freqs_cis_double_precision(
indices_grid: mx.array,
dim: int,
theta: float,
max_pos: List[int],
use_middle_indices_grid: bool,
num_attention_heads: int,
rope_type: LTXRopeType,
) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies in double precision using numpy.
MLX GPU doesn't support float64, so we use numpy for computation then convert back.
"""
# Convert to numpy float64
indices_grid_np = np.array(indices_grid).astype(np.float64)
# Generate frequency indices in float64
n_pos_dims = indices_grid_np.shape[1]
n_elem = 2 * n_pos_dims
# Compute log-spaced frequencies
log_start = math.log(1.0) / math.log(theta)
log_end = math.log(theta) / math.log(theta)
num_indices = dim // n_elem
if num_indices == 0:
num_indices = 1
lin_space = np.linspace(log_start, log_end, num_indices)
indices_np = np.power(theta, lin_space) * (math.pi / 2)
# Handle middle indices grid
# Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise
if use_middle_indices_grid:
assert len(indices_grid_np.shape) == 4
assert indices_grid_np.shape[-1] == 2
indices_grid_start = indices_grid_np[..., 0]
indices_grid_end = indices_grid_np[..., 1]
indices_grid_np = (indices_grid_start + indices_grid_end) / 2.0
elif len(indices_grid_np.shape) == 4:
indices_grid_np = indices_grid_np[..., 0]
# After handling: indices_grid_np shape is (B, n_dims, T)
# Get fractional positions: (B, n_dims, T) -> (B, T, n_dims)
batch_size = indices_grid_np.shape[0]
seq_len = indices_grid_np.shape[2]
fractional_positions = np.zeros((batch_size, seq_len, n_pos_dims), dtype=np.float64)
for i in range(n_pos_dims):
# indices_grid_np[:, i, :] has shape (B, T)
fractional_positions[:, :, i] = indices_grid_np[:, i, :] / max_pos[i]
# Scale to [-1, 1]
scaled_positions = fractional_positions * 2 - 1
# Compute frequencies: outer product
freqs = np.expand_dims(scaled_positions, axis=-1) * indices_np.reshape(1, 1, 1, -1)
freqs = np.swapaxes(freqs, -1, -2)
freqs = freqs.reshape(freqs.shape[:-2] + (-1,))
# Compute cos/sin in float64
cos_freq = np.cos(freqs)
sin_freq = np.sin(freqs)
# Prepare based on rope type
if rope_type == LTXRopeType.SPLIT:
expected_freqs = dim // 2
current_freqs = cos_freq.shape[-1]
pad_size = expected_freqs - current_freqs
# Add padding
if pad_size > 0:
cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64)
sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64)
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1)
# Reshape for multi-head attention: (B, T, dim//2) -> (B, H, T, dim//2//H)
b, t = cos_freq.shape[0], cos_freq.shape[1]
cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1)
sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1)
cos_freq = np.swapaxes(cos_freq, 1, 2)
sin_freq = np.swapaxes(sin_freq, 1, 2)
else:
# Interleaved
cos_freq = np.repeat(cos_freq, 2, axis=-1)
sin_freq = np.repeat(sin_freq, 2, axis=-1)
pad_size = dim % n_elem
if pad_size > 0:
cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64)
sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64)
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1)
# Convert back to MLX (float32 for GPU compatibility)
cos_freq = mx.array(cos_freq.astype(np.float32))
sin_freq = mx.array(sin_freq.astype(np.float32))
return cos_freq, sin_freq

View File

@@ -0,0 +1,727 @@
"""Gemma 3 Text Encoder for LTX-2 - Full Pipeline."""
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.utils import rms_norm
from mlx_video.models.ltx.rope import apply_rotary_emb_1d
@dataclass
class Gemma3Config:
"""Configuration for Gemma 3 text model."""
hidden_size: int = 3840
num_attention_heads: int = 16
num_key_value_heads: int = 8
head_dim: int = 256
intermediate_size: int = 15360
num_hidden_layers: int = 48
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
vocab_size: int = 262208
max_position_embeddings: int = 131072
class RMSNorm(nn.Module):
"""RMS Normalization (Gemma style with 1+weight scaling)."""
def __init__(self, dims: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
# Gemma initializes to ones, but uses (1+weight) scaling
# After loading weights, weight will have the actual learned values
self.weight = mx.ones((dims,))
def __call__(self, x: mx.array) -> mx.array:
# Gemma-style RMSNorm uses (1 + weight) as the scale factor
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
def apply_rotary_emb(
q: mx.array,
k: mx.array,
positions: mx.array,
head_dim: int,
rope_theta: float = 1000000.0,
) -> Tuple[mx.array, mx.array]:
"""Apply rotary position embeddings to Q and K."""
inv_freq = 1.0 / (rope_theta ** (mx.arange(0, head_dim, 2).astype(mx.float32) / head_dim))
freqs = positions[:, :, None].astype(mx.float32) * inv_freq[None, None, :]
cos = mx.cos(freqs)
sin = mx.sin(freqs)
cos = cos[:, :, None, :]
sin = sin[:, :, None, :]
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return mx.concatenate([-x2, x1], axis=-1)
cos_full = mx.concatenate([cos, cos], axis=-1)
sin_full = mx.concatenate([sin, sin], axis=-1)
q_embed = q * cos_full + rotate_half(q) * sin_full
k_embed = k * cos_full + rotate_half(k) * sin_full
return q_embed, k_embed
class Gemma3MLP(nn.Module):
"""Gemma 3 MLP with gated activation."""
def __init__(self, config: Gemma3Config):
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def __call__(self, x: mx.array) -> mx.array:
gate = nn.gelu_approx(self.gate_proj(x))
up = self.up_proj(x)
return self.down_proj(gate * up)
class Gemma3Attention(nn.Module):
def __init__(self, config: Gemma3Config):
super().__init__()
self.config = config
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.scale = 1.0 / math.sqrt(config.head_dim)
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
def __call__(
self,
hidden_states: mx.array,
positions: mx.array,
attention_mask: Optional[mx.array] = None,
) -> mx.array:
batch_size, seq_len, _ = hidden_states.shape
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
k = mx.reshape(k, (batch_size, seq_len, self.num_kv_heads, self.head_dim))
v = mx.reshape(v, (batch_size, seq_len, self.num_kv_heads, self.head_dim))
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, positions, self.head_dim, self.config.rope_theta)
q = mx.transpose(q, (0, 2, 1, 3))
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
# Create causal mask (lower triangular)
causal_mask = mx.triu(mx.full((seq_len, seq_len), -1e9, dtype=k.dtype), k=1)
causal_mask = causal_mask[None, None, :, :] # (1, 1, seq, seq
if attention_mask is not None:
causal_mask = causal_mask + (1.0 - attention_mask[:, None, None, :].astype(k.dtype)) * -1e9
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=causal_mask)
out = mx.transpose(out, (0, 2, 1, 3))
out = mx.reshape(out, (batch_size, seq_len, -1))
return self.o_proj(out)
class Gemma3DecoderLayer(nn.Module):
def __init__(self, config: Gemma3Config):
super().__init__()
self.self_attn = Gemma3Attention(config)
self.mlp = Gemma3MLP(config)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
hidden_states: mx.array,
positions: mx.array,
attention_mask: Optional[mx.array] = None,
) -> mx.array:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(hidden_states, positions, attention_mask)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Gemma3TextModel(nn.Module):
def __init__(self, config: Gemma3Config):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [Gemma3DecoderLayer(config) for _ in range(config.num_hidden_layers)]
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Gemma scales embeddings by sqrt(hidden_size)
self.embed_scale = config.hidden_size ** 0.5
def __call__(
self,
input_ids: mx.array,
attention_mask: Optional[mx.array] = None,
output_hidden_states: bool = True,
) -> Tuple[mx.array, List[mx.array]]:
batch_size, seq_len = input_ids.shape
# Gemma scales embeddings by sqrt(hidden_size)
hidden_states = self.embed_tokens(input_ids) * self.embed_scale
all_hidden_states = [hidden_states] if output_hidden_states else []
positions = mx.arange(seq_len)[None, :].astype(mx.int32)
positions = mx.broadcast_to(positions, (batch_size, seq_len))
for layer in self.layers:
hidden_states = layer(hidden_states, positions, attention_mask)
if output_hidden_states:
all_hidden_states.append(hidden_states)
hidden_states = self.norm(hidden_states)
return hidden_states, all_hidden_states
class ConnectorAttention(nn.Module):
def __init__(
self,
dim: int = 3840,
num_heads: int = 30,
head_dim: int = 128,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
inner_dim = num_heads * head_dim
self.scale = 1.0 / math.sqrt(head_dim)
self.to_q = nn.Linear(dim, inner_dim, bias=True)
self.to_k = nn.Linear(dim, inner_dim, bias=True)
self.to_v = nn.Linear(dim, inner_dim, bias=True)
self.to_out = [nn.Linear(inner_dim, dim, bias=True)]
# Standard RMSNorm (not Gemma-style) on full inner_dim
self.q_norm = nn.RMSNorm(inner_dim, eps=1e-6)
self.k_norm = nn.RMSNorm(inner_dim, eps=1e-6)
def __call__(
self,
x: mx.array,
attention_mask: Optional[mx.array] = None,
pe: Optional[mx.array] = None,
) -> mx.array:
batch_size, seq_len, _ = x.shape
# Project to Q, K, V
q = self.to_q(x) # (B, seq, inner_dim)
k = self.to_k(x)
v = self.to_v(x)
# QK normalization on full inner_dim BEFORE reshape (matches PyTorch)
q = self.q_norm(q)
k = self.k_norm(k)
if pe is not None:
# pe: (1, seq_len, num_heads, head_dim, 2)
# q, k: (B, seq, inner_dim) - need to reshape for RoPE then reshape back
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim))
q, k = apply_rotary_emb_1d(q, k, pe)
# Reshape back for attention computation
q = mx.reshape(q, (batch_size, seq_len, -1))
k = mx.reshape(k, (batch_size, seq_len, -1))
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
mask = mx.full((batch_size, seq_len, seq_len), -1e9, dtype=q.dtype)
if attention_mask is not None:
mask = mask + (1.0 - attention_mask[:, None, None, :].astype(q.dtype)) * -1e9
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attention_mask)
out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
return self.to_out[0](out)
class GEGLU(nn.Module):
"""GELU-gated linear unit."""
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.proj = nn.Linear(in_dim, out_dim, bias=True)
def __call__(self, x: mx.array) -> mx.array:
return nn.gelu_approx(self.proj(x))
class ConnectorFeedForward(nn.Module):
def __init__(self, dim: int = 3840, mult: int = 4, dropout: float = 0.0):
super().__init__()
inner_dim = dim * mult
self.net = [
GEGLU(dim, inner_dim),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim, bias=True),
]
def __call__(self, x: mx.array) -> mx.array:
for layer in self.net:
x = layer(x)
return x
class ConnectorTransformerBlock(nn.Module):
def __init__(self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128):
super().__init__()
self.attn1 = ConnectorAttention(dim, num_heads, head_dim)
self.ff = ConnectorFeedForward(dim)
def __call__(
self,
x: mx.array,
attention_mask: Optional[mx.array] = None,
pe: Optional[mx.array] = None,
) -> mx.array:
# Pre-norm + attention + residual
norm_x = rms_norm(x)
if norm_x.ndim == 4:
norm_x = mx.squeeze(norm_x, axis=1)
attn_out = self.attn1(norm_x, attention_mask, pe)
x = x + attn_out
if x.ndim == 4:
x = mx.squeeze(x, axis=1)
# Pre-norm + FFN + residual
norm_x = rms_norm(x)
ff_out = self.ff(norm_x)
x = x + ff_out
if x.ndim == 4:
x = mx.squeeze(x, axis=1)
return x
class Embeddings1DConnector(nn.Module):
def __init__(
self,
dim: int = 3840,
num_heads: int = 30,
head_dim: int = 128,
num_layers: int = 2,
num_learnable_registers: int = 128,
positional_embedding_theta: float = 10000.0,
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = head_dim
self.num_learnable_registers = num_learnable_registers
self.positional_embedding_theta = positional_embedding_theta
self.transformer_1d_blocks = [
ConnectorTransformerBlock(dim, num_heads, head_dim)
for _ in range(num_layers)
]
if num_learnable_registers > 0:
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> mx.array:
import math
dim = self.num_heads * self.head_dim
theta = self.positional_embedding_theta
n_elem = 2
linspace_vals = mx.linspace(0.0, 1.0, dim // n_elem)
indices = (theta ** linspace_vals) * (math.pi / 2)
positions = mx.arange(seq_len).astype(mx.float32)
freqs = positions[:, None] * indices[None, :] # (seq_len, dim//2)
cos = mx.cos(freqs) # (seq_len, dim//2)
sin = mx.sin(freqs)
cos_full = mx.repeat(cos, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
sin_full = mx.repeat(sin, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
freqs_cis = mx.stack([cos_full, sin_full], axis=-1) # (1, seq_len, num_heads, head_dim, 2)
return freqs_cis.astype(dtype)
def _replace_padded_with_registers(
self,
hidden_states: mx.array,
attention_mask: mx.array,
) -> Tuple[mx.array, mx.array]:
batch_size, seq_len, dim = hidden_states.shape
# Binary mask: 1 for valid tokens, 0 for padded
# attention_mask is additive: 0 for valid, large negative for padded
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(mx.int32) # (batch, seq)
# Tile registers to match sequence length
num_tiles = seq_len // self.num_learnable_registers
registers = mx.tile(self.learnable_registers, (num_tiles, 1)) # (seq_len, dim)
# Process each batch item (PyTorch uses advanced indexing)
result_list = []
for b in range(batch_size):
mask_b = mask_binary[b] # (seq,)
hs_b = hidden_states[b] # (seq, dim)
# Count valid tokens
num_valid = int(mx.sum(mask_b))
# Extract valid tokens (where mask is 1)
# Since we have left-padded input, valid tokens are at the end
valid_tokens = hs_b[seq_len - num_valid:] # (num_valid, dim)
# Pad with zeros on the right to get back to seq_len
pad_length = seq_len - num_valid
if pad_length > 0:
padding = mx.zeros((pad_length, dim), dtype=hs_b.dtype)
adjusted = mx.concatenate([valid_tokens, padding], axis=0) # (seq_len, dim)
else:
adjusted = valid_tokens
# Create flipped mask: 1s at front (where valid tokens now are), 0s at back
flipped_mask = mx.concatenate([
mx.ones((num_valid,), dtype=mx.int32),
mx.zeros((pad_length,), dtype=mx.int32)
], axis=0) # (seq,)
# Combine: valid tokens at front, registers at back
flipped_mask_expanded = flipped_mask[:, None].astype(hs_b.dtype) # (seq, 1)
combined = flipped_mask_expanded * adjusted + (1 - flipped_mask_expanded) * registers
result_list.append(combined)
hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim)
# Reset attention mask to all zeros (no masking after register replacement)
attention_mask = mx.zeros_like(attention_mask)
return hidden_states, attention_mask
def __call__(
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
) -> Tuple[mx.array, mx.array]:
# Replace padded tokens with learnable registers
if self.num_learnable_registers > 0 and attention_mask is not None:
hidden_states, attention_mask = self._replace_padded_with_registers(
hidden_states, attention_mask
)
# Compute RoPE frequencies
seq_len = hidden_states.shape[1]
freqs_cis = self._precompute_freqs_cis(seq_len, hidden_states.dtype)
# Process through transformer blocks
for block in self.transformer_1d_blocks:
hidden_states = block(hidden_states, attention_mask, freqs_cis)
# Final RMS norm
hidden_states = rms_norm(hidden_states)
return hidden_states, attention_mask
def norm_and_concat_hidden_states(
hidden_states: List[mx.array],
attention_mask: mx.array,
padding_side: str = "left",
) -> mx.array:
# Stack hidden states: (batch, seq, dim, num_layers)
stacked = mx.stack(hidden_states, axis=-1)
b, t, d, num_layers = stacked.shape
# Compute sequence lengths from attention mask
sequence_lengths = mx.sum(attention_mask, axis=-1) # (batch,)
# Build mask based on padding side
token_indices = mx.arange(t)[None, :] # (1, T)
if padding_side == "right":
mask = token_indices < sequence_lengths[:, None] # (B, T)
else: # left padding
start_indices = t - sequence_lengths[:, None] # (B, 1)
mask = token_indices >= start_indices # (B, T)
mask = mask[:, :, None, None] # (B, T, 1, 1)
eps = 1e-6
# Compute masked mean per layer
masked = mx.where(mask, stacked, mx.zeros_like(stacked))
denom = (sequence_lengths * d).reshape(b, 1, 1, 1)
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
# Compute masked min/max per layer
large_val = 1e9
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, large_val, dtype=stacked.dtype))
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, -large_val, dtype=stacked.dtype))
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
range_val = x_max - x_min
# Normalize: 8 * (x - mean) / range
normed = 8 * (stacked - mean) / (range_val + eps)
# Flatten layers into feature dimension: (B, T, D*L)
normed = mx.reshape(normed, (b, t, -1))
# Zero out padded positions
mask_flat = mx.broadcast_to(mask[:, :, :, 0], (b, t, d * num_layers))
normed = mx.where(mask_flat, normed, mx.zeros_like(normed))
return normed
class GemmaFeaturesExtractor(nn.Module):
def __init__(self, input_dim: int = 188160, output_dim: int = 3840):
super().__init__()
self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=False)
def __call__(self, x: mx.array) -> mx.array:
return self.aggregate_embed(x)
def sanitize_gemma3_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
sanitized = {}
for key, value in weights.items():
new_key = None
if key.startswith("base_text_encoder.language_model."):
new_key = key.replace("base_text_encoder.language_model.", "")
elif key.startswith("language_model.model."):
new_key = key.replace("language_model.model.", "")
elif key.startswith("language_model."):
new_key = key.replace("language_model.", "")
else:
continue
if new_key is None:
continue
sanitized[new_key] = value
return sanitized
class LTX2TextEncoder(nn.Module):
def __init__(
self,
model_path: str = "Lightricks/LTX-2",
hidden_dim: int = 3840,
num_layers: int = 49, # 48 transformer layers + 1 embedding
):
super().__init__()
self._model_path = model_path
self.hidden_dim = hidden_dim
self.num_layers = num_layers
# Gemma 3 model
self.config = Gemma3Config()
self.model = Gemma3TextModel(self.config)
# Feature extractor: 3840*49 -> 3840
self.feature_extractor = GemmaFeaturesExtractor(
input_dim=hidden_dim * num_layers,
output_dim=hidden_dim,
)
# Video embeddings connector: 2-layer transformer
self.video_embeddings_connector = Embeddings1DConnector(
dim=hidden_dim,
num_heads=30,
head_dim=128,
num_layers=2,
num_learnable_registers=128,
)
self.processor = None
def load(self, model_path: Optional[str] = None):
path = model_path or self._model_path
# Load Gemma weights from text_encoder subdirectory
if Path(path).is_dir():
text_encoder_path = Path(path) / "text_encoder"
if text_encoder_path.exists():
gemma_path = str(text_encoder_path)
else:
gemma_path = path
else:
gemma_path = path
print(f"Loading Gemma 3 text encoder from {gemma_path}...")
weight_files = sorted(Path(gemma_path).glob("*.safetensors"))
all_weights = {}
for i, wf in enumerate(weight_files):
print(f" Loading weight file {i+1}/{len(weight_files)}...")
weights = mx.load(str(wf))
all_weights.update(weights)
# Sanitize and load Gemma weights
sanitized = sanitize_gemma3_weights(all_weights)
print(f" Sanitized Gemma weights: {len(sanitized)}")
self.model.load_weights(list(sanitized.items()), strict=False)
# Load transformer weights for feature extractor and connector
transformer_path = Path(model_path or self._model_path)
transformer_files = list(transformer_path.glob("ltx-2*.safetensors"))
if transformer_files:
print(f"Loading transformer weights for text pipeline...")
transformer_weights = mx.load(str(transformer_files[0]))
# Load feature extractor (aggregate_embed)
if "text_embedding_projection.aggregate_embed.weight" in transformer_weights:
self.feature_extractor.aggregate_embed.weight = transformer_weights[
"text_embedding_projection.aggregate_embed.weight"
]
print(" Loaded aggregate_embed weights")
# Load video_embeddings_connector weights
connector_weights = {}
for key, value in transformer_weights.items():
if key.startswith("model.diffusion_model.video_embeddings_connector."):
new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "")
connector_weights[new_key] = value
if connector_weights:
# Map weight names to our structure
mapped_weights = {}
for key, value in connector_weights.items():
# transformer_1d_blocks.X.attn1.* -> transformer_1d_blocks.X.attn1.*
# transformer_1d_blocks.X.ff.net.0.proj.* -> transformer_1d_blocks.X.ff.net.0.proj.*
# transformer_1d_blocks.X.ff.net.2.* -> transformer_1d_blocks.X.ff.net.2.*
mapped_weights[key] = value
self.video_embeddings_connector.load_weights(
list(mapped_weights.items()), strict=False
)
print(f" Loaded {len(connector_weights)} connector weights")
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
if "learnable_registers" in connector_weights:
self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"]
print(f" Loaded learnable_registers: {connector_weights['learnable_registers'].shape}")
# Load tokenizer
from transformers import AutoTokenizer
tokenizer_path = Path(model_path or self._model_path) / "tokenizer"
if tokenizer_path.exists():
self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True)
else:
self.processor = AutoTokenizer.from_pretrained(gemma_path, trust_remote_code=True)
# Set left padding to match official LTX-2 text encoder
self.processor.padding_side = "left"
print("Text encoder loaded successfully")
def encode(
self,
prompt: str,
max_length: int = 1024,
) -> Tuple[mx.array, mx.array]:
if self.processor is None:
raise RuntimeError("Model not loaded. Call load() first.")
# Tokenize with left padding (as in PyTorch version)
inputs = self.processor(
prompt,
return_tensors="np",
max_length=max_length,
truncation=True,
padding="max_length",
)
input_ids = mx.array(inputs["input_ids"])
attention_mask = mx.array(inputs["attention_mask"])
# Get all hidden states from Gemma
_, all_hidden_states = self.model(input_ids, attention_mask, output_hidden_states=True)
# Normalize and concatenate all hidden states
concat_hidden = norm_and_concat_hidden_states(
all_hidden_states, attention_mask, padding_side="left"
)
# Project through feature extractor
features = self.feature_extractor(concat_hidden)
# Convert attention mask to additive format for connector
additive_mask = (attention_mask - 1).astype(features.dtype)
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
# Process through connector
# Note: connector replaces padding with learnable registers and resets mask to zeros
# This means all positions now have valid embeddings (no need for final masking)
embeddings, _ = self.video_embeddings_connector(features, additive_mask)
# Return embeddings without zeroing - the connector's register replacement
# means all positions have meaningful values now
return embeddings, attention_mask
def __call__(
self,
prompt: str,
max_length: int = 1024,
) -> Tuple[mx.array, mx.array]:
return self.encode(prompt, max_length)
def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
encoder = LTX2TextEncoder(model_path=model_path)
encoder.load()
return encoder

View File

@@ -0,0 +1,26 @@
import mlx.core as mx
import mlx.nn as nn
class PixArtAlphaTextProjection(nn.Module):
def __init__(
self,
in_features: int,
hidden_size: int,
out_features: int | None = None,
bias: bool = True,
):
super().__init__()
out_features = out_features or hidden_size
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
self.act = nn.GELU(approx="precise")
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
def __call__(self, x: mx.array) -> mx.array:
x = self.linear1(x)
x = self.act(x)
x = self.linear2(x)
return x

View File

@@ -0,0 +1,359 @@
from dataclasses import dataclass, replace
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.config import LTXRopeType, TransformerConfig
from mlx_video.models.ltx.attention import Attention
from mlx_video.models.ltx.feed_forward import FeedForward
from mlx_video.utils import rms_norm
@dataclass(frozen=True)
class Modality:
latent: mx.array
timesteps: mx.array
positions: mx.array
context: mx.array
enabled: bool = True
context_mask: Optional[mx.array] = None
@dataclass(frozen=True)
class TransformerArgs:
x: mx.array
context: mx.array
context_mask: Optional[mx.array]
timesteps: mx.array
embedded_timestep: mx.array
positional_embeddings: Tuple[mx.array, mx.array]
cross_positional_embeddings: Optional[Tuple[mx.array, mx.array]]
cross_scale_shift_timestep: Optional[mx.array]
cross_gate_timestep: Optional[mx.array]
enabled: bool
class BasicAVTransformerBlock(nn.Module):
"""Audio-Video transformer block with cross-modal attention.
Supports video-only, audio-only, or combined audio-video processing
with bidirectional cross-attention between modalities.
"""
def __init__(
self,
idx: int,
video: Optional[TransformerConfig] = None,
audio: Optional[TransformerConfig] = None,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
norm_eps: float = 1e-6,
):
"""Initialize transformer block.
Args:
idx: Block index
video: Video modality configuration
audio: Audio modality configuration
rope_type: Type of rotary position embedding
norm_eps: Epsilon for normalization
"""
super().__init__()
self.idx = idx
self.norm_eps = norm_eps
# Video components
if video is not None:
self.attn1 = Attention(
query_dim=video.dim,
heads=video.heads,
dim_head=video.d_head,
context_dim=None, # Self-attention
rope_type=rope_type,
norm_eps=norm_eps,
)
self.attn2 = Attention(
query_dim=video.dim,
context_dim=video.context_dim,
heads=video.heads,
dim_head=video.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
)
self.ff = FeedForward(video.dim, dim_out=video.dim)
# 6 scale-shift parameters: 3 for attention, 3 for MLP
self.scale_shift_table = mx.zeros((6, video.dim))
# Audio components
if audio is not None:
self.audio_attn1 = Attention(
query_dim=audio.dim,
heads=audio.heads,
dim_head=audio.d_head,
context_dim=None,
rope_type=rope_type,
norm_eps=norm_eps,
)
self.audio_attn2 = Attention(
query_dim=audio.dim,
context_dim=audio.context_dim,
heads=audio.heads,
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
)
self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)
self.audio_scale_shift_table = mx.zeros((6, audio.dim))
# Cross-modal attention (when both video and audio are enabled)
if audio is not None and video is not None:
# Audio-to-Video: Q from video, K/V from audio
self.audio_to_video_attn = Attention(
query_dim=video.dim,
context_dim=audio.dim,
heads=audio.heads,
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
)
# Video-to-Audio: Q from audio, K/V from video
self.video_to_audio_attn = Attention(
query_dim=audio.dim,
context_dim=video.dim,
heads=audio.heads,
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
)
# Scale-shift tables for cross-attention
self.scale_shift_table_a2v_ca_audio = mx.zeros((5, audio.dim))
self.scale_shift_table_a2v_ca_video = mx.zeros((5, video.dim))
def get_ada_values(
self,
scale_shift_table: mx.array,
batch_size: int,
timestep: mx.array,
indices: slice,
) -> Tuple[mx.array, ...]:
"""Get adaptive normalization values from scale-shift table.
Args:
scale_shift_table: Table of shape (num_params, dim)
batch_size: Batch size
timestep: Timestep embeddings of shape (B, 1, num_params * dim) or similar
indices: Slice for which parameters to extract
Returns:
Tuple of scale-shift values
"""
num_ada_params = scale_shift_table.shape[0]
# scale_shift_table[indices]: (num_selected, dim)
# Add batch and sequence dimensions: (1, 1, num_selected, dim)
table_slice = scale_shift_table[indices]
table_expanded = mx.expand_dims(mx.expand_dims(table_slice, axis=0), axis=0)
# timestep: (B, seq, num_params * dim) -> reshape to (B, seq, num_params, dim)
timestep_reshaped = mx.reshape(
timestep,
(batch_size, timestep.shape[1], num_ada_params, -1)
)
# Extract the relevant indices
timestep_slice = timestep_reshaped[:, :, indices, :]
# Add table values to timestep
ada_values = table_expanded + timestep_slice
# Unbind along the parameter dimension
# Result: tuple of tensors, each of shape (B, seq, dim)
num_sliced = ada_values.shape[2]
result = tuple(ada_values[:, :, i, :] for i in range(num_sliced))
return result
def get_av_ca_ada_values(
self,
scale_shift_table: mx.array,
batch_size: int,
scale_shift_timestep: mx.array,
gate_timestep: mx.array,
num_scale_shift_values: int = 4,
) -> Tuple[mx.array, mx.array, mx.array, mx.array, mx.array]:
"""Get adaptive values for cross-modal attention.
Args:
scale_shift_table: Table with 5 parameters (4 scale-shift + 1 gate)
batch_size: Batch size
scale_shift_timestep: Timestep for scale-shift
gate_timestep: Timestep for gating
num_scale_shift_values: Number of scale-shift values (default 4)
Returns:
Tuple of 5 tensors: (scale1, shift1, scale2, shift2, gate)
"""
# Get scale-shift values
scale_shift_ada = self.get_ada_values(
scale_shift_table[:num_scale_shift_values, :],
batch_size,
scale_shift_timestep,
slice(None, None),
)
# Get gate values
gate_ada = self.get_ada_values(
scale_shift_table[num_scale_shift_values:, :],
batch_size,
gate_timestep,
slice(None, None),
)
# Squeeze the sequence dimension if it's 1
scale_shift_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada)
gate_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada)
return (*scale_shift_squeezed, *gate_squeezed)
def __call__(
self,
video: Optional[TransformerArgs] = None,
audio: Optional[TransformerArgs] = None,
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
"""Forward pass through transformer block.
Args:
video: Video modality arguments
audio: Audio modality arguments
Returns:
Tuple of (updated_video, updated_audio) TransformerArgs
"""
batch_size = video.x.shape[0] if video is not None else audio.x.shape[0]
vx = video.x if video is not None else None
ax = audio.x if audio is not None else None
# Check which modalities to run
run_vx = video is not None and video.enabled and vx.size > 0
run_ax = audio is not None and audio.enabled and ax.size > 0
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0)
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0)
# Process video self-attention and cross-attention with text
if run_vx:
vshift_msa, vscale_msa, vgate_msa = self.get_ada_values(
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
)
# Self-attention with RoPE
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa
# Cross-attention with text context
vx = vx + self.attn2(
rms_norm(vx, eps=self.norm_eps),
context=video.context,
mask=video.context_mask,
)
# Process audio self-attention and cross-attention with text
if run_ax:
ashift_msa, ascale_msa, agate_msa = self.get_ada_values(
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
)
# Self-attention with RoPE
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa
# Cross-attention with text context
ax = ax + self.audio_attn2(
rms_norm(ax, eps=self.norm_eps),
context=audio.context,
mask=audio.context_mask,
)
# Audio-Video cross-modal attention
if run_a2v or run_v2a:
vx_norm3 = rms_norm(vx, eps=self.norm_eps)
ax_norm3 = rms_norm(ax, eps=self.norm_eps)
# Get adaptive values for audio cross-attention
(
scale_ca_audio_a2v,
shift_ca_audio_a2v,
scale_ca_audio_v2a,
shift_ca_audio_v2a,
gate_out_v2a,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_audio,
ax.shape[0],
audio.cross_scale_shift_timestep,
audio.cross_gate_timestep,
)
# Get adaptive values for video cross-attention
(
scale_ca_video_a2v,
shift_ca_video_a2v,
scale_ca_video_v2a,
shift_ca_video_v2a,
gate_out_a2v,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_video,
vx.shape[0],
video.cross_scale_shift_timestep,
video.cross_gate_timestep,
)
# Audio-to-Video cross-attention
if run_a2v:
vx_scaled = vx_norm3 * (1 + scale_ca_video_a2v) + shift_ca_video_a2v
ax_scaled = ax_norm3 * (1 + scale_ca_audio_a2v) + shift_ca_audio_a2v
vx = vx + (
self.audio_to_video_attn(
vx_scaled,
context=ax_scaled,
pe=video.cross_positional_embeddings,
k_pe=audio.cross_positional_embeddings,
)
* gate_out_a2v
)
# Video-to-Audio cross-attention
if run_v2a:
ax_scaled = ax_norm3 * (1 + scale_ca_audio_v2a) + shift_ca_audio_v2a
vx_scaled = vx_norm3 * (1 + scale_ca_video_v2a) + shift_ca_video_v2a
ax = ax + (
self.video_to_audio_attn(
ax_scaled,
context=vx_scaled,
pe=audio.cross_positional_embeddings,
k_pe=video.cross_positional_embeddings,
)
* gate_out_v2a
)
# Process video feed-forward
if run_vx:
vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None)
)
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
vx = vx + self.ff(vx_scaled) * vgate_mlp
# Process audio feed-forward
if run_ax:
ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None)
)
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
ax = ax + self.audio_ff(ax_scaled) * agate_mlp
# Return updated TransformerArgs
video_out = replace(video, x=vx) if video is not None else None
audio_out = replace(audio, x=ax) if audio is not None else None
return video_out, audio_out

View File

@@ -0,0 +1,364 @@
from typing import Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
class Conv3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]] = 3,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
groups: int = 1,
bias: bool = True,
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
# Weight shape: (C_out, KD, KH, KW, C_in)
scale = 1.0 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels),
)
if bias:
self.bias = mx.zeros((out_channels,))
else:
self.bias = None
def __call__(self, x: mx.array) -> mx.array:
"""Forward pass.
Args:
x: Input tensor of shape (N, D, H, W, C_in)
Returns:
Output tensor of shape (N, D', H', W', C_out)
"""
y = mx.conv3d(
x,
self.weight,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
if self.bias is not None:
y = y + self.bias
return y
class GroupNorm3d(nn.Module):
def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5):
super().__init__()
self.num_groups = num_groups
self.num_channels = num_channels
self.eps = eps
self.weight = mx.ones((num_channels,))
self.bias = mx.zeros((num_channels,))
def __call__(self, x: mx.array) -> mx.array:
# x: (N, D, H, W, C)
n, d, h, w, c = x.shape
# Reshape to (N, D*H*W, num_groups, C//num_groups)
x = mx.reshape(x, (n, d * h * w, self.num_groups, c // self.num_groups))
# Compute mean and var over spatial and channel group dims
mean = mx.mean(x, axis=(1, 3), keepdims=True)
var = mx.var(x, axis=(1, 3), keepdims=True)
# Normalize
x = (x - mean) / mx.sqrt(var + self.eps)
# Reshape back
x = mx.reshape(x, (n, d, h, w, c))
# Apply weight and bias
x = x * self.weight + self.bias
return x
class PixelShuffle2D(nn.Module):
"""Pixel shuffle for 2D spatial upsampling."""
def __init__(self, upscale_factor: int = 2):
super().__init__()
self.upscale_factor = upscale_factor
def __call__(self, x: mx.array) -> mx.array:
# x: (N, H, W, C) where C = out_channels * upscale_factor^2
n, h, w, c = x.shape
r = self.upscale_factor
out_c = c // (r * r)
# Reshape: (N, H, W, out_c, r, r)
x = mx.reshape(x, (n, h, w, out_c, r, r))
# Permute: (N, H, r, W, r, out_c)
x = mx.transpose(x, (0, 1, 4, 2, 5, 3))
# Reshape: (N, H*r, W*r, out_c)
x = mx.reshape(x, (n, h * r, w * r, out_c))
return x
class SpatialRationalResampler(nn.Module):
def __init__(self, mid_channels: int = 1024, scale: float = 2.0):
super().__init__()
self.scale = scale
# 2D conv: mid_channels -> 4*mid_channels for pixel shuffle
self.conv = nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1)
# Blur kernel for antialiasing
self.blur_down_kernel = mx.ones((1, 1, 5, 5)) / 25.0
self.pixel_shuffle = PixelShuffle2D(2)
def __call__(self, x: mx.array) -> mx.array:
# x: (N, D, H, W, C) - channels last 3D format
n, d, h, w, c = x.shape
# Process frame by frame
# Reshape to (N*D, H, W, C) for 2D operations
x = mx.reshape(x, (n * d, h, w, c))
# Apply 2D conv
x = self.conv(x)
# Pixel shuffle for 2x upscaling
x = self.pixel_shuffle(x)
# Reshape back to (N, D, H*2, W*2, C)
x = mx.reshape(x, (n, d, h * 2, w * 2, c))
return x
class ResBlock3D(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.conv1 = Conv3d(channels, channels, kernel_size=3, padding=1)
self.norm1 = GroupNorm3d(32, channels)
self.conv2 = Conv3d(channels, channels, kernel_size=3, padding=1)
self.norm2 = GroupNorm3d(32, channels)
def __call__(self, x: mx.array) -> mx.array:
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = nn.silu(x)
x = self.conv2(x)
x = self.norm2(x)
# Activation AFTER residual addition
x = nn.silu(x + residual)
return x
class LatentUpsampler(nn.Module):
def __init__(
self,
in_channels: int = 128,
mid_channels: int = 1024,
num_blocks_per_stage: int = 4,
):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
# Initial projection
self.initial_conv = Conv3d(in_channels, mid_channels, kernel_size=3, padding=1)
self.initial_norm = GroupNorm3d(32, mid_channels)
# Pre-upsample ResBlocks - use dict with int keys for MLX parameter tracking
self.res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
# Upsampler: 2D spatial upsampling (frame-by-frame)
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=2.0)
# Post-upsample ResBlocks - use dict with int keys for MLX parameter tracking
self.post_upsample_res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
# Final projection
self.final_conv = Conv3d(mid_channels, in_channels, kernel_size=3, padding=1)
def __call__(self, latent: mx.array, debug: bool = False) -> mx.array:
"""Upsample latents by 2x spatially.
Args:
latent: Input tensor of shape (B, C, F, H, W) - channels first
debug: If True, print intermediate values for debugging
Returns:
Upsampled tensor of shape (B, C, F, H*2, W*2) - channels first
"""
def debug_stats(name, t):
if debug:
mx.eval(t)
print(f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
if debug:
print(" [DEBUG] LatentUpsampler forward pass:")
debug_stats("Input (channels first)", latent)
# Convert from channels first (B, C, F, H, W) to channels last (B, F, H, W, C)
x = mx.transpose(latent, (0, 2, 3, 4, 1))
if debug:
debug_stats("After transpose to channels-last", x)
# Initial conv
x = self.initial_conv(x)
if debug:
debug_stats("After initial_conv", x)
x = self.initial_norm(x)
if debug:
debug_stats("After initial_norm", x)
x = nn.silu(x)
if debug:
debug_stats("After silu", x)
# Pre-upsample blocks
for i in sorted(self.res_blocks.keys()):
x = self.res_blocks[i](x)
if debug:
debug_stats(f"After res_blocks[{i}]", x)
# Upsample (2D spatial, frame-by-frame)
x = self.upsampler(x)
if debug:
debug_stats("After upsampler (spatial 2x)", x)
# Post-upsample blocks
for i in sorted(self.post_upsample_res_blocks.keys()):
x = self.post_upsample_res_blocks[i](x)
if debug:
debug_stats(f"After post_upsample_res_blocks[{i}]", x)
# Final conv
x = self.final_conv(x)
if debug:
debug_stats("After final_conv", x)
# Convert back to channels first (B, C, F, H, W)
x = mx.transpose(x, (0, 4, 1, 2, 3))
if debug:
debug_stats("Output (channels first)", x)
return x
def upsample_latents(
latent: mx.array,
upsampler: LatentUpsampler,
latent_mean: mx.array,
latent_std: mx.array,
debug: bool = False,
) -> mx.array:
# Un-normalize: latent * std + mean
latent_mean = latent_mean.reshape(1, -1, 1, 1, 1)
latent_std = latent_std.reshape(1, -1, 1, 1, 1)
latent = latent * latent_std + latent_mean
# Upsample
latent = upsampler(latent, debug=debug)
# Re-normalize: (latent - mean) / std
latent = (latent - latent_mean) / latent_std
return latent
def load_upsampler(weights_path: str) -> LatentUpsampler:
"""Load upsampler from safetensors weights.
Args:
weights_path: Path to upsampler weights file
Returns:
Loaded LatentUpsampler model
"""
print(f"Loading spatial upsampler from {weights_path}...")
raw_weights = mx.load(weights_path)
# Check weight shapes to determine mid_channels
# res_blocks.0.conv1.weight should be (mid_channels, mid_channels, 3, 3, 3)
sample_key = "res_blocks.0.conv1.weight"
if sample_key in raw_weights:
mid_channels = raw_weights[sample_key].shape[0]
else:
mid_channels = 1024 # default
print(f" Detected mid_channels: {mid_channels}")
# Create model
upsampler = LatentUpsampler(
in_channels=128,
mid_channels=mid_channels,
num_blocks_per_stage=4,
)
# Sanitize weights - convert from PyTorch to MLX format
sanitized = {}
for key, value in raw_weights.items():
new_key = key
# Conv3d weights: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
if "conv" in key and "weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Conv2d weights: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
if "conv" in key and "weight" in key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
# Map upsampler.conv to upsampler.conv (SpatialRationalResampler)
# Keys: upsampler.conv.weight, upsampler.conv.bias, upsampler.blur_down.kernel
if key.startswith("upsampler."):
new_key = key # Keep as is for SpatialRationalResampler
sanitized[new_key] = value
# Load weights
upsampler.load_weights(list(sanitized.items()), strict=False)
print(f" Loaded {len(sanitized)} weights")
return upsampler

View File

@@ -0,0 +1 @@
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder

View File

@@ -0,0 +1,294 @@
from enum import Enum
from typing import List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
class PaddingModeType(Enum):
ZEROS = "zeros"
REFLECT = "reflect"
def reflect_pad_2d(x: mx.array, pad_h: int, pad_w: int) -> mx.array:
"""Apply reflect padding to spatial dimensions of a 5D tensor.
Args:
x: Input tensor of shape (B, D, H, W, C) - channels last
pad_h: Padding for height dimension
pad_w: Padding for width dimension
Returns:
Padded tensor
"""
if pad_h == 0 and pad_w == 0:
return x
# Height padding (axis 2)
if pad_h > 0:
# Get reflection indices - exclude boundary
top_pad = x[:, :, 1:pad_h+1, :, :][:, :, ::-1, :, :] # Flip top portion
bottom_pad = x[:, :, -pad_h-1:-1, :, :][:, :, ::-1, :, :] # Flip bottom portion
x = mx.concatenate([top_pad, x, bottom_pad], axis=2)
# Width padding (axis 3)
if pad_w > 0:
left_pad = x[:, :, :, 1:pad_w+1, :][:, :, :, ::-1, :] # Flip left portion
right_pad = x[:, :, :, -pad_w-1:-1, :][:, :, :, ::-1, :] # Flip right portion
x = mx.concatenate([left_pad, x, right_pad], axis=3)
return x
def make_conv_nd(
dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[int, Tuple[int, ...], str] = 0,
causal: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
) -> nn.Module:
if dims == 2:
return CausalConv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
causal=causal,
spatial_padding_mode=spatial_padding_mode,
)
elif dims == 3:
return CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
causal=causal,
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"Unsupported number of dimensions: {dims}")
class CausalConv3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int], str] = 0,
causal: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
self.causal = causal
self.spatial_padding_mode = spatial_padding_mode
# Normalize kernel_size and stride to tuples
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride, stride)
self.kernel_size = kernel_size
self.stride = stride
self.time_kernel_size = kernel_size[0]
# Calculate spatial padding (temporal is handled separately via frame replication)
height_pad = kernel_size[1] // 2
width_pad = kernel_size[2] // 2
self.spatial_padding = (height_pad, width_pad)
# Create the base convolution (without padding, we'll handle it manually)
self.conv = nn.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0, # We handle padding manually
bias=True,
)
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
use_causal = causal if causal is not None else self.causal
# Apply temporal padding via frame replication
# Only apply if kernel_size > 1
if self.time_kernel_size > 1:
if use_causal:
# Causal: replicate first frame kernel_size-1 times at the beginning
first_frame_pad = mx.repeat(x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2)
x = mx.concatenate([first_frame_pad, x], axis=2)
else:
# Non-causal: replicate first frame at start, last frame at end
pad_size = (self.time_kernel_size - 1) // 2
if pad_size > 0:
first_frame_pad = mx.repeat(x[:, :, :1, :, :], pad_size, axis=2)
last_frame_pad = mx.repeat(x[:, :, -1:, :, :], pad_size, axis=2)
x = mx.concatenate([first_frame_pad, x, last_frame_pad], axis=2)
# Transpose to channels last: (B, C, D, H, W) -> (B, D, H, W, C)
x = mx.transpose(x, (0, 2, 3, 4, 1))
# Apply spatial padding
pad_h, pad_w = self.spatial_padding
if pad_h > 0 or pad_w > 0:
if self.spatial_padding_mode == PaddingModeType.REFLECT:
# Use reflect padding for spatial dimensions
x = reflect_pad_2d(x, pad_h, pad_w)
else:
# Use zero padding for spatial dimensions
pad_width = [
(0, 0), # Batch
(0, 0), # D (temporal - already padded)
(pad_h, pad_h), # H
(pad_w, pad_w), # W
(0, 0), # C
]
x = mx.pad(x, pad_width)
# Apply convolution with chunking for large tensors
# Note: We choose to use chunking because MLX conv3d fails around 33 frames with 192x192 spatial
x = self._chunked_conv3d(x)
# Transpose back to channels first: (B, D, H, W, C) -> (B, C, D, H, W)
x = mx.transpose(x, (0, 4, 1, 2, 3))
return x
def _chunked_conv3d(self, x: mx.array) -> mx.array:
"""Apply conv3d in temporal chunks to work around MLX bug with large tensors.
Args:
x: Input tensor of shape (B, D, H, W, C) in channels-last format
Returns:
Output tensor after conv3d
"""
b, d, h, w, c = x.shape
total_elements = d * h * w * c
max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk
if total_elements <= max_safe_elements:
return self.conv(x)
elements_per_frame = h * w * c
max_frames_per_chunk = max(1, max_safe_elements // elements_per_frame)
chunk_size = min(max_frames_per_chunk, 24) # Cap at 24 frames per chunk
kernel_t = self.time_kernel_size
overlap = kernel_t - 1
expected_output_frames = d - overlap
outputs = []
out_idx = 0
# Process chunks
in_start = 0
while out_idx < expected_output_frames:
remaining = expected_output_frames - out_idx
out_frames_this_chunk = min(chunk_size, remaining)
in_frames_needed = out_frames_this_chunk + overlap
in_end = min(in_start + in_frames_needed, d)
chunk = x[:, in_start:in_end, :, :, :]
chunk_out = self.conv(chunk)
mx.eval(chunk_out)
outputs.append(chunk_out)
out_idx += chunk_out.shape[1]
in_start += chunk_out.shape[1]
# Concatenate all chunks
if len(outputs) == 1:
return outputs[0]
return mx.concatenate(outputs, axis=1)
class CausalConv2d(nn.Module):
"""2D convolution with optional causal padding."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int], str] = 0,
causal: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
"""Initialize CausalConv2d."""
super().__init__()
self.causal = causal
self.spatial_padding_mode = spatial_padding_mode
# Normalize kernel_size and stride
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride)
self.kernel_size = kernel_size
self.stride = stride
# Calculate padding
if isinstance(padding, str) and padding == "same":
self.padding = (
(kernel_size[0] - 1) // 2,
(kernel_size[1] - 1) // 2,
)
elif isinstance(padding, int):
self.padding = (padding, padding)
else:
self.padding = padding
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0,
bias=True,
)
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
"""Forward pass."""
# Transpose to channels last: (B, C, H, W) -> (B, H, W, C)
x = mx.transpose(x, (0, 2, 3, 1))
# Apply padding
pad_h, pad_w = self.padding
if pad_h != 0 or pad_w != 0:
pad_width = [
(0, 0), # Batch
(pad_h, pad_h), # H
(pad_w, pad_w), # W
(0, 0), # C
]
x = mx.pad(x, pad_width)
x = self.conv(x)
# Transpose back: (B, H, W, C) -> (B, C, H, W)
x = mx.transpose(x, (0, 3, 1, 2))
return x

View File

@@ -0,0 +1,524 @@
"""Video VAE Decoder for LTX-2 with timestep conditioning.
Architecture (from PyTorch weights):
- conv_in: 128 -> 1024
- up_blocks.0: 5 ResBlocks at 1024 (with timestep)
- up_blocks.1: Conv 1024 -> 4096, depth2space -> 512, upscale 2x
- up_blocks.2: 5 ResBlocks at 512 (with timestep)
- up_blocks.3: Conv 512 -> 2048, depth2space -> 256, upscale 2x
- up_blocks.4: 5 ResBlocks at 256 (with timestep)
- up_blocks.5: Conv 256 -> 1024, depth2space -> 128, upscale 2x
- up_blocks.6: 5 ResBlocks at 128 (with timestep)
- pixel_norm + timestep modulation (last_scale_shift_table)
- conv_out: 128 -> 48
- unpatchify: 48 -> 3 with patch_size=4
"""
import math
from typing import List, Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx.video_vae.ops import unpatchify
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
def get_timestep_embedding(
timesteps: mx.array,
embedding_dim: int,
flip_sin_to_cos: bool = True,
downscale_freq_shift: float = 0,
scale: float = 1,
max_period: int = 10000,
) -> mx.array:
"""Create sinusoidal timestep embeddings."""
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * mx.arange(0, half_dim, dtype=mx.float32)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = mx.exp(exponent)
emb = timesteps[:, None].astype(mx.float32) * emb[None, :]
emb = scale * emb
emb = mx.concatenate([mx.sin(emb), mx.cos(emb)], axis=-1)
if flip_sin_to_cos:
emb = mx.concatenate([emb[:, half_dim:], emb[:, :half_dim]], axis=-1)
if embedding_dim % 2 == 1:
emb = mx.pad(emb, [(0, 0), (0, 1)])
return emb
class TimestepEmbedding(nn.Module):
"""MLP for timestep embedding."""
def __init__(self, in_channels: int, time_embed_dim: int):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
self.act = nn.SiLU()
def __call__(self, sample: mx.array) -> mx.array:
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class PixArtAlphaTimestepEmbedder(nn.Module):
"""Combined timestep embedding (sinusoidal + MLP)."""
def __init__(self, embedding_dim: int):
super().__init__()
self.timestep_embedder = TimestepEmbedding(
in_channels=256,
time_embed_dim=embedding_dim
)
def __call__(self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32) -> mx.array:
timesteps_proj = get_timestep_embedding(
timestep,
embedding_dim=256,
flip_sin_to_cos=True,
downscale_freq_shift=0
)
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
return timesteps_emb
class ResnetBlock3DSimple(nn.Module):
"""ResNet block with optional timestep conditioning.
Weight keys: conv1.conv, conv2.conv, scale_shift_table
"""
def __init__(
self,
channels: int,
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
timestep_conditioning: bool = False,
):
super().__init__()
self.timestep_conditioning = timestep_conditioning
# Nested conv structure to match PyTorch naming: conv1.conv.weight
self.conv1 = self._make_conv_wrapper(channels, channels, spatial_padding_mode)
self.conv2 = self._make_conv_wrapper(channels, channels, spatial_padding_mode)
self.act = nn.SiLU()
# Scale-shift table for timestep conditioning: [shift1, scale1, shift2, scale2]
if timestep_conditioning:
self.scale_shift_table = mx.zeros((4, channels))
def _make_conv_wrapper(self, in_ch, out_ch, padding_mode):
"""Create a wrapper object with a 'conv' attribute to match PyTorch naming."""
class ConvWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
self_inner.conv = CausalConv3d(
in_channels=in_ch,
out_channels=out_ch,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
return ConvWrapper()
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization."""
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
def __call__(
self,
x: mx.array,
causal: bool = False,
timestep_embed: Optional[mx.array] = None,
) -> mx.array:
residual = x
batch_size = x.shape[0]
# Block 1 with optional timestep conditioning
x = self.pixel_norm(x)
if self.timestep_conditioning and timestep_embed is not None:
# scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1)
# Combine table with timestep embedding
ada_values = self.scale_shift_table[None, :, :, None, None, None] # (1, 4, C, 1, 1, 1)
# Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1)
channels = self.scale_shift_table.shape[1]
ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1)
ada_values = ada_values + ts_reshaped
shift1 = ada_values[:, 0] # (B, C, 1, 1, 1)
scale1 = ada_values[:, 1]
shift2 = ada_values[:, 2]
scale2 = ada_values[:, 3]
x = x * (1 + scale1) + shift1
x = self.act(x)
x = self.conv1(x, causal=causal)
# Block 2 with optional timestep conditioning
x = self.pixel_norm(x)
if self.timestep_conditioning and timestep_embed is not None:
x = x * (1 + scale2) + shift2
x = self.act(x)
x = self.conv2(x, causal=causal)
return x + residual
class ResBlockGroup(nn.Module):
"""Group of ResNet blocks with shared timestep embedding.
PyTorch naming: res_blocks.0, res_blocks.1, etc.
"""
def __init__(
self,
channels: int,
num_layers: int = 5,
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
timestep_conditioning: bool = False,
):
super().__init__()
self.timestep_conditioning = timestep_conditioning
# Time embedder for this block group: embed_dim = 4 * channels
if timestep_conditioning:
self.time_embedder = PixArtAlphaTimestepEmbedder(
embedding_dim=channels * 4
)
self.res_blocks = [
ResnetBlock3DSimple(
channels,
spatial_padding_mode,
timestep_conditioning=timestep_conditioning
)
for _ in range(num_layers)
]
def __call__(
self,
x: mx.array,
causal: bool = False,
timestep: Optional[mx.array] = None,
) -> mx.array:
timestep_embed = None
if self.timestep_conditioning and timestep is not None:
batch_size = x.shape[0]
timestep_embed = self.time_embedder(
timestep.flatten(),
hidden_dtype=x.dtype
)
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
for res_block in self.res_blocks:
x = res_block(x, causal=causal, timestep_embed=timestep_embed)
return x
class LTX2VideoDecoder(nn.Module):
"""LTX-2 Video VAE Decoder with timestep conditioning.
Architecture:
- conv_in: 128 -> 1024
- up_blocks.0: 5 ResBlocks at 1024 (with timestep)
- up_blocks.1: Upsampler 1024 -> 512
- up_blocks.2: 5 ResBlocks at 512 (with timestep)
- up_blocks.3: Upsampler 512 -> 256
- up_blocks.4: 5 ResBlocks at 256 (with timestep)
- up_blocks.5: Upsampler 256 -> 128
- up_blocks.6: 5 ResBlocks at 128 (with timestep)
- conv_out: 128 -> 48 (3 * 4^2 for patch_size=4)
"""
def __init__(
self,
in_channels: int = 128,
out_channels: int = 3,
patch_size: int = 4,
num_layers_per_block: int = 5,
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
timestep_conditioning: bool = True,
):
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.timestep_conditioning = timestep_conditioning
# Decode parameters (configurable via constructor)
self.decode_noise_scale = 0.025 # Set to 0.0 to disable noise
self.decode_timestep = 0.05
# Per-channel statistics for denormalization (loaded from weights)
self.latents_mean = mx.zeros((in_channels,))
self.latents_std = mx.ones((in_channels,))
# Initial conv: 128 -> 1024
class ConvInWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
self_inner.conv = CausalConv3d(
in_channels=in_channels,
out_channels=1024,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
self.conv_in = ConvInWrapper()
# Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample
self.up_blocks = [
ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
DepthToSpaceUpsample(
dims=3,
in_channels=1024,
stride=(2, 2, 2),
residual=True, # CRITICAL: Must match PyTorch config!
out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode,
),
ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
DepthToSpaceUpsample(
dims=3,
in_channels=512,
stride=(2, 2, 2),
residual=True, # CRITICAL: Must match PyTorch config!
out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode,
),
ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
DepthToSpaceUpsample(
dims=3,
in_channels=256,
stride=(2, 2, 2),
residual=True, # CRITICAL: Must match PyTorch config!
out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode,
),
ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
]
final_out_channels = out_channels * patch_size * patch_size
class ConvOutWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
self_inner.conv = CausalConv3d(
in_channels=128,
out_channels=final_out_channels,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
self.conv_out = ConvOutWrapper()
self.act = nn.SiLU()
if timestep_conditioning:
self.timestep_scale_multiplier = mx.array(1000.0)
self.last_time_embedder = PixArtAlphaTimestepEmbedder(
embedding_dim=128 * 2 # 256, matches (2, 128) table
)
self.last_scale_shift_table = mx.zeros((2, 128))
def denormalize(self, x: mx.array) -> mx.array:
"""Denormalize latents using per-channel statistics."""
mean = self.latents_mean.reshape(1, -1, 1, 1, 1)
std = self.latents_std.reshape(1, -1, 1, 1, 1)
return x * std + mean
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization."""
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
def __call__(
self,
sample: mx.array,
causal: bool = False,
timestep: Optional[mx.array] = None,
debug: bool = False,
) -> mx.array:
def debug_stats(name, t):
if debug:
mx.eval(t)
print(f" [VAE] {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
batch_size = sample.shape[0]
if debug:
debug_stats("Input", sample)
# Add noise if timestep conditioning is enabled
if self.timestep_conditioning:
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
sample = noise + (1.0 - self.decode_noise_scale) * sample
if debug:
debug_stats("After noise", sample)
if debug:
print(f" [VAE] Denorm stats - mean: [{self.latents_mean.min().item():.4f}, {self.latents_mean.max().item():.4f}], std: [{self.latents_std.min().item():.4f}, {self.latents_std.max().item():.4f}]")
sample = self.denormalize(sample)
if debug:
debug_stats("After denormalize", sample)
if timestep is None and self.timestep_conditioning:
timestep = mx.full((batch_size,), self.decode_timestep)
scaled_timestep = None
if self.timestep_conditioning and timestep is not None:
scaled_timestep = timestep * self.timestep_scale_multiplier
x = self.conv_in(sample, causal=causal)
if debug:
debug_stats("After conv_in", x)
for i, block in enumerate(self.up_blocks):
if isinstance(block, ResBlockGroup):
x = block(x, causal=causal, timestep=scaled_timestep)
else:
x = block(x, causal=causal)
if debug:
block_type = type(block).__name__
debug_stats(f"After up_blocks[{i}] ({block_type})", x)
x = self.pixel_norm(x)
if debug:
debug_stats("After pixel_norm", x)
if self.timestep_conditioning and scaled_timestep is not None:
embedded_timestep = self.last_time_embedder(
scaled_timestep.flatten(),
hidden_dtype=x.dtype
)
embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1)
ada_values = self.last_scale_shift_table[None, :, :, None, None, None] # (1, 2, 128, 1, 1, 1)
ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1)
ada_values = ada_values + ts_reshaped
shift = ada_values[:, 0] # (B, 128, 1, 1, 1)
scale = ada_values[:, 1]
x = x * (1 + scale) + shift
if debug:
debug_stats("After timestep modulation", x)
x = self.act(x)
if debug:
debug_stats("After activation", x)
x = self.conv_out(x, causal=causal)
if debug:
debug_stats("After conv_out", x)
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
if debug:
debug_stats("After unpatchify", x)
return x
def load_vae_decoder(model_path: str, timestep_conditioning: bool = True) -> LTX2VideoDecoder:
from pathlib import Path
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
model_path = Path(model_path)
# Try to find the weights file
if model_path.is_file() and model_path.suffix == ".safetensors":
weights_path = model_path
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
else:
raise FileNotFoundError(f"VAE weights not found at {model_path}")
print(f"Loading VAE decoder from {weights_path}...")
weights = mx.load(str(weights_path))
# Determine prefix based on weight keys
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
has_decoder_prefix = any(k.startswith("decoder.") for k in weights.keys())
if has_vae_prefix:
prefix = "vae.decoder."
stats_prefix = "vae.per_channel_statistics."
elif has_decoder_prefix:
prefix = "decoder."
stats_prefix = ""
else:
prefix = ""
stats_prefix = ""
# Load per-channel statistics for denormalization
# Note: use std-of-means (not mean-of-stds) for proper denormalization
mean_key = f"{stats_prefix}mean-of-means" if stats_prefix else "latents_mean"
std_key = f"{stats_prefix}std-of-means" if stats_prefix else "latents_std"
if mean_key in weights:
decoder.latents_mean = weights[mean_key]
print(f" Loaded latent mean: shape {decoder.latents_mean.shape}")
if std_key in weights:
decoder.latents_std = weights[std_key]
print(f" Loaded latent std: shape {decoder.latents_std.shape}")
# Build decoder weights dict with key remapping
decoder_weights = {}
for key, value in weights.items():
if not key.startswith(prefix):
continue
# Remove prefix
new_key = key[len(prefix):]
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".conv.weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
if ".conv.bias" in key:
pass # bias doesn't need transpose
if ".conv.weight" in new_key or ".conv.bias" in new_key:
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
decoder_weights[new_key] = value
print(f" Found {len(decoder_weights)} decoder weights")
ts_keys = [k for k in decoder_weights.keys() if "scale_shift" in k or "time_embedder" in k or "timestep_scale" in k]
print(f" Found {len(ts_keys)} timestep conditioning weights")
# Load weights
decoder.load_weights(list(decoder_weights.items()), strict=False)
print("VAE decoder loaded successfully")
return decoder

View File

@@ -0,0 +1,120 @@
"""Operations for Video VAE."""
from typing import List, Tuple
import mlx.core as mx
import mlx.nn as nn
def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
"""Convert video to patches.
Moves spatial pixels from H, W dimensions to channel dimension.
Args:
x: Input tensor of shape (B, C, F, H, W)
patch_size_hw: Spatial patch size
patch_size_t: Temporal patch size
Returns:
Patched tensor of shape (B, C * patch_size_hw^2, F, H/patch_size_hw, W/patch_size_hw)
"""
b, c, f, h, w = x.shape
# Check dimensions are divisible
assert h % patch_size_hw == 0 and w % patch_size_hw == 0
assert f % patch_size_t == 0
# New dimensions
new_h = h // patch_size_hw
new_w = w // patch_size_hw
new_f = f // patch_size_t
new_c = c * patch_size_hw * patch_size_hw * patch_size_t
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw))
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, ph, pw, F', H', W')
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
# Reshape: (B, C, pt, ph, pw, F', H', W') -> (B, C*pt*ph*pw, F', H', W')
x = mx.reshape(x, (b, new_c, new_f, new_h, new_w))
return x
def unpatchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
"""Convert patches back to video.
Inverse of patchify - moves pixels from channel dimension back to spatial.
Matches PyTorch einops: "b (c p r q) f h w -> b c (f p) (h q) (w r)"
where p=patch_size_t, r=patch_size_hw (width), q=patch_size_hw (height)
Args:
x: Patched tensor of shape (B, C * patch_size_hw^2, F, H, W)
patch_size_hw: Spatial patch size
patch_size_t: Temporal patch size
Returns:
Video tensor of shape (B, C, F * patch_size_t, H * patch_size_hw, W * patch_size_hw)
"""
b, c_packed, f, h, w = x.shape
# Calculate original channel count
c = c_packed // (patch_size_hw * patch_size_hw * patch_size_t)
# Reshape: (B, C*pt*pr*pq, F, H, W) -> (B, C, pt, pr, pq, F, H, W)
# where pt=temporal, pr=width_patch (r), pq=height_patch (q)
# Channel layout from PyTorch is (c, p, r, q) = (c, temporal, width, height)
x = mx.reshape(x, (b, c, patch_size_t, patch_size_hw, patch_size_hw, f, h, w))
# Permute to interleave patches with spatial dims:
# (B, C, pt, pr, pq, F, H, W) -> (B, C, F, pt, H, pq, W, pr)
x = mx.transpose(x, (0, 1, 5, 2, 6, 4, 7, 3))
# Reshape: (B, C, F, pt, H, pq, W, pr) -> (B, C, F*pt, H*pq, W*pr)
x = mx.reshape(x, (b, c, f * patch_size_t, h * patch_size_hw, w * patch_size_hw))
return x
class PerChannelStatistics(nn.Module):
def __init__(self, latent_channels: int = 128):
super().__init__()
self.latent_channels = latent_channels
# Learnable per-channel mean and std
self.mean = mx.zeros((latent_channels,))
self.std = mx.ones((latent_channels,))
def normalize(self, x: mx.array) -> mx.array:
"""Normalize latents using per-channel statistics.
Args:
x: Input tensor of shape (B, C, ...)
Returns:
Normalized tensor
"""
# Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1)
mean = self.mean.reshape(1, -1, 1, 1, 1)
std = self.std.reshape(1, -1, 1, 1, 1)
return (x - mean) / std
def un_normalize(self, x: mx.array) -> mx.array:
"""Denormalize latents using per-channel statistics.
Args:
x: Normalized tensor of shape (B, C, ...)
Returns:
Denormalized tensor
"""
mean = self.mean.reshape(1, -1, 1, 1, 1)
std = self.std.reshape(1, -1, 1, 1, 1)
return x * std + mean

View File

@@ -0,0 +1,171 @@
"""ResNet blocks for Video VAE."""
from enum import Enum
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.utils import PixelNorm
class NormLayerType(Enum):
GROUP_NORM = "group_norm"
PIXEL_NORM = "pixel_norm"
def get_norm_layer(
norm_type: NormLayerType,
num_channels: int,
num_groups: int = 32,
eps: float = 1e-6,
) -> nn.Module:
if norm_type == NormLayerType.GROUP_NORM:
return nn.GroupNorm(num_groups=num_groups, dims=num_channels, eps=eps)
elif norm_type == NormLayerType.PIXEL_NORM:
return PixelNorm(eps=eps)
else:
raise ValueError(f"Unknown norm type: {norm_type}")
class ResnetBlock3D(nn.Module):
def __init__(
self,
dims: int,
in_channels: int,
out_channels: Optional[int] = None,
eps: float = 1e-6,
groups: int = 32,
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
inject_noise: bool = False,
timestep_conditioning: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
out_channels = out_channels or in_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.inject_noise = inject_noise
# First normalization and convolution
self.norm1 = get_norm_layer(norm_layer, in_channels, groups, eps)
self.conv1 = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
# Second normalization and convolution
self.norm2 = get_norm_layer(norm_layer, out_channels, groups, eps)
self.conv2 = CausalConv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
# Shortcut connection if channels change
if in_channels != out_channels:
self.shortcut = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
spatial_padding_mode=spatial_padding_mode,
)
else:
self.shortcut = None
# Activation
self.act = nn.SiLU()
def __call__(
self,
x: mx.array,
causal: bool = True,
generator: Optional[int] = None,
) -> mx.array:
residual = x
# First block
x = self.norm1(x)
x = self.act(x)
x = self.conv1(x, causal=causal)
# Inject noise if enabled
if self.inject_noise and generator is not None:
noise = mx.random.normal(x.shape)
x = x + noise * 0.01
# Second block
x = self.norm2(x)
x = self.act(x)
x = self.conv2(x, causal=causal)
# Shortcut
if self.shortcut is not None:
residual = self.shortcut(residual, causal=causal)
return x + residual
class UNetMidBlock3D(nn.Module):
def __init__(
self,
dims: int,
in_channels: int,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_groups: int = 32,
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
inject_noise: bool = False,
timestep_conditioning: bool = False,
attention_head_dim: Optional[int] = None,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
self.num_layers = num_layers
# Create ResNet blocks
self.resnets = [
ResnetBlock3D(
dims=dims,
in_channels=in_channels,
out_channels=in_channels,
eps=resnet_eps,
groups=resnet_groups,
norm_layer=norm_layer,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
for _ in range(num_layers)
]
def __call__(
self,
x: mx.array,
causal: bool = True,
timestep: Optional[mx.array] = None,
generator: Optional[int] = None,
) -> mx.array:
for resnet in self.resnets:
x = resnet(x, causal=causal, generator=generator)
return x

View File

@@ -0,0 +1,173 @@
"""Sampling operations for Video VAE (upsampling/downsampling)."""
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
class SpaceToDepthDownsample(nn.Module):
def __init__(
self,
dims: int,
in_channels: int,
out_channels: int,
stride: Union[int, Tuple[int, int, int]],
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
if isinstance(stride, int):
stride = (stride, stride, stride)
self.stride = stride
self.dims = dims
# Calculate the multiplier for channels
multiplier = stride[0] * stride[1] * stride[2]
intermediate_channels = in_channels * multiplier
# 1x1x1 convolution to adjust channels
self.conv = CausalConv3d(
in_channels=intermediate_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
spatial_padding_mode=spatial_padding_mode,
)
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
b, c, d, h, w = x.shape
st, sh, sw = self.stride
# Pad if necessary to make dimensions divisible by stride
pad_d = (st - d % st) % st
pad_h = (sh - h % sh) % sh
pad_w = (sw - w % sw) % sw
if pad_d > 0 or pad_h > 0 or pad_w > 0:
# For causal, pad at the end of temporal dimension
if causal:
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)])
else:
x = mx.pad(x, [(0, 0), (0, 0), (pad_d // 2, pad_d - pad_d // 2),
(pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2)])
b, c, d, h, w = x.shape
# Reshape to group spatial elements
# (B, C, D, H, W) -> (B, C, D/st, st, H/sh, sh, W/sw, sw)
x = mx.reshape(x, (b, c, d // st, st, h // sh, sh, w // sw, sw))
# Permute to move stride elements to channel dim
# (B, C, D', st, H', sh, W', sw) -> (B, C, st, sh, sw, D', H', W')
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
# Reshape to combine channels
# (B, C, st, sh, sw, D', H', W') -> (B, C*st*sh*sw, D', H', W')
new_c = c * st * sh * sw
new_d = d // st
new_h = h // sh
new_w = w // sw
x = mx.reshape(x, (b, new_c, new_d, new_h, new_w))
# Apply 1x1 conv to adjust channels
x = self.conv(x, causal=causal)
return x
class DepthToSpaceUpsample(nn.Module):
def __init__(
self,
dims: int,
in_channels: int,
stride: Union[int, Tuple[int, int, int]],
residual: bool = False,
out_channels_reduction_factor: int = 1,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
if isinstance(stride, int):
stride = (stride, stride, stride)
self.stride = stride
self.dims = dims
self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor
# Calculate output channels
multiplier = stride[0] * stride[1] * stride[2]
out_channels = in_channels // out_channels_reduction_factor
self.out_channels = out_channels
# 3x3x3 convolution to prepare channels for unpacking (matches PyTorch)
self.conv = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels * multiplier,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def _depth_to_space(self, x: mx.array) -> mx.array:
b, c_packed, d, h, w = x.shape
st, sh, sw = self.stride
c = c_packed // (st * sh * sw)
# (B, C*st*sh*sw, D, H, W) -> (B, C, st, sh, sw, D, H, W)
x = mx.reshape(x, (b, c, st, sh, sw, d, h, w))
# (B, C, st, sh, sw, D, H, W) -> (B, C, D, st, H, sh, W, sw)
x = mx.transpose(x, (0, 1, 5, 2, 6, 3, 7, 4))
# (B, C, D, st, H, sh, W, sw) -> (B, C, D*st, H*sh, W*sw)
x = mx.reshape(x, (b, c, d * st, h * sh, w * sw))
return x
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
b, c, d, h, w = x.shape
st, sh, sw = self.stride
# Compute residual path if enabled
x_residual = None
if self.residual:
# Reshape input: treat channels as spatial factors
# "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)"
x_residual = self._depth_to_space(x)
# Tile channels to match output (PyTorch .repeat() tiles, not element-repeat!)
# num_repeat = prod(stride) / out_channels_reduction_factor
num_repeat = (st * sh * sw) // self.out_channels_reduction_factor
x_residual = mx.tile(x_residual, (1, num_repeat, 1, 1, 1))
# Remove first temporal frame if temporal upsampling
if st > 1:
x_residual = x_residual[:, :, 1:, :, :]
# Apply conv
x = self.conv(x, causal=causal)
# Depth to space rearrangement
x = self._depth_to_space(x)
# Remove first frame for causal temporal upsampling
if st > 1:
x = x[:, :, 1:, :, :]
# Add residual
if self.residual and x_residual is not None:
x = x + x_residual
return x

View File

@@ -0,0 +1,528 @@
"""Video VAE Encoder and Decoder for LTX-2."""
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx.video_vae.ops import PerChannelStatistics, patchify, unpatchify
from mlx_video.models.ltx.video_vae.resnet import (
NormLayerType,
ResnetBlock3D,
UNetMidBlock3D,
get_norm_layer,
)
from mlx_video.models.ltx.video_vae.sampling import (
DepthToSpaceUpsample,
SpaceToDepthDownsample,
)
from mlx_video.utils import PixelNorm
class LogVarianceType(Enum):
"""Log variance mode for VAE."""
PER_CHANNEL = "per_channel"
UNIFORM = "uniform"
CONSTANT = "constant"
NONE = "none"
def _make_encoder_block(
block_name: str,
block_config: Dict[str, Any],
in_channels: int,
convolution_dimensions: int,
norm_layer: NormLayerType,
norm_num_groups: int,
spatial_padding_mode: PaddingModeType,
) -> Tuple[nn.Module, int]:
"""Create an encoder block.
Args:
block_name: Type of block
block_config: Block configuration
in_channels: Input channels
convolution_dimensions: Number of dimensions
norm_layer: Normalization layer type
norm_num_groups: Number of groups for group norm
spatial_padding_mode: Padding mode
Returns:
Tuple of (block, output_channels)
"""
out_channels = in_channels
if block_name == "res_x":
block = UNetMidBlock3D(
dims=convolution_dimensions,
in_channels=in_channels,
num_layers=block_config["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
out_channels = in_channels * block_config.get("multiplier", 2)
block = ResnetBlock3D(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(2, 1, 1),
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(1, 2, 2),
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
block = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(2, 2, 2),
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_x_y":
out_channels = in_channels * block_config.get("multiplier", 2)
block = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(2, 2, 2),
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_res":
out_channels = in_channels * block_config.get("multiplier", 2)
block = SpaceToDepthDownsample(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space_res":
out_channels = in_channels * block_config.get("multiplier", 2)
block = SpaceToDepthDownsample(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time_res":
out_channels = in_channels * block_config.get("multiplier", 2)
block = SpaceToDepthDownsample(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"Unknown encoder block: {block_name}")
return block, out_channels
def _make_decoder_block(
block_name: str,
block_config: Dict[str, Any],
in_channels: int,
convolution_dimensions: int,
norm_layer: NormLayerType,
timestep_conditioning: bool,
norm_num_groups: int,
spatial_padding_mode: PaddingModeType,
) -> Tuple[nn.Module, int]:
"""Create a decoder block."""
out_channels = in_channels
if block_name == "res_x":
block = UNetMidBlock3D(
dims=convolution_dimensions,
in_channels=in_channels,
num_layers=block_config["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_config.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
out_channels = in_channels // block_config.get("multiplier", 2)
block = ResnetBlock3D(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_config.get("inject_noise", False),
timestep_conditioning=False,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = DepthToSpaceUpsample(
dims=convolution_dimensions,
in_channels=in_channels,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = DepthToSpaceUpsample(
dims=convolution_dimensions,
in_channels=in_channels,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
out_channels = in_channels // block_config.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=convolution_dimensions,
in_channels=in_channels,
stride=(2, 2, 2),
residual=block_config.get("residual", False),
out_channels_reduction_factor=block_config.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"Unknown decoder block: {block_name}")
return block, out_channels
class VideoEncoder(nn.Module):
_DEFAULT_NORM_NUM_GROUPS = 32
def __init__(
self,
convolution_dimensions: int = 3,
in_channels: int = 3,
out_channels: int = 128,
encoder_blocks: List[Tuple[str, Any]] = None,
patch_size: int = 4,
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
latent_log_var: LogVarianceType = LogVarianceType.UNIFORM,
encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
"""Initialize VideoEncoder.
Args:
convolution_dimensions: Number of dimensions (3 for video)
in_channels: Input channels (3 for RGB)
out_channels: Output latent channels
encoder_blocks: List of (block_name, config) tuples
patch_size: Spatial patch size
norm_layer: Normalization layer type
latent_log_var: Log variance mode
encoder_spatial_padding_mode: Padding mode
"""
super().__init__()
if encoder_blocks is None:
encoder_blocks = []
self.patch_size = patch_size
self.norm_layer = norm_layer
self.latent_channels = out_channels
self.latent_log_var = latent_log_var
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
# Per-channel statistics for normalizing latents
self.per_channel_statistics = PerChannelStatistics(latent_channels=out_channels)
# After patchify, channels increase by patch_size^2
in_channels = in_channels * patch_size ** 2
feature_channels = out_channels
# Initial convolution
self.conv_in = CausalConv3d(
in_channels=in_channels,
out_channels=feature_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=encoder_spatial_padding_mode,
)
# Build encoder blocks
self.down_blocks = []
for block_name, block_params in encoder_blocks:
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block, feature_channels = _make_encoder_block(
block_name=block_name,
block_config=block_config,
in_channels=feature_channels,
convolution_dimensions=convolution_dimensions,
norm_layer=norm_layer,
norm_num_groups=self._norm_num_groups,
spatial_padding_mode=encoder_spatial_padding_mode,
)
self.down_blocks.append(block)
# Output normalization and convolution
if norm_layer == NormLayerType.GROUP_NORM:
self.conv_norm_out = nn.GroupNorm(
num_groups=self._norm_num_groups,
dims=feature_channels,
eps=1e-6,
)
elif norm_layer == NormLayerType.PIXEL_NORM:
self.conv_norm_out = PixelNorm()
self.conv_act = nn.SiLU()
# Calculate output convolution channels
conv_out_channels = out_channels
if latent_log_var == LogVarianceType.PER_CHANNEL:
conv_out_channels *= 2
elif latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
conv_out_channels += 1
self.conv_out = CausalConv3d(
in_channels=feature_channels,
out_channels=conv_out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=encoder_spatial_padding_mode,
)
def __call__(self, sample: mx.array) -> mx.array:
"""Encode video to latent representation.
Args:
sample: Input video of shape (B, C, F, H, W).
F must be 1 + 8*k (e.g., 1, 9, 17, 25, 33...)
Returns:
Normalized latent means of shape (B, 128, F', H', W')
"""
# Validate frame count
frames_count = sample.shape[2]
if ((frames_count - 1) % 8) != 0:
raise ValueError(
"Invalid number of frames: Encode input must have 1 + 8 * x frames "
f"(e.g., 1, 9, 17, ...). Got {frames_count} frames."
)
# Initial patchify
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
sample = self.conv_in(sample, causal=True)
# Process through encoder blocks
for down_block in self.down_blocks:
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
sample = down_block(sample, causal=True)
else:
sample = down_block(sample, causal=True)
# Output processing
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=True)
# Handle log variance modes
if self.latent_log_var == LogVarianceType.UNIFORM:
means = sample[:, :-1, ...]
logvar = sample[:, -1:, ...]
num_channels = means.shape[1]
repeated_logvar = mx.tile(logvar, (1, num_channels, 1, 1, 1))
sample = mx.concatenate([means, repeated_logvar], axis=1)
elif self.latent_log_var == LogVarianceType.CONSTANT:
sample = sample[:, :-1, ...]
approx_ln_0 = -30
sample = mx.concatenate([
sample,
mx.full_like(sample, approx_ln_0),
], axis=1)
# Split into means and logvar, normalize means
means = sample[:, :self.latent_channels, ...]
return self.per_channel_statistics.normalize(means)
class VideoDecoder(nn.Module):
_DEFAULT_NORM_NUM_GROUPS = 32
def __init__(
self,
convolution_dimensions: int = 3,
in_channels: int = 128,
out_channels: int = 3,
decoder_blocks: List[Tuple[str, Any]] = None,
patch_size: int = 4,
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
causal: bool = False,
timestep_conditioning: bool = False,
decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
):
"""Initialize VideoDecoder.
Args:
convolution_dimensions: Number of dimensions
in_channels: Input latent channels
out_channels: Output channels (3 for RGB)
decoder_blocks: List of (block_name, config) tuples
patch_size: Spatial patch size
norm_layer: Normalization layer type
causal: Whether to use causal convolutions
timestep_conditioning: Whether to use timestep conditioning
decoder_spatial_padding_mode: Padding mode
"""
super().__init__()
if decoder_blocks is None:
decoder_blocks = []
self.patch_size = patch_size
out_channels = out_channels * patch_size ** 2
self.causal = causal
self.timestep_conditioning = timestep_conditioning
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
# Per-channel statistics for denormalizing latents
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
# Noise and timestep parameters
self.decode_noise_scale = 0.025
self.decode_timestep = 0.05
# Compute initial feature channels
feature_channels = in_channels
for block_name, block_params in list(reversed(decoder_blocks)):
block_config = block_params if isinstance(block_params, dict) else {}
if block_name == "res_x_y":
feature_channels = feature_channels * block_config.get("multiplier", 2)
if block_name == "compress_all":
feature_channels = feature_channels * block_config.get("multiplier", 1)
# Initial convolution
self.conv_in = CausalConv3d(
in_channels=in_channels,
out_channels=feature_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=decoder_spatial_padding_mode,
)
# Build decoder blocks (reversed order)
self.up_blocks = []
for block_name, block_params in list(reversed(decoder_blocks)):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block, feature_channels = _make_decoder_block(
block_name=block_name,
block_config=block_config,
in_channels=feature_channels,
convolution_dimensions=convolution_dimensions,
norm_layer=norm_layer,
timestep_conditioning=timestep_conditioning,
norm_num_groups=self._norm_num_groups,
spatial_padding_mode=decoder_spatial_padding_mode,
)
self.up_blocks.append(block)
# Output normalization
if norm_layer == NormLayerType.GROUP_NORM:
self.conv_norm_out = nn.GroupNorm(
num_groups=self._norm_num_groups,
dims=feature_channels,
eps=1e-6,
)
elif norm_layer == NormLayerType.PIXEL_NORM:
self.conv_norm_out = PixelNorm()
self.conv_act = nn.SiLU()
self.conv_out = CausalConv3d(
in_channels=feature_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=decoder_spatial_padding_mode,
)
def __call__(
self,
sample: mx.array,
timestep: Optional[mx.array] = None,
) -> mx.array:
"""Decode latent to video.
Args:
sample: Latent tensor of shape (B, 128, F', H', W')
timestep: Optional timestep for conditioning
Returns:
Decoded video of shape (B, 3, F, H, W)
"""
batch_size = sample.shape[0]
# Add noise if timestep conditioning is enabled
if self.timestep_conditioning:
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
sample = noise + (1.0 - self.decode_noise_scale) * sample
# Denormalize latents
sample = self.per_channel_statistics.un_normalize(sample)
# Use default timestep if not provided
if timestep is None and self.timestep_conditioning:
timestep = mx.full((batch_size,), self.decode_timestep)
# Initial convolution
sample = self.conv_in(sample, causal=self.causal)
# Process through decoder blocks
for up_block in self.up_blocks:
if isinstance(up_block, UNetMidBlock3D):
sample = up_block(sample, causal=self.causal)
elif isinstance(up_block, ResnetBlock3D):
sample = up_block(sample, causal=self.causal)
else:
sample = up_block(sample, causal=self.causal)
# Output processing
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=self.causal)
# Unpatchify to restore spatial resolution
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample

165
mlx_video/postprocess.py Normal file
View File

@@ -0,0 +1,165 @@
import numpy as np
from typing import Optional
def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75) -> np.ndarray:
"""Apply bilateral filter to reduce grid artifacts while preserving edges.
Args:
image: Input image as uint8 numpy array (H, W, C)
d: Diameter of each pixel neighborhood
sigma_color: Filter sigma in the color space
sigma_space: Filter sigma in the coordinate space
Returns:
Filtered image
"""
try:
import cv2
return cv2.bilateralFilter(image, d, sigma_color, sigma_space)
except ImportError:
# Fallback to simple Gaussian blur if cv2 not available
return gaussian_blur(image, kernel_size=3)
def gaussian_blur(image: np.ndarray, kernel_size: int = 3) -> np.ndarray:
"""Apply Gaussian blur.
Args:
image: Input image as uint8 numpy array (H, W, C)
kernel_size: Size of the Gaussian kernel (must be odd)
Returns:
Blurred image
"""
try:
import cv2
return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
except ImportError:
# Simple box blur fallback
from scipy.ndimage import uniform_filter
return uniform_filter(image, size=(kernel_size, kernel_size, 1)).astype(np.uint8)
def unsharp_mask(image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, amount: float = 1.0) -> np.ndarray:
"""Apply unsharp masking to enhance edges after blur.
Args:
image: Input image as uint8 numpy array
kernel_size: Size of the Gaussian kernel
sigma: Gaussian sigma
amount: Strength of sharpening
Returns:
Sharpened image
"""
try:
import cv2
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
sharpened = cv2.addWeighted(image, 1 + amount, blurred, -amount, 0)
return np.clip(sharpened, 0, 255).astype(np.uint8)
except ImportError:
return image
def reduce_grid_artifacts(
video: np.ndarray,
method: str = "bilateral",
strength: float = 1.0,
) -> np.ndarray:
"""Reduce grid artifacts in video frames.
Args:
video: Video as numpy array (F, H, W, C) uint8
method: "bilateral", "gaussian", or "frequency"
strength: How strong to apply the filter (0-1)
Returns:
Processed video
"""
if method == "bilateral":
d = max(3, int(5 * strength))
sigma = 50 + 50 * strength
processed = np.stack([
bilateral_filter(frame, d=d, sigma_color=sigma, sigma_space=sigma)
for frame in video
])
elif method == "gaussian":
kernel_size = max(3, int(3 + 4 * strength))
if kernel_size % 2 == 0:
kernel_size += 1
processed = np.stack([
gaussian_blur(frame, kernel_size=kernel_size)
for frame in video
])
elif method == "frequency":
processed = np.stack([
remove_grid_frequency(frame, grid_size=8)
for frame in video
])
else:
raise ValueError(f"Unknown method: {method}")
# Optionally sharpen to recover some detail
if strength < 1.0:
# Blend with original based on strength
alpha = strength
processed = (alpha * processed + (1 - alpha) * video).astype(np.uint8)
return processed
def remove_grid_frequency(frame: np.ndarray, grid_size: int = 8) -> np.ndarray:
"""Remove grid-frequency components using FFT.
Args:
frame: Input frame (H, W, C) uint8
grid_size: Expected grid periodicity in pixels
Returns:
Filtered frame
"""
result = np.zeros_like(frame)
for c in range(frame.shape[2]):
channel = frame[:, :, c].astype(np.float32)
h, w = channel.shape
# FFT
fft = np.fft.fft2(channel)
fft_shifted = np.fft.fftshift(fft)
# Create notch filter at grid frequencies
cy, cx = h // 2, w // 2
mask = np.ones((h, w), dtype=np.float32)
# Attenuate frequencies at grid periodicity
freq_y = h // grid_size
freq_x = w // grid_size
for fy in range(-2, 3):
for fx in range(-2, 3):
if fy == 0 and fx == 0:
continue
y_pos = cy + fy * freq_y
x_pos = cx + fx * freq_x
if 0 <= y_pos < h and 0 <= x_pos < w:
# Gaussian attenuation around the frequency
for dy in range(-2, 3):
for dx in range(-2, 3):
yy, xx = y_pos + dy, x_pos + dx
if 0 <= yy < h and 0 <= xx < w:
dist = np.sqrt(dy**2 + dx**2)
mask[yy, xx] *= min(1.0, dist / 3.0)
# Apply mask and inverse FFT
fft_filtered = fft_shifted * mask
channel_filtered = np.fft.ifft2(np.fft.ifftshift(fft_filtered)).real
result[:, :, c] = np.clip(channel_filtered, 0, 255).astype(np.uint8)
return result

View File

@@ -0,0 +1,26 @@
import mlx.core as mx
import mlx.nn as nn
class PixArtAlphaTextProjection(nn.Module):
def __init__(
self,
in_features: int,
hidden_size: int,
out_features: int | None = None,
bias: bool = True,
):
super().__init__()
out_features = out_features or hidden_size
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
self.act = nn.GELU(approx="precise")
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
def __call__(self, x: mx.array) -> mx.array:
x = self.linear1(x)
x = self.act(x)
x = self.linear2(x)
return x

127
mlx_video/utils.py Normal file
View File

@@ -0,0 +1,127 @@
"""Utility functions for MLX Video."""
import math
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from functools import partial
@partial(mx.compile, shapeless=True)
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],)), eps)
@partial(mx.compile, shapeless=True)
def to_denoised(
noisy: mx.array,
velocity: mx.array,
sigma: mx.array | float
) -> mx.array:
"""Convert velocity prediction to denoised output.
Given noisy input x_t and velocity prediction v, compute denoised x_0:
x_0 = x_t - sigma * v
Args:
noisy: Noisy input tensor x_t
velocity: Velocity prediction v
sigma: Noise level (scalar or per-sample)
Returns:
Denoised tensor x_0
"""
if isinstance(sigma, (int, float)):
return noisy - sigma * velocity
else:
# sigma is per-sample
while sigma.ndim < velocity.ndim:
sigma = mx.expand_dims(sigma, axis=-1)
return noisy - sigma * velocity
def repeat_interleave(x: mx.array, repeats: int, axis: int = -1) -> mx.array:
"""Repeat elements of tensor along an axis, similar to torch.repeat_interleave.
Args:
x: Input tensor
repeats: Number of repetitions for each element
axis: The axis along which to repeat values
Returns:
Tensor with repeated values
"""
# Handle negative axis
if axis < 0:
axis = x.ndim + axis
# Get shape
shape = list(x.shape)
# Expand dims, repeat, then reshape
x = mx.expand_dims(x, axis=axis + 1)
# Create tile pattern
tile_pattern = [1] * x.ndim
tile_pattern[axis + 1] = repeats
x = mx.tile(x, tile_pattern)
# Reshape to merge the repeated dimension
new_shape = shape.copy()
new_shape[axis] *= repeats
return mx.reshape(x, new_shape)
class PixelNorm(nn.Module):
def __init__(self, eps: float = 1e-6):
super().__init__()
self.eps = eps
def __call__(self, x: mx.array) -> mx.array:
return x / mx.sqrt(mx.mean(x * x, axis=1, keepdims=True) + self.eps)
def get_timestep_embedding(
timesteps: mx.array,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1.0,
scale: float = 1.0,
max_period: int = 10000,
) -> mx.array:
"""Create sinusoidal timestep embeddings.
Args:
timesteps: 1D tensor of timesteps
embedding_dim: Dimension of the embeddings to create
flip_sin_to_cos: If True, flip sin and cos ordering
downscale_freq_shift: Frequency shift factor
scale: Scale factor for timesteps
max_period: Maximum period for the sinusoids
Returns:
Tensor of shape (len(timesteps), embedding_dim)
"""
assert timesteps.ndim == 1, "Timesteps should be 1D"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * mx.arange(0, half_dim, dtype=mx.float32)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = mx.exp(exponent)
emb = (timesteps[:, None].astype(mx.float32) * scale) * emb[None, :]
# Compute sin and cos embeddings
if flip_sin_to_cos:
emb = mx.concatenate([mx.cos(emb), mx.sin(emb)], axis=-1)
else:
emb = mx.concatenate([mx.sin(emb), mx.cos(emb)], axis=-1)
# Zero pad if odd embedding dimension
if embedding_dim % 2 == 1:
emb = mx.pad(emb, [(0, 0), (0, 1)])
return emb

26
pyproject.toml Normal file
View File

@@ -0,0 +1,26 @@
[project]
name = "mlx-video"
version = "0.0.1"
description = "MLX-Video is the best package for inference and finetuning of Image-Video-Audio generation models on your Mac using MLX."
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"mlx>=0.22.0",
"numpy",
"safetensors",
"huggingface_hub",
"tqdm",
]
license = {text="MIT"}
authors = [
{name = "Prince Canuma", email = "prince.gdt@gmail.com"}
]
[project.optional-dependencies]
dev = [
"pytest",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

479
uv.lock generated Normal file
View File

@@ -0,0 +1,479 @@
version = 1
revision = 3
requires-python = ">=3.11"
[[package]]
name = "anyio"
version = "4.12.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "idna" },
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/96/f0/5eb65b2bb0d09ac6776f2eb54adee6abe8228ea05b20a5ad0e4945de8aac/anyio-4.12.1.tar.gz", hash = "sha256:41cfcc3a4c85d3f05c932da7c26d0201ac36f72abd4435ba90d0464a3ffed703", size = 228685, upload-time = "2026-01-06T11:45:21.246Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/38/0e/27be9fdef66e72d64c0cdc3cc2823101b80585f8119b5c112c2e8f5f7dab/anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c", size = 113592, upload-time = "2026-01-06T11:45:19.497Z" },
]
[[package]]
name = "certifi"
version = "2026.1.4"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/e0/2d/a891ca51311197f6ad14a7ef42e2399f36cf2f9bd44752b3dc4eab60fdc5/certifi-2026.1.4.tar.gz", hash = "sha256:ac726dd470482006e014ad384921ed6438c457018f4b3d204aea4281258b2120", size = 154268, upload-time = "2026-01-04T02:42:41.825Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e6/ad/3cc14f097111b4de0040c83a525973216457bbeeb63739ef1ed275c1c021/certifi-2026.1.4-py3-none-any.whl", hash = "sha256:9943707519e4add1115f44c2bc244f782c0249876bf51b6599fee1ffbedd685c", size = 152900, upload-time = "2026-01-04T02:42:40.15Z" },
]
[[package]]
name = "click"
version = "8.3.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3d/fa/656b739db8587d7b5dfa22e22ed02566950fbfbcdc20311993483657a5c0/click-8.3.1.tar.gz", hash = "sha256:12ff4785d337a1bb490bb7e9c2b1ee5da3112e94a8622f26a6c77f5d2fc6842a", size = 295065, upload-time = "2025-11-15T20:45:42.706Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl", hash = "sha256:981153a64e25f12d547d3426c367a4857371575ee7ad18df2a6183ab0545b2a6", size = 108274, upload-time = "2025-11-15T20:45:41.139Z" },
]
[[package]]
name = "colorama"
version = "0.4.6"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" },
]
[[package]]
name = "filelock"
version = "3.20.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c1/e0/a75dbe4bca1e7d41307323dad5ea2efdd95408f74ab2de8bd7dba9b51a1a/filelock-3.20.2.tar.gz", hash = "sha256:a2241ff4ddde2a7cebddf78e39832509cb045d18ec1a09d7248d6bfc6bfbbe64", size = 19510, upload-time = "2026-01-02T15:33:32.582Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/9a/30/ab407e2ec752aa541704ed8f93c11e2a5d92c168b8a755d818b74a3c5c2d/filelock-3.20.2-py3-none-any.whl", hash = "sha256:fbba7237d6ea277175a32c54bb71ef814a8546d8601269e1bfc388de333974e8", size = 16697, upload-time = "2026-01-02T15:33:31.133Z" },
]
[[package]]
name = "fsspec"
version = "2025.12.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/b6/27/954057b0d1f53f086f681755207dda6de6c660ce133c829158e8e8fe7895/fsspec-2025.12.0.tar.gz", hash = "sha256:c505de011584597b1060ff778bb664c1bc022e87921b0e4f10cc9c44f9635973", size = 309748, upload-time = "2025-12-03T15:23:42.687Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/51/c7/b64cae5dba3a1b138d7123ec36bb5ccd39d39939f18454407e5468f4763f/fsspec-2025.12.0-py3-none-any.whl", hash = "sha256:8bf1fe301b7d8acfa6e8571e3b1c3d158f909666642431cc78a1b7b4dbc5ec5b", size = 201422, upload-time = "2025-12-03T15:23:41.434Z" },
]
[[package]]
name = "h11"
version = "0.16.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" },
]
[[package]]
name = "hf-xet"
version = "1.2.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/5e/6e/0f11bacf08a67f7fb5ee09740f2ca54163863b07b70d579356e9222ce5d8/hf_xet-1.2.0.tar.gz", hash = "sha256:a8c27070ca547293b6890c4bf389f713f80e8c478631432962bb7f4bc0bd7d7f", size = 506020, upload-time = "2025-10-24T19:04:32.129Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/9e/a5/85ef910a0aa034a2abcfadc360ab5ac6f6bc4e9112349bd40ca97551cff0/hf_xet-1.2.0-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:ceeefcd1b7aed4956ae8499e2199607765fbd1c60510752003b6cc0b8413b649", size = 2861870, upload-time = "2025-10-24T19:04:11.422Z" },
{ url = "https://files.pythonhosted.org/packages/ea/40/e2e0a7eb9a51fe8828ba2d47fe22a7e74914ea8a0db68a18c3aa7449c767/hf_xet-1.2.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b70218dd548e9840224df5638fdc94bd033552963cfa97f9170829381179c813", size = 2717584, upload-time = "2025-10-24T19:04:09.586Z" },
{ url = "https://files.pythonhosted.org/packages/a5/7d/daf7f8bc4594fdd59a8a596f9e3886133fdc68e675292218a5e4c1b7e834/hf_xet-1.2.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d40b18769bb9a8bc82a9ede575ce1a44c75eb80e7375a01d76259089529b5dc", size = 3315004, upload-time = "2025-10-24T19:04:00.314Z" },
{ url = "https://files.pythonhosted.org/packages/b1/ba/45ea2f605fbf6d81c8b21e4d970b168b18a53515923010c312c06cd83164/hf_xet-1.2.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:cd3a6027d59cfb60177c12d6424e31f4b5ff13d8e3a1247b3a584bf8977e6df5", size = 3222636, upload-time = "2025-10-24T19:03:58.111Z" },
{ url = "https://files.pythonhosted.org/packages/4a/1d/04513e3cab8f29ab8c109d309ddd21a2705afab9d52f2ba1151e0c14f086/hf_xet-1.2.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6de1fc44f58f6dd937956c8d304d8c2dea264c80680bcfa61ca4a15e7b76780f", size = 3408448, upload-time = "2025-10-24T19:04:20.951Z" },
{ url = "https://files.pythonhosted.org/packages/f0/7c/60a2756d7feec7387db3a1176c632357632fbe7849fce576c5559d4520c7/hf_xet-1.2.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f182f264ed2acd566c514e45da9f2119110e48a87a327ca271027904c70c5832", size = 3503401, upload-time = "2025-10-24T19:04:22.549Z" },
{ url = "https://files.pythonhosted.org/packages/4e/64/48fffbd67fb418ab07451e4ce641a70de1c40c10a13e25325e24858ebe5a/hf_xet-1.2.0-cp313-cp313t-win_amd64.whl", hash = "sha256:293a7a3787e5c95d7be1857358a9130694a9c6021de3f27fa233f37267174382", size = 2900866, upload-time = "2025-10-24T19:04:33.461Z" },
{ url = "https://files.pythonhosted.org/packages/e2/51/f7e2caae42f80af886db414d4e9885fac959330509089f97cccb339c6b87/hf_xet-1.2.0-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:10bfab528b968c70e062607f663e21e34e2bba349e8038db546646875495179e", size = 2861861, upload-time = "2025-10-24T19:04:19.01Z" },
{ url = "https://files.pythonhosted.org/packages/6e/1d/a641a88b69994f9371bd347f1dd35e5d1e2e2460a2e350c8d5165fc62005/hf_xet-1.2.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:2a212e842647b02eb6a911187dc878e79c4aa0aa397e88dd3b26761676e8c1f8", size = 2717699, upload-time = "2025-10-24T19:04:17.306Z" },
{ url = "https://files.pythonhosted.org/packages/df/e0/e5e9bba7d15f0318955f7ec3f4af13f92e773fbb368c0b8008a5acbcb12f/hf_xet-1.2.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30e06daccb3a7d4c065f34fc26c14c74f4653069bb2b194e7f18f17cbe9939c0", size = 3314885, upload-time = "2025-10-24T19:04:07.642Z" },
{ url = "https://files.pythonhosted.org/packages/21/90/b7fe5ff6f2b7b8cbdf1bd56145f863c90a5807d9758a549bf3d916aa4dec/hf_xet-1.2.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:29c8fc913a529ec0a91867ce3d119ac1aac966e098cf49501800c870328cc090", size = 3221550, upload-time = "2025-10-24T19:04:05.55Z" },
{ url = "https://files.pythonhosted.org/packages/6f/cb/73f276f0a7ce46cc6a6ec7d6c7d61cbfe5f2e107123d9bbd0193c355f106/hf_xet-1.2.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:66e159cbfcfbb29f920db2c09ed8b660eb894640d284f102ada929b6e3dc410a", size = 3408010, upload-time = "2025-10-24T19:04:28.598Z" },
{ url = "https://files.pythonhosted.org/packages/b8/1e/d642a12caa78171f4be64f7cd9c40e3ca5279d055d0873188a58c0f5fbb9/hf_xet-1.2.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:9c91d5ae931510107f148874e9e2de8a16052b6f1b3ca3c1b12f15ccb491390f", size = 3503264, upload-time = "2025-10-24T19:04:30.397Z" },
{ url = "https://files.pythonhosted.org/packages/17/b5/33764714923fa1ff922770f7ed18c2daae034d21ae6e10dbf4347c854154/hf_xet-1.2.0-cp314-cp314t-win_amd64.whl", hash = "sha256:210d577732b519ac6ede149d2f2f34049d44e8622bf14eb3d63bbcd2d4b332dc", size = 2901071, upload-time = "2025-10-24T19:04:37.463Z" },
{ url = "https://files.pythonhosted.org/packages/96/2d/22338486473df5923a9ab7107d375dbef9173c338ebef5098ef593d2b560/hf_xet-1.2.0-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:46740d4ac024a7ca9b22bebf77460ff43332868b661186a8e46c227fdae01848", size = 2866099, upload-time = "2025-10-24T19:04:15.366Z" },
{ url = "https://files.pythonhosted.org/packages/7f/8c/c5becfa53234299bc2210ba314eaaae36c2875e0045809b82e40a9544f0c/hf_xet-1.2.0-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:27df617a076420d8845bea087f59303da8be17ed7ec0cd7ee3b9b9f579dff0e4", size = 2722178, upload-time = "2025-10-24T19:04:13.695Z" },
{ url = "https://files.pythonhosted.org/packages/9a/92/cf3ab0b652b082e66876d08da57fcc6fa2f0e6c70dfbbafbd470bb73eb47/hf_xet-1.2.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3651fd5bfe0281951b988c0facbe726aa5e347b103a675f49a3fa8144c7968fd", size = 3320214, upload-time = "2025-10-24T19:04:03.596Z" },
{ url = "https://files.pythonhosted.org/packages/46/92/3f7ec4a1b6a65bf45b059b6d4a5d38988f63e193056de2f420137e3c3244/hf_xet-1.2.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:d06fa97c8562fb3ee7a378dd9b51e343bc5bc8190254202c9771029152f5e08c", size = 3229054, upload-time = "2025-10-24T19:04:01.949Z" },
{ url = "https://files.pythonhosted.org/packages/0b/dd/7ac658d54b9fb7999a0ccb07ad863b413cbaf5cf172f48ebcd9497ec7263/hf_xet-1.2.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:4c1428c9ae73ec0939410ec73023c4f842927f39db09b063b9482dac5a3bb737", size = 3413812, upload-time = "2025-10-24T19:04:24.585Z" },
{ url = "https://files.pythonhosted.org/packages/92/68/89ac4e5b12a9ff6286a12174c8538a5930e2ed662091dd2572bbe0a18c8a/hf_xet-1.2.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a55558084c16b09b5ed32ab9ed38421e2d87cf3f1f89815764d1177081b99865", size = 3508920, upload-time = "2025-10-24T19:04:26.927Z" },
{ url = "https://files.pythonhosted.org/packages/cb/44/870d44b30e1dcfb6a65932e3e1506c103a8a5aea9103c337e7a53180322c/hf_xet-1.2.0-cp37-abi3-win_amd64.whl", hash = "sha256:e6584a52253f72c9f52f9e549d5895ca7a471608495c4ecaa6cc73dba2b24d69", size = 2905735, upload-time = "2025-10-24T19:04:35.928Z" },
]
[[package]]
name = "httpcore"
version = "1.0.9"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "certifi" },
{ name = "h11" },
]
sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" },
]
[[package]]
name = "httpx"
version = "0.28.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "certifi" },
{ name = "httpcore" },
{ name = "idna" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" },
]
[[package]]
name = "huggingface-hub"
version = "1.3.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock" },
{ name = "fsspec" },
{ name = "hf-xet", marker = "platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" },
{ name = "httpx" },
{ name = "packaging" },
{ name = "pyyaml" },
{ name = "shellingham" },
{ name = "tqdm" },
{ name = "typer-slim" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/dd/dd/1cc985c5dda36298b152f75e82a1c81f52243b78fb7e9cad637a29561ad1/huggingface_hub-1.3.1.tar.gz", hash = "sha256:e80e0cfb4a75557c51ab20d575bdea6bb6106c2f97b7c75d8490642f1efb6df5", size = 622356, upload-time = "2026-01-09T14:08:16.888Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/90/fb/cb8fe5f71d5622427f20bcab9e06a696a5aaf21bfe7bd0a8a0c63c88abf5/huggingface_hub-1.3.1-py3-none-any.whl", hash = "sha256:efbc7f3153cb84e2bb69b62ed90985e21ecc9343d15647a419fc0ee4b85f0ac3", size = 533351, upload-time = "2026-01-09T14:08:14.519Z" },
]
[[package]]
name = "idna"
version = "3.11"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/6f/6d/0703ccc57f3a7233505399edb88de3cbd678da106337b9fcde432b65ed60/idna-3.11.tar.gz", hash = "sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902", size = 194582, upload-time = "2025-10-12T14:55:20.501Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" },
]
[[package]]
name = "iniconfig"
version = "2.3.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
]
[[package]]
name = "mlx"
version = "0.30.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/07/14/74acbd677ececd17a44dafda1b472aebacef54f60ff9a41a801f711de9a7/mlx-0.30.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:acfd7d1b8e5b9fa1b7e9fab4cc5ba6a492c559fbb1c5aeab16c1d7a148ab4f1b", size = 593048, upload-time = "2025-12-18T01:55:34.9Z" },
{ url = "https://files.pythonhosted.org/packages/58/8c/5309848afb9c53d363f59b88ae5811de248e2817e91aeadf007e2ac8d22b/mlx-0.30.1-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:b62030471272d1835b8137164bd43d863cc93ff1d67ec4f1f87bb4c8613dd5a6", size = 593043, upload-time = "2025-12-18T01:55:36.839Z" },
{ url = "https://files.pythonhosted.org/packages/e8/5a/0039815a930f0193e2cffb27c57dc6971004bce0086c2bbbdb10395c272c/mlx-0.30.1-cp311-cp311-macosx_26_0_arm64.whl", hash = "sha256:0489cd340f2d262cb3aaad4368e40e84b152e182e4cea37ba018e56c72e1d020", size = 567014, upload-time = "2025-12-18T00:15:51.731Z" },
{ url = "https://files.pythonhosted.org/packages/de/c7/6bdb5497c1f5ed3e33afa7785761ad87fd3436c071805d9a93c905943f04/mlx-0.30.1-cp311-cp311-manylinux_2_35_aarch64.whl", hash = "sha256:fbdcfc3ed556a7e701a8eb67da299e2a25f52615193833ca6374decca3be5bf4", size = 658930, upload-time = "2025-12-18T01:55:38.441Z" },
{ url = "https://files.pythonhosted.org/packages/91/02/2d86a1c116e951eb4d88fe313c321e23628ce7404712e1258cacf925a8b8/mlx-0.30.1-cp311-cp311-manylinux_2_35_x86_64.whl", hash = "sha256:68ec854e7b5f89454e67d6c2fa7bb416b8afb148003ccd775904ec6ec4744818", size = 692484, upload-time = "2025-12-18T01:55:40.254Z" },
{ url = "https://files.pythonhosted.org/packages/3a/4b/ad57b2f0ede3f0d009c0e3e1270c219bd18f9025388855ee149680cffa20/mlx-0.30.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:deaef3ecd2f99930867a29de748e3bffa9cc7e4dfa834f2501c37ed29aece1cc", size = 593397, upload-time = "2025-12-18T01:55:41.814Z" },
{ url = "https://files.pythonhosted.org/packages/ef/14/7fa03a0f66ac3cfb2fd6752178a1488f13c7233fff26eed0f832d961db35/mlx-0.30.1-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:86ccdcda0b5ea4768b87da25beae5b83ac7cc802506116b6845cea6f450e2377", size = 593397, upload-time = "2025-12-18T01:55:43Z" },
{ url = "https://files.pythonhosted.org/packages/9c/c8/9f1343dbe2381f9653df4e0a62dc8bf38f575a2553dc2aa6916de32d2a85/mlx-0.30.1-cp312-cp312-macosx_26_0_arm64.whl", hash = "sha256:a625cb434b2acc5674fe10683374641dab9671fb354ae7c2c67a1fb0405eeb37", size = 567576, upload-time = "2025-12-18T00:15:55.114Z" },
{ url = "https://files.pythonhosted.org/packages/15/ff/485ed9c99c18ef89ac987178c0a526cb4148ba38b14838d315311d9d76a8/mlx-0.30.1-cp312-cp312-manylinux_2_35_aarch64.whl", hash = "sha256:ccc1ff3aca8fb1073c7dcd1274cebe48ae75f852d14b16c7db8228fbbad594dd", size = 643654, upload-time = "2025-12-18T01:55:44.165Z" },
{ url = "https://files.pythonhosted.org/packages/8a/d3/54d3bf5e404c3b6424b49c505dc8b3c06c6bb498fe720195b1fafbd69b5e/mlx-0.30.1-cp312-cp312-manylinux_2_35_x86_64.whl", hash = "sha256:55ed7fc4b389d6e49dac6d34a97b41e61cbe3662ac29c3d29cf612e6b2ed9827", size = 687305, upload-time = "2025-12-18T01:55:45.526Z" },
{ url = "https://files.pythonhosted.org/packages/f9/fd/c6f56cd87d48763ed63655ace627c06db9819eae7d43d132f40d4965947a/mlx-0.30.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:743520758bc8261b2ed8f3b3dc96e4e9236769dd8f61fb17877c5e44037e2058", size = 593366, upload-time = "2025-12-18T01:55:46.786Z" },
{ url = "https://files.pythonhosted.org/packages/dc/53/96d8c48b21f91c4216b6d2ef6dfc10862e5fb0b811a2aaf02c96c78601de/mlx-0.30.1-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:fc9745bc1860ca60128e3a6d36157da06d936e2b4007a4dcba990b40202f598f", size = 593368, upload-time = "2025-12-18T01:55:48.363Z" },
{ url = "https://files.pythonhosted.org/packages/70/ce/476c3b7d3a4153bd0e1c5af1f1b6c09a804b652bbed34072404b322c22e0/mlx-0.30.1-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:a1480399c67bb327a66c5527b73915132e3fcaae3bce9634e5c81ccad9f43229", size = 567561, upload-time = "2025-12-18T00:15:56.153Z" },
{ url = "https://files.pythonhosted.org/packages/33/41/7ad1e639fd7dd1cf01a62c1c5b051024a859888c27504996e9d8380e6754/mlx-0.30.1-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:8e19850a4236a8e174f851f5789b8b62a8eb74f5a8fa49ad8ba286c5ddb5f9bf", size = 643122, upload-time = "2025-12-18T01:55:49.607Z" },
{ url = "https://files.pythonhosted.org/packages/d0/dc/72d3737c5b0662eb5e785d353dbc5e34d793d27b09b99e39993ee051bd19/mlx-0.30.1-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:1c8ed5bcd9f1910fca209e95859ac737e60b3e1954181b820fa269158f81049a", size = 687254, upload-time = "2025-12-18T01:55:51.239Z" },
{ url = "https://files.pythonhosted.org/packages/9b/cc/523448996247bb05d9d68e23bccf3dafdda660befb9330f6bd5fa13361e8/mlx-0.30.1-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:d34cc2c25b0ee41c1349f14650db760e282685339858e305453f62405c12bc1b", size = 596006, upload-time = "2025-12-18T01:55:52.463Z" },
{ url = "https://files.pythonhosted.org/packages/23/0e/f9f2f9659c34c87be8f4167f6a1d6ed7e826f4889d20eecd4c0d8122f0e9/mlx-0.30.1-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:4e47d301e9095b87f0bda8827bfd6ffe744223aba5cee8f28e25894d647f5823", size = 596008, upload-time = "2025-12-18T01:55:54.02Z" },
{ url = "https://files.pythonhosted.org/packages/56/a7/49e41fb141de95b6a376091a963c737839c9cda04e423c67f57460a50458/mlx-0.30.1-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:cfba13e2a52255d663a1ad62f0f83eb3991e42147edf9a8d38cdd224e48ca49b", size = 570406, upload-time = "2025-12-18T00:15:57.177Z" },
{ url = "https://files.pythonhosted.org/packages/73/99/a43cb112167cf865c069f5e108ae42f5314663930ff3dd86c2d23d984191/mlx-0.30.1-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:bebfec377208eb29cc88aa86c897c7446aa0984838669e138f273f9225d627ff", size = 646461, upload-time = "2025-12-18T01:55:55.285Z" },
{ url = "https://files.pythonhosted.org/packages/d4/ff/1e1968f107b4221a98dc26832586b1f646b27ddf3e55c95051c09d751f0a/mlx-0.30.1-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:d18012d5cf0f013bc4a405cfd1e9d2d28e798f4d2dc4f15aa0fbffff73c02ba2", size = 687114, upload-time = "2025-12-18T01:55:56.506Z" },
]
[[package]]
name = "mlx-metal"
version = "0.30.1"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/09/3f/0be35ddad7e13d8ecd33a9185895f9739bb00b96ef0cce36cf0405d4aec0/mlx_metal-0.30.1-py3-none-macosx_14_0_arm64.whl", hash = "sha256:e7e92c6bdbd7ac8083f528a4c6640552ae106a57bb3d99856ac10a32e93a4b5e", size = 36864966, upload-time = "2025-12-18T01:55:31.473Z" },
{ url = "https://files.pythonhosted.org/packages/1e/1f/c0bddd0d5bf3871411aabe32121e09e1b7cdbece8917a49d5a442310e3e5/mlx_metal-0.30.1-py3-none-macosx_15_0_arm64.whl", hash = "sha256:bb50f57418af7fc3c42a2da2c4bde0e7ab7ac0b997de1f6f642a6680ac65d626", size = 36859011, upload-time = "2025-12-18T01:55:34.541Z" },
{ url = "https://files.pythonhosted.org/packages/67/b3/73cc2f584ac612a476096d35a61eed75ee7ed8b4e320b0c36cf60a14d4eb/mlx_metal-0.30.1-py3-none-macosx_26_0_arm64.whl", hash = "sha256:e0b151a0053ac00b4226710bfb6dbf54b87283fb01e10fb3877f9ea969f680aa", size = 44981160, upload-time = "2025-12-18T00:15:47.518Z" },
]
[[package]]
name = "mlx-video"
version = "0.0.1"
source = { editable = "." }
dependencies = [
{ name = "huggingface-hub" },
{ name = "mlx" },
{ name = "numpy" },
{ name = "safetensors" },
{ name = "tqdm" },
]
[package.optional-dependencies]
dev = [
{ name = "pytest" },
]
[package.metadata]
requires-dist = [
{ name = "huggingface-hub" },
{ name = "mlx", specifier = ">=0.22.0" },
{ name = "numpy" },
{ name = "pytest", marker = "extra == 'dev'" },
{ name = "safetensors" },
{ name = "tqdm" },
]
provides-extras = ["dev"]
[[package]]
name = "numpy"
version = "2.4.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a4/7a/6a3d14e205d292b738db449d0de649b373a59edb0d0b4493821d0a3e8718/numpy-2.4.0.tar.gz", hash = "sha256:6e504f7b16118198f138ef31ba24d985b124c2c469fe8467007cf30fd992f934", size = 20685720, upload-time = "2025-12-20T16:18:19.023Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/26/7e/7bae7cbcc2f8132271967aa03e03954fc1e48aa1f3bf32b29ca95fbef352/numpy-2.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:316b2f2584682318539f0bcaca5a496ce9ca78c88066579ebd11fd06f8e4741e", size = 16940166, upload-time = "2025-12-20T16:15:43.434Z" },
{ url = "https://files.pythonhosted.org/packages/0f/27/6c13f5b46776d6246ec884ac5817452672156a506d08a1f2abb39961930a/numpy-2.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a2718c1de8504121714234b6f8241d0019450353276c88b9453c9c3d92e101db", size = 12641781, upload-time = "2025-12-20T16:15:45.701Z" },
{ url = "https://files.pythonhosted.org/packages/14/1c/83b4998d4860d15283241d9e5215f28b40ac31f497c04b12fa7f428ff370/numpy-2.4.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:21555da4ec4a0c942520ead42c3b0dc9477441e085c42b0fbdd6a084869a6f6b", size = 5470247, upload-time = "2025-12-20T16:15:47.943Z" },
{ url = "https://files.pythonhosted.org/packages/54/08/cbce72c835d937795571b0464b52069f869c9e78b0c076d416c5269d2718/numpy-2.4.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:413aa561266a4be2d06cd2b9665e89d9f54c543f418773076a76adcf2af08bc7", size = 6799807, upload-time = "2025-12-20T16:15:49.795Z" },
{ url = "https://files.pythonhosted.org/packages/ff/be/2e647961cd8c980591d75cdcd9e8f647d69fbe05e2a25613dc0a2ea5fb1a/numpy-2.4.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0feafc9e03128074689183031181fac0897ff169692d8492066e949041096548", size = 14701992, upload-time = "2025-12-20T16:15:51.615Z" },
{ url = "https://files.pythonhosted.org/packages/a2/fb/e1652fb8b6fd91ce6ed429143fe2e01ce714711e03e5b762615e7b36172c/numpy-2.4.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a8fdfed3deaf1928fb7667d96e0567cdf58c2b370ea2ee7e586aa383ec2cb346", size = 16646871, upload-time = "2025-12-20T16:15:54.129Z" },
{ url = "https://files.pythonhosted.org/packages/62/23/d841207e63c4322842f7cd042ae981cffe715c73376dcad8235fb31debf1/numpy-2.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:e06a922a469cae9a57100864caf4f8a97a1026513793969f8ba5b63137a35d25", size = 16487190, upload-time = "2025-12-20T16:15:56.147Z" },
{ url = "https://files.pythonhosted.org/packages/bc/a0/6a842c8421ebfdec0a230e65f61e0dabda6edbef443d999d79b87c273965/numpy-2.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:927ccf5cd17c48f801f4ed43a7e5673a2724bd2171460be3e3894e6e332ef83a", size = 18580762, upload-time = "2025-12-20T16:15:58.524Z" },
{ url = "https://files.pythonhosted.org/packages/0a/d1/c79e0046641186f2134dde05e6181825b911f8bdcef31b19ddd16e232847/numpy-2.4.0-cp311-cp311-win32.whl", hash = "sha256:882567b7ae57c1b1a0250208cc21a7976d8cbcc49d5a322e607e6f09c9e0bd53", size = 6233359, upload-time = "2025-12-20T16:16:00.938Z" },
{ url = "https://files.pythonhosted.org/packages/fc/f0/74965001d231f28184d6305b8cdc1b6fcd4bf23033f6cb039cfe76c9fca7/numpy-2.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:8b986403023c8f3bf8f487c2e6186afda156174d31c175f747d8934dfddf3479", size = 12601132, upload-time = "2025-12-20T16:16:02.484Z" },
{ url = "https://files.pythonhosted.org/packages/65/32/55408d0f46dfebce38017f5bd931affa7256ad6beac1a92a012e1fbc67a7/numpy-2.4.0-cp311-cp311-win_arm64.whl", hash = "sha256:3f3096405acc48887458bbf9f6814d43785ac7ba2a57ea6442b581dedbc60ce6", size = 10573977, upload-time = "2025-12-20T16:16:04.77Z" },
{ url = "https://files.pythonhosted.org/packages/8b/ff/f6400ffec95de41c74b8e73df32e3fff1830633193a7b1e409be7fb1bb8c/numpy-2.4.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2a8b6bb8369abefb8bd1801b054ad50e02b3275c8614dc6e5b0373c305291037", size = 16653117, upload-time = "2025-12-20T16:16:06.709Z" },
{ url = "https://files.pythonhosted.org/packages/fd/28/6c23e97450035072e8d830a3c411bf1abd1f42c611ff9d29e3d8f55c6252/numpy-2.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2e284ca13d5a8367e43734148622caf0b261b275673823593e3e3634a6490f83", size = 12369711, upload-time = "2025-12-20T16:16:08.758Z" },
{ url = "https://files.pythonhosted.org/packages/bc/af/acbef97b630ab1bb45e6a7d01d1452e4251aa88ce680ac36e56c272120ec/numpy-2.4.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:49ff32b09f5aa0cd30a20c2b39db3e669c845589f2b7fc910365210887e39344", size = 5198355, upload-time = "2025-12-20T16:16:10.902Z" },
{ url = "https://files.pythonhosted.org/packages/c1/c8/4e0d436b66b826f2e53330adaa6311f5cac9871a5b5c31ad773b27f25a74/numpy-2.4.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:36cbfb13c152b1c7c184ddac43765db8ad672567e7bafff2cc755a09917ed2e6", size = 6545298, upload-time = "2025-12-20T16:16:12.607Z" },
{ url = "https://files.pythonhosted.org/packages/ef/27/e1f5d144ab54eac34875e79037011d511ac57b21b220063310cb96c80fbc/numpy-2.4.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:35ddc8f4914466e6fc954c76527aa91aa763682a4f6d73249ef20b418fe6effb", size = 14398387, upload-time = "2025-12-20T16:16:14.257Z" },
{ url = "https://files.pythonhosted.org/packages/67/64/4cb909dd5ab09a9a5d086eff9586e69e827b88a5585517386879474f4cf7/numpy-2.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dc578891de1db95b2a35001b695451767b580bb45753717498213c5ff3c41d63", size = 16363091, upload-time = "2025-12-20T16:16:17.32Z" },
{ url = "https://files.pythonhosted.org/packages/9d/9c/8efe24577523ec6809261859737cf117b0eb6fdb655abdfdc81b2e468ce4/numpy-2.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:98e81648e0b36e325ab67e46b5400a7a6d4a22b8a7c8e8bbfe20e7db7906bf95", size = 16176394, upload-time = "2025-12-20T16:16:19.524Z" },
{ url = "https://files.pythonhosted.org/packages/61/f0/1687441ece7b47a62e45a1f82015352c240765c707928edd8aef875d5951/numpy-2.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d57b5046c120561ba8fa8e4030fbb8b822f3063910fa901ffadf16e2b7128ad6", size = 18287378, upload-time = "2025-12-20T16:16:22.866Z" },
{ url = "https://files.pythonhosted.org/packages/d3/6f/f868765d44e6fc466467ed810ba9d8d6db1add7d4a748abfa2a4c99a3194/numpy-2.4.0-cp312-cp312-win32.whl", hash = "sha256:92190db305a6f48734d3982f2c60fa30d6b5ee9bff10f2887b930d7b40119f4c", size = 5955432, upload-time = "2025-12-20T16:16:25.06Z" },
{ url = "https://files.pythonhosted.org/packages/d4/b5/94c1e79fcbab38d1ca15e13777477b2914dd2d559b410f96949d6637b085/numpy-2.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:680060061adb2d74ce352628cb798cfdec399068aa7f07ba9fb818b2b3305f98", size = 12306201, upload-time = "2025-12-20T16:16:26.979Z" },
{ url = "https://files.pythonhosted.org/packages/70/09/c39dadf0b13bb0768cd29d6a3aaff1fb7c6905ac40e9aaeca26b1c086e06/numpy-2.4.0-cp312-cp312-win_arm64.whl", hash = "sha256:39699233bc72dd482da1415dcb06076e32f60eddc796a796c5fb6c5efce94667", size = 10308234, upload-time = "2025-12-20T16:16:29.417Z" },
{ url = "https://files.pythonhosted.org/packages/a7/0d/853fd96372eda07c824d24adf02e8bc92bb3731b43a9b2a39161c3667cc4/numpy-2.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a152d86a3ae00ba5f47b3acf3b827509fd0b6cb7d3259665e63dafbad22a75ea", size = 16649088, upload-time = "2025-12-20T16:16:31.421Z" },
{ url = "https://files.pythonhosted.org/packages/e3/37/cc636f1f2a9f585434e20a3e6e63422f70bfe4f7f6698e941db52ea1ac9a/numpy-2.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:39b19251dec4de8ff8496cd0806cbe27bf0684f765abb1f4809554de93785f2d", size = 12364065, upload-time = "2025-12-20T16:16:33.491Z" },
{ url = "https://files.pythonhosted.org/packages/ed/69/0b78f37ca3690969beee54103ce5f6021709134e8020767e93ba691a72f1/numpy-2.4.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:009bd0ea12d3c784b6639a8457537016ce5172109e585338e11334f6a7bb88ee", size = 5192640, upload-time = "2025-12-20T16:16:35.636Z" },
{ url = "https://files.pythonhosted.org/packages/1d/2a/08569f8252abf590294dbb09a430543ec8f8cc710383abfb3e75cc73aeda/numpy-2.4.0-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:5fe44e277225fd3dff6882d86d3d447205d43532c3627313d17e754fb3905a0e", size = 6541556, upload-time = "2025-12-20T16:16:37.276Z" },
{ url = "https://files.pythonhosted.org/packages/93/e9/a949885a4e177493d61519377952186b6cbfdf1d6002764c664ba28349b5/numpy-2.4.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f935c4493eda9069851058fa0d9e39dbf6286be690066509305e52912714dbb2", size = 14396562, upload-time = "2025-12-20T16:16:38.953Z" },
{ url = "https://files.pythonhosted.org/packages/99/98/9d4ad53b0e9ef901c2ef1d550d2136f5ac42d3fd2988390a6def32e23e48/numpy-2.4.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8cfa5f29a695cb7438965e6c3e8d06e0416060cf0d709c1b1c1653a939bf5c2a", size = 16351719, upload-time = "2025-12-20T16:16:41.503Z" },
{ url = "https://files.pythonhosted.org/packages/28/de/5f3711a38341d6e8dd619f6353251a0cdd07f3d6d101a8fd46f4ef87f895/numpy-2.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ba0cb30acd3ef11c94dc27fbfba68940652492bc107075e7ffe23057f9425681", size = 16176053, upload-time = "2025-12-20T16:16:44.552Z" },
{ url = "https://files.pythonhosted.org/packages/2a/5b/2a3753dc43916501b4183532e7ace862e13211042bceafa253afb5c71272/numpy-2.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:60e8c196cd82cbbd4f130b5290007e13e6de3eca79f0d4d38014769d96a7c475", size = 18277859, upload-time = "2025-12-20T16:16:47.174Z" },
{ url = "https://files.pythonhosted.org/packages/2c/c5/a18bcdd07a941db3076ef489d036ab16d2bfc2eae0cf27e5a26e29189434/numpy-2.4.0-cp313-cp313-win32.whl", hash = "sha256:5f48cb3e88fbc294dc90e215d86fbaf1c852c63dbdb6c3a3e63f45c4b57f7344", size = 5953849, upload-time = "2025-12-20T16:16:49.554Z" },
{ url = "https://files.pythonhosted.org/packages/4f/f1/719010ff8061da6e8a26e1980cf090412d4f5f8060b31f0c45d77dd67a01/numpy-2.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:a899699294f28f7be8992853c0c60741f16ff199205e2e6cdca155762cbaa59d", size = 12302840, upload-time = "2025-12-20T16:16:51.227Z" },
{ url = "https://files.pythonhosted.org/packages/f5/5a/b3d259083ed8b4d335270c76966cb6cf14a5d1b69e1a608994ac57a659e6/numpy-2.4.0-cp313-cp313-win_arm64.whl", hash = "sha256:9198f447e1dc5647d07c9a6bbe2063cc0132728cc7175b39dbc796da5b54920d", size = 10308509, upload-time = "2025-12-20T16:16:53.313Z" },
{ url = "https://files.pythonhosted.org/packages/31/01/95edcffd1bb6c0633df4e808130545c4f07383ab629ac7e316fb44fff677/numpy-2.4.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:74623f2ab5cc3f7c886add4f735d1031a1d2be4a4ae63c0546cfd74e7a31ddf6", size = 12491815, upload-time = "2025-12-20T16:16:55.496Z" },
{ url = "https://files.pythonhosted.org/packages/59/ea/5644b8baa92cc1c7163b4b4458c8679852733fa74ca49c942cfa82ded4e0/numpy-2.4.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:0804a8e4ab070d1d35496e65ffd3cf8114c136a2b81f61dfab0de4b218aacfd5", size = 5320321, upload-time = "2025-12-20T16:16:57.468Z" },
{ url = "https://files.pythonhosted.org/packages/26/4e/e10938106d70bc21319bd6a86ae726da37edc802ce35a3a71ecdf1fdfe7f/numpy-2.4.0-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:02a2038eb27f9443a8b266a66911e926566b5a6ffd1a689b588f7f35b81e7dc3", size = 6641635, upload-time = "2025-12-20T16:16:59.379Z" },
{ url = "https://files.pythonhosted.org/packages/b3/8d/a8828e3eaf5c0b4ab116924df82f24ce3416fa38d0674d8f708ddc6c8aac/numpy-2.4.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1889b3a3f47a7b5bee16bc25a2145bd7cb91897f815ce3499db64c7458b6d91d", size = 14456053, upload-time = "2025-12-20T16:17:01.768Z" },
{ url = "https://files.pythonhosted.org/packages/68/a1/17d97609d87d4520aa5ae2dcfb32305654550ac6a35effb946d303e594ce/numpy-2.4.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:85eef4cb5625c47ee6425c58a3502555e10f45ee973da878ac8248ad58c136f3", size = 16401702, upload-time = "2025-12-20T16:17:04.235Z" },
{ url = "https://files.pythonhosted.org/packages/18/32/0f13c1b2d22bea1118356b8b963195446f3af124ed7a5adfa8fdecb1b6ca/numpy-2.4.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6dc8b7e2f4eb184b37655195f421836cfae6f58197b67e3ffc501f1333d993fa", size = 16242493, upload-time = "2025-12-20T16:17:06.856Z" },
{ url = "https://files.pythonhosted.org/packages/ae/23/48f21e3d309fbc137c068a1475358cbd3a901b3987dcfc97a029ab3068e2/numpy-2.4.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:44aba2f0cafd287871a495fb3163408b0bd25bbce135c6f621534a07f4f7875c", size = 18324222, upload-time = "2025-12-20T16:17:09.392Z" },
{ url = "https://files.pythonhosted.org/packages/ac/52/41f3d71296a3dcaa4f456aaa3c6fc8e745b43d0552b6bde56571bb4b4a0f/numpy-2.4.0-cp313-cp313t-win32.whl", hash = "sha256:20c115517513831860c573996e395707aa9fb691eb179200125c250e895fcd93", size = 6076216, upload-time = "2025-12-20T16:17:11.437Z" },
{ url = "https://files.pythonhosted.org/packages/35/ff/46fbfe60ab0710d2a2b16995f708750307d30eccbb4c38371ea9e986866e/numpy-2.4.0-cp313-cp313t-win_amd64.whl", hash = "sha256:b48e35f4ab6f6a7597c46e301126ceba4c44cd3280e3750f85db48b082624fa4", size = 12444263, upload-time = "2025-12-20T16:17:13.182Z" },
{ url = "https://files.pythonhosted.org/packages/a3/e3/9189ab319c01d2ed556c932ccf55064c5d75bb5850d1df7a482ce0badead/numpy-2.4.0-cp313-cp313t-win_arm64.whl", hash = "sha256:4d1cfce39e511069b11e67cd0bd78ceff31443b7c9e5c04db73c7a19f572967c", size = 10378265, upload-time = "2025-12-20T16:17:15.211Z" },
{ url = "https://files.pythonhosted.org/packages/ab/ed/52eac27de39d5e5a6c9aadabe672bc06f55e24a3d9010cd1183948055d76/numpy-2.4.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:c95eb6db2884917d86cde0b4d4cf31adf485c8ec36bf8696dd66fa70de96f36b", size = 16647476, upload-time = "2025-12-20T16:17:17.671Z" },
{ url = "https://files.pythonhosted.org/packages/77/c0/990ce1b7fcd4e09aeaa574e2a0a839589e4b08b2ca68070f1acb1fea6736/numpy-2.4.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:65167da969cd1ec3a1df31cb221ca3a19a8aaa25370ecb17d428415e93c1935e", size = 12374563, upload-time = "2025-12-20T16:17:20.216Z" },
{ url = "https://files.pythonhosted.org/packages/37/7c/8c5e389c6ae8f5fd2277a988600d79e9625db3fff011a2d87ac80b881a4c/numpy-2.4.0-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:3de19cfecd1465d0dcf8a5b5ea8b3155b42ed0b639dba4b71e323d74f2a3be5e", size = 5203107, upload-time = "2025-12-20T16:17:22.47Z" },
{ url = "https://files.pythonhosted.org/packages/e6/94/ca5b3bd6a8a70a5eec9a0b8dd7f980c1eff4b8a54970a9a7fef248ef564f/numpy-2.4.0-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:6c05483c3136ac4c91b4e81903cb53a8707d316f488124d0398499a4f8e8ef51", size = 6538067, upload-time = "2025-12-20T16:17:24.001Z" },
{ url = "https://files.pythonhosted.org/packages/79/43/993eb7bb5be6761dde2b3a3a594d689cec83398e3f58f4758010f3b85727/numpy-2.4.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:36667db4d6c1cea79c8930ab72fadfb4060feb4bfe724141cd4bd064d2e5f8ce", size = 14411926, upload-time = "2025-12-20T16:17:25.822Z" },
{ url = "https://files.pythonhosted.org/packages/03/75/d4c43b61de473912496317a854dac54f1efec3eeb158438da6884b70bb90/numpy-2.4.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9a818668b674047fd88c4cddada7ab8f1c298812783e8328e956b78dc4807f9f", size = 16354295, upload-time = "2025-12-20T16:17:28.308Z" },
{ url = "https://files.pythonhosted.org/packages/b8/0a/b54615b47ee8736a6461a4bb6749128dd3435c5a759d5663f11f0e9af4ac/numpy-2.4.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:1ee32359fb7543b7b7bd0b2f46294db27e29e7bbdf70541e81b190836cd83ded", size = 16190242, upload-time = "2025-12-20T16:17:30.993Z" },
{ url = "https://files.pythonhosted.org/packages/98/ce/ea207769aacad6246525ec6c6bbd66a2bf56c72443dc10e2f90feed29290/numpy-2.4.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:e493962256a38f58283de033d8af176c5c91c084ea30f15834f7545451c42059", size = 18280875, upload-time = "2025-12-20T16:17:33.327Z" },
{ url = "https://files.pythonhosted.org/packages/17/ef/ec409437aa962ea372ed601c519a2b141701683ff028f894b7466f0ab42b/numpy-2.4.0-cp314-cp314-win32.whl", hash = "sha256:6bbaebf0d11567fa8926215ae731e1d58e6ec28a8a25235b8a47405d301332db", size = 6002530, upload-time = "2025-12-20T16:17:35.729Z" },
{ url = "https://files.pythonhosted.org/packages/5f/4a/5cb94c787a3ed1ac65e1271b968686521169a7b3ec0b6544bb3ca32960b0/numpy-2.4.0-cp314-cp314-win_amd64.whl", hash = "sha256:3d857f55e7fdf7c38ab96c4558c95b97d1c685be6b05c249f5fdafcbd6f9899e", size = 12435890, upload-time = "2025-12-20T16:17:37.599Z" },
{ url = "https://files.pythonhosted.org/packages/48/a0/04b89db963af9de1104975e2544f30de89adbf75b9e75f7dd2599be12c79/numpy-2.4.0-cp314-cp314-win_arm64.whl", hash = "sha256:bb50ce5fb202a26fd5404620e7ef820ad1ab3558b444cb0b55beb7ef66cd2d63", size = 10591892, upload-time = "2025-12-20T16:17:39.649Z" },
{ url = "https://files.pythonhosted.org/packages/53/e5/d74b5ccf6712c06c7a545025a6a71bfa03bdc7e0568b405b0d655232fd92/numpy-2.4.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:355354388cba60f2132df297e2d53053d4063f79077b67b481d21276d61fc4df", size = 12494312, upload-time = "2025-12-20T16:17:41.714Z" },
{ url = "https://files.pythonhosted.org/packages/c2/08/3ca9cc2ddf54dfee7ae9a6479c071092a228c68aef08252aa08dac2af002/numpy-2.4.0-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:1d8f9fde5f6dc1b6fc34df8162f3b3079365468703fee7f31d4e0cc8c63baed9", size = 5322862, upload-time = "2025-12-20T16:17:44.145Z" },
{ url = "https://files.pythonhosted.org/packages/87/74/0bb63a68394c0c1e52670cfff2e309afa41edbe11b3327d9af29e4383f34/numpy-2.4.0-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:e0434aa22c821f44eeb4c650b81c7fbdd8c0122c6c4b5a576a76d5a35625ecd9", size = 6644986, upload-time = "2025-12-20T16:17:46.203Z" },
{ url = "https://files.pythonhosted.org/packages/06/8f/9264d9bdbcf8236af2823623fe2f3981d740fc3461e2787e231d97c38c28/numpy-2.4.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:40483b2f2d3ba7aad426443767ff5632ec3156ef09742b96913787d13c336471", size = 14457958, upload-time = "2025-12-20T16:17:48.017Z" },
{ url = "https://files.pythonhosted.org/packages/8c/d9/f9a69ae564bbc7236a35aa883319364ef5fd41f72aa320cc1cbe66148fe2/numpy-2.4.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d9e6a7664ddd9746e20b7325351fe1a8408d0a2bf9c63b5e898290ddc8f09544", size = 16398394, upload-time = "2025-12-20T16:17:50.409Z" },
{ url = "https://files.pythonhosted.org/packages/34/c7/39241501408dde7f885d241a98caba5421061a2c6d2b2197ac5e3aa842d8/numpy-2.4.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:ecb0019d44f4cdb50b676c5d0cb4b1eae8e15d1ed3d3e6639f986fc92b2ec52c", size = 16241044, upload-time = "2025-12-20T16:17:52.661Z" },
{ url = "https://files.pythonhosted.org/packages/7c/95/cae7effd90e065a95e59fe710eeee05d7328ed169776dfdd9f789e032125/numpy-2.4.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:d0ffd9e2e4441c96a9c91ec1783285d80bf835b677853fc2770a89d50c1e48ac", size = 18321772, upload-time = "2025-12-20T16:17:54.947Z" },
{ url = "https://files.pythonhosted.org/packages/96/df/3c6c279accd2bfb968a76298e5b276310bd55d243df4fa8ac5816d79347d/numpy-2.4.0-cp314-cp314t-win32.whl", hash = "sha256:77f0d13fa87036d7553bf81f0e1fe3ce68d14c9976c9851744e4d3e91127e95f", size = 6148320, upload-time = "2025-12-20T16:17:57.249Z" },
{ url = "https://files.pythonhosted.org/packages/92/8d/f23033cce252e7a75cae853d17f582e86534c46404dea1c8ee094a9d6d84/numpy-2.4.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b1f5b45829ac1848893f0ddf5cb326110604d6df96cdc255b0bf9edd154104d4", size = 12623460, upload-time = "2025-12-20T16:17:58.963Z" },
{ url = "https://files.pythonhosted.org/packages/a4/4f/1f8475907d1a7c4ef9020edf7f39ea2422ec896849245f00688e4b268a71/numpy-2.4.0-cp314-cp314t-win_arm64.whl", hash = "sha256:23a3e9d1a6f360267e8fbb38ba5db355a6a7e9be71d7fce7ab3125e88bb646c8", size = 10661799, upload-time = "2025-12-20T16:18:01.078Z" },
{ url = "https://files.pythonhosted.org/packages/4b/ef/088e7c7342f300aaf3ee5f2c821c4b9996a1bef2aaf6a49cc8ab4883758e/numpy-2.4.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:b54c83f1c0c0f1d748dca0af516062b8829d53d1f0c402be24b4257a9c48ada6", size = 16819003, upload-time = "2025-12-20T16:18:03.41Z" },
{ url = "https://files.pythonhosted.org/packages/ff/ce/a53017b5443b4b84517182d463fc7bcc2adb4faa8b20813f8e5f5aeb5faa/numpy-2.4.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:aabb081ca0ec5d39591fc33018cd4b3f96e1a2dd6756282029986d00a785fba4", size = 12567105, upload-time = "2025-12-20T16:18:05.594Z" },
{ url = "https://files.pythonhosted.org/packages/77/58/5ff91b161f2ec650c88a626c3905d938c89aaadabd0431e6d9c1330c83e2/numpy-2.4.0-pp311-pypy311_pp73-macosx_14_0_arm64.whl", hash = "sha256:8eafe7c36c8430b7794edeab3087dec7bf31d634d92f2af9949434b9d1964cba", size = 5395590, upload-time = "2025-12-20T16:18:08.031Z" },
{ url = "https://files.pythonhosted.org/packages/1d/4e/f1a084106df8c2df8132fc437e56987308e0524836aa7733721c8429d4fe/numpy-2.4.0-pp311-pypy311_pp73-macosx_14_0_x86_64.whl", hash = "sha256:2f585f52b2baf07ff3356158d9268ea095e221371f1074fadea2f42544d58b4d", size = 6709947, upload-time = "2025-12-20T16:18:09.836Z" },
{ url = "https://files.pythonhosted.org/packages/63/09/3d8aeb809c0332c3f642da812ac2e3d74fc9252b3021f8c30c82e99e3f3d/numpy-2.4.0-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:32ed06d0fe9cae27d8fb5f400c63ccee72370599c75e683a6358dd3a4fb50aaf", size = 14535119, upload-time = "2025-12-20T16:18:12.105Z" },
{ url = "https://files.pythonhosted.org/packages/fd/7f/68f0fc43a2cbdc6bb239160c754d87c922f60fbaa0fa3cd3d312b8a7f5ee/numpy-2.4.0-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:57c540ed8fb1f05cb997c6761cd56db72395b0d6985e90571ff660452ade4f98", size = 16475815, upload-time = "2025-12-20T16:18:14.433Z" },
{ url = "https://files.pythonhosted.org/packages/11/73/edeacba3167b1ca66d51b1a5a14697c2c40098b5ffa01811c67b1785a5ab/numpy-2.4.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:a39fb973a726e63223287adc6dafe444ce75af952d711e400f3bf2b36ef55a7b", size = 12489376, upload-time = "2025-12-20T16:18:16.524Z" },
]
[[package]]
name = "packaging"
version = "25.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" },
]
[[package]]
name = "pluggy"
version = "1.6.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
]
[[package]]
name = "pygments"
version = "2.19.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
]
[[package]]
name = "pytest"
version = "9.0.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "iniconfig" },
{ name = "packaging" },
{ name = "pluggy" },
{ name = "pygments" },
]
sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" },
]
[[package]]
name = "pyyaml"
version = "6.0.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/05/8e/961c0007c59b8dd7729d542c61a4d537767a59645b82a0b521206e1e25c2/pyyaml-6.0.3.tar.gz", hash = "sha256:d76623373421df22fb4cf8817020cbb7ef15c725b9d5e45f17e189bfc384190f", size = 130960, upload-time = "2025-09-25T21:33:16.546Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/6d/16/a95b6757765b7b031c9374925bb718d55e0a9ba8a1b6a12d25962ea44347/pyyaml-6.0.3-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:44edc647873928551a01e7a563d7452ccdebee747728c1080d881d68af7b997e", size = 185826, upload-time = "2025-09-25T21:31:58.655Z" },
{ url = "https://files.pythonhosted.org/packages/16/19/13de8e4377ed53079ee996e1ab0a9c33ec2faf808a4647b7b4c0d46dd239/pyyaml-6.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:652cb6edd41e718550aad172851962662ff2681490a8a711af6a4d288dd96824", size = 175577, upload-time = "2025-09-25T21:32:00.088Z" },
{ url = "https://files.pythonhosted.org/packages/0c/62/d2eb46264d4b157dae1275b573017abec435397aa59cbcdab6fc978a8af4/pyyaml-6.0.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:10892704fc220243f5305762e276552a0395f7beb4dbf9b14ec8fd43b57f126c", size = 775556, upload-time = "2025-09-25T21:32:01.31Z" },
{ url = "https://files.pythonhosted.org/packages/10/cb/16c3f2cf3266edd25aaa00d6c4350381c8b012ed6f5276675b9eba8d9ff4/pyyaml-6.0.3-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:850774a7879607d3a6f50d36d04f00ee69e7fc816450e5f7e58d7f17f1ae5c00", size = 882114, upload-time = "2025-09-25T21:32:03.376Z" },
{ url = "https://files.pythonhosted.org/packages/71/60/917329f640924b18ff085ab889a11c763e0b573da888e8404ff486657602/pyyaml-6.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b8bb0864c5a28024fac8a632c443c87c5aa6f215c0b126c449ae1a150412f31d", size = 806638, upload-time = "2025-09-25T21:32:04.553Z" },
{ url = "https://files.pythonhosted.org/packages/dd/6f/529b0f316a9fd167281a6c3826b5583e6192dba792dd55e3203d3f8e655a/pyyaml-6.0.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1d37d57ad971609cf3c53ba6a7e365e40660e3be0e5175fa9f2365a379d6095a", size = 767463, upload-time = "2025-09-25T21:32:06.152Z" },
{ url = "https://files.pythonhosted.org/packages/f2/6a/b627b4e0c1dd03718543519ffb2f1deea4a1e6d42fbab8021936a4d22589/pyyaml-6.0.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:37503bfbfc9d2c40b344d06b2199cf0e96e97957ab1c1b546fd4f87e53e5d3e4", size = 794986, upload-time = "2025-09-25T21:32:07.367Z" },
{ url = "https://files.pythonhosted.org/packages/45/91/47a6e1c42d9ee337c4839208f30d9f09caa9f720ec7582917b264defc875/pyyaml-6.0.3-cp311-cp311-win32.whl", hash = "sha256:8098f252adfa6c80ab48096053f512f2321f0b998f98150cea9bd23d83e1467b", size = 142543, upload-time = "2025-09-25T21:32:08.95Z" },
{ url = "https://files.pythonhosted.org/packages/da/e3/ea007450a105ae919a72393cb06f122f288ef60bba2dc64b26e2646fa315/pyyaml-6.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:9f3bfb4965eb874431221a3ff3fdcddc7e74e3b07799e0e84ca4a0f867d449bf", size = 158763, upload-time = "2025-09-25T21:32:09.96Z" },
{ url = "https://files.pythonhosted.org/packages/d1/33/422b98d2195232ca1826284a76852ad5a86fe23e31b009c9886b2d0fb8b2/pyyaml-6.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7f047e29dcae44602496db43be01ad42fc6f1cc0d8cd6c83d342306c32270196", size = 182063, upload-time = "2025-09-25T21:32:11.445Z" },
{ url = "https://files.pythonhosted.org/packages/89/a0/6cf41a19a1f2f3feab0e9c0b74134aa2ce6849093d5517a0c550fe37a648/pyyaml-6.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fc09d0aa354569bc501d4e787133afc08552722d3ab34836a80547331bb5d4a0", size = 173973, upload-time = "2025-09-25T21:32:12.492Z" },
{ url = "https://files.pythonhosted.org/packages/ed/23/7a778b6bd0b9a8039df8b1b1d80e2e2ad78aa04171592c8a5c43a56a6af4/pyyaml-6.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9149cad251584d5fb4981be1ecde53a1ca46c891a79788c0df828d2f166bda28", size = 775116, upload-time = "2025-09-25T21:32:13.652Z" },
{ url = "https://files.pythonhosted.org/packages/65/30/d7353c338e12baef4ecc1b09e877c1970bd3382789c159b4f89d6a70dc09/pyyaml-6.0.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5fdec68f91a0c6739b380c83b951e2c72ac0197ace422360e6d5a959d8d97b2c", size = 844011, upload-time = "2025-09-25T21:32:15.21Z" },
{ url = "https://files.pythonhosted.org/packages/8b/9d/b3589d3877982d4f2329302ef98a8026e7f4443c765c46cfecc8858c6b4b/pyyaml-6.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ba1cc08a7ccde2d2ec775841541641e4548226580ab850948cbfda66a1befcdc", size = 807870, upload-time = "2025-09-25T21:32:16.431Z" },
{ url = "https://files.pythonhosted.org/packages/05/c0/b3be26a015601b822b97d9149ff8cb5ead58c66f981e04fedf4e762f4bd4/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8dc52c23056b9ddd46818a57b78404882310fb473d63f17b07d5c40421e47f8e", size = 761089, upload-time = "2025-09-25T21:32:17.56Z" },
{ url = "https://files.pythonhosted.org/packages/be/8e/98435a21d1d4b46590d5459a22d88128103f8da4c2d4cb8f14f2a96504e1/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:41715c910c881bc081f1e8872880d3c650acf13dfa8214bad49ed4cede7c34ea", size = 790181, upload-time = "2025-09-25T21:32:18.834Z" },
{ url = "https://files.pythonhosted.org/packages/74/93/7baea19427dcfbe1e5a372d81473250b379f04b1bd3c4c5ff825e2327202/pyyaml-6.0.3-cp312-cp312-win32.whl", hash = "sha256:96b533f0e99f6579b3d4d4995707cf36df9100d67e0c8303a0c55b27b5f99bc5", size = 137658, upload-time = "2025-09-25T21:32:20.209Z" },
{ url = "https://files.pythonhosted.org/packages/86/bf/899e81e4cce32febab4fb42bb97dcdf66bc135272882d1987881a4b519e9/pyyaml-6.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:5fcd34e47f6e0b794d17de1b4ff496c00986e1c83f7ab2fb8fcfe9616ff7477b", size = 154003, upload-time = "2025-09-25T21:32:21.167Z" },
{ url = "https://files.pythonhosted.org/packages/1a/08/67bd04656199bbb51dbed1439b7f27601dfb576fb864099c7ef0c3e55531/pyyaml-6.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:64386e5e707d03a7e172c0701abfb7e10f0fb753ee1d773128192742712a98fd", size = 140344, upload-time = "2025-09-25T21:32:22.617Z" },
{ url = "https://files.pythonhosted.org/packages/d1/11/0fd08f8192109f7169db964b5707a2f1e8b745d4e239b784a5a1dd80d1db/pyyaml-6.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8da9669d359f02c0b91ccc01cac4a67f16afec0dac22c2ad09f46bee0697eba8", size = 181669, upload-time = "2025-09-25T21:32:23.673Z" },
{ url = "https://files.pythonhosted.org/packages/b1/16/95309993f1d3748cd644e02e38b75d50cbc0d9561d21f390a76242ce073f/pyyaml-6.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2283a07e2c21a2aa78d9c4442724ec1eb15f5e42a723b99cb3d822d48f5f7ad1", size = 173252, upload-time = "2025-09-25T21:32:25.149Z" },
{ url = "https://files.pythonhosted.org/packages/50/31/b20f376d3f810b9b2371e72ef5adb33879b25edb7a6d072cb7ca0c486398/pyyaml-6.0.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ee2922902c45ae8ccada2c5b501ab86c36525b883eff4255313a253a3160861c", size = 767081, upload-time = "2025-09-25T21:32:26.575Z" },
{ url = "https://files.pythonhosted.org/packages/49/1e/a55ca81e949270d5d4432fbbd19dfea5321eda7c41a849d443dc92fd1ff7/pyyaml-6.0.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a33284e20b78bd4a18c8c2282d549d10bc8408a2a7ff57653c0cf0b9be0afce5", size = 841159, upload-time = "2025-09-25T21:32:27.727Z" },
{ url = "https://files.pythonhosted.org/packages/74/27/e5b8f34d02d9995b80abcef563ea1f8b56d20134d8f4e5e81733b1feceb2/pyyaml-6.0.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0f29edc409a6392443abf94b9cf89ce99889a1dd5376d94316ae5145dfedd5d6", size = 801626, upload-time = "2025-09-25T21:32:28.878Z" },
{ url = "https://files.pythonhosted.org/packages/f9/11/ba845c23988798f40e52ba45f34849aa8a1f2d4af4b798588010792ebad6/pyyaml-6.0.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f7057c9a337546edc7973c0d3ba84ddcdf0daa14533c2065749c9075001090e6", size = 753613, upload-time = "2025-09-25T21:32:30.178Z" },
{ url = "https://files.pythonhosted.org/packages/3d/e0/7966e1a7bfc0a45bf0a7fb6b98ea03fc9b8d84fa7f2229e9659680b69ee3/pyyaml-6.0.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:eda16858a3cab07b80edaf74336ece1f986ba330fdb8ee0d6c0d68fe82bc96be", size = 794115, upload-time = "2025-09-25T21:32:31.353Z" },
{ url = "https://files.pythonhosted.org/packages/de/94/980b50a6531b3019e45ddeada0626d45fa85cbe22300844a7983285bed3b/pyyaml-6.0.3-cp313-cp313-win32.whl", hash = "sha256:d0eae10f8159e8fdad514efdc92d74fd8d682c933a6dd088030f3834bc8e6b26", size = 137427, upload-time = "2025-09-25T21:32:32.58Z" },
{ url = "https://files.pythonhosted.org/packages/97/c9/39d5b874e8b28845e4ec2202b5da735d0199dbe5b8fb85f91398814a9a46/pyyaml-6.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:79005a0d97d5ddabfeeea4cf676af11e647e41d81c9a7722a193022accdb6b7c", size = 154090, upload-time = "2025-09-25T21:32:33.659Z" },
{ url = "https://files.pythonhosted.org/packages/73/e8/2bdf3ca2090f68bb3d75b44da7bbc71843b19c9f2b9cb9b0f4ab7a5a4329/pyyaml-6.0.3-cp313-cp313-win_arm64.whl", hash = "sha256:5498cd1645aa724a7c71c8f378eb29ebe23da2fc0d7a08071d89469bf1d2defb", size = 140246, upload-time = "2025-09-25T21:32:34.663Z" },
{ url = "https://files.pythonhosted.org/packages/9d/8c/f4bd7f6465179953d3ac9bc44ac1a8a3e6122cf8ada906b4f96c60172d43/pyyaml-6.0.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:8d1fab6bb153a416f9aeb4b8763bc0f22a5586065f86f7664fc23339fc1c1fac", size = 181814, upload-time = "2025-09-25T21:32:35.712Z" },
{ url = "https://files.pythonhosted.org/packages/bd/9c/4d95bb87eb2063d20db7b60faa3840c1b18025517ae857371c4dd55a6b3a/pyyaml-6.0.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:34d5fcd24b8445fadc33f9cf348c1047101756fd760b4dacb5c3e99755703310", size = 173809, upload-time = "2025-09-25T21:32:36.789Z" },
{ url = "https://files.pythonhosted.org/packages/92/b5/47e807c2623074914e29dabd16cbbdd4bf5e9b2db9f8090fa64411fc5382/pyyaml-6.0.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:501a031947e3a9025ed4405a168e6ef5ae3126c59f90ce0cd6f2bfc477be31b7", size = 766454, upload-time = "2025-09-25T21:32:37.966Z" },
{ url = "https://files.pythonhosted.org/packages/02/9e/e5e9b168be58564121efb3de6859c452fccde0ab093d8438905899a3a483/pyyaml-6.0.3-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b3bc83488de33889877a0f2543ade9f70c67d66d9ebb4ac959502e12de895788", size = 836355, upload-time = "2025-09-25T21:32:39.178Z" },
{ url = "https://files.pythonhosted.org/packages/88/f9/16491d7ed2a919954993e48aa941b200f38040928474c9e85ea9e64222c3/pyyaml-6.0.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c458b6d084f9b935061bc36216e8a69a7e293a2f1e68bf956dcd9e6cbcd143f5", size = 794175, upload-time = "2025-09-25T21:32:40.865Z" },
{ url = "https://files.pythonhosted.org/packages/dd/3f/5989debef34dc6397317802b527dbbafb2b4760878a53d4166579111411e/pyyaml-6.0.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:7c6610def4f163542a622a73fb39f534f8c101d690126992300bf3207eab9764", size = 755228, upload-time = "2025-09-25T21:32:42.084Z" },
{ url = "https://files.pythonhosted.org/packages/d7/ce/af88a49043cd2e265be63d083fc75b27b6ed062f5f9fd6cdc223ad62f03e/pyyaml-6.0.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5190d403f121660ce8d1d2c1bb2ef1bd05b5f68533fc5c2ea899bd15f4399b35", size = 789194, upload-time = "2025-09-25T21:32:43.362Z" },
{ url = "https://files.pythonhosted.org/packages/23/20/bb6982b26a40bb43951265ba29d4c246ef0ff59c9fdcdf0ed04e0687de4d/pyyaml-6.0.3-cp314-cp314-win_amd64.whl", hash = "sha256:4a2e8cebe2ff6ab7d1050ecd59c25d4c8bd7e6f400f5f82b96557ac0abafd0ac", size = 156429, upload-time = "2025-09-25T21:32:57.844Z" },
{ url = "https://files.pythonhosted.org/packages/f4/f4/a4541072bb9422c8a883ab55255f918fa378ecf083f5b85e87fc2b4eda1b/pyyaml-6.0.3-cp314-cp314-win_arm64.whl", hash = "sha256:93dda82c9c22deb0a405ea4dc5f2d0cda384168e466364dec6255b293923b2f3", size = 143912, upload-time = "2025-09-25T21:32:59.247Z" },
{ url = "https://files.pythonhosted.org/packages/7c/f9/07dd09ae774e4616edf6cda684ee78f97777bdd15847253637a6f052a62f/pyyaml-6.0.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:02893d100e99e03eda1c8fd5c441d8c60103fd175728e23e431db1b589cf5ab3", size = 189108, upload-time = "2025-09-25T21:32:44.377Z" },
{ url = "https://files.pythonhosted.org/packages/4e/78/8d08c9fb7ce09ad8c38ad533c1191cf27f7ae1effe5bb9400a46d9437fcf/pyyaml-6.0.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:c1ff362665ae507275af2853520967820d9124984e0f7466736aea23d8611fba", size = 183641, upload-time = "2025-09-25T21:32:45.407Z" },
{ url = "https://files.pythonhosted.org/packages/7b/5b/3babb19104a46945cf816d047db2788bcaf8c94527a805610b0289a01c6b/pyyaml-6.0.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6adc77889b628398debc7b65c073bcb99c4a0237b248cacaf3fe8a557563ef6c", size = 831901, upload-time = "2025-09-25T21:32:48.83Z" },
{ url = "https://files.pythonhosted.org/packages/8b/cc/dff0684d8dc44da4d22a13f35f073d558c268780ce3c6ba1b87055bb0b87/pyyaml-6.0.3-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a80cb027f6b349846a3bf6d73b5e95e782175e52f22108cfa17876aaeff93702", size = 861132, upload-time = "2025-09-25T21:32:50.149Z" },
{ url = "https://files.pythonhosted.org/packages/b1/5e/f77dc6b9036943e285ba76b49e118d9ea929885becb0a29ba8a7c75e29fe/pyyaml-6.0.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00c4bdeba853cc34e7dd471f16b4114f4162dc03e6b7afcc2128711f0eca823c", size = 839261, upload-time = "2025-09-25T21:32:51.808Z" },
{ url = "https://files.pythonhosted.org/packages/ce/88/a9db1376aa2a228197c58b37302f284b5617f56a5d959fd1763fb1675ce6/pyyaml-6.0.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:66e1674c3ef6f541c35191caae2d429b967b99e02040f5ba928632d9a7f0f065", size = 805272, upload-time = "2025-09-25T21:32:52.941Z" },
{ url = "https://files.pythonhosted.org/packages/da/92/1446574745d74df0c92e6aa4a7b0b3130706a4142b2d1a5869f2eaa423c6/pyyaml-6.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:16249ee61e95f858e83976573de0f5b2893b3677ba71c9dd36b9cf8be9ac6d65", size = 829923, upload-time = "2025-09-25T21:32:54.537Z" },
{ url = "https://files.pythonhosted.org/packages/f0/7a/1c7270340330e575b92f397352af856a8c06f230aa3e76f86b39d01b416a/pyyaml-6.0.3-cp314-cp314t-win_amd64.whl", hash = "sha256:4ad1906908f2f5ae4e5a8ddfce73c320c2a1429ec52eafd27138b7f1cbe341c9", size = 174062, upload-time = "2025-09-25T21:32:55.767Z" },
{ url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" },
]
[[package]]
name = "safetensors"
version = "0.7.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/29/9c/6e74567782559a63bd040a236edca26fd71bc7ba88de2ef35d75df3bca5e/safetensors-0.7.0.tar.gz", hash = "sha256:07663963b67e8bd9f0b8ad15bb9163606cd27cc5a1b96235a50d8369803b96b0", size = 200878, upload-time = "2025-11-19T15:18:43.199Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/47/aef6c06649039accf914afef490268e1067ed82be62bcfa5b7e886ad15e8/safetensors-0.7.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c82f4d474cf725255d9e6acf17252991c3c8aac038d6ef363a4bf8be2f6db517", size = 467781, upload-time = "2025-11-19T15:18:35.84Z" },
{ url = "https://files.pythonhosted.org/packages/e8/00/374c0c068e30cd31f1e1b46b4b5738168ec79e7689ca82ee93ddfea05109/safetensors-0.7.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:94fd4858284736bb67a897a41608b5b0c2496c9bdb3bf2af1fa3409127f20d57", size = 447058, upload-time = "2025-11-19T15:18:34.416Z" },
{ url = "https://files.pythonhosted.org/packages/f1/06/578ffed52c2296f93d7fd2d844cabfa92be51a587c38c8afbb8ae449ca89/safetensors-0.7.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e07d91d0c92a31200f25351f4acb2bc6aff7f48094e13ebb1d0fb995b54b6542", size = 491748, upload-time = "2025-11-19T15:18:09.79Z" },
{ url = "https://files.pythonhosted.org/packages/ae/33/1debbbb70e4791dde185edb9413d1fe01619255abb64b300157d7f15dddd/safetensors-0.7.0-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8469155f4cb518bafb4acf4865e8bb9d6804110d2d9bdcaa78564b9fd841e104", size = 503881, upload-time = "2025-11-19T15:18:16.145Z" },
{ url = "https://files.pythonhosted.org/packages/8e/1c/40c2ca924d60792c3be509833df711b553c60effbd91da6f5284a83f7122/safetensors-0.7.0-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:54bef08bf00a2bff599982f6b08e8770e09cc012d7bba00783fc7ea38f1fb37d", size = 623463, upload-time = "2025-11-19T15:18:21.11Z" },
{ url = "https://files.pythonhosted.org/packages/9b/3a/13784a9364bd43b0d61eef4bea2845039bc2030458b16594a1bd787ae26e/safetensors-0.7.0-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:42cb091236206bb2016d245c377ed383aa7f78691748f3bb6ee1bfa51ae2ce6a", size = 532855, upload-time = "2025-11-19T15:18:25.719Z" },
{ url = "https://files.pythonhosted.org/packages/a0/60/429e9b1cb3fc651937727befe258ea24122d9663e4d5709a48c9cbfceecb/safetensors-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dac7252938f0696ddea46f5e855dd3138444e82236e3be475f54929f0c510d48", size = 507152, upload-time = "2025-11-19T15:18:33.023Z" },
{ url = "https://files.pythonhosted.org/packages/3c/a8/4b45e4e059270d17af60359713ffd83f97900d45a6afa73aaa0d737d48b6/safetensors-0.7.0-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1d060c70284127fa805085d8f10fbd0962792aed71879d00864acda69dbab981", size = 541856, upload-time = "2025-11-19T15:18:31.075Z" },
{ url = "https://files.pythonhosted.org/packages/06/87/d26d8407c44175d8ae164a95b5a62707fcc445f3c0c56108e37d98070a3d/safetensors-0.7.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:cdab83a366799fa730f90a4ebb563e494f28e9e92c4819e556152ad55e43591b", size = 674060, upload-time = "2025-11-19T15:18:37.211Z" },
{ url = "https://files.pythonhosted.org/packages/11/f5/57644a2ff08dc6325816ba7217e5095f17269dada2554b658442c66aed51/safetensors-0.7.0-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:672132907fcad9f2aedcb705b2d7b3b93354a2aec1b2f706c4db852abe338f85", size = 771715, upload-time = "2025-11-19T15:18:38.689Z" },
{ url = "https://files.pythonhosted.org/packages/86/31/17883e13a814bd278ae6e266b13282a01049b0c81341da7fd0e3e71a80a3/safetensors-0.7.0-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:5d72abdb8a4d56d4020713724ba81dac065fedb7f3667151c4a637f1d3fb26c0", size = 714377, upload-time = "2025-11-19T15:18:40.162Z" },
{ url = "https://files.pythonhosted.org/packages/4a/d8/0c8a7dc9b41dcac53c4cbf9df2b9c83e0e0097203de8b37a712b345c0be5/safetensors-0.7.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b0f6d66c1c538d5a94a73aa9ddca8ccc4227e6c9ff555322ea40bdd142391dd4", size = 677368, upload-time = "2025-11-19T15:18:41.627Z" },
{ url = "https://files.pythonhosted.org/packages/05/e5/cb4b713c8a93469e3c5be7c3f8d77d307e65fe89673e731f5c2bfd0a9237/safetensors-0.7.0-cp38-abi3-win32.whl", hash = "sha256:c74af94bf3ac15ac4d0f2a7c7b4663a15f8c2ab15ed0fc7531ca61d0835eccba", size = 326423, upload-time = "2025-11-19T15:18:45.74Z" },
{ url = "https://files.pythonhosted.org/packages/5d/e6/ec8471c8072382cb91233ba7267fd931219753bb43814cbc71757bfd4dab/safetensors-0.7.0-cp38-abi3-win_amd64.whl", hash = "sha256:d1239932053f56f3456f32eb9625590cc7582e905021f94636202a864d470755", size = 341380, upload-time = "2025-11-19T15:18:44.427Z" },
]
[[package]]
name = "shellingham"
version = "1.5.4"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310, upload-time = "2023-10-24T04:13:40.426Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" },
]
[[package]]
name = "tqdm"
version = "4.67.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" },
]
[[package]]
name = "typer-slim"
version = "0.21.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "click" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/17/d4/064570dec6358aa9049d4708e4a10407d74c99258f8b2136bb8702303f1a/typer_slim-0.21.1.tar.gz", hash = "sha256:73495dd08c2d0940d611c5a8c04e91c2a0a98600cbd4ee19192255a233b6dbfd", size = 110478, upload-time = "2026-01-06T11:21:11.176Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c8/0a/4aca634faf693e33004796b6cee0ae2e1dba375a800c16ab8d3eff4bb800/typer_slim-0.21.1-py3-none-any.whl", hash = "sha256:6e6c31047f171ac93cc5a973c9e617dbc5ab2bddc4d0a3135dc161b4e2020e0d", size = 47444, upload-time = "2026-01-06T11:21:12.441Z" },
]
[[package]]
name = "typing-extensions"
version = "4.15.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" },
]