Add audio to video conditioning

This commit is contained in:
Prince Canuma
2026-03-16 01:42:11 +01:00
parent f53b9e0807
commit 6f6105b715
7 changed files with 623 additions and 62 deletions

View File

@@ -606,6 +606,86 @@ def save_weights(path: Path, weights: Dict[str, mx.array]) -> None:
mx.save_safetensors(str(path / "model.safetensors"), weights)
def convert_audio_encoder(
model_path: Union[str, Path],
source_repo: str = "Lightricks/LTX-2",
) -> Path:
"""Convert and save audio encoder weights from original HF checkpoint.
The audio VAE safetensors in the HF repo contains both encoder and decoder
weights. This extracts encoder weights, transposes Conv2d for MLX, and saves
them to a separate directory for AudioEncoder.from_pretrained().
Args:
model_path: Local model directory (output location).
source_repo: HF repo containing audio_vae/diffusion_pytorch_model.safetensors.
Returns:
Path to the audio_vae_encoder directory.
"""
model_path = Path(model_path)
encoder_dir = model_path / "audio_vae_encoder"
if (encoder_dir / "model.safetensors").exists():
return encoder_dir
# Download original audio VAE weights
from huggingface_hub import hf_hub_download
vae_path = hf_hub_download(
source_repo,
"audio_vae/diffusion_pytorch_model.safetensors",
)
raw_weights = mx.load(vae_path)
# Extract encoder weights and per-channel statistics
from mlx_video.models.ltx.audio_vae import AudioEncoder
from mlx_video.models.ltx.config import AudioEncoderModelConfig
# Build config from the decoder config (same audio VAE architecture)
decoder_config_path = model_path / "audio_vae" / "config.json"
if decoder_config_path.exists():
with open(decoder_config_path) as f:
dec_cfg = json.load(f)
enc_config = {
"ch": dec_cfg.get("ch", 128),
"in_channels": dec_cfg.get("out_ch", 2),
"ch_mult": dec_cfg.get("ch_mult", [1, 2, 4]),
"num_res_blocks": dec_cfg.get("num_res_blocks", 2),
"attn_resolutions": dec_cfg.get("attn_resolutions", []),
"resolution": dec_cfg.get("resolution", 256),
"z_channels": dec_cfg.get("z_channels", 8),
"double_z": True,
"n_fft": 1024,
"norm_type": dec_cfg.get("norm_type", "pixel"),
"causality_axis": dec_cfg.get("causality_axis", "height"),
"dropout": dec_cfg.get("dropout", 0.0),
"mid_block_add_attention": dec_cfg.get("mid_block_add_attention", False),
"sample_rate": dec_cfg.get("sample_rate", 16000),
"mel_hop_length": dec_cfg.get("mel_hop_length", 160),
"is_causal": dec_cfg.get("is_causal", True),
"mel_bins": dec_cfg.get("mel_bins", 64) or 64,
"resamp_with_conv": dec_cfg.get("resamp_with_conv", True),
"attn_type": dec_cfg.get("attn_type", "vanilla"),
}
else:
enc_config = {"in_channels": 2, "double_z": True, "n_fft": 1024, "mel_bins": 64}
# Sanitize weights
config = AudioEncoderModelConfig.from_dict(enc_config)
encoder = AudioEncoder(config)
sanitized = encoder.sanitize(raw_weights)
# Save
encoder_dir.mkdir(parents=True, exist_ok=True)
mx.save_safetensors(str(encoder_dir / "model.safetensors"), sanitized)
with open(encoder_dir / "config.json", "w") as f:
json.dump(enc_config, f, indent=2)
print(f"Audio encoder weights saved to {encoder_dir}")
return encoder_dir
def load_model(
path_or_hf_repo: str,
lazy: bool = False,