feat(wan): Add I2V-14B dual-model support

This commit is contained in:
Daniel
2026-02-27 23:43:42 +01:00
parent 2bb95c61ed
commit f4195f0118
14 changed files with 1332 additions and 152 deletions

View File

@@ -316,6 +316,14 @@ def convert_wan_checkpoint(
def _detect_config():
"""Detect config from source config.json or transformer weight shapes."""
if is_dual:
# Check source config.json for model_type (I2V vs T2V)
src_cfg_path = checkpoint_dir / "high_noise_model" / "config.json"
if src_cfg_path.exists():
with open(src_cfg_path) as f:
src_config = json.load(f)
src_model_type = src_config.get("model_type", "t2v")
if src_model_type == "i2v" or src_config.get("in_dim") == 36:
return WanModelConfig.wan22_i2v_14b()
return WanModelConfig.wan22_t2v_14b()
# Try reading source config.json first (most reliable)
@@ -413,7 +421,7 @@ def convert_wan_checkpoint(
weights = load_torch_weights(str(vae_path))
if is_wan22_vae:
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
include_encoder = config.model_type == "ti2v"
include_encoder = config.model_type in ("ti2v", "i2v")
weights = sanitize_wan22_vae_weights(weights, include_encoder=include_encoder)
else:
weights = sanitize_wan_vae_weights(weights)