feat(wan): Add I2V-14B dual-model support

This commit is contained in:
Daniel
2026-02-27 23:43:42 +01:00
parent 2bb95c61ed
commit f4195f0118
14 changed files with 1332 additions and 152 deletions

View File

@@ -87,16 +87,23 @@ def load_vae_decoder(model_path: Path, config=None):
def load_vae_encoder(model_path: Path, config=None):
"""Load VAE encoder for I2V image encoding.
Only supports Wan2.2 (vae_z_dim=48).
For Wan2.2 TI2V (vae_z_dim=48), uses Wan22VAEEncoder.
For Wan2.1/I2V-14B (vae_z_dim=16), uses WanVAE with encoder=True.
"""
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
if config is not None and config.vae_z_dim == 16:
from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16, encoder=True)
else:
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
vae = Wan22VAEEncoder(z_dim=config.vae_z_dim if config else 48)
encoder = Wan22VAEEncoder(z_dim=config.vae_z_dim)
weights = mx.load(str(model_path))
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
encoder.load_weights(list(weights.items()), strict=False)
mx.eval(encoder.parameters())
return encoder
vae.load_weights(list(weights.items()), strict=False)
mx.eval(vae.parameters())
return vae
def _clean_text(text: str) -> str: