initial commit (LTX-2)
This commit is contained in:
457
mlx_video/convert.py
Normal file
457
mlx_video/convert.py
Normal file
@@ -0,0 +1,457 @@
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
|
||||
from mlx_video.models.ltx.ltx import LTXModel
|
||||
|
||||
|
||||
def get_model_path(
|
||||
path_or_hf_repo: str,
|
||||
revision: Optional[str] = None,
|
||||
) -> Path:
|
||||
"""Get local path to model, downloading if necessary.
|
||||
|
||||
Args:
|
||||
path_or_hf_repo: Local path or HuggingFace repo ID
|
||||
revision: Git revision for HF repo
|
||||
|
||||
Returns:
|
||||
Path to model directory
|
||||
"""
|
||||
model_path = Path(path_or_hf_repo)
|
||||
|
||||
if model_path.exists():
|
||||
return model_path
|
||||
|
||||
# Download from HuggingFace
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_hf_repo,
|
||||
revision=revision,
|
||||
allow_patterns=[
|
||||
"*.safetensors",
|
||||
"*.json",
|
||||
"config.json",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
return model_path
|
||||
|
||||
|
||||
def load_safetensors(path: Path) -> Dict[str, mx.array]:
|
||||
"""Load weights from safetensors file(s) using MLX.
|
||||
|
||||
Args:
|
||||
path: Path to model directory or single safetensors file
|
||||
|
||||
Returns:
|
||||
Dictionary of weights
|
||||
"""
|
||||
weights = {}
|
||||
|
||||
if path.is_file():
|
||||
# Single file - use mx.load directly (handles bfloat16)
|
||||
return mx.load(str(path))
|
||||
else:
|
||||
# Directory - load all safetensors files
|
||||
safetensor_files = list(path.glob("*.safetensors"))
|
||||
for sf_path in safetensor_files:
|
||||
file_weights = mx.load(str(sf_path))
|
||||
weights.update(file_weights)
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def load_transformer_weights(model_path: Path) -> Dict[str, mx.array]:
|
||||
"""Load transformer weights from LTX-2 model.
|
||||
|
||||
Args:
|
||||
model_path: Path to LTX-2 model directory
|
||||
|
||||
Returns:
|
||||
Dictionary of transformer weights
|
||||
"""
|
||||
# Try distilled model first, then dev
|
||||
weight_files = [
|
||||
model_path / "ltx-2-19b-distilled.safetensors",
|
||||
model_path / "ltx-2-19b-dev.safetensors",
|
||||
]
|
||||
|
||||
for weight_file in weight_files:
|
||||
if weight_file.exists():
|
||||
print(f"Loading transformer weights from {weight_file.name}...")
|
||||
return mx.load(str(weight_file))
|
||||
|
||||
raise FileNotFoundError(f"No transformer weights found in {model_path}")
|
||||
|
||||
|
||||
def load_vae_weights(model_path: Path) -> Dict[str, mx.array]:
|
||||
"""Load VAE weights from LTX-2 model.
|
||||
|
||||
Args:
|
||||
model_path: Path to LTX-2 model directory
|
||||
|
||||
Returns:
|
||||
Dictionary of VAE weights
|
||||
"""
|
||||
vae_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
|
||||
if vae_path.exists():
|
||||
print(f"Loading VAE weights from {vae_path}...")
|
||||
return mx.load(str(vae_path))
|
||||
|
||||
raise FileNotFoundError(f"VAE weights not found at {vae_path}")
|
||||
|
||||
|
||||
def sanitize_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize transformer weight names from PyTorch LTX-2 format to MLX format.
|
||||
|
||||
Args:
|
||||
weights: Dictionary of weights with PyTorch naming
|
||||
|
||||
Returns:
|
||||
Dictionary with MLX-compatible naming for transformer
|
||||
"""
|
||||
sanitized = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Skip non-transformer weights (VAE, vocoder, audio_vae, connectors)
|
||||
if not key.startswith("model.diffusion_model."):
|
||||
continue
|
||||
|
||||
# Remove 'model.diffusion_model.' prefix
|
||||
new_key = key.replace("model.diffusion_model.", "")
|
||||
|
||||
# Handle to_out.0 -> to_out (MLX doesn't use Sequential numbering)
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
|
||||
# Handle feed-forward net naming
|
||||
# PyTorch: ff.net.0.proj -> ff.net_0_proj (or similar)
|
||||
# MLX FeedForward: uses proj_in, proj_out
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
|
||||
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
||||
|
||||
# Handle AdaLN naming - keep emb wrapper, just fix linear naming
|
||||
# PyTorch: adaln_single.emb.timestep_embedder.linear_1 -> adaln_single.emb.timestep_embedder.linear1
|
||||
new_key = new_key.replace(".linear_1.", ".linear1.")
|
||||
new_key = new_key.replace(".linear_2.", ".linear2.")
|
||||
|
||||
# Handle caption projection (keep linear1/linear2 naming for compatibility)
|
||||
# These are already mapped correctly in the sanitization
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize VAE weight names from PyTorch format to MLX format.
|
||||
|
||||
Args:
|
||||
weights: Dictionary of weights with PyTorch naming
|
||||
|
||||
Returns:
|
||||
Dictionary with MLX-compatible naming for VAE
|
||||
"""
|
||||
sanitized = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Skip position_ids (not needed)
|
||||
if "position_ids" in key:
|
||||
continue
|
||||
|
||||
# Handle Conv3d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, D, H, W)
|
||||
# MLX: (out_channels, D, H, W, in_channels)
|
||||
if "conv" in key.lower() and "weight" in key and value.ndim == 5:
|
||||
# Transpose from (O, I, D, H, W) to (O, D, H, W, I)
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
# Handle Conv2d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, H, W)
|
||||
# MLX: (out_channels, H, W, in_channels)
|
||||
if "conv" in key.lower() and "weight" in key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize weight names from PyTorch format to MLX format.
|
||||
|
||||
Generic function that handles both transformer and VAE weights.
|
||||
|
||||
Args:
|
||||
weights: Dictionary of weights with PyTorch naming
|
||||
|
||||
Returns:
|
||||
Dictionary with MLX-compatible naming
|
||||
"""
|
||||
sanitized = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Skip position_ids (not needed)
|
||||
if "position_ids" in key:
|
||||
continue
|
||||
|
||||
# Handle transformer weights
|
||||
if key.startswith("model.diffusion_model."):
|
||||
new_key = key.replace("model.diffusion_model.", "")
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
|
||||
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
||||
new_key = new_key.replace(".linear_1.", ".linear1.")
|
||||
new_key = new_key.replace(".linear_2.", ".linear2.")
|
||||
|
||||
# Handle Conv3d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, D, H, W)
|
||||
# MLX: (out_channels, D, H, W, in_channels)
|
||||
if "conv" in key.lower() and "weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
# Handle Conv2d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, H, W)
|
||||
# MLX: (out_channels, H, W, in_channels)
|
||||
if "conv" in key.lower() and "weight" in key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def load_config(model_path: Path) -> Dict[str, Any]:
|
||||
"""Load model configuration.
|
||||
|
||||
Args:
|
||||
model_path: Path to model directory
|
||||
|
||||
Returns:
|
||||
Configuration dictionary
|
||||
"""
|
||||
config_path = model_path / "config.json"
|
||||
|
||||
if config_path.exists():
|
||||
with open(config_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
# Return default config
|
||||
return {}
|
||||
|
||||
|
||||
def create_model_from_config(config: Dict[str, Any]) -> LTXModel:
|
||||
"""Create model instance from configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
LTXModel instance
|
||||
"""
|
||||
# Map config to LTXModelConfig
|
||||
model_config = LTXModelConfig(
|
||||
model_type=LTXModelType.AudioVideo,
|
||||
num_attention_heads=config.get("num_attention_heads", 32),
|
||||
attention_head_dim=config.get("attention_head_dim", 128),
|
||||
in_channels=config.get("in_channels", 128),
|
||||
out_channels=config.get("out_channels", 128),
|
||||
num_layers=config.get("num_layers", 48),
|
||||
cross_attention_dim=config.get("cross_attention_dim", 4096),
|
||||
caption_channels=config.get("caption_channels", 3840),
|
||||
audio_num_attention_heads=config.get("audio_num_attention_heads", 32),
|
||||
audio_attention_head_dim=config.get("audio_attention_head_dim", 64),
|
||||
audio_in_channels=config.get("audio_in_channels", 128),
|
||||
audio_out_channels=config.get("audio_out_channels", 128),
|
||||
audio_cross_attention_dim=config.get("audio_cross_attention_dim", 2048),
|
||||
positional_embedding_theta=config.get("positional_embedding_theta", 10000.0),
|
||||
positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]),
|
||||
audio_positional_embedding_max_pos=config.get("audio_positional_embedding_max_pos", [20]),
|
||||
timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000),
|
||||
av_ca_timestep_scale_multiplier=config.get("av_ca_timestep_scale_multiplier", 1),
|
||||
norm_eps=config.get("norm_eps", 1e-6),
|
||||
)
|
||||
|
||||
return LTXModel(model_config)
|
||||
|
||||
|
||||
def convert(
|
||||
hf_path: str,
|
||||
mlx_path: str = "mlx_model",
|
||||
dtype: Optional[str] = None,
|
||||
quantize: bool = False,
|
||||
q_bits: int = 4,
|
||||
q_group_size: int = 64,
|
||||
) -> Path:
|
||||
"""Convert HuggingFace model to MLX format.
|
||||
|
||||
Args:
|
||||
hf_path: HuggingFace model path or repo ID
|
||||
mlx_path: Output path for MLX model
|
||||
dtype: Target dtype (float16, float32, bfloat16)
|
||||
quantize: Whether to quantize the model
|
||||
q_bits: Quantization bits
|
||||
q_group_size: Quantization group size
|
||||
|
||||
Returns:
|
||||
Path to converted model
|
||||
"""
|
||||
print(f"Loading model from {hf_path}...")
|
||||
model_path = get_model_path(hf_path)
|
||||
|
||||
# Load config
|
||||
config = load_config(model_path)
|
||||
|
||||
# Load weights
|
||||
print("Loading weights...")
|
||||
weights = load_safetensors(model_path)
|
||||
|
||||
# Sanitize weights
|
||||
print("Sanitizing weights...")
|
||||
weights = sanitize_weights(weights)
|
||||
|
||||
# Convert dtype if specified
|
||||
if dtype is not None:
|
||||
dtype_map = {
|
||||
"float16": mx.float16,
|
||||
"float32": mx.float32,
|
||||
"bfloat16": mx.bfloat16,
|
||||
}
|
||||
target_dtype = dtype_map.get(dtype, mx.float16)
|
||||
print(f"Converting to {dtype}...")
|
||||
weights = {
|
||||
k: v.astype(target_dtype) if v.dtype in [mx.float32, mx.float16, mx.bfloat16] else v
|
||||
for k, v in weights.items()
|
||||
}
|
||||
|
||||
# Create output directory
|
||||
output_path = Path(mlx_path)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save weights
|
||||
print(f"Saving weights to {output_path}...")
|
||||
save_weights(output_path, weights)
|
||||
|
||||
# Save config
|
||||
config_out_path = output_path / "config.json"
|
||||
with open(config_out_path, "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
print(f"Model converted successfully to {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
def save_weights(path: Path, weights: Dict[str, mx.array]) -> None:
|
||||
"""Save weights in safetensors format.
|
||||
|
||||
Args:
|
||||
path: Output directory
|
||||
weights: Dictionary of weights
|
||||
"""
|
||||
from safetensors.numpy import save_file
|
||||
import numpy as np
|
||||
|
||||
# Convert to numpy for safetensors
|
||||
np_weights = {k: np.array(v) for k, v in weights.items()}
|
||||
|
||||
# Save to file
|
||||
save_file(np_weights, path / "model.safetensors")
|
||||
|
||||
|
||||
def load_model(
|
||||
path_or_hf_repo: str,
|
||||
lazy: bool = False,
|
||||
) -> LTXModel:
|
||||
"""Load LTX model from path or HuggingFace.
|
||||
|
||||
Args:
|
||||
path_or_hf_repo: Path to model or HuggingFace repo ID
|
||||
lazy: Whether to use lazy loading
|
||||
|
||||
Returns:
|
||||
Loaded LTXModel
|
||||
"""
|
||||
model_path = get_model_path(path_or_hf_repo)
|
||||
|
||||
# Load config
|
||||
config = load_config(model_path)
|
||||
|
||||
# Create model
|
||||
model = create_model_from_config(config)
|
||||
|
||||
# Load weights
|
||||
weights = load_safetensors(model_path)
|
||||
|
||||
# Sanitize if needed
|
||||
weights = sanitize_weights(weights)
|
||||
|
||||
# Load weights into model
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
if not lazy:
|
||||
mx.eval(model.parameters())
|
||||
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert LTX-2 model to MLX format")
|
||||
parser.add_argument(
|
||||
"--hf-path",
|
||||
type=str,
|
||||
default="Lightricks/LTX-2",
|
||||
help="HuggingFace model path or repo ID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlx-path",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="Output path for MLX model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=["float16", "float32", "bfloat16"],
|
||||
default="float16",
|
||||
help="Target dtype",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize",
|
||||
action="store_true",
|
||||
help="Quantize the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-bits",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Quantization bits",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert(
|
||||
hf_path=args.hf_path,
|
||||
mlx_path=args.mlx_path,
|
||||
dtype=args.dtype,
|
||||
quantize=args.quantize,
|
||||
q_bits=args.q_bits,
|
||||
)
|
||||
Reference in New Issue
Block a user