diff --git a/mlx_video/__init__.py b/mlx_video/__init__.py index cea80ec..38e276b 100644 --- a/mlx_video/__init__.py +++ b/mlx_video/__init__.py @@ -1,16 +1,9 @@ from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig -from mlx_video.convert import ( - load_transformer_weights, - load_vae_weights, - load_audio_vae_weights, - load_vocoder_weights, - sanitize_audio_vae_weights, - sanitize_vocoder_weights, -) # Audio VAE components from mlx_video.models.ltx_2.audio_vae import ( AudioDecoder, + AudioEncoder, Vocoder, decode_audio, AudioPatchifier, @@ -23,19 +16,22 @@ from mlx_video.models.ltx_2.conditioning import ( VideoConditionByLatentIndex, ) +# Utilities +from mlx_video.models.ltx_2.utils import ( + convert_audio_encoder, + get_model_path, + load_safetensors, + load_config, + save_weights, +) + __all__ = [ # Models "LTXModel", "LTXModelConfig", - # Weight loading - "load_transformer_weights", - "load_vae_weights", - "load_audio_vae_weights", - "load_vocoder_weights", - "sanitize_audio_vae_weights", - "sanitize_vocoder_weights", # Audio VAE "AudioDecoder", + "AudioEncoder", "Vocoder", "decode_audio", "AudioPatchifier", @@ -43,4 +39,10 @@ __all__ = [ "PerChannelStatistics", # Conditioning "VideoConditionByLatentIndex", -] \ No newline at end of file + # Utilities + "convert_audio_encoder", + "get_model_path", + "load_safetensors", + "load_config", + "save_weights", +] diff --git a/mlx_video/convert.py b/mlx_video/convert.py index 1947eea..05c736b 100644 --- a/mlx_video/convert.py +++ b/mlx_video/convert.py @@ -1,2 +1,2 @@ -"""Stub — delegates to mlx_video.models.ltx_2.weight_loading.""" -from mlx_video.models.ltx_2.weight_loading import * # noqa: F401,F403 +"""Stub — delegates to mlx_video.models.ltx_2.utils.""" +from mlx_video.models.ltx_2.utils import * # noqa: F401,F403 diff --git a/mlx_video/models/ltx_2/generate.py b/mlx_video/models/ltx_2/generate.py index 6f49147..c7df2dc 100644 --- a/mlx_video/models/ltx_2/generate.py +++ b/mlx_video/models/ltx_2/generate.py @@ -1633,7 +1633,7 @@ def generate_video( a2v_sr = None if is_a2v: from mlx_video.models.ltx_2.audio_vae.audio_processor import load_audio, ensure_stereo, waveform_to_mel - from mlx_video.convert import convert_audio_encoder + from mlx_video.models.ltx_2.utils import convert_audio_encoder from mlx_video.models.ltx_2.audio_vae import AudioEncoder with console.status("[blue]Loading and encoding input audio (A2V)...[/]", spinner="dots"): diff --git a/mlx_video/models/ltx_2/utils.py b/mlx_video/models/ltx_2/utils.py new file mode 100644 index 0000000..a603539 --- /dev/null +++ b/mlx_video/models/ltx_2/utils.py @@ -0,0 +1,161 @@ +"""Shared utilities for LTX-2 model loading and conversion.""" + +import json +from pathlib import Path +from typing import Any, Dict, Optional + +import mlx.core as mx +from huggingface_hub import snapshot_download + + +def get_model_path( + path_or_hf_repo: str, + revision: Optional[str] = None, +) -> Path: + """Get local path to model, downloading if necessary. + + Args: + path_or_hf_repo: Local path or HuggingFace repo ID + revision: Git revision for HF repo + + Returns: + Path to model directory + """ + model_path = Path(path_or_hf_repo) + + if model_path.exists(): + return model_path + + model_path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + revision=revision, + allow_patterns=[ + "*.safetensors", + "*.json", + "config.json", + ], + ) + ) + + return model_path + + +def load_safetensors(path: Path) -> Dict[str, mx.array]: + """Load weights from safetensors file(s) using MLX. + + Args: + path: Path to model directory or single safetensors file + + Returns: + Dictionary of weights + """ + if path.is_file(): + return mx.load(str(path)) + + weights = {} + for sf_path in path.glob("*.safetensors"): + weights.update(mx.load(str(sf_path))) + return weights + + +def load_config(model_path: Path) -> Dict[str, Any]: + """Load model configuration from config.json. + + Args: + model_path: Path to model directory + + Returns: + Configuration dictionary + """ + config_path = model_path / "config.json" + if config_path.exists(): + with open(config_path, "r") as f: + return json.load(f) + return {} + + +def save_weights(path: Path, weights: Dict[str, mx.array]) -> None: + """Save weights in safetensors format. + + Args: + path: Output directory + weights: Dictionary of weights + """ + path.mkdir(parents=True, exist_ok=True) + mx.save_safetensors(str(path / "model.safetensors"), weights) + + +def convert_audio_encoder( + model_path, + source_repo: str = "Lightricks/LTX-2", +) -> Path: + """Convert and save audio encoder weights from original HF checkpoint. + + Extracts encoder weights from the combined audio VAE safetensors, + transposes Conv2d for MLX, and saves for AudioEncoder.from_pretrained(). + + Args: + model_path: Local model directory (output location). + source_repo: HF repo containing audio_vae/diffusion_pytorch_model.safetensors. + + Returns: + Path to the audio_vae/encoder directory. + """ + model_path = Path(model_path) + encoder_dir = model_path / "audio_vae" / "encoder" + + if (encoder_dir / "model.safetensors").exists(): + return encoder_dir + + from huggingface_hub import hf_hub_download + vae_path = hf_hub_download( + source_repo, + "audio_vae/diffusion_pytorch_model.safetensors", + ) + + raw_weights = mx.load(vae_path) + + from mlx_video.models.ltx_2.audio_vae import AudioEncoder + from mlx_video.models.ltx_2.config import AudioEncoderModelConfig + + # Build config from the decoder config (same audio VAE architecture) + decoder_config_path = model_path / "audio_vae" / "decoder" / "config.json" + if decoder_config_path.exists(): + with open(decoder_config_path) as f: + dec_cfg = json.load(f) + enc_config = { + "ch": dec_cfg.get("ch", 128), + "in_channels": dec_cfg.get("out_ch", 2), + "ch_mult": dec_cfg.get("ch_mult", [1, 2, 4]), + "num_res_blocks": dec_cfg.get("num_res_blocks", 2), + "attn_resolutions": dec_cfg.get("attn_resolutions", []), + "resolution": dec_cfg.get("resolution", 256), + "z_channels": dec_cfg.get("z_channels", 8), + "double_z": True, + "n_fft": 1024, + "norm_type": dec_cfg.get("norm_type", "pixel"), + "causality_axis": dec_cfg.get("causality_axis", "height"), + "dropout": dec_cfg.get("dropout", 0.0), + "mid_block_add_attention": dec_cfg.get("mid_block_add_attention", False), + "sample_rate": dec_cfg.get("sample_rate", 16000), + "mel_hop_length": dec_cfg.get("mel_hop_length", 160), + "is_causal": dec_cfg.get("is_causal", True), + "mel_bins": dec_cfg.get("mel_bins", 64) or 64, + "resamp_with_conv": dec_cfg.get("resamp_with_conv", True), + "attn_type": dec_cfg.get("attn_type", "vanilla"), + } + else: + enc_config = {"in_channels": 2, "double_z": True, "n_fft": 1024, "mel_bins": 64} + + config = AudioEncoderModelConfig.from_dict(enc_config) + encoder = AudioEncoder(config) + sanitized = encoder.sanitize(raw_weights) + + encoder_dir.mkdir(parents=True, exist_ok=True) + mx.save_safetensors(str(encoder_dir / "model.safetensors"), sanitized) + with open(encoder_dir / "config.json", "w") as f: + json.dump(enc_config, f, indent=2) + + print(f"Audio encoder weights saved to {encoder_dir}") + return encoder_dir diff --git a/mlx_video/models/ltx_2/weight_loading.py b/mlx_video/models/ltx_2/weight_loading.py deleted file mode 100644 index 234bb08..0000000 --- a/mlx_video/models/ltx_2/weight_loading.py +++ /dev/null @@ -1,770 +0,0 @@ -import json -import shutil -from pathlib import Path -from typing import Any, Dict, Optional, Union - -import mlx.core as mx -import mlx.nn as nn -from huggingface_hub import snapshot_download - -from mlx_video.models.ltx_2.config import LTXModelConfig, LTXModelType -from mlx_video.models.ltx_2.ltx import LTXModel - - -def get_model_path( - path_or_hf_repo: str, - revision: Optional[str] = None, -) -> Path: - """Get local path to model, downloading if necessary. - - Args: - path_or_hf_repo: Local path or HuggingFace repo ID - revision: Git revision for HF repo - - Returns: - Path to model directory - """ - model_path = Path(path_or_hf_repo) - - if model_path.exists(): - return model_path - - # Download from HuggingFace - model_path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - revision=revision, - allow_patterns=[ - "*.safetensors", - "*.json", - "config.json", - ], - ) - ) - - return model_path - - -def load_safetensors(path: Path) -> Dict[str, mx.array]: - """Load weights from safetensors file(s) using MLX. - - Args: - path: Path to model directory or single safetensors file - - Returns: - Dictionary of weights - """ - weights = {} - - if path.is_file(): - # Single file - use mx.load directly (handles bfloat16) - return mx.load(str(path)) - else: - # Directory - load all safetensors files - safetensor_files = list(path.glob("*.safetensors")) - for sf_path in safetensor_files: - file_weights = mx.load(str(sf_path)) - weights.update(file_weights) - - return weights - - -def load_transformer_weights(model_path: Path) -> Dict[str, mx.array]: - """Load transformer weights from LTX-2 model. - - Args: - model_path: Path to LTX-2 model directory - - Returns: - Dictionary of transformer weights - """ - # Try distilled model first, then dev - weight_files = [ - model_path / "ltx-2-19b-distilled.safetensors", - model_path / "ltx-2-19b-dev.safetensors", - ] - - for weight_file in weight_files: - if weight_file.exists(): - print(f"Loading transformer weights from {weight_file.name}...") - return mx.load(str(weight_file)) - - raise FileNotFoundError(f"No transformer weights found in {model_path}") - - -def load_vae_weights(model_path: Path) -> Dict[str, mx.array]: - """Load VAE weights from LTX-2 model. - - Args: - model_path: Path to LTX-2 model directory - - Returns: - Dictionary of VAE weights - """ - vae_path = model_path / "vae" / "diffusion_pytorch_model.safetensors" - if vae_path.exists(): - print(f"Loading VAE weights from {vae_path}...") - return mx.load(str(vae_path)) - - raise FileNotFoundError(f"VAE weights not found at {vae_path}") - - -def load_audio_vae_weights(model_path: Path) -> Dict[str, mx.array]: - """Load audio VAE weights from LTX-2 model. - - Args: - model_path: Path to LTX-2 model directory - - Returns: - Dictionary of audio VAE weights - """ - # Try different possible paths for audio VAE weights - audio_vae_paths = [ - model_path / "audio_vae" / "decoder" / "model.safetensors", - model_path / "audio_vae" / "decoder" / "diffusion_pytorch_model.safetensors", - model_path / "audio_vae" / "diffusion_pytorch_model.safetensors", - model_path / "audio_vae.safetensors", - ] - - # Also check in main model weights - main_paths = [ - model_path / "ltx-2-19b-distilled.safetensors", - model_path / "ltx-2-19b-dev.safetensors", - ] - - for audio_path in audio_vae_paths: - if audio_path.exists(): - print(f"Loading audio VAE weights from {audio_path}...") - return mx.load(str(audio_path)) - - # Check main model weights for audio_vae keys - for main_path in main_paths: - if main_path.exists(): - print(f"Loading audio VAE weights from {main_path.name}...") - all_weights = mx.load(str(main_path)) - # Filter to only audio_vae keys - audio_weights = {k: v for k, v in all_weights.items() if "audio_vae" in k} - if audio_weights: - return audio_weights - - raise FileNotFoundError(f"Audio VAE weights not found in {model_path}") - - -def load_vocoder_weights(model_path: Path) -> Dict[str, mx.array]: - """Load vocoder weights from LTX-2 model. - - Args: - model_path: Path to LTX-2 model directory - - Returns: - Dictionary of vocoder weights - """ - # Try different possible paths for vocoder weights - vocoder_paths = [ - model_path / "vocoder" / "diffusion_pytorch_model.safetensors", - model_path / "vocoder.safetensors", - ] - - # Also check in main model weights - main_paths = [ - model_path / "ltx-2-19b-distilled.safetensors", - model_path / "ltx-2-19b-dev.safetensors", - ] - - for vocoder_path in vocoder_paths: - if vocoder_path.exists(): - print(f"Loading vocoder weights from {vocoder_path}...") - return mx.load(str(vocoder_path)) - - # Check main model weights for vocoder keys - for main_path in main_paths: - if main_path.exists(): - print(f"Loading vocoder weights from {main_path.name}...") - all_weights = mx.load(str(main_path)) - # Filter to only vocoder keys - vocoder_weights = {k: v for k, v in all_weights.items() if "vocoder" in k} - if vocoder_weights: - return vocoder_weights - - raise FileNotFoundError(f"Vocoder weights not found in {model_path}") - - -def sanitize_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: - """Sanitize transformer weight names from PyTorch LTX-2 format to MLX format. - - Args: - weights: Dictionary of weights with PyTorch naming - - Returns: - Dictionary with MLX-compatible naming for transformer - """ - sanitized = {} - - for key, value in weights.items(): - new_key = key - - # Skip non-transformer weights (VAE, vocoder, audio_vae, connectors) - if not key.startswith("model.diffusion_model."): - continue - - # Remove 'model.diffusion_model.' prefix - new_key = key.replace("model.diffusion_model.", "") - - # Handle to_out.0 -> to_out (MLX doesn't use Sequential numbering) - new_key = new_key.replace(".to_out.0.", ".to_out.") - - # Handle feed-forward net naming - # PyTorch: ff.net.0.proj -> ff.net_0_proj (or similar) - # MLX FeedForward: uses proj_in, proj_out - new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") - new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") - new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.") - new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.") - - # Handle AdaLN naming - keep emb wrapper, just fix linear naming - # PyTorch: adaln_single.emb.timestep_embedder.linear_1 -> adaln_single.emb.timestep_embedder.linear1 - new_key = new_key.replace(".linear_1.", ".linear1.") - new_key = new_key.replace(".linear_2.", ".linear2.") - - # Handle caption projection (keep linear1/linear2 naming for compatibility) - # These are already mapped correctly in the sanitization - - sanitized[new_key] = value - - return sanitized - - -def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: - """Sanitize VAE weight names from PyTorch format to MLX format. - - Args: - weights: Dictionary of weights with PyTorch naming - - Returns: - Dictionary with MLX-compatible naming for VAE decoder - """ - sanitized = {} - - for key, value in weights.items(): - new_key = key - - # Skip position_ids (not needed) - if "position_ids" in key: - continue - - # Only process VAE decoder weights (skip audio_vae, etc.) - if not key.startswith("vae."): - continue - - # Handle per-channel statistics key mapping - # PyTorch: vae.per_channel_statistics.mean-of-means -> per_channel_statistics.mean - # PyTorch: vae.per_channel_statistics.std-of-means -> per_channel_statistics.std - # Be careful: mean-of-stds_over_std-of-means also ends with std-of-means - if "vae.per_channel_statistics" in key: - if key == "vae.per_channel_statistics.mean-of-means": - new_key = "per_channel_statistics.mean" - elif key == "vae.per_channel_statistics.std-of-means": - new_key = "per_channel_statistics.std" - else: - # Skip other per_channel_statistics keys (channel, mean-of-stds, etc.) - continue - elif key.startswith("vae.decoder."): - # Strip the vae.decoder. prefix for decoder weights - new_key = key.replace("vae.decoder.", "") - else: - # Skip other vae.* keys that are not decoder weights - continue - - # Handle Conv3d weight shape conversion - # PyTorch: (out_channels, in_channels, D, H, W) - # MLX: (out_channels, D, H, W, in_channels) - if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5: - # Transpose from (O, I, D, H, W) to (O, D, H, W, I) - value = mx.transpose(value, (0, 2, 3, 4, 1)) - - # Handle Conv2d weight shape conversion - # PyTorch: (out_channels, in_channels, H, W) - # MLX: (out_channels, H, W, in_channels) - if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4: - value = mx.transpose(value, (0, 2, 3, 1)) - - sanitized[new_key] = value - - return sanitized - - -def sanitize_vae_encoder_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: - """Sanitize VAE encoder weight names from PyTorch format to MLX format. - - Args: - weights: Dictionary of weights with PyTorch naming - - Returns: - Dictionary with MLX-compatible naming for VAE encoder - """ - sanitized = {} - - for key, value in weights.items(): - new_key = key - - # Skip position_ids (not needed) - if "position_ids" in key: - continue - - # Only process VAE encoder weights - if not key.startswith("vae."): - continue - - # Handle per-channel statistics key mapping - if "vae.per_channel_statistics" in key: - if key == "vae.per_channel_statistics.mean-of-means": - new_key = "per_channel_statistics._mean_of_means" - elif key == "vae.per_channel_statistics.std-of-means": - new_key = "per_channel_statistics._std_of_means" - else: - # Skip other per_channel_statistics keys - continue - elif key.startswith("vae.encoder."): - # Strip the vae.encoder. prefix for encoder weights - new_key = key.replace("vae.encoder.", "") - else: - # Skip other vae.* keys that are not encoder weights - continue - - # Handle Conv3d weight shape conversion - # PyTorch: (out_channels, in_channels, D, H, W) - # MLX: (out_channels, D, H, W, in_channels) - if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5: - value = mx.transpose(value, (0, 2, 3, 4, 1)) - - # Handle Conv2d weight shape conversion - if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4: - value = mx.transpose(value, (0, 2, 3, 1)) - - sanitized[new_key] = value - - return sanitized - - -def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: - """Sanitize audio VAE weight names from PyTorch format to MLX format. - - Args: - weights: Dictionary of weights with PyTorch naming - - Returns: - Dictionary with MLX-compatible naming for audio VAE decoder - """ - sanitized = {} - - if "audio_vae." in weights: - return weights - - for key, value in weights.items(): - new_key = key - - # Handle audio_vae.decoder weights - if key.startswith("audio_vae.decoder."): - new_key = key.replace("audio_vae.decoder.", "") - elif key.startswith("audio_vae.per_channel_statistics."): - # Map per-channel statistics - if "mean-of-means" in key: - new_key = "per_channel_statistics.mean_of_means" - elif "std-of-means" in key: - new_key = "per_channel_statistics.std_of_means" - else: - continue # Skip other statistics keys - else: - continue # Skip non-decoder keys - - # Handle Conv2d weight shape conversion - # PyTorch: (out_channels, in_channels, H, W) - # MLX: (out_channels, H, W, in_channels) - if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4: - value = mx.transpose(value, (0, 2, 3, 1)) - - sanitized[new_key] = value - - return sanitized - - -def sanitize_vocoder_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: - """Sanitize vocoder weight names from PyTorch format to MLX format. - - Args: - weights: Dictionary of weights with PyTorch naming - - Returns: - Dictionary with MLX-compatible naming for vocoder - """ - sanitized = {} - - for key, value in weights.items(): - new_key = key - - # Handle vocoder weights - if key.startswith("vocoder."): - new_key = key.replace("vocoder.", "") - - # Handle ModuleList indices -> dict keys - # PyTorch: ups.0, ups.1, ... -> ups.0, ups.1, ... - # PyTorch: resblocks.0, resblocks.1, ... -> resblocks.0, resblocks.1, ... - - # Handle Conv1d weight shape conversion - # PyTorch: (out_channels, in_channels, kernel) - # MLX: (out_channels, kernel, in_channels) - if "weight" in new_key and value.ndim == 3: - if "ups" in new_key: - # ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch) - value = mx.transpose(value, (1, 2, 0)) - else: - # Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch) - value = mx.transpose(value, (0, 2, 1)) - - sanitized[new_key] = value - - return sanitized - - -def sanitize_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: - """Sanitize weight names from PyTorch format to MLX format. - - Generic function that handles both transformer and VAE weights. - - Args: - weights: Dictionary of weights with PyTorch naming - - Returns: - Dictionary with MLX-compatible naming - """ - sanitized = {} - - for key, value in weights.items(): - new_key = key - - # Skip position_ids (not needed) - if "position_ids" in key: - continue - - # Handle transformer weights - if key.startswith("model.diffusion_model."): - new_key = key.replace("model.diffusion_model.", "") - new_key = new_key.replace(".to_out.0.", ".to_out.") - new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") - new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") - new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.") - new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.") - new_key = new_key.replace(".linear_1.", ".linear1.") - new_key = new_key.replace(".linear_2.", ".linear2.") - - # Handle Conv3d weight shape conversion - # PyTorch: (out_channels, in_channels, D, H, W) - # MLX: (out_channels, D, H, W, in_channels) - if "conv" in key.lower() and "weight" in key and value.ndim == 5: - value = mx.transpose(value, (0, 2, 3, 4, 1)) - - # Handle Conv2d weight shape conversion - # PyTorch: (out_channels, in_channels, H, W) - # MLX: (out_channels, H, W, in_channels) - if "conv" in key.lower() and "weight" in key and value.ndim == 4: - value = mx.transpose(value, (0, 2, 3, 1)) - - sanitized[new_key] = value - - return sanitized - - -def load_config(model_path: Path) -> Dict[str, Any]: - """Load model configuration. - - Args: - model_path: Path to model directory - - Returns: - Configuration dictionary - """ - config_path = model_path / "config.json" - - if config_path.exists(): - with open(config_path, "r") as f: - return json.load(f) - - # Return default config - return {} - - -def create_model_from_config(config: Dict[str, Any]) -> LTXModel: - """Create model instance from configuration. - - Args: - config: Configuration dictionary - - Returns: - LTXModel instance - """ - # Map config to LTXModelConfig - model_config = LTXModelConfig( - model_type=LTXModelType.AudioVideo, - num_attention_heads=config.get("num_attention_heads", 32), - attention_head_dim=config.get("attention_head_dim", 128), - in_channels=config.get("in_channels", 128), - out_channels=config.get("out_channels", 128), - num_layers=config.get("num_layers", 48), - cross_attention_dim=config.get("cross_attention_dim", 4096), - caption_channels=config.get("caption_channels", 3840), - audio_num_attention_heads=config.get("audio_num_attention_heads", 32), - audio_attention_head_dim=config.get("audio_attention_head_dim", 64), - audio_in_channels=config.get("audio_in_channels", 128), - audio_out_channels=config.get("audio_out_channels", 128), - audio_cross_attention_dim=config.get("audio_cross_attention_dim", 2048), - positional_embedding_theta=config.get("positional_embedding_theta", 10000.0), - positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]), - audio_positional_embedding_max_pos=config.get("audio_positional_embedding_max_pos", [20]), - timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000), - av_ca_timestep_scale_multiplier=config.get("av_ca_timestep_scale_multiplier", 1000), - norm_eps=config.get("norm_eps", 1e-6), - ) - - return LTXModel(model_config) - - -def convert( - hf_path: str, - mlx_path: str = "mlx_model", - dtype: Optional[str] = None, - quantize: bool = False, - q_bits: int = 4, - q_group_size: int = 64, -) -> Path: - """Convert HuggingFace model to MLX format. - - Args: - hf_path: HuggingFace model path or repo ID - mlx_path: Output path for MLX model - dtype: Target dtype (float16, float32, bfloat16) - quantize: Whether to quantize the model - q_bits: Quantization bits - q_group_size: Quantization group size - - Returns: - Path to converted model - """ - print(f"Loading model from {hf_path}...") - model_path = get_model_path(hf_path) - - # Load config - config = load_config(model_path) - - # Load weights - print("Loading weights...") - weights = load_safetensors(model_path) - - # Sanitize weights - print("Sanitizing weights...") - weights = sanitize_weights(weights) - - # Convert dtype if specified - if dtype is not None: - dtype_map = { - "float16": mx.float16, - "float32": mx.float32, - "bfloat16": mx.bfloat16, - } - target_dtype = dtype_map.get(dtype, mx.float16) - print(f"Converting to {dtype}...") - weights = { - k: v.astype(target_dtype) if v.dtype in [mx.float32, mx.float16, mx.bfloat16] else v - for k, v in weights.items() - } - - # Create output directory - output_path = Path(mlx_path) - output_path.mkdir(parents=True, exist_ok=True) - - # Save weights - print(f"Saving weights to {output_path}...") - save_weights(output_path, weights) - - # Save config - config_out_path = output_path / "config.json" - with open(config_out_path, "w") as f: - json.dump(config, f, indent=2) - - print(f"Model converted successfully to {output_path}") - return output_path - - -def save_weights(path: Path, weights: Dict[str, mx.array]) -> None: - """Save weights in safetensors format. - - Uses mx.save_safetensors to preserve exact dtype (especially bfloat16). - Converting through numpy loses bfloat16 fidelity since numpy lacks native - bfloat16 support. - - Args: - path: Output directory - weights: Dictionary of weights - """ - mx.save_safetensors(str(path / "model.safetensors"), weights) - - -def convert_audio_encoder( - model_path: Union[str, Path], - source_repo: str = "Lightricks/LTX-2", -) -> Path: - """Convert and save audio encoder weights from original HF checkpoint. - - The audio VAE safetensors in the HF repo contains both encoder and decoder - weights. This extracts encoder weights, transposes Conv2d for MLX, and saves - them to a separate directory for AudioEncoder.from_pretrained(). - - Args: - model_path: Local model directory (output location). - source_repo: HF repo containing audio_vae/diffusion_pytorch_model.safetensors. - - Returns: - Path to the audio_vae/encoder directory. - """ - model_path = Path(model_path) - encoder_dir = model_path / "audio_vae" / "encoder" - - if (encoder_dir / "model.safetensors").exists(): - return encoder_dir - - # Download original audio VAE weights - from huggingface_hub import hf_hub_download - vae_path = hf_hub_download( - source_repo, - "audio_vae/diffusion_pytorch_model.safetensors", - ) - - raw_weights = mx.load(vae_path) - - # Extract encoder weights and per-channel statistics - from mlx_video.models.ltx_2.audio_vae import AudioEncoder - from mlx_video.models.ltx_2.config import AudioEncoderModelConfig - - # Build config from the decoder config (same audio VAE architecture) - decoder_config_path = model_path / "audio_vae" / "decoder" / "config.json" - if decoder_config_path.exists(): - with open(decoder_config_path) as f: - dec_cfg = json.load(f) - enc_config = { - "ch": dec_cfg.get("ch", 128), - "in_channels": dec_cfg.get("out_ch", 2), - "ch_mult": dec_cfg.get("ch_mult", [1, 2, 4]), - "num_res_blocks": dec_cfg.get("num_res_blocks", 2), - "attn_resolutions": dec_cfg.get("attn_resolutions", []), - "resolution": dec_cfg.get("resolution", 256), - "z_channels": dec_cfg.get("z_channels", 8), - "double_z": True, - "n_fft": 1024, - "norm_type": dec_cfg.get("norm_type", "pixel"), - "causality_axis": dec_cfg.get("causality_axis", "height"), - "dropout": dec_cfg.get("dropout", 0.0), - "mid_block_add_attention": dec_cfg.get("mid_block_add_attention", False), - "sample_rate": dec_cfg.get("sample_rate", 16000), - "mel_hop_length": dec_cfg.get("mel_hop_length", 160), - "is_causal": dec_cfg.get("is_causal", True), - "mel_bins": dec_cfg.get("mel_bins", 64) or 64, - "resamp_with_conv": dec_cfg.get("resamp_with_conv", True), - "attn_type": dec_cfg.get("attn_type", "vanilla"), - } - else: - enc_config = {"in_channels": 2, "double_z": True, "n_fft": 1024, "mel_bins": 64} - - # Sanitize weights - config = AudioEncoderModelConfig.from_dict(enc_config) - encoder = AudioEncoder(config) - sanitized = encoder.sanitize(raw_weights) - - # Save - encoder_dir.mkdir(parents=True, exist_ok=True) - mx.save_safetensors(str(encoder_dir / "model.safetensors"), sanitized) - with open(encoder_dir / "config.json", "w") as f: - json.dump(enc_config, f, indent=2) - - print(f"Audio encoder weights saved to {encoder_dir}") - return encoder_dir - - -def load_model( - path_or_hf_repo: str, - lazy: bool = False, -) -> LTXModel: - """Load LTX model from path or HuggingFace. - - Args: - path_or_hf_repo: Path to model or HuggingFace repo ID - lazy: Whether to use lazy loading - - Returns: - Loaded LTXModel - """ - model_path = get_model_path(path_or_hf_repo) - - # Load config - config = load_config(model_path) - - # Create model - model = create_model_from_config(config) - - # Load weights - weights = load_safetensors(model_path) - - # Sanitize if needed - weights = sanitize_weights(weights) - - # Load weights into model - model.load_weights(list(weights.items())) - - if not lazy: - mx.eval(model.parameters()) - - return model - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="Convert LTX-2 model to MLX format") - parser.add_argument( - "--hf-path", - type=str, - default="Lightricks/LTX-2", - help="HuggingFace model path or repo ID", - ) - parser.add_argument( - "--mlx-path", - type=str, - default="mlx_model", - help="Output path for MLX model", - ) - parser.add_argument( - "--dtype", - type=str, - choices=["float16", "float32", "bfloat16"], - default="float16", - help="Target dtype", - ) - parser.add_argument( - "--quantize", - action="store_true", - help="Quantize the model", - ) - parser.add_argument( - "--q-bits", - type=int, - default=4, - help="Quantization bits", - ) - - args = parser.parse_args() - - convert( - hf_path=args.hf_path, - mlx_path=args.mlx_path, - dtype=args.dtype, - quantize=args.quantize, - q_bits=args.q_bits, - )