feat(wan): Add Wan2.2 I2V support

This commit is contained in:
Daniel
2026-02-27 13:46:23 +01:00
parent 93da550f65
commit 2bb95c61ed
26 changed files with 4401 additions and 2968 deletions

View File

@@ -338,6 +338,10 @@ def convert_wan_checkpoint(
print(f" Source config: dim={src_dim}, layers={src_num_layers}, "
f"heads={src_num_heads}, type={src_model_type}")
# Use preset for known TI2V 5B configuration
if src_model_type == "ti2v" and src_dim == 3072:
return WanModelConfig.wan22_ti2v_5b()
is_22 = model_version == "2.2"
# Wan2.2 uses different VAE with z_dim=48 and stride (4,16,16)
@@ -409,7 +413,8 @@ def convert_wan_checkpoint(
weights = load_torch_weights(str(vae_path))
if is_wan22_vae:
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = sanitize_wan22_vae_weights(weights)
include_encoder = config.model_type == "ti2v"
weights = sanitize_wan22_vae_weights(weights, include_encoder=include_encoder)
else:
weights = sanitize_wan_vae_weights(weights)
# Always save VAE in float32 — official Wan2.2 runs VAE decode in