458 lines
13 KiB
Python
458 lines
13 KiB
Python
import json
|
|
import shutil
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Optional, Union
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
from huggingface_hub import snapshot_download
|
|
|
|
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
|
|
from mlx_video.models.ltx.ltx import LTXModel
|
|
|
|
|
|
def get_model_path(
|
|
path_or_hf_repo: str,
|
|
revision: Optional[str] = None,
|
|
) -> Path:
|
|
"""Get local path to model, downloading if necessary.
|
|
|
|
Args:
|
|
path_or_hf_repo: Local path or HuggingFace repo ID
|
|
revision: Git revision for HF repo
|
|
|
|
Returns:
|
|
Path to model directory
|
|
"""
|
|
model_path = Path(path_or_hf_repo)
|
|
|
|
if model_path.exists():
|
|
return model_path
|
|
|
|
# Download from HuggingFace
|
|
model_path = Path(
|
|
snapshot_download(
|
|
repo_id=path_or_hf_repo,
|
|
revision=revision,
|
|
allow_patterns=[
|
|
"*.safetensors",
|
|
"*.json",
|
|
"config.json",
|
|
],
|
|
)
|
|
)
|
|
|
|
return model_path
|
|
|
|
|
|
def load_safetensors(path: Path) -> Dict[str, mx.array]:
|
|
"""Load weights from safetensors file(s) using MLX.
|
|
|
|
Args:
|
|
path: Path to model directory or single safetensors file
|
|
|
|
Returns:
|
|
Dictionary of weights
|
|
"""
|
|
weights = {}
|
|
|
|
if path.is_file():
|
|
# Single file - use mx.load directly (handles bfloat16)
|
|
return mx.load(str(path))
|
|
else:
|
|
# Directory - load all safetensors files
|
|
safetensor_files = list(path.glob("*.safetensors"))
|
|
for sf_path in safetensor_files:
|
|
file_weights = mx.load(str(sf_path))
|
|
weights.update(file_weights)
|
|
|
|
return weights
|
|
|
|
|
|
def load_transformer_weights(model_path: Path) -> Dict[str, mx.array]:
|
|
"""Load transformer weights from LTX-2 model.
|
|
|
|
Args:
|
|
model_path: Path to LTX-2 model directory
|
|
|
|
Returns:
|
|
Dictionary of transformer weights
|
|
"""
|
|
# Try distilled model first, then dev
|
|
weight_files = [
|
|
model_path / "ltx-2-19b-distilled.safetensors",
|
|
model_path / "ltx-2-19b-dev.safetensors",
|
|
]
|
|
|
|
for weight_file in weight_files:
|
|
if weight_file.exists():
|
|
print(f"Loading transformer weights from {weight_file.name}...")
|
|
return mx.load(str(weight_file))
|
|
|
|
raise FileNotFoundError(f"No transformer weights found in {model_path}")
|
|
|
|
|
|
def load_vae_weights(model_path: Path) -> Dict[str, mx.array]:
|
|
"""Load VAE weights from LTX-2 model.
|
|
|
|
Args:
|
|
model_path: Path to LTX-2 model directory
|
|
|
|
Returns:
|
|
Dictionary of VAE weights
|
|
"""
|
|
vae_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
|
|
if vae_path.exists():
|
|
print(f"Loading VAE weights from {vae_path}...")
|
|
return mx.load(str(vae_path))
|
|
|
|
raise FileNotFoundError(f"VAE weights not found at {vae_path}")
|
|
|
|
|
|
def sanitize_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|
"""Sanitize transformer weight names from PyTorch LTX-2 format to MLX format.
|
|
|
|
Args:
|
|
weights: Dictionary of weights with PyTorch naming
|
|
|
|
Returns:
|
|
Dictionary with MLX-compatible naming for transformer
|
|
"""
|
|
sanitized = {}
|
|
|
|
for key, value in weights.items():
|
|
new_key = key
|
|
|
|
# Skip non-transformer weights (VAE, vocoder, audio_vae, connectors)
|
|
if not key.startswith("model.diffusion_model."):
|
|
continue
|
|
|
|
# Remove 'model.diffusion_model.' prefix
|
|
new_key = key.replace("model.diffusion_model.", "")
|
|
|
|
# Handle to_out.0 -> to_out (MLX doesn't use Sequential numbering)
|
|
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
|
|
|
# Handle feed-forward net naming
|
|
# PyTorch: ff.net.0.proj -> ff.net_0_proj (or similar)
|
|
# MLX FeedForward: uses proj_in, proj_out
|
|
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
|
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
|
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
|
|
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
|
|
|
# Handle AdaLN naming - keep emb wrapper, just fix linear naming
|
|
# PyTorch: adaln_single.emb.timestep_embedder.linear_1 -> adaln_single.emb.timestep_embedder.linear1
|
|
new_key = new_key.replace(".linear_1.", ".linear1.")
|
|
new_key = new_key.replace(".linear_2.", ".linear2.")
|
|
|
|
# Handle caption projection (keep linear1/linear2 naming for compatibility)
|
|
# These are already mapped correctly in the sanitization
|
|
|
|
sanitized[new_key] = value
|
|
|
|
return sanitized
|
|
|
|
|
|
def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|
"""Sanitize VAE weight names from PyTorch format to MLX format.
|
|
|
|
Args:
|
|
weights: Dictionary of weights with PyTorch naming
|
|
|
|
Returns:
|
|
Dictionary with MLX-compatible naming for VAE
|
|
"""
|
|
sanitized = {}
|
|
|
|
for key, value in weights.items():
|
|
new_key = key
|
|
|
|
# Skip position_ids (not needed)
|
|
if "position_ids" in key:
|
|
continue
|
|
|
|
# Handle Conv3d weight shape conversion
|
|
# PyTorch: (out_channels, in_channels, D, H, W)
|
|
# MLX: (out_channels, D, H, W, in_channels)
|
|
if "conv" in key.lower() and "weight" in key and value.ndim == 5:
|
|
# Transpose from (O, I, D, H, W) to (O, D, H, W, I)
|
|
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
|
|
|
# Handle Conv2d weight shape conversion
|
|
# PyTorch: (out_channels, in_channels, H, W)
|
|
# MLX: (out_channels, H, W, in_channels)
|
|
if "conv" in key.lower() and "weight" in key and value.ndim == 4:
|
|
value = mx.transpose(value, (0, 2, 3, 1))
|
|
|
|
sanitized[new_key] = value
|
|
|
|
return sanitized
|
|
|
|
|
|
def sanitize_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|
"""Sanitize weight names from PyTorch format to MLX format.
|
|
|
|
Generic function that handles both transformer and VAE weights.
|
|
|
|
Args:
|
|
weights: Dictionary of weights with PyTorch naming
|
|
|
|
Returns:
|
|
Dictionary with MLX-compatible naming
|
|
"""
|
|
sanitized = {}
|
|
|
|
for key, value in weights.items():
|
|
new_key = key
|
|
|
|
# Skip position_ids (not needed)
|
|
if "position_ids" in key:
|
|
continue
|
|
|
|
# Handle transformer weights
|
|
if key.startswith("model.diffusion_model."):
|
|
new_key = key.replace("model.diffusion_model.", "")
|
|
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
|
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
|
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
|
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
|
|
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
|
new_key = new_key.replace(".linear_1.", ".linear1.")
|
|
new_key = new_key.replace(".linear_2.", ".linear2.")
|
|
|
|
# Handle Conv3d weight shape conversion
|
|
# PyTorch: (out_channels, in_channels, D, H, W)
|
|
# MLX: (out_channels, D, H, W, in_channels)
|
|
if "conv" in key.lower() and "weight" in key and value.ndim == 5:
|
|
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
|
|
|
# Handle Conv2d weight shape conversion
|
|
# PyTorch: (out_channels, in_channels, H, W)
|
|
# MLX: (out_channels, H, W, in_channels)
|
|
if "conv" in key.lower() and "weight" in key and value.ndim == 4:
|
|
value = mx.transpose(value, (0, 2, 3, 1))
|
|
|
|
sanitized[new_key] = value
|
|
|
|
return sanitized
|
|
|
|
|
|
def load_config(model_path: Path) -> Dict[str, Any]:
|
|
"""Load model configuration.
|
|
|
|
Args:
|
|
model_path: Path to model directory
|
|
|
|
Returns:
|
|
Configuration dictionary
|
|
"""
|
|
config_path = model_path / "config.json"
|
|
|
|
if config_path.exists():
|
|
with open(config_path, "r") as f:
|
|
return json.load(f)
|
|
|
|
# Return default config
|
|
return {}
|
|
|
|
|
|
def create_model_from_config(config: Dict[str, Any]) -> LTXModel:
|
|
"""Create model instance from configuration.
|
|
|
|
Args:
|
|
config: Configuration dictionary
|
|
|
|
Returns:
|
|
LTXModel instance
|
|
"""
|
|
# Map config to LTXModelConfig
|
|
model_config = LTXModelConfig(
|
|
model_type=LTXModelType.AudioVideo,
|
|
num_attention_heads=config.get("num_attention_heads", 32),
|
|
attention_head_dim=config.get("attention_head_dim", 128),
|
|
in_channels=config.get("in_channels", 128),
|
|
out_channels=config.get("out_channels", 128),
|
|
num_layers=config.get("num_layers", 48),
|
|
cross_attention_dim=config.get("cross_attention_dim", 4096),
|
|
caption_channels=config.get("caption_channels", 3840),
|
|
audio_num_attention_heads=config.get("audio_num_attention_heads", 32),
|
|
audio_attention_head_dim=config.get("audio_attention_head_dim", 64),
|
|
audio_in_channels=config.get("audio_in_channels", 128),
|
|
audio_out_channels=config.get("audio_out_channels", 128),
|
|
audio_cross_attention_dim=config.get("audio_cross_attention_dim", 2048),
|
|
positional_embedding_theta=config.get("positional_embedding_theta", 10000.0),
|
|
positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]),
|
|
audio_positional_embedding_max_pos=config.get("audio_positional_embedding_max_pos", [20]),
|
|
timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000),
|
|
av_ca_timestep_scale_multiplier=config.get("av_ca_timestep_scale_multiplier", 1),
|
|
norm_eps=config.get("norm_eps", 1e-6),
|
|
)
|
|
|
|
return LTXModel(model_config)
|
|
|
|
|
|
def convert(
|
|
hf_path: str,
|
|
mlx_path: str = "mlx_model",
|
|
dtype: Optional[str] = None,
|
|
quantize: bool = False,
|
|
q_bits: int = 4,
|
|
q_group_size: int = 64,
|
|
) -> Path:
|
|
"""Convert HuggingFace model to MLX format.
|
|
|
|
Args:
|
|
hf_path: HuggingFace model path or repo ID
|
|
mlx_path: Output path for MLX model
|
|
dtype: Target dtype (float16, float32, bfloat16)
|
|
quantize: Whether to quantize the model
|
|
q_bits: Quantization bits
|
|
q_group_size: Quantization group size
|
|
|
|
Returns:
|
|
Path to converted model
|
|
"""
|
|
print(f"Loading model from {hf_path}...")
|
|
model_path = get_model_path(hf_path)
|
|
|
|
# Load config
|
|
config = load_config(model_path)
|
|
|
|
# Load weights
|
|
print("Loading weights...")
|
|
weights = load_safetensors(model_path)
|
|
|
|
# Sanitize weights
|
|
print("Sanitizing weights...")
|
|
weights = sanitize_weights(weights)
|
|
|
|
# Convert dtype if specified
|
|
if dtype is not None:
|
|
dtype_map = {
|
|
"float16": mx.float16,
|
|
"float32": mx.float32,
|
|
"bfloat16": mx.bfloat16,
|
|
}
|
|
target_dtype = dtype_map.get(dtype, mx.float16)
|
|
print(f"Converting to {dtype}...")
|
|
weights = {
|
|
k: v.astype(target_dtype) if v.dtype in [mx.float32, mx.float16, mx.bfloat16] else v
|
|
for k, v in weights.items()
|
|
}
|
|
|
|
# Create output directory
|
|
output_path = Path(mlx_path)
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save weights
|
|
print(f"Saving weights to {output_path}...")
|
|
save_weights(output_path, weights)
|
|
|
|
# Save config
|
|
config_out_path = output_path / "config.json"
|
|
with open(config_out_path, "w") as f:
|
|
json.dump(config, f, indent=2)
|
|
|
|
print(f"Model converted successfully to {output_path}")
|
|
return output_path
|
|
|
|
|
|
def save_weights(path: Path, weights: Dict[str, mx.array]) -> None:
|
|
"""Save weights in safetensors format.
|
|
|
|
Args:
|
|
path: Output directory
|
|
weights: Dictionary of weights
|
|
"""
|
|
from safetensors.numpy import save_file
|
|
import numpy as np
|
|
|
|
# Convert to numpy for safetensors
|
|
np_weights = {k: np.array(v) for k, v in weights.items()}
|
|
|
|
# Save to file
|
|
save_file(np_weights, path / "model.safetensors")
|
|
|
|
|
|
def load_model(
|
|
path_or_hf_repo: str,
|
|
lazy: bool = False,
|
|
) -> LTXModel:
|
|
"""Load LTX model from path or HuggingFace.
|
|
|
|
Args:
|
|
path_or_hf_repo: Path to model or HuggingFace repo ID
|
|
lazy: Whether to use lazy loading
|
|
|
|
Returns:
|
|
Loaded LTXModel
|
|
"""
|
|
model_path = get_model_path(path_or_hf_repo)
|
|
|
|
# Load config
|
|
config = load_config(model_path)
|
|
|
|
# Create model
|
|
model = create_model_from_config(config)
|
|
|
|
# Load weights
|
|
weights = load_safetensors(model_path)
|
|
|
|
# Sanitize if needed
|
|
weights = sanitize_weights(weights)
|
|
|
|
# Load weights into model
|
|
model.load_weights(list(weights.items()))
|
|
|
|
if not lazy:
|
|
mx.eval(model.parameters())
|
|
|
|
return model
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Convert LTX-2 model to MLX format")
|
|
parser.add_argument(
|
|
"--hf-path",
|
|
type=str,
|
|
default="Lightricks/LTX-2",
|
|
help="HuggingFace model path or repo ID",
|
|
)
|
|
parser.add_argument(
|
|
"--mlx-path",
|
|
type=str,
|
|
default="mlx_model",
|
|
help="Output path for MLX model",
|
|
)
|
|
parser.add_argument(
|
|
"--dtype",
|
|
type=str,
|
|
choices=["float16", "float32", "bfloat16"],
|
|
default="float16",
|
|
help="Target dtype",
|
|
)
|
|
parser.add_argument(
|
|
"--quantize",
|
|
action="store_true",
|
|
help="Quantize the model",
|
|
)
|
|
parser.add_argument(
|
|
"--q-bits",
|
|
type=int,
|
|
default=4,
|
|
help="Quantization bits",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
convert(
|
|
hf_path=args.hf_path,
|
|
mlx_path=args.mlx_path,
|
|
dtype=args.dtype,
|
|
quantize=args.quantize,
|
|
q_bits=args.q_bits,
|
|
)
|