import json import shutil from pathlib import Path from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download from mlx_video.models.ltx_2.config import LTXModelConfig, LTXModelType from mlx_video.models.ltx_2.ltx import LTXModel def get_model_path( path_or_hf_repo: str, revision: Optional[str] = None, ) -> Path: """Get local path to model, downloading if necessary. Args: path_or_hf_repo: Local path or HuggingFace repo ID revision: Git revision for HF repo Returns: Path to model directory """ model_path = Path(path_or_hf_repo) if model_path.exists(): return model_path # Download from HuggingFace model_path = Path( snapshot_download( repo_id=path_or_hf_repo, revision=revision, allow_patterns=[ "*.safetensors", "*.json", "config.json", ], ) ) return model_path def load_safetensors(path: Path) -> Dict[str, mx.array]: """Load weights from safetensors file(s) using MLX. Args: path: Path to model directory or single safetensors file Returns: Dictionary of weights """ weights = {} if path.is_file(): # Single file - use mx.load directly (handles bfloat16) return mx.load(str(path)) else: # Directory - load all safetensors files safetensor_files = list(path.glob("*.safetensors")) for sf_path in safetensor_files: file_weights = mx.load(str(sf_path)) weights.update(file_weights) return weights def load_transformer_weights(model_path: Path) -> Dict[str, mx.array]: """Load transformer weights from LTX-2 model. Args: model_path: Path to LTX-2 model directory Returns: Dictionary of transformer weights """ # Try distilled model first, then dev weight_files = [ model_path / "ltx-2-19b-distilled.safetensors", model_path / "ltx-2-19b-dev.safetensors", ] for weight_file in weight_files: if weight_file.exists(): print(f"Loading transformer weights from {weight_file.name}...") return mx.load(str(weight_file)) raise FileNotFoundError(f"No transformer weights found in {model_path}") def load_vae_weights(model_path: Path) -> Dict[str, mx.array]: """Load VAE weights from LTX-2 model. Args: model_path: Path to LTX-2 model directory Returns: Dictionary of VAE weights """ vae_path = model_path / "vae" / "diffusion_pytorch_model.safetensors" if vae_path.exists(): print(f"Loading VAE weights from {vae_path}...") return mx.load(str(vae_path)) raise FileNotFoundError(f"VAE weights not found at {vae_path}") def load_audio_vae_weights(model_path: Path) -> Dict[str, mx.array]: """Load audio VAE weights from LTX-2 model. Args: model_path: Path to LTX-2 model directory Returns: Dictionary of audio VAE weights """ # Try different possible paths for audio VAE weights audio_vae_paths = [ model_path / "audio_vae" / "diffusion_pytorch_model.safetensors", model_path / "audio_vae.safetensors", ] # Also check in main model weights main_paths = [ model_path / "ltx-2-19b-distilled.safetensors", model_path / "ltx-2-19b-dev.safetensors", ] for audio_path in audio_vae_paths: if audio_path.exists(): print(f"Loading audio VAE weights from {audio_path}...") return mx.load(str(audio_path)) # Check main model weights for audio_vae keys for main_path in main_paths: if main_path.exists(): print(f"Loading audio VAE weights from {main_path.name}...") all_weights = mx.load(str(main_path)) # Filter to only audio_vae keys audio_weights = {k: v for k, v in all_weights.items() if "audio_vae" in k} if audio_weights: return audio_weights raise FileNotFoundError(f"Audio VAE weights not found in {model_path}") def load_vocoder_weights(model_path: Path) -> Dict[str, mx.array]: """Load vocoder weights from LTX-2 model. Args: model_path: Path to LTX-2 model directory Returns: Dictionary of vocoder weights """ # Try different possible paths for vocoder weights vocoder_paths = [ model_path / "vocoder" / "diffusion_pytorch_model.safetensors", model_path / "vocoder.safetensors", ] # Also check in main model weights main_paths = [ model_path / "ltx-2-19b-distilled.safetensors", model_path / "ltx-2-19b-dev.safetensors", ] for vocoder_path in vocoder_paths: if vocoder_path.exists(): print(f"Loading vocoder weights from {vocoder_path}...") return mx.load(str(vocoder_path)) # Check main model weights for vocoder keys for main_path in main_paths: if main_path.exists(): print(f"Loading vocoder weights from {main_path.name}...") all_weights = mx.load(str(main_path)) # Filter to only vocoder keys vocoder_weights = {k: v for k, v in all_weights.items() if "vocoder" in k} if vocoder_weights: return vocoder_weights raise FileNotFoundError(f"Vocoder weights not found in {model_path}") def sanitize_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: """Sanitize transformer weight names from PyTorch LTX-2 format to MLX format. Args: weights: Dictionary of weights with PyTorch naming Returns: Dictionary with MLX-compatible naming for transformer """ sanitized = {} for key, value in weights.items(): new_key = key # Skip non-transformer weights (VAE, vocoder, audio_vae, connectors) if not key.startswith("model.diffusion_model."): continue # Remove 'model.diffusion_model.' prefix new_key = key.replace("model.diffusion_model.", "") # Handle to_out.0 -> to_out (MLX doesn't use Sequential numbering) new_key = new_key.replace(".to_out.0.", ".to_out.") # Handle feed-forward net naming # PyTorch: ff.net.0.proj -> ff.net_0_proj (or similar) # MLX FeedForward: uses proj_in, proj_out new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.") new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.") # Handle AdaLN naming - keep emb wrapper, just fix linear naming # PyTorch: adaln_single.emb.timestep_embedder.linear_1 -> adaln_single.emb.timestep_embedder.linear1 new_key = new_key.replace(".linear_1.", ".linear1.") new_key = new_key.replace(".linear_2.", ".linear2.") # Handle caption projection (keep linear1/linear2 naming for compatibility) # These are already mapped correctly in the sanitization sanitized[new_key] = value return sanitized def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: """Sanitize VAE weight names from PyTorch format to MLX format. Args: weights: Dictionary of weights with PyTorch naming Returns: Dictionary with MLX-compatible naming for VAE decoder """ sanitized = {} for key, value in weights.items(): new_key = key # Skip position_ids (not needed) if "position_ids" in key: continue # Only process VAE decoder weights (skip audio_vae, etc.) if not key.startswith("vae."): continue # Handle per-channel statistics key mapping # PyTorch: vae.per_channel_statistics.mean-of-means -> per_channel_statistics.mean # PyTorch: vae.per_channel_statistics.std-of-means -> per_channel_statistics.std # Be careful: mean-of-stds_over_std-of-means also ends with std-of-means if "vae.per_channel_statistics" in key: if key == "vae.per_channel_statistics.mean-of-means": new_key = "per_channel_statistics.mean" elif key == "vae.per_channel_statistics.std-of-means": new_key = "per_channel_statistics.std" else: # Skip other per_channel_statistics keys (channel, mean-of-stds, etc.) continue elif key.startswith("vae.decoder."): # Strip the vae.decoder. prefix for decoder weights new_key = key.replace("vae.decoder.", "") else: # Skip other vae.* keys that are not decoder weights continue # Handle Conv3d weight shape conversion # PyTorch: (out_channels, in_channels, D, H, W) # MLX: (out_channels, D, H, W, in_channels) if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5: # Transpose from (O, I, D, H, W) to (O, D, H, W, I) value = mx.transpose(value, (0, 2, 3, 4, 1)) # Handle Conv2d weight shape conversion # PyTorch: (out_channels, in_channels, H, W) # MLX: (out_channels, H, W, in_channels) if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4: value = mx.transpose(value, (0, 2, 3, 1)) sanitized[new_key] = value return sanitized def sanitize_vae_encoder_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: """Sanitize VAE encoder weight names from PyTorch format to MLX format. Args: weights: Dictionary of weights with PyTorch naming Returns: Dictionary with MLX-compatible naming for VAE encoder """ sanitized = {} for key, value in weights.items(): new_key = key # Skip position_ids (not needed) if "position_ids" in key: continue # Only process VAE encoder weights if not key.startswith("vae."): continue # Handle per-channel statistics key mapping if "vae.per_channel_statistics" in key: if key == "vae.per_channel_statistics.mean-of-means": new_key = "per_channel_statistics._mean_of_means" elif key == "vae.per_channel_statistics.std-of-means": new_key = "per_channel_statistics._std_of_means" else: # Skip other per_channel_statistics keys continue elif key.startswith("vae.encoder."): # Strip the vae.encoder. prefix for encoder weights new_key = key.replace("vae.encoder.", "") else: # Skip other vae.* keys that are not encoder weights continue # Handle Conv3d weight shape conversion # PyTorch: (out_channels, in_channels, D, H, W) # MLX: (out_channels, D, H, W, in_channels) if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5: value = mx.transpose(value, (0, 2, 3, 4, 1)) # Handle Conv2d weight shape conversion if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4: value = mx.transpose(value, (0, 2, 3, 1)) sanitized[new_key] = value return sanitized def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: """Sanitize audio VAE weight names from PyTorch format to MLX format. Args: weights: Dictionary of weights with PyTorch naming Returns: Dictionary with MLX-compatible naming for audio VAE decoder """ sanitized = {} if "audio_vae." in weights: return weights for key, value in weights.items(): new_key = key # Handle audio_vae.decoder weights if key.startswith("audio_vae.decoder."): new_key = key.replace("audio_vae.decoder.", "") elif key.startswith("audio_vae.per_channel_statistics."): # Map per-channel statistics if "mean-of-means" in key: new_key = "per_channel_statistics.mean_of_means" elif "std-of-means" in key: new_key = "per_channel_statistics.std_of_means" else: continue # Skip other statistics keys else: continue # Skip non-decoder keys # Handle Conv2d weight shape conversion # PyTorch: (out_channels, in_channels, H, W) # MLX: (out_channels, H, W, in_channels) if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4: value = mx.transpose(value, (0, 2, 3, 1)) sanitized[new_key] = value return sanitized def sanitize_vocoder_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: """Sanitize vocoder weight names from PyTorch format to MLX format. Args: weights: Dictionary of weights with PyTorch naming Returns: Dictionary with MLX-compatible naming for vocoder """ sanitized = {} for key, value in weights.items(): new_key = key # Handle vocoder weights if key.startswith("vocoder."): new_key = key.replace("vocoder.", "") # Handle ModuleList indices -> dict keys # PyTorch: ups.0, ups.1, ... -> ups.0, ups.1, ... # PyTorch: resblocks.0, resblocks.1, ... -> resblocks.0, resblocks.1, ... # Handle Conv1d weight shape conversion # PyTorch: (out_channels, in_channels, kernel) # MLX: (out_channels, kernel, in_channels) if "weight" in new_key and value.ndim == 3: if "ups" in new_key: # ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch) value = mx.transpose(value, (1, 2, 0)) else: # Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch) value = mx.transpose(value, (0, 2, 1)) sanitized[new_key] = value return sanitized def sanitize_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: """Sanitize weight names from PyTorch format to MLX format. Generic function that handles both transformer and VAE weights. Args: weights: Dictionary of weights with PyTorch naming Returns: Dictionary with MLX-compatible naming """ sanitized = {} for key, value in weights.items(): new_key = key # Skip position_ids (not needed) if "position_ids" in key: continue # Handle transformer weights if key.startswith("model.diffusion_model."): new_key = key.replace("model.diffusion_model.", "") new_key = new_key.replace(".to_out.0.", ".to_out.") new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.") new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.") new_key = new_key.replace(".linear_1.", ".linear1.") new_key = new_key.replace(".linear_2.", ".linear2.") # Handle Conv3d weight shape conversion # PyTorch: (out_channels, in_channels, D, H, W) # MLX: (out_channels, D, H, W, in_channels) if "conv" in key.lower() and "weight" in key and value.ndim == 5: value = mx.transpose(value, (0, 2, 3, 4, 1)) # Handle Conv2d weight shape conversion # PyTorch: (out_channels, in_channels, H, W) # MLX: (out_channels, H, W, in_channels) if "conv" in key.lower() and "weight" in key and value.ndim == 4: value = mx.transpose(value, (0, 2, 3, 1)) sanitized[new_key] = value return sanitized def load_config(model_path: Path) -> Dict[str, Any]: """Load model configuration. Args: model_path: Path to model directory Returns: Configuration dictionary """ config_path = model_path / "config.json" if config_path.exists(): with open(config_path, "r") as f: return json.load(f) # Return default config return {} def create_model_from_config(config: Dict[str, Any]) -> LTXModel: """Create model instance from configuration. Args: config: Configuration dictionary Returns: LTXModel instance """ # Map config to LTXModelConfig model_config = LTXModelConfig( model_type=LTXModelType.AudioVideo, num_attention_heads=config.get("num_attention_heads", 32), attention_head_dim=config.get("attention_head_dim", 128), in_channels=config.get("in_channels", 128), out_channels=config.get("out_channels", 128), num_layers=config.get("num_layers", 48), cross_attention_dim=config.get("cross_attention_dim", 4096), caption_channels=config.get("caption_channels", 3840), audio_num_attention_heads=config.get("audio_num_attention_heads", 32), audio_attention_head_dim=config.get("audio_attention_head_dim", 64), audio_in_channels=config.get("audio_in_channels", 128), audio_out_channels=config.get("audio_out_channels", 128), audio_cross_attention_dim=config.get("audio_cross_attention_dim", 2048), positional_embedding_theta=config.get("positional_embedding_theta", 10000.0), positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]), audio_positional_embedding_max_pos=config.get("audio_positional_embedding_max_pos", [20]), timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000), av_ca_timestep_scale_multiplier=config.get("av_ca_timestep_scale_multiplier", 1000), norm_eps=config.get("norm_eps", 1e-6), ) return LTXModel(model_config) def convert( hf_path: str, mlx_path: str = "mlx_model", dtype: Optional[str] = None, quantize: bool = False, q_bits: int = 4, q_group_size: int = 64, ) -> Path: """Convert HuggingFace model to MLX format. Args: hf_path: HuggingFace model path or repo ID mlx_path: Output path for MLX model dtype: Target dtype (float16, float32, bfloat16) quantize: Whether to quantize the model q_bits: Quantization bits q_group_size: Quantization group size Returns: Path to converted model """ print(f"Loading model from {hf_path}...") model_path = get_model_path(hf_path) # Load config config = load_config(model_path) # Load weights print("Loading weights...") weights = load_safetensors(model_path) # Sanitize weights print("Sanitizing weights...") weights = sanitize_weights(weights) # Convert dtype if specified if dtype is not None: dtype_map = { "float16": mx.float16, "float32": mx.float32, "bfloat16": mx.bfloat16, } target_dtype = dtype_map.get(dtype, mx.float16) print(f"Converting to {dtype}...") weights = { k: v.astype(target_dtype) if v.dtype in [mx.float32, mx.float16, mx.bfloat16] else v for k, v in weights.items() } # Create output directory output_path = Path(mlx_path) output_path.mkdir(parents=True, exist_ok=True) # Save weights print(f"Saving weights to {output_path}...") save_weights(output_path, weights) # Save config config_out_path = output_path / "config.json" with open(config_out_path, "w") as f: json.dump(config, f, indent=2) print(f"Model converted successfully to {output_path}") return output_path def save_weights(path: Path, weights: Dict[str, mx.array]) -> None: """Save weights in safetensors format. Uses mx.save_safetensors to preserve exact dtype (especially bfloat16). Converting through numpy loses bfloat16 fidelity since numpy lacks native bfloat16 support. Args: path: Output directory weights: Dictionary of weights """ mx.save_safetensors(str(path / "model.safetensors"), weights) def convert_audio_encoder( model_path: Union[str, Path], source_repo: str = "Lightricks/LTX-2", ) -> Path: """Convert and save audio encoder weights from original HF checkpoint. The audio VAE safetensors in the HF repo contains both encoder and decoder weights. This extracts encoder weights, transposes Conv2d for MLX, and saves them to a separate directory for AudioEncoder.from_pretrained(). Args: model_path: Local model directory (output location). source_repo: HF repo containing audio_vae/diffusion_pytorch_model.safetensors. Returns: Path to the audio_vae_encoder directory. """ model_path = Path(model_path) encoder_dir = model_path / "audio_vae_encoder" if (encoder_dir / "model.safetensors").exists(): return encoder_dir # Download original audio VAE weights from huggingface_hub import hf_hub_download vae_path = hf_hub_download( source_repo, "audio_vae/diffusion_pytorch_model.safetensors", ) raw_weights = mx.load(vae_path) # Extract encoder weights and per-channel statistics from mlx_video.models.ltx_2.audio_vae import AudioEncoder from mlx_video.models.ltx_2.config import AudioEncoderModelConfig # Build config from the decoder config (same audio VAE architecture) decoder_config_path = model_path / "audio_vae" / "config.json" if decoder_config_path.exists(): with open(decoder_config_path) as f: dec_cfg = json.load(f) enc_config = { "ch": dec_cfg.get("ch", 128), "in_channels": dec_cfg.get("out_ch", 2), "ch_mult": dec_cfg.get("ch_mult", [1, 2, 4]), "num_res_blocks": dec_cfg.get("num_res_blocks", 2), "attn_resolutions": dec_cfg.get("attn_resolutions", []), "resolution": dec_cfg.get("resolution", 256), "z_channels": dec_cfg.get("z_channels", 8), "double_z": True, "n_fft": 1024, "norm_type": dec_cfg.get("norm_type", "pixel"), "causality_axis": dec_cfg.get("causality_axis", "height"), "dropout": dec_cfg.get("dropout", 0.0), "mid_block_add_attention": dec_cfg.get("mid_block_add_attention", False), "sample_rate": dec_cfg.get("sample_rate", 16000), "mel_hop_length": dec_cfg.get("mel_hop_length", 160), "is_causal": dec_cfg.get("is_causal", True), "mel_bins": dec_cfg.get("mel_bins", 64) or 64, "resamp_with_conv": dec_cfg.get("resamp_with_conv", True), "attn_type": dec_cfg.get("attn_type", "vanilla"), } else: enc_config = {"in_channels": 2, "double_z": True, "n_fft": 1024, "mel_bins": 64} # Sanitize weights config = AudioEncoderModelConfig.from_dict(enc_config) encoder = AudioEncoder(config) sanitized = encoder.sanitize(raw_weights) # Save encoder_dir.mkdir(parents=True, exist_ok=True) mx.save_safetensors(str(encoder_dir / "model.safetensors"), sanitized) with open(encoder_dir / "config.json", "w") as f: json.dump(enc_config, f, indent=2) print(f"Audio encoder weights saved to {encoder_dir}") return encoder_dir def load_model( path_or_hf_repo: str, lazy: bool = False, ) -> LTXModel: """Load LTX model from path or HuggingFace. Args: path_or_hf_repo: Path to model or HuggingFace repo ID lazy: Whether to use lazy loading Returns: Loaded LTXModel """ model_path = get_model_path(path_or_hf_repo) # Load config config = load_config(model_path) # Create model model = create_model_from_config(config) # Load weights weights = load_safetensors(model_path) # Sanitize if needed weights = sanitize_weights(weights) # Load weights into model model.load_weights(list(weights.items())) if not lazy: mx.eval(model.parameters()) return model if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Convert LTX-2 model to MLX format") parser.add_argument( "--hf-path", type=str, default="Lightricks/LTX-2", help="HuggingFace model path or repo ID", ) parser.add_argument( "--mlx-path", type=str, default="mlx_model", help="Output path for MLX model", ) parser.add_argument( "--dtype", type=str, choices=["float16", "float32", "bfloat16"], default="float16", help="Target dtype", ) parser.add_argument( "--quantize", action="store_true", help="Quantize the model", ) parser.add_argument( "--q-bits", type=int, default=4, help="Quantization bits", ) args = parser.parse_args() convert( hf_path=args.hf_path, mlx_path=args.mlx_path, dtype=args.dtype, quantize=args.quantize, q_bits=args.q_bits, )