feat(wan): Add I2V-14B dual-model support
This commit is contained in:
@@ -316,6 +316,14 @@ def convert_wan_checkpoint(
|
||||
def _detect_config():
|
||||
"""Detect config from source config.json or transformer weight shapes."""
|
||||
if is_dual:
|
||||
# Check source config.json for model_type (I2V vs T2V)
|
||||
src_cfg_path = checkpoint_dir / "high_noise_model" / "config.json"
|
||||
if src_cfg_path.exists():
|
||||
with open(src_cfg_path) as f:
|
||||
src_config = json.load(f)
|
||||
src_model_type = src_config.get("model_type", "t2v")
|
||||
if src_model_type == "i2v" or src_config.get("in_dim") == 36:
|
||||
return WanModelConfig.wan22_i2v_14b()
|
||||
return WanModelConfig.wan22_t2v_14b()
|
||||
|
||||
# Try reading source config.json first (most reliable)
|
||||
@@ -413,7 +421,7 @@ def convert_wan_checkpoint(
|
||||
weights = load_torch_weights(str(vae_path))
|
||||
if is_wan22_vae:
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
include_encoder = config.model_type == "ti2v"
|
||||
include_encoder = config.model_type in ("ti2v", "i2v")
|
||||
weights = sanitize_wan22_vae_weights(weights, include_encoder=include_encoder)
|
||||
else:
|
||||
weights = sanitize_wan_vae_weights(weights)
|
||||
|
||||
Reference in New Issue
Block a user