Refactor weight loading and utility functions for LTX-2 model; remove deprecated weight loading file and update imports accordingly
This commit is contained in:
@@ -1633,7 +1633,7 @@ def generate_video(
|
||||
a2v_sr = None
|
||||
if is_a2v:
|
||||
from mlx_video.models.ltx_2.audio_vae.audio_processor import load_audio, ensure_stereo, waveform_to_mel
|
||||
from mlx_video.convert import convert_audio_encoder
|
||||
from mlx_video.models.ltx_2.utils import convert_audio_encoder
|
||||
from mlx_video.models.ltx_2.audio_vae import AudioEncoder
|
||||
|
||||
with console.status("[blue]Loading and encoding input audio (A2V)...[/]", spinner="dots"):
|
||||
|
||||
161
mlx_video/models/ltx_2/utils.py
Normal file
161
mlx_video/models/ltx_2/utils.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Shared utilities for LTX-2 model loading and conversion."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
def get_model_path(
|
||||
path_or_hf_repo: str,
|
||||
revision: Optional[str] = None,
|
||||
) -> Path:
|
||||
"""Get local path to model, downloading if necessary.
|
||||
|
||||
Args:
|
||||
path_or_hf_repo: Local path or HuggingFace repo ID
|
||||
revision: Git revision for HF repo
|
||||
|
||||
Returns:
|
||||
Path to model directory
|
||||
"""
|
||||
model_path = Path(path_or_hf_repo)
|
||||
|
||||
if model_path.exists():
|
||||
return model_path
|
||||
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_hf_repo,
|
||||
revision=revision,
|
||||
allow_patterns=[
|
||||
"*.safetensors",
|
||||
"*.json",
|
||||
"config.json",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
return model_path
|
||||
|
||||
|
||||
def load_safetensors(path: Path) -> Dict[str, mx.array]:
|
||||
"""Load weights from safetensors file(s) using MLX.
|
||||
|
||||
Args:
|
||||
path: Path to model directory or single safetensors file
|
||||
|
||||
Returns:
|
||||
Dictionary of weights
|
||||
"""
|
||||
if path.is_file():
|
||||
return mx.load(str(path))
|
||||
|
||||
weights = {}
|
||||
for sf_path in path.glob("*.safetensors"):
|
||||
weights.update(mx.load(str(sf_path)))
|
||||
return weights
|
||||
|
||||
|
||||
def load_config(model_path: Path) -> Dict[str, Any]:
|
||||
"""Load model configuration from config.json.
|
||||
|
||||
Args:
|
||||
model_path: Path to model directory
|
||||
|
||||
Returns:
|
||||
Configuration dictionary
|
||||
"""
|
||||
config_path = model_path / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path, "r") as f:
|
||||
return json.load(f)
|
||||
return {}
|
||||
|
||||
|
||||
def save_weights(path: Path, weights: Dict[str, mx.array]) -> None:
|
||||
"""Save weights in safetensors format.
|
||||
|
||||
Args:
|
||||
path: Output directory
|
||||
weights: Dictionary of weights
|
||||
"""
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
mx.save_safetensors(str(path / "model.safetensors"), weights)
|
||||
|
||||
|
||||
def convert_audio_encoder(
|
||||
model_path,
|
||||
source_repo: str = "Lightricks/LTX-2",
|
||||
) -> Path:
|
||||
"""Convert and save audio encoder weights from original HF checkpoint.
|
||||
|
||||
Extracts encoder weights from the combined audio VAE safetensors,
|
||||
transposes Conv2d for MLX, and saves for AudioEncoder.from_pretrained().
|
||||
|
||||
Args:
|
||||
model_path: Local model directory (output location).
|
||||
source_repo: HF repo containing audio_vae/diffusion_pytorch_model.safetensors.
|
||||
|
||||
Returns:
|
||||
Path to the audio_vae/encoder directory.
|
||||
"""
|
||||
model_path = Path(model_path)
|
||||
encoder_dir = model_path / "audio_vae" / "encoder"
|
||||
|
||||
if (encoder_dir / "model.safetensors").exists():
|
||||
return encoder_dir
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
vae_path = hf_hub_download(
|
||||
source_repo,
|
||||
"audio_vae/diffusion_pytorch_model.safetensors",
|
||||
)
|
||||
|
||||
raw_weights = mx.load(vae_path)
|
||||
|
||||
from mlx_video.models.ltx_2.audio_vae import AudioEncoder
|
||||
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
||||
|
||||
# Build config from the decoder config (same audio VAE architecture)
|
||||
decoder_config_path = model_path / "audio_vae" / "decoder" / "config.json"
|
||||
if decoder_config_path.exists():
|
||||
with open(decoder_config_path) as f:
|
||||
dec_cfg = json.load(f)
|
||||
enc_config = {
|
||||
"ch": dec_cfg.get("ch", 128),
|
||||
"in_channels": dec_cfg.get("out_ch", 2),
|
||||
"ch_mult": dec_cfg.get("ch_mult", [1, 2, 4]),
|
||||
"num_res_blocks": dec_cfg.get("num_res_blocks", 2),
|
||||
"attn_resolutions": dec_cfg.get("attn_resolutions", []),
|
||||
"resolution": dec_cfg.get("resolution", 256),
|
||||
"z_channels": dec_cfg.get("z_channels", 8),
|
||||
"double_z": True,
|
||||
"n_fft": 1024,
|
||||
"norm_type": dec_cfg.get("norm_type", "pixel"),
|
||||
"causality_axis": dec_cfg.get("causality_axis", "height"),
|
||||
"dropout": dec_cfg.get("dropout", 0.0),
|
||||
"mid_block_add_attention": dec_cfg.get("mid_block_add_attention", False),
|
||||
"sample_rate": dec_cfg.get("sample_rate", 16000),
|
||||
"mel_hop_length": dec_cfg.get("mel_hop_length", 160),
|
||||
"is_causal": dec_cfg.get("is_causal", True),
|
||||
"mel_bins": dec_cfg.get("mel_bins", 64) or 64,
|
||||
"resamp_with_conv": dec_cfg.get("resamp_with_conv", True),
|
||||
"attn_type": dec_cfg.get("attn_type", "vanilla"),
|
||||
}
|
||||
else:
|
||||
enc_config = {"in_channels": 2, "double_z": True, "n_fft": 1024, "mel_bins": 64}
|
||||
|
||||
config = AudioEncoderModelConfig.from_dict(enc_config)
|
||||
encoder = AudioEncoder(config)
|
||||
sanitized = encoder.sanitize(raw_weights)
|
||||
|
||||
encoder_dir.mkdir(parents=True, exist_ok=True)
|
||||
mx.save_safetensors(str(encoder_dir / "model.safetensors"), sanitized)
|
||||
with open(encoder_dir / "config.json", "w") as f:
|
||||
json.dump(enc_config, f, indent=2)
|
||||
|
||||
print(f"Audio encoder weights saved to {encoder_dir}")
|
||||
return encoder_dir
|
||||
@@ -1,770 +0,0 @@
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from mlx_video.models.ltx_2.config import LTXModelConfig, LTXModelType
|
||||
from mlx_video.models.ltx_2.ltx import LTXModel
|
||||
|
||||
|
||||
def get_model_path(
|
||||
path_or_hf_repo: str,
|
||||
revision: Optional[str] = None,
|
||||
) -> Path:
|
||||
"""Get local path to model, downloading if necessary.
|
||||
|
||||
Args:
|
||||
path_or_hf_repo: Local path or HuggingFace repo ID
|
||||
revision: Git revision for HF repo
|
||||
|
||||
Returns:
|
||||
Path to model directory
|
||||
"""
|
||||
model_path = Path(path_or_hf_repo)
|
||||
|
||||
if model_path.exists():
|
||||
return model_path
|
||||
|
||||
# Download from HuggingFace
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_hf_repo,
|
||||
revision=revision,
|
||||
allow_patterns=[
|
||||
"*.safetensors",
|
||||
"*.json",
|
||||
"config.json",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
return model_path
|
||||
|
||||
|
||||
def load_safetensors(path: Path) -> Dict[str, mx.array]:
|
||||
"""Load weights from safetensors file(s) using MLX.
|
||||
|
||||
Args:
|
||||
path: Path to model directory or single safetensors file
|
||||
|
||||
Returns:
|
||||
Dictionary of weights
|
||||
"""
|
||||
weights = {}
|
||||
|
||||
if path.is_file():
|
||||
# Single file - use mx.load directly (handles bfloat16)
|
||||
return mx.load(str(path))
|
||||
else:
|
||||
# Directory - load all safetensors files
|
||||
safetensor_files = list(path.glob("*.safetensors"))
|
||||
for sf_path in safetensor_files:
|
||||
file_weights = mx.load(str(sf_path))
|
||||
weights.update(file_weights)
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def load_transformer_weights(model_path: Path) -> Dict[str, mx.array]:
|
||||
"""Load transformer weights from LTX-2 model.
|
||||
|
||||
Args:
|
||||
model_path: Path to LTX-2 model directory
|
||||
|
||||
Returns:
|
||||
Dictionary of transformer weights
|
||||
"""
|
||||
# Try distilled model first, then dev
|
||||
weight_files = [
|
||||
model_path / "ltx-2-19b-distilled.safetensors",
|
||||
model_path / "ltx-2-19b-dev.safetensors",
|
||||
]
|
||||
|
||||
for weight_file in weight_files:
|
||||
if weight_file.exists():
|
||||
print(f"Loading transformer weights from {weight_file.name}...")
|
||||
return mx.load(str(weight_file))
|
||||
|
||||
raise FileNotFoundError(f"No transformer weights found in {model_path}")
|
||||
|
||||
|
||||
def load_vae_weights(model_path: Path) -> Dict[str, mx.array]:
|
||||
"""Load VAE weights from LTX-2 model.
|
||||
|
||||
Args:
|
||||
model_path: Path to LTX-2 model directory
|
||||
|
||||
Returns:
|
||||
Dictionary of VAE weights
|
||||
"""
|
||||
vae_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
|
||||
if vae_path.exists():
|
||||
print(f"Loading VAE weights from {vae_path}...")
|
||||
return mx.load(str(vae_path))
|
||||
|
||||
raise FileNotFoundError(f"VAE weights not found at {vae_path}")
|
||||
|
||||
|
||||
def load_audio_vae_weights(model_path: Path) -> Dict[str, mx.array]:
|
||||
"""Load audio VAE weights from LTX-2 model.
|
||||
|
||||
Args:
|
||||
model_path: Path to LTX-2 model directory
|
||||
|
||||
Returns:
|
||||
Dictionary of audio VAE weights
|
||||
"""
|
||||
# Try different possible paths for audio VAE weights
|
||||
audio_vae_paths = [
|
||||
model_path / "audio_vae" / "decoder" / "model.safetensors",
|
||||
model_path / "audio_vae" / "decoder" / "diffusion_pytorch_model.safetensors",
|
||||
model_path / "audio_vae" / "diffusion_pytorch_model.safetensors",
|
||||
model_path / "audio_vae.safetensors",
|
||||
]
|
||||
|
||||
# Also check in main model weights
|
||||
main_paths = [
|
||||
model_path / "ltx-2-19b-distilled.safetensors",
|
||||
model_path / "ltx-2-19b-dev.safetensors",
|
||||
]
|
||||
|
||||
for audio_path in audio_vae_paths:
|
||||
if audio_path.exists():
|
||||
print(f"Loading audio VAE weights from {audio_path}...")
|
||||
return mx.load(str(audio_path))
|
||||
|
||||
# Check main model weights for audio_vae keys
|
||||
for main_path in main_paths:
|
||||
if main_path.exists():
|
||||
print(f"Loading audio VAE weights from {main_path.name}...")
|
||||
all_weights = mx.load(str(main_path))
|
||||
# Filter to only audio_vae keys
|
||||
audio_weights = {k: v for k, v in all_weights.items() if "audio_vae" in k}
|
||||
if audio_weights:
|
||||
return audio_weights
|
||||
|
||||
raise FileNotFoundError(f"Audio VAE weights not found in {model_path}")
|
||||
|
||||
|
||||
def load_vocoder_weights(model_path: Path) -> Dict[str, mx.array]:
|
||||
"""Load vocoder weights from LTX-2 model.
|
||||
|
||||
Args:
|
||||
model_path: Path to LTX-2 model directory
|
||||
|
||||
Returns:
|
||||
Dictionary of vocoder weights
|
||||
"""
|
||||
# Try different possible paths for vocoder weights
|
||||
vocoder_paths = [
|
||||
model_path / "vocoder" / "diffusion_pytorch_model.safetensors",
|
||||
model_path / "vocoder.safetensors",
|
||||
]
|
||||
|
||||
# Also check in main model weights
|
||||
main_paths = [
|
||||
model_path / "ltx-2-19b-distilled.safetensors",
|
||||
model_path / "ltx-2-19b-dev.safetensors",
|
||||
]
|
||||
|
||||
for vocoder_path in vocoder_paths:
|
||||
if vocoder_path.exists():
|
||||
print(f"Loading vocoder weights from {vocoder_path}...")
|
||||
return mx.load(str(vocoder_path))
|
||||
|
||||
# Check main model weights for vocoder keys
|
||||
for main_path in main_paths:
|
||||
if main_path.exists():
|
||||
print(f"Loading vocoder weights from {main_path.name}...")
|
||||
all_weights = mx.load(str(main_path))
|
||||
# Filter to only vocoder keys
|
||||
vocoder_weights = {k: v for k, v in all_weights.items() if "vocoder" in k}
|
||||
if vocoder_weights:
|
||||
return vocoder_weights
|
||||
|
||||
raise FileNotFoundError(f"Vocoder weights not found in {model_path}")
|
||||
|
||||
|
||||
def sanitize_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize transformer weight names from PyTorch LTX-2 format to MLX format.
|
||||
|
||||
Args:
|
||||
weights: Dictionary of weights with PyTorch naming
|
||||
|
||||
Returns:
|
||||
Dictionary with MLX-compatible naming for transformer
|
||||
"""
|
||||
sanitized = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Skip non-transformer weights (VAE, vocoder, audio_vae, connectors)
|
||||
if not key.startswith("model.diffusion_model."):
|
||||
continue
|
||||
|
||||
# Remove 'model.diffusion_model.' prefix
|
||||
new_key = key.replace("model.diffusion_model.", "")
|
||||
|
||||
# Handle to_out.0 -> to_out (MLX doesn't use Sequential numbering)
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
|
||||
# Handle feed-forward net naming
|
||||
# PyTorch: ff.net.0.proj -> ff.net_0_proj (or similar)
|
||||
# MLX FeedForward: uses proj_in, proj_out
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
|
||||
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
||||
|
||||
# Handle AdaLN naming - keep emb wrapper, just fix linear naming
|
||||
# PyTorch: adaln_single.emb.timestep_embedder.linear_1 -> adaln_single.emb.timestep_embedder.linear1
|
||||
new_key = new_key.replace(".linear_1.", ".linear1.")
|
||||
new_key = new_key.replace(".linear_2.", ".linear2.")
|
||||
|
||||
# Handle caption projection (keep linear1/linear2 naming for compatibility)
|
||||
# These are already mapped correctly in the sanitization
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize VAE weight names from PyTorch format to MLX format.
|
||||
|
||||
Args:
|
||||
weights: Dictionary of weights with PyTorch naming
|
||||
|
||||
Returns:
|
||||
Dictionary with MLX-compatible naming for VAE decoder
|
||||
"""
|
||||
sanitized = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Skip position_ids (not needed)
|
||||
if "position_ids" in key:
|
||||
continue
|
||||
|
||||
# Only process VAE decoder weights (skip audio_vae, etc.)
|
||||
if not key.startswith("vae."):
|
||||
continue
|
||||
|
||||
# Handle per-channel statistics key mapping
|
||||
# PyTorch: vae.per_channel_statistics.mean-of-means -> per_channel_statistics.mean
|
||||
# PyTorch: vae.per_channel_statistics.std-of-means -> per_channel_statistics.std
|
||||
# Be careful: mean-of-stds_over_std-of-means also ends with std-of-means
|
||||
if "vae.per_channel_statistics" in key:
|
||||
if key == "vae.per_channel_statistics.mean-of-means":
|
||||
new_key = "per_channel_statistics.mean"
|
||||
elif key == "vae.per_channel_statistics.std-of-means":
|
||||
new_key = "per_channel_statistics.std"
|
||||
else:
|
||||
# Skip other per_channel_statistics keys (channel, mean-of-stds, etc.)
|
||||
continue
|
||||
elif key.startswith("vae.decoder."):
|
||||
# Strip the vae.decoder. prefix for decoder weights
|
||||
new_key = key.replace("vae.decoder.", "")
|
||||
else:
|
||||
# Skip other vae.* keys that are not decoder weights
|
||||
continue
|
||||
|
||||
# Handle Conv3d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, D, H, W)
|
||||
# MLX: (out_channels, D, H, W, in_channels)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
|
||||
# Transpose from (O, I, D, H, W) to (O, D, H, W, I)
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
# Handle Conv2d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, H, W)
|
||||
# MLX: (out_channels, H, W, in_channels)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_vae_encoder_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize VAE encoder weight names from PyTorch format to MLX format.
|
||||
|
||||
Args:
|
||||
weights: Dictionary of weights with PyTorch naming
|
||||
|
||||
Returns:
|
||||
Dictionary with MLX-compatible naming for VAE encoder
|
||||
"""
|
||||
sanitized = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Skip position_ids (not needed)
|
||||
if "position_ids" in key:
|
||||
continue
|
||||
|
||||
# Only process VAE encoder weights
|
||||
if not key.startswith("vae."):
|
||||
continue
|
||||
|
||||
# Handle per-channel statistics key mapping
|
||||
if "vae.per_channel_statistics" in key:
|
||||
if key == "vae.per_channel_statistics.mean-of-means":
|
||||
new_key = "per_channel_statistics._mean_of_means"
|
||||
elif key == "vae.per_channel_statistics.std-of-means":
|
||||
new_key = "per_channel_statistics._std_of_means"
|
||||
else:
|
||||
# Skip other per_channel_statistics keys
|
||||
continue
|
||||
elif key.startswith("vae.encoder."):
|
||||
# Strip the vae.encoder. prefix for encoder weights
|
||||
new_key = key.replace("vae.encoder.", "")
|
||||
else:
|
||||
# Skip other vae.* keys that are not encoder weights
|
||||
continue
|
||||
|
||||
# Handle Conv3d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, D, H, W)
|
||||
# MLX: (out_channels, D, H, W, in_channels)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
# Handle Conv2d weight shape conversion
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize audio VAE weight names from PyTorch format to MLX format.
|
||||
|
||||
Args:
|
||||
weights: Dictionary of weights with PyTorch naming
|
||||
|
||||
Returns:
|
||||
Dictionary with MLX-compatible naming for audio VAE decoder
|
||||
"""
|
||||
sanitized = {}
|
||||
|
||||
if "audio_vae." in weights:
|
||||
return weights
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Handle audio_vae.decoder weights
|
||||
if key.startswith("audio_vae.decoder."):
|
||||
new_key = key.replace("audio_vae.decoder.", "")
|
||||
elif key.startswith("audio_vae.per_channel_statistics."):
|
||||
# Map per-channel statistics
|
||||
if "mean-of-means" in key:
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
elif "std-of-means" in key:
|
||||
new_key = "per_channel_statistics.std_of_means"
|
||||
else:
|
||||
continue # Skip other statistics keys
|
||||
else:
|
||||
continue # Skip non-decoder keys
|
||||
|
||||
# Handle Conv2d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, H, W)
|
||||
# MLX: (out_channels, H, W, in_channels)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_vocoder_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize vocoder weight names from PyTorch format to MLX format.
|
||||
|
||||
Args:
|
||||
weights: Dictionary of weights with PyTorch naming
|
||||
|
||||
Returns:
|
||||
Dictionary with MLX-compatible naming for vocoder
|
||||
"""
|
||||
sanitized = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Handle vocoder weights
|
||||
if key.startswith("vocoder."):
|
||||
new_key = key.replace("vocoder.", "")
|
||||
|
||||
# Handle ModuleList indices -> dict keys
|
||||
# PyTorch: ups.0, ups.1, ... -> ups.0, ups.1, ...
|
||||
# PyTorch: resblocks.0, resblocks.1, ... -> resblocks.0, resblocks.1, ...
|
||||
|
||||
# Handle Conv1d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, kernel)
|
||||
# MLX: (out_channels, kernel, in_channels)
|
||||
if "weight" in new_key and value.ndim == 3:
|
||||
if "ups" in new_key:
|
||||
# ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch)
|
||||
value = mx.transpose(value, (1, 2, 0))
|
||||
else:
|
||||
# Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch)
|
||||
value = mx.transpose(value, (0, 2, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize weight names from PyTorch format to MLX format.
|
||||
|
||||
Generic function that handles both transformer and VAE weights.
|
||||
|
||||
Args:
|
||||
weights: Dictionary of weights with PyTorch naming
|
||||
|
||||
Returns:
|
||||
Dictionary with MLX-compatible naming
|
||||
"""
|
||||
sanitized = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Skip position_ids (not needed)
|
||||
if "position_ids" in key:
|
||||
continue
|
||||
|
||||
# Handle transformer weights
|
||||
if key.startswith("model.diffusion_model."):
|
||||
new_key = key.replace("model.diffusion_model.", "")
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
|
||||
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
||||
new_key = new_key.replace(".linear_1.", ".linear1.")
|
||||
new_key = new_key.replace(".linear_2.", ".linear2.")
|
||||
|
||||
# Handle Conv3d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, D, H, W)
|
||||
# MLX: (out_channels, D, H, W, in_channels)
|
||||
if "conv" in key.lower() and "weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
# Handle Conv2d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, H, W)
|
||||
# MLX: (out_channels, H, W, in_channels)
|
||||
if "conv" in key.lower() and "weight" in key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def load_config(model_path: Path) -> Dict[str, Any]:
|
||||
"""Load model configuration.
|
||||
|
||||
Args:
|
||||
model_path: Path to model directory
|
||||
|
||||
Returns:
|
||||
Configuration dictionary
|
||||
"""
|
||||
config_path = model_path / "config.json"
|
||||
|
||||
if config_path.exists():
|
||||
with open(config_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
# Return default config
|
||||
return {}
|
||||
|
||||
|
||||
def create_model_from_config(config: Dict[str, Any]) -> LTXModel:
|
||||
"""Create model instance from configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
LTXModel instance
|
||||
"""
|
||||
# Map config to LTXModelConfig
|
||||
model_config = LTXModelConfig(
|
||||
model_type=LTXModelType.AudioVideo,
|
||||
num_attention_heads=config.get("num_attention_heads", 32),
|
||||
attention_head_dim=config.get("attention_head_dim", 128),
|
||||
in_channels=config.get("in_channels", 128),
|
||||
out_channels=config.get("out_channels", 128),
|
||||
num_layers=config.get("num_layers", 48),
|
||||
cross_attention_dim=config.get("cross_attention_dim", 4096),
|
||||
caption_channels=config.get("caption_channels", 3840),
|
||||
audio_num_attention_heads=config.get("audio_num_attention_heads", 32),
|
||||
audio_attention_head_dim=config.get("audio_attention_head_dim", 64),
|
||||
audio_in_channels=config.get("audio_in_channels", 128),
|
||||
audio_out_channels=config.get("audio_out_channels", 128),
|
||||
audio_cross_attention_dim=config.get("audio_cross_attention_dim", 2048),
|
||||
positional_embedding_theta=config.get("positional_embedding_theta", 10000.0),
|
||||
positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]),
|
||||
audio_positional_embedding_max_pos=config.get("audio_positional_embedding_max_pos", [20]),
|
||||
timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000),
|
||||
av_ca_timestep_scale_multiplier=config.get("av_ca_timestep_scale_multiplier", 1000),
|
||||
norm_eps=config.get("norm_eps", 1e-6),
|
||||
)
|
||||
|
||||
return LTXModel(model_config)
|
||||
|
||||
|
||||
def convert(
|
||||
hf_path: str,
|
||||
mlx_path: str = "mlx_model",
|
||||
dtype: Optional[str] = None,
|
||||
quantize: bool = False,
|
||||
q_bits: int = 4,
|
||||
q_group_size: int = 64,
|
||||
) -> Path:
|
||||
"""Convert HuggingFace model to MLX format.
|
||||
|
||||
Args:
|
||||
hf_path: HuggingFace model path or repo ID
|
||||
mlx_path: Output path for MLX model
|
||||
dtype: Target dtype (float16, float32, bfloat16)
|
||||
quantize: Whether to quantize the model
|
||||
q_bits: Quantization bits
|
||||
q_group_size: Quantization group size
|
||||
|
||||
Returns:
|
||||
Path to converted model
|
||||
"""
|
||||
print(f"Loading model from {hf_path}...")
|
||||
model_path = get_model_path(hf_path)
|
||||
|
||||
# Load config
|
||||
config = load_config(model_path)
|
||||
|
||||
# Load weights
|
||||
print("Loading weights...")
|
||||
weights = load_safetensors(model_path)
|
||||
|
||||
# Sanitize weights
|
||||
print("Sanitizing weights...")
|
||||
weights = sanitize_weights(weights)
|
||||
|
||||
# Convert dtype if specified
|
||||
if dtype is not None:
|
||||
dtype_map = {
|
||||
"float16": mx.float16,
|
||||
"float32": mx.float32,
|
||||
"bfloat16": mx.bfloat16,
|
||||
}
|
||||
target_dtype = dtype_map.get(dtype, mx.float16)
|
||||
print(f"Converting to {dtype}...")
|
||||
weights = {
|
||||
k: v.astype(target_dtype) if v.dtype in [mx.float32, mx.float16, mx.bfloat16] else v
|
||||
for k, v in weights.items()
|
||||
}
|
||||
|
||||
# Create output directory
|
||||
output_path = Path(mlx_path)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save weights
|
||||
print(f"Saving weights to {output_path}...")
|
||||
save_weights(output_path, weights)
|
||||
|
||||
# Save config
|
||||
config_out_path = output_path / "config.json"
|
||||
with open(config_out_path, "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
print(f"Model converted successfully to {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
def save_weights(path: Path, weights: Dict[str, mx.array]) -> None:
|
||||
"""Save weights in safetensors format.
|
||||
|
||||
Uses mx.save_safetensors to preserve exact dtype (especially bfloat16).
|
||||
Converting through numpy loses bfloat16 fidelity since numpy lacks native
|
||||
bfloat16 support.
|
||||
|
||||
Args:
|
||||
path: Output directory
|
||||
weights: Dictionary of weights
|
||||
"""
|
||||
mx.save_safetensors(str(path / "model.safetensors"), weights)
|
||||
|
||||
|
||||
def convert_audio_encoder(
|
||||
model_path: Union[str, Path],
|
||||
source_repo: str = "Lightricks/LTX-2",
|
||||
) -> Path:
|
||||
"""Convert and save audio encoder weights from original HF checkpoint.
|
||||
|
||||
The audio VAE safetensors in the HF repo contains both encoder and decoder
|
||||
weights. This extracts encoder weights, transposes Conv2d for MLX, and saves
|
||||
them to a separate directory for AudioEncoder.from_pretrained().
|
||||
|
||||
Args:
|
||||
model_path: Local model directory (output location).
|
||||
source_repo: HF repo containing audio_vae/diffusion_pytorch_model.safetensors.
|
||||
|
||||
Returns:
|
||||
Path to the audio_vae/encoder directory.
|
||||
"""
|
||||
model_path = Path(model_path)
|
||||
encoder_dir = model_path / "audio_vae" / "encoder"
|
||||
|
||||
if (encoder_dir / "model.safetensors").exists():
|
||||
return encoder_dir
|
||||
|
||||
# Download original audio VAE weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
vae_path = hf_hub_download(
|
||||
source_repo,
|
||||
"audio_vae/diffusion_pytorch_model.safetensors",
|
||||
)
|
||||
|
||||
raw_weights = mx.load(vae_path)
|
||||
|
||||
# Extract encoder weights and per-channel statistics
|
||||
from mlx_video.models.ltx_2.audio_vae import AudioEncoder
|
||||
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
||||
|
||||
# Build config from the decoder config (same audio VAE architecture)
|
||||
decoder_config_path = model_path / "audio_vae" / "decoder" / "config.json"
|
||||
if decoder_config_path.exists():
|
||||
with open(decoder_config_path) as f:
|
||||
dec_cfg = json.load(f)
|
||||
enc_config = {
|
||||
"ch": dec_cfg.get("ch", 128),
|
||||
"in_channels": dec_cfg.get("out_ch", 2),
|
||||
"ch_mult": dec_cfg.get("ch_mult", [1, 2, 4]),
|
||||
"num_res_blocks": dec_cfg.get("num_res_blocks", 2),
|
||||
"attn_resolutions": dec_cfg.get("attn_resolutions", []),
|
||||
"resolution": dec_cfg.get("resolution", 256),
|
||||
"z_channels": dec_cfg.get("z_channels", 8),
|
||||
"double_z": True,
|
||||
"n_fft": 1024,
|
||||
"norm_type": dec_cfg.get("norm_type", "pixel"),
|
||||
"causality_axis": dec_cfg.get("causality_axis", "height"),
|
||||
"dropout": dec_cfg.get("dropout", 0.0),
|
||||
"mid_block_add_attention": dec_cfg.get("mid_block_add_attention", False),
|
||||
"sample_rate": dec_cfg.get("sample_rate", 16000),
|
||||
"mel_hop_length": dec_cfg.get("mel_hop_length", 160),
|
||||
"is_causal": dec_cfg.get("is_causal", True),
|
||||
"mel_bins": dec_cfg.get("mel_bins", 64) or 64,
|
||||
"resamp_with_conv": dec_cfg.get("resamp_with_conv", True),
|
||||
"attn_type": dec_cfg.get("attn_type", "vanilla"),
|
||||
}
|
||||
else:
|
||||
enc_config = {"in_channels": 2, "double_z": True, "n_fft": 1024, "mel_bins": 64}
|
||||
|
||||
# Sanitize weights
|
||||
config = AudioEncoderModelConfig.from_dict(enc_config)
|
||||
encoder = AudioEncoder(config)
|
||||
sanitized = encoder.sanitize(raw_weights)
|
||||
|
||||
# Save
|
||||
encoder_dir.mkdir(parents=True, exist_ok=True)
|
||||
mx.save_safetensors(str(encoder_dir / "model.safetensors"), sanitized)
|
||||
with open(encoder_dir / "config.json", "w") as f:
|
||||
json.dump(enc_config, f, indent=2)
|
||||
|
||||
print(f"Audio encoder weights saved to {encoder_dir}")
|
||||
return encoder_dir
|
||||
|
||||
|
||||
def load_model(
|
||||
path_or_hf_repo: str,
|
||||
lazy: bool = False,
|
||||
) -> LTXModel:
|
||||
"""Load LTX model from path or HuggingFace.
|
||||
|
||||
Args:
|
||||
path_or_hf_repo: Path to model or HuggingFace repo ID
|
||||
lazy: Whether to use lazy loading
|
||||
|
||||
Returns:
|
||||
Loaded LTXModel
|
||||
"""
|
||||
model_path = get_model_path(path_or_hf_repo)
|
||||
|
||||
# Load config
|
||||
config = load_config(model_path)
|
||||
|
||||
# Create model
|
||||
model = create_model_from_config(config)
|
||||
|
||||
# Load weights
|
||||
weights = load_safetensors(model_path)
|
||||
|
||||
# Sanitize if needed
|
||||
weights = sanitize_weights(weights)
|
||||
|
||||
# Load weights into model
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
if not lazy:
|
||||
mx.eval(model.parameters())
|
||||
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert LTX-2 model to MLX format")
|
||||
parser.add_argument(
|
||||
"--hf-path",
|
||||
type=str,
|
||||
default="Lightricks/LTX-2",
|
||||
help="HuggingFace model path or repo ID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mlx-path",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="Output path for MLX model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=["float16", "float32", "bfloat16"],
|
||||
default="float16",
|
||||
help="Target dtype",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize",
|
||||
action="store_true",
|
||||
help="Quantize the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-bits",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Quantization bits",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert(
|
||||
hf_path=args.hf_path,
|
||||
mlx_path=args.mlx_path,
|
||||
dtype=args.dtype,
|
||||
quantize=args.quantize,
|
||||
q_bits=args.q_bits,
|
||||
)
|
||||
Reference in New Issue
Block a user