initial commit (LTX-2)
This commit is contained in:
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
.env
|
||||||
|
claude.md
|
||||||
|
.DS_Store
|
||||||
|
**.pyc
|
||||||
|
__pycache__/*
|
||||||
394
main.py
Normal file
394
main.py
Normal file
@@ -0,0 +1,394 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
|
||||||
|
from mlx_video.models.ltx.ltx import LTXModel
|
||||||
|
from mlx_video.models.ltx.transformer import Modality
|
||||||
|
from mlx_video.convert import sanitize_transformer_weights
|
||||||
|
from mlx_video.generate import create_position_grid
|
||||||
|
from mlx_video.utils import to_denoised
|
||||||
|
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||||
|
from mlx_video.models.ltx.upsampler import LatentUpsampler, load_upsampler, upsample_latents
|
||||||
|
|
||||||
|
# Paths
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
|
||||||
|
LTX2_REPO = "Lightricks/LTX-2"
|
||||||
|
|
||||||
|
def get_ltx2_cache_dir():
|
||||||
|
# Try to get local cache (local_only), will not download files
|
||||||
|
try:
|
||||||
|
ref_path = snapshot_download(
|
||||||
|
repo_id=LTX2_REPO,
|
||||||
|
local_files_only=True,
|
||||||
|
allow_patterns=["*"],
|
||||||
|
ignore_patterns=[],
|
||||||
|
# leave as default revision and cache_dir, only local
|
||||||
|
)
|
||||||
|
return ref_path
|
||||||
|
except Exception:
|
||||||
|
# If not present locally, download from hub
|
||||||
|
return snapshot_download(
|
||||||
|
repo_id=LTX2_REPO,
|
||||||
|
local_files_only=False,
|
||||||
|
resume_download=True,
|
||||||
|
allow_patterns=["*.safetensors", "*.json"],
|
||||||
|
ignore_patterns=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
LTX2_PATH = Path(get_ltx2_cache_dir())
|
||||||
|
MODEL_PATH = str(LTX2_PATH / 'ltx-2-19b-distilled.safetensors')
|
||||||
|
UPSAMPLER_PATH = str(LTX2_PATH / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')
|
||||||
|
TEXT_ENCODER_PATH = str(LTX2_PATH / 'text_encoder')
|
||||||
|
TOKENIZER_PATH = str(LTX2_PATH / 'tokenizer')
|
||||||
|
|
||||||
|
# Distilled sigma schedules (from PyTorch)
|
||||||
|
STAGE_1_SIGMA_SCHEDULE = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
|
||||||
|
STAGE_2_SIGMA_SCHEDULE = [0.909375, 0.725, 0.421875, 0.0] # Refinement steps
|
||||||
|
|
||||||
|
|
||||||
|
def denoise_loop(
|
||||||
|
latents: mx.array,
|
||||||
|
positions: mx.array,
|
||||||
|
text_embeddings: mx.array,
|
||||||
|
transformer: LTXModel,
|
||||||
|
sigma_schedule: list,
|
||||||
|
stage_name: str = "Stage",
|
||||||
|
negative_embeddings: mx.array = None,
|
||||||
|
cfg_scale: float = 1.0,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Run denoising loop for given sigma schedule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latents: Noisy latent tensor
|
||||||
|
positions: Position embeddings
|
||||||
|
text_embeddings: Positive prompt embeddings
|
||||||
|
transformer: The transformer model
|
||||||
|
sigma_schedule: List of sigma values for each step
|
||||||
|
stage_name: Name for logging
|
||||||
|
negative_embeddings: Negative prompt embeddings for CFG (optional)
|
||||||
|
cfg_scale: Classifier-free guidance scale (1.0 = no guidance)
|
||||||
|
"""
|
||||||
|
use_cfg = negative_embeddings is not None and cfg_scale > 1.0
|
||||||
|
|
||||||
|
for i in range(len(sigma_schedule) - 1):
|
||||||
|
sigma = sigma_schedule[i]
|
||||||
|
sigma_next = sigma_schedule[i + 1]
|
||||||
|
|
||||||
|
print(f" {stage_name} step {i+1}/{len(sigma_schedule)-1}: sigma={sigma:.4f} -> {sigma_next:.4f}")
|
||||||
|
|
||||||
|
b, c, f, h, w = latents.shape
|
||||||
|
latents_flat = mx.reshape(latents, (b, c, -1))
|
||||||
|
latents_flat = mx.transpose(latents_flat, (0, 2, 1))
|
||||||
|
|
||||||
|
timesteps = mx.full((1,), sigma)
|
||||||
|
|
||||||
|
# Positive (conditioned) prediction
|
||||||
|
video_modality = Modality(
|
||||||
|
latent=latents_flat,
|
||||||
|
timesteps=timesteps,
|
||||||
|
positions=positions,
|
||||||
|
context=text_embeddings,
|
||||||
|
context_mask=None,
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
vx_cond, _ = transformer(video=video_modality, audio=None)
|
||||||
|
mx.eval(vx_cond)
|
||||||
|
|
||||||
|
if use_cfg:
|
||||||
|
# Negative (unconditioned) prediction
|
||||||
|
video_modality_neg = Modality(
|
||||||
|
latent=latents_flat,
|
||||||
|
timesteps=timesteps,
|
||||||
|
positions=positions,
|
||||||
|
context=negative_embeddings,
|
||||||
|
context_mask=None,
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
vx_uncond, _ = transformer(video=video_modality_neg, audio=None)
|
||||||
|
mx.eval(vx_uncond)
|
||||||
|
|
||||||
|
# CFG: output = uncond + cfg_scale * (cond - uncond)
|
||||||
|
vx = vx_uncond + cfg_scale * (vx_cond - vx_uncond)
|
||||||
|
else:
|
||||||
|
vx = vx_cond
|
||||||
|
|
||||||
|
vx_reshaped = mx.transpose(vx, (0, 2, 1))
|
||||||
|
vx_reshaped = mx.reshape(vx_reshaped, (b, c, f, h, w))
|
||||||
|
|
||||||
|
# Debug: Print velocity stats
|
||||||
|
vx_np = np.array(vx_reshaped)
|
||||||
|
print(f" Velocity: min={vx_np.min():.4f}, max={vx_np.max():.4f}, mean={vx_np.mean():.4f}")
|
||||||
|
|
||||||
|
# Get denoised prediction: x_0 = x_t - sigma * velocity
|
||||||
|
denoised = to_denoised(latents, vx_reshaped, sigma)
|
||||||
|
mx.eval(denoised)
|
||||||
|
|
||||||
|
# Debug: Print denoised stats
|
||||||
|
denoised_np = np.array(denoised)
|
||||||
|
print(f" Denoised: min={denoised_np.min():.4f}, max={denoised_np.max():.4f}, mean={denoised_np.mean():.4f}")
|
||||||
|
|
||||||
|
# Euler step: x_next = x_0 + sigma_next * (x_t - x_0) / sigma
|
||||||
|
if sigma_next > 0:
|
||||||
|
velocity = (latents - denoised) / sigma
|
||||||
|
latents = denoised + sigma_next * velocity
|
||||||
|
else:
|
||||||
|
latents = denoised
|
||||||
|
mx.eval(latents)
|
||||||
|
|
||||||
|
# Debug: Print latents after step
|
||||||
|
latents_np = np.array(latents)
|
||||||
|
print(f" Latents after step: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("="*60)
|
||||||
|
print("MLX LTX-2 Video Generation (Two-Stage)")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
# Config - same as PyTorch reference
|
||||||
|
prompt = "A beautiful woman with flowing dark hair stands on a tropical beach at golden hour, gentle waves lapping at her feet, she turns and smiles at the camera, warm sunlight illuminating her face, palm trees swaying in the background, cinematic lighting, photorealistic"
|
||||||
|
negative_prompt = "" # PyTorch script doesn't use negative prompt
|
||||||
|
cfg_scale = 1.0 # No CFG in the distilled pipeline
|
||||||
|
height, width, num_frames = 512, 512, 500 # Must be divisible by 64 for two-stage
|
||||||
|
seed = 123
|
||||||
|
|
||||||
|
# Stage 1: Half resolution
|
||||||
|
stage1_height = height // 2
|
||||||
|
stage1_width = width // 2
|
||||||
|
stage1_latent_height = stage1_height // 32
|
||||||
|
stage1_latent_width = stage1_width // 32
|
||||||
|
latent_frames = 1 + (num_frames - 1) // 8
|
||||||
|
|
||||||
|
# Stage 2: Full resolution
|
||||||
|
latent_height = height // 32
|
||||||
|
latent_width = width // 32
|
||||||
|
|
||||||
|
print(f"\nConfig:")
|
||||||
|
print(f" Prompt: {prompt}")
|
||||||
|
print(f" Negative prompt: '{negative_prompt}'")
|
||||||
|
print(f" CFG scale: {cfg_scale}")
|
||||||
|
print(f" Final resolution: {width}x{height}, {num_frames} frames")
|
||||||
|
print(f" Stage 1: {stage1_width}x{stage1_height} -> latent {stage1_latent_width}x{stage1_latent_height}")
|
||||||
|
print(f" Stage 2: {width}x{height} -> latent {latent_width}x{latent_height}")
|
||||||
|
print(f" Seed: {seed}")
|
||||||
|
|
||||||
|
mx.random.seed(seed)
|
||||||
|
|
||||||
|
# Load text encoder
|
||||||
|
print("\nLoading text encoder...")
|
||||||
|
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
|
||||||
|
|
||||||
|
text_encoder = LTX2TextEncoder(model_path=str(LTX2_PATH))
|
||||||
|
text_encoder.load(str(LTX2_PATH))
|
||||||
|
mx.eval(text_encoder.parameters())
|
||||||
|
|
||||||
|
# Encode positive prompt
|
||||||
|
print("Encoding text...")
|
||||||
|
text_embeddings, attention_mask = text_encoder(prompt)
|
||||||
|
mx.eval(text_embeddings)
|
||||||
|
print(f" Positive embeddings: {text_embeddings.shape}")
|
||||||
|
|
||||||
|
# Encode negative prompt for CFG
|
||||||
|
negative_embeddings, _ = text_encoder(negative_prompt)
|
||||||
|
mx.eval(negative_embeddings)
|
||||||
|
print(f" Negative embeddings: {negative_embeddings.shape}")
|
||||||
|
|
||||||
|
# Free text encoder memory
|
||||||
|
del text_encoder
|
||||||
|
mx.clear_cache()
|
||||||
|
|
||||||
|
# Load transformer
|
||||||
|
print("\nLoading transformer...")
|
||||||
|
raw_weights = mx.load(MODEL_PATH)
|
||||||
|
sanitized = sanitize_transformer_weights(raw_weights)
|
||||||
|
|
||||||
|
config = LTXModelConfig(
|
||||||
|
model_type=LTXModelType.VideoOnly,
|
||||||
|
num_attention_heads=32,
|
||||||
|
attention_head_dim=128,
|
||||||
|
in_channels=128,
|
||||||
|
out_channels=128,
|
||||||
|
num_layers=48,
|
||||||
|
cross_attention_dim=4096,
|
||||||
|
caption_channels=3840,
|
||||||
|
rope_type=LTXRopeType.SPLIT,
|
||||||
|
double_precision_rope=True,
|
||||||
|
positional_embedding_theta=10000.0,
|
||||||
|
positional_embedding_max_pos=[20, 2048, 2048],
|
||||||
|
use_middle_indices_grid=True,
|
||||||
|
timestep_scale_multiplier=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
transformer = LTXModel(config)
|
||||||
|
transformer.load_weights(list(sanitized.items()), strict=False)
|
||||||
|
mx.eval(transformer.parameters())
|
||||||
|
print(" Transformer loaded!")
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# Stage 1: Generate at half resolution
|
||||||
|
# ========================================
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("Stage 1: Generating at half resolution")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
mx.random.seed(seed)
|
||||||
|
latents = mx.random.normal((1, 128, latent_frames, stage1_latent_height, stage1_latent_width))
|
||||||
|
mx.eval(latents)
|
||||||
|
print(f" Initial latents: {latents.shape}")
|
||||||
|
|
||||||
|
positions = create_position_grid(1, latent_frames, stage1_latent_height, stage1_latent_width)
|
||||||
|
mx.eval(positions)
|
||||||
|
|
||||||
|
latents = denoise_loop(
|
||||||
|
latents=latents,
|
||||||
|
positions=positions,
|
||||||
|
text_embeddings=text_embeddings,
|
||||||
|
transformer=transformer,
|
||||||
|
sigma_schedule=STAGE_1_SIGMA_SCHEDULE,
|
||||||
|
stage_name="Stage 1",
|
||||||
|
negative_embeddings=negative_embeddings,
|
||||||
|
cfg_scale=cfg_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\nStage 1 latents: {latents.shape}")
|
||||||
|
latents_np = np.array(latents)
|
||||||
|
print(f" Stats: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# Upsample latents 2x
|
||||||
|
# ========================================
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("Upsampling latents 2x")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
# Load upsampler
|
||||||
|
print(" Loading spatial upsampler...")
|
||||||
|
upsampler = load_upsampler(UPSAMPLER_PATH)
|
||||||
|
mx.eval(upsampler.parameters())
|
||||||
|
|
||||||
|
# Load latent statistics for normalization
|
||||||
|
vae_decoder = load_vae_decoder(MODEL_PATH, timestep_conditioning=True)
|
||||||
|
# EXPERIMENT: Disable VAE decode noise for sharper output
|
||||||
|
# vae_decoder.decode_noise_scale = 0.0
|
||||||
|
# print(f" VAE decode_noise_scale set to {vae_decoder.decode_noise_scale}")
|
||||||
|
latent_mean = vae_decoder.latents_mean
|
||||||
|
latent_std = vae_decoder.latents_std
|
||||||
|
|
||||||
|
# Upsample
|
||||||
|
print(" Upsampling...")
|
||||||
|
latents = upsample_latents(latents, upsampler, latent_mean, latent_std, debug=False)
|
||||||
|
mx.eval(latents)
|
||||||
|
print(f" Upsampled latents: {latents.shape}")
|
||||||
|
|
||||||
|
# Free upsampler memory
|
||||||
|
del upsampler
|
||||||
|
mx.clear_cache()
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# Stage 2: Refine at full resolution
|
||||||
|
# ========================================
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("Stage 2: Refining at full resolution")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
# Debug: Print upsampled latent stats before adding noise
|
||||||
|
latents_np = np.array(latents)
|
||||||
|
print(f" Upsampled latents (before noise): min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
|
||||||
|
|
||||||
|
# Create new position grid for full resolution
|
||||||
|
positions = create_position_grid(1, latent_frames, latent_height, latent_width)
|
||||||
|
mx.eval(positions)
|
||||||
|
|
||||||
|
# Add noise at initial sigma for stage 2
|
||||||
|
# PyTorch uses interpolation: noisy = noise * scale + clean * (1 - scale)
|
||||||
|
# NOT addition: noisy = clean + scale * noise
|
||||||
|
noise_scale = STAGE_2_SIGMA_SCHEDULE[0]
|
||||||
|
noise = mx.random.normal(latents.shape)
|
||||||
|
latents = noise * noise_scale + latents * (1 - noise_scale)
|
||||||
|
mx.eval(latents)
|
||||||
|
|
||||||
|
# Debug: Print latents after adding noise
|
||||||
|
latents_np = np.array(latents)
|
||||||
|
print(f" After adding noise (sigma={noise_scale}): min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
|
||||||
|
|
||||||
|
latents = denoise_loop(
|
||||||
|
latents=latents,
|
||||||
|
positions=positions,
|
||||||
|
text_embeddings=text_embeddings,
|
||||||
|
transformer=transformer,
|
||||||
|
sigma_schedule=STAGE_2_SIGMA_SCHEDULE,
|
||||||
|
stage_name="Stage 2",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\nFinal latents: {latents.shape}")
|
||||||
|
latents_np = np.array(latents)
|
||||||
|
print(f" Stats: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
|
||||||
|
|
||||||
|
# Save latents for PyTorch comparison
|
||||||
|
np.save("mlx_final_latents.npy", latents_np)
|
||||||
|
print(" Saved latents to mlx_final_latents.npy")
|
||||||
|
|
||||||
|
# Free transformer memory
|
||||||
|
del transformer
|
||||||
|
mx.clear_cache()
|
||||||
|
|
||||||
|
# ========================================
|
||||||
|
# Decode to video
|
||||||
|
# ========================================
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("Decoding with VAE")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
# Decode latents to video
|
||||||
|
video = vae_decoder(latents, debug=True)
|
||||||
|
mx.eval(video)
|
||||||
|
print(f" Video shape: {video.shape}")
|
||||||
|
|
||||||
|
# Convert to frames
|
||||||
|
video = mx.squeeze(video, axis=0) # (C, F, H, W)
|
||||||
|
|
||||||
|
# Debug: check raw RGB values before conversion
|
||||||
|
video_raw = np.array(video)
|
||||||
|
print(f" Raw video per-channel means: R={video_raw[0].mean():.4f}, G={video_raw[1].mean():.4f}, B={video_raw[2].mean():.4f}")
|
||||||
|
print(f" Raw video range: [{video_raw.min():.4f}, {video_raw.max():.4f}]")
|
||||||
|
|
||||||
|
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
|
||||||
|
video = (video + 1.0) / 2.0 # [-1, 1] -> [0, 1]
|
||||||
|
video = mx.clip(video, 0.0, 1.0)
|
||||||
|
video = (video * 255).astype(mx.uint8)
|
||||||
|
video_np = np.array(video)
|
||||||
|
|
||||||
|
print(f" Converted video RGB means: R={video_np[:,:,:,0].mean():.1f}, G={video_np[:,:,:,1].mean():.1f}, B={video_np[:,:,:,2].mean():.1f}")
|
||||||
|
|
||||||
|
# Save first frame
|
||||||
|
output_path = Path("mlx_output_frame0_2.png")
|
||||||
|
Image.fromarray(video_np[0]).save(output_path)
|
||||||
|
print(f"\nSaved first frame to {output_path}")
|
||||||
|
|
||||||
|
# Save video
|
||||||
|
try:
|
||||||
|
import imageio
|
||||||
|
video_path = "mlx_output_video_2.mp4"
|
||||||
|
imageio.mimwrite(video_path, video_np, fps=24, codec='libx264')
|
||||||
|
print(f"Saved video to {video_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Could not save video: {e}")
|
||||||
|
|
||||||
|
print("\nDone!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
main()
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"Time taken: {end_time - start_time} seconds")
|
||||||
13
mlx_video/__init__.py
Normal file
13
mlx_video/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
|
||||||
|
from mlx_video.models.ltx import LTXModel, LTXModelConfig
|
||||||
|
from mlx_video.generate import LTXVideoPipeline, GenerationConfig
|
||||||
|
from mlx_video.convert import load_transformer_weights, load_vae_weights
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LTXModel",
|
||||||
|
"LTXModelConfig",
|
||||||
|
"LTXVideoPipeline",
|
||||||
|
"GenerationConfig",
|
||||||
|
"load_transformer_weights",
|
||||||
|
"load_vae_weights",
|
||||||
|
]
|
||||||
457
mlx_video/convert.py
Normal file
457
mlx_video/convert.py
Normal file
@@ -0,0 +1,457 @@
|
|||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
|
||||||
|
from mlx_video.models.ltx.ltx import LTXModel
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_path(
|
||||||
|
path_or_hf_repo: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
) -> Path:
|
||||||
|
"""Get local path to model, downloading if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path_or_hf_repo: Local path or HuggingFace repo ID
|
||||||
|
revision: Git revision for HF repo
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to model directory
|
||||||
|
"""
|
||||||
|
model_path = Path(path_or_hf_repo)
|
||||||
|
|
||||||
|
if model_path.exists():
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
# Download from HuggingFace
|
||||||
|
model_path = Path(
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=path_or_hf_repo,
|
||||||
|
revision=revision,
|
||||||
|
allow_patterns=[
|
||||||
|
"*.safetensors",
|
||||||
|
"*.json",
|
||||||
|
"config.json",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_safetensors(path: Path) -> Dict[str, mx.array]:
|
||||||
|
"""Load weights from safetensors file(s) using MLX.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to model directory or single safetensors file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of weights
|
||||||
|
"""
|
||||||
|
weights = {}
|
||||||
|
|
||||||
|
if path.is_file():
|
||||||
|
# Single file - use mx.load directly (handles bfloat16)
|
||||||
|
return mx.load(str(path))
|
||||||
|
else:
|
||||||
|
# Directory - load all safetensors files
|
||||||
|
safetensor_files = list(path.glob("*.safetensors"))
|
||||||
|
for sf_path in safetensor_files:
|
||||||
|
file_weights = mx.load(str(sf_path))
|
||||||
|
weights.update(file_weights)
|
||||||
|
|
||||||
|
return weights
|
||||||
|
|
||||||
|
|
||||||
|
def load_transformer_weights(model_path: Path) -> Dict[str, mx.array]:
|
||||||
|
"""Load transformer weights from LTX-2 model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to LTX-2 model directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of transformer weights
|
||||||
|
"""
|
||||||
|
# Try distilled model first, then dev
|
||||||
|
weight_files = [
|
||||||
|
model_path / "ltx-2-19b-distilled.safetensors",
|
||||||
|
model_path / "ltx-2-19b-dev.safetensors",
|
||||||
|
]
|
||||||
|
|
||||||
|
for weight_file in weight_files:
|
||||||
|
if weight_file.exists():
|
||||||
|
print(f"Loading transformer weights from {weight_file.name}...")
|
||||||
|
return mx.load(str(weight_file))
|
||||||
|
|
||||||
|
raise FileNotFoundError(f"No transformer weights found in {model_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_vae_weights(model_path: Path) -> Dict[str, mx.array]:
|
||||||
|
"""Load VAE weights from LTX-2 model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to LTX-2 model directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of VAE weights
|
||||||
|
"""
|
||||||
|
vae_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
|
||||||
|
if vae_path.exists():
|
||||||
|
print(f"Loading VAE weights from {vae_path}...")
|
||||||
|
return mx.load(str(vae_path))
|
||||||
|
|
||||||
|
raise FileNotFoundError(f"VAE weights not found at {vae_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
|
"""Sanitize transformer weight names from PyTorch LTX-2 format to MLX format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights: Dictionary of weights with PyTorch naming
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with MLX-compatible naming for transformer
|
||||||
|
"""
|
||||||
|
sanitized = {}
|
||||||
|
|
||||||
|
for key, value in weights.items():
|
||||||
|
new_key = key
|
||||||
|
|
||||||
|
# Skip non-transformer weights (VAE, vocoder, audio_vae, connectors)
|
||||||
|
if not key.startswith("model.diffusion_model."):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Remove 'model.diffusion_model.' prefix
|
||||||
|
new_key = key.replace("model.diffusion_model.", "")
|
||||||
|
|
||||||
|
# Handle to_out.0 -> to_out (MLX doesn't use Sequential numbering)
|
||||||
|
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||||
|
|
||||||
|
# Handle feed-forward net naming
|
||||||
|
# PyTorch: ff.net.0.proj -> ff.net_0_proj (or similar)
|
||||||
|
# MLX FeedForward: uses proj_in, proj_out
|
||||||
|
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||||
|
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||||
|
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
|
||||||
|
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
||||||
|
|
||||||
|
# Handle AdaLN naming - keep emb wrapper, just fix linear naming
|
||||||
|
# PyTorch: adaln_single.emb.timestep_embedder.linear_1 -> adaln_single.emb.timestep_embedder.linear1
|
||||||
|
new_key = new_key.replace(".linear_1.", ".linear1.")
|
||||||
|
new_key = new_key.replace(".linear_2.", ".linear2.")
|
||||||
|
|
||||||
|
# Handle caption projection (keep linear1/linear2 naming for compatibility)
|
||||||
|
# These are already mapped correctly in the sanitization
|
||||||
|
|
||||||
|
sanitized[new_key] = value
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
|
"""Sanitize VAE weight names from PyTorch format to MLX format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights: Dictionary of weights with PyTorch naming
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with MLX-compatible naming for VAE
|
||||||
|
"""
|
||||||
|
sanitized = {}
|
||||||
|
|
||||||
|
for key, value in weights.items():
|
||||||
|
new_key = key
|
||||||
|
|
||||||
|
# Skip position_ids (not needed)
|
||||||
|
if "position_ids" in key:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle Conv3d weight shape conversion
|
||||||
|
# PyTorch: (out_channels, in_channels, D, H, W)
|
||||||
|
# MLX: (out_channels, D, H, W, in_channels)
|
||||||
|
if "conv" in key.lower() and "weight" in key and value.ndim == 5:
|
||||||
|
# Transpose from (O, I, D, H, W) to (O, D, H, W, I)
|
||||||
|
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||||
|
|
||||||
|
# Handle Conv2d weight shape conversion
|
||||||
|
# PyTorch: (out_channels, in_channels, H, W)
|
||||||
|
# MLX: (out_channels, H, W, in_channels)
|
||||||
|
if "conv" in key.lower() and "weight" in key and value.ndim == 4:
|
||||||
|
value = mx.transpose(value, (0, 2, 3, 1))
|
||||||
|
|
||||||
|
sanitized[new_key] = value
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
|
"""Sanitize weight names from PyTorch format to MLX format.
|
||||||
|
|
||||||
|
Generic function that handles both transformer and VAE weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights: Dictionary of weights with PyTorch naming
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with MLX-compatible naming
|
||||||
|
"""
|
||||||
|
sanitized = {}
|
||||||
|
|
||||||
|
for key, value in weights.items():
|
||||||
|
new_key = key
|
||||||
|
|
||||||
|
# Skip position_ids (not needed)
|
||||||
|
if "position_ids" in key:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle transformer weights
|
||||||
|
if key.startswith("model.diffusion_model."):
|
||||||
|
new_key = key.replace("model.diffusion_model.", "")
|
||||||
|
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||||
|
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||||
|
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||||
|
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
|
||||||
|
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
||||||
|
new_key = new_key.replace(".linear_1.", ".linear1.")
|
||||||
|
new_key = new_key.replace(".linear_2.", ".linear2.")
|
||||||
|
|
||||||
|
# Handle Conv3d weight shape conversion
|
||||||
|
# PyTorch: (out_channels, in_channels, D, H, W)
|
||||||
|
# MLX: (out_channels, D, H, W, in_channels)
|
||||||
|
if "conv" in key.lower() and "weight" in key and value.ndim == 5:
|
||||||
|
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||||
|
|
||||||
|
# Handle Conv2d weight shape conversion
|
||||||
|
# PyTorch: (out_channels, in_channels, H, W)
|
||||||
|
# MLX: (out_channels, H, W, in_channels)
|
||||||
|
if "conv" in key.lower() and "weight" in key and value.ndim == 4:
|
||||||
|
value = mx.transpose(value, (0, 2, 3, 1))
|
||||||
|
|
||||||
|
sanitized[new_key] = value
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(model_path: Path) -> Dict[str, Any]:
|
||||||
|
"""Load model configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to model directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configuration dictionary
|
||||||
|
"""
|
||||||
|
config_path = model_path / "config.json"
|
||||||
|
|
||||||
|
if config_path.exists():
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
# Return default config
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def create_model_from_config(config: Dict[str, Any]) -> LTXModel:
|
||||||
|
"""Create model instance from configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LTXModel instance
|
||||||
|
"""
|
||||||
|
# Map config to LTXModelConfig
|
||||||
|
model_config = LTXModelConfig(
|
||||||
|
model_type=LTXModelType.AudioVideo,
|
||||||
|
num_attention_heads=config.get("num_attention_heads", 32),
|
||||||
|
attention_head_dim=config.get("attention_head_dim", 128),
|
||||||
|
in_channels=config.get("in_channels", 128),
|
||||||
|
out_channels=config.get("out_channels", 128),
|
||||||
|
num_layers=config.get("num_layers", 48),
|
||||||
|
cross_attention_dim=config.get("cross_attention_dim", 4096),
|
||||||
|
caption_channels=config.get("caption_channels", 3840),
|
||||||
|
audio_num_attention_heads=config.get("audio_num_attention_heads", 32),
|
||||||
|
audio_attention_head_dim=config.get("audio_attention_head_dim", 64),
|
||||||
|
audio_in_channels=config.get("audio_in_channels", 128),
|
||||||
|
audio_out_channels=config.get("audio_out_channels", 128),
|
||||||
|
audio_cross_attention_dim=config.get("audio_cross_attention_dim", 2048),
|
||||||
|
positional_embedding_theta=config.get("positional_embedding_theta", 10000.0),
|
||||||
|
positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]),
|
||||||
|
audio_positional_embedding_max_pos=config.get("audio_positional_embedding_max_pos", [20]),
|
||||||
|
timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000),
|
||||||
|
av_ca_timestep_scale_multiplier=config.get("av_ca_timestep_scale_multiplier", 1),
|
||||||
|
norm_eps=config.get("norm_eps", 1e-6),
|
||||||
|
)
|
||||||
|
|
||||||
|
return LTXModel(model_config)
|
||||||
|
|
||||||
|
|
||||||
|
def convert(
|
||||||
|
hf_path: str,
|
||||||
|
mlx_path: str = "mlx_model",
|
||||||
|
dtype: Optional[str] = None,
|
||||||
|
quantize: bool = False,
|
||||||
|
q_bits: int = 4,
|
||||||
|
q_group_size: int = 64,
|
||||||
|
) -> Path:
|
||||||
|
"""Convert HuggingFace model to MLX format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hf_path: HuggingFace model path or repo ID
|
||||||
|
mlx_path: Output path for MLX model
|
||||||
|
dtype: Target dtype (float16, float32, bfloat16)
|
||||||
|
quantize: Whether to quantize the model
|
||||||
|
q_bits: Quantization bits
|
||||||
|
q_group_size: Quantization group size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to converted model
|
||||||
|
"""
|
||||||
|
print(f"Loading model from {hf_path}...")
|
||||||
|
model_path = get_model_path(hf_path)
|
||||||
|
|
||||||
|
# Load config
|
||||||
|
config = load_config(model_path)
|
||||||
|
|
||||||
|
# Load weights
|
||||||
|
print("Loading weights...")
|
||||||
|
weights = load_safetensors(model_path)
|
||||||
|
|
||||||
|
# Sanitize weights
|
||||||
|
print("Sanitizing weights...")
|
||||||
|
weights = sanitize_weights(weights)
|
||||||
|
|
||||||
|
# Convert dtype if specified
|
||||||
|
if dtype is not None:
|
||||||
|
dtype_map = {
|
||||||
|
"float16": mx.float16,
|
||||||
|
"float32": mx.float32,
|
||||||
|
"bfloat16": mx.bfloat16,
|
||||||
|
}
|
||||||
|
target_dtype = dtype_map.get(dtype, mx.float16)
|
||||||
|
print(f"Converting to {dtype}...")
|
||||||
|
weights = {
|
||||||
|
k: v.astype(target_dtype) if v.dtype in [mx.float32, mx.float16, mx.bfloat16] else v
|
||||||
|
for k, v in weights.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create output directory
|
||||||
|
output_path = Path(mlx_path)
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save weights
|
||||||
|
print(f"Saving weights to {output_path}...")
|
||||||
|
save_weights(output_path, weights)
|
||||||
|
|
||||||
|
# Save config
|
||||||
|
config_out_path = output_path / "config.json"
|
||||||
|
with open(config_out_path, "w") as f:
|
||||||
|
json.dump(config, f, indent=2)
|
||||||
|
|
||||||
|
print(f"Model converted successfully to {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def save_weights(path: Path, weights: Dict[str, mx.array]) -> None:
|
||||||
|
"""Save weights in safetensors format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Output directory
|
||||||
|
weights: Dictionary of weights
|
||||||
|
"""
|
||||||
|
from safetensors.numpy import save_file
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Convert to numpy for safetensors
|
||||||
|
np_weights = {k: np.array(v) for k, v in weights.items()}
|
||||||
|
|
||||||
|
# Save to file
|
||||||
|
save_file(np_weights, path / "model.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(
|
||||||
|
path_or_hf_repo: str,
|
||||||
|
lazy: bool = False,
|
||||||
|
) -> LTXModel:
|
||||||
|
"""Load LTX model from path or HuggingFace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path_or_hf_repo: Path to model or HuggingFace repo ID
|
||||||
|
lazy: Whether to use lazy loading
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded LTXModel
|
||||||
|
"""
|
||||||
|
model_path = get_model_path(path_or_hf_repo)
|
||||||
|
|
||||||
|
# Load config
|
||||||
|
config = load_config(model_path)
|
||||||
|
|
||||||
|
# Create model
|
||||||
|
model = create_model_from_config(config)
|
||||||
|
|
||||||
|
# Load weights
|
||||||
|
weights = load_safetensors(model_path)
|
||||||
|
|
||||||
|
# Sanitize if needed
|
||||||
|
weights = sanitize_weights(weights)
|
||||||
|
|
||||||
|
# Load weights into model
|
||||||
|
model.load_weights(list(weights.items()))
|
||||||
|
|
||||||
|
if not lazy:
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Convert LTX-2 model to MLX format")
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf-path",
|
||||||
|
type=str,
|
||||||
|
default="Lightricks/LTX-2",
|
||||||
|
help="HuggingFace model path or repo ID",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mlx-path",
|
||||||
|
type=str,
|
||||||
|
default="mlx_model",
|
||||||
|
help="Output path for MLX model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["float16", "float32", "bfloat16"],
|
||||||
|
default="float16",
|
||||||
|
help="Target dtype",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantize",
|
||||||
|
action="store_true",
|
||||||
|
help="Quantize the model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--q-bits",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Quantization bits",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
convert(
|
||||||
|
hf_path=args.hf_path,
|
||||||
|
mlx_path=args.mlx_path,
|
||||||
|
dtype=args.dtype,
|
||||||
|
quantize=args.quantize,
|
||||||
|
q_bits=args.q_bits,
|
||||||
|
)
|
||||||
586
mlx_video/generate.py
Normal file
586
mlx_video/generate.py
Normal file
@@ -0,0 +1,586 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple, Iterator, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mlx_video.models.ltx.ltx import LTXModel, X0Model
|
||||||
|
from mlx_video.models.ltx.transformer import Modality
|
||||||
|
from mlx_video.models.ltx.video_vae import VideoEncoder, VideoDecoder
|
||||||
|
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder, load_text_encoder
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GenerationConfig:
|
||||||
|
"""Configuration for video generation."""
|
||||||
|
# Video dimensions
|
||||||
|
height: int = 512
|
||||||
|
width: int = 512
|
||||||
|
num_frames: int = 33 # Must be 1 + 8*k
|
||||||
|
|
||||||
|
# Diffusion parameters
|
||||||
|
num_inference_steps: int = 8 # For distilled model (ignored if use_distilled=True)
|
||||||
|
guidance_scale: float = 3.0
|
||||||
|
use_distilled: bool = True # Use hardcoded sigma values for distilled model
|
||||||
|
|
||||||
|
# Latent dimensions (computed from video dimensions)
|
||||||
|
@property
|
||||||
|
def latent_height(self) -> int:
|
||||||
|
return self.height // 32
|
||||||
|
|
||||||
|
@property
|
||||||
|
def latent_width(self) -> int:
|
||||||
|
return self.width // 32
|
||||||
|
|
||||||
|
@property
|
||||||
|
def latent_frames(self) -> int:
|
||||||
|
return 1 + (self.num_frames - 1) // 8
|
||||||
|
|
||||||
|
|
||||||
|
# Hardcoded sigma values for distilled model (from LTX-2 pipeline)
|
||||||
|
# These were tuned to match the distillation process
|
||||||
|
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
|
||||||
|
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0]
|
||||||
|
|
||||||
|
# Scheduler constants for dynamic sigma computation (non-distilled models)
|
||||||
|
BASE_SHIFT_ANCHOR = 1024
|
||||||
|
MAX_SHIFT_ANCHOR = 4096
|
||||||
|
|
||||||
|
|
||||||
|
def get_sigmas(
|
||||||
|
num_steps: int,
|
||||||
|
num_tokens: int,
|
||||||
|
max_shift: float = 2.05,
|
||||||
|
base_shift: float = 0.95,
|
||||||
|
stretch: bool = True,
|
||||||
|
terminal: float = 0.1,
|
||||||
|
use_distilled: bool = True,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Get sigma schedule for diffusion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_steps: Number of diffusion steps
|
||||||
|
num_tokens: Number of latent tokens (T * H * W)
|
||||||
|
max_shift: Maximum shift for sigma schedule
|
||||||
|
base_shift: Base shift for sigma schedule
|
||||||
|
stretch: Whether to stretch sigmas to terminal value
|
||||||
|
terminal: Terminal value for stretching
|
||||||
|
use_distilled: If True, use hardcoded distilled sigma values
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Array of sigma values
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
|
||||||
|
# For distilled model, use hardcoded sigma values
|
||||||
|
if use_distilled:
|
||||||
|
return mx.array(DISTILLED_SIGMA_VALUES, dtype=mx.float32)
|
||||||
|
|
||||||
|
# For non-distilled models, compute dynamically using LTX2Scheduler logic
|
||||||
|
# Linear base schedule
|
||||||
|
sigmas = mx.linspace(1.0, 0.0, num_steps + 1)
|
||||||
|
|
||||||
|
# Compute token-dependent sigma shift
|
||||||
|
x1 = BASE_SHIFT_ANCHOR
|
||||||
|
x2 = MAX_SHIFT_ANCHOR
|
||||||
|
mm = (max_shift - base_shift) / (x2 - x1)
|
||||||
|
b = base_shift - mm * x1
|
||||||
|
sigma_shift = num_tokens * mm + b
|
||||||
|
|
||||||
|
# Apply exponential transformation
|
||||||
|
# sigmas = exp(sigma_shift) / (exp(sigma_shift) + (1/sigmas - 1)^1)
|
||||||
|
power = 1
|
||||||
|
exp_shift = math.exp(sigma_shift)
|
||||||
|
|
||||||
|
# Convert to numpy for computation then back to mx
|
||||||
|
sigmas_np = np.array(sigmas)
|
||||||
|
result = np.zeros_like(sigmas_np)
|
||||||
|
non_zero = sigmas_np != 0
|
||||||
|
result[non_zero] = exp_shift / (exp_shift + (1.0 / sigmas_np[non_zero] - 1.0) ** power)
|
||||||
|
|
||||||
|
# Stretch sigmas so final value matches terminal
|
||||||
|
if stretch:
|
||||||
|
non_zero_mask = result != 0
|
||||||
|
non_zero_sigmas = result[non_zero_mask]
|
||||||
|
one_minus_z = 1.0 - non_zero_sigmas
|
||||||
|
scale_factor = one_minus_z[-1] / (1.0 - terminal)
|
||||||
|
stretched = 1.0 - (one_minus_z / scale_factor)
|
||||||
|
result[non_zero_mask] = stretched
|
||||||
|
|
||||||
|
return mx.array(result, dtype=mx.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def create_position_grid(
|
||||||
|
batch_size: int,
|
||||||
|
num_frames: int,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
temporal_scale: int = 8,
|
||||||
|
spatial_scale: int = 32,
|
||||||
|
fps: float = 24.0,
|
||||||
|
causal_fix: bool = True,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Create position grid for RoPE in pixel space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: Batch size
|
||||||
|
num_frames: Number of frames (latent)
|
||||||
|
height: Height (latent)
|
||||||
|
width: Width (latent)
|
||||||
|
temporal_scale: VAE temporal scale factor (default 8)
|
||||||
|
spatial_scale: VAE spatial scale factor (default 32)
|
||||||
|
fps: Frames per second (default 24.0)
|
||||||
|
causal_fix: Apply causal fix for first frame (default True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Position grid of shape (B, 3, num_patches, 2) in pixel space
|
||||||
|
where dim 2 is [start, end) bounds for each patch
|
||||||
|
"""
|
||||||
|
# Patch size is (1, 1, 1) for LTX-2 - no spatial patching
|
||||||
|
patch_size_t, patch_size_h, patch_size_w = 1, 1, 1
|
||||||
|
|
||||||
|
# Generate grid coordinates for each dimension (frame, height, width)
|
||||||
|
# These are the starting coordinates for each patch in latent space
|
||||||
|
t_coords = np.arange(0, num_frames, patch_size_t)
|
||||||
|
h_coords = np.arange(0, height, patch_size_h)
|
||||||
|
w_coords = np.arange(0, width, patch_size_w)
|
||||||
|
|
||||||
|
# Create meshgrid with indexing='ij' for (frame, height, width) order
|
||||||
|
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
|
||||||
|
|
||||||
|
# Stack to get shape (3, grid_t, grid_h, grid_w)
|
||||||
|
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
|
||||||
|
|
||||||
|
# Calculate end coordinates (start + patch_size)
|
||||||
|
patch_size_delta = np.array([patch_size_t, patch_size_h, patch_size_w]).reshape(3, 1, 1, 1)
|
||||||
|
patch_ends = patch_starts + patch_size_delta
|
||||||
|
|
||||||
|
# Stack start and end: shape (3, grid_t, grid_h, grid_w, 2)
|
||||||
|
latent_coords = np.stack([patch_starts, patch_ends], axis=-1)
|
||||||
|
|
||||||
|
# Flatten spatial/temporal dims: (3, num_patches, 2)
|
||||||
|
num_patches = num_frames * height * width
|
||||||
|
latent_coords = latent_coords.reshape(3, num_patches, 2)
|
||||||
|
|
||||||
|
# Broadcast to batch: (batch, 3, num_patches, 2)
|
||||||
|
latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1))
|
||||||
|
|
||||||
|
# Convert latent coords to pixel coords by scaling with VAE factors
|
||||||
|
scale_factors = np.array([temporal_scale, spatial_scale, spatial_scale]).reshape(1, 3, 1, 1)
|
||||||
|
pixel_coords = (latent_coords * scale_factors).astype(np.float32)
|
||||||
|
|
||||||
|
# Apply causal fix for first frame temporal axis
|
||||||
|
if causal_fix:
|
||||||
|
# VAE temporal stride for first frame is 1 instead of temporal_scale
|
||||||
|
# Shift and clamp to keep first-frame timestamps non-negative
|
||||||
|
pixel_coords[:, 0, :, :] = np.clip(
|
||||||
|
pixel_coords[:, 0, :, :] + 1 - temporal_scale,
|
||||||
|
a_min=0,
|
||||||
|
a_max=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert temporal to time in seconds by dividing by fps
|
||||||
|
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps
|
||||||
|
|
||||||
|
return mx.array(pixel_coords, dtype=mx.float32)
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVideoPipeline:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
transformer: LTXModel,
|
||||||
|
text_encoder: Optional[LTX2TextEncoder] = None,
|
||||||
|
tokenizer: Optional[any] = None,
|
||||||
|
vae_encoder: Optional[VideoEncoder] = None,
|
||||||
|
vae_decoder: Optional[VideoDecoder] = None,
|
||||||
|
):
|
||||||
|
"""Initialize pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transformer: LTX transformer model
|
||||||
|
text_encoder: Optional LTX text encoder
|
||||||
|
tokenizer: Optional tokenizer for text encoding
|
||||||
|
vae_encoder: Optional VAE encoder
|
||||||
|
vae_decoder: Optional VAE decoder
|
||||||
|
"""
|
||||||
|
self.transformer = transformer
|
||||||
|
self.text_encoder = text_encoder
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.vae_encoder = vae_encoder
|
||||||
|
self.vae_decoder = vae_decoder
|
||||||
|
self.x0_model = X0Model(transformer)
|
||||||
|
|
||||||
|
def prepare_latents(
|
||||||
|
self,
|
||||||
|
batch_size: int,
|
||||||
|
num_frames: int,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
dtype: mx.Dtype = mx.float16,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Prepare initial noise latents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: Batch size
|
||||||
|
num_frames: Number of latent frames
|
||||||
|
height: Latent height
|
||||||
|
width: Latent width
|
||||||
|
dtype: Data type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Random latent noise
|
||||||
|
"""
|
||||||
|
# Use in_channels from transformer config
|
||||||
|
in_channels = self.transformer.config.in_channels
|
||||||
|
shape = (batch_size, in_channels, num_frames, height, width)
|
||||||
|
latents = mx.random.normal(shape).astype(dtype)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def prepare_text_embeddings(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]],
|
||||||
|
batch_size: int,
|
||||||
|
max_length: int = 1024,
|
||||||
|
) -> Tuple[mx.array, Optional[mx.array]]:
|
||||||
|
"""Prepare text embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Text prompt or list of prompts
|
||||||
|
batch_size: Batch size
|
||||||
|
max_length: Maximum sequence length for tokenization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (text_embeddings, attention_mask)
|
||||||
|
"""
|
||||||
|
# If text encoder is available, use it
|
||||||
|
if self.text_encoder is not None and self.tokenizer is not None:
|
||||||
|
# Handle single or multiple prompts
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
prompts = [prompt] * batch_size
|
||||||
|
else:
|
||||||
|
prompts = prompt
|
||||||
|
|
||||||
|
# Tokenize
|
||||||
|
tokens = self.tokenizer(
|
||||||
|
prompts,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = mx.array(tokens["input_ids"])
|
||||||
|
attention_mask = mx.array(tokens["attention_mask"])
|
||||||
|
|
||||||
|
# Encode
|
||||||
|
embeddings = self.text_encoder(input_ids, attention_mask)
|
||||||
|
mx.eval(embeddings)
|
||||||
|
|
||||||
|
return embeddings, None # Connector handles masking internally
|
||||||
|
|
||||||
|
# Fallback: random embeddings (for testing without text encoder)
|
||||||
|
print("Warning: No text encoder provided, using random embeddings")
|
||||||
|
seq_len = max_length + 128 # Account for learnable registers
|
||||||
|
embed_dim = self.transformer.config.caption_channels
|
||||||
|
|
||||||
|
embeddings = mx.random.normal((batch_size, seq_len, embed_dim))
|
||||||
|
mask = mx.ones((batch_size, seq_len))
|
||||||
|
|
||||||
|
return embeddings, mask
|
||||||
|
|
||||||
|
def denoise_step(
|
||||||
|
self,
|
||||||
|
latents: mx.array,
|
||||||
|
sigma: float,
|
||||||
|
sigma_next: float,
|
||||||
|
text_embeddings: mx.array,
|
||||||
|
positions: mx.array,
|
||||||
|
text_mask: Optional[mx.array] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Perform one denoising step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latents: Current noisy latents
|
||||||
|
sigma: Current noise level
|
||||||
|
sigma_next: Next noise level
|
||||||
|
text_embeddings: Text conditioning
|
||||||
|
positions: Position grid for RoPE
|
||||||
|
text_mask: Optional attention mask for text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Denoised latents
|
||||||
|
"""
|
||||||
|
batch_size = latents.shape[0]
|
||||||
|
|
||||||
|
# Flatten latents for transformer: (B, C, F, H, W) -> (B, F*H*W, C)
|
||||||
|
b, c, f, h, w = latents.shape
|
||||||
|
latents_flat = mx.reshape(latents, (b, c, -1))
|
||||||
|
latents_flat = mx.transpose(latents_flat, (0, 2, 1))
|
||||||
|
|
||||||
|
# Create timestep tensor
|
||||||
|
timesteps = mx.full((batch_size,), sigma)
|
||||||
|
|
||||||
|
# Create video modality input
|
||||||
|
video_modality = Modality(
|
||||||
|
latent=latents_flat,
|
||||||
|
timesteps=timesteps,
|
||||||
|
positions=positions,
|
||||||
|
context=text_embeddings,
|
||||||
|
context_mask=text_mask,
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run denoising
|
||||||
|
denoised_video, _ = self.x0_model(video=video_modality, audio=None)
|
||||||
|
|
||||||
|
# Reshape back: (B, F*H*W, C) -> (B, C, F, H, W)
|
||||||
|
denoised_video = mx.transpose(denoised_video, (0, 2, 1))
|
||||||
|
denoised_video = mx.reshape(denoised_video, (b, c, f, h, w))
|
||||||
|
|
||||||
|
# Euler step
|
||||||
|
if sigma_next > 0:
|
||||||
|
# x_next = x0 + sigma_next * (x - x0) / sigma
|
||||||
|
noise = (latents - denoised_video) / sigma
|
||||||
|
latents = denoised_video + sigma_next * noise
|
||||||
|
else:
|
||||||
|
latents = denoised_video
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
config: Optional[GenerationConfig] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Generate video from text prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Text prompt
|
||||||
|
config: Generation configuration
|
||||||
|
seed: Random seed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated video tensor of shape (B, C, F, H, W)
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = GenerationConfig()
|
||||||
|
|
||||||
|
if seed is not None:
|
||||||
|
mx.random.seed(seed)
|
||||||
|
|
||||||
|
batch_size = 1
|
||||||
|
|
||||||
|
# Prepare text embeddings
|
||||||
|
text_embeddings, text_mask = self.prepare_text_embeddings(prompt, batch_size)
|
||||||
|
|
||||||
|
# Prepare initial latents
|
||||||
|
latents = self.prepare_latents(
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_frames=config.latent_frames,
|
||||||
|
height=config.latent_height,
|
||||||
|
width=config.latent_width,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare position grid
|
||||||
|
positions = create_position_grid(
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_frames=config.latent_frames,
|
||||||
|
height=config.latent_height,
|
||||||
|
width=config.latent_width,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get sigma schedule
|
||||||
|
num_tokens = config.latent_frames * config.latent_height * config.latent_width
|
||||||
|
sigmas = get_sigmas(
|
||||||
|
config.num_inference_steps,
|
||||||
|
num_tokens,
|
||||||
|
use_distilled=config.use_distilled,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Denoising loop
|
||||||
|
for i in range(len(sigmas) - 1):
|
||||||
|
sigma = float(sigmas[i])
|
||||||
|
sigma_next = float(sigmas[i + 1])
|
||||||
|
|
||||||
|
latents = self.denoise_step(
|
||||||
|
latents=latents,
|
||||||
|
sigma=sigma,
|
||||||
|
sigma_next=sigma_next,
|
||||||
|
text_embeddings=text_embeddings,
|
||||||
|
positions=positions,
|
||||||
|
text_mask=text_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
mx.eval(latents)
|
||||||
|
|
||||||
|
# Decode latents to video
|
||||||
|
if self.vae_decoder is not None:
|
||||||
|
video = self.vae_decoder(latents)
|
||||||
|
else:
|
||||||
|
video = latents
|
||||||
|
|
||||||
|
return video
|
||||||
|
|
||||||
|
|
||||||
|
def generate_video(
|
||||||
|
prompt: str,
|
||||||
|
transformer: LTXModel,
|
||||||
|
text_encoder: Optional[LTX2TextEncoder] = None,
|
||||||
|
tokenizer: Optional[any] = None,
|
||||||
|
vae_decoder: Optional[VideoDecoder] = None,
|
||||||
|
config: Optional[GenerationConfig] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Generate video from text prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Text prompt
|
||||||
|
transformer: LTX transformer model
|
||||||
|
text_encoder: Optional text encoder
|
||||||
|
tokenizer: Optional tokenizer
|
||||||
|
vae_decoder: Optional VAE decoder
|
||||||
|
config: Generation configuration
|
||||||
|
seed: Random seed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated video tensor
|
||||||
|
"""
|
||||||
|
pipeline = LTXVideoPipeline(
|
||||||
|
transformer=transformer,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
vae_decoder=vae_decoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
return pipeline(prompt, config, seed)
|
||||||
|
|
||||||
|
|
||||||
|
def load_pipeline(
|
||||||
|
model_path: str,
|
||||||
|
text_encoder_path: Optional[str] = None,
|
||||||
|
tokenizer_path: Optional[str] = None,
|
||||||
|
load_text_encoder_weights: bool = True,
|
||||||
|
) -> LTXVideoPipeline:
|
||||||
|
"""Load complete LTX-2 video generation pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to LTX-2 model weights (safetensors)
|
||||||
|
text_encoder_path: Path to text encoder weights directory
|
||||||
|
tokenizer_path: Path to tokenizer directory
|
||||||
|
load_text_encoder_weights: Whether to load text encoder weights
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured LTXVideoPipeline
|
||||||
|
"""
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
|
||||||
|
from mlx_video.models.ltx.ltx import LTXModel
|
||||||
|
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||||
|
from mlx_video.convert import sanitize_transformer_weights
|
||||||
|
|
||||||
|
print("Loading LTX-2 pipeline...")
|
||||||
|
|
||||||
|
# Load transformer
|
||||||
|
print(" Loading transformer...")
|
||||||
|
raw_weights = mx.load(model_path)
|
||||||
|
sanitized = sanitize_transformer_weights(raw_weights)
|
||||||
|
|
||||||
|
config = LTXModelConfig(
|
||||||
|
model_type=LTXModelType.VideoOnly,
|
||||||
|
num_attention_heads=32,
|
||||||
|
attention_head_dim=128,
|
||||||
|
in_channels=128,
|
||||||
|
out_channels=128,
|
||||||
|
num_layers=48,
|
||||||
|
cross_attention_dim=4096,
|
||||||
|
caption_channels=3840,
|
||||||
|
)
|
||||||
|
transformer = LTXModel(config)
|
||||||
|
transformer.load_weights(list(sanitized.items()), strict=False)
|
||||||
|
print(" Transformer loaded")
|
||||||
|
|
||||||
|
# Load VAE decoder
|
||||||
|
print(" Loading VAE decoder...")
|
||||||
|
vae_decoder = load_vae_decoder(model_path, timestep_conditioning=True)
|
||||||
|
print(" VAE decoder loaded")
|
||||||
|
|
||||||
|
# Load text encoder if paths provided
|
||||||
|
text_encoder = None
|
||||||
|
tokenizer = None
|
||||||
|
|
||||||
|
if load_text_encoder_weights and text_encoder_path is not None:
|
||||||
|
print(" Loading text encoder...")
|
||||||
|
text_encoder = load_text_encoder(model_path, text_encoder_path)
|
||||||
|
print(" Text encoder loaded")
|
||||||
|
|
||||||
|
if tokenizer_path is not None:
|
||||||
|
print(" Loading tokenizer...")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||||
|
print(" Tokenizer loaded")
|
||||||
|
|
||||||
|
print("Pipeline ready!")
|
||||||
|
|
||||||
|
return LTXVideoPipeline(
|
||||||
|
transformer=transformer,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
vae_decoder=vae_decoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def video_to_numpy(video: mx.array) -> np.ndarray:
|
||||||
|
"""Convert video tensor to numpy array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video: Video tensor of shape (B, C, F, H, W) in range [-1, 1]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Numpy array of shape (B, F, H, W, C) in range [0, 255]
|
||||||
|
"""
|
||||||
|
# Clamp to [-1, 1]
|
||||||
|
video = mx.clip(video, -1.0, 1.0)
|
||||||
|
|
||||||
|
# Scale to [0, 255]
|
||||||
|
video = ((video + 1.0) / 2.0 * 255.0).astype(mx.uint8)
|
||||||
|
|
||||||
|
# Rearrange: (B, C, F, H, W) -> (B, F, H, W, C)
|
||||||
|
video = mx.transpose(video, (0, 2, 3, 4, 1))
|
||||||
|
|
||||||
|
return np.array(video)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Example usage
|
||||||
|
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
|
||||||
|
|
||||||
|
# Create a small test config
|
||||||
|
config = LTXModelConfig(
|
||||||
|
model_type=LTXModelType.VideoOnly,
|
||||||
|
num_layers=2, # Reduced for testing
|
||||||
|
num_attention_heads=4,
|
||||||
|
attention_head_dim=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create model
|
||||||
|
model = LTXModel(config)
|
||||||
|
|
||||||
|
# Generate video
|
||||||
|
gen_config = GenerationConfig(
|
||||||
|
height=256,
|
||||||
|
width=256,
|
||||||
|
num_frames=9,
|
||||||
|
num_inference_steps=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Testing generation pipeline...")
|
||||||
|
pipeline = LTXVideoPipeline(transformer=model)
|
||||||
|
|
||||||
|
# This would require proper text embeddings in practice
|
||||||
|
# video = pipeline("A cat walking", gen_config, seed=42)
|
||||||
|
# print(f"Generated video shape: {video.shape}")
|
||||||
|
|
||||||
|
print("Pipeline initialized successfully!")
|
||||||
2
mlx_video/models/__init__.py
Normal file
2
mlx_video/models/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
|
||||||
|
from mlx_video.models.ltx import LTXModel, LTXModelConfig
|
||||||
7
mlx_video/models/ltx/__init__.py
Normal file
7
mlx_video/models/ltx/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
|
||||||
|
from mlx_video.models.ltx.config import (
|
||||||
|
LTXModelConfig,
|
||||||
|
TransformerConfig,
|
||||||
|
LTXModelType,
|
||||||
|
)
|
||||||
|
from mlx_video.models.ltx.ltx import LTXModel, X0Model
|
||||||
161
mlx_video/models/ltx/adaln.py
Normal file
161
mlx_video/models/ltx/adaln.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from mlx_video.utils import get_timestep_embedding
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNormSingle(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
embedding_coefficient: int = 6,
|
||||||
|
use_additional_conditions: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
size_emb_dim=0 if not use_additional_conditions else embedding_dim // 3,
|
||||||
|
use_additional_conditions=use_additional_conditions,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.linear = nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
timestep: mx.array,
|
||||||
|
added_cond_kwargs: dict | None = None,
|
||||||
|
batch_size: int | None = None,
|
||||||
|
hidden_dtype: mx.Dtype | None = None,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
|
||||||
|
added_cond_kwargs = added_cond_kwargs or {}
|
||||||
|
|
||||||
|
embedded_timestep = self.emb(
|
||||||
|
timestep,
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
**added_cond_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
scale_shift_params = self.linear(self.silu(embedded_timestep))
|
||||||
|
return scale_shift_params, embedded_timestep
|
||||||
|
|
||||||
|
|
||||||
|
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
size_emb_dim: int = 0,
|
||||||
|
use_additional_conditions: bool = False,
|
||||||
|
timestep_proj_dim: int = 256,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.size_emb_dim = size_emb_dim
|
||||||
|
self.use_additional_conditions = use_additional_conditions
|
||||||
|
|
||||||
|
self.time_proj = Timesteps(timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||||
|
self.timestep_embedder = TimestepEmbedding(timestep_proj_dim, embedding_dim, out_dim=embedding_dim)
|
||||||
|
|
||||||
|
if use_additional_conditions and size_emb_dim > 0:
|
||||||
|
self.additional_embedder = ConditionEmbedding(size_emb_dim, embedding_dim)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
timestep: mx.array,
|
||||||
|
resolution: mx.array | None = None,
|
||||||
|
aspect_ratio: mx.array | None = None,
|
||||||
|
batch_size: int | None = None,
|
||||||
|
hidden_dtype: mx.Dtype | None = None,
|
||||||
|
) -> mx.array:
|
||||||
|
# Project timestep
|
||||||
|
timesteps_proj = self.time_proj(timestep)
|
||||||
|
if hidden_dtype is not None:
|
||||||
|
timesteps_proj = timesteps_proj.astype(hidden_dtype)
|
||||||
|
|
||||||
|
timesteps_emb = self.timestep_embedder(timesteps_proj)
|
||||||
|
|
||||||
|
# Add additional conditions if enabled
|
||||||
|
if self.use_additional_conditions and self.size_emb_dim > 0:
|
||||||
|
if resolution is not None and aspect_ratio is not None:
|
||||||
|
additional_embeds = self.additional_embedder(resolution, aspect_ratio, hidden_dtype)
|
||||||
|
timesteps_emb = timesteps_emb + additional_embeds
|
||||||
|
|
||||||
|
return timesteps_emb
|
||||||
|
|
||||||
|
|
||||||
|
class Timesteps(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_channels: int,
|
||||||
|
flip_sin_to_cos: bool = False,
|
||||||
|
downscale_freq_shift: float = 1.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.flip_sin_to_cos = flip_sin_to_cos
|
||||||
|
self.downscale_freq_shift = downscale_freq_shift
|
||||||
|
|
||||||
|
def __call__(self, timesteps: mx.array) -> mx.array:
|
||||||
|
return get_timestep_embedding(
|
||||||
|
timesteps,
|
||||||
|
self.num_channels,
|
||||||
|
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||||
|
downscale_freq_shift=self.downscale_freq_shift,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
time_embed_dim: int,
|
||||||
|
act_fn: str = "silu",
|
||||||
|
out_dim: int | None = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
out_dim = out_dim or time_embed_dim
|
||||||
|
self.linear1 = nn.Linear(in_channels, time_embed_dim)
|
||||||
|
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
|
||||||
|
self.linear2 = nn.Linear(time_embed_dim, out_dim)
|
||||||
|
|
||||||
|
def __call__(self, sample: mx.array) -> mx.array:
|
||||||
|
sample = self.linear1(sample)
|
||||||
|
sample = self.act(sample)
|
||||||
|
sample = self.linear2(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionEmbedding(nn.Module):
|
||||||
|
def __init__(self, size_emb_dim: int, embedding_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.resolution_embedder = TimestepEmbedding(size_emb_dim, embedding_dim)
|
||||||
|
self.aspect_ratio_embedder = TimestepEmbedding(size_emb_dim, embedding_dim)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
resolution: mx.array,
|
||||||
|
aspect_ratio: mx.array,
|
||||||
|
hidden_dtype: mx.Dtype | None = None,
|
||||||
|
) -> mx.array:
|
||||||
|
resolution_emb = self.resolution_embedder(resolution)
|
||||||
|
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio)
|
||||||
|
|
||||||
|
if hidden_dtype is not None:
|
||||||
|
resolution_emb = resolution_emb.astype(hidden_dtype)
|
||||||
|
aspect_ratio_emb = aspect_ratio_emb.astype(hidden_dtype)
|
||||||
|
|
||||||
|
return resolution_emb + aspect_ratio_emb
|
||||||
142
mlx_video/models/ltx/attention.py
Normal file
142
mlx_video/models/ltx/attention.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
"""Attention module for LTX-2."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from mlx_video.models.ltx.config import LTXRopeType
|
||||||
|
from mlx_video.models.ltx.rope import apply_rotary_emb
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_dot_product_attention(
|
||||||
|
q: mx.array,
|
||||||
|
k: mx.array,
|
||||||
|
v: mx.array,
|
||||||
|
heads: int,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
|
||||||
|
b, q_seq_len, dim = q.shape
|
||||||
|
_, kv_seq_len, _ = k.shape
|
||||||
|
dim_head = dim // heads
|
||||||
|
|
||||||
|
# Reshape to (B, seq_len, heads, dim_head)
|
||||||
|
q = mx.reshape(q, (b, q_seq_len, heads, dim_head))
|
||||||
|
k = mx.reshape(k, (b, kv_seq_len, heads, dim_head))
|
||||||
|
v = mx.reshape(v, (b, kv_seq_len, heads, dim_head))
|
||||||
|
|
||||||
|
# Transpose to (B, heads, seq_len, dim_head)
|
||||||
|
q = mx.swapaxes(q, 1, 2)
|
||||||
|
k = mx.swapaxes(k, 1, 2)
|
||||||
|
v = mx.swapaxes(v, 1, 2)
|
||||||
|
|
||||||
|
# Handle mask dimensions
|
||||||
|
if mask is not None:
|
||||||
|
# Add batch dimension if needed
|
||||||
|
if mask.ndim == 2:
|
||||||
|
mask = mx.expand_dims(mask, axis=0)
|
||||||
|
# Add heads dimension if needed
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mx.expand_dims(mask, axis=1)
|
||||||
|
|
||||||
|
# Compute scaled dot-product attention
|
||||||
|
scale = 1.0 / math.sqrt(dim_head)
|
||||||
|
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||||
|
|
||||||
|
# Reshape back to (B, q_seq_len, heads * dim_head)
|
||||||
|
out = mx.swapaxes(out, 1, 2)
|
||||||
|
out = mx.reshape(out, (b, q_seq_len, heads * dim_head))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
"""Multi-head attention with rotary position embeddings.
|
||||||
|
|
||||||
|
Supports both self-attention and cross-attention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
query_dim: int,
|
||||||
|
context_dim: Optional[int] = None,
|
||||||
|
heads: int = 8,
|
||||||
|
dim_head: int = 64,
|
||||||
|
norm_eps: float = 1e-6,
|
||||||
|
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||||
|
):
|
||||||
|
"""Initialize attention module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_dim: Dimension of query input
|
||||||
|
context_dim: Dimension of context (key/value) input. If None, same as query_dim
|
||||||
|
heads: Number of attention heads
|
||||||
|
dim_head: Dimension per head
|
||||||
|
norm_eps: Epsilon for RMS normalization
|
||||||
|
rope_type: Type of rotary position embedding
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.rope_type = rope_type
|
||||||
|
self.heads = heads
|
||||||
|
self.dim_head = dim_head
|
||||||
|
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
context_dim = query_dim if context_dim is None else context_dim
|
||||||
|
|
||||||
|
# Q, K, V projections
|
||||||
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=True)
|
||||||
|
self.to_k = nn.Linear(context_dim, inner_dim, bias=True)
|
||||||
|
self.to_v = nn.Linear(context_dim, inner_dim, bias=True)
|
||||||
|
|
||||||
|
# Q and K normalization
|
||||||
|
self.q_norm = nn.RMSNorm(inner_dim, eps=norm_eps)
|
||||||
|
self.k_norm = nn.RMSNorm(inner_dim, eps=norm_eps)
|
||||||
|
|
||||||
|
# Output projection
|
||||||
|
self.to_out = nn.Linear(inner_dim, query_dim, bias=True)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
context: Optional[mx.array] = None,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
pe: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
k_pe: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Query input of shape (B, seq_len, query_dim)
|
||||||
|
context: Context for cross-attention. If None, uses x (self-attention)
|
||||||
|
mask: Attention mask
|
||||||
|
pe: Position embeddings for query (and key if k_pe is None)
|
||||||
|
k_pe: Position embeddings for key (optional, uses pe if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Attention output of shape (B, seq_len, query_dim)
|
||||||
|
"""
|
||||||
|
# Compute Q, K, V
|
||||||
|
q = self.to_q(x)
|
||||||
|
context = x if context is None else context
|
||||||
|
k = self.to_k(context)
|
||||||
|
v = self.to_v(context)
|
||||||
|
|
||||||
|
# Apply normalization
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
|
||||||
|
# Apply rotary position embeddings
|
||||||
|
if pe is not None:
|
||||||
|
q = apply_rotary_emb(q, pe, self.rope_type)
|
||||||
|
k_pe_to_use = pe if k_pe is None else k_pe
|
||||||
|
k = apply_rotary_emb(k, k_pe_to_use, self.rope_type)
|
||||||
|
|
||||||
|
# Compute attention
|
||||||
|
out = scaled_dot_product_attention(q, k, v, self.heads, mask)
|
||||||
|
|
||||||
|
# Project output
|
||||||
|
return self.to_out(out)
|
||||||
181
mlx_video/models/ltx/config.py
Normal file
181
mlx_video/models/ltx/config.py
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
|
||||||
|
import inspect
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class LTXModelType(Enum):
|
||||||
|
AudioVideo = "ltx av model"
|
||||||
|
VideoOnly = "ltx video only model"
|
||||||
|
AudioOnly = "ltx audio only model"
|
||||||
|
|
||||||
|
def is_video_enabled(self) -> bool:
|
||||||
|
return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly)
|
||||||
|
|
||||||
|
def is_audio_enabled(self) -> bool:
|
||||||
|
return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly)
|
||||||
|
|
||||||
|
|
||||||
|
class LTXRopeType(Enum):
|
||||||
|
INTERLEAVED = "interleaved"
|
||||||
|
SPLIT = "split"
|
||||||
|
TWO_D = "2d"
|
||||||
|
|
||||||
|
class AttentionType(Enum):
|
||||||
|
DEFAULT = "default"
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseModelConfig:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, params: dict[str, Any]) -> "BaseModelConfig":
|
||||||
|
"""Create config from dictionary, filtering only valid parameters."""
|
||||||
|
return cls(
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in params.items()
|
||||||
|
if k in inspect.signature(cls).parameters
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Export config to dictionary."""
|
||||||
|
result = {}
|
||||||
|
for k, v in self.__dict__.items():
|
||||||
|
if v is not None:
|
||||||
|
if isinstance(v, Enum):
|
||||||
|
result[k] = v.value
|
||||||
|
elif hasattr(v, 'to_dict'):
|
||||||
|
result[k] = v.to_dict()
|
||||||
|
else:
|
||||||
|
result[k] = v
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TransformerConfig(BaseModelConfig):
|
||||||
|
dim: int
|
||||||
|
heads: int
|
||||||
|
d_head: int
|
||||||
|
context_dim: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VideoVAEConfig(BaseModelConfig):
|
||||||
|
convolution_dimensions: int = 3
|
||||||
|
in_channels: int = 3
|
||||||
|
out_channels: int = 128
|
||||||
|
latent_channels: int = 128
|
||||||
|
patch_size: int = 4
|
||||||
|
encoder_blocks: List[tuple] = field(default_factory=lambda: [
|
||||||
|
("res_x", {"num_layers": 4}),
|
||||||
|
("compress_space_res", {"multiplier": 2}),
|
||||||
|
("res_x", {"num_layers": 6}),
|
||||||
|
("compress_time_res", {"multiplier": 2}),
|
||||||
|
("res_x", {"num_layers": 6}),
|
||||||
|
("compress_all_res", {"multiplier": 2}),
|
||||||
|
("res_x", {"num_layers": 2}),
|
||||||
|
("compress_all_res", {"multiplier": 2}),
|
||||||
|
("res_x", {"num_layers": 2}),
|
||||||
|
])
|
||||||
|
decoder_blocks: List[tuple] = field(default_factory=lambda: [
|
||||||
|
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||||
|
("compress_all", {"residual": True, "multiplier": 2}),
|
||||||
|
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||||
|
("compress_all", {"residual": True, "multiplier": 2}),
|
||||||
|
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||||
|
("compress_all", {"residual": True, "multiplier": 2}),
|
||||||
|
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LTXModelConfig(BaseModelConfig):
|
||||||
|
|
||||||
|
# Model type
|
||||||
|
model_type: LTXModelType = LTXModelType.AudioVideo
|
||||||
|
|
||||||
|
# Video transformer config
|
||||||
|
num_attention_heads: int = 32
|
||||||
|
attention_head_dim: int = 128
|
||||||
|
in_channels: int = 128
|
||||||
|
out_channels: int = 128
|
||||||
|
num_layers: int = 48
|
||||||
|
cross_attention_dim: int = 4096
|
||||||
|
caption_channels: int = 3840
|
||||||
|
|
||||||
|
# Audio transformer config
|
||||||
|
audio_num_attention_heads: int = 32
|
||||||
|
audio_attention_head_dim: int = 64
|
||||||
|
audio_in_channels: int = 128
|
||||||
|
audio_out_channels: int = 128
|
||||||
|
audio_cross_attention_dim: int = 2048
|
||||||
|
|
||||||
|
# Positional embedding config
|
||||||
|
positional_embedding_theta: float = 10000.0
|
||||||
|
positional_embedding_max_pos: Optional[List[int]] = None
|
||||||
|
audio_positional_embedding_max_pos: Optional[List[int]] = None
|
||||||
|
use_middle_indices_grid: bool = True
|
||||||
|
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED
|
||||||
|
double_precision_rope: bool = False
|
||||||
|
|
||||||
|
# Timestep config
|
||||||
|
timestep_scale_multiplier: int = 1000
|
||||||
|
av_ca_timestep_scale_multiplier: int = 1
|
||||||
|
|
||||||
|
# Normalization
|
||||||
|
norm_eps: float = 1e-6
|
||||||
|
|
||||||
|
# Attention type
|
||||||
|
attention_type: AttentionType = AttentionType.DEFAULT
|
||||||
|
|
||||||
|
# VAE config
|
||||||
|
vae_config: Optional[VideoVAEConfig] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Set default values after initialization."""
|
||||||
|
if self.positional_embedding_max_pos is None:
|
||||||
|
self.positional_embedding_max_pos = [20, 2048, 2048]
|
||||||
|
if self.audio_positional_embedding_max_pos is None:
|
||||||
|
self.audio_positional_embedding_max_pos = [20]
|
||||||
|
|
||||||
|
# Convert string enum values if loading from dict
|
||||||
|
if isinstance(self.model_type, str):
|
||||||
|
self.model_type = LTXModelType(self.model_type)
|
||||||
|
if isinstance(self.rope_type, str):
|
||||||
|
self.rope_type = LTXRopeType(self.rope_type)
|
||||||
|
if isinstance(self.attention_type, str):
|
||||||
|
self.attention_type = AttentionType(self.attention_type)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inner_dim(self) -> int:
|
||||||
|
"""Video inner dimension."""
|
||||||
|
return self.num_attention_heads * self.attention_head_dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_inner_dim(self) -> int:
|
||||||
|
"""Audio inner dimension."""
|
||||||
|
return self.audio_num_attention_heads * self.audio_attention_head_dim
|
||||||
|
|
||||||
|
def get_video_config(self) -> Optional[TransformerConfig]:
|
||||||
|
"""Get video transformer configuration."""
|
||||||
|
if not self.model_type.is_video_enabled():
|
||||||
|
return None
|
||||||
|
return TransformerConfig(
|
||||||
|
dim=self.inner_dim,
|
||||||
|
heads=self.num_attention_heads,
|
||||||
|
d_head=self.attention_head_dim,
|
||||||
|
context_dim=self.cross_attention_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_audio_config(self) -> Optional[TransformerConfig]:
|
||||||
|
"""Get audio transformer configuration."""
|
||||||
|
if not self.model_type.is_audio_enabled():
|
||||||
|
return None
|
||||||
|
return TransformerConfig(
|
||||||
|
dim=self.audio_inner_dim,
|
||||||
|
heads=self.audio_num_attention_heads,
|
||||||
|
d_head=self.audio_attention_head_dim,
|
||||||
|
context_dim=self.audio_cross_attention_dim,
|
||||||
|
)
|
||||||
40
mlx_video/models/ltx/feed_forward.py
Normal file
40
mlx_video/models/ltx/feed_forward.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class GELU(nn.Module):
|
||||||
|
def __init__(self, approximate: str = "tanh"):
|
||||||
|
super().__init__()
|
||||||
|
self.approximate = approximate
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
if self.approximate == "tanh":
|
||||||
|
return nn.gelu_approx(x)
|
||||||
|
else:
|
||||||
|
return nn.gelu(x)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
dim_out: int | None = None,
|
||||||
|
mult: int = 4,
|
||||||
|
bias: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dim_out = dim_out or dim
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
|
||||||
|
self.proj_in = nn.Linear(dim, inner_dim, bias=bias)
|
||||||
|
self.act = GELU(approximate="tanh")
|
||||||
|
self.proj_out = nn.Linear(inner_dim, dim_out, bias=bias)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
|
||||||
|
x = self.proj_in(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.proj_out(x)
|
||||||
|
return x
|
||||||
518
mlx_video/models/ltx/ltx.py
Normal file
518
mlx_video/models/ltx/ltx.py
Normal file
@@ -0,0 +1,518 @@
|
|||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from mlx_video.models.ltx.config import (
|
||||||
|
LTXModelConfig,
|
||||||
|
LTXModelType,
|
||||||
|
LTXRopeType,
|
||||||
|
TransformerConfig,
|
||||||
|
)
|
||||||
|
from mlx_video.models.ltx.adaln import AdaLayerNormSingle
|
||||||
|
from mlx_video.models.ltx.rope import precompute_freqs_cis
|
||||||
|
from mlx_video.models.ltx.text_projection import PixArtAlphaTextProjection
|
||||||
|
from mlx_video.models.ltx.transformer import (
|
||||||
|
BasicAVTransformerBlock,
|
||||||
|
Modality,
|
||||||
|
TransformerArgs,
|
||||||
|
)
|
||||||
|
from mlx_video.utils import to_denoised
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerArgsPreprocessor:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patchify_proj: nn.Linear,
|
||||||
|
adaln: AdaLayerNormSingle,
|
||||||
|
caption_projection: PixArtAlphaTextProjection,
|
||||||
|
inner_dim: int,
|
||||||
|
max_pos: List[int],
|
||||||
|
num_attention_heads: int,
|
||||||
|
use_middle_indices_grid: bool,
|
||||||
|
timestep_scale_multiplier: int,
|
||||||
|
positional_embedding_theta: float,
|
||||||
|
rope_type: LTXRopeType,
|
||||||
|
double_precision_rope: bool = False,
|
||||||
|
):
|
||||||
|
self.patchify_proj = patchify_proj
|
||||||
|
self.adaln = adaln
|
||||||
|
self.caption_projection = caption_projection
|
||||||
|
self.inner_dim = inner_dim
|
||||||
|
self.max_pos = max_pos
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.use_middle_indices_grid = use_middle_indices_grid
|
||||||
|
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||||
|
self.positional_embedding_theta = positional_embedding_theta
|
||||||
|
self.rope_type = rope_type
|
||||||
|
self.double_precision_rope = double_precision_rope
|
||||||
|
|
||||||
|
def _prepare_timestep(
|
||||||
|
self,
|
||||||
|
timestep: mx.array,
|
||||||
|
batch_size: int,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
|
||||||
|
timestep = timestep * self.timestep_scale_multiplier
|
||||||
|
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1))
|
||||||
|
|
||||||
|
# Reshape to (batch, tokens, dim)
|
||||||
|
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
|
||||||
|
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
|
||||||
|
|
||||||
|
return timestep_emb, embedded_timestep
|
||||||
|
|
||||||
|
def _prepare_context(
|
||||||
|
self,
|
||||||
|
context: mx.array,
|
||||||
|
x: mx.array,
|
||||||
|
attention_mask: Optional[mx.array] = None,
|
||||||
|
) -> Tuple[mx.array, Optional[mx.array]]:
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
context = self.caption_projection(context)
|
||||||
|
context = mx.reshape(context, (batch_size, -1, x.shape[-1]))
|
||||||
|
return context, attention_mask
|
||||||
|
|
||||||
|
def _prepare_attention_mask(
|
||||||
|
self,
|
||||||
|
attention_mask: Optional[mx.array],
|
||||||
|
x_dtype: mx.Dtype,
|
||||||
|
) -> Optional[mx.array]:
|
||||||
|
if attention_mask is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if already float
|
||||||
|
if attention_mask.dtype in [mx.float16, mx.float32, mx.bfloat16]:
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
# Convert boolean/int mask to float mask
|
||||||
|
# 0 -> -inf (masked), 1 -> 0 (not masked)
|
||||||
|
mask = (attention_mask.astype(x_dtype) - 1) * 1e9
|
||||||
|
mask = mx.reshape(mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def _prepare_positional_embeddings(
|
||||||
|
self,
|
||||||
|
positions: mx.array,
|
||||||
|
inner_dim: int,
|
||||||
|
max_pos: List[int],
|
||||||
|
use_middle_indices_grid: bool,
|
||||||
|
num_attention_heads: int,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
pe = precompute_freqs_cis(
|
||||||
|
positions,
|
||||||
|
dim=inner_dim,
|
||||||
|
theta=self.positional_embedding_theta,
|
||||||
|
max_pos=max_pos,
|
||||||
|
use_middle_indices_grid=use_middle_indices_grid,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
rope_type=self.rope_type,
|
||||||
|
double_precision=self.double_precision_rope,
|
||||||
|
)
|
||||||
|
return pe
|
||||||
|
|
||||||
|
def prepare(self, modality: Modality) -> TransformerArgs:
|
||||||
|
x = self.patchify_proj(modality.latent)
|
||||||
|
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0])
|
||||||
|
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
|
||||||
|
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
|
||||||
|
pe = self._prepare_positional_embeddings(
|
||||||
|
positions=modality.positions,
|
||||||
|
inner_dim=self.inner_dim,
|
||||||
|
max_pos=self.max_pos,
|
||||||
|
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
return TransformerArgs(
|
||||||
|
x=x,
|
||||||
|
context=context,
|
||||||
|
context_mask=attention_mask,
|
||||||
|
timesteps=timestep,
|
||||||
|
embedded_timestep=embedded_timestep,
|
||||||
|
positional_embeddings=pe,
|
||||||
|
cross_positional_embeddings=None,
|
||||||
|
cross_scale_shift_timestep=None,
|
||||||
|
cross_gate_timestep=None,
|
||||||
|
enabled=modality.enabled,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalTransformerArgsPreprocessor:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patchify_proj: nn.Linear,
|
||||||
|
adaln: AdaLayerNormSingle,
|
||||||
|
caption_projection: PixArtAlphaTextProjection,
|
||||||
|
cross_scale_shift_adaln: AdaLayerNormSingle,
|
||||||
|
cross_gate_adaln: AdaLayerNormSingle,
|
||||||
|
inner_dim: int,
|
||||||
|
max_pos: List[int],
|
||||||
|
num_attention_heads: int,
|
||||||
|
cross_pe_max_pos: int,
|
||||||
|
use_middle_indices_grid: bool,
|
||||||
|
audio_cross_attention_dim: int,
|
||||||
|
timestep_scale_multiplier: int,
|
||||||
|
positional_embedding_theta: float,
|
||||||
|
rope_type: LTXRopeType,
|
||||||
|
av_ca_timestep_scale_multiplier: int,
|
||||||
|
double_precision_rope: bool = False,
|
||||||
|
):
|
||||||
|
self.simple_preprocessor = TransformerArgsPreprocessor(
|
||||||
|
patchify_proj=patchify_proj,
|
||||||
|
adaln=adaln,
|
||||||
|
caption_projection=caption_projection,
|
||||||
|
inner_dim=inner_dim,
|
||||||
|
max_pos=max_pos,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
use_middle_indices_grid=use_middle_indices_grid,
|
||||||
|
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||||
|
positional_embedding_theta=positional_embedding_theta,
|
||||||
|
rope_type=rope_type,
|
||||||
|
double_precision_rope=double_precision_rope,
|
||||||
|
)
|
||||||
|
self.cross_scale_shift_adaln = cross_scale_shift_adaln
|
||||||
|
self.cross_gate_adaln = cross_gate_adaln
|
||||||
|
self.cross_pe_max_pos = cross_pe_max_pos
|
||||||
|
self.audio_cross_attention_dim = audio_cross_attention_dim
|
||||||
|
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
|
||||||
|
|
||||||
|
def prepare(self, modality: Modality) -> TransformerArgs:
|
||||||
|
from dataclasses import replace
|
||||||
|
|
||||||
|
transformer_args = self.simple_preprocessor.prepare(modality)
|
||||||
|
|
||||||
|
# Prepare cross-modal positional embeddings
|
||||||
|
cross_pe = self.simple_preprocessor._prepare_positional_embeddings(
|
||||||
|
positions=modality.positions[:, 0:1, :],
|
||||||
|
inner_dim=self.audio_cross_attention_dim,
|
||||||
|
max_pos=[self.cross_pe_max_pos],
|
||||||
|
use_middle_indices_grid=True,
|
||||||
|
num_attention_heads=self.simple_preprocessor.num_attention_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare cross-attention timestep embeddings
|
||||||
|
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
|
||||||
|
timestep=modality.timesteps,
|
||||||
|
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
||||||
|
batch_size=transformer_args.x.shape[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
return replace(
|
||||||
|
transformer_args,
|
||||||
|
cross_positional_embeddings=cross_pe,
|
||||||
|
cross_scale_shift_timestep=cross_scale_shift_timestep,
|
||||||
|
cross_gate_timestep=cross_gate_timestep,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_cross_attention_timestep(
|
||||||
|
self,
|
||||||
|
timestep: mx.array,
|
||||||
|
timestep_scale_multiplier: int,
|
||||||
|
batch_size: int,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
timestep = timestep * timestep_scale_multiplier
|
||||||
|
|
||||||
|
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
|
||||||
|
|
||||||
|
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1))
|
||||||
|
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
|
||||||
|
|
||||||
|
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor)
|
||||||
|
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
|
||||||
|
|
||||||
|
return scale_shift_timestep, gate_timestep
|
||||||
|
|
||||||
|
|
||||||
|
class LTXModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: LTXModelConfig):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.model_type = config.model_type
|
||||||
|
self.use_middle_indices_grid = config.use_middle_indices_grid
|
||||||
|
self.rope_type = config.rope_type
|
||||||
|
self.timestep_scale_multiplier = config.timestep_scale_multiplier
|
||||||
|
self.positional_embedding_theta = config.positional_embedding_theta
|
||||||
|
|
||||||
|
cross_pe_max_pos = None
|
||||||
|
|
||||||
|
if config.model_type.is_video_enabled():
|
||||||
|
self.positional_embedding_max_pos = config.positional_embedding_max_pos
|
||||||
|
self.num_attention_heads = config.num_attention_heads
|
||||||
|
self.inner_dim = config.inner_dim
|
||||||
|
self._init_video(config)
|
||||||
|
|
||||||
|
if config.model_type.is_audio_enabled():
|
||||||
|
self.audio_positional_embedding_max_pos = config.audio_positional_embedding_max_pos
|
||||||
|
self.audio_num_attention_heads = config.audio_num_attention_heads
|
||||||
|
self.audio_inner_dim = config.audio_inner_dim
|
||||||
|
self._init_audio(config)
|
||||||
|
|
||||||
|
# Initialize cross-modal components
|
||||||
|
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
|
||||||
|
cross_pe_max_pos = max(
|
||||||
|
config.positional_embedding_max_pos[0],
|
||||||
|
config.audio_positional_embedding_max_pos[0],
|
||||||
|
)
|
||||||
|
self.av_ca_timestep_scale_multiplier = config.av_ca_timestep_scale_multiplier
|
||||||
|
self.audio_cross_attention_dim = config.audio_cross_attention_dim
|
||||||
|
self._init_audio_video(config)
|
||||||
|
|
||||||
|
self._init_preprocessors(config, cross_pe_max_pos)
|
||||||
|
|
||||||
|
self._init_transformer_blocks(config)
|
||||||
|
|
||||||
|
def _init_video(self, config: LTXModelConfig) -> None:
|
||||||
|
self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True)
|
||||||
|
self.adaln_single = AdaLayerNormSingle(self.inner_dim)
|
||||||
|
self.caption_projection = PixArtAlphaTextProjection(
|
||||||
|
in_features=config.caption_channels,
|
||||||
|
hidden_size=self.inner_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scale_shift_table = mx.zeros((2, self.inner_dim))
|
||||||
|
self.norm_out = nn.LayerNorm(self.inner_dim, eps=config.norm_eps, affine=False)
|
||||||
|
self.proj_out = nn.Linear(self.inner_dim, config.out_channels)
|
||||||
|
|
||||||
|
def _init_audio(self, config: LTXModelConfig) -> None:
|
||||||
|
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
|
||||||
|
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim)
|
||||||
|
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||||
|
in_features=config.caption_channels,
|
||||||
|
hidden_size=self.audio_inner_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Output components
|
||||||
|
self.audio_scale_shift_table = mx.zeros((2, self.audio_inner_dim))
|
||||||
|
self.audio_norm_out = nn.LayerNorm(self.audio_inner_dim, eps=config.norm_eps, affine=False)
|
||||||
|
self.audio_proj_out = nn.Linear(self.audio_inner_dim, config.audio_out_channels)
|
||||||
|
|
||||||
|
def _init_audio_video(self, config: LTXModelConfig) -> None:
|
||||||
|
num_scale_shift_values = 4
|
||||||
|
|
||||||
|
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.inner_dim,
|
||||||
|
embedding_coefficient=num_scale_shift_values,
|
||||||
|
)
|
||||||
|
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.audio_inner_dim,
|
||||||
|
embedding_coefficient=num_scale_shift_values,
|
||||||
|
)
|
||||||
|
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.inner_dim,
|
||||||
|
embedding_coefficient=1,
|
||||||
|
)
|
||||||
|
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.audio_inner_dim,
|
||||||
|
embedding_coefficient=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_preprocessors(self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]) -> None:
|
||||||
|
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
|
||||||
|
# Multi-modal preprocessors
|
||||||
|
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
||||||
|
patchify_proj=self.patchify_proj,
|
||||||
|
adaln=self.adaln_single,
|
||||||
|
caption_projection=self.caption_projection,
|
||||||
|
cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
|
||||||
|
cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
|
||||||
|
inner_dim=self.inner_dim,
|
||||||
|
max_pos=config.positional_embedding_max_pos,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
cross_pe_max_pos=cross_pe_max_pos,
|
||||||
|
use_middle_indices_grid=config.use_middle_indices_grid,
|
||||||
|
audio_cross_attention_dim=config.audio_cross_attention_dim,
|
||||||
|
timestep_scale_multiplier=config.timestep_scale_multiplier,
|
||||||
|
positional_embedding_theta=config.positional_embedding_theta,
|
||||||
|
rope_type=config.rope_type,
|
||||||
|
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
|
||||||
|
double_precision_rope=config.double_precision_rope,
|
||||||
|
)
|
||||||
|
self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
||||||
|
patchify_proj=self.audio_patchify_proj,
|
||||||
|
adaln=self.audio_adaln_single,
|
||||||
|
caption_projection=self.audio_caption_projection,
|
||||||
|
cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
|
||||||
|
cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
|
||||||
|
inner_dim=self.audio_inner_dim,
|
||||||
|
max_pos=config.audio_positional_embedding_max_pos,
|
||||||
|
num_attention_heads=self.audio_num_attention_heads,
|
||||||
|
cross_pe_max_pos=cross_pe_max_pos,
|
||||||
|
use_middle_indices_grid=config.use_middle_indices_grid,
|
||||||
|
audio_cross_attention_dim=config.audio_cross_attention_dim,
|
||||||
|
timestep_scale_multiplier=config.timestep_scale_multiplier,
|
||||||
|
positional_embedding_theta=config.positional_embedding_theta,
|
||||||
|
rope_type=config.rope_type,
|
||||||
|
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
|
||||||
|
double_precision_rope=config.double_precision_rope,
|
||||||
|
)
|
||||||
|
elif config.model_type.is_video_enabled():
|
||||||
|
self.video_args_preprocessor = TransformerArgsPreprocessor(
|
||||||
|
patchify_proj=self.patchify_proj,
|
||||||
|
adaln=self.adaln_single,
|
||||||
|
caption_projection=self.caption_projection,
|
||||||
|
inner_dim=self.inner_dim,
|
||||||
|
max_pos=config.positional_embedding_max_pos,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
use_middle_indices_grid=config.use_middle_indices_grid,
|
||||||
|
timestep_scale_multiplier=config.timestep_scale_multiplier,
|
||||||
|
positional_embedding_theta=config.positional_embedding_theta,
|
||||||
|
rope_type=config.rope_type,
|
||||||
|
double_precision_rope=config.double_precision_rope,
|
||||||
|
)
|
||||||
|
elif config.model_type.is_audio_enabled():
|
||||||
|
self.audio_args_preprocessor = TransformerArgsPreprocessor(
|
||||||
|
patchify_proj=self.audio_patchify_proj,
|
||||||
|
adaln=self.audio_adaln_single,
|
||||||
|
caption_projection=self.audio_caption_projection,
|
||||||
|
inner_dim=self.audio_inner_dim,
|
||||||
|
max_pos=config.audio_positional_embedding_max_pos,
|
||||||
|
num_attention_heads=self.audio_num_attention_heads,
|
||||||
|
use_middle_indices_grid=config.use_middle_indices_grid,
|
||||||
|
timestep_scale_multiplier=config.timestep_scale_multiplier,
|
||||||
|
positional_embedding_theta=config.positional_embedding_theta,
|
||||||
|
rope_type=config.rope_type,
|
||||||
|
double_precision_rope=config.double_precision_rope,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_transformer_blocks(self, config: LTXModelConfig) -> None:
|
||||||
|
video_config = config.get_video_config()
|
||||||
|
audio_config = config.get_audio_config()
|
||||||
|
|
||||||
|
self.transformer_blocks = [
|
||||||
|
BasicAVTransformerBlock(
|
||||||
|
idx=idx,
|
||||||
|
video=video_config,
|
||||||
|
audio=audio_config,
|
||||||
|
rope_type=config.rope_type,
|
||||||
|
norm_eps=config.norm_eps,
|
||||||
|
)
|
||||||
|
for idx in range(config.num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
def _process_transformer_blocks(
|
||||||
|
self,
|
||||||
|
video: Optional[TransformerArgs],
|
||||||
|
audio: Optional[TransformerArgs],
|
||||||
|
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
|
||||||
|
"""Process through all transformer blocks."""
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
video, audio = block(video=video, audio=audio)
|
||||||
|
return video, audio
|
||||||
|
|
||||||
|
def _process_output(
|
||||||
|
self,
|
||||||
|
scale_shift_table: mx.array,
|
||||||
|
norm_out: nn.LayerNorm,
|
||||||
|
proj_out: nn.Linear,
|
||||||
|
x: mx.array,
|
||||||
|
embedded_timestep: mx.array,
|
||||||
|
) -> mx.array:
|
||||||
|
|
||||||
|
# scale_shift_table: (2, dim) -> expand to (1, 1, 2, dim)
|
||||||
|
# embedded_timestep: (B, 1, dim) -> expand to (B, 1, 1, dim)
|
||||||
|
table_expanded = scale_shift_table[None, None, :, :] # (1, 1, 2, dim)
|
||||||
|
timestep_expanded = embedded_timestep[:, :, None, :] # (B, 1, 1, dim)
|
||||||
|
|
||||||
|
# Combine: (1, 1, 2, dim) + (B, 1, 1, dim) broadcasts to (B, 1, 2, dim)
|
||||||
|
scale_shift_values = table_expanded + timestep_expanded
|
||||||
|
|
||||||
|
# Extract shift and scale (first index is shift, second is scale)
|
||||||
|
shift = scale_shift_values[:, :, 0, :] # (B, 1, dim)
|
||||||
|
scale = scale_shift_values[:, :, 1, :] # (B, 1, dim)
|
||||||
|
|
||||||
|
x = norm_out(x)
|
||||||
|
x = x * (1 + scale) + shift # Broadcasts (B, 1, dim) to (B, seq, dim)
|
||||||
|
x = proj_out(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
video: Optional[Modality] = None,
|
||||||
|
audio: Optional[Modality] = None,
|
||||||
|
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
|
||||||
|
|
||||||
|
# Validate inputs
|
||||||
|
if not self.model_type.is_video_enabled() and video is not None:
|
||||||
|
raise ValueError("Video is not enabled for this model")
|
||||||
|
if not self.model_type.is_audio_enabled() and audio is not None:
|
||||||
|
raise ValueError("Audio is not enabled for this model")
|
||||||
|
|
||||||
|
# Preprocess arguments
|
||||||
|
video_args = self.video_args_preprocessor.prepare(video) if video is not None else None
|
||||||
|
audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None
|
||||||
|
|
||||||
|
# Process transformer blocks
|
||||||
|
video_out, audio_out = self._process_transformer_blocks(
|
||||||
|
video=video_args,
|
||||||
|
audio=audio_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process outputs
|
||||||
|
vx = (
|
||||||
|
self._process_output(
|
||||||
|
self.scale_shift_table,
|
||||||
|
self.norm_out,
|
||||||
|
self.proj_out,
|
||||||
|
video_out.x,
|
||||||
|
video_out.embedded_timestep,
|
||||||
|
)
|
||||||
|
if video_out is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
ax = (
|
||||||
|
self._process_output(
|
||||||
|
self.audio_scale_shift_table,
|
||||||
|
self.audio_norm_out,
|
||||||
|
self.audio_proj_out,
|
||||||
|
audio_out.x,
|
||||||
|
audio_out.embedded_timestep,
|
||||||
|
)
|
||||||
|
if audio_out is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return vx, ax
|
||||||
|
|
||||||
|
def sanitize(self, weights: dict) -> dict:
|
||||||
|
sanitized = {}
|
||||||
|
for key, value in weights.items():
|
||||||
|
new_key = key
|
||||||
|
|
||||||
|
# Handle common remappings
|
||||||
|
# transformer_blocks.X -> transformer_blocks[X]
|
||||||
|
if "transformer_blocks." in new_key:
|
||||||
|
# Keep as-is for now, MLX handles this
|
||||||
|
pass
|
||||||
|
|
||||||
|
sanitized[new_key] = value
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
class X0Model(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, velocity_model: LTXModel):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.velocity_model = velocity_model
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
video: Optional[Modality] = None,
|
||||||
|
audio: Optional[Modality] = None,
|
||||||
|
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
|
||||||
|
|
||||||
|
vx, ax = self.velocity_model(video, audio)
|
||||||
|
|
||||||
|
denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
|
||||||
|
denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
|
||||||
|
|
||||||
|
return denoised_video, denoised_audio
|
||||||
508
mlx_video/models/ltx/rope.py
Normal file
508
mlx_video/models/ltx/rope.py
Normal file
@@ -0,0 +1,508 @@
|
|||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Callable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mlx_video.models.ltx.config import LTXRopeType
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(
|
||||||
|
input_tensor: mx.array,
|
||||||
|
freqs_cis: Tuple[mx.array, mx.array],
|
||||||
|
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Apply rotary position embeddings to input tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tensor: Input tensor to apply RoPE to
|
||||||
|
freqs_cis: Tuple of (cos_freqs, sin_freqs)
|
||||||
|
rope_type: Type of RoPE to apply (INTERLEAVED or SPLIT)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor with rotary embeddings applied
|
||||||
|
"""
|
||||||
|
if rope_type == LTXRopeType.INTERLEAVED:
|
||||||
|
return apply_interleaved_rotary_emb(input_tensor, freqs_cis[0], freqs_cis[1])
|
||||||
|
elif rope_type == LTXRopeType.SPLIT:
|
||||||
|
return apply_split_rotary_emb(input_tensor, freqs_cis[0], freqs_cis[1])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid rope type: {rope_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_interleaved_rotary_emb(
|
||||||
|
input_tensor: mx.array,
|
||||||
|
cos_freqs: mx.array,
|
||||||
|
sin_freqs: mx.array,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Apply interleaved rotary embeddings.
|
||||||
|
|
||||||
|
Pairs adjacent dimensions and applies rotation.
|
||||||
|
Pattern: [x0, x1, x2, x3, ...] -> rotate pairs (x0,x1), (x2,x3), ...
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tensor: Input tensor of shape (..., dim)
|
||||||
|
cos_freqs: Cosine frequencies
|
||||||
|
sin_freqs: Sine frequencies
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor with interleaved rotary embeddings applied
|
||||||
|
"""
|
||||||
|
# Reshape to pair adjacent dimensions: (..., dim) -> (..., dim/2, 2)
|
||||||
|
shape = input_tensor.shape
|
||||||
|
input_tensor = mx.reshape(input_tensor, shape[:-1] + (shape[-1] // 2, 2))
|
||||||
|
|
||||||
|
# Extract pairs
|
||||||
|
t1 = input_tensor[..., 0] # Even indices
|
||||||
|
t2 = input_tensor[..., 1] # Odd indices
|
||||||
|
|
||||||
|
# Apply rotation: (-t2, t1) pattern
|
||||||
|
t_rot = mx.stack([-t2, t1], axis=-1)
|
||||||
|
|
||||||
|
# Flatten back: (..., dim/2, 2) -> (..., dim)
|
||||||
|
input_tensor = mx.reshape(input_tensor, shape)
|
||||||
|
t_rot = mx.reshape(t_rot, shape)
|
||||||
|
|
||||||
|
# Apply rotary embeddings
|
||||||
|
out = input_tensor * cos_freqs + t_rot * sin_freqs
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half_interleaved(x: mx.array) -> mx.array:
|
||||||
|
"""Rotate for interleaved RoPE: [x0, x1, x2, x3] -> [-x1, x0, -x3, x2].
|
||||||
|
|
||||||
|
PyTorch equivalent:
|
||||||
|
t_dup = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||||
|
t1, t2 = t_dup.unbind(dim=-1)
|
||||||
|
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||||
|
return rearrange(t_dup, "... d r -> ... (d r)")
|
||||||
|
"""
|
||||||
|
# x: (..., dim) where dim is even
|
||||||
|
x_even = x[..., 0::2] # [x0, x2, x4, ...]
|
||||||
|
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
|
||||||
|
# Stack: [[-x1, x0], [-x3, x2], ...] then flatten to [-x1, x0, -x3, x2, ...]
|
||||||
|
rotated = mx.stack([-x_odd, x_even], axis=-1)
|
||||||
|
return mx.reshape(rotated, x.shape)
|
||||||
|
|
||||||
|
def apply_rotary_emb_1d(
|
||||||
|
q: mx.array,
|
||||||
|
k: mx.array,
|
||||||
|
freqs_cis: mx.array,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
"""Apply 1D rotary embeddings using precomputed frequencies (interleaved)."""
|
||||||
|
# freqs_cis: (1, seq_len, num_heads, head_dim, 2) where [..., 0] = cos, [..., 1] = sin
|
||||||
|
cos = freqs_cis[..., 0] # (1, seq_len, num_heads, head_dim)
|
||||||
|
sin = freqs_cis[..., 1]
|
||||||
|
|
||||||
|
# q, k: (batch, seq_len, num_heads, head_dim)
|
||||||
|
# Interleaved RoPE: pairs of adjacent dims rotate together
|
||||||
|
q_r = q * cos + rotate_half_interleaved(q) * sin
|
||||||
|
k_r = k * cos + rotate_half_interleaved(k) * sin
|
||||||
|
|
||||||
|
return q_r, k_r
|
||||||
|
|
||||||
|
|
||||||
|
def apply_split_rotary_emb(
|
||||||
|
input_tensor: mx.array,
|
||||||
|
cos_freqs: mx.array,
|
||||||
|
sin_freqs: mx.array,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Apply split rotary embeddings.
|
||||||
|
|
||||||
|
Splits dimensions into two halves and applies rotation.
|
||||||
|
Pattern: split into first half and second half
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tensor: Input tensor
|
||||||
|
cos_freqs: Cosine frequencies of shape (B, H, T, D//2)
|
||||||
|
sin_freqs: Sine frequencies of shape (B, H, T, D//2)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor with split rotary embeddings applied
|
||||||
|
"""
|
||||||
|
needs_reshape = False
|
||||||
|
original_shape = input_tensor.shape
|
||||||
|
|
||||||
|
# Handle dimension mismatch
|
||||||
|
if input_tensor.ndim != 4 and cos_freqs.ndim == 4:
|
||||||
|
b, h, t, _ = cos_freqs.shape
|
||||||
|
# Reshape from (B, T, H*D) to (B, H, T, D)
|
||||||
|
input_tensor = mx.reshape(input_tensor, (b, t, h, -1))
|
||||||
|
input_tensor = mx.swapaxes(input_tensor, 1, 2)
|
||||||
|
needs_reshape = True
|
||||||
|
|
||||||
|
# Split into two halves: (..., dim) -> (..., 2, dim//2)
|
||||||
|
dim = input_tensor.shape[-1]
|
||||||
|
split_input = mx.reshape(input_tensor, input_tensor.shape[:-1] + (2, dim // 2))
|
||||||
|
|
||||||
|
# Get first and second halves
|
||||||
|
first_half = split_input[..., 0, :] # (..., dim//2)
|
||||||
|
second_half = split_input[..., 1, :] # (..., dim//2)
|
||||||
|
|
||||||
|
# Apply cosine to both halves
|
||||||
|
output_first = first_half * cos_freqs
|
||||||
|
output_second = second_half * cos_freqs
|
||||||
|
|
||||||
|
# Apply sine cross-terms (addcmul pattern)
|
||||||
|
output_first = output_first - sin_freqs * second_half
|
||||||
|
output_second = output_second + sin_freqs * first_half
|
||||||
|
|
||||||
|
# Stack back together
|
||||||
|
output = mx.stack([output_first, output_second], axis=-2)
|
||||||
|
|
||||||
|
# Flatten: (..., 2, dim//2) -> (..., dim)
|
||||||
|
output = mx.reshape(output, input_tensor.shape)
|
||||||
|
|
||||||
|
if needs_reshape:
|
||||||
|
# Reshape back: (B, H, T, D) -> (B, T, H*D)
|
||||||
|
b, h, t, d = output.shape
|
||||||
|
output = mx.swapaxes(output, 1, 2)
|
||||||
|
output = mx.reshape(output, (b, t, h * d))
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def generate_freq_grid(
|
||||||
|
positional_embedding_theta: float,
|
||||||
|
positional_embedding_max_pos_count: int,
|
||||||
|
inner_dim: int,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Generate frequency grid for RoPE.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
positional_embedding_theta: Base theta value
|
||||||
|
positional_embedding_max_pos_count: Number of position dimensions
|
||||||
|
inner_dim: Inner dimension of the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Frequency indices tensor
|
||||||
|
"""
|
||||||
|
theta = positional_embedding_theta
|
||||||
|
start = 1.0
|
||||||
|
end = theta
|
||||||
|
|
||||||
|
n_elem = 2 * positional_embedding_max_pos_count
|
||||||
|
|
||||||
|
# Compute logarithmic spacing
|
||||||
|
log_start = math.log(start) / math.log(theta)
|
||||||
|
log_end = math.log(end) / math.log(theta)
|
||||||
|
|
||||||
|
num_indices = inner_dim // n_elem
|
||||||
|
if num_indices == 0:
|
||||||
|
num_indices = 1
|
||||||
|
|
||||||
|
# Create linearly spaced values in log space
|
||||||
|
lin_space = mx.linspace(log_start, log_end, num_indices)
|
||||||
|
|
||||||
|
# Compute power indices
|
||||||
|
pow_indices = mx.power(theta, lin_space)
|
||||||
|
|
||||||
|
# Scale by pi/2
|
||||||
|
return pow_indices * (math.pi / 2)
|
||||||
|
|
||||||
|
|
||||||
|
def get_fractional_positions(
|
||||||
|
indices_grid: mx.array,
|
||||||
|
max_pos: List[int],
|
||||||
|
) -> mx.array:
|
||||||
|
"""Convert indices to fractional positions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indices_grid: Grid of position indices of shape (B, n_pos_dims, ...)
|
||||||
|
max_pos: Maximum position for each dimension
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Fractional positions in range [-1, 1] after scaling
|
||||||
|
"""
|
||||||
|
n_pos_dims = indices_grid.shape[1]
|
||||||
|
assert n_pos_dims == len(max_pos), (
|
||||||
|
f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Divide each dimension by its max position
|
||||||
|
fractional_positions = []
|
||||||
|
for i in range(n_pos_dims):
|
||||||
|
frac = indices_grid[:, i] / max_pos[i]
|
||||||
|
fractional_positions.append(frac)
|
||||||
|
|
||||||
|
return mx.stack(fractional_positions, axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_freqs(
|
||||||
|
indices: mx.array,
|
||||||
|
indices_grid: mx.array,
|
||||||
|
max_pos: List[int],
|
||||||
|
use_middle_indices_grid: bool,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Generate frequencies from indices and position grid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indices: Frequency indices
|
||||||
|
indices_grid: Position indices grid
|
||||||
|
max_pos: Maximum positions per dimension
|
||||||
|
use_middle_indices_grid: Whether to use middle of index ranges
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Frequency tensor
|
||||||
|
"""
|
||||||
|
# Handle middle indices grid
|
||||||
|
if use_middle_indices_grid:
|
||||||
|
# indices_grid shape: (B, n_dims, T, 2) where last dim is [start, end]
|
||||||
|
assert len(indices_grid.shape) == 4
|
||||||
|
assert indices_grid.shape[-1] == 2
|
||||||
|
indices_grid_start = indices_grid[..., 0]
|
||||||
|
indices_grid_end = indices_grid[..., 1]
|
||||||
|
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
|
||||||
|
elif len(indices_grid.shape) == 4:
|
||||||
|
indices_grid = indices_grid[..., 0]
|
||||||
|
|
||||||
|
# Get fractional positions
|
||||||
|
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||||
|
|
||||||
|
# Compute frequencies
|
||||||
|
# fractional_positions: (B, T, n_dims)
|
||||||
|
# indices: (inner_dim // n_elem,)
|
||||||
|
# Result: (B, T, inner_dim // n_elem * n_dims)
|
||||||
|
|
||||||
|
# Scale fractional positions to [-1, 1]
|
||||||
|
scaled_positions = fractional_positions * 2 - 1 # (B, T, n_dims)
|
||||||
|
|
||||||
|
# Outer product with indices
|
||||||
|
# (B, T, n_dims, 1) * (1, 1, 1, n_indices) -> (B, T, n_dims, n_indices)
|
||||||
|
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.expand_dims(
|
||||||
|
mx.expand_dims(mx.expand_dims(indices, axis=0), axis=0), axis=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transpose and flatten: (B, T, n_dims, n_indices) -> (B, T, n_indices * n_dims)
|
||||||
|
freqs = mx.swapaxes(freqs, -1, -2) # (B, T, n_indices, n_dims)
|
||||||
|
freqs = mx.reshape(freqs, freqs.shape[:-2] + (-1,))
|
||||||
|
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
|
||||||
|
def split_freqs_cis(
|
||||||
|
freqs: mx.array,
|
||||||
|
pad_size: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
"""Prepare cos/sin frequencies for split RoPE.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
freqs: Frequency tensor
|
||||||
|
pad_size: Padding size for dimension alignment
|
||||||
|
num_attention_heads: Number of attention heads
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (cos_freq, sin_freq) with shape (B, H, T, D//2)
|
||||||
|
"""
|
||||||
|
cos_freq = mx.cos(freqs)
|
||||||
|
sin_freq = mx.sin(freqs)
|
||||||
|
|
||||||
|
# Add padding if needed
|
||||||
|
if pad_size != 0:
|
||||||
|
cos_padding = mx.ones_like(cos_freq[:, :, :pad_size])
|
||||||
|
sin_padding = mx.zeros_like(sin_freq[:, :, :pad_size])
|
||||||
|
|
||||||
|
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
|
||||||
|
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
|
||||||
|
|
||||||
|
# Reshape for multi-head attention
|
||||||
|
b, t = cos_freq.shape[0], cos_freq.shape[1]
|
||||||
|
|
||||||
|
cos_freq = mx.reshape(cos_freq, (b, t, num_attention_heads, -1))
|
||||||
|
sin_freq = mx.reshape(sin_freq, (b, t, num_attention_heads, -1))
|
||||||
|
|
||||||
|
# Swap axes: (B, T, H, D//2) -> (B, H, T, D//2)
|
||||||
|
cos_freq = mx.swapaxes(cos_freq, 1, 2)
|
||||||
|
sin_freq = mx.swapaxes(sin_freq, 1, 2)
|
||||||
|
|
||||||
|
return cos_freq, sin_freq
|
||||||
|
|
||||||
|
|
||||||
|
def interleaved_freqs_cis(
|
||||||
|
freqs: mx.array,
|
||||||
|
pad_size: int,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
"""Prepare cos/sin frequencies for interleaved RoPE.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
freqs: Frequency tensor of shape (B, T, dim//2)
|
||||||
|
pad_size: Padding size for dimension alignment
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (cos_freq, sin_freq) with shape (B, T, dim)
|
||||||
|
"""
|
||||||
|
# Compute cos and sin
|
||||||
|
cos_freq = mx.cos(freqs)
|
||||||
|
sin_freq = mx.sin(freqs)
|
||||||
|
|
||||||
|
# Repeat interleave: each element repeated twice
|
||||||
|
# (B, T, D) -> (B, T, 2*D) with pattern [c0, c0, c1, c1, ...]
|
||||||
|
cos_freq = mx.repeat(cos_freq, 2, axis=-1)
|
||||||
|
sin_freq = mx.repeat(sin_freq, 2, axis=-1)
|
||||||
|
|
||||||
|
# Add padding if needed
|
||||||
|
if pad_size != 0:
|
||||||
|
cos_padding = mx.ones_like(cos_freq[:, :, :pad_size])
|
||||||
|
sin_padding = mx.zeros_like(sin_freq[:, :, :pad_size])
|
||||||
|
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
|
||||||
|
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
|
||||||
|
|
||||||
|
return cos_freq, sin_freq
|
||||||
|
|
||||||
|
|
||||||
|
def precompute_freqs_cis(
|
||||||
|
indices_grid: mx.array,
|
||||||
|
dim: int,
|
||||||
|
theta: float = 10000.0,
|
||||||
|
max_pos: Optional[List[int]] = None,
|
||||||
|
use_middle_indices_grid: bool = False,
|
||||||
|
num_attention_heads: int = 32,
|
||||||
|
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||||
|
double_precision: bool = False,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
"""Precompute RoPE frequencies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indices_grid: Position indices grid
|
||||||
|
dim: Dimension for RoPE
|
||||||
|
theta: Base theta value for frequency computation
|
||||||
|
max_pos: Maximum position per dimension
|
||||||
|
use_middle_indices_grid: Whether to use middle indices
|
||||||
|
num_attention_heads: Number of attention heads
|
||||||
|
rope_type: Type of RoPE (INTERLEAVED or SPLIT)
|
||||||
|
double_precision: If True, compute frequencies in float64 for higher precision
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (cos_freq, sin_freq) tensors
|
||||||
|
"""
|
||||||
|
if max_pos is None:
|
||||||
|
max_pos = [20, 2048, 2048]
|
||||||
|
|
||||||
|
# For double precision, compute in numpy (float64) then convert back to MLX
|
||||||
|
# MLX GPU doesn't support float64, so we use numpy for high precision computation
|
||||||
|
if double_precision:
|
||||||
|
return _precompute_freqs_cis_double_precision(
|
||||||
|
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
|
||||||
|
num_attention_heads, rope_type
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate frequency indices
|
||||||
|
indices = generate_freq_grid(theta, indices_grid.shape[1], dim)
|
||||||
|
|
||||||
|
# Generate frequencies
|
||||||
|
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
|
||||||
|
|
||||||
|
# Prepare cos/sin based on rope type
|
||||||
|
if rope_type == LTXRopeType.SPLIT:
|
||||||
|
expected_freqs = dim // 2
|
||||||
|
current_freqs = freqs.shape[-1]
|
||||||
|
pad_size = expected_freqs - current_freqs
|
||||||
|
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
|
||||||
|
else:
|
||||||
|
# Interleaved
|
||||||
|
n_elem = 2 * indices_grid.shape[1]
|
||||||
|
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||||
|
|
||||||
|
return cos_freq, sin_freq
|
||||||
|
|
||||||
|
|
||||||
|
def _precompute_freqs_cis_double_precision(
|
||||||
|
indices_grid: mx.array,
|
||||||
|
dim: int,
|
||||||
|
theta: float,
|
||||||
|
max_pos: List[int],
|
||||||
|
use_middle_indices_grid: bool,
|
||||||
|
num_attention_heads: int,
|
||||||
|
rope_type: LTXRopeType,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
"""Compute RoPE frequencies in double precision using numpy.
|
||||||
|
|
||||||
|
MLX GPU doesn't support float64, so we use numpy for computation then convert back.
|
||||||
|
"""
|
||||||
|
# Convert to numpy float64
|
||||||
|
indices_grid_np = np.array(indices_grid).astype(np.float64)
|
||||||
|
|
||||||
|
# Generate frequency indices in float64
|
||||||
|
n_pos_dims = indices_grid_np.shape[1]
|
||||||
|
n_elem = 2 * n_pos_dims
|
||||||
|
|
||||||
|
# Compute log-spaced frequencies
|
||||||
|
log_start = math.log(1.0) / math.log(theta)
|
||||||
|
log_end = math.log(theta) / math.log(theta)
|
||||||
|
num_indices = dim // n_elem
|
||||||
|
if num_indices == 0:
|
||||||
|
num_indices = 1
|
||||||
|
lin_space = np.linspace(log_start, log_end, num_indices)
|
||||||
|
indices_np = np.power(theta, lin_space) * (math.pi / 2)
|
||||||
|
|
||||||
|
# Handle middle indices grid
|
||||||
|
# Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise
|
||||||
|
if use_middle_indices_grid:
|
||||||
|
assert len(indices_grid_np.shape) == 4
|
||||||
|
assert indices_grid_np.shape[-1] == 2
|
||||||
|
indices_grid_start = indices_grid_np[..., 0]
|
||||||
|
indices_grid_end = indices_grid_np[..., 1]
|
||||||
|
indices_grid_np = (indices_grid_start + indices_grid_end) / 2.0
|
||||||
|
elif len(indices_grid_np.shape) == 4:
|
||||||
|
indices_grid_np = indices_grid_np[..., 0]
|
||||||
|
# After handling: indices_grid_np shape is (B, n_dims, T)
|
||||||
|
|
||||||
|
# Get fractional positions: (B, n_dims, T) -> (B, T, n_dims)
|
||||||
|
batch_size = indices_grid_np.shape[0]
|
||||||
|
seq_len = indices_grid_np.shape[2]
|
||||||
|
fractional_positions = np.zeros((batch_size, seq_len, n_pos_dims), dtype=np.float64)
|
||||||
|
for i in range(n_pos_dims):
|
||||||
|
# indices_grid_np[:, i, :] has shape (B, T)
|
||||||
|
fractional_positions[:, :, i] = indices_grid_np[:, i, :] / max_pos[i]
|
||||||
|
|
||||||
|
# Scale to [-1, 1]
|
||||||
|
scaled_positions = fractional_positions * 2 - 1
|
||||||
|
|
||||||
|
# Compute frequencies: outer product
|
||||||
|
freqs = np.expand_dims(scaled_positions, axis=-1) * indices_np.reshape(1, 1, 1, -1)
|
||||||
|
freqs = np.swapaxes(freqs, -1, -2)
|
||||||
|
freqs = freqs.reshape(freqs.shape[:-2] + (-1,))
|
||||||
|
|
||||||
|
# Compute cos/sin in float64
|
||||||
|
cos_freq = np.cos(freqs)
|
||||||
|
sin_freq = np.sin(freqs)
|
||||||
|
|
||||||
|
# Prepare based on rope type
|
||||||
|
if rope_type == LTXRopeType.SPLIT:
|
||||||
|
expected_freqs = dim // 2
|
||||||
|
current_freqs = cos_freq.shape[-1]
|
||||||
|
pad_size = expected_freqs - current_freqs
|
||||||
|
|
||||||
|
# Add padding
|
||||||
|
if pad_size > 0:
|
||||||
|
cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64)
|
||||||
|
sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64)
|
||||||
|
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1)
|
||||||
|
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1)
|
||||||
|
|
||||||
|
# Reshape for multi-head attention: (B, T, dim//2) -> (B, H, T, dim//2//H)
|
||||||
|
b, t = cos_freq.shape[0], cos_freq.shape[1]
|
||||||
|
cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1)
|
||||||
|
sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1)
|
||||||
|
cos_freq = np.swapaxes(cos_freq, 1, 2)
|
||||||
|
sin_freq = np.swapaxes(sin_freq, 1, 2)
|
||||||
|
else:
|
||||||
|
# Interleaved
|
||||||
|
cos_freq = np.repeat(cos_freq, 2, axis=-1)
|
||||||
|
sin_freq = np.repeat(sin_freq, 2, axis=-1)
|
||||||
|
|
||||||
|
pad_size = dim % n_elem
|
||||||
|
if pad_size > 0:
|
||||||
|
cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64)
|
||||||
|
sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64)
|
||||||
|
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1)
|
||||||
|
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1)
|
||||||
|
|
||||||
|
# Convert back to MLX (float32 for GPU compatibility)
|
||||||
|
cos_freq = mx.array(cos_freq.astype(np.float32))
|
||||||
|
sin_freq = mx.array(sin_freq.astype(np.float32))
|
||||||
|
|
||||||
|
return cos_freq, sin_freq
|
||||||
727
mlx_video/models/ltx/text_encoder.py
Normal file
727
mlx_video/models/ltx/text_encoder.py
Normal file
@@ -0,0 +1,727 @@
|
|||||||
|
"""Gemma 3 Text Encoder for LTX-2 - Full Pipeline."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from mlx_video.utils import rms_norm
|
||||||
|
from mlx_video.models.ltx.rope import apply_rotary_emb_1d
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Gemma3Config:
|
||||||
|
"""Configuration for Gemma 3 text model."""
|
||||||
|
hidden_size: int = 3840
|
||||||
|
num_attention_heads: int = 16
|
||||||
|
num_key_value_heads: int = 8
|
||||||
|
head_dim: int = 256
|
||||||
|
intermediate_size: int = 15360
|
||||||
|
num_hidden_layers: int = 48
|
||||||
|
rms_norm_eps: float = 1e-6
|
||||||
|
rope_theta: float = 1000000.0
|
||||||
|
vocab_size: int = 262208
|
||||||
|
max_position_embeddings: int = 131072
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
"""RMS Normalization (Gemma style with 1+weight scaling)."""
|
||||||
|
|
||||||
|
def __init__(self, dims: int, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
# Gemma initializes to ones, but uses (1+weight) scaling
|
||||||
|
# After loading weights, weight will have the actual learned values
|
||||||
|
self.weight = mx.ones((dims,))
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
# Gemma-style RMSNorm uses (1 + weight) as the scale factor
|
||||||
|
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(
|
||||||
|
q: mx.array,
|
||||||
|
k: mx.array,
|
||||||
|
positions: mx.array,
|
||||||
|
head_dim: int,
|
||||||
|
rope_theta: float = 1000000.0,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
"""Apply rotary position embeddings to Q and K."""
|
||||||
|
inv_freq = 1.0 / (rope_theta ** (mx.arange(0, head_dim, 2).astype(mx.float32) / head_dim))
|
||||||
|
freqs = positions[:, :, None].astype(mx.float32) * inv_freq[None, None, :]
|
||||||
|
cos = mx.cos(freqs)
|
||||||
|
sin = mx.sin(freqs)
|
||||||
|
cos = cos[:, :, None, :]
|
||||||
|
sin = sin[:, :, None, :]
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return mx.concatenate([-x2, x1], axis=-1)
|
||||||
|
|
||||||
|
cos_full = mx.concatenate([cos, cos], axis=-1)
|
||||||
|
sin_full = mx.concatenate([sin, sin], axis=-1)
|
||||||
|
q_embed = q * cos_full + rotate_half(q) * sin_full
|
||||||
|
k_embed = k * cos_full + rotate_half(k) * sin_full
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3MLP(nn.Module):
|
||||||
|
"""Gemma 3 MLP with gated activation."""
|
||||||
|
|
||||||
|
def __init__(self, config: Gemma3Config):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||||
|
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||||
|
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
gate = nn.gelu_approx(self.gate_proj(x))
|
||||||
|
up = self.up_proj(x)
|
||||||
|
return self.down_proj(gate * up)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3Attention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: Gemma3Config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.num_kv_heads = config.num_key_value_heads
|
||||||
|
self.head_dim = config.head_dim
|
||||||
|
self.scale = 1.0 / math.sqrt(config.head_dim)
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||||
|
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
||||||
|
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
||||||
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
|
||||||
|
|
||||||
|
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||||
|
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states: mx.array,
|
||||||
|
positions: mx.array,
|
||||||
|
attention_mask: Optional[mx.array] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
batch_size, seq_len, _ = hidden_states.shape
|
||||||
|
|
||||||
|
q = self.q_proj(hidden_states)
|
||||||
|
k = self.k_proj(hidden_states)
|
||||||
|
v = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||||
|
k = mx.reshape(k, (batch_size, seq_len, self.num_kv_heads, self.head_dim))
|
||||||
|
v = mx.reshape(v, (batch_size, seq_len, self.num_kv_heads, self.head_dim))
|
||||||
|
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
|
||||||
|
q, k = apply_rotary_emb(q, k, positions, self.head_dim, self.config.rope_theta)
|
||||||
|
|
||||||
|
q = mx.transpose(q, (0, 2, 1, 3))
|
||||||
|
k = mx.transpose(k, (0, 2, 1, 3))
|
||||||
|
v = mx.transpose(v, (0, 2, 1, 3))
|
||||||
|
|
||||||
|
# Create causal mask (lower triangular)
|
||||||
|
causal_mask = mx.triu(mx.full((seq_len, seq_len), -1e9, dtype=k.dtype), k=1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :] # (1, 1, seq, seq
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask + (1.0 - attention_mask[:, None, None, :].astype(k.dtype)) * -1e9
|
||||||
|
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=causal_mask)
|
||||||
|
out = mx.transpose(out, (0, 2, 1, 3))
|
||||||
|
out = mx.reshape(out, (batch_size, seq_len, -1))
|
||||||
|
|
||||||
|
return self.o_proj(out)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3DecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: Gemma3Config):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = Gemma3Attention(config)
|
||||||
|
self.mlp = Gemma3MLP(config)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states: mx.array,
|
||||||
|
positions: mx.array,
|
||||||
|
attention_mask: Optional[mx.array] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
hidden_states = self.self_attn(hidden_states, positions, attention_mask)
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.pre_feedforward_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3TextModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: Gemma3Config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||||
|
self.layers = [Gemma3DecoderLayer(config) for _ in range(config.num_hidden_layers)]
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
# Gemma scales embeddings by sqrt(hidden_size)
|
||||||
|
self.embed_scale = config.hidden_size ** 0.5
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids: mx.array,
|
||||||
|
attention_mask: Optional[mx.array] = None,
|
||||||
|
output_hidden_states: bool = True,
|
||||||
|
) -> Tuple[mx.array, List[mx.array]]:
|
||||||
|
|
||||||
|
batch_size, seq_len = input_ids.shape
|
||||||
|
|
||||||
|
# Gemma scales embeddings by sqrt(hidden_size)
|
||||||
|
hidden_states = self.embed_tokens(input_ids) * self.embed_scale
|
||||||
|
|
||||||
|
all_hidden_states = [hidden_states] if output_hidden_states else []
|
||||||
|
|
||||||
|
positions = mx.arange(seq_len)[None, :].astype(mx.int32)
|
||||||
|
positions = mx.broadcast_to(positions, (batch_size, seq_len))
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states, positions, attention_mask)
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states.append(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, all_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int = 3840,
|
||||||
|
num_heads: int = 30,
|
||||||
|
head_dim: int = 128,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
inner_dim = num_heads * head_dim
|
||||||
|
self.scale = 1.0 / math.sqrt(head_dim)
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(dim, inner_dim, bias=True)
|
||||||
|
self.to_k = nn.Linear(dim, inner_dim, bias=True)
|
||||||
|
self.to_v = nn.Linear(dim, inner_dim, bias=True)
|
||||||
|
self.to_out = [nn.Linear(inner_dim, dim, bias=True)]
|
||||||
|
|
||||||
|
# Standard RMSNorm (not Gemma-style) on full inner_dim
|
||||||
|
self.q_norm = nn.RMSNorm(inner_dim, eps=1e-6)
|
||||||
|
self.k_norm = nn.RMSNorm(inner_dim, eps=1e-6)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
attention_mask: Optional[mx.array] = None,
|
||||||
|
pe: Optional[mx.array] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
batch_size, seq_len, _ = x.shape
|
||||||
|
|
||||||
|
# Project to Q, K, V
|
||||||
|
q = self.to_q(x) # (B, seq, inner_dim)
|
||||||
|
k = self.to_k(x)
|
||||||
|
v = self.to_v(x)
|
||||||
|
|
||||||
|
# QK normalization on full inner_dim BEFORE reshape (matches PyTorch)
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
|
||||||
|
|
||||||
|
if pe is not None:
|
||||||
|
# pe: (1, seq_len, num_heads, head_dim, 2)
|
||||||
|
# q, k: (B, seq, inner_dim) - need to reshape for RoPE then reshape back
|
||||||
|
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||||
|
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||||
|
q, k = apply_rotary_emb_1d(q, k, pe)
|
||||||
|
# Reshape back for attention computation
|
||||||
|
q = mx.reshape(q, (batch_size, seq_len, -1))
|
||||||
|
k = mx.reshape(k, (batch_size, seq_len, -1))
|
||||||
|
|
||||||
|
|
||||||
|
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||||
|
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||||
|
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
mask = mx.full((batch_size, seq_len, seq_len), -1e9, dtype=q.dtype)
|
||||||
|
if attention_mask is not None:
|
||||||
|
mask = mask + (1.0 - attention_mask[:, None, None, :].astype(q.dtype)) * -1e9
|
||||||
|
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attention_mask)
|
||||||
|
out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
|
||||||
|
|
||||||
|
return self.to_out[0](out)
|
||||||
|
|
||||||
|
|
||||||
|
class GEGLU(nn.Module):
|
||||||
|
"""GELU-gated linear unit."""
|
||||||
|
|
||||||
|
def __init__(self, in_dim: int, out_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Linear(in_dim, out_dim, bias=True)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
return nn.gelu_approx(self.proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorFeedForward(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim: int = 3840, mult: int = 4, dropout: float = 0.0):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = dim * mult
|
||||||
|
self.net = [
|
||||||
|
GEGLU(dim, inner_dim),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(inner_dim, dim, bias=True),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
for layer in self.net:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorTransformerBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128):
|
||||||
|
super().__init__()
|
||||||
|
self.attn1 = ConnectorAttention(dim, num_heads, head_dim)
|
||||||
|
self.ff = ConnectorFeedForward(dim)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
attention_mask: Optional[mx.array] = None,
|
||||||
|
pe: Optional[mx.array] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
# Pre-norm + attention + residual
|
||||||
|
norm_x = rms_norm(x)
|
||||||
|
if norm_x.ndim == 4:
|
||||||
|
norm_x = mx.squeeze(norm_x, axis=1)
|
||||||
|
attn_out = self.attn1(norm_x, attention_mask, pe)
|
||||||
|
x = x + attn_out
|
||||||
|
if x.ndim == 4:
|
||||||
|
x = mx.squeeze(x, axis=1)
|
||||||
|
|
||||||
|
# Pre-norm + FFN + residual
|
||||||
|
norm_x = rms_norm(x)
|
||||||
|
ff_out = self.ff(norm_x)
|
||||||
|
x = x + ff_out
|
||||||
|
if x.ndim == 4:
|
||||||
|
x = mx.squeeze(x, axis=1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Embeddings1DConnector(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int = 3840,
|
||||||
|
num_heads: int = 30,
|
||||||
|
head_dim: int = 128,
|
||||||
|
num_layers: int = 2,
|
||||||
|
num_learnable_registers: int = 128,
|
||||||
|
positional_embedding_theta: float = 10000.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.num_learnable_registers = num_learnable_registers
|
||||||
|
self.positional_embedding_theta = positional_embedding_theta
|
||||||
|
|
||||||
|
self.transformer_1d_blocks = [
|
||||||
|
ConnectorTransformerBlock(dim, num_heads, head_dim)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
if num_learnable_registers > 0:
|
||||||
|
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
|
||||||
|
|
||||||
|
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> mx.array:
|
||||||
|
import math
|
||||||
|
|
||||||
|
dim = self.num_heads * self.head_dim
|
||||||
|
theta = self.positional_embedding_theta
|
||||||
|
n_elem = 2
|
||||||
|
|
||||||
|
|
||||||
|
linspace_vals = mx.linspace(0.0, 1.0, dim // n_elem)
|
||||||
|
indices = (theta ** linspace_vals) * (math.pi / 2)
|
||||||
|
|
||||||
|
positions = mx.arange(seq_len).astype(mx.float32)
|
||||||
|
freqs = positions[:, None] * indices[None, :] # (seq_len, dim//2)
|
||||||
|
|
||||||
|
cos = mx.cos(freqs) # (seq_len, dim//2)
|
||||||
|
sin = mx.sin(freqs)
|
||||||
|
|
||||||
|
|
||||||
|
cos_full = mx.repeat(cos, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
|
||||||
|
sin_full = mx.repeat(sin, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
freqs_cis = mx.stack([cos_full, sin_full], axis=-1) # (1, seq_len, num_heads, head_dim, 2)
|
||||||
|
return freqs_cis.astype(dtype)
|
||||||
|
|
||||||
|
def _replace_padded_with_registers(
|
||||||
|
self,
|
||||||
|
hidden_states: mx.array,
|
||||||
|
attention_mask: mx.array,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
batch_size, seq_len, dim = hidden_states.shape
|
||||||
|
|
||||||
|
# Binary mask: 1 for valid tokens, 0 for padded
|
||||||
|
# attention_mask is additive: 0 for valid, large negative for padded
|
||||||
|
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(mx.int32) # (batch, seq)
|
||||||
|
|
||||||
|
# Tile registers to match sequence length
|
||||||
|
num_tiles = seq_len // self.num_learnable_registers
|
||||||
|
registers = mx.tile(self.learnable_registers, (num_tiles, 1)) # (seq_len, dim)
|
||||||
|
|
||||||
|
# Process each batch item (PyTorch uses advanced indexing)
|
||||||
|
result_list = []
|
||||||
|
for b in range(batch_size):
|
||||||
|
mask_b = mask_binary[b] # (seq,)
|
||||||
|
hs_b = hidden_states[b] # (seq, dim)
|
||||||
|
|
||||||
|
# Count valid tokens
|
||||||
|
num_valid = int(mx.sum(mask_b))
|
||||||
|
|
||||||
|
# Extract valid tokens (where mask is 1)
|
||||||
|
# Since we have left-padded input, valid tokens are at the end
|
||||||
|
valid_tokens = hs_b[seq_len - num_valid:] # (num_valid, dim)
|
||||||
|
|
||||||
|
# Pad with zeros on the right to get back to seq_len
|
||||||
|
pad_length = seq_len - num_valid
|
||||||
|
if pad_length > 0:
|
||||||
|
padding = mx.zeros((pad_length, dim), dtype=hs_b.dtype)
|
||||||
|
adjusted = mx.concatenate([valid_tokens, padding], axis=0) # (seq_len, dim)
|
||||||
|
else:
|
||||||
|
adjusted = valid_tokens
|
||||||
|
|
||||||
|
# Create flipped mask: 1s at front (where valid tokens now are), 0s at back
|
||||||
|
flipped_mask = mx.concatenate([
|
||||||
|
mx.ones((num_valid,), dtype=mx.int32),
|
||||||
|
mx.zeros((pad_length,), dtype=mx.int32)
|
||||||
|
], axis=0) # (seq,)
|
||||||
|
|
||||||
|
# Combine: valid tokens at front, registers at back
|
||||||
|
flipped_mask_expanded = flipped_mask[:, None].astype(hs_b.dtype) # (seq, 1)
|
||||||
|
combined = flipped_mask_expanded * adjusted + (1 - flipped_mask_expanded) * registers
|
||||||
|
|
||||||
|
result_list.append(combined)
|
||||||
|
|
||||||
|
hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim)
|
||||||
|
|
||||||
|
# Reset attention mask to all zeros (no masking after register replacement)
|
||||||
|
attention_mask = mx.zeros_like(attention_mask)
|
||||||
|
|
||||||
|
return hidden_states, attention_mask
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states: mx.array,
|
||||||
|
attention_mask: Optional[mx.array] = None,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
|
||||||
|
# Replace padded tokens with learnable registers
|
||||||
|
if self.num_learnable_registers > 0 and attention_mask is not None:
|
||||||
|
hidden_states, attention_mask = self._replace_padded_with_registers(
|
||||||
|
hidden_states, attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute RoPE frequencies
|
||||||
|
seq_len = hidden_states.shape[1]
|
||||||
|
freqs_cis = self._precompute_freqs_cis(seq_len, hidden_states.dtype)
|
||||||
|
|
||||||
|
# Process through transformer blocks
|
||||||
|
for block in self.transformer_1d_blocks:
|
||||||
|
hidden_states = block(hidden_states, attention_mask, freqs_cis)
|
||||||
|
|
||||||
|
# Final RMS norm
|
||||||
|
hidden_states = rms_norm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def norm_and_concat_hidden_states(
|
||||||
|
hidden_states: List[mx.array],
|
||||||
|
attention_mask: mx.array,
|
||||||
|
padding_side: str = "left",
|
||||||
|
) -> mx.array:
|
||||||
|
|
||||||
|
# Stack hidden states: (batch, seq, dim, num_layers)
|
||||||
|
stacked = mx.stack(hidden_states, axis=-1)
|
||||||
|
b, t, d, num_layers = stacked.shape
|
||||||
|
|
||||||
|
# Compute sequence lengths from attention mask
|
||||||
|
sequence_lengths = mx.sum(attention_mask, axis=-1) # (batch,)
|
||||||
|
|
||||||
|
# Build mask based on padding side
|
||||||
|
token_indices = mx.arange(t)[None, :] # (1, T)
|
||||||
|
|
||||||
|
if padding_side == "right":
|
||||||
|
mask = token_indices < sequence_lengths[:, None] # (B, T)
|
||||||
|
else: # left padding
|
||||||
|
start_indices = t - sequence_lengths[:, None] # (B, 1)
|
||||||
|
mask = token_indices >= start_indices # (B, T)
|
||||||
|
|
||||||
|
mask = mask[:, :, None, None] # (B, T, 1, 1)
|
||||||
|
eps = 1e-6
|
||||||
|
|
||||||
|
# Compute masked mean per layer
|
||||||
|
masked = mx.where(mask, stacked, mx.zeros_like(stacked))
|
||||||
|
denom = (sequence_lengths * d).reshape(b, 1, 1, 1)
|
||||||
|
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
|
||||||
|
|
||||||
|
# Compute masked min/max per layer
|
||||||
|
large_val = 1e9
|
||||||
|
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, large_val, dtype=stacked.dtype))
|
||||||
|
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, -large_val, dtype=stacked.dtype))
|
||||||
|
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
|
||||||
|
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
|
||||||
|
range_val = x_max - x_min
|
||||||
|
|
||||||
|
# Normalize: 8 * (x - mean) / range
|
||||||
|
normed = 8 * (stacked - mean) / (range_val + eps)
|
||||||
|
|
||||||
|
# Flatten layers into feature dimension: (B, T, D*L)
|
||||||
|
normed = mx.reshape(normed, (b, t, -1))
|
||||||
|
|
||||||
|
# Zero out padded positions
|
||||||
|
mask_flat = mx.broadcast_to(mask[:, :, :, 0], (b, t, d * num_layers))
|
||||||
|
normed = mx.where(mask_flat, normed, mx.zeros_like(normed))
|
||||||
|
|
||||||
|
return normed
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaFeaturesExtractor(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, input_dim: int = 188160, output_dim: int = 3840):
|
||||||
|
super().__init__()
|
||||||
|
self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
return self.aggregate_embed(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_gemma3_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
|
sanitized = {}
|
||||||
|
|
||||||
|
for key, value in weights.items():
|
||||||
|
new_key = None
|
||||||
|
|
||||||
|
if key.startswith("base_text_encoder.language_model."):
|
||||||
|
new_key = key.replace("base_text_encoder.language_model.", "")
|
||||||
|
elif key.startswith("language_model.model."):
|
||||||
|
new_key = key.replace("language_model.model.", "")
|
||||||
|
elif key.startswith("language_model."):
|
||||||
|
new_key = key.replace("language_model.", "")
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if new_key is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
sanitized[new_key] = value
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2TextEncoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path: str = "Lightricks/LTX-2",
|
||||||
|
hidden_dim: int = 3840,
|
||||||
|
num_layers: int = 49, # 48 transformer layers + 1 embedding
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._model_path = model_path
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
# Gemma 3 model
|
||||||
|
self.config = Gemma3Config()
|
||||||
|
self.model = Gemma3TextModel(self.config)
|
||||||
|
|
||||||
|
# Feature extractor: 3840*49 -> 3840
|
||||||
|
self.feature_extractor = GemmaFeaturesExtractor(
|
||||||
|
input_dim=hidden_dim * num_layers,
|
||||||
|
output_dim=hidden_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Video embeddings connector: 2-layer transformer
|
||||||
|
self.video_embeddings_connector = Embeddings1DConnector(
|
||||||
|
dim=hidden_dim,
|
||||||
|
num_heads=30,
|
||||||
|
head_dim=128,
|
||||||
|
num_layers=2,
|
||||||
|
num_learnable_registers=128,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.processor = None
|
||||||
|
|
||||||
|
def load(self, model_path: Optional[str] = None):
|
||||||
|
path = model_path or self._model_path
|
||||||
|
|
||||||
|
# Load Gemma weights from text_encoder subdirectory
|
||||||
|
if Path(path).is_dir():
|
||||||
|
text_encoder_path = Path(path) / "text_encoder"
|
||||||
|
if text_encoder_path.exists():
|
||||||
|
gemma_path = str(text_encoder_path)
|
||||||
|
else:
|
||||||
|
gemma_path = path
|
||||||
|
else:
|
||||||
|
gemma_path = path
|
||||||
|
|
||||||
|
print(f"Loading Gemma 3 text encoder from {gemma_path}...")
|
||||||
|
weight_files = sorted(Path(gemma_path).glob("*.safetensors"))
|
||||||
|
all_weights = {}
|
||||||
|
for i, wf in enumerate(weight_files):
|
||||||
|
print(f" Loading weight file {i+1}/{len(weight_files)}...")
|
||||||
|
weights = mx.load(str(wf))
|
||||||
|
all_weights.update(weights)
|
||||||
|
|
||||||
|
# Sanitize and load Gemma weights
|
||||||
|
sanitized = sanitize_gemma3_weights(all_weights)
|
||||||
|
print(f" Sanitized Gemma weights: {len(sanitized)}")
|
||||||
|
self.model.load_weights(list(sanitized.items()), strict=False)
|
||||||
|
|
||||||
|
# Load transformer weights for feature extractor and connector
|
||||||
|
transformer_path = Path(model_path or self._model_path)
|
||||||
|
transformer_files = list(transformer_path.glob("ltx-2*.safetensors"))
|
||||||
|
if transformer_files:
|
||||||
|
print(f"Loading transformer weights for text pipeline...")
|
||||||
|
transformer_weights = mx.load(str(transformer_files[0]))
|
||||||
|
|
||||||
|
# Load feature extractor (aggregate_embed)
|
||||||
|
if "text_embedding_projection.aggregate_embed.weight" in transformer_weights:
|
||||||
|
self.feature_extractor.aggregate_embed.weight = transformer_weights[
|
||||||
|
"text_embedding_projection.aggregate_embed.weight"
|
||||||
|
]
|
||||||
|
print(" Loaded aggregate_embed weights")
|
||||||
|
|
||||||
|
# Load video_embeddings_connector weights
|
||||||
|
connector_weights = {}
|
||||||
|
for key, value in transformer_weights.items():
|
||||||
|
if key.startswith("model.diffusion_model.video_embeddings_connector."):
|
||||||
|
new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "")
|
||||||
|
connector_weights[new_key] = value
|
||||||
|
|
||||||
|
if connector_weights:
|
||||||
|
# Map weight names to our structure
|
||||||
|
mapped_weights = {}
|
||||||
|
for key, value in connector_weights.items():
|
||||||
|
# transformer_1d_blocks.X.attn1.* -> transformer_1d_blocks.X.attn1.*
|
||||||
|
# transformer_1d_blocks.X.ff.net.0.proj.* -> transformer_1d_blocks.X.ff.net.0.proj.*
|
||||||
|
# transformer_1d_blocks.X.ff.net.2.* -> transformer_1d_blocks.X.ff.net.2.*
|
||||||
|
mapped_weights[key] = value
|
||||||
|
|
||||||
|
self.video_embeddings_connector.load_weights(
|
||||||
|
list(mapped_weights.items()), strict=False
|
||||||
|
)
|
||||||
|
print(f" Loaded {len(connector_weights)} connector weights")
|
||||||
|
|
||||||
|
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
|
||||||
|
if "learnable_registers" in connector_weights:
|
||||||
|
self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"]
|
||||||
|
print(f" Loaded learnable_registers: {connector_weights['learnable_registers'].shape}")
|
||||||
|
|
||||||
|
# Load tokenizer
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
tokenizer_path = Path(model_path or self._model_path) / "tokenizer"
|
||||||
|
if tokenizer_path.exists():
|
||||||
|
self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True)
|
||||||
|
else:
|
||||||
|
self.processor = AutoTokenizer.from_pretrained(gemma_path, trust_remote_code=True)
|
||||||
|
# Set left padding to match official LTX-2 text encoder
|
||||||
|
self.processor.padding_side = "left"
|
||||||
|
|
||||||
|
print("Text encoder loaded successfully")
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
max_length: int = 1024,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
|
||||||
|
if self.processor is None:
|
||||||
|
raise RuntimeError("Model not loaded. Call load() first.")
|
||||||
|
|
||||||
|
# Tokenize with left padding (as in PyTorch version)
|
||||||
|
inputs = self.processor(
|
||||||
|
prompt,
|
||||||
|
return_tensors="np",
|
||||||
|
max_length=max_length,
|
||||||
|
truncation=True,
|
||||||
|
padding="max_length",
|
||||||
|
)
|
||||||
|
input_ids = mx.array(inputs["input_ids"])
|
||||||
|
attention_mask = mx.array(inputs["attention_mask"])
|
||||||
|
|
||||||
|
# Get all hidden states from Gemma
|
||||||
|
_, all_hidden_states = self.model(input_ids, attention_mask, output_hidden_states=True)
|
||||||
|
|
||||||
|
# Normalize and concatenate all hidden states
|
||||||
|
concat_hidden = norm_and_concat_hidden_states(
|
||||||
|
all_hidden_states, attention_mask, padding_side="left"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Project through feature extractor
|
||||||
|
features = self.feature_extractor(concat_hidden)
|
||||||
|
|
||||||
|
# Convert attention mask to additive format for connector
|
||||||
|
additive_mask = (attention_mask - 1).astype(features.dtype)
|
||||||
|
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||||
|
|
||||||
|
# Process through connector
|
||||||
|
# Note: connector replaces padding with learnable registers and resets mask to zeros
|
||||||
|
# This means all positions now have valid embeddings (no need for final masking)
|
||||||
|
embeddings, _ = self.video_embeddings_connector(features, additive_mask)
|
||||||
|
|
||||||
|
# Return embeddings without zeroing - the connector's register replacement
|
||||||
|
# means all positions have meaningful values now
|
||||||
|
return embeddings, attention_mask
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
max_length: int = 1024,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
return self.encode(prompt, max_length)
|
||||||
|
|
||||||
|
|
||||||
|
def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
|
||||||
|
encoder = LTX2TextEncoder(model_path=model_path)
|
||||||
|
encoder.load()
|
||||||
|
return encoder
|
||||||
|
|
||||||
26
mlx_video/models/ltx/text_projection.py
Normal file
26
mlx_video/models/ltx/text_projection.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class PixArtAlphaTextProjection(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
hidden_size: int,
|
||||||
|
out_features: int | None = None,
|
||||||
|
bias: bool = True,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
out_features = out_features or hidden_size
|
||||||
|
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
|
||||||
|
self.act = nn.GELU(approx="precise")
|
||||||
|
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
x = self.linear1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.linear2(x)
|
||||||
|
return x
|
||||||
359
mlx_video/models/ltx/transformer.py
Normal file
359
mlx_video/models/ltx/transformer.py
Normal file
@@ -0,0 +1,359 @@
|
|||||||
|
from dataclasses import dataclass, replace
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from mlx_video.models.ltx.config import LTXRopeType, TransformerConfig
|
||||||
|
from mlx_video.models.ltx.attention import Attention
|
||||||
|
from mlx_video.models.ltx.feed_forward import FeedForward
|
||||||
|
from mlx_video.utils import rms_norm
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Modality:
|
||||||
|
latent: mx.array
|
||||||
|
timesteps: mx.array
|
||||||
|
positions: mx.array
|
||||||
|
context: mx.array
|
||||||
|
enabled: bool = True
|
||||||
|
context_mask: Optional[mx.array] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TransformerArgs:
|
||||||
|
x: mx.array
|
||||||
|
context: mx.array
|
||||||
|
context_mask: Optional[mx.array]
|
||||||
|
timesteps: mx.array
|
||||||
|
embedded_timestep: mx.array
|
||||||
|
positional_embeddings: Tuple[mx.array, mx.array]
|
||||||
|
cross_positional_embeddings: Optional[Tuple[mx.array, mx.array]]
|
||||||
|
cross_scale_shift_timestep: Optional[mx.array]
|
||||||
|
cross_gate_timestep: Optional[mx.array]
|
||||||
|
enabled: bool
|
||||||
|
|
||||||
|
|
||||||
|
class BasicAVTransformerBlock(nn.Module):
|
||||||
|
"""Audio-Video transformer block with cross-modal attention.
|
||||||
|
|
||||||
|
Supports video-only, audio-only, or combined audio-video processing
|
||||||
|
with bidirectional cross-attention between modalities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
idx: int,
|
||||||
|
video: Optional[TransformerConfig] = None,
|
||||||
|
audio: Optional[TransformerConfig] = None,
|
||||||
|
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||||
|
norm_eps: float = 1e-6,
|
||||||
|
):
|
||||||
|
"""Initialize transformer block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx: Block index
|
||||||
|
video: Video modality configuration
|
||||||
|
audio: Audio modality configuration
|
||||||
|
rope_type: Type of rotary position embedding
|
||||||
|
norm_eps: Epsilon for normalization
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.idx = idx
|
||||||
|
self.norm_eps = norm_eps
|
||||||
|
|
||||||
|
# Video components
|
||||||
|
if video is not None:
|
||||||
|
self.attn1 = Attention(
|
||||||
|
query_dim=video.dim,
|
||||||
|
heads=video.heads,
|
||||||
|
dim_head=video.d_head,
|
||||||
|
context_dim=None, # Self-attention
|
||||||
|
rope_type=rope_type,
|
||||||
|
norm_eps=norm_eps,
|
||||||
|
)
|
||||||
|
self.attn2 = Attention(
|
||||||
|
query_dim=video.dim,
|
||||||
|
context_dim=video.context_dim,
|
||||||
|
heads=video.heads,
|
||||||
|
dim_head=video.d_head,
|
||||||
|
rope_type=rope_type,
|
||||||
|
norm_eps=norm_eps,
|
||||||
|
)
|
||||||
|
self.ff = FeedForward(video.dim, dim_out=video.dim)
|
||||||
|
# 6 scale-shift parameters: 3 for attention, 3 for MLP
|
||||||
|
self.scale_shift_table = mx.zeros((6, video.dim))
|
||||||
|
|
||||||
|
# Audio components
|
||||||
|
if audio is not None:
|
||||||
|
self.audio_attn1 = Attention(
|
||||||
|
query_dim=audio.dim,
|
||||||
|
heads=audio.heads,
|
||||||
|
dim_head=audio.d_head,
|
||||||
|
context_dim=None,
|
||||||
|
rope_type=rope_type,
|
||||||
|
norm_eps=norm_eps,
|
||||||
|
)
|
||||||
|
self.audio_attn2 = Attention(
|
||||||
|
query_dim=audio.dim,
|
||||||
|
context_dim=audio.context_dim,
|
||||||
|
heads=audio.heads,
|
||||||
|
dim_head=audio.d_head,
|
||||||
|
rope_type=rope_type,
|
||||||
|
norm_eps=norm_eps,
|
||||||
|
)
|
||||||
|
self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)
|
||||||
|
self.audio_scale_shift_table = mx.zeros((6, audio.dim))
|
||||||
|
|
||||||
|
# Cross-modal attention (when both video and audio are enabled)
|
||||||
|
if audio is not None and video is not None:
|
||||||
|
# Audio-to-Video: Q from video, K/V from audio
|
||||||
|
self.audio_to_video_attn = Attention(
|
||||||
|
query_dim=video.dim,
|
||||||
|
context_dim=audio.dim,
|
||||||
|
heads=audio.heads,
|
||||||
|
dim_head=audio.d_head,
|
||||||
|
rope_type=rope_type,
|
||||||
|
norm_eps=norm_eps,
|
||||||
|
)
|
||||||
|
# Video-to-Audio: Q from audio, K/V from video
|
||||||
|
self.video_to_audio_attn = Attention(
|
||||||
|
query_dim=audio.dim,
|
||||||
|
context_dim=video.dim,
|
||||||
|
heads=audio.heads,
|
||||||
|
dim_head=audio.d_head,
|
||||||
|
rope_type=rope_type,
|
||||||
|
norm_eps=norm_eps,
|
||||||
|
)
|
||||||
|
# Scale-shift tables for cross-attention
|
||||||
|
self.scale_shift_table_a2v_ca_audio = mx.zeros((5, audio.dim))
|
||||||
|
self.scale_shift_table_a2v_ca_video = mx.zeros((5, video.dim))
|
||||||
|
|
||||||
|
def get_ada_values(
|
||||||
|
self,
|
||||||
|
scale_shift_table: mx.array,
|
||||||
|
batch_size: int,
|
||||||
|
timestep: mx.array,
|
||||||
|
indices: slice,
|
||||||
|
) -> Tuple[mx.array, ...]:
|
||||||
|
"""Get adaptive normalization values from scale-shift table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scale_shift_table: Table of shape (num_params, dim)
|
||||||
|
batch_size: Batch size
|
||||||
|
timestep: Timestep embeddings of shape (B, 1, num_params * dim) or similar
|
||||||
|
indices: Slice for which parameters to extract
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of scale-shift values
|
||||||
|
"""
|
||||||
|
num_ada_params = scale_shift_table.shape[0]
|
||||||
|
|
||||||
|
# scale_shift_table[indices]: (num_selected, dim)
|
||||||
|
# Add batch and sequence dimensions: (1, 1, num_selected, dim)
|
||||||
|
table_slice = scale_shift_table[indices]
|
||||||
|
table_expanded = mx.expand_dims(mx.expand_dims(table_slice, axis=0), axis=0)
|
||||||
|
|
||||||
|
# timestep: (B, seq, num_params * dim) -> reshape to (B, seq, num_params, dim)
|
||||||
|
timestep_reshaped = mx.reshape(
|
||||||
|
timestep,
|
||||||
|
(batch_size, timestep.shape[1], num_ada_params, -1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the relevant indices
|
||||||
|
timestep_slice = timestep_reshaped[:, :, indices, :]
|
||||||
|
|
||||||
|
# Add table values to timestep
|
||||||
|
ada_values = table_expanded + timestep_slice
|
||||||
|
|
||||||
|
# Unbind along the parameter dimension
|
||||||
|
# Result: tuple of tensors, each of shape (B, seq, dim)
|
||||||
|
num_sliced = ada_values.shape[2]
|
||||||
|
result = tuple(ada_values[:, :, i, :] for i in range(num_sliced))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_av_ca_ada_values(
|
||||||
|
self,
|
||||||
|
scale_shift_table: mx.array,
|
||||||
|
batch_size: int,
|
||||||
|
scale_shift_timestep: mx.array,
|
||||||
|
gate_timestep: mx.array,
|
||||||
|
num_scale_shift_values: int = 4,
|
||||||
|
) -> Tuple[mx.array, mx.array, mx.array, mx.array, mx.array]:
|
||||||
|
"""Get adaptive values for cross-modal attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scale_shift_table: Table with 5 parameters (4 scale-shift + 1 gate)
|
||||||
|
batch_size: Batch size
|
||||||
|
scale_shift_timestep: Timestep for scale-shift
|
||||||
|
gate_timestep: Timestep for gating
|
||||||
|
num_scale_shift_values: Number of scale-shift values (default 4)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of 5 tensors: (scale1, shift1, scale2, shift2, gate)
|
||||||
|
"""
|
||||||
|
# Get scale-shift values
|
||||||
|
scale_shift_ada = self.get_ada_values(
|
||||||
|
scale_shift_table[:num_scale_shift_values, :],
|
||||||
|
batch_size,
|
||||||
|
scale_shift_timestep,
|
||||||
|
slice(None, None),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get gate values
|
||||||
|
gate_ada = self.get_ada_values(
|
||||||
|
scale_shift_table[num_scale_shift_values:, :],
|
||||||
|
batch_size,
|
||||||
|
gate_timestep,
|
||||||
|
slice(None, None),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Squeeze the sequence dimension if it's 1
|
||||||
|
scale_shift_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada)
|
||||||
|
gate_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada)
|
||||||
|
|
||||||
|
return (*scale_shift_squeezed, *gate_squeezed)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
video: Optional[TransformerArgs] = None,
|
||||||
|
audio: Optional[TransformerArgs] = None,
|
||||||
|
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
|
||||||
|
"""Forward pass through transformer block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video: Video modality arguments
|
||||||
|
audio: Audio modality arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (updated_video, updated_audio) TransformerArgs
|
||||||
|
"""
|
||||||
|
batch_size = video.x.shape[0] if video is not None else audio.x.shape[0]
|
||||||
|
|
||||||
|
vx = video.x if video is not None else None
|
||||||
|
ax = audio.x if audio is not None else None
|
||||||
|
|
||||||
|
# Check which modalities to run
|
||||||
|
run_vx = video is not None and video.enabled and vx.size > 0
|
||||||
|
run_ax = audio is not None and audio.enabled and ax.size > 0
|
||||||
|
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0)
|
||||||
|
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0)
|
||||||
|
|
||||||
|
# Process video self-attention and cross-attention with text
|
||||||
|
if run_vx:
|
||||||
|
vshift_msa, vscale_msa, vgate_msa = self.get_ada_values(
|
||||||
|
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Self-attention with RoPE
|
||||||
|
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
|
||||||
|
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa
|
||||||
|
|
||||||
|
# Cross-attention with text context
|
||||||
|
vx = vx + self.attn2(
|
||||||
|
rms_norm(vx, eps=self.norm_eps),
|
||||||
|
context=video.context,
|
||||||
|
mask=video.context_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process audio self-attention and cross-attention with text
|
||||||
|
if run_ax:
|
||||||
|
ashift_msa, ascale_msa, agate_msa = self.get_ada_values(
|
||||||
|
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Self-attention with RoPE
|
||||||
|
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
|
||||||
|
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa
|
||||||
|
|
||||||
|
# Cross-attention with text context
|
||||||
|
ax = ax + self.audio_attn2(
|
||||||
|
rms_norm(ax, eps=self.norm_eps),
|
||||||
|
context=audio.context,
|
||||||
|
mask=audio.context_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Audio-Video cross-modal attention
|
||||||
|
if run_a2v or run_v2a:
|
||||||
|
vx_norm3 = rms_norm(vx, eps=self.norm_eps)
|
||||||
|
ax_norm3 = rms_norm(ax, eps=self.norm_eps)
|
||||||
|
|
||||||
|
# Get adaptive values for audio cross-attention
|
||||||
|
(
|
||||||
|
scale_ca_audio_a2v,
|
||||||
|
shift_ca_audio_a2v,
|
||||||
|
scale_ca_audio_v2a,
|
||||||
|
shift_ca_audio_v2a,
|
||||||
|
gate_out_v2a,
|
||||||
|
) = self.get_av_ca_ada_values(
|
||||||
|
self.scale_shift_table_a2v_ca_audio,
|
||||||
|
ax.shape[0],
|
||||||
|
audio.cross_scale_shift_timestep,
|
||||||
|
audio.cross_gate_timestep,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get adaptive values for video cross-attention
|
||||||
|
(
|
||||||
|
scale_ca_video_a2v,
|
||||||
|
shift_ca_video_a2v,
|
||||||
|
scale_ca_video_v2a,
|
||||||
|
shift_ca_video_v2a,
|
||||||
|
gate_out_a2v,
|
||||||
|
) = self.get_av_ca_ada_values(
|
||||||
|
self.scale_shift_table_a2v_ca_video,
|
||||||
|
vx.shape[0],
|
||||||
|
video.cross_scale_shift_timestep,
|
||||||
|
video.cross_gate_timestep,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Audio-to-Video cross-attention
|
||||||
|
if run_a2v:
|
||||||
|
vx_scaled = vx_norm3 * (1 + scale_ca_video_a2v) + shift_ca_video_a2v
|
||||||
|
ax_scaled = ax_norm3 * (1 + scale_ca_audio_a2v) + shift_ca_audio_a2v
|
||||||
|
vx = vx + (
|
||||||
|
self.audio_to_video_attn(
|
||||||
|
vx_scaled,
|
||||||
|
context=ax_scaled,
|
||||||
|
pe=video.cross_positional_embeddings,
|
||||||
|
k_pe=audio.cross_positional_embeddings,
|
||||||
|
)
|
||||||
|
* gate_out_a2v
|
||||||
|
)
|
||||||
|
|
||||||
|
# Video-to-Audio cross-attention
|
||||||
|
if run_v2a:
|
||||||
|
ax_scaled = ax_norm3 * (1 + scale_ca_audio_v2a) + shift_ca_audio_v2a
|
||||||
|
vx_scaled = vx_norm3 * (1 + scale_ca_video_v2a) + shift_ca_video_v2a
|
||||||
|
ax = ax + (
|
||||||
|
self.video_to_audio_attn(
|
||||||
|
ax_scaled,
|
||||||
|
context=vx_scaled,
|
||||||
|
pe=audio.cross_positional_embeddings,
|
||||||
|
k_pe=video.cross_positional_embeddings,
|
||||||
|
)
|
||||||
|
* gate_out_v2a
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process video feed-forward
|
||||||
|
if run_vx:
|
||||||
|
vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
|
||||||
|
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None)
|
||||||
|
)
|
||||||
|
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
|
||||||
|
vx = vx + self.ff(vx_scaled) * vgate_mlp
|
||||||
|
|
||||||
|
# Process audio feed-forward
|
||||||
|
if run_ax:
|
||||||
|
ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
|
||||||
|
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None)
|
||||||
|
)
|
||||||
|
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
|
||||||
|
ax = ax + self.audio_ff(ax_scaled) * agate_mlp
|
||||||
|
|
||||||
|
# Return updated TransformerArgs
|
||||||
|
video_out = replace(video, x=vx) if video is not None else None
|
||||||
|
audio_out = replace(audio, x=ax) if audio is not None else None
|
||||||
|
|
||||||
|
return video_out, audio_out
|
||||||
364
mlx_video/models/ltx/upsampler.py
Normal file
364
mlx_video/models/ltx/upsampler.py
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class Conv3d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Union[int, Tuple[int, int, int]] = 3,
|
||||||
|
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||||
|
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||||
|
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||||
|
groups: int = 1,
|
||||||
|
bias: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||||
|
if isinstance(stride, int):
|
||||||
|
stride = (stride, stride, stride)
|
||||||
|
if isinstance(padding, int):
|
||||||
|
padding = (padding, padding, padding)
|
||||||
|
if isinstance(dilation, int):
|
||||||
|
dilation = (dilation, dilation, dilation)
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
self.dilation = dilation
|
||||||
|
self.groups = groups
|
||||||
|
|
||||||
|
# Weight shape: (C_out, KD, KH, KW, C_in)
|
||||||
|
scale = 1.0 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5
|
||||||
|
self.weight = mx.random.uniform(
|
||||||
|
low=-scale,
|
||||||
|
high=scale,
|
||||||
|
shape=(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels),
|
||||||
|
)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias = mx.zeros((out_channels,))
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor of shape (N, D, H, W, C_in)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output tensor of shape (N, D', H', W', C_out)
|
||||||
|
"""
|
||||||
|
y = mx.conv3d(
|
||||||
|
x,
|
||||||
|
self.weight,
|
||||||
|
stride=self.stride,
|
||||||
|
padding=self.padding,
|
||||||
|
dilation=self.dilation,
|
||||||
|
groups=self.groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
y = y + self.bias
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class GroupNorm3d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.num_groups = num_groups
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = mx.ones((num_channels,))
|
||||||
|
self.bias = mx.zeros((num_channels,))
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
# x: (N, D, H, W, C)
|
||||||
|
n, d, h, w, c = x.shape
|
||||||
|
|
||||||
|
# Reshape to (N, D*H*W, num_groups, C//num_groups)
|
||||||
|
x = mx.reshape(x, (n, d * h * w, self.num_groups, c // self.num_groups))
|
||||||
|
|
||||||
|
# Compute mean and var over spatial and channel group dims
|
||||||
|
mean = mx.mean(x, axis=(1, 3), keepdims=True)
|
||||||
|
var = mx.var(x, axis=(1, 3), keepdims=True)
|
||||||
|
|
||||||
|
# Normalize
|
||||||
|
x = (x - mean) / mx.sqrt(var + self.eps)
|
||||||
|
|
||||||
|
# Reshape back
|
||||||
|
x = mx.reshape(x, (n, d, h, w, c))
|
||||||
|
|
||||||
|
# Apply weight and bias
|
||||||
|
x = x * self.weight + self.bias
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PixelShuffle2D(nn.Module):
|
||||||
|
"""Pixel shuffle for 2D spatial upsampling."""
|
||||||
|
|
||||||
|
def __init__(self, upscale_factor: int = 2):
|
||||||
|
super().__init__()
|
||||||
|
self.upscale_factor = upscale_factor
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
# x: (N, H, W, C) where C = out_channels * upscale_factor^2
|
||||||
|
n, h, w, c = x.shape
|
||||||
|
r = self.upscale_factor
|
||||||
|
out_c = c // (r * r)
|
||||||
|
|
||||||
|
# Reshape: (N, H, W, out_c, r, r)
|
||||||
|
x = mx.reshape(x, (n, h, w, out_c, r, r))
|
||||||
|
|
||||||
|
# Permute: (N, H, r, W, r, out_c)
|
||||||
|
x = mx.transpose(x, (0, 1, 4, 2, 5, 3))
|
||||||
|
|
||||||
|
# Reshape: (N, H*r, W*r, out_c)
|
||||||
|
x = mx.reshape(x, (n, h * r, w * r, out_c))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialRationalResampler(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, mid_channels: int = 1024, scale: float = 2.0):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
# 2D conv: mid_channels -> 4*mid_channels for pixel shuffle
|
||||||
|
self.conv = nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
# Blur kernel for antialiasing
|
||||||
|
self.blur_down_kernel = mx.ones((1, 1, 5, 5)) / 25.0
|
||||||
|
|
||||||
|
self.pixel_shuffle = PixelShuffle2D(2)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
# x: (N, D, H, W, C) - channels last 3D format
|
||||||
|
|
||||||
|
n, d, h, w, c = x.shape
|
||||||
|
|
||||||
|
# Process frame by frame
|
||||||
|
# Reshape to (N*D, H, W, C) for 2D operations
|
||||||
|
x = mx.reshape(x, (n * d, h, w, c))
|
||||||
|
|
||||||
|
# Apply 2D conv
|
||||||
|
x = self.conv(x)
|
||||||
|
|
||||||
|
# Pixel shuffle for 2x upscaling
|
||||||
|
x = self.pixel_shuffle(x)
|
||||||
|
|
||||||
|
# Reshape back to (N, D, H*2, W*2, C)
|
||||||
|
x = mx.reshape(x, (n, d, h * 2, w * 2, c))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock3D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = Conv3d(channels, channels, kernel_size=3, padding=1)
|
||||||
|
self.norm1 = GroupNorm3d(32, channels)
|
||||||
|
self.conv2 = Conv3d(channels, channels, kernel_size=3, padding=1)
|
||||||
|
self.norm2 = GroupNorm3d(32, channels)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = nn.silu(x)
|
||||||
|
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.norm2(x)
|
||||||
|
|
||||||
|
# Activation AFTER residual addition
|
||||||
|
x = nn.silu(x + residual)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LatentUpsampler(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 128,
|
||||||
|
mid_channels: int = 1024,
|
||||||
|
num_blocks_per_stage: int = 4,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.mid_channels = mid_channels
|
||||||
|
|
||||||
|
# Initial projection
|
||||||
|
self.initial_conv = Conv3d(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||||
|
self.initial_norm = GroupNorm3d(32, mid_channels)
|
||||||
|
|
||||||
|
# Pre-upsample ResBlocks - use dict with int keys for MLX parameter tracking
|
||||||
|
self.res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
|
||||||
|
|
||||||
|
# Upsampler: 2D spatial upsampling (frame-by-frame)
|
||||||
|
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=2.0)
|
||||||
|
|
||||||
|
# Post-upsample ResBlocks - use dict with int keys for MLX parameter tracking
|
||||||
|
self.post_upsample_res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
|
||||||
|
|
||||||
|
# Final projection
|
||||||
|
self.final_conv = Conv3d(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def __call__(self, latent: mx.array, debug: bool = False) -> mx.array:
|
||||||
|
"""Upsample latents by 2x spatially.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latent: Input tensor of shape (B, C, F, H, W) - channels first
|
||||||
|
debug: If True, print intermediate values for debugging
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Upsampled tensor of shape (B, C, F, H*2, W*2) - channels first
|
||||||
|
"""
|
||||||
|
def debug_stats(name, t):
|
||||||
|
if debug:
|
||||||
|
mx.eval(t)
|
||||||
|
print(f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(" [DEBUG] LatentUpsampler forward pass:")
|
||||||
|
debug_stats("Input (channels first)", latent)
|
||||||
|
|
||||||
|
# Convert from channels first (B, C, F, H, W) to channels last (B, F, H, W, C)
|
||||||
|
x = mx.transpose(latent, (0, 2, 3, 4, 1))
|
||||||
|
if debug:
|
||||||
|
debug_stats("After transpose to channels-last", x)
|
||||||
|
|
||||||
|
# Initial conv
|
||||||
|
x = self.initial_conv(x)
|
||||||
|
if debug:
|
||||||
|
debug_stats("After initial_conv", x)
|
||||||
|
x = self.initial_norm(x)
|
||||||
|
if debug:
|
||||||
|
debug_stats("After initial_norm", x)
|
||||||
|
x = nn.silu(x)
|
||||||
|
if debug:
|
||||||
|
debug_stats("After silu", x)
|
||||||
|
|
||||||
|
# Pre-upsample blocks
|
||||||
|
for i in sorted(self.res_blocks.keys()):
|
||||||
|
x = self.res_blocks[i](x)
|
||||||
|
if debug:
|
||||||
|
debug_stats(f"After res_blocks[{i}]", x)
|
||||||
|
|
||||||
|
# Upsample (2D spatial, frame-by-frame)
|
||||||
|
x = self.upsampler(x)
|
||||||
|
if debug:
|
||||||
|
debug_stats("After upsampler (spatial 2x)", x)
|
||||||
|
|
||||||
|
# Post-upsample blocks
|
||||||
|
for i in sorted(self.post_upsample_res_blocks.keys()):
|
||||||
|
x = self.post_upsample_res_blocks[i](x)
|
||||||
|
if debug:
|
||||||
|
debug_stats(f"After post_upsample_res_blocks[{i}]", x)
|
||||||
|
|
||||||
|
# Final conv
|
||||||
|
x = self.final_conv(x)
|
||||||
|
if debug:
|
||||||
|
debug_stats("After final_conv", x)
|
||||||
|
|
||||||
|
# Convert back to channels first (B, C, F, H, W)
|
||||||
|
x = mx.transpose(x, (0, 4, 1, 2, 3))
|
||||||
|
if debug:
|
||||||
|
debug_stats("Output (channels first)", x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def upsample_latents(
|
||||||
|
latent: mx.array,
|
||||||
|
upsampler: LatentUpsampler,
|
||||||
|
latent_mean: mx.array,
|
||||||
|
latent_std: mx.array,
|
||||||
|
debug: bool = False,
|
||||||
|
) -> mx.array:
|
||||||
|
|
||||||
|
# Un-normalize: latent * std + mean
|
||||||
|
latent_mean = latent_mean.reshape(1, -1, 1, 1, 1)
|
||||||
|
latent_std = latent_std.reshape(1, -1, 1, 1, 1)
|
||||||
|
latent = latent * latent_std + latent_mean
|
||||||
|
|
||||||
|
# Upsample
|
||||||
|
latent = upsampler(latent, debug=debug)
|
||||||
|
|
||||||
|
# Re-normalize: (latent - mean) / std
|
||||||
|
latent = (latent - latent_mean) / latent_std
|
||||||
|
|
||||||
|
return latent
|
||||||
|
|
||||||
|
|
||||||
|
def load_upsampler(weights_path: str) -> LatentUpsampler:
|
||||||
|
"""Load upsampler from safetensors weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights_path: Path to upsampler weights file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded LatentUpsampler model
|
||||||
|
"""
|
||||||
|
print(f"Loading spatial upsampler from {weights_path}...")
|
||||||
|
raw_weights = mx.load(weights_path)
|
||||||
|
|
||||||
|
# Check weight shapes to determine mid_channels
|
||||||
|
# res_blocks.0.conv1.weight should be (mid_channels, mid_channels, 3, 3, 3)
|
||||||
|
sample_key = "res_blocks.0.conv1.weight"
|
||||||
|
if sample_key in raw_weights:
|
||||||
|
mid_channels = raw_weights[sample_key].shape[0]
|
||||||
|
else:
|
||||||
|
mid_channels = 1024 # default
|
||||||
|
|
||||||
|
print(f" Detected mid_channels: {mid_channels}")
|
||||||
|
|
||||||
|
# Create model
|
||||||
|
upsampler = LatentUpsampler(
|
||||||
|
in_channels=128,
|
||||||
|
mid_channels=mid_channels,
|
||||||
|
num_blocks_per_stage=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sanitize weights - convert from PyTorch to MLX format
|
||||||
|
sanitized = {}
|
||||||
|
for key, value in raw_weights.items():
|
||||||
|
new_key = key
|
||||||
|
|
||||||
|
# Conv3d weights: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
|
||||||
|
if "conv" in key and "weight" in key and value.ndim == 5:
|
||||||
|
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||||
|
|
||||||
|
# Conv2d weights: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
|
||||||
|
if "conv" in key and "weight" in key and value.ndim == 4:
|
||||||
|
value = mx.transpose(value, (0, 2, 3, 1))
|
||||||
|
|
||||||
|
# Map upsampler.conv to upsampler.conv (SpatialRationalResampler)
|
||||||
|
# Keys: upsampler.conv.weight, upsampler.conv.bias, upsampler.blur_down.kernel
|
||||||
|
if key.startswith("upsampler."):
|
||||||
|
new_key = key # Keep as is for SpatialRationalResampler
|
||||||
|
|
||||||
|
sanitized[new_key] = value
|
||||||
|
|
||||||
|
# Load weights
|
||||||
|
upsampler.load_weights(list(sanitized.items()), strict=False)
|
||||||
|
|
||||||
|
print(f" Loaded {len(sanitized)} weights")
|
||||||
|
|
||||||
|
return upsampler
|
||||||
1
mlx_video/models/ltx/video_vae/__init__.py
Normal file
1
mlx_video/models/ltx/video_vae/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder
|
||||||
294
mlx_video/models/ltx/video_vae/convolution.py
Normal file
294
mlx_video/models/ltx/video_vae/convolution.py
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class PaddingModeType(Enum):
|
||||||
|
ZEROS = "zeros"
|
||||||
|
REFLECT = "reflect"
|
||||||
|
|
||||||
|
|
||||||
|
def reflect_pad_2d(x: mx.array, pad_h: int, pad_w: int) -> mx.array:
|
||||||
|
"""Apply reflect padding to spatial dimensions of a 5D tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor of shape (B, D, H, W, C) - channels last
|
||||||
|
pad_h: Padding for height dimension
|
||||||
|
pad_w: Padding for width dimension
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Padded tensor
|
||||||
|
"""
|
||||||
|
if pad_h == 0 and pad_w == 0:
|
||||||
|
return x
|
||||||
|
|
||||||
|
# Height padding (axis 2)
|
||||||
|
if pad_h > 0:
|
||||||
|
# Get reflection indices - exclude boundary
|
||||||
|
top_pad = x[:, :, 1:pad_h+1, :, :][:, :, ::-1, :, :] # Flip top portion
|
||||||
|
bottom_pad = x[:, :, -pad_h-1:-1, :, :][:, :, ::-1, :, :] # Flip bottom portion
|
||||||
|
x = mx.concatenate([top_pad, x, bottom_pad], axis=2)
|
||||||
|
|
||||||
|
# Width padding (axis 3)
|
||||||
|
if pad_w > 0:
|
||||||
|
left_pad = x[:, :, :, 1:pad_w+1, :][:, :, :, ::-1, :] # Flip left portion
|
||||||
|
right_pad = x[:, :, :, -pad_w-1:-1, :][:, :, :, ::-1, :] # Flip right portion
|
||||||
|
x = mx.concatenate([left_pad, x, right_pad], axis=3)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def make_conv_nd(
|
||||||
|
dims: int,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Union[int, Tuple[int, ...]],
|
||||||
|
stride: Union[int, Tuple[int, ...]] = 1,
|
||||||
|
padding: Union[int, Tuple[int, ...], str] = 0,
|
||||||
|
causal: bool = False,
|
||||||
|
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||||
|
) -> nn.Module:
|
||||||
|
|
||||||
|
if dims == 2:
|
||||||
|
return CausalConv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
causal=causal,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
elif dims == 3:
|
||||||
|
return CausalConv3d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
causal=causal,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported number of dimensions: {dims}")
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv3d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Union[int, Tuple[int, int, int]],
|
||||||
|
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||||
|
padding: Union[int, Tuple[int, int, int], str] = 0,
|
||||||
|
causal: bool = False,
|
||||||
|
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.causal = causal
|
||||||
|
self.spatial_padding_mode = spatial_padding_mode
|
||||||
|
|
||||||
|
# Normalize kernel_size and stride to tuples
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||||
|
if isinstance(stride, int):
|
||||||
|
stride = (stride, stride, stride)
|
||||||
|
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
self.time_kernel_size = kernel_size[0]
|
||||||
|
|
||||||
|
# Calculate spatial padding (temporal is handled separately via frame replication)
|
||||||
|
height_pad = kernel_size[1] // 2
|
||||||
|
width_pad = kernel_size[2] // 2
|
||||||
|
self.spatial_padding = (height_pad, width_pad)
|
||||||
|
|
||||||
|
# Create the base convolution (without padding, we'll handle it manually)
|
||||||
|
self.conv = nn.Conv3d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=0, # We handle padding manually
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
|
||||||
|
|
||||||
|
use_causal = causal if causal is not None else self.causal
|
||||||
|
|
||||||
|
# Apply temporal padding via frame replication
|
||||||
|
# Only apply if kernel_size > 1
|
||||||
|
if self.time_kernel_size > 1:
|
||||||
|
if use_causal:
|
||||||
|
# Causal: replicate first frame kernel_size-1 times at the beginning
|
||||||
|
first_frame_pad = mx.repeat(x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2)
|
||||||
|
x = mx.concatenate([first_frame_pad, x], axis=2)
|
||||||
|
else:
|
||||||
|
# Non-causal: replicate first frame at start, last frame at end
|
||||||
|
pad_size = (self.time_kernel_size - 1) // 2
|
||||||
|
if pad_size > 0:
|
||||||
|
first_frame_pad = mx.repeat(x[:, :, :1, :, :], pad_size, axis=2)
|
||||||
|
last_frame_pad = mx.repeat(x[:, :, -1:, :, :], pad_size, axis=2)
|
||||||
|
x = mx.concatenate([first_frame_pad, x, last_frame_pad], axis=2)
|
||||||
|
|
||||||
|
# Transpose to channels last: (B, C, D, H, W) -> (B, D, H, W, C)
|
||||||
|
x = mx.transpose(x, (0, 2, 3, 4, 1))
|
||||||
|
|
||||||
|
# Apply spatial padding
|
||||||
|
pad_h, pad_w = self.spatial_padding
|
||||||
|
if pad_h > 0 or pad_w > 0:
|
||||||
|
if self.spatial_padding_mode == PaddingModeType.REFLECT:
|
||||||
|
# Use reflect padding for spatial dimensions
|
||||||
|
x = reflect_pad_2d(x, pad_h, pad_w)
|
||||||
|
else:
|
||||||
|
# Use zero padding for spatial dimensions
|
||||||
|
pad_width = [
|
||||||
|
(0, 0), # Batch
|
||||||
|
(0, 0), # D (temporal - already padded)
|
||||||
|
(pad_h, pad_h), # H
|
||||||
|
(pad_w, pad_w), # W
|
||||||
|
(0, 0), # C
|
||||||
|
]
|
||||||
|
x = mx.pad(x, pad_width)
|
||||||
|
|
||||||
|
# Apply convolution with chunking for large tensors
|
||||||
|
# Note: We choose to use chunking because MLX conv3d fails around 33 frames with 192x192 spatial
|
||||||
|
x = self._chunked_conv3d(x)
|
||||||
|
|
||||||
|
# Transpose back to channels first: (B, D, H, W, C) -> (B, C, D, H, W)
|
||||||
|
x = mx.transpose(x, (0, 4, 1, 2, 3))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _chunked_conv3d(self, x: mx.array) -> mx.array:
|
||||||
|
"""Apply conv3d in temporal chunks to work around MLX bug with large tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor of shape (B, D, H, W, C) in channels-last format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output tensor after conv3d
|
||||||
|
"""
|
||||||
|
b, d, h, w, c = x.shape
|
||||||
|
|
||||||
|
|
||||||
|
total_elements = d * h * w * c
|
||||||
|
max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk
|
||||||
|
|
||||||
|
if total_elements <= max_safe_elements:
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
elements_per_frame = h * w * c
|
||||||
|
max_frames_per_chunk = max(1, max_safe_elements // elements_per_frame)
|
||||||
|
chunk_size = min(max_frames_per_chunk, 24) # Cap at 24 frames per chunk
|
||||||
|
|
||||||
|
kernel_t = self.time_kernel_size
|
||||||
|
|
||||||
|
overlap = kernel_t - 1
|
||||||
|
|
||||||
|
|
||||||
|
expected_output_frames = d - overlap
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
out_idx = 0
|
||||||
|
|
||||||
|
# Process chunks
|
||||||
|
in_start = 0
|
||||||
|
while out_idx < expected_output_frames:
|
||||||
|
remaining = expected_output_frames - out_idx
|
||||||
|
out_frames_this_chunk = min(chunk_size, remaining)
|
||||||
|
|
||||||
|
in_frames_needed = out_frames_this_chunk + overlap
|
||||||
|
in_end = min(in_start + in_frames_needed, d)
|
||||||
|
|
||||||
|
chunk = x[:, in_start:in_end, :, :, :]
|
||||||
|
|
||||||
|
chunk_out = self.conv(chunk)
|
||||||
|
mx.eval(chunk_out)
|
||||||
|
|
||||||
|
outputs.append(chunk_out)
|
||||||
|
|
||||||
|
out_idx += chunk_out.shape[1]
|
||||||
|
in_start += chunk_out.shape[1]
|
||||||
|
|
||||||
|
# Concatenate all chunks
|
||||||
|
if len(outputs) == 1:
|
||||||
|
return outputs[0]
|
||||||
|
return mx.concatenate(outputs, axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv2d(nn.Module):
|
||||||
|
"""2D convolution with optional causal padding."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Union[int, Tuple[int, int]],
|
||||||
|
stride: Union[int, Tuple[int, int]] = 1,
|
||||||
|
padding: Union[int, Tuple[int, int], str] = 0,
|
||||||
|
causal: bool = False,
|
||||||
|
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||||
|
):
|
||||||
|
"""Initialize CausalConv2d."""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.causal = causal
|
||||||
|
self.spatial_padding_mode = spatial_padding_mode
|
||||||
|
|
||||||
|
# Normalize kernel_size and stride
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
kernel_size = (kernel_size, kernel_size)
|
||||||
|
if isinstance(stride, int):
|
||||||
|
stride = (stride, stride)
|
||||||
|
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
# Calculate padding
|
||||||
|
if isinstance(padding, str) and padding == "same":
|
||||||
|
self.padding = (
|
||||||
|
(kernel_size[0] - 1) // 2,
|
||||||
|
(kernel_size[1] - 1) // 2,
|
||||||
|
)
|
||||||
|
elif isinstance(padding, int):
|
||||||
|
self.padding = (padding, padding)
|
||||||
|
else:
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=0,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
|
||||||
|
"""Forward pass."""
|
||||||
|
# Transpose to channels last: (B, C, H, W) -> (B, H, W, C)
|
||||||
|
x = mx.transpose(x, (0, 2, 3, 1))
|
||||||
|
|
||||||
|
# Apply padding
|
||||||
|
pad_h, pad_w = self.padding
|
||||||
|
if pad_h != 0 or pad_w != 0:
|
||||||
|
pad_width = [
|
||||||
|
(0, 0), # Batch
|
||||||
|
(pad_h, pad_h), # H
|
||||||
|
(pad_w, pad_w), # W
|
||||||
|
(0, 0), # C
|
||||||
|
]
|
||||||
|
x = mx.pad(x, pad_width)
|
||||||
|
|
||||||
|
x = self.conv(x)
|
||||||
|
|
||||||
|
# Transpose back: (B, H, W, C) -> (B, C, H, W)
|
||||||
|
x = mx.transpose(x, (0, 3, 1, 2))
|
||||||
|
|
||||||
|
return x
|
||||||
524
mlx_video/models/ltx/video_vae/decoder.py
Normal file
524
mlx_video/models/ltx/video_vae/decoder.py
Normal file
@@ -0,0 +1,524 @@
|
|||||||
|
"""Video VAE Decoder for LTX-2 with timestep conditioning.
|
||||||
|
|
||||||
|
Architecture (from PyTorch weights):
|
||||||
|
- conv_in: 128 -> 1024
|
||||||
|
- up_blocks.0: 5 ResBlocks at 1024 (with timestep)
|
||||||
|
- up_blocks.1: Conv 1024 -> 4096, depth2space -> 512, upscale 2x
|
||||||
|
- up_blocks.2: 5 ResBlocks at 512 (with timestep)
|
||||||
|
- up_blocks.3: Conv 512 -> 2048, depth2space -> 256, upscale 2x
|
||||||
|
- up_blocks.4: 5 ResBlocks at 256 (with timestep)
|
||||||
|
- up_blocks.5: Conv 256 -> 1024, depth2space -> 128, upscale 2x
|
||||||
|
- up_blocks.6: 5 ResBlocks at 128 (with timestep)
|
||||||
|
- pixel_norm + timestep modulation (last_scale_shift_table)
|
||||||
|
- conv_out: 128 -> 48
|
||||||
|
- unpatchify: 48 -> 3 with patch_size=4
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||||
|
from mlx_video.models.ltx.video_vae.ops import unpatchify
|
||||||
|
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
|
||||||
|
|
||||||
|
|
||||||
|
def get_timestep_embedding(
|
||||||
|
timesteps: mx.array,
|
||||||
|
embedding_dim: int,
|
||||||
|
flip_sin_to_cos: bool = True,
|
||||||
|
downscale_freq_shift: float = 0,
|
||||||
|
scale: float = 1,
|
||||||
|
max_period: int = 10000,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Create sinusoidal timestep embeddings."""
|
||||||
|
half_dim = embedding_dim // 2
|
||||||
|
exponent = -math.log(max_period) * mx.arange(0, half_dim, dtype=mx.float32)
|
||||||
|
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||||
|
|
||||||
|
emb = mx.exp(exponent)
|
||||||
|
emb = timesteps[:, None].astype(mx.float32) * emb[None, :]
|
||||||
|
emb = scale * emb
|
||||||
|
|
||||||
|
emb = mx.concatenate([mx.sin(emb), mx.cos(emb)], axis=-1)
|
||||||
|
|
||||||
|
if flip_sin_to_cos:
|
||||||
|
emb = mx.concatenate([emb[:, half_dim:], emb[:, :half_dim]], axis=-1)
|
||||||
|
|
||||||
|
if embedding_dim % 2 == 1:
|
||||||
|
emb = mx.pad(emb, [(0, 0), (0, 1)])
|
||||||
|
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
"""MLP for timestep embedding."""
|
||||||
|
|
||||||
|
def __init__(self, in_channels: int, time_embed_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||||
|
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
|
def __call__(self, sample: mx.array) -> mx.array:
|
||||||
|
sample = self.linear_1(sample)
|
||||||
|
sample = self.act(sample)
|
||||||
|
sample = self.linear_2(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class PixArtAlphaTimestepEmbedder(nn.Module):
|
||||||
|
"""Combined timestep embedding (sinusoidal + MLP)."""
|
||||||
|
|
||||||
|
def __init__(self, embedding_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.timestep_embedder = TimestepEmbedding(
|
||||||
|
in_channels=256,
|
||||||
|
time_embed_dim=embedding_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32) -> mx.array:
|
||||||
|
timesteps_proj = get_timestep_embedding(
|
||||||
|
timestep,
|
||||||
|
embedding_dim=256,
|
||||||
|
flip_sin_to_cos=True,
|
||||||
|
downscale_freq_shift=0
|
||||||
|
)
|
||||||
|
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
|
||||||
|
return timesteps_emb
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlock3DSimple(nn.Module):
|
||||||
|
"""ResNet block with optional timestep conditioning.
|
||||||
|
|
||||||
|
Weight keys: conv1.conv, conv2.conv, scale_shift_table
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||||
|
timestep_conditioning: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.timestep_conditioning = timestep_conditioning
|
||||||
|
|
||||||
|
# Nested conv structure to match PyTorch naming: conv1.conv.weight
|
||||||
|
self.conv1 = self._make_conv_wrapper(channels, channels, spatial_padding_mode)
|
||||||
|
self.conv2 = self._make_conv_wrapper(channels, channels, spatial_padding_mode)
|
||||||
|
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
|
# Scale-shift table for timestep conditioning: [shift1, scale1, shift2, scale2]
|
||||||
|
if timestep_conditioning:
|
||||||
|
self.scale_shift_table = mx.zeros((4, channels))
|
||||||
|
|
||||||
|
def _make_conv_wrapper(self, in_ch, out_ch, padding_mode):
|
||||||
|
"""Create a wrapper object with a 'conv' attribute to match PyTorch naming."""
|
||||||
|
class ConvWrapper(nn.Module):
|
||||||
|
def __init__(self_inner):
|
||||||
|
super().__init__()
|
||||||
|
self_inner.conv = CausalConv3d(
|
||||||
|
in_channels=in_ch,
|
||||||
|
out_channels=out_ch,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
spatial_padding_mode=padding_mode,
|
||||||
|
)
|
||||||
|
def __call__(self_inner, x, causal=False):
|
||||||
|
return self_inner.conv(x, causal=causal)
|
||||||
|
return ConvWrapper()
|
||||||
|
|
||||||
|
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||||
|
"""Apply pixel normalization."""
|
||||||
|
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
causal: bool = False,
|
||||||
|
timestep_embed: Optional[mx.array] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
residual = x
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
|
||||||
|
# Block 1 with optional timestep conditioning
|
||||||
|
x = self.pixel_norm(x)
|
||||||
|
|
||||||
|
if self.timestep_conditioning and timestep_embed is not None:
|
||||||
|
# scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1)
|
||||||
|
# Combine table with timestep embedding
|
||||||
|
ada_values = self.scale_shift_table[None, :, :, None, None, None] # (1, 4, C, 1, 1, 1)
|
||||||
|
# Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1)
|
||||||
|
channels = self.scale_shift_table.shape[1]
|
||||||
|
ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1)
|
||||||
|
ada_values = ada_values + ts_reshaped
|
||||||
|
|
||||||
|
shift1 = ada_values[:, 0] # (B, C, 1, 1, 1)
|
||||||
|
scale1 = ada_values[:, 1]
|
||||||
|
shift2 = ada_values[:, 2]
|
||||||
|
scale2 = ada_values[:, 3]
|
||||||
|
|
||||||
|
x = x * (1 + scale1) + shift1
|
||||||
|
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.conv1(x, causal=causal)
|
||||||
|
|
||||||
|
# Block 2 with optional timestep conditioning
|
||||||
|
x = self.pixel_norm(x)
|
||||||
|
|
||||||
|
if self.timestep_conditioning and timestep_embed is not None:
|
||||||
|
x = x * (1 + scale2) + shift2
|
||||||
|
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.conv2(x, causal=causal)
|
||||||
|
|
||||||
|
return x + residual
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlockGroup(nn.Module):
|
||||||
|
"""Group of ResNet blocks with shared timestep embedding.
|
||||||
|
|
||||||
|
PyTorch naming: res_blocks.0, res_blocks.1, etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
num_layers: int = 5,
|
||||||
|
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||||
|
timestep_conditioning: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.timestep_conditioning = timestep_conditioning
|
||||||
|
|
||||||
|
# Time embedder for this block group: embed_dim = 4 * channels
|
||||||
|
if timestep_conditioning:
|
||||||
|
self.time_embedder = PixArtAlphaTimestepEmbedder(
|
||||||
|
embedding_dim=channels * 4
|
||||||
|
)
|
||||||
|
|
||||||
|
self.res_blocks = [
|
||||||
|
ResnetBlock3DSimple(
|
||||||
|
channels,
|
||||||
|
spatial_padding_mode,
|
||||||
|
timestep_conditioning=timestep_conditioning
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
causal: bool = False,
|
||||||
|
timestep: Optional[mx.array] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
timestep_embed = None
|
||||||
|
|
||||||
|
if self.timestep_conditioning and timestep is not None:
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
timestep_embed = self.time_embedder(
|
||||||
|
timestep.flatten(),
|
||||||
|
hidden_dtype=x.dtype
|
||||||
|
)
|
||||||
|
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
|
||||||
|
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
|
||||||
|
|
||||||
|
for res_block in self.res_blocks:
|
||||||
|
x = res_block(x, causal=causal, timestep_embed=timestep_embed)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2VideoDecoder(nn.Module):
|
||||||
|
"""LTX-2 Video VAE Decoder with timestep conditioning.
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
- conv_in: 128 -> 1024
|
||||||
|
- up_blocks.0: 5 ResBlocks at 1024 (with timestep)
|
||||||
|
- up_blocks.1: Upsampler 1024 -> 512
|
||||||
|
- up_blocks.2: 5 ResBlocks at 512 (with timestep)
|
||||||
|
- up_blocks.3: Upsampler 512 -> 256
|
||||||
|
- up_blocks.4: 5 ResBlocks at 256 (with timestep)
|
||||||
|
- up_blocks.5: Upsampler 256 -> 128
|
||||||
|
- up_blocks.6: 5 ResBlocks at 128 (with timestep)
|
||||||
|
- conv_out: 128 -> 48 (3 * 4^2 for patch_size=4)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 128,
|
||||||
|
out_channels: int = 3,
|
||||||
|
patch_size: int = 4,
|
||||||
|
num_layers_per_block: int = 5,
|
||||||
|
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||||
|
timestep_conditioning: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.timestep_conditioning = timestep_conditioning
|
||||||
|
|
||||||
|
# Decode parameters (configurable via constructor)
|
||||||
|
self.decode_noise_scale = 0.025 # Set to 0.0 to disable noise
|
||||||
|
self.decode_timestep = 0.05
|
||||||
|
|
||||||
|
# Per-channel statistics for denormalization (loaded from weights)
|
||||||
|
self.latents_mean = mx.zeros((in_channels,))
|
||||||
|
self.latents_std = mx.ones((in_channels,))
|
||||||
|
|
||||||
|
# Initial conv: 128 -> 1024
|
||||||
|
class ConvInWrapper(nn.Module):
|
||||||
|
def __init__(self_inner):
|
||||||
|
super().__init__()
|
||||||
|
self_inner.conv = CausalConv3d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=1024,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
def __call__(self_inner, x, causal=False):
|
||||||
|
return self_inner.conv(x, causal=causal)
|
||||||
|
self.conv_in = ConvInWrapper()
|
||||||
|
|
||||||
|
# Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample
|
||||||
|
|
||||||
|
self.up_blocks = [
|
||||||
|
ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||||
|
DepthToSpaceUpsample(
|
||||||
|
dims=3,
|
||||||
|
in_channels=1024,
|
||||||
|
stride=(2, 2, 2),
|
||||||
|
residual=True, # CRITICAL: Must match PyTorch config!
|
||||||
|
out_channels_reduction_factor=2,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
),
|
||||||
|
ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||||
|
DepthToSpaceUpsample(
|
||||||
|
dims=3,
|
||||||
|
in_channels=512,
|
||||||
|
stride=(2, 2, 2),
|
||||||
|
residual=True, # CRITICAL: Must match PyTorch config!
|
||||||
|
out_channels_reduction_factor=2,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
),
|
||||||
|
ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||||
|
DepthToSpaceUpsample(
|
||||||
|
dims=3,
|
||||||
|
in_channels=256,
|
||||||
|
stride=(2, 2, 2),
|
||||||
|
residual=True, # CRITICAL: Must match PyTorch config!
|
||||||
|
out_channels_reduction_factor=2,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
),
|
||||||
|
ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||||
|
]
|
||||||
|
|
||||||
|
final_out_channels = out_channels * patch_size * patch_size
|
||||||
|
class ConvOutWrapper(nn.Module):
|
||||||
|
def __init__(self_inner):
|
||||||
|
super().__init__()
|
||||||
|
self_inner.conv = CausalConv3d(
|
||||||
|
in_channels=128,
|
||||||
|
out_channels=final_out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
def __call__(self_inner, x, causal=False):
|
||||||
|
return self_inner.conv(x, causal=causal)
|
||||||
|
self.conv_out = ConvOutWrapper()
|
||||||
|
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
|
if timestep_conditioning:
|
||||||
|
self.timestep_scale_multiplier = mx.array(1000.0)
|
||||||
|
self.last_time_embedder = PixArtAlphaTimestepEmbedder(
|
||||||
|
embedding_dim=128 * 2 # 256, matches (2, 128) table
|
||||||
|
)
|
||||||
|
self.last_scale_shift_table = mx.zeros((2, 128))
|
||||||
|
|
||||||
|
def denormalize(self, x: mx.array) -> mx.array:
|
||||||
|
"""Denormalize latents using per-channel statistics."""
|
||||||
|
mean = self.latents_mean.reshape(1, -1, 1, 1, 1)
|
||||||
|
std = self.latents_std.reshape(1, -1, 1, 1, 1)
|
||||||
|
return x * std + mean
|
||||||
|
|
||||||
|
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||||
|
"""Apply pixel normalization."""
|
||||||
|
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
sample: mx.array,
|
||||||
|
causal: bool = False,
|
||||||
|
timestep: Optional[mx.array] = None,
|
||||||
|
debug: bool = False,
|
||||||
|
) -> mx.array:
|
||||||
|
|
||||||
|
def debug_stats(name, t):
|
||||||
|
if debug:
|
||||||
|
mx.eval(t)
|
||||||
|
print(f" [VAE] {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
|
||||||
|
|
||||||
|
batch_size = sample.shape[0]
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
debug_stats("Input", sample)
|
||||||
|
|
||||||
|
# Add noise if timestep conditioning is enabled
|
||||||
|
if self.timestep_conditioning:
|
||||||
|
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
||||||
|
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
||||||
|
if debug:
|
||||||
|
debug_stats("After noise", sample)
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(f" [VAE] Denorm stats - mean: [{self.latents_mean.min().item():.4f}, {self.latents_mean.max().item():.4f}], std: [{self.latents_std.min().item():.4f}, {self.latents_std.max().item():.4f}]")
|
||||||
|
sample = self.denormalize(sample)
|
||||||
|
if debug:
|
||||||
|
debug_stats("After denormalize", sample)
|
||||||
|
|
||||||
|
if timestep is None and self.timestep_conditioning:
|
||||||
|
timestep = mx.full((batch_size,), self.decode_timestep)
|
||||||
|
|
||||||
|
scaled_timestep = None
|
||||||
|
if self.timestep_conditioning and timestep is not None:
|
||||||
|
scaled_timestep = timestep * self.timestep_scale_multiplier
|
||||||
|
|
||||||
|
x = self.conv_in(sample, causal=causal)
|
||||||
|
if debug:
|
||||||
|
debug_stats("After conv_in", x)
|
||||||
|
|
||||||
|
for i, block in enumerate(self.up_blocks):
|
||||||
|
if isinstance(block, ResBlockGroup):
|
||||||
|
x = block(x, causal=causal, timestep=scaled_timestep)
|
||||||
|
else:
|
||||||
|
x = block(x, causal=causal)
|
||||||
|
if debug:
|
||||||
|
block_type = type(block).__name__
|
||||||
|
debug_stats(f"After up_blocks[{i}] ({block_type})", x)
|
||||||
|
|
||||||
|
x = self.pixel_norm(x)
|
||||||
|
if debug:
|
||||||
|
debug_stats("After pixel_norm", x)
|
||||||
|
|
||||||
|
if self.timestep_conditioning and scaled_timestep is not None:
|
||||||
|
embedded_timestep = self.last_time_embedder(
|
||||||
|
scaled_timestep.flatten(),
|
||||||
|
hidden_dtype=x.dtype
|
||||||
|
)
|
||||||
|
embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1)
|
||||||
|
|
||||||
|
ada_values = self.last_scale_shift_table[None, :, :, None, None, None] # (1, 2, 128, 1, 1, 1)
|
||||||
|
ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1)
|
||||||
|
ada_values = ada_values + ts_reshaped
|
||||||
|
|
||||||
|
shift = ada_values[:, 0] # (B, 128, 1, 1, 1)
|
||||||
|
scale = ada_values[:, 1]
|
||||||
|
|
||||||
|
x = x * (1 + scale) + shift
|
||||||
|
if debug:
|
||||||
|
debug_stats("After timestep modulation", x)
|
||||||
|
|
||||||
|
x = self.act(x)
|
||||||
|
if debug:
|
||||||
|
debug_stats("After activation", x)
|
||||||
|
|
||||||
|
x = self.conv_out(x, causal=causal)
|
||||||
|
if debug:
|
||||||
|
debug_stats("After conv_out", x)
|
||||||
|
|
||||||
|
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
|
||||||
|
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||||
|
if debug:
|
||||||
|
debug_stats("After unpatchify", x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def load_vae_decoder(model_path: str, timestep_conditioning: bool = True) -> LTX2VideoDecoder:
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
|
||||||
|
|
||||||
|
model_path = Path(model_path)
|
||||||
|
|
||||||
|
# Try to find the weights file
|
||||||
|
if model_path.is_file() and model_path.suffix == ".safetensors":
|
||||||
|
weights_path = model_path
|
||||||
|
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
|
||||||
|
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
|
||||||
|
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
|
||||||
|
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"VAE weights not found at {model_path}")
|
||||||
|
|
||||||
|
print(f"Loading VAE decoder from {weights_path}...")
|
||||||
|
weights = mx.load(str(weights_path))
|
||||||
|
|
||||||
|
# Determine prefix based on weight keys
|
||||||
|
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
|
||||||
|
has_decoder_prefix = any(k.startswith("decoder.") for k in weights.keys())
|
||||||
|
|
||||||
|
if has_vae_prefix:
|
||||||
|
prefix = "vae.decoder."
|
||||||
|
stats_prefix = "vae.per_channel_statistics."
|
||||||
|
elif has_decoder_prefix:
|
||||||
|
prefix = "decoder."
|
||||||
|
stats_prefix = ""
|
||||||
|
else:
|
||||||
|
prefix = ""
|
||||||
|
stats_prefix = ""
|
||||||
|
|
||||||
|
# Load per-channel statistics for denormalization
|
||||||
|
# Note: use std-of-means (not mean-of-stds) for proper denormalization
|
||||||
|
mean_key = f"{stats_prefix}mean-of-means" if stats_prefix else "latents_mean"
|
||||||
|
std_key = f"{stats_prefix}std-of-means" if stats_prefix else "latents_std"
|
||||||
|
|
||||||
|
if mean_key in weights:
|
||||||
|
decoder.latents_mean = weights[mean_key]
|
||||||
|
print(f" Loaded latent mean: shape {decoder.latents_mean.shape}")
|
||||||
|
if std_key in weights:
|
||||||
|
decoder.latents_std = weights[std_key]
|
||||||
|
print(f" Loaded latent std: shape {decoder.latents_std.shape}")
|
||||||
|
|
||||||
|
# Build decoder weights dict with key remapping
|
||||||
|
decoder_weights = {}
|
||||||
|
for key, value in weights.items():
|
||||||
|
if not key.startswith(prefix):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Remove prefix
|
||||||
|
new_key = key[len(prefix):]
|
||||||
|
|
||||||
|
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||||
|
if ".conv.weight" in key and value.ndim == 5:
|
||||||
|
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||||
|
if ".conv.bias" in key:
|
||||||
|
pass # bias doesn't need transpose
|
||||||
|
|
||||||
|
|
||||||
|
if ".conv.weight" in new_key or ".conv.bias" in new_key:
|
||||||
|
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
|
||||||
|
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
|
||||||
|
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
|
||||||
|
|
||||||
|
decoder_weights[new_key] = value
|
||||||
|
|
||||||
|
print(f" Found {len(decoder_weights)} decoder weights")
|
||||||
|
|
||||||
|
ts_keys = [k for k in decoder_weights.keys() if "scale_shift" in k or "time_embedder" in k or "timestep_scale" in k]
|
||||||
|
print(f" Found {len(ts_keys)} timestep conditioning weights")
|
||||||
|
|
||||||
|
# Load weights
|
||||||
|
decoder.load_weights(list(decoder_weights.items()), strict=False)
|
||||||
|
|
||||||
|
print("VAE decoder loaded successfully")
|
||||||
|
return decoder
|
||||||
120
mlx_video/models/ltx/video_vae/ops.py
Normal file
120
mlx_video/models/ltx/video_vae/ops.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
"""Operations for Video VAE."""
|
||||||
|
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
|
||||||
|
"""Convert video to patches.
|
||||||
|
|
||||||
|
Moves spatial pixels from H, W dimensions to channel dimension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor of shape (B, C, F, H, W)
|
||||||
|
patch_size_hw: Spatial patch size
|
||||||
|
patch_size_t: Temporal patch size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Patched tensor of shape (B, C * patch_size_hw^2, F, H/patch_size_hw, W/patch_size_hw)
|
||||||
|
"""
|
||||||
|
b, c, f, h, w = x.shape
|
||||||
|
|
||||||
|
# Check dimensions are divisible
|
||||||
|
assert h % patch_size_hw == 0 and w % patch_size_hw == 0
|
||||||
|
assert f % patch_size_t == 0
|
||||||
|
|
||||||
|
# New dimensions
|
||||||
|
new_h = h // patch_size_hw
|
||||||
|
new_w = w // patch_size_hw
|
||||||
|
new_f = f // patch_size_t
|
||||||
|
new_c = c * patch_size_hw * patch_size_hw * patch_size_t
|
||||||
|
|
||||||
|
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
|
||||||
|
x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw))
|
||||||
|
|
||||||
|
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, ph, pw, F', H', W')
|
||||||
|
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
|
||||||
|
|
||||||
|
# Reshape: (B, C, pt, ph, pw, F', H', W') -> (B, C*pt*ph*pw, F', H', W')
|
||||||
|
x = mx.reshape(x, (b, new_c, new_f, new_h, new_w))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def unpatchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
|
||||||
|
"""Convert patches back to video.
|
||||||
|
|
||||||
|
Inverse of patchify - moves pixels from channel dimension back to spatial.
|
||||||
|
Matches PyTorch einops: "b (c p r q) f h w -> b c (f p) (h q) (w r)"
|
||||||
|
where p=patch_size_t, r=patch_size_hw (width), q=patch_size_hw (height)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Patched tensor of shape (B, C * patch_size_hw^2, F, H, W)
|
||||||
|
patch_size_hw: Spatial patch size
|
||||||
|
patch_size_t: Temporal patch size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Video tensor of shape (B, C, F * patch_size_t, H * patch_size_hw, W * patch_size_hw)
|
||||||
|
"""
|
||||||
|
b, c_packed, f, h, w = x.shape
|
||||||
|
|
||||||
|
# Calculate original channel count
|
||||||
|
c = c_packed // (patch_size_hw * patch_size_hw * patch_size_t)
|
||||||
|
|
||||||
|
# Reshape: (B, C*pt*pr*pq, F, H, W) -> (B, C, pt, pr, pq, F, H, W)
|
||||||
|
# where pt=temporal, pr=width_patch (r), pq=height_patch (q)
|
||||||
|
# Channel layout from PyTorch is (c, p, r, q) = (c, temporal, width, height)
|
||||||
|
x = mx.reshape(x, (b, c, patch_size_t, patch_size_hw, patch_size_hw, f, h, w))
|
||||||
|
|
||||||
|
# Permute to interleave patches with spatial dims:
|
||||||
|
# (B, C, pt, pr, pq, F, H, W) -> (B, C, F, pt, H, pq, W, pr)
|
||||||
|
|
||||||
|
x = mx.transpose(x, (0, 1, 5, 2, 6, 4, 7, 3))
|
||||||
|
|
||||||
|
# Reshape: (B, C, F, pt, H, pq, W, pr) -> (B, C, F*pt, H*pq, W*pr)
|
||||||
|
x = mx.reshape(x, (b, c, f * patch_size_t, h * patch_size_hw, w * patch_size_hw))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PerChannelStatistics(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, latent_channels: int = 128):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.latent_channels = latent_channels
|
||||||
|
|
||||||
|
# Learnable per-channel mean and std
|
||||||
|
self.mean = mx.zeros((latent_channels,))
|
||||||
|
self.std = mx.ones((latent_channels,))
|
||||||
|
|
||||||
|
def normalize(self, x: mx.array) -> mx.array:
|
||||||
|
"""Normalize latents using per-channel statistics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor of shape (B, C, ...)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized tensor
|
||||||
|
"""
|
||||||
|
# Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1)
|
||||||
|
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||||
|
std = self.std.reshape(1, -1, 1, 1, 1)
|
||||||
|
|
||||||
|
return (x - mean) / std
|
||||||
|
|
||||||
|
def un_normalize(self, x: mx.array) -> mx.array:
|
||||||
|
"""Denormalize latents using per-channel statistics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Normalized tensor of shape (B, C, ...)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Denormalized tensor
|
||||||
|
"""
|
||||||
|
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||||
|
std = self.std.reshape(1, -1, 1, 1, 1)
|
||||||
|
|
||||||
|
return x * std + mean
|
||||||
171
mlx_video/models/ltx/video_vae/resnet.py
Normal file
171
mlx_video/models/ltx/video_vae/resnet.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""ResNet blocks for Video VAE."""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||||
|
from mlx_video.utils import PixelNorm
|
||||||
|
|
||||||
|
|
||||||
|
class NormLayerType(Enum):
|
||||||
|
GROUP_NORM = "group_norm"
|
||||||
|
PIXEL_NORM = "pixel_norm"
|
||||||
|
|
||||||
|
|
||||||
|
def get_norm_layer(
|
||||||
|
norm_type: NormLayerType,
|
||||||
|
num_channels: int,
|
||||||
|
num_groups: int = 32,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
) -> nn.Module:
|
||||||
|
|
||||||
|
if norm_type == NormLayerType.GROUP_NORM:
|
||||||
|
return nn.GroupNorm(num_groups=num_groups, dims=num_channels, eps=eps)
|
||||||
|
elif norm_type == NormLayerType.PIXEL_NORM:
|
||||||
|
return PixelNorm(eps=eps)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown norm type: {norm_type}")
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlock3D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims: int,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: Optional[int] = None,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
groups: int = 32,
|
||||||
|
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||||
|
inject_noise: bool = False,
|
||||||
|
timestep_conditioning: bool = False,
|
||||||
|
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
out_channels = out_channels or in_channels
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.inject_noise = inject_noise
|
||||||
|
|
||||||
|
# First normalization and convolution
|
||||||
|
self.norm1 = get_norm_layer(norm_layer, in_channels, groups, eps)
|
||||||
|
self.conv1 = CausalConv3d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Second normalization and convolution
|
||||||
|
self.norm2 = get_norm_layer(norm_layer, out_channels, groups, eps)
|
||||||
|
self.conv2 = CausalConv3d(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shortcut connection if channels change
|
||||||
|
if in_channels != out_channels:
|
||||||
|
self.shortcut = CausalConv3d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.shortcut = None
|
||||||
|
|
||||||
|
# Activation
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
causal: bool = True,
|
||||||
|
generator: Optional[int] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
# First block
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.conv1(x, causal=causal)
|
||||||
|
|
||||||
|
# Inject noise if enabled
|
||||||
|
if self.inject_noise and generator is not None:
|
||||||
|
noise = mx.random.normal(x.shape)
|
||||||
|
x = x + noise * 0.01
|
||||||
|
|
||||||
|
# Second block
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.conv2(x, causal=causal)
|
||||||
|
|
||||||
|
# Shortcut
|
||||||
|
if self.shortcut is not None:
|
||||||
|
residual = self.shortcut(residual, causal=causal)
|
||||||
|
|
||||||
|
return x + residual
|
||||||
|
|
||||||
|
|
||||||
|
class UNetMidBlock3D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims: int,
|
||||||
|
in_channels: int,
|
||||||
|
num_layers: int = 1,
|
||||||
|
resnet_eps: float = 1e-6,
|
||||||
|
resnet_groups: int = 32,
|
||||||
|
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||||
|
inject_noise: bool = False,
|
||||||
|
timestep_conditioning: bool = False,
|
||||||
|
attention_head_dim: Optional[int] = None,
|
||||||
|
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
# Create ResNet blocks
|
||||||
|
self.resnets = [
|
||||||
|
ResnetBlock3D(
|
||||||
|
dims=dims,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
inject_noise=inject_noise,
|
||||||
|
timestep_conditioning=timestep_conditioning,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
causal: bool = True,
|
||||||
|
timestep: Optional[mx.array] = None,
|
||||||
|
generator: Optional[int] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
|
||||||
|
for resnet in self.resnets:
|
||||||
|
x = resnet(x, causal=causal, generator=generator)
|
||||||
|
|
||||||
|
return x
|
||||||
173
mlx_video/models/ltx/video_vae/sampling.py
Normal file
173
mlx_video/models/ltx/video_vae/sampling.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
"""Sampling operations for Video VAE (upsampling/downsampling)."""
|
||||||
|
|
||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||||
|
|
||||||
|
|
||||||
|
class SpaceToDepthDownsample(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims: int,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
stride: Union[int, Tuple[int, int, int]],
|
||||||
|
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if isinstance(stride, int):
|
||||||
|
stride = (stride, stride, stride)
|
||||||
|
|
||||||
|
self.stride = stride
|
||||||
|
self.dims = dims
|
||||||
|
|
||||||
|
# Calculate the multiplier for channels
|
||||||
|
multiplier = stride[0] * stride[1] * stride[2]
|
||||||
|
intermediate_channels = in_channels * multiplier
|
||||||
|
|
||||||
|
# 1x1x1 convolution to adjust channels
|
||||||
|
self.conv = CausalConv3d(
|
||||||
|
in_channels=intermediate_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
|
||||||
|
|
||||||
|
b, c, d, h, w = x.shape
|
||||||
|
st, sh, sw = self.stride
|
||||||
|
|
||||||
|
# Pad if necessary to make dimensions divisible by stride
|
||||||
|
pad_d = (st - d % st) % st
|
||||||
|
pad_h = (sh - h % sh) % sh
|
||||||
|
pad_w = (sw - w % sw) % sw
|
||||||
|
|
||||||
|
if pad_d > 0 or pad_h > 0 or pad_w > 0:
|
||||||
|
# For causal, pad at the end of temporal dimension
|
||||||
|
if causal:
|
||||||
|
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)])
|
||||||
|
else:
|
||||||
|
x = mx.pad(x, [(0, 0), (0, 0), (pad_d // 2, pad_d - pad_d // 2),
|
||||||
|
(pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2)])
|
||||||
|
|
||||||
|
b, c, d, h, w = x.shape
|
||||||
|
|
||||||
|
# Reshape to group spatial elements
|
||||||
|
# (B, C, D, H, W) -> (B, C, D/st, st, H/sh, sh, W/sw, sw)
|
||||||
|
x = mx.reshape(x, (b, c, d // st, st, h // sh, sh, w // sw, sw))
|
||||||
|
|
||||||
|
# Permute to move stride elements to channel dim
|
||||||
|
# (B, C, D', st, H', sh, W', sw) -> (B, C, st, sh, sw, D', H', W')
|
||||||
|
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
|
||||||
|
|
||||||
|
# Reshape to combine channels
|
||||||
|
# (B, C, st, sh, sw, D', H', W') -> (B, C*st*sh*sw, D', H', W')
|
||||||
|
new_c = c * st * sh * sw
|
||||||
|
new_d = d // st
|
||||||
|
new_h = h // sh
|
||||||
|
new_w = w // sw
|
||||||
|
x = mx.reshape(x, (b, new_c, new_d, new_h, new_w))
|
||||||
|
|
||||||
|
# Apply 1x1 conv to adjust channels
|
||||||
|
x = self.conv(x, causal=causal)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DepthToSpaceUpsample(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims: int,
|
||||||
|
in_channels: int,
|
||||||
|
stride: Union[int, Tuple[int, int, int]],
|
||||||
|
residual: bool = False,
|
||||||
|
out_channels_reduction_factor: int = 1,
|
||||||
|
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if isinstance(stride, int):
|
||||||
|
stride = (stride, stride, stride)
|
||||||
|
|
||||||
|
self.stride = stride
|
||||||
|
self.dims = dims
|
||||||
|
self.residual = residual
|
||||||
|
self.out_channels_reduction_factor = out_channels_reduction_factor
|
||||||
|
|
||||||
|
# Calculate output channels
|
||||||
|
multiplier = stride[0] * stride[1] * stride[2]
|
||||||
|
out_channels = in_channels // out_channels_reduction_factor
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
# 3x3x3 convolution to prepare channels for unpacking (matches PyTorch)
|
||||||
|
self.conv = CausalConv3d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels * multiplier,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _depth_to_space(self, x: mx.array) -> mx.array:
|
||||||
|
b, c_packed, d, h, w = x.shape
|
||||||
|
st, sh, sw = self.stride
|
||||||
|
c = c_packed // (st * sh * sw)
|
||||||
|
|
||||||
|
# (B, C*st*sh*sw, D, H, W) -> (B, C, st, sh, sw, D, H, W)
|
||||||
|
x = mx.reshape(x, (b, c, st, sh, sw, d, h, w))
|
||||||
|
|
||||||
|
# (B, C, st, sh, sw, D, H, W) -> (B, C, D, st, H, sh, W, sw)
|
||||||
|
x = mx.transpose(x, (0, 1, 5, 2, 6, 3, 7, 4))
|
||||||
|
|
||||||
|
# (B, C, D, st, H, sh, W, sw) -> (B, C, D*st, H*sh, W*sw)
|
||||||
|
x = mx.reshape(x, (b, c, d * st, h * sh, w * sw))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
|
||||||
|
|
||||||
|
b, c, d, h, w = x.shape
|
||||||
|
st, sh, sw = self.stride
|
||||||
|
|
||||||
|
# Compute residual path if enabled
|
||||||
|
x_residual = None
|
||||||
|
if self.residual:
|
||||||
|
# Reshape input: treat channels as spatial factors
|
||||||
|
# "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)"
|
||||||
|
x_residual = self._depth_to_space(x)
|
||||||
|
|
||||||
|
# Tile channels to match output (PyTorch .repeat() tiles, not element-repeat!)
|
||||||
|
# num_repeat = prod(stride) / out_channels_reduction_factor
|
||||||
|
num_repeat = (st * sh * sw) // self.out_channels_reduction_factor
|
||||||
|
x_residual = mx.tile(x_residual, (1, num_repeat, 1, 1, 1))
|
||||||
|
|
||||||
|
# Remove first temporal frame if temporal upsampling
|
||||||
|
if st > 1:
|
||||||
|
x_residual = x_residual[:, :, 1:, :, :]
|
||||||
|
|
||||||
|
# Apply conv
|
||||||
|
x = self.conv(x, causal=causal)
|
||||||
|
|
||||||
|
# Depth to space rearrangement
|
||||||
|
x = self._depth_to_space(x)
|
||||||
|
|
||||||
|
# Remove first frame for causal temporal upsampling
|
||||||
|
if st > 1:
|
||||||
|
x = x[:, :, 1:, :, :]
|
||||||
|
|
||||||
|
# Add residual
|
||||||
|
if self.residual and x_residual is not None:
|
||||||
|
x = x + x_residual
|
||||||
|
|
||||||
|
return x
|
||||||
528
mlx_video/models/ltx/video_vae/video_vae.py
Normal file
528
mlx_video/models/ltx/video_vae/video_vae.py
Normal file
@@ -0,0 +1,528 @@
|
|||||||
|
"""Video VAE Encoder and Decoder for LTX-2."""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||||
|
from mlx_video.models.ltx.video_vae.ops import PerChannelStatistics, patchify, unpatchify
|
||||||
|
from mlx_video.models.ltx.video_vae.resnet import (
|
||||||
|
NormLayerType,
|
||||||
|
ResnetBlock3D,
|
||||||
|
UNetMidBlock3D,
|
||||||
|
get_norm_layer,
|
||||||
|
)
|
||||||
|
from mlx_video.models.ltx.video_vae.sampling import (
|
||||||
|
DepthToSpaceUpsample,
|
||||||
|
SpaceToDepthDownsample,
|
||||||
|
)
|
||||||
|
from mlx_video.utils import PixelNorm
|
||||||
|
|
||||||
|
|
||||||
|
class LogVarianceType(Enum):
|
||||||
|
"""Log variance mode for VAE."""
|
||||||
|
PER_CHANNEL = "per_channel"
|
||||||
|
UNIFORM = "uniform"
|
||||||
|
CONSTANT = "constant"
|
||||||
|
NONE = "none"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_encoder_block(
|
||||||
|
block_name: str,
|
||||||
|
block_config: Dict[str, Any],
|
||||||
|
in_channels: int,
|
||||||
|
convolution_dimensions: int,
|
||||||
|
norm_layer: NormLayerType,
|
||||||
|
norm_num_groups: int,
|
||||||
|
spatial_padding_mode: PaddingModeType,
|
||||||
|
) -> Tuple[nn.Module, int]:
|
||||||
|
"""Create an encoder block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block_name: Type of block
|
||||||
|
block_config: Block configuration
|
||||||
|
in_channels: Input channels
|
||||||
|
convolution_dimensions: Number of dimensions
|
||||||
|
norm_layer: Normalization layer type
|
||||||
|
norm_num_groups: Number of groups for group norm
|
||||||
|
spatial_padding_mode: Padding mode
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (block, output_channels)
|
||||||
|
"""
|
||||||
|
out_channels = in_channels
|
||||||
|
|
||||||
|
if block_name == "res_x":
|
||||||
|
block = UNetMidBlock3D(
|
||||||
|
dims=convolution_dimensions,
|
||||||
|
in_channels=in_channels,
|
||||||
|
num_layers=block_config["num_layers"],
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
elif block_name == "res_x_y":
|
||||||
|
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||||
|
block = ResnetBlock3D(
|
||||||
|
dims=convolution_dimensions,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
eps=1e-6,
|
||||||
|
groups=norm_num_groups,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
elif block_name == "compress_time":
|
||||||
|
block = CausalConv3d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=(2, 1, 1),
|
||||||
|
padding=1,
|
||||||
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
elif block_name == "compress_space":
|
||||||
|
block = CausalConv3d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=(1, 2, 2),
|
||||||
|
padding=1,
|
||||||
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
elif block_name == "compress_all":
|
||||||
|
block = CausalConv3d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=(2, 2, 2),
|
||||||
|
padding=1,
|
||||||
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
elif block_name == "compress_all_x_y":
|
||||||
|
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||||
|
block = CausalConv3d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=(2, 2, 2),
|
||||||
|
padding=1,
|
||||||
|
causal=True,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
elif block_name == "compress_all_res":
|
||||||
|
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||||
|
block = SpaceToDepthDownsample(
|
||||||
|
dims=convolution_dimensions,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
stride=(2, 2, 2),
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
elif block_name == "compress_space_res":
|
||||||
|
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||||
|
block = SpaceToDepthDownsample(
|
||||||
|
dims=convolution_dimensions,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
stride=(1, 2, 2),
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
elif block_name == "compress_time_res":
|
||||||
|
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||||
|
block = SpaceToDepthDownsample(
|
||||||
|
dims=convolution_dimensions,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
stride=(2, 1, 1),
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown encoder block: {block_name}")
|
||||||
|
|
||||||
|
return block, out_channels
|
||||||
|
|
||||||
|
|
||||||
|
def _make_decoder_block(
|
||||||
|
block_name: str,
|
||||||
|
block_config: Dict[str, Any],
|
||||||
|
in_channels: int,
|
||||||
|
convolution_dimensions: int,
|
||||||
|
norm_layer: NormLayerType,
|
||||||
|
timestep_conditioning: bool,
|
||||||
|
norm_num_groups: int,
|
||||||
|
spatial_padding_mode: PaddingModeType,
|
||||||
|
) -> Tuple[nn.Module, int]:
|
||||||
|
"""Create a decoder block."""
|
||||||
|
out_channels = in_channels
|
||||||
|
|
||||||
|
if block_name == "res_x":
|
||||||
|
block = UNetMidBlock3D(
|
||||||
|
dims=convolution_dimensions,
|
||||||
|
in_channels=in_channels,
|
||||||
|
num_layers=block_config["num_layers"],
|
||||||
|
resnet_eps=1e-6,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
inject_noise=block_config.get("inject_noise", False),
|
||||||
|
timestep_conditioning=timestep_conditioning,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
elif block_name == "res_x_y":
|
||||||
|
out_channels = in_channels // block_config.get("multiplier", 2)
|
||||||
|
block = ResnetBlock3D(
|
||||||
|
dims=convolution_dimensions,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
eps=1e-6,
|
||||||
|
groups=norm_num_groups,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
inject_noise=block_config.get("inject_noise", False),
|
||||||
|
timestep_conditioning=False,
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
elif block_name == "compress_time":
|
||||||
|
block = DepthToSpaceUpsample(
|
||||||
|
dims=convolution_dimensions,
|
||||||
|
in_channels=in_channels,
|
||||||
|
stride=(2, 1, 1),
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
elif block_name == "compress_space":
|
||||||
|
block = DepthToSpaceUpsample(
|
||||||
|
dims=convolution_dimensions,
|
||||||
|
in_channels=in_channels,
|
||||||
|
stride=(1, 2, 2),
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
elif block_name == "compress_all":
|
||||||
|
out_channels = in_channels // block_config.get("multiplier", 1)
|
||||||
|
block = DepthToSpaceUpsample(
|
||||||
|
dims=convolution_dimensions,
|
||||||
|
in_channels=in_channels,
|
||||||
|
stride=(2, 2, 2),
|
||||||
|
residual=block_config.get("residual", False),
|
||||||
|
out_channels_reduction_factor=block_config.get("multiplier", 1),
|
||||||
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown decoder block: {block_name}")
|
||||||
|
|
||||||
|
return block, out_channels
|
||||||
|
|
||||||
|
|
||||||
|
class VideoEncoder(nn.Module):
|
||||||
|
|
||||||
|
_DEFAULT_NORM_NUM_GROUPS = 32
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
convolution_dimensions: int = 3,
|
||||||
|
in_channels: int = 3,
|
||||||
|
out_channels: int = 128,
|
||||||
|
encoder_blocks: List[Tuple[str, Any]] = None,
|
||||||
|
patch_size: int = 4,
|
||||||
|
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||||
|
latent_log_var: LogVarianceType = LogVarianceType.UNIFORM,
|
||||||
|
encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||||
|
):
|
||||||
|
"""Initialize VideoEncoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
convolution_dimensions: Number of dimensions (3 for video)
|
||||||
|
in_channels: Input channels (3 for RGB)
|
||||||
|
out_channels: Output latent channels
|
||||||
|
encoder_blocks: List of (block_name, config) tuples
|
||||||
|
patch_size: Spatial patch size
|
||||||
|
norm_layer: Normalization layer type
|
||||||
|
latent_log_var: Log variance mode
|
||||||
|
encoder_spatial_padding_mode: Padding mode
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if encoder_blocks is None:
|
||||||
|
encoder_blocks = []
|
||||||
|
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.norm_layer = norm_layer
|
||||||
|
self.latent_channels = out_channels
|
||||||
|
self.latent_log_var = latent_log_var
|
||||||
|
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
|
||||||
|
|
||||||
|
# Per-channel statistics for normalizing latents
|
||||||
|
self.per_channel_statistics = PerChannelStatistics(latent_channels=out_channels)
|
||||||
|
|
||||||
|
# After patchify, channels increase by patch_size^2
|
||||||
|
in_channels = in_channels * patch_size ** 2
|
||||||
|
feature_channels = out_channels
|
||||||
|
|
||||||
|
# Initial convolution
|
||||||
|
self.conv_in = CausalConv3d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=feature_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
causal=True,
|
||||||
|
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build encoder blocks
|
||||||
|
self.down_blocks = []
|
||||||
|
for block_name, block_params in encoder_blocks:
|
||||||
|
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||||
|
|
||||||
|
block, feature_channels = _make_encoder_block(
|
||||||
|
block_name=block_name,
|
||||||
|
block_config=block_config,
|
||||||
|
in_channels=feature_channels,
|
||||||
|
convolution_dimensions=convolution_dimensions,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
norm_num_groups=self._norm_num_groups,
|
||||||
|
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||||
|
)
|
||||||
|
self.down_blocks.append(block)
|
||||||
|
|
||||||
|
# Output normalization and convolution
|
||||||
|
if norm_layer == NormLayerType.GROUP_NORM:
|
||||||
|
self.conv_norm_out = nn.GroupNorm(
|
||||||
|
num_groups=self._norm_num_groups,
|
||||||
|
dims=feature_channels,
|
||||||
|
eps=1e-6,
|
||||||
|
)
|
||||||
|
elif norm_layer == NormLayerType.PIXEL_NORM:
|
||||||
|
self.conv_norm_out = PixelNorm()
|
||||||
|
|
||||||
|
self.conv_act = nn.SiLU()
|
||||||
|
|
||||||
|
# Calculate output convolution channels
|
||||||
|
conv_out_channels = out_channels
|
||||||
|
if latent_log_var == LogVarianceType.PER_CHANNEL:
|
||||||
|
conv_out_channels *= 2
|
||||||
|
elif latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
|
||||||
|
conv_out_channels += 1
|
||||||
|
|
||||||
|
self.conv_out = CausalConv3d(
|
||||||
|
in_channels=feature_channels,
|
||||||
|
out_channels=conv_out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
causal=True,
|
||||||
|
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, sample: mx.array) -> mx.array:
|
||||||
|
"""Encode video to latent representation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample: Input video of shape (B, C, F, H, W).
|
||||||
|
F must be 1 + 8*k (e.g., 1, 9, 17, 25, 33...)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized latent means of shape (B, 128, F', H', W')
|
||||||
|
"""
|
||||||
|
# Validate frame count
|
||||||
|
frames_count = sample.shape[2]
|
||||||
|
if ((frames_count - 1) % 8) != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid number of frames: Encode input must have 1 + 8 * x frames "
|
||||||
|
f"(e.g., 1, 9, 17, ...). Got {frames_count} frames."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initial patchify
|
||||||
|
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||||
|
sample = self.conv_in(sample, causal=True)
|
||||||
|
|
||||||
|
# Process through encoder blocks
|
||||||
|
for down_block in self.down_blocks:
|
||||||
|
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
|
||||||
|
sample = down_block(sample, causal=True)
|
||||||
|
else:
|
||||||
|
sample = down_block(sample, causal=True)
|
||||||
|
|
||||||
|
# Output processing
|
||||||
|
sample = self.conv_norm_out(sample)
|
||||||
|
sample = self.conv_act(sample)
|
||||||
|
sample = self.conv_out(sample, causal=True)
|
||||||
|
|
||||||
|
# Handle log variance modes
|
||||||
|
if self.latent_log_var == LogVarianceType.UNIFORM:
|
||||||
|
means = sample[:, :-1, ...]
|
||||||
|
logvar = sample[:, -1:, ...]
|
||||||
|
num_channels = means.shape[1]
|
||||||
|
repeated_logvar = mx.tile(logvar, (1, num_channels, 1, 1, 1))
|
||||||
|
sample = mx.concatenate([means, repeated_logvar], axis=1)
|
||||||
|
elif self.latent_log_var == LogVarianceType.CONSTANT:
|
||||||
|
sample = sample[:, :-1, ...]
|
||||||
|
approx_ln_0 = -30
|
||||||
|
sample = mx.concatenate([
|
||||||
|
sample,
|
||||||
|
mx.full_like(sample, approx_ln_0),
|
||||||
|
], axis=1)
|
||||||
|
|
||||||
|
# Split into means and logvar, normalize means
|
||||||
|
means = sample[:, :self.latent_channels, ...]
|
||||||
|
return self.per_channel_statistics.normalize(means)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoDecoder(nn.Module):
|
||||||
|
|
||||||
|
_DEFAULT_NORM_NUM_GROUPS = 32
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
convolution_dimensions: int = 3,
|
||||||
|
in_channels: int = 128,
|
||||||
|
out_channels: int = 3,
|
||||||
|
decoder_blocks: List[Tuple[str, Any]] = None,
|
||||||
|
patch_size: int = 4,
|
||||||
|
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||||
|
causal: bool = False,
|
||||||
|
timestep_conditioning: bool = False,
|
||||||
|
decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||||
|
):
|
||||||
|
"""Initialize VideoDecoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
convolution_dimensions: Number of dimensions
|
||||||
|
in_channels: Input latent channels
|
||||||
|
out_channels: Output channels (3 for RGB)
|
||||||
|
decoder_blocks: List of (block_name, config) tuples
|
||||||
|
patch_size: Spatial patch size
|
||||||
|
norm_layer: Normalization layer type
|
||||||
|
causal: Whether to use causal convolutions
|
||||||
|
timestep_conditioning: Whether to use timestep conditioning
|
||||||
|
decoder_spatial_padding_mode: Padding mode
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if decoder_blocks is None:
|
||||||
|
decoder_blocks = []
|
||||||
|
|
||||||
|
self.patch_size = patch_size
|
||||||
|
out_channels = out_channels * patch_size ** 2
|
||||||
|
self.causal = causal
|
||||||
|
self.timestep_conditioning = timestep_conditioning
|
||||||
|
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
|
||||||
|
|
||||||
|
# Per-channel statistics for denormalizing latents
|
||||||
|
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
|
||||||
|
|
||||||
|
# Noise and timestep parameters
|
||||||
|
self.decode_noise_scale = 0.025
|
||||||
|
self.decode_timestep = 0.05
|
||||||
|
|
||||||
|
# Compute initial feature channels
|
||||||
|
feature_channels = in_channels
|
||||||
|
for block_name, block_params in list(reversed(decoder_blocks)):
|
||||||
|
block_config = block_params if isinstance(block_params, dict) else {}
|
||||||
|
if block_name == "res_x_y":
|
||||||
|
feature_channels = feature_channels * block_config.get("multiplier", 2)
|
||||||
|
if block_name == "compress_all":
|
||||||
|
feature_channels = feature_channels * block_config.get("multiplier", 1)
|
||||||
|
|
||||||
|
# Initial convolution
|
||||||
|
self.conv_in = CausalConv3d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=feature_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
causal=True,
|
||||||
|
spatial_padding_mode=decoder_spatial_padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build decoder blocks (reversed order)
|
||||||
|
self.up_blocks = []
|
||||||
|
for block_name, block_params in list(reversed(decoder_blocks)):
|
||||||
|
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||||
|
|
||||||
|
block, feature_channels = _make_decoder_block(
|
||||||
|
block_name=block_name,
|
||||||
|
block_config=block_config,
|
||||||
|
in_channels=feature_channels,
|
||||||
|
convolution_dimensions=convolution_dimensions,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
timestep_conditioning=timestep_conditioning,
|
||||||
|
norm_num_groups=self._norm_num_groups,
|
||||||
|
spatial_padding_mode=decoder_spatial_padding_mode,
|
||||||
|
)
|
||||||
|
self.up_blocks.append(block)
|
||||||
|
|
||||||
|
# Output normalization
|
||||||
|
if norm_layer == NormLayerType.GROUP_NORM:
|
||||||
|
self.conv_norm_out = nn.GroupNorm(
|
||||||
|
num_groups=self._norm_num_groups,
|
||||||
|
dims=feature_channels,
|
||||||
|
eps=1e-6,
|
||||||
|
)
|
||||||
|
elif norm_layer == NormLayerType.PIXEL_NORM:
|
||||||
|
self.conv_norm_out = PixelNorm()
|
||||||
|
|
||||||
|
self.conv_act = nn.SiLU()
|
||||||
|
self.conv_out = CausalConv3d(
|
||||||
|
in_channels=feature_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
causal=True,
|
||||||
|
spatial_padding_mode=decoder_spatial_padding_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
sample: mx.array,
|
||||||
|
timestep: Optional[mx.array] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Decode latent to video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample: Latent tensor of shape (B, 128, F', H', W')
|
||||||
|
timestep: Optional timestep for conditioning
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decoded video of shape (B, 3, F, H, W)
|
||||||
|
"""
|
||||||
|
batch_size = sample.shape[0]
|
||||||
|
|
||||||
|
# Add noise if timestep conditioning is enabled
|
||||||
|
if self.timestep_conditioning:
|
||||||
|
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
||||||
|
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
||||||
|
|
||||||
|
# Denormalize latents
|
||||||
|
sample = self.per_channel_statistics.un_normalize(sample)
|
||||||
|
|
||||||
|
# Use default timestep if not provided
|
||||||
|
if timestep is None and self.timestep_conditioning:
|
||||||
|
timestep = mx.full((batch_size,), self.decode_timestep)
|
||||||
|
|
||||||
|
# Initial convolution
|
||||||
|
sample = self.conv_in(sample, causal=self.causal)
|
||||||
|
|
||||||
|
# Process through decoder blocks
|
||||||
|
for up_block in self.up_blocks:
|
||||||
|
if isinstance(up_block, UNetMidBlock3D):
|
||||||
|
sample = up_block(sample, causal=self.causal)
|
||||||
|
elif isinstance(up_block, ResnetBlock3D):
|
||||||
|
sample = up_block(sample, causal=self.causal)
|
||||||
|
else:
|
||||||
|
sample = up_block(sample, causal=self.causal)
|
||||||
|
|
||||||
|
# Output processing
|
||||||
|
sample = self.conv_norm_out(sample)
|
||||||
|
sample = self.conv_act(sample)
|
||||||
|
sample = self.conv_out(sample, causal=self.causal)
|
||||||
|
|
||||||
|
# Unpatchify to restore spatial resolution
|
||||||
|
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||||
|
|
||||||
|
return sample
|
||||||
165
mlx_video/postprocess.py
Normal file
165
mlx_video/postprocess.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75) -> np.ndarray:
|
||||||
|
"""Apply bilateral filter to reduce grid artifacts while preserving edges.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input image as uint8 numpy array (H, W, C)
|
||||||
|
d: Diameter of each pixel neighborhood
|
||||||
|
sigma_color: Filter sigma in the color space
|
||||||
|
sigma_space: Filter sigma in the coordinate space
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered image
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
return cv2.bilateralFilter(image, d, sigma_color, sigma_space)
|
||||||
|
except ImportError:
|
||||||
|
# Fallback to simple Gaussian blur if cv2 not available
|
||||||
|
return gaussian_blur(image, kernel_size=3)
|
||||||
|
|
||||||
|
|
||||||
|
def gaussian_blur(image: np.ndarray, kernel_size: int = 3) -> np.ndarray:
|
||||||
|
"""Apply Gaussian blur.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input image as uint8 numpy array (H, W, C)
|
||||||
|
kernel_size: Size of the Gaussian kernel (must be odd)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Blurred image
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
|
||||||
|
except ImportError:
|
||||||
|
# Simple box blur fallback
|
||||||
|
from scipy.ndimage import uniform_filter
|
||||||
|
return uniform_filter(image, size=(kernel_size, kernel_size, 1)).astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
def unsharp_mask(image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, amount: float = 1.0) -> np.ndarray:
|
||||||
|
"""Apply unsharp masking to enhance edges after blur.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input image as uint8 numpy array
|
||||||
|
kernel_size: Size of the Gaussian kernel
|
||||||
|
sigma: Gaussian sigma
|
||||||
|
amount: Strength of sharpening
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sharpened image
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
|
||||||
|
sharpened = cv2.addWeighted(image, 1 + amount, blurred, -amount, 0)
|
||||||
|
return np.clip(sharpened, 0, 255).astype(np.uint8)
|
||||||
|
except ImportError:
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_grid_artifacts(
|
||||||
|
video: np.ndarray,
|
||||||
|
method: str = "bilateral",
|
||||||
|
strength: float = 1.0,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Reduce grid artifacts in video frames.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video: Video as numpy array (F, H, W, C) uint8
|
||||||
|
method: "bilateral", "gaussian", or "frequency"
|
||||||
|
strength: How strong to apply the filter (0-1)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed video
|
||||||
|
"""
|
||||||
|
if method == "bilateral":
|
||||||
|
d = max(3, int(5 * strength))
|
||||||
|
sigma = 50 + 50 * strength
|
||||||
|
processed = np.stack([
|
||||||
|
bilateral_filter(frame, d=d, sigma_color=sigma, sigma_space=sigma)
|
||||||
|
for frame in video
|
||||||
|
])
|
||||||
|
elif method == "gaussian":
|
||||||
|
kernel_size = max(3, int(3 + 4 * strength))
|
||||||
|
if kernel_size % 2 == 0:
|
||||||
|
kernel_size += 1
|
||||||
|
processed = np.stack([
|
||||||
|
gaussian_blur(frame, kernel_size=kernel_size)
|
||||||
|
for frame in video
|
||||||
|
])
|
||||||
|
elif method == "frequency":
|
||||||
|
processed = np.stack([
|
||||||
|
remove_grid_frequency(frame, grid_size=8)
|
||||||
|
for frame in video
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown method: {method}")
|
||||||
|
|
||||||
|
# Optionally sharpen to recover some detail
|
||||||
|
if strength < 1.0:
|
||||||
|
# Blend with original based on strength
|
||||||
|
alpha = strength
|
||||||
|
processed = (alpha * processed + (1 - alpha) * video).astype(np.uint8)
|
||||||
|
|
||||||
|
return processed
|
||||||
|
|
||||||
|
|
||||||
|
def remove_grid_frequency(frame: np.ndarray, grid_size: int = 8) -> np.ndarray:
|
||||||
|
"""Remove grid-frequency components using FFT.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: Input frame (H, W, C) uint8
|
||||||
|
grid_size: Expected grid periodicity in pixels
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered frame
|
||||||
|
"""
|
||||||
|
result = np.zeros_like(frame)
|
||||||
|
|
||||||
|
for c in range(frame.shape[2]):
|
||||||
|
channel = frame[:, :, c].astype(np.float32)
|
||||||
|
h, w = channel.shape
|
||||||
|
|
||||||
|
# FFT
|
||||||
|
fft = np.fft.fft2(channel)
|
||||||
|
fft_shifted = np.fft.fftshift(fft)
|
||||||
|
|
||||||
|
# Create notch filter at grid frequencies
|
||||||
|
cy, cx = h // 2, w // 2
|
||||||
|
mask = np.ones((h, w), dtype=np.float32)
|
||||||
|
|
||||||
|
# Attenuate frequencies at grid periodicity
|
||||||
|
freq_y = h // grid_size
|
||||||
|
freq_x = w // grid_size
|
||||||
|
|
||||||
|
for fy in range(-2, 3):
|
||||||
|
for fx in range(-2, 3):
|
||||||
|
if fy == 0 and fx == 0:
|
||||||
|
continue
|
||||||
|
y_pos = cy + fy * freq_y
|
||||||
|
x_pos = cx + fx * freq_x
|
||||||
|
if 0 <= y_pos < h and 0 <= x_pos < w:
|
||||||
|
# Gaussian attenuation around the frequency
|
||||||
|
for dy in range(-2, 3):
|
||||||
|
for dx in range(-2, 3):
|
||||||
|
yy, xx = y_pos + dy, x_pos + dx
|
||||||
|
if 0 <= yy < h and 0 <= xx < w:
|
||||||
|
dist = np.sqrt(dy**2 + dx**2)
|
||||||
|
mask[yy, xx] *= min(1.0, dist / 3.0)
|
||||||
|
|
||||||
|
# Apply mask and inverse FFT
|
||||||
|
fft_filtered = fft_shifted * mask
|
||||||
|
channel_filtered = np.fft.ifft2(np.fft.ifftshift(fft_filtered)).real
|
||||||
|
|
||||||
|
result[:, :, c] = np.clip(channel_filtered, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
26
mlx_video/text_projection.py
Normal file
26
mlx_video/text_projection.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class PixArtAlphaTextProjection(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
hidden_size: int,
|
||||||
|
out_features: int | None = None,
|
||||||
|
bias: bool = True,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
out_features = out_features or hidden_size
|
||||||
|
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
|
||||||
|
self.act = nn.GELU(approx="precise")
|
||||||
|
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
x = self.linear1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.linear2(x)
|
||||||
|
return x
|
||||||
127
mlx_video/utils.py
Normal file
127
mlx_video/utils.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
"""Utility functions for MLX Video."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
|
||||||
|
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],)), eps)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def to_denoised(
|
||||||
|
noisy: mx.array,
|
||||||
|
velocity: mx.array,
|
||||||
|
sigma: mx.array | float
|
||||||
|
) -> mx.array:
|
||||||
|
"""Convert velocity prediction to denoised output.
|
||||||
|
|
||||||
|
Given noisy input x_t and velocity prediction v, compute denoised x_0:
|
||||||
|
x_0 = x_t - sigma * v
|
||||||
|
|
||||||
|
Args:
|
||||||
|
noisy: Noisy input tensor x_t
|
||||||
|
velocity: Velocity prediction v
|
||||||
|
sigma: Noise level (scalar or per-sample)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Denoised tensor x_0
|
||||||
|
"""
|
||||||
|
if isinstance(sigma, (int, float)):
|
||||||
|
return noisy - sigma * velocity
|
||||||
|
else:
|
||||||
|
# sigma is per-sample
|
||||||
|
while sigma.ndim < velocity.ndim:
|
||||||
|
sigma = mx.expand_dims(sigma, axis=-1)
|
||||||
|
return noisy - sigma * velocity
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_interleave(x: mx.array, repeats: int, axis: int = -1) -> mx.array:
|
||||||
|
"""Repeat elements of tensor along an axis, similar to torch.repeat_interleave.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor
|
||||||
|
repeats: Number of repetitions for each element
|
||||||
|
axis: The axis along which to repeat values
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor with repeated values
|
||||||
|
"""
|
||||||
|
# Handle negative axis
|
||||||
|
if axis < 0:
|
||||||
|
axis = x.ndim + axis
|
||||||
|
|
||||||
|
# Get shape
|
||||||
|
shape = list(x.shape)
|
||||||
|
|
||||||
|
# Expand dims, repeat, then reshape
|
||||||
|
x = mx.expand_dims(x, axis=axis + 1)
|
||||||
|
|
||||||
|
# Create tile pattern
|
||||||
|
tile_pattern = [1] * x.ndim
|
||||||
|
tile_pattern[axis + 1] = repeats
|
||||||
|
|
||||||
|
x = mx.tile(x, tile_pattern)
|
||||||
|
|
||||||
|
# Reshape to merge the repeated dimension
|
||||||
|
new_shape = shape.copy()
|
||||||
|
new_shape[axis] *= repeats
|
||||||
|
|
||||||
|
return mx.reshape(x, new_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class PixelNorm(nn.Module):
|
||||||
|
def __init__(self, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
return x / mx.sqrt(mx.mean(x * x, axis=1, keepdims=True) + self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
def get_timestep_embedding(
|
||||||
|
timesteps: mx.array,
|
||||||
|
embedding_dim: int,
|
||||||
|
flip_sin_to_cos: bool = False,
|
||||||
|
downscale_freq_shift: float = 1.0,
|
||||||
|
scale: float = 1.0,
|
||||||
|
max_period: int = 10000,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Create sinusoidal timestep embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timesteps: 1D tensor of timesteps
|
||||||
|
embedding_dim: Dimension of the embeddings to create
|
||||||
|
flip_sin_to_cos: If True, flip sin and cos ordering
|
||||||
|
downscale_freq_shift: Frequency shift factor
|
||||||
|
scale: Scale factor for timesteps
|
||||||
|
max_period: Maximum period for the sinusoids
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of shape (len(timesteps), embedding_dim)
|
||||||
|
"""
|
||||||
|
assert timesteps.ndim == 1, "Timesteps should be 1D"
|
||||||
|
|
||||||
|
half_dim = embedding_dim // 2
|
||||||
|
exponent = -math.log(max_period) * mx.arange(0, half_dim, dtype=mx.float32)
|
||||||
|
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||||
|
|
||||||
|
emb = mx.exp(exponent)
|
||||||
|
emb = (timesteps[:, None].astype(mx.float32) * scale) * emb[None, :]
|
||||||
|
|
||||||
|
# Compute sin and cos embeddings
|
||||||
|
if flip_sin_to_cos:
|
||||||
|
emb = mx.concatenate([mx.cos(emb), mx.sin(emb)], axis=-1)
|
||||||
|
else:
|
||||||
|
emb = mx.concatenate([mx.sin(emb), mx.cos(emb)], axis=-1)
|
||||||
|
|
||||||
|
# Zero pad if odd embedding dimension
|
||||||
|
if embedding_dim % 2 == 1:
|
||||||
|
emb = mx.pad(emb, [(0, 0), (0, 1)])
|
||||||
|
|
||||||
|
return emb
|
||||||
26
pyproject.toml
Normal file
26
pyproject.toml
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
[project]
|
||||||
|
name = "mlx-video"
|
||||||
|
version = "0.0.1"
|
||||||
|
description = "MLX-Video is the best package for inference and finetuning of Image-Video-Audio generation models on your Mac using MLX."
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
dependencies = [
|
||||||
|
"mlx>=0.22.0",
|
||||||
|
"numpy",
|
||||||
|
"safetensors",
|
||||||
|
"huggingface_hub",
|
||||||
|
"tqdm",
|
||||||
|
]
|
||||||
|
license = {text="MIT"}
|
||||||
|
authors = [
|
||||||
|
{name = "Prince Canuma", email = "prince.gdt@gmail.com"}
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest",
|
||||||
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
479
uv.lock
generated
Normal file
479
uv.lock
generated
Normal file
@@ -0,0 +1,479 @@
|
|||||||
|
version = 1
|
||||||
|
revision = 3
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "anyio"
|
||||||
|
version = "4.12.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "idna" },
|
||||||
|
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/96/f0/5eb65b2bb0d09ac6776f2eb54adee6abe8228ea05b20a5ad0e4945de8aac/anyio-4.12.1.tar.gz", hash = "sha256:41cfcc3a4c85d3f05c932da7c26d0201ac36f72abd4435ba90d0464a3ffed703", size = 228685, upload-time = "2026-01-06T11:45:21.246Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/38/0e/27be9fdef66e72d64c0cdc3cc2823101b80585f8119b5c112c2e8f5f7dab/anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c", size = 113592, upload-time = "2026-01-06T11:45:19.497Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "certifi"
|
||||||
|
version = "2026.1.4"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/e0/2d/a891ca51311197f6ad14a7ef42e2399f36cf2f9bd44752b3dc4eab60fdc5/certifi-2026.1.4.tar.gz", hash = "sha256:ac726dd470482006e014ad384921ed6438c457018f4b3d204aea4281258b2120", size = 154268, upload-time = "2026-01-04T02:42:41.825Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e6/ad/3cc14f097111b4de0040c83a525973216457bbeeb63739ef1ed275c1c021/certifi-2026.1.4-py3-none-any.whl", hash = "sha256:9943707519e4add1115f44c2bc244f782c0249876bf51b6599fee1ffbedd685c", size = 152900, upload-time = "2026-01-04T02:42:40.15Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "click"
|
||||||
|
version = "8.3.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/3d/fa/656b739db8587d7b5dfa22e22ed02566950fbfbcdc20311993483657a5c0/click-8.3.1.tar.gz", hash = "sha256:12ff4785d337a1bb490bb7e9c2b1ee5da3112e94a8622f26a6c77f5d2fc6842a", size = 295065, upload-time = "2025-11-15T20:45:42.706Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl", hash = "sha256:981153a64e25f12d547d3426c367a4857371575ee7ad18df2a6183ab0545b2a6", size = 108274, upload-time = "2025-11-15T20:45:41.139Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "colorama"
|
||||||
|
version = "0.4.6"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "filelock"
|
||||||
|
version = "3.20.2"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/c1/e0/a75dbe4bca1e7d41307323dad5ea2efdd95408f74ab2de8bd7dba9b51a1a/filelock-3.20.2.tar.gz", hash = "sha256:a2241ff4ddde2a7cebddf78e39832509cb045d18ec1a09d7248d6bfc6bfbbe64", size = 19510, upload-time = "2026-01-02T15:33:32.582Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9a/30/ab407e2ec752aa541704ed8f93c11e2a5d92c168b8a755d818b74a3c5c2d/filelock-3.20.2-py3-none-any.whl", hash = "sha256:fbba7237d6ea277175a32c54bb71ef814a8546d8601269e1bfc388de333974e8", size = 16697, upload-time = "2026-01-02T15:33:31.133Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fsspec"
|
||||||
|
version = "2025.12.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/b6/27/954057b0d1f53f086f681755207dda6de6c660ce133c829158e8e8fe7895/fsspec-2025.12.0.tar.gz", hash = "sha256:c505de011584597b1060ff778bb664c1bc022e87921b0e4f10cc9c44f9635973", size = 309748, upload-time = "2025-12-03T15:23:42.687Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/51/c7/b64cae5dba3a1b138d7123ec36bb5ccd39d39939f18454407e5468f4763f/fsspec-2025.12.0-py3-none-any.whl", hash = "sha256:8bf1fe301b7d8acfa6e8571e3b1c3d158f909666642431cc78a1b7b4dbc5ec5b", size = 201422, upload-time = "2025-12-03T15:23:41.434Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "h11"
|
||||||
|
version = "0.16.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hf-xet"
|
||||||
|
version = "1.2.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/5e/6e/0f11bacf08a67f7fb5ee09740f2ca54163863b07b70d579356e9222ce5d8/hf_xet-1.2.0.tar.gz", hash = "sha256:a8c27070ca547293b6890c4bf389f713f80e8c478631432962bb7f4bc0bd7d7f", size = 506020, upload-time = "2025-10-24T19:04:32.129Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9e/a5/85ef910a0aa034a2abcfadc360ab5ac6f6bc4e9112349bd40ca97551cff0/hf_xet-1.2.0-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:ceeefcd1b7aed4956ae8499e2199607765fbd1c60510752003b6cc0b8413b649", size = 2861870, upload-time = "2025-10-24T19:04:11.422Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ea/40/e2e0a7eb9a51fe8828ba2d47fe22a7e74914ea8a0db68a18c3aa7449c767/hf_xet-1.2.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b70218dd548e9840224df5638fdc94bd033552963cfa97f9170829381179c813", size = 2717584, upload-time = "2025-10-24T19:04:09.586Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a5/7d/daf7f8bc4594fdd59a8a596f9e3886133fdc68e675292218a5e4c1b7e834/hf_xet-1.2.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d40b18769bb9a8bc82a9ede575ce1a44c75eb80e7375a01d76259089529b5dc", size = 3315004, upload-time = "2025-10-24T19:04:00.314Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b1/ba/45ea2f605fbf6d81c8b21e4d970b168b18a53515923010c312c06cd83164/hf_xet-1.2.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:cd3a6027d59cfb60177c12d6424e31f4b5ff13d8e3a1247b3a584bf8977e6df5", size = 3222636, upload-time = "2025-10-24T19:03:58.111Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4a/1d/04513e3cab8f29ab8c109d309ddd21a2705afab9d52f2ba1151e0c14f086/hf_xet-1.2.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6de1fc44f58f6dd937956c8d304d8c2dea264c80680bcfa61ca4a15e7b76780f", size = 3408448, upload-time = "2025-10-24T19:04:20.951Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f0/7c/60a2756d7feec7387db3a1176c632357632fbe7849fce576c5559d4520c7/hf_xet-1.2.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f182f264ed2acd566c514e45da9f2119110e48a87a327ca271027904c70c5832", size = 3503401, upload-time = "2025-10-24T19:04:22.549Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4e/64/48fffbd67fb418ab07451e4ce641a70de1c40c10a13e25325e24858ebe5a/hf_xet-1.2.0-cp313-cp313t-win_amd64.whl", hash = "sha256:293a7a3787e5c95d7be1857358a9130694a9c6021de3f27fa233f37267174382", size = 2900866, upload-time = "2025-10-24T19:04:33.461Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e2/51/f7e2caae42f80af886db414d4e9885fac959330509089f97cccb339c6b87/hf_xet-1.2.0-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:10bfab528b968c70e062607f663e21e34e2bba349e8038db546646875495179e", size = 2861861, upload-time = "2025-10-24T19:04:19.01Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/6e/1d/a641a88b69994f9371bd347f1dd35e5d1e2e2460a2e350c8d5165fc62005/hf_xet-1.2.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:2a212e842647b02eb6a911187dc878e79c4aa0aa397e88dd3b26761676e8c1f8", size = 2717699, upload-time = "2025-10-24T19:04:17.306Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/df/e0/e5e9bba7d15f0318955f7ec3f4af13f92e773fbb368c0b8008a5acbcb12f/hf_xet-1.2.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30e06daccb3a7d4c065f34fc26c14c74f4653069bb2b194e7f18f17cbe9939c0", size = 3314885, upload-time = "2025-10-24T19:04:07.642Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/21/90/b7fe5ff6f2b7b8cbdf1bd56145f863c90a5807d9758a549bf3d916aa4dec/hf_xet-1.2.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:29c8fc913a529ec0a91867ce3d119ac1aac966e098cf49501800c870328cc090", size = 3221550, upload-time = "2025-10-24T19:04:05.55Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/6f/cb/73f276f0a7ce46cc6a6ec7d6c7d61cbfe5f2e107123d9bbd0193c355f106/hf_xet-1.2.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:66e159cbfcfbb29f920db2c09ed8b660eb894640d284f102ada929b6e3dc410a", size = 3408010, upload-time = "2025-10-24T19:04:28.598Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b8/1e/d642a12caa78171f4be64f7cd9c40e3ca5279d055d0873188a58c0f5fbb9/hf_xet-1.2.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:9c91d5ae931510107f148874e9e2de8a16052b6f1b3ca3c1b12f15ccb491390f", size = 3503264, upload-time = "2025-10-24T19:04:30.397Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/17/b5/33764714923fa1ff922770f7ed18c2daae034d21ae6e10dbf4347c854154/hf_xet-1.2.0-cp314-cp314t-win_amd64.whl", hash = "sha256:210d577732b519ac6ede149d2f2f34049d44e8622bf14eb3d63bbcd2d4b332dc", size = 2901071, upload-time = "2025-10-24T19:04:37.463Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/96/2d/22338486473df5923a9ab7107d375dbef9173c338ebef5098ef593d2b560/hf_xet-1.2.0-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:46740d4ac024a7ca9b22bebf77460ff43332868b661186a8e46c227fdae01848", size = 2866099, upload-time = "2025-10-24T19:04:15.366Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/7f/8c/c5becfa53234299bc2210ba314eaaae36c2875e0045809b82e40a9544f0c/hf_xet-1.2.0-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:27df617a076420d8845bea087f59303da8be17ed7ec0cd7ee3b9b9f579dff0e4", size = 2722178, upload-time = "2025-10-24T19:04:13.695Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9a/92/cf3ab0b652b082e66876d08da57fcc6fa2f0e6c70dfbbafbd470bb73eb47/hf_xet-1.2.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3651fd5bfe0281951b988c0facbe726aa5e347b103a675f49a3fa8144c7968fd", size = 3320214, upload-time = "2025-10-24T19:04:03.596Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/46/92/3f7ec4a1b6a65bf45b059b6d4a5d38988f63e193056de2f420137e3c3244/hf_xet-1.2.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:d06fa97c8562fb3ee7a378dd9b51e343bc5bc8190254202c9771029152f5e08c", size = 3229054, upload-time = "2025-10-24T19:04:01.949Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0b/dd/7ac658d54b9fb7999a0ccb07ad863b413cbaf5cf172f48ebcd9497ec7263/hf_xet-1.2.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:4c1428c9ae73ec0939410ec73023c4f842927f39db09b063b9482dac5a3bb737", size = 3413812, upload-time = "2025-10-24T19:04:24.585Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/92/68/89ac4e5b12a9ff6286a12174c8538a5930e2ed662091dd2572bbe0a18c8a/hf_xet-1.2.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a55558084c16b09b5ed32ab9ed38421e2d87cf3f1f89815764d1177081b99865", size = 3508920, upload-time = "2025-10-24T19:04:26.927Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/cb/44/870d44b30e1dcfb6a65932e3e1506c103a8a5aea9103c337e7a53180322c/hf_xet-1.2.0-cp37-abi3-win_amd64.whl", hash = "sha256:e6584a52253f72c9f52f9e549d5895ca7a471608495c4ecaa6cc73dba2b24d69", size = 2905735, upload-time = "2025-10-24T19:04:35.928Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "httpcore"
|
||||||
|
version = "1.0.9"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "certifi" },
|
||||||
|
{ name = "h11" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "httpx"
|
||||||
|
version = "0.28.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "anyio" },
|
||||||
|
{ name = "certifi" },
|
||||||
|
{ name = "httpcore" },
|
||||||
|
{ name = "idna" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "huggingface-hub"
|
||||||
|
version = "1.3.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "filelock" },
|
||||||
|
{ name = "fsspec" },
|
||||||
|
{ name = "hf-xet", marker = "platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" },
|
||||||
|
{ name = "httpx" },
|
||||||
|
{ name = "packaging" },
|
||||||
|
{ name = "pyyaml" },
|
||||||
|
{ name = "shellingham" },
|
||||||
|
{ name = "tqdm" },
|
||||||
|
{ name = "typer-slim" },
|
||||||
|
{ name = "typing-extensions" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/dd/dd/1cc985c5dda36298b152f75e82a1c81f52243b78fb7e9cad637a29561ad1/huggingface_hub-1.3.1.tar.gz", hash = "sha256:e80e0cfb4a75557c51ab20d575bdea6bb6106c2f97b7c75d8490642f1efb6df5", size = 622356, upload-time = "2026-01-09T14:08:16.888Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/90/fb/cb8fe5f71d5622427f20bcab9e06a696a5aaf21bfe7bd0a8a0c63c88abf5/huggingface_hub-1.3.1-py3-none-any.whl", hash = "sha256:efbc7f3153cb84e2bb69b62ed90985e21ecc9343d15647a419fc0ee4b85f0ac3", size = 533351, upload-time = "2026-01-09T14:08:14.519Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "idna"
|
||||||
|
version = "3.11"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/6f/6d/0703ccc57f3a7233505399edb88de3cbd678da106337b9fcde432b65ed60/idna-3.11.tar.gz", hash = "sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902", size = 194582, upload-time = "2025-10-12T14:55:20.501Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "iniconfig"
|
||||||
|
version = "2.3.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mlx"
|
||||||
|
version = "0.30.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
|
||||||
|
]
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/07/14/74acbd677ececd17a44dafda1b472aebacef54f60ff9a41a801f711de9a7/mlx-0.30.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:acfd7d1b8e5b9fa1b7e9fab4cc5ba6a492c559fbb1c5aeab16c1d7a148ab4f1b", size = 593048, upload-time = "2025-12-18T01:55:34.9Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/58/8c/5309848afb9c53d363f59b88ae5811de248e2817e91aeadf007e2ac8d22b/mlx-0.30.1-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:b62030471272d1835b8137164bd43d863cc93ff1d67ec4f1f87bb4c8613dd5a6", size = 593043, upload-time = "2025-12-18T01:55:36.839Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e8/5a/0039815a930f0193e2cffb27c57dc6971004bce0086c2bbbdb10395c272c/mlx-0.30.1-cp311-cp311-macosx_26_0_arm64.whl", hash = "sha256:0489cd340f2d262cb3aaad4368e40e84b152e182e4cea37ba018e56c72e1d020", size = 567014, upload-time = "2025-12-18T00:15:51.731Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/de/c7/6bdb5497c1f5ed3e33afa7785761ad87fd3436c071805d9a93c905943f04/mlx-0.30.1-cp311-cp311-manylinux_2_35_aarch64.whl", hash = "sha256:fbdcfc3ed556a7e701a8eb67da299e2a25f52615193833ca6374decca3be5bf4", size = 658930, upload-time = "2025-12-18T01:55:38.441Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/91/02/2d86a1c116e951eb4d88fe313c321e23628ce7404712e1258cacf925a8b8/mlx-0.30.1-cp311-cp311-manylinux_2_35_x86_64.whl", hash = "sha256:68ec854e7b5f89454e67d6c2fa7bb416b8afb148003ccd775904ec6ec4744818", size = 692484, upload-time = "2025-12-18T01:55:40.254Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3a/4b/ad57b2f0ede3f0d009c0e3e1270c219bd18f9025388855ee149680cffa20/mlx-0.30.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:deaef3ecd2f99930867a29de748e3bffa9cc7e4dfa834f2501c37ed29aece1cc", size = 593397, upload-time = "2025-12-18T01:55:41.814Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ef/14/7fa03a0f66ac3cfb2fd6752178a1488f13c7233fff26eed0f832d961db35/mlx-0.30.1-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:86ccdcda0b5ea4768b87da25beae5b83ac7cc802506116b6845cea6f450e2377", size = 593397, upload-time = "2025-12-18T01:55:43Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9c/c8/9f1343dbe2381f9653df4e0a62dc8bf38f575a2553dc2aa6916de32d2a85/mlx-0.30.1-cp312-cp312-macosx_26_0_arm64.whl", hash = "sha256:a625cb434b2acc5674fe10683374641dab9671fb354ae7c2c67a1fb0405eeb37", size = 567576, upload-time = "2025-12-18T00:15:55.114Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/15/ff/485ed9c99c18ef89ac987178c0a526cb4148ba38b14838d315311d9d76a8/mlx-0.30.1-cp312-cp312-manylinux_2_35_aarch64.whl", hash = "sha256:ccc1ff3aca8fb1073c7dcd1274cebe48ae75f852d14b16c7db8228fbbad594dd", size = 643654, upload-time = "2025-12-18T01:55:44.165Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8a/d3/54d3bf5e404c3b6424b49c505dc8b3c06c6bb498fe720195b1fafbd69b5e/mlx-0.30.1-cp312-cp312-manylinux_2_35_x86_64.whl", hash = "sha256:55ed7fc4b389d6e49dac6d34a97b41e61cbe3662ac29c3d29cf612e6b2ed9827", size = 687305, upload-time = "2025-12-18T01:55:45.526Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f9/fd/c6f56cd87d48763ed63655ace627c06db9819eae7d43d132f40d4965947a/mlx-0.30.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:743520758bc8261b2ed8f3b3dc96e4e9236769dd8f61fb17877c5e44037e2058", size = 593366, upload-time = "2025-12-18T01:55:46.786Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/dc/53/96d8c48b21f91c4216b6d2ef6dfc10862e5fb0b811a2aaf02c96c78601de/mlx-0.30.1-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:fc9745bc1860ca60128e3a6d36157da06d936e2b4007a4dcba990b40202f598f", size = 593368, upload-time = "2025-12-18T01:55:48.363Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/70/ce/476c3b7d3a4153bd0e1c5af1f1b6c09a804b652bbed34072404b322c22e0/mlx-0.30.1-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:a1480399c67bb327a66c5527b73915132e3fcaae3bce9634e5c81ccad9f43229", size = 567561, upload-time = "2025-12-18T00:15:56.153Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/33/41/7ad1e639fd7dd1cf01a62c1c5b051024a859888c27504996e9d8380e6754/mlx-0.30.1-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:8e19850a4236a8e174f851f5789b8b62a8eb74f5a8fa49ad8ba286c5ddb5f9bf", size = 643122, upload-time = "2025-12-18T01:55:49.607Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d0/dc/72d3737c5b0662eb5e785d353dbc5e34d793d27b09b99e39993ee051bd19/mlx-0.30.1-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:1c8ed5bcd9f1910fca209e95859ac737e60b3e1954181b820fa269158f81049a", size = 687254, upload-time = "2025-12-18T01:55:51.239Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9b/cc/523448996247bb05d9d68e23bccf3dafdda660befb9330f6bd5fa13361e8/mlx-0.30.1-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:d34cc2c25b0ee41c1349f14650db760e282685339858e305453f62405c12bc1b", size = 596006, upload-time = "2025-12-18T01:55:52.463Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/23/0e/f9f2f9659c34c87be8f4167f6a1d6ed7e826f4889d20eecd4c0d8122f0e9/mlx-0.30.1-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:4e47d301e9095b87f0bda8827bfd6ffe744223aba5cee8f28e25894d647f5823", size = 596008, upload-time = "2025-12-18T01:55:54.02Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/56/a7/49e41fb141de95b6a376091a963c737839c9cda04e423c67f57460a50458/mlx-0.30.1-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:cfba13e2a52255d663a1ad62f0f83eb3991e42147edf9a8d38cdd224e48ca49b", size = 570406, upload-time = "2025-12-18T00:15:57.177Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/73/99/a43cb112167cf865c069f5e108ae42f5314663930ff3dd86c2d23d984191/mlx-0.30.1-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:bebfec377208eb29cc88aa86c897c7446aa0984838669e138f273f9225d627ff", size = 646461, upload-time = "2025-12-18T01:55:55.285Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d4/ff/1e1968f107b4221a98dc26832586b1f646b27ddf3e55c95051c09d751f0a/mlx-0.30.1-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:d18012d5cf0f013bc4a405cfd1e9d2d28e798f4d2dc4f15aa0fbffff73c02ba2", size = 687114, upload-time = "2025-12-18T01:55:56.506Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mlx-metal"
|
||||||
|
version = "0.30.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/09/3f/0be35ddad7e13d8ecd33a9185895f9739bb00b96ef0cce36cf0405d4aec0/mlx_metal-0.30.1-py3-none-macosx_14_0_arm64.whl", hash = "sha256:e7e92c6bdbd7ac8083f528a4c6640552ae106a57bb3d99856ac10a32e93a4b5e", size = 36864966, upload-time = "2025-12-18T01:55:31.473Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1e/1f/c0bddd0d5bf3871411aabe32121e09e1b7cdbece8917a49d5a442310e3e5/mlx_metal-0.30.1-py3-none-macosx_15_0_arm64.whl", hash = "sha256:bb50f57418af7fc3c42a2da2c4bde0e7ab7ac0b997de1f6f642a6680ac65d626", size = 36859011, upload-time = "2025-12-18T01:55:34.541Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/67/b3/73cc2f584ac612a476096d35a61eed75ee7ed8b4e320b0c36cf60a14d4eb/mlx_metal-0.30.1-py3-none-macosx_26_0_arm64.whl", hash = "sha256:e0b151a0053ac00b4226710bfb6dbf54b87283fb01e10fb3877f9ea969f680aa", size = 44981160, upload-time = "2025-12-18T00:15:47.518Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mlx-video"
|
||||||
|
version = "0.0.1"
|
||||||
|
source = { editable = "." }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "huggingface-hub" },
|
||||||
|
{ name = "mlx" },
|
||||||
|
{ name = "numpy" },
|
||||||
|
{ name = "safetensors" },
|
||||||
|
{ name = "tqdm" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
{ name = "pytest" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.metadata]
|
||||||
|
requires-dist = [
|
||||||
|
{ name = "huggingface-hub" },
|
||||||
|
{ name = "mlx", specifier = ">=0.22.0" },
|
||||||
|
{ name = "numpy" },
|
||||||
|
{ name = "pytest", marker = "extra == 'dev'" },
|
||||||
|
{ name = "safetensors" },
|
||||||
|
{ name = "tqdm" },
|
||||||
|
]
|
||||||
|
provides-extras = ["dev"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "numpy"
|
||||||
|
version = "2.4.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/a4/7a/6a3d14e205d292b738db449d0de649b373a59edb0d0b4493821d0a3e8718/numpy-2.4.0.tar.gz", hash = "sha256:6e504f7b16118198f138ef31ba24d985b124c2c469fe8467007cf30fd992f934", size = 20685720, upload-time = "2025-12-20T16:18:19.023Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/26/7e/7bae7cbcc2f8132271967aa03e03954fc1e48aa1f3bf32b29ca95fbef352/numpy-2.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:316b2f2584682318539f0bcaca5a496ce9ca78c88066579ebd11fd06f8e4741e", size = 16940166, upload-time = "2025-12-20T16:15:43.434Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0f/27/6c13f5b46776d6246ec884ac5817452672156a506d08a1f2abb39961930a/numpy-2.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a2718c1de8504121714234b6f8241d0019450353276c88b9453c9c3d92e101db", size = 12641781, upload-time = "2025-12-20T16:15:45.701Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/14/1c/83b4998d4860d15283241d9e5215f28b40ac31f497c04b12fa7f428ff370/numpy-2.4.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:21555da4ec4a0c942520ead42c3b0dc9477441e085c42b0fbdd6a084869a6f6b", size = 5470247, upload-time = "2025-12-20T16:15:47.943Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/54/08/cbce72c835d937795571b0464b52069f869c9e78b0c076d416c5269d2718/numpy-2.4.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:413aa561266a4be2d06cd2b9665e89d9f54c543f418773076a76adcf2af08bc7", size = 6799807, upload-time = "2025-12-20T16:15:49.795Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ff/be/2e647961cd8c980591d75cdcd9e8f647d69fbe05e2a25613dc0a2ea5fb1a/numpy-2.4.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0feafc9e03128074689183031181fac0897ff169692d8492066e949041096548", size = 14701992, upload-time = "2025-12-20T16:15:51.615Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a2/fb/e1652fb8b6fd91ce6ed429143fe2e01ce714711e03e5b762615e7b36172c/numpy-2.4.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a8fdfed3deaf1928fb7667d96e0567cdf58c2b370ea2ee7e586aa383ec2cb346", size = 16646871, upload-time = "2025-12-20T16:15:54.129Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/62/23/d841207e63c4322842f7cd042ae981cffe715c73376dcad8235fb31debf1/numpy-2.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:e06a922a469cae9a57100864caf4f8a97a1026513793969f8ba5b63137a35d25", size = 16487190, upload-time = "2025-12-20T16:15:56.147Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/bc/a0/6a842c8421ebfdec0a230e65f61e0dabda6edbef443d999d79b87c273965/numpy-2.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:927ccf5cd17c48f801f4ed43a7e5673a2724bd2171460be3e3894e6e332ef83a", size = 18580762, upload-time = "2025-12-20T16:15:58.524Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0a/d1/c79e0046641186f2134dde05e6181825b911f8bdcef31b19ddd16e232847/numpy-2.4.0-cp311-cp311-win32.whl", hash = "sha256:882567b7ae57c1b1a0250208cc21a7976d8cbcc49d5a322e607e6f09c9e0bd53", size = 6233359, upload-time = "2025-12-20T16:16:00.938Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/fc/f0/74965001d231f28184d6305b8cdc1b6fcd4bf23033f6cb039cfe76c9fca7/numpy-2.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:8b986403023c8f3bf8f487c2e6186afda156174d31c175f747d8934dfddf3479", size = 12601132, upload-time = "2025-12-20T16:16:02.484Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/65/32/55408d0f46dfebce38017f5bd931affa7256ad6beac1a92a012e1fbc67a7/numpy-2.4.0-cp311-cp311-win_arm64.whl", hash = "sha256:3f3096405acc48887458bbf9f6814d43785ac7ba2a57ea6442b581dedbc60ce6", size = 10573977, upload-time = "2025-12-20T16:16:04.77Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8b/ff/f6400ffec95de41c74b8e73df32e3fff1830633193a7b1e409be7fb1bb8c/numpy-2.4.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2a8b6bb8369abefb8bd1801b054ad50e02b3275c8614dc6e5b0373c305291037", size = 16653117, upload-time = "2025-12-20T16:16:06.709Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/fd/28/6c23e97450035072e8d830a3c411bf1abd1f42c611ff9d29e3d8f55c6252/numpy-2.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2e284ca13d5a8367e43734148622caf0b261b275673823593e3e3634a6490f83", size = 12369711, upload-time = "2025-12-20T16:16:08.758Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/bc/af/acbef97b630ab1bb45e6a7d01d1452e4251aa88ce680ac36e56c272120ec/numpy-2.4.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:49ff32b09f5aa0cd30a20c2b39db3e669c845589f2b7fc910365210887e39344", size = 5198355, upload-time = "2025-12-20T16:16:10.902Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c1/c8/4e0d436b66b826f2e53330adaa6311f5cac9871a5b5c31ad773b27f25a74/numpy-2.4.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:36cbfb13c152b1c7c184ddac43765db8ad672567e7bafff2cc755a09917ed2e6", size = 6545298, upload-time = "2025-12-20T16:16:12.607Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ef/27/e1f5d144ab54eac34875e79037011d511ac57b21b220063310cb96c80fbc/numpy-2.4.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:35ddc8f4914466e6fc954c76527aa91aa763682a4f6d73249ef20b418fe6effb", size = 14398387, upload-time = "2025-12-20T16:16:14.257Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/67/64/4cb909dd5ab09a9a5d086eff9586e69e827b88a5585517386879474f4cf7/numpy-2.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dc578891de1db95b2a35001b695451767b580bb45753717498213c5ff3c41d63", size = 16363091, upload-time = "2025-12-20T16:16:17.32Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9d/9c/8efe24577523ec6809261859737cf117b0eb6fdb655abdfdc81b2e468ce4/numpy-2.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:98e81648e0b36e325ab67e46b5400a7a6d4a22b8a7c8e8bbfe20e7db7906bf95", size = 16176394, upload-time = "2025-12-20T16:16:19.524Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/61/f0/1687441ece7b47a62e45a1f82015352c240765c707928edd8aef875d5951/numpy-2.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d57b5046c120561ba8fa8e4030fbb8b822f3063910fa901ffadf16e2b7128ad6", size = 18287378, upload-time = "2025-12-20T16:16:22.866Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d3/6f/f868765d44e6fc466467ed810ba9d8d6db1add7d4a748abfa2a4c99a3194/numpy-2.4.0-cp312-cp312-win32.whl", hash = "sha256:92190db305a6f48734d3982f2c60fa30d6b5ee9bff10f2887b930d7b40119f4c", size = 5955432, upload-time = "2025-12-20T16:16:25.06Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d4/b5/94c1e79fcbab38d1ca15e13777477b2914dd2d559b410f96949d6637b085/numpy-2.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:680060061adb2d74ce352628cb798cfdec399068aa7f07ba9fb818b2b3305f98", size = 12306201, upload-time = "2025-12-20T16:16:26.979Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/70/09/c39dadf0b13bb0768cd29d6a3aaff1fb7c6905ac40e9aaeca26b1c086e06/numpy-2.4.0-cp312-cp312-win_arm64.whl", hash = "sha256:39699233bc72dd482da1415dcb06076e32f60eddc796a796c5fb6c5efce94667", size = 10308234, upload-time = "2025-12-20T16:16:29.417Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a7/0d/853fd96372eda07c824d24adf02e8bc92bb3731b43a9b2a39161c3667cc4/numpy-2.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a152d86a3ae00ba5f47b3acf3b827509fd0b6cb7d3259665e63dafbad22a75ea", size = 16649088, upload-time = "2025-12-20T16:16:31.421Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e3/37/cc636f1f2a9f585434e20a3e6e63422f70bfe4f7f6698e941db52ea1ac9a/numpy-2.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:39b19251dec4de8ff8496cd0806cbe27bf0684f765abb1f4809554de93785f2d", size = 12364065, upload-time = "2025-12-20T16:16:33.491Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ed/69/0b78f37ca3690969beee54103ce5f6021709134e8020767e93ba691a72f1/numpy-2.4.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:009bd0ea12d3c784b6639a8457537016ce5172109e585338e11334f6a7bb88ee", size = 5192640, upload-time = "2025-12-20T16:16:35.636Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1d/2a/08569f8252abf590294dbb09a430543ec8f8cc710383abfb3e75cc73aeda/numpy-2.4.0-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:5fe44e277225fd3dff6882d86d3d447205d43532c3627313d17e754fb3905a0e", size = 6541556, upload-time = "2025-12-20T16:16:37.276Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/93/e9/a949885a4e177493d61519377952186b6cbfdf1d6002764c664ba28349b5/numpy-2.4.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f935c4493eda9069851058fa0d9e39dbf6286be690066509305e52912714dbb2", size = 14396562, upload-time = "2025-12-20T16:16:38.953Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/99/98/9d4ad53b0e9ef901c2ef1d550d2136f5ac42d3fd2988390a6def32e23e48/numpy-2.4.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8cfa5f29a695cb7438965e6c3e8d06e0416060cf0d709c1b1c1653a939bf5c2a", size = 16351719, upload-time = "2025-12-20T16:16:41.503Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/28/de/5f3711a38341d6e8dd619f6353251a0cdd07f3d6d101a8fd46f4ef87f895/numpy-2.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ba0cb30acd3ef11c94dc27fbfba68940652492bc107075e7ffe23057f9425681", size = 16176053, upload-time = "2025-12-20T16:16:44.552Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/2a/5b/2a3753dc43916501b4183532e7ace862e13211042bceafa253afb5c71272/numpy-2.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:60e8c196cd82cbbd4f130b5290007e13e6de3eca79f0d4d38014769d96a7c475", size = 18277859, upload-time = "2025-12-20T16:16:47.174Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/2c/c5/a18bcdd07a941db3076ef489d036ab16d2bfc2eae0cf27e5a26e29189434/numpy-2.4.0-cp313-cp313-win32.whl", hash = "sha256:5f48cb3e88fbc294dc90e215d86fbaf1c852c63dbdb6c3a3e63f45c4b57f7344", size = 5953849, upload-time = "2025-12-20T16:16:49.554Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4f/f1/719010ff8061da6e8a26e1980cf090412d4f5f8060b31f0c45d77dd67a01/numpy-2.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:a899699294f28f7be8992853c0c60741f16ff199205e2e6cdca155762cbaa59d", size = 12302840, upload-time = "2025-12-20T16:16:51.227Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f5/5a/b3d259083ed8b4d335270c76966cb6cf14a5d1b69e1a608994ac57a659e6/numpy-2.4.0-cp313-cp313-win_arm64.whl", hash = "sha256:9198f447e1dc5647d07c9a6bbe2063cc0132728cc7175b39dbc796da5b54920d", size = 10308509, upload-time = "2025-12-20T16:16:53.313Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/31/01/95edcffd1bb6c0633df4e808130545c4f07383ab629ac7e316fb44fff677/numpy-2.4.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:74623f2ab5cc3f7c886add4f735d1031a1d2be4a4ae63c0546cfd74e7a31ddf6", size = 12491815, upload-time = "2025-12-20T16:16:55.496Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/59/ea/5644b8baa92cc1c7163b4b4458c8679852733fa74ca49c942cfa82ded4e0/numpy-2.4.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:0804a8e4ab070d1d35496e65ffd3cf8114c136a2b81f61dfab0de4b218aacfd5", size = 5320321, upload-time = "2025-12-20T16:16:57.468Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/26/4e/e10938106d70bc21319bd6a86ae726da37edc802ce35a3a71ecdf1fdfe7f/numpy-2.4.0-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:02a2038eb27f9443a8b266a66911e926566b5a6ffd1a689b588f7f35b81e7dc3", size = 6641635, upload-time = "2025-12-20T16:16:59.379Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b3/8d/a8828e3eaf5c0b4ab116924df82f24ce3416fa38d0674d8f708ddc6c8aac/numpy-2.4.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1889b3a3f47a7b5bee16bc25a2145bd7cb91897f815ce3499db64c7458b6d91d", size = 14456053, upload-time = "2025-12-20T16:17:01.768Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/68/a1/17d97609d87d4520aa5ae2dcfb32305654550ac6a35effb946d303e594ce/numpy-2.4.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:85eef4cb5625c47ee6425c58a3502555e10f45ee973da878ac8248ad58c136f3", size = 16401702, upload-time = "2025-12-20T16:17:04.235Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/18/32/0f13c1b2d22bea1118356b8b963195446f3af124ed7a5adfa8fdecb1b6ca/numpy-2.4.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6dc8b7e2f4eb184b37655195f421836cfae6f58197b67e3ffc501f1333d993fa", size = 16242493, upload-time = "2025-12-20T16:17:06.856Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ae/23/48f21e3d309fbc137c068a1475358cbd3a901b3987dcfc97a029ab3068e2/numpy-2.4.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:44aba2f0cafd287871a495fb3163408b0bd25bbce135c6f621534a07f4f7875c", size = 18324222, upload-time = "2025-12-20T16:17:09.392Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ac/52/41f3d71296a3dcaa4f456aaa3c6fc8e745b43d0552b6bde56571bb4b4a0f/numpy-2.4.0-cp313-cp313t-win32.whl", hash = "sha256:20c115517513831860c573996e395707aa9fb691eb179200125c250e895fcd93", size = 6076216, upload-time = "2025-12-20T16:17:11.437Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/35/ff/46fbfe60ab0710d2a2b16995f708750307d30eccbb4c38371ea9e986866e/numpy-2.4.0-cp313-cp313t-win_amd64.whl", hash = "sha256:b48e35f4ab6f6a7597c46e301126ceba4c44cd3280e3750f85db48b082624fa4", size = 12444263, upload-time = "2025-12-20T16:17:13.182Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a3/e3/9189ab319c01d2ed556c932ccf55064c5d75bb5850d1df7a482ce0badead/numpy-2.4.0-cp313-cp313t-win_arm64.whl", hash = "sha256:4d1cfce39e511069b11e67cd0bd78ceff31443b7c9e5c04db73c7a19f572967c", size = 10378265, upload-time = "2025-12-20T16:17:15.211Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ab/ed/52eac27de39d5e5a6c9aadabe672bc06f55e24a3d9010cd1183948055d76/numpy-2.4.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:c95eb6db2884917d86cde0b4d4cf31adf485c8ec36bf8696dd66fa70de96f36b", size = 16647476, upload-time = "2025-12-20T16:17:17.671Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/77/c0/990ce1b7fcd4e09aeaa574e2a0a839589e4b08b2ca68070f1acb1fea6736/numpy-2.4.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:65167da969cd1ec3a1df31cb221ca3a19a8aaa25370ecb17d428415e93c1935e", size = 12374563, upload-time = "2025-12-20T16:17:20.216Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/37/7c/8c5e389c6ae8f5fd2277a988600d79e9625db3fff011a2d87ac80b881a4c/numpy-2.4.0-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:3de19cfecd1465d0dcf8a5b5ea8b3155b42ed0b639dba4b71e323d74f2a3be5e", size = 5203107, upload-time = "2025-12-20T16:17:22.47Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e6/94/ca5b3bd6a8a70a5eec9a0b8dd7f980c1eff4b8a54970a9a7fef248ef564f/numpy-2.4.0-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:6c05483c3136ac4c91b4e81903cb53a8707d316f488124d0398499a4f8e8ef51", size = 6538067, upload-time = "2025-12-20T16:17:24.001Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/79/43/993eb7bb5be6761dde2b3a3a594d689cec83398e3f58f4758010f3b85727/numpy-2.4.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:36667db4d6c1cea79c8930ab72fadfb4060feb4bfe724141cd4bd064d2e5f8ce", size = 14411926, upload-time = "2025-12-20T16:17:25.822Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/03/75/d4c43b61de473912496317a854dac54f1efec3eeb158438da6884b70bb90/numpy-2.4.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9a818668b674047fd88c4cddada7ab8f1c298812783e8328e956b78dc4807f9f", size = 16354295, upload-time = "2025-12-20T16:17:28.308Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b8/0a/b54615b47ee8736a6461a4bb6749128dd3435c5a759d5663f11f0e9af4ac/numpy-2.4.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:1ee32359fb7543b7b7bd0b2f46294db27e29e7bbdf70541e81b190836cd83ded", size = 16190242, upload-time = "2025-12-20T16:17:30.993Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/98/ce/ea207769aacad6246525ec6c6bbd66a2bf56c72443dc10e2f90feed29290/numpy-2.4.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:e493962256a38f58283de033d8af176c5c91c084ea30f15834f7545451c42059", size = 18280875, upload-time = "2025-12-20T16:17:33.327Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/17/ef/ec409437aa962ea372ed601c519a2b141701683ff028f894b7466f0ab42b/numpy-2.4.0-cp314-cp314-win32.whl", hash = "sha256:6bbaebf0d11567fa8926215ae731e1d58e6ec28a8a25235b8a47405d301332db", size = 6002530, upload-time = "2025-12-20T16:17:35.729Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5f/4a/5cb94c787a3ed1ac65e1271b968686521169a7b3ec0b6544bb3ca32960b0/numpy-2.4.0-cp314-cp314-win_amd64.whl", hash = "sha256:3d857f55e7fdf7c38ab96c4558c95b97d1c685be6b05c249f5fdafcbd6f9899e", size = 12435890, upload-time = "2025-12-20T16:17:37.599Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/48/a0/04b89db963af9de1104975e2544f30de89adbf75b9e75f7dd2599be12c79/numpy-2.4.0-cp314-cp314-win_arm64.whl", hash = "sha256:bb50ce5fb202a26fd5404620e7ef820ad1ab3558b444cb0b55beb7ef66cd2d63", size = 10591892, upload-time = "2025-12-20T16:17:39.649Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/53/e5/d74b5ccf6712c06c7a545025a6a71bfa03bdc7e0568b405b0d655232fd92/numpy-2.4.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:355354388cba60f2132df297e2d53053d4063f79077b67b481d21276d61fc4df", size = 12494312, upload-time = "2025-12-20T16:17:41.714Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c2/08/3ca9cc2ddf54dfee7ae9a6479c071092a228c68aef08252aa08dac2af002/numpy-2.4.0-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:1d8f9fde5f6dc1b6fc34df8162f3b3079365468703fee7f31d4e0cc8c63baed9", size = 5322862, upload-time = "2025-12-20T16:17:44.145Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/87/74/0bb63a68394c0c1e52670cfff2e309afa41edbe11b3327d9af29e4383f34/numpy-2.4.0-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:e0434aa22c821f44eeb4c650b81c7fbdd8c0122c6c4b5a576a76d5a35625ecd9", size = 6644986, upload-time = "2025-12-20T16:17:46.203Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/06/8f/9264d9bdbcf8236af2823623fe2f3981d740fc3461e2787e231d97c38c28/numpy-2.4.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:40483b2f2d3ba7aad426443767ff5632ec3156ef09742b96913787d13c336471", size = 14457958, upload-time = "2025-12-20T16:17:48.017Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8c/d9/f9a69ae564bbc7236a35aa883319364ef5fd41f72aa320cc1cbe66148fe2/numpy-2.4.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d9e6a7664ddd9746e20b7325351fe1a8408d0a2bf9c63b5e898290ddc8f09544", size = 16398394, upload-time = "2025-12-20T16:17:50.409Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/34/c7/39241501408dde7f885d241a98caba5421061a2c6d2b2197ac5e3aa842d8/numpy-2.4.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:ecb0019d44f4cdb50b676c5d0cb4b1eae8e15d1ed3d3e6639f986fc92b2ec52c", size = 16241044, upload-time = "2025-12-20T16:17:52.661Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/7c/95/cae7effd90e065a95e59fe710eeee05d7328ed169776dfdd9f789e032125/numpy-2.4.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:d0ffd9e2e4441c96a9c91ec1783285d80bf835b677853fc2770a89d50c1e48ac", size = 18321772, upload-time = "2025-12-20T16:17:54.947Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/96/df/3c6c279accd2bfb968a76298e5b276310bd55d243df4fa8ac5816d79347d/numpy-2.4.0-cp314-cp314t-win32.whl", hash = "sha256:77f0d13fa87036d7553bf81f0e1fe3ce68d14c9976c9851744e4d3e91127e95f", size = 6148320, upload-time = "2025-12-20T16:17:57.249Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/92/8d/f23033cce252e7a75cae853d17f582e86534c46404dea1c8ee094a9d6d84/numpy-2.4.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b1f5b45829ac1848893f0ddf5cb326110604d6df96cdc255b0bf9edd154104d4", size = 12623460, upload-time = "2025-12-20T16:17:58.963Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a4/4f/1f8475907d1a7c4ef9020edf7f39ea2422ec896849245f00688e4b268a71/numpy-2.4.0-cp314-cp314t-win_arm64.whl", hash = "sha256:23a3e9d1a6f360267e8fbb38ba5db355a6a7e9be71d7fce7ab3125e88bb646c8", size = 10661799, upload-time = "2025-12-20T16:18:01.078Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4b/ef/088e7c7342f300aaf3ee5f2c821c4b9996a1bef2aaf6a49cc8ab4883758e/numpy-2.4.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:b54c83f1c0c0f1d748dca0af516062b8829d53d1f0c402be24b4257a9c48ada6", size = 16819003, upload-time = "2025-12-20T16:18:03.41Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ff/ce/a53017b5443b4b84517182d463fc7bcc2adb4faa8b20813f8e5f5aeb5faa/numpy-2.4.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:aabb081ca0ec5d39591fc33018cd4b3f96e1a2dd6756282029986d00a785fba4", size = 12567105, upload-time = "2025-12-20T16:18:05.594Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/77/58/5ff91b161f2ec650c88a626c3905d938c89aaadabd0431e6d9c1330c83e2/numpy-2.4.0-pp311-pypy311_pp73-macosx_14_0_arm64.whl", hash = "sha256:8eafe7c36c8430b7794edeab3087dec7bf31d634d92f2af9949434b9d1964cba", size = 5395590, upload-time = "2025-12-20T16:18:08.031Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1d/4e/f1a084106df8c2df8132fc437e56987308e0524836aa7733721c8429d4fe/numpy-2.4.0-pp311-pypy311_pp73-macosx_14_0_x86_64.whl", hash = "sha256:2f585f52b2baf07ff3356158d9268ea095e221371f1074fadea2f42544d58b4d", size = 6709947, upload-time = "2025-12-20T16:18:09.836Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/63/09/3d8aeb809c0332c3f642da812ac2e3d74fc9252b3021f8c30c82e99e3f3d/numpy-2.4.0-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:32ed06d0fe9cae27d8fb5f400c63ccee72370599c75e683a6358dd3a4fb50aaf", size = 14535119, upload-time = "2025-12-20T16:18:12.105Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/fd/7f/68f0fc43a2cbdc6bb239160c754d87c922f60fbaa0fa3cd3d312b8a7f5ee/numpy-2.4.0-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:57c540ed8fb1f05cb997c6761cd56db72395b0d6985e90571ff660452ade4f98", size = 16475815, upload-time = "2025-12-20T16:18:14.433Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/11/73/edeacba3167b1ca66d51b1a5a14697c2c40098b5ffa01811c67b1785a5ab/numpy-2.4.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:a39fb973a726e63223287adc6dafe444ce75af952d711e400f3bf2b36ef55a7b", size = 12489376, upload-time = "2025-12-20T16:18:16.524Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "packaging"
|
||||||
|
version = "25.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pluggy"
|
||||||
|
version = "1.6.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pygments"
|
||||||
|
version = "2.19.2"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytest"
|
||||||
|
version = "9.0.2"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||||
|
{ name = "iniconfig" },
|
||||||
|
{ name = "packaging" },
|
||||||
|
{ name = "pluggy" },
|
||||||
|
{ name = "pygments" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyyaml"
|
||||||
|
version = "6.0.3"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/05/8e/961c0007c59b8dd7729d542c61a4d537767a59645b82a0b521206e1e25c2/pyyaml-6.0.3.tar.gz", hash = "sha256:d76623373421df22fb4cf8817020cbb7ef15c725b9d5e45f17e189bfc384190f", size = 130960, upload-time = "2025-09-25T21:33:16.546Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/6d/16/a95b6757765b7b031c9374925bb718d55e0a9ba8a1b6a12d25962ea44347/pyyaml-6.0.3-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:44edc647873928551a01e7a563d7452ccdebee747728c1080d881d68af7b997e", size = 185826, upload-time = "2025-09-25T21:31:58.655Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/16/19/13de8e4377ed53079ee996e1ab0a9c33ec2faf808a4647b7b4c0d46dd239/pyyaml-6.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:652cb6edd41e718550aad172851962662ff2681490a8a711af6a4d288dd96824", size = 175577, upload-time = "2025-09-25T21:32:00.088Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/0c/62/d2eb46264d4b157dae1275b573017abec435397aa59cbcdab6fc978a8af4/pyyaml-6.0.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:10892704fc220243f5305762e276552a0395f7beb4dbf9b14ec8fd43b57f126c", size = 775556, upload-time = "2025-09-25T21:32:01.31Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/10/cb/16c3f2cf3266edd25aaa00d6c4350381c8b012ed6f5276675b9eba8d9ff4/pyyaml-6.0.3-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:850774a7879607d3a6f50d36d04f00ee69e7fc816450e5f7e58d7f17f1ae5c00", size = 882114, upload-time = "2025-09-25T21:32:03.376Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/71/60/917329f640924b18ff085ab889a11c763e0b573da888e8404ff486657602/pyyaml-6.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b8bb0864c5a28024fac8a632c443c87c5aa6f215c0b126c449ae1a150412f31d", size = 806638, upload-time = "2025-09-25T21:32:04.553Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/dd/6f/529b0f316a9fd167281a6c3826b5583e6192dba792dd55e3203d3f8e655a/pyyaml-6.0.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1d37d57ad971609cf3c53ba6a7e365e40660e3be0e5175fa9f2365a379d6095a", size = 767463, upload-time = "2025-09-25T21:32:06.152Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f2/6a/b627b4e0c1dd03718543519ffb2f1deea4a1e6d42fbab8021936a4d22589/pyyaml-6.0.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:37503bfbfc9d2c40b344d06b2199cf0e96e97957ab1c1b546fd4f87e53e5d3e4", size = 794986, upload-time = "2025-09-25T21:32:07.367Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/45/91/47a6e1c42d9ee337c4839208f30d9f09caa9f720ec7582917b264defc875/pyyaml-6.0.3-cp311-cp311-win32.whl", hash = "sha256:8098f252adfa6c80ab48096053f512f2321f0b998f98150cea9bd23d83e1467b", size = 142543, upload-time = "2025-09-25T21:32:08.95Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/da/e3/ea007450a105ae919a72393cb06f122f288ef60bba2dc64b26e2646fa315/pyyaml-6.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:9f3bfb4965eb874431221a3ff3fdcddc7e74e3b07799e0e84ca4a0f867d449bf", size = 158763, upload-time = "2025-09-25T21:32:09.96Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d1/33/422b98d2195232ca1826284a76852ad5a86fe23e31b009c9886b2d0fb8b2/pyyaml-6.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7f047e29dcae44602496db43be01ad42fc6f1cc0d8cd6c83d342306c32270196", size = 182063, upload-time = "2025-09-25T21:32:11.445Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/89/a0/6cf41a19a1f2f3feab0e9c0b74134aa2ce6849093d5517a0c550fe37a648/pyyaml-6.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fc09d0aa354569bc501d4e787133afc08552722d3ab34836a80547331bb5d4a0", size = 173973, upload-time = "2025-09-25T21:32:12.492Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ed/23/7a778b6bd0b9a8039df8b1b1d80e2e2ad78aa04171592c8a5c43a56a6af4/pyyaml-6.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9149cad251584d5fb4981be1ecde53a1ca46c891a79788c0df828d2f166bda28", size = 775116, upload-time = "2025-09-25T21:32:13.652Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/65/30/d7353c338e12baef4ecc1b09e877c1970bd3382789c159b4f89d6a70dc09/pyyaml-6.0.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5fdec68f91a0c6739b380c83b951e2c72ac0197ace422360e6d5a959d8d97b2c", size = 844011, upload-time = "2025-09-25T21:32:15.21Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8b/9d/b3589d3877982d4f2329302ef98a8026e7f4443c765c46cfecc8858c6b4b/pyyaml-6.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ba1cc08a7ccde2d2ec775841541641e4548226580ab850948cbfda66a1befcdc", size = 807870, upload-time = "2025-09-25T21:32:16.431Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/05/c0/b3be26a015601b822b97d9149ff8cb5ead58c66f981e04fedf4e762f4bd4/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8dc52c23056b9ddd46818a57b78404882310fb473d63f17b07d5c40421e47f8e", size = 761089, upload-time = "2025-09-25T21:32:17.56Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/be/8e/98435a21d1d4b46590d5459a22d88128103f8da4c2d4cb8f14f2a96504e1/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:41715c910c881bc081f1e8872880d3c650acf13dfa8214bad49ed4cede7c34ea", size = 790181, upload-time = "2025-09-25T21:32:18.834Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/74/93/7baea19427dcfbe1e5a372d81473250b379f04b1bd3c4c5ff825e2327202/pyyaml-6.0.3-cp312-cp312-win32.whl", hash = "sha256:96b533f0e99f6579b3d4d4995707cf36df9100d67e0c8303a0c55b27b5f99bc5", size = 137658, upload-time = "2025-09-25T21:32:20.209Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/86/bf/899e81e4cce32febab4fb42bb97dcdf66bc135272882d1987881a4b519e9/pyyaml-6.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:5fcd34e47f6e0b794d17de1b4ff496c00986e1c83f7ab2fb8fcfe9616ff7477b", size = 154003, upload-time = "2025-09-25T21:32:21.167Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1a/08/67bd04656199bbb51dbed1439b7f27601dfb576fb864099c7ef0c3e55531/pyyaml-6.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:64386e5e707d03a7e172c0701abfb7e10f0fb753ee1d773128192742712a98fd", size = 140344, upload-time = "2025-09-25T21:32:22.617Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d1/11/0fd08f8192109f7169db964b5707a2f1e8b745d4e239b784a5a1dd80d1db/pyyaml-6.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8da9669d359f02c0b91ccc01cac4a67f16afec0dac22c2ad09f46bee0697eba8", size = 181669, upload-time = "2025-09-25T21:32:23.673Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b1/16/95309993f1d3748cd644e02e38b75d50cbc0d9561d21f390a76242ce073f/pyyaml-6.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2283a07e2c21a2aa78d9c4442724ec1eb15f5e42a723b99cb3d822d48f5f7ad1", size = 173252, upload-time = "2025-09-25T21:32:25.149Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/50/31/b20f376d3f810b9b2371e72ef5adb33879b25edb7a6d072cb7ca0c486398/pyyaml-6.0.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ee2922902c45ae8ccada2c5b501ab86c36525b883eff4255313a253a3160861c", size = 767081, upload-time = "2025-09-25T21:32:26.575Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/49/1e/a55ca81e949270d5d4432fbbd19dfea5321eda7c41a849d443dc92fd1ff7/pyyaml-6.0.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a33284e20b78bd4a18c8c2282d549d10bc8408a2a7ff57653c0cf0b9be0afce5", size = 841159, upload-time = "2025-09-25T21:32:27.727Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/74/27/e5b8f34d02d9995b80abcef563ea1f8b56d20134d8f4e5e81733b1feceb2/pyyaml-6.0.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0f29edc409a6392443abf94b9cf89ce99889a1dd5376d94316ae5145dfedd5d6", size = 801626, upload-time = "2025-09-25T21:32:28.878Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f9/11/ba845c23988798f40e52ba45f34849aa8a1f2d4af4b798588010792ebad6/pyyaml-6.0.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f7057c9a337546edc7973c0d3ba84ddcdf0daa14533c2065749c9075001090e6", size = 753613, upload-time = "2025-09-25T21:32:30.178Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3d/e0/7966e1a7bfc0a45bf0a7fb6b98ea03fc9b8d84fa7f2229e9659680b69ee3/pyyaml-6.0.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:eda16858a3cab07b80edaf74336ece1f986ba330fdb8ee0d6c0d68fe82bc96be", size = 794115, upload-time = "2025-09-25T21:32:31.353Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/de/94/980b50a6531b3019e45ddeada0626d45fa85cbe22300844a7983285bed3b/pyyaml-6.0.3-cp313-cp313-win32.whl", hash = "sha256:d0eae10f8159e8fdad514efdc92d74fd8d682c933a6dd088030f3834bc8e6b26", size = 137427, upload-time = "2025-09-25T21:32:32.58Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/97/c9/39d5b874e8b28845e4ec2202b5da735d0199dbe5b8fb85f91398814a9a46/pyyaml-6.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:79005a0d97d5ddabfeeea4cf676af11e647e41d81c9a7722a193022accdb6b7c", size = 154090, upload-time = "2025-09-25T21:32:33.659Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/73/e8/2bdf3ca2090f68bb3d75b44da7bbc71843b19c9f2b9cb9b0f4ab7a5a4329/pyyaml-6.0.3-cp313-cp313-win_arm64.whl", hash = "sha256:5498cd1645aa724a7c71c8f378eb29ebe23da2fc0d7a08071d89469bf1d2defb", size = 140246, upload-time = "2025-09-25T21:32:34.663Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9d/8c/f4bd7f6465179953d3ac9bc44ac1a8a3e6122cf8ada906b4f96c60172d43/pyyaml-6.0.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:8d1fab6bb153a416f9aeb4b8763bc0f22a5586065f86f7664fc23339fc1c1fac", size = 181814, upload-time = "2025-09-25T21:32:35.712Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/bd/9c/4d95bb87eb2063d20db7b60faa3840c1b18025517ae857371c4dd55a6b3a/pyyaml-6.0.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:34d5fcd24b8445fadc33f9cf348c1047101756fd760b4dacb5c3e99755703310", size = 173809, upload-time = "2025-09-25T21:32:36.789Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/92/b5/47e807c2623074914e29dabd16cbbdd4bf5e9b2db9f8090fa64411fc5382/pyyaml-6.0.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:501a031947e3a9025ed4405a168e6ef5ae3126c59f90ce0cd6f2bfc477be31b7", size = 766454, upload-time = "2025-09-25T21:32:37.966Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/02/9e/e5e9b168be58564121efb3de6859c452fccde0ab093d8438905899a3a483/pyyaml-6.0.3-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b3bc83488de33889877a0f2543ade9f70c67d66d9ebb4ac959502e12de895788", size = 836355, upload-time = "2025-09-25T21:32:39.178Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/88/f9/16491d7ed2a919954993e48aa941b200f38040928474c9e85ea9e64222c3/pyyaml-6.0.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c458b6d084f9b935061bc36216e8a69a7e293a2f1e68bf956dcd9e6cbcd143f5", size = 794175, upload-time = "2025-09-25T21:32:40.865Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/dd/3f/5989debef34dc6397317802b527dbbafb2b4760878a53d4166579111411e/pyyaml-6.0.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:7c6610def4f163542a622a73fb39f534f8c101d690126992300bf3207eab9764", size = 755228, upload-time = "2025-09-25T21:32:42.084Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d7/ce/af88a49043cd2e265be63d083fc75b27b6ed062f5f9fd6cdc223ad62f03e/pyyaml-6.0.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5190d403f121660ce8d1d2c1bb2ef1bd05b5f68533fc5c2ea899bd15f4399b35", size = 789194, upload-time = "2025-09-25T21:32:43.362Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/23/20/bb6982b26a40bb43951265ba29d4c246ef0ff59c9fdcdf0ed04e0687de4d/pyyaml-6.0.3-cp314-cp314-win_amd64.whl", hash = "sha256:4a2e8cebe2ff6ab7d1050ecd59c25d4c8bd7e6f400f5f82b96557ac0abafd0ac", size = 156429, upload-time = "2025-09-25T21:32:57.844Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f4/f4/a4541072bb9422c8a883ab55255f918fa378ecf083f5b85e87fc2b4eda1b/pyyaml-6.0.3-cp314-cp314-win_arm64.whl", hash = "sha256:93dda82c9c22deb0a405ea4dc5f2d0cda384168e466364dec6255b293923b2f3", size = 143912, upload-time = "2025-09-25T21:32:59.247Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/7c/f9/07dd09ae774e4616edf6cda684ee78f97777bdd15847253637a6f052a62f/pyyaml-6.0.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:02893d100e99e03eda1c8fd5c441d8c60103fd175728e23e431db1b589cf5ab3", size = 189108, upload-time = "2025-09-25T21:32:44.377Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4e/78/8d08c9fb7ce09ad8c38ad533c1191cf27f7ae1effe5bb9400a46d9437fcf/pyyaml-6.0.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:c1ff362665ae507275af2853520967820d9124984e0f7466736aea23d8611fba", size = 183641, upload-time = "2025-09-25T21:32:45.407Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/7b/5b/3babb19104a46945cf816d047db2788bcaf8c94527a805610b0289a01c6b/pyyaml-6.0.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6adc77889b628398debc7b65c073bcb99c4a0237b248cacaf3fe8a557563ef6c", size = 831901, upload-time = "2025-09-25T21:32:48.83Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8b/cc/dff0684d8dc44da4d22a13f35f073d558c268780ce3c6ba1b87055bb0b87/pyyaml-6.0.3-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a80cb027f6b349846a3bf6d73b5e95e782175e52f22108cfa17876aaeff93702", size = 861132, upload-time = "2025-09-25T21:32:50.149Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b1/5e/f77dc6b9036943e285ba76b49e118d9ea929885becb0a29ba8a7c75e29fe/pyyaml-6.0.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00c4bdeba853cc34e7dd471f16b4114f4162dc03e6b7afcc2128711f0eca823c", size = 839261, upload-time = "2025-09-25T21:32:51.808Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ce/88/a9db1376aa2a228197c58b37302f284b5617f56a5d959fd1763fb1675ce6/pyyaml-6.0.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:66e1674c3ef6f541c35191caae2d429b967b99e02040f5ba928632d9a7f0f065", size = 805272, upload-time = "2025-09-25T21:32:52.941Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/da/92/1446574745d74df0c92e6aa4a7b0b3130706a4142b2d1a5869f2eaa423c6/pyyaml-6.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:16249ee61e95f858e83976573de0f5b2893b3677ba71c9dd36b9cf8be9ac6d65", size = 829923, upload-time = "2025-09-25T21:32:54.537Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f0/7a/1c7270340330e575b92f397352af856a8c06f230aa3e76f86b39d01b416a/pyyaml-6.0.3-cp314-cp314t-win_amd64.whl", hash = "sha256:4ad1906908f2f5ae4e5a8ddfce73c320c2a1429ec52eafd27138b7f1cbe341c9", size = 174062, upload-time = "2025-09-25T21:32:55.767Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "safetensors"
|
||||||
|
version = "0.7.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/29/9c/6e74567782559a63bd040a236edca26fd71bc7ba88de2ef35d75df3bca5e/safetensors-0.7.0.tar.gz", hash = "sha256:07663963b67e8bd9f0b8ad15bb9163606cd27cc5a1b96235a50d8369803b96b0", size = 200878, upload-time = "2025-11-19T15:18:43.199Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/fa/47/aef6c06649039accf914afef490268e1067ed82be62bcfa5b7e886ad15e8/safetensors-0.7.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c82f4d474cf725255d9e6acf17252991c3c8aac038d6ef363a4bf8be2f6db517", size = 467781, upload-time = "2025-11-19T15:18:35.84Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e8/00/374c0c068e30cd31f1e1b46b4b5738168ec79e7689ca82ee93ddfea05109/safetensors-0.7.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:94fd4858284736bb67a897a41608b5b0c2496c9bdb3bf2af1fa3409127f20d57", size = 447058, upload-time = "2025-11-19T15:18:34.416Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f1/06/578ffed52c2296f93d7fd2d844cabfa92be51a587c38c8afbb8ae449ca89/safetensors-0.7.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e07d91d0c92a31200f25351f4acb2bc6aff7f48094e13ebb1d0fb995b54b6542", size = 491748, upload-time = "2025-11-19T15:18:09.79Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/ae/33/1debbbb70e4791dde185edb9413d1fe01619255abb64b300157d7f15dddd/safetensors-0.7.0-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8469155f4cb518bafb4acf4865e8bb9d6804110d2d9bdcaa78564b9fd841e104", size = 503881, upload-time = "2025-11-19T15:18:16.145Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8e/1c/40c2ca924d60792c3be509833df711b553c60effbd91da6f5284a83f7122/safetensors-0.7.0-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:54bef08bf00a2bff599982f6b08e8770e09cc012d7bba00783fc7ea38f1fb37d", size = 623463, upload-time = "2025-11-19T15:18:21.11Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9b/3a/13784a9364bd43b0d61eef4bea2845039bc2030458b16594a1bd787ae26e/safetensors-0.7.0-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:42cb091236206bb2016d245c377ed383aa7f78691748f3bb6ee1bfa51ae2ce6a", size = 532855, upload-time = "2025-11-19T15:18:25.719Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a0/60/429e9b1cb3fc651937727befe258ea24122d9663e4d5709a48c9cbfceecb/safetensors-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dac7252938f0696ddea46f5e855dd3138444e82236e3be475f54929f0c510d48", size = 507152, upload-time = "2025-11-19T15:18:33.023Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3c/a8/4b45e4e059270d17af60359713ffd83f97900d45a6afa73aaa0d737d48b6/safetensors-0.7.0-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1d060c70284127fa805085d8f10fbd0962792aed71879d00864acda69dbab981", size = 541856, upload-time = "2025-11-19T15:18:31.075Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/06/87/d26d8407c44175d8ae164a95b5a62707fcc445f3c0c56108e37d98070a3d/safetensors-0.7.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:cdab83a366799fa730f90a4ebb563e494f28e9e92c4819e556152ad55e43591b", size = 674060, upload-time = "2025-11-19T15:18:37.211Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/11/f5/57644a2ff08dc6325816ba7217e5095f17269dada2554b658442c66aed51/safetensors-0.7.0-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:672132907fcad9f2aedcb705b2d7b3b93354a2aec1b2f706c4db852abe338f85", size = 771715, upload-time = "2025-11-19T15:18:38.689Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/86/31/17883e13a814bd278ae6e266b13282a01049b0c81341da7fd0e3e71a80a3/safetensors-0.7.0-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:5d72abdb8a4d56d4020713724ba81dac065fedb7f3667151c4a637f1d3fb26c0", size = 714377, upload-time = "2025-11-19T15:18:40.162Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4a/d8/0c8a7dc9b41dcac53c4cbf9df2b9c83e0e0097203de8b37a712b345c0be5/safetensors-0.7.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b0f6d66c1c538d5a94a73aa9ddca8ccc4227e6c9ff555322ea40bdd142391dd4", size = 677368, upload-time = "2025-11-19T15:18:41.627Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/05/e5/cb4b713c8a93469e3c5be7c3f8d77d307e65fe89673e731f5c2bfd0a9237/safetensors-0.7.0-cp38-abi3-win32.whl", hash = "sha256:c74af94bf3ac15ac4d0f2a7c7b4663a15f8c2ab15ed0fc7531ca61d0835eccba", size = 326423, upload-time = "2025-11-19T15:18:45.74Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5d/e6/ec8471c8072382cb91233ba7267fd931219753bb43814cbc71757bfd4dab/safetensors-0.7.0-cp38-abi3-win_amd64.whl", hash = "sha256:d1239932053f56f3456f32eb9625590cc7582e905021f94636202a864d470755", size = 341380, upload-time = "2025-11-19T15:18:44.427Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "shellingham"
|
||||||
|
version = "1.5.4"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310, upload-time = "2023-10-24T04:13:40.426Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tqdm"
|
||||||
|
version = "4.67.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "typer-slim"
|
||||||
|
version = "0.21.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "click" },
|
||||||
|
{ name = "typing-extensions" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/17/d4/064570dec6358aa9049d4708e4a10407d74c99258f8b2136bb8702303f1a/typer_slim-0.21.1.tar.gz", hash = "sha256:73495dd08c2d0940d611c5a8c04e91c2a0a98600cbd4ee19192255a233b6dbfd", size = 110478, upload-time = "2026-01-06T11:21:11.176Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/c8/0a/4aca634faf693e33004796b6cee0ae2e1dba375a800c16ab8d3eff4bb800/typer_slim-0.21.1-py3-none-any.whl", hash = "sha256:6e6c31047f171ac93cc5a973c9e617dbc5ab2bddc4d0a3135dc161b4e2020e0d", size = 47444, upload-time = "2026-01-06T11:21:12.441Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "typing-extensions"
|
||||||
|
version = "4.15.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" },
|
||||||
|
]
|
||||||
Reference in New Issue
Block a user