Refactor weight loading and utility functions for LTX-2 model; remove deprecated weight loading file and update imports accordingly

This commit is contained in:
Prince Canuma
2026-03-16 22:25:22 +01:00
parent dd573d53d2
commit 7a576bfbf4
5 changed files with 182 additions and 789 deletions

View File

@@ -1,16 +1,9 @@
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
from mlx_video.convert import (
load_transformer_weights,
load_vae_weights,
load_audio_vae_weights,
load_vocoder_weights,
sanitize_audio_vae_weights,
sanitize_vocoder_weights,
)
# Audio VAE components # Audio VAE components
from mlx_video.models.ltx_2.audio_vae import ( from mlx_video.models.ltx_2.audio_vae import (
AudioDecoder, AudioDecoder,
AudioEncoder,
Vocoder, Vocoder,
decode_audio, decode_audio,
AudioPatchifier, AudioPatchifier,
@@ -23,19 +16,22 @@ from mlx_video.models.ltx_2.conditioning import (
VideoConditionByLatentIndex, VideoConditionByLatentIndex,
) )
# Utilities
from mlx_video.models.ltx_2.utils import (
convert_audio_encoder,
get_model_path,
load_safetensors,
load_config,
save_weights,
)
__all__ = [ __all__ = [
# Models # Models
"LTXModel", "LTXModel",
"LTXModelConfig", "LTXModelConfig",
# Weight loading
"load_transformer_weights",
"load_vae_weights",
"load_audio_vae_weights",
"load_vocoder_weights",
"sanitize_audio_vae_weights",
"sanitize_vocoder_weights",
# Audio VAE # Audio VAE
"AudioDecoder", "AudioDecoder",
"AudioEncoder",
"Vocoder", "Vocoder",
"decode_audio", "decode_audio",
"AudioPatchifier", "AudioPatchifier",
@@ -43,4 +39,10 @@ __all__ = [
"PerChannelStatistics", "PerChannelStatistics",
# Conditioning # Conditioning
"VideoConditionByLatentIndex", "VideoConditionByLatentIndex",
# Utilities
"convert_audio_encoder",
"get_model_path",
"load_safetensors",
"load_config",
"save_weights",
] ]

View File

@@ -1,2 +1,2 @@
"""Stub — delegates to mlx_video.models.ltx_2.weight_loading.""" """Stub — delegates to mlx_video.models.ltx_2.utils."""
from mlx_video.models.ltx_2.weight_loading import * # noqa: F401,F403 from mlx_video.models.ltx_2.utils import * # noqa: F401,F403

View File

@@ -1633,7 +1633,7 @@ def generate_video(
a2v_sr = None a2v_sr = None
if is_a2v: if is_a2v:
from mlx_video.models.ltx_2.audio_vae.audio_processor import load_audio, ensure_stereo, waveform_to_mel from mlx_video.models.ltx_2.audio_vae.audio_processor import load_audio, ensure_stereo, waveform_to_mel
from mlx_video.convert import convert_audio_encoder from mlx_video.models.ltx_2.utils import convert_audio_encoder
from mlx_video.models.ltx_2.audio_vae import AudioEncoder from mlx_video.models.ltx_2.audio_vae import AudioEncoder
with console.status("[blue]Loading and encoding input audio (A2V)...[/]", spinner="dots"): with console.status("[blue]Loading and encoding input audio (A2V)...[/]", spinner="dots"):

View File

@@ -0,0 +1,161 @@
"""Shared utilities for LTX-2 model loading and conversion."""
import json
from pathlib import Path
from typing import Any, Dict, Optional
import mlx.core as mx
from huggingface_hub import snapshot_download
def get_model_path(
path_or_hf_repo: str,
revision: Optional[str] = None,
) -> Path:
"""Get local path to model, downloading if necessary.
Args:
path_or_hf_repo: Local path or HuggingFace repo ID
revision: Git revision for HF repo
Returns:
Path to model directory
"""
model_path = Path(path_or_hf_repo)
if model_path.exists():
return model_path
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
revision=revision,
allow_patterns=[
"*.safetensors",
"*.json",
"config.json",
],
)
)
return model_path
def load_safetensors(path: Path) -> Dict[str, mx.array]:
"""Load weights from safetensors file(s) using MLX.
Args:
path: Path to model directory or single safetensors file
Returns:
Dictionary of weights
"""
if path.is_file():
return mx.load(str(path))
weights = {}
for sf_path in path.glob("*.safetensors"):
weights.update(mx.load(str(sf_path)))
return weights
def load_config(model_path: Path) -> Dict[str, Any]:
"""Load model configuration from config.json.
Args:
model_path: Path to model directory
Returns:
Configuration dictionary
"""
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path, "r") as f:
return json.load(f)
return {}
def save_weights(path: Path, weights: Dict[str, mx.array]) -> None:
"""Save weights in safetensors format.
Args:
path: Output directory
weights: Dictionary of weights
"""
path.mkdir(parents=True, exist_ok=True)
mx.save_safetensors(str(path / "model.safetensors"), weights)
def convert_audio_encoder(
model_path,
source_repo: str = "Lightricks/LTX-2",
) -> Path:
"""Convert and save audio encoder weights from original HF checkpoint.
Extracts encoder weights from the combined audio VAE safetensors,
transposes Conv2d for MLX, and saves for AudioEncoder.from_pretrained().
Args:
model_path: Local model directory (output location).
source_repo: HF repo containing audio_vae/diffusion_pytorch_model.safetensors.
Returns:
Path to the audio_vae/encoder directory.
"""
model_path = Path(model_path)
encoder_dir = model_path / "audio_vae" / "encoder"
if (encoder_dir / "model.safetensors").exists():
return encoder_dir
from huggingface_hub import hf_hub_download
vae_path = hf_hub_download(
source_repo,
"audio_vae/diffusion_pytorch_model.safetensors",
)
raw_weights = mx.load(vae_path)
from mlx_video.models.ltx_2.audio_vae import AudioEncoder
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
# Build config from the decoder config (same audio VAE architecture)
decoder_config_path = model_path / "audio_vae" / "decoder" / "config.json"
if decoder_config_path.exists():
with open(decoder_config_path) as f:
dec_cfg = json.load(f)
enc_config = {
"ch": dec_cfg.get("ch", 128),
"in_channels": dec_cfg.get("out_ch", 2),
"ch_mult": dec_cfg.get("ch_mult", [1, 2, 4]),
"num_res_blocks": dec_cfg.get("num_res_blocks", 2),
"attn_resolutions": dec_cfg.get("attn_resolutions", []),
"resolution": dec_cfg.get("resolution", 256),
"z_channels": dec_cfg.get("z_channels", 8),
"double_z": True,
"n_fft": 1024,
"norm_type": dec_cfg.get("norm_type", "pixel"),
"causality_axis": dec_cfg.get("causality_axis", "height"),
"dropout": dec_cfg.get("dropout", 0.0),
"mid_block_add_attention": dec_cfg.get("mid_block_add_attention", False),
"sample_rate": dec_cfg.get("sample_rate", 16000),
"mel_hop_length": dec_cfg.get("mel_hop_length", 160),
"is_causal": dec_cfg.get("is_causal", True),
"mel_bins": dec_cfg.get("mel_bins", 64) or 64,
"resamp_with_conv": dec_cfg.get("resamp_with_conv", True),
"attn_type": dec_cfg.get("attn_type", "vanilla"),
}
else:
enc_config = {"in_channels": 2, "double_z": True, "n_fft": 1024, "mel_bins": 64}
config = AudioEncoderModelConfig.from_dict(enc_config)
encoder = AudioEncoder(config)
sanitized = encoder.sanitize(raw_weights)
encoder_dir.mkdir(parents=True, exist_ok=True)
mx.save_safetensors(str(encoder_dir / "model.safetensors"), sanitized)
with open(encoder_dir / "config.json", "w") as f:
json.dump(enc_config, f, indent=2)
print(f"Audio encoder weights saved to {encoder_dir}")
return encoder_dir

View File

@@ -1,770 +0,0 @@
import json
import shutil
from pathlib import Path
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx_video.models.ltx_2.config import LTXModelConfig, LTXModelType
from mlx_video.models.ltx_2.ltx import LTXModel
def get_model_path(
path_or_hf_repo: str,
revision: Optional[str] = None,
) -> Path:
"""Get local path to model, downloading if necessary.
Args:
path_or_hf_repo: Local path or HuggingFace repo ID
revision: Git revision for HF repo
Returns:
Path to model directory
"""
model_path = Path(path_or_hf_repo)
if model_path.exists():
return model_path
# Download from HuggingFace
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
revision=revision,
allow_patterns=[
"*.safetensors",
"*.json",
"config.json",
],
)
)
return model_path
def load_safetensors(path: Path) -> Dict[str, mx.array]:
"""Load weights from safetensors file(s) using MLX.
Args:
path: Path to model directory or single safetensors file
Returns:
Dictionary of weights
"""
weights = {}
if path.is_file():
# Single file - use mx.load directly (handles bfloat16)
return mx.load(str(path))
else:
# Directory - load all safetensors files
safetensor_files = list(path.glob("*.safetensors"))
for sf_path in safetensor_files:
file_weights = mx.load(str(sf_path))
weights.update(file_weights)
return weights
def load_transformer_weights(model_path: Path) -> Dict[str, mx.array]:
"""Load transformer weights from LTX-2 model.
Args:
model_path: Path to LTX-2 model directory
Returns:
Dictionary of transformer weights
"""
# Try distilled model first, then dev
weight_files = [
model_path / "ltx-2-19b-distilled.safetensors",
model_path / "ltx-2-19b-dev.safetensors",
]
for weight_file in weight_files:
if weight_file.exists():
print(f"Loading transformer weights from {weight_file.name}...")
return mx.load(str(weight_file))
raise FileNotFoundError(f"No transformer weights found in {model_path}")
def load_vae_weights(model_path: Path) -> Dict[str, mx.array]:
"""Load VAE weights from LTX-2 model.
Args:
model_path: Path to LTX-2 model directory
Returns:
Dictionary of VAE weights
"""
vae_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
if vae_path.exists():
print(f"Loading VAE weights from {vae_path}...")
return mx.load(str(vae_path))
raise FileNotFoundError(f"VAE weights not found at {vae_path}")
def load_audio_vae_weights(model_path: Path) -> Dict[str, mx.array]:
"""Load audio VAE weights from LTX-2 model.
Args:
model_path: Path to LTX-2 model directory
Returns:
Dictionary of audio VAE weights
"""
# Try different possible paths for audio VAE weights
audio_vae_paths = [
model_path / "audio_vae" / "decoder" / "model.safetensors",
model_path / "audio_vae" / "decoder" / "diffusion_pytorch_model.safetensors",
model_path / "audio_vae" / "diffusion_pytorch_model.safetensors",
model_path / "audio_vae.safetensors",
]
# Also check in main model weights
main_paths = [
model_path / "ltx-2-19b-distilled.safetensors",
model_path / "ltx-2-19b-dev.safetensors",
]
for audio_path in audio_vae_paths:
if audio_path.exists():
print(f"Loading audio VAE weights from {audio_path}...")
return mx.load(str(audio_path))
# Check main model weights for audio_vae keys
for main_path in main_paths:
if main_path.exists():
print(f"Loading audio VAE weights from {main_path.name}...")
all_weights = mx.load(str(main_path))
# Filter to only audio_vae keys
audio_weights = {k: v for k, v in all_weights.items() if "audio_vae" in k}
if audio_weights:
return audio_weights
raise FileNotFoundError(f"Audio VAE weights not found in {model_path}")
def load_vocoder_weights(model_path: Path) -> Dict[str, mx.array]:
"""Load vocoder weights from LTX-2 model.
Args:
model_path: Path to LTX-2 model directory
Returns:
Dictionary of vocoder weights
"""
# Try different possible paths for vocoder weights
vocoder_paths = [
model_path / "vocoder" / "diffusion_pytorch_model.safetensors",
model_path / "vocoder.safetensors",
]
# Also check in main model weights
main_paths = [
model_path / "ltx-2-19b-distilled.safetensors",
model_path / "ltx-2-19b-dev.safetensors",
]
for vocoder_path in vocoder_paths:
if vocoder_path.exists():
print(f"Loading vocoder weights from {vocoder_path}...")
return mx.load(str(vocoder_path))
# Check main model weights for vocoder keys
for main_path in main_paths:
if main_path.exists():
print(f"Loading vocoder weights from {main_path.name}...")
all_weights = mx.load(str(main_path))
# Filter to only vocoder keys
vocoder_weights = {k: v for k, v in all_weights.items() if "vocoder" in k}
if vocoder_weights:
return vocoder_weights
raise FileNotFoundError(f"Vocoder weights not found in {model_path}")
def sanitize_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize transformer weight names from PyTorch LTX-2 format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for transformer
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Skip non-transformer weights (VAE, vocoder, audio_vae, connectors)
if not key.startswith("model.diffusion_model."):
continue
# Remove 'model.diffusion_model.' prefix
new_key = key.replace("model.diffusion_model.", "")
# Handle to_out.0 -> to_out (MLX doesn't use Sequential numbering)
new_key = new_key.replace(".to_out.0.", ".to_out.")
# Handle feed-forward net naming
# PyTorch: ff.net.0.proj -> ff.net_0_proj (or similar)
# MLX FeedForward: uses proj_in, proj_out
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
# Handle AdaLN naming - keep emb wrapper, just fix linear naming
# PyTorch: adaln_single.emb.timestep_embedder.linear_1 -> adaln_single.emb.timestep_embedder.linear1
new_key = new_key.replace(".linear_1.", ".linear1.")
new_key = new_key.replace(".linear_2.", ".linear2.")
# Handle caption projection (keep linear1/linear2 naming for compatibility)
# These are already mapped correctly in the sanitization
sanitized[new_key] = value
return sanitized
def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize VAE weight names from PyTorch format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for VAE decoder
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Skip position_ids (not needed)
if "position_ids" in key:
continue
# Only process VAE decoder weights (skip audio_vae, etc.)
if not key.startswith("vae."):
continue
# Handle per-channel statistics key mapping
# PyTorch: vae.per_channel_statistics.mean-of-means -> per_channel_statistics.mean
# PyTorch: vae.per_channel_statistics.std-of-means -> per_channel_statistics.std
# Be careful: mean-of-stds_over_std-of-means also ends with std-of-means
if "vae.per_channel_statistics" in key:
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
# Skip other per_channel_statistics keys (channel, mean-of-stds, etc.)
continue
elif key.startswith("vae.decoder."):
# Strip the vae.decoder. prefix for decoder weights
new_key = key.replace("vae.decoder.", "")
else:
# Skip other vae.* keys that are not decoder weights
continue
# Handle Conv3d weight shape conversion
# PyTorch: (out_channels, in_channels, D, H, W)
# MLX: (out_channels, D, H, W, in_channels)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
# Transpose from (O, I, D, H, W) to (O, D, H, W, I)
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Handle Conv2d weight shape conversion
# PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
def sanitize_vae_encoder_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize VAE encoder weight names from PyTorch format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for VAE encoder
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Skip position_ids (not needed)
if "position_ids" in key:
continue
# Only process VAE encoder weights
if not key.startswith("vae."):
continue
# Handle per-channel statistics key mapping
if "vae.per_channel_statistics" in key:
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics._mean_of_means"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics._std_of_means"
else:
# Skip other per_channel_statistics keys
continue
elif key.startswith("vae.encoder."):
# Strip the vae.encoder. prefix for encoder weights
new_key = key.replace("vae.encoder.", "")
else:
# Skip other vae.* keys that are not encoder weights
continue
# Handle Conv3d weight shape conversion
# PyTorch: (out_channels, in_channels, D, H, W)
# MLX: (out_channels, D, H, W, in_channels)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Handle Conv2d weight shape conversion
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize audio VAE weight names from PyTorch format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for audio VAE decoder
"""
sanitized = {}
if "audio_vae." in weights:
return weights
for key, value in weights.items():
new_key = key
# Handle audio_vae.decoder weights
if key.startswith("audio_vae.decoder."):
new_key = key.replace("audio_vae.decoder.", "")
elif key.startswith("audio_vae.per_channel_statistics."):
# Map per-channel statistics
if "mean-of-means" in key:
new_key = "per_channel_statistics.mean_of_means"
elif "std-of-means" in key:
new_key = "per_channel_statistics.std_of_means"
else:
continue # Skip other statistics keys
else:
continue # Skip non-decoder keys
# Handle Conv2d weight shape conversion
# PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
def sanitize_vocoder_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize vocoder weight names from PyTorch format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for vocoder
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Handle vocoder weights
if key.startswith("vocoder."):
new_key = key.replace("vocoder.", "")
# Handle ModuleList indices -> dict keys
# PyTorch: ups.0, ups.1, ... -> ups.0, ups.1, ...
# PyTorch: resblocks.0, resblocks.1, ... -> resblocks.0, resblocks.1, ...
# Handle Conv1d weight shape conversion
# PyTorch: (out_channels, in_channels, kernel)
# MLX: (out_channels, kernel, in_channels)
if "weight" in new_key and value.ndim == 3:
if "ups" in new_key:
# ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch)
value = mx.transpose(value, (1, 2, 0))
else:
# Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch)
value = mx.transpose(value, (0, 2, 1))
sanitized[new_key] = value
return sanitized
def sanitize_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize weight names from PyTorch format to MLX format.
Generic function that handles both transformer and VAE weights.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Skip position_ids (not needed)
if "position_ids" in key:
continue
# Handle transformer weights
if key.startswith("model.diffusion_model."):
new_key = key.replace("model.diffusion_model.", "")
new_key = new_key.replace(".to_out.0.", ".to_out.")
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
new_key = new_key.replace(".linear_1.", ".linear1.")
new_key = new_key.replace(".linear_2.", ".linear2.")
# Handle Conv3d weight shape conversion
# PyTorch: (out_channels, in_channels, D, H, W)
# MLX: (out_channels, D, H, W, in_channels)
if "conv" in key.lower() and "weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Handle Conv2d weight shape conversion
# PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels)
if "conv" in key.lower() and "weight" in key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
def load_config(model_path: Path) -> Dict[str, Any]:
"""Load model configuration.
Args:
model_path: Path to model directory
Returns:
Configuration dictionary
"""
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path, "r") as f:
return json.load(f)
# Return default config
return {}
def create_model_from_config(config: Dict[str, Any]) -> LTXModel:
"""Create model instance from configuration.
Args:
config: Configuration dictionary
Returns:
LTXModel instance
"""
# Map config to LTXModelConfig
model_config = LTXModelConfig(
model_type=LTXModelType.AudioVideo,
num_attention_heads=config.get("num_attention_heads", 32),
attention_head_dim=config.get("attention_head_dim", 128),
in_channels=config.get("in_channels", 128),
out_channels=config.get("out_channels", 128),
num_layers=config.get("num_layers", 48),
cross_attention_dim=config.get("cross_attention_dim", 4096),
caption_channels=config.get("caption_channels", 3840),
audio_num_attention_heads=config.get("audio_num_attention_heads", 32),
audio_attention_head_dim=config.get("audio_attention_head_dim", 64),
audio_in_channels=config.get("audio_in_channels", 128),
audio_out_channels=config.get("audio_out_channels", 128),
audio_cross_attention_dim=config.get("audio_cross_attention_dim", 2048),
positional_embedding_theta=config.get("positional_embedding_theta", 10000.0),
positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]),
audio_positional_embedding_max_pos=config.get("audio_positional_embedding_max_pos", [20]),
timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000),
av_ca_timestep_scale_multiplier=config.get("av_ca_timestep_scale_multiplier", 1000),
norm_eps=config.get("norm_eps", 1e-6),
)
return LTXModel(model_config)
def convert(
hf_path: str,
mlx_path: str = "mlx_model",
dtype: Optional[str] = None,
quantize: bool = False,
q_bits: int = 4,
q_group_size: int = 64,
) -> Path:
"""Convert HuggingFace model to MLX format.
Args:
hf_path: HuggingFace model path or repo ID
mlx_path: Output path for MLX model
dtype: Target dtype (float16, float32, bfloat16)
quantize: Whether to quantize the model
q_bits: Quantization bits
q_group_size: Quantization group size
Returns:
Path to converted model
"""
print(f"Loading model from {hf_path}...")
model_path = get_model_path(hf_path)
# Load config
config = load_config(model_path)
# Load weights
print("Loading weights...")
weights = load_safetensors(model_path)
# Sanitize weights
print("Sanitizing weights...")
weights = sanitize_weights(weights)
# Convert dtype if specified
if dtype is not None:
dtype_map = {
"float16": mx.float16,
"float32": mx.float32,
"bfloat16": mx.bfloat16,
}
target_dtype = dtype_map.get(dtype, mx.float16)
print(f"Converting to {dtype}...")
weights = {
k: v.astype(target_dtype) if v.dtype in [mx.float32, mx.float16, mx.bfloat16] else v
for k, v in weights.items()
}
# Create output directory
output_path = Path(mlx_path)
output_path.mkdir(parents=True, exist_ok=True)
# Save weights
print(f"Saving weights to {output_path}...")
save_weights(output_path, weights)
# Save config
config_out_path = output_path / "config.json"
with open(config_out_path, "w") as f:
json.dump(config, f, indent=2)
print(f"Model converted successfully to {output_path}")
return output_path
def save_weights(path: Path, weights: Dict[str, mx.array]) -> None:
"""Save weights in safetensors format.
Uses mx.save_safetensors to preserve exact dtype (especially bfloat16).
Converting through numpy loses bfloat16 fidelity since numpy lacks native
bfloat16 support.
Args:
path: Output directory
weights: Dictionary of weights
"""
mx.save_safetensors(str(path / "model.safetensors"), weights)
def convert_audio_encoder(
model_path: Union[str, Path],
source_repo: str = "Lightricks/LTX-2",
) -> Path:
"""Convert and save audio encoder weights from original HF checkpoint.
The audio VAE safetensors in the HF repo contains both encoder and decoder
weights. This extracts encoder weights, transposes Conv2d for MLX, and saves
them to a separate directory for AudioEncoder.from_pretrained().
Args:
model_path: Local model directory (output location).
source_repo: HF repo containing audio_vae/diffusion_pytorch_model.safetensors.
Returns:
Path to the audio_vae/encoder directory.
"""
model_path = Path(model_path)
encoder_dir = model_path / "audio_vae" / "encoder"
if (encoder_dir / "model.safetensors").exists():
return encoder_dir
# Download original audio VAE weights
from huggingface_hub import hf_hub_download
vae_path = hf_hub_download(
source_repo,
"audio_vae/diffusion_pytorch_model.safetensors",
)
raw_weights = mx.load(vae_path)
# Extract encoder weights and per-channel statistics
from mlx_video.models.ltx_2.audio_vae import AudioEncoder
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
# Build config from the decoder config (same audio VAE architecture)
decoder_config_path = model_path / "audio_vae" / "decoder" / "config.json"
if decoder_config_path.exists():
with open(decoder_config_path) as f:
dec_cfg = json.load(f)
enc_config = {
"ch": dec_cfg.get("ch", 128),
"in_channels": dec_cfg.get("out_ch", 2),
"ch_mult": dec_cfg.get("ch_mult", [1, 2, 4]),
"num_res_blocks": dec_cfg.get("num_res_blocks", 2),
"attn_resolutions": dec_cfg.get("attn_resolutions", []),
"resolution": dec_cfg.get("resolution", 256),
"z_channels": dec_cfg.get("z_channels", 8),
"double_z": True,
"n_fft": 1024,
"norm_type": dec_cfg.get("norm_type", "pixel"),
"causality_axis": dec_cfg.get("causality_axis", "height"),
"dropout": dec_cfg.get("dropout", 0.0),
"mid_block_add_attention": dec_cfg.get("mid_block_add_attention", False),
"sample_rate": dec_cfg.get("sample_rate", 16000),
"mel_hop_length": dec_cfg.get("mel_hop_length", 160),
"is_causal": dec_cfg.get("is_causal", True),
"mel_bins": dec_cfg.get("mel_bins", 64) or 64,
"resamp_with_conv": dec_cfg.get("resamp_with_conv", True),
"attn_type": dec_cfg.get("attn_type", "vanilla"),
}
else:
enc_config = {"in_channels": 2, "double_z": True, "n_fft": 1024, "mel_bins": 64}
# Sanitize weights
config = AudioEncoderModelConfig.from_dict(enc_config)
encoder = AudioEncoder(config)
sanitized = encoder.sanitize(raw_weights)
# Save
encoder_dir.mkdir(parents=True, exist_ok=True)
mx.save_safetensors(str(encoder_dir / "model.safetensors"), sanitized)
with open(encoder_dir / "config.json", "w") as f:
json.dump(enc_config, f, indent=2)
print(f"Audio encoder weights saved to {encoder_dir}")
return encoder_dir
def load_model(
path_or_hf_repo: str,
lazy: bool = False,
) -> LTXModel:
"""Load LTX model from path or HuggingFace.
Args:
path_or_hf_repo: Path to model or HuggingFace repo ID
lazy: Whether to use lazy loading
Returns:
Loaded LTXModel
"""
model_path = get_model_path(path_or_hf_repo)
# Load config
config = load_config(model_path)
# Create model
model = create_model_from_config(config)
# Load weights
weights = load_safetensors(model_path)
# Sanitize if needed
weights = sanitize_weights(weights)
# Load weights into model
model.load_weights(list(weights.items()))
if not lazy:
mx.eval(model.parameters())
return model
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Convert LTX-2 model to MLX format")
parser.add_argument(
"--hf-path",
type=str,
default="Lightricks/LTX-2",
help="HuggingFace model path or repo ID",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="Output path for MLX model",
)
parser.add_argument(
"--dtype",
type=str,
choices=["float16", "float32", "bfloat16"],
default="float16",
help="Target dtype",
)
parser.add_argument(
"--quantize",
action="store_true",
help="Quantize the model",
)
parser.add_argument(
"--q-bits",
type=int,
default=4,
help="Quantization bits",
)
args = parser.parse_args()
convert(
hf_path=args.hf_path,
mlx_path=args.mlx_path,
dtype=args.dtype,
quantize=args.quantize,
q_bits=args.q_bits,
)