feat(wan): Add I2V-14B dual-model support
This commit is contained in:
@@ -87,16 +87,23 @@ def load_vae_decoder(model_path: Path, config=None):
|
||||
def load_vae_encoder(model_path: Path, config=None):
|
||||
"""Load VAE encoder for I2V image encoding.
|
||||
|
||||
Only supports Wan2.2 (vae_z_dim=48).
|
||||
For Wan2.2 TI2V (vae_z_dim=48), uses Wan22VAEEncoder.
|
||||
For Wan2.1/I2V-14B (vae_z_dim=16), uses WanVAE with encoder=True.
|
||||
"""
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||
if config is not None and config.vae_z_dim == 16:
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16, encoder=True)
|
||||
else:
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||
|
||||
vae = Wan22VAEEncoder(z_dim=config.vae_z_dim if config else 48)
|
||||
|
||||
encoder = Wan22VAEEncoder(z_dim=config.vae_z_dim)
|
||||
weights = mx.load(str(model_path))
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
encoder.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(encoder.parameters())
|
||||
return encoder
|
||||
vae.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(vae.parameters())
|
||||
return vae
|
||||
|
||||
|
||||
def _clean_text(text: str) -> str:
|
||||
|
||||
Reference in New Issue
Block a user