Add audio to video conditioning
This commit is contained in:
@@ -606,6 +606,86 @@ def save_weights(path: Path, weights: Dict[str, mx.array]) -> None:
|
||||
mx.save_safetensors(str(path / "model.safetensors"), weights)
|
||||
|
||||
|
||||
def convert_audio_encoder(
|
||||
model_path: Union[str, Path],
|
||||
source_repo: str = "Lightricks/LTX-2",
|
||||
) -> Path:
|
||||
"""Convert and save audio encoder weights from original HF checkpoint.
|
||||
|
||||
The audio VAE safetensors in the HF repo contains both encoder and decoder
|
||||
weights. This extracts encoder weights, transposes Conv2d for MLX, and saves
|
||||
them to a separate directory for AudioEncoder.from_pretrained().
|
||||
|
||||
Args:
|
||||
model_path: Local model directory (output location).
|
||||
source_repo: HF repo containing audio_vae/diffusion_pytorch_model.safetensors.
|
||||
|
||||
Returns:
|
||||
Path to the audio_vae_encoder directory.
|
||||
"""
|
||||
model_path = Path(model_path)
|
||||
encoder_dir = model_path / "audio_vae_encoder"
|
||||
|
||||
if (encoder_dir / "model.safetensors").exists():
|
||||
return encoder_dir
|
||||
|
||||
# Download original audio VAE weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
vae_path = hf_hub_download(
|
||||
source_repo,
|
||||
"audio_vae/diffusion_pytorch_model.safetensors",
|
||||
)
|
||||
|
||||
raw_weights = mx.load(vae_path)
|
||||
|
||||
# Extract encoder weights and per-channel statistics
|
||||
from mlx_video.models.ltx.audio_vae import AudioEncoder
|
||||
from mlx_video.models.ltx.config import AudioEncoderModelConfig
|
||||
|
||||
# Build config from the decoder config (same audio VAE architecture)
|
||||
decoder_config_path = model_path / "audio_vae" / "config.json"
|
||||
if decoder_config_path.exists():
|
||||
with open(decoder_config_path) as f:
|
||||
dec_cfg = json.load(f)
|
||||
enc_config = {
|
||||
"ch": dec_cfg.get("ch", 128),
|
||||
"in_channels": dec_cfg.get("out_ch", 2),
|
||||
"ch_mult": dec_cfg.get("ch_mult", [1, 2, 4]),
|
||||
"num_res_blocks": dec_cfg.get("num_res_blocks", 2),
|
||||
"attn_resolutions": dec_cfg.get("attn_resolutions", []),
|
||||
"resolution": dec_cfg.get("resolution", 256),
|
||||
"z_channels": dec_cfg.get("z_channels", 8),
|
||||
"double_z": True,
|
||||
"n_fft": 1024,
|
||||
"norm_type": dec_cfg.get("norm_type", "pixel"),
|
||||
"causality_axis": dec_cfg.get("causality_axis", "height"),
|
||||
"dropout": dec_cfg.get("dropout", 0.0),
|
||||
"mid_block_add_attention": dec_cfg.get("mid_block_add_attention", False),
|
||||
"sample_rate": dec_cfg.get("sample_rate", 16000),
|
||||
"mel_hop_length": dec_cfg.get("mel_hop_length", 160),
|
||||
"is_causal": dec_cfg.get("is_causal", True),
|
||||
"mel_bins": dec_cfg.get("mel_bins", 64) or 64,
|
||||
"resamp_with_conv": dec_cfg.get("resamp_with_conv", True),
|
||||
"attn_type": dec_cfg.get("attn_type", "vanilla"),
|
||||
}
|
||||
else:
|
||||
enc_config = {"in_channels": 2, "double_z": True, "n_fft": 1024, "mel_bins": 64}
|
||||
|
||||
# Sanitize weights
|
||||
config = AudioEncoderModelConfig.from_dict(enc_config)
|
||||
encoder = AudioEncoder(config)
|
||||
sanitized = encoder.sanitize(raw_weights)
|
||||
|
||||
# Save
|
||||
encoder_dir.mkdir(parents=True, exist_ok=True)
|
||||
mx.save_safetensors(str(encoder_dir / "model.safetensors"), sanitized)
|
||||
with open(encoder_dir / "config.json", "w") as f:
|
||||
json.dump(enc_config, f, indent=2)
|
||||
|
||||
print(f"Audio encoder weights saved to {encoder_dir}")
|
||||
return encoder_dir
|
||||
|
||||
|
||||
def load_model(
|
||||
path_or_hf_repo: str,
|
||||
lazy: bool = False,
|
||||
|
||||
Reference in New Issue
Block a user