feat(wan): Add Wan2.2 I2V support
This commit is contained in:
@@ -338,6 +338,10 @@ def convert_wan_checkpoint(
|
||||
print(f" Source config: dim={src_dim}, layers={src_num_layers}, "
|
||||
f"heads={src_num_heads}, type={src_model_type}")
|
||||
|
||||
# Use preset for known TI2V 5B configuration
|
||||
if src_model_type == "ti2v" and src_dim == 3072:
|
||||
return WanModelConfig.wan22_ti2v_5b()
|
||||
|
||||
is_22 = model_version == "2.2"
|
||||
|
||||
# Wan2.2 uses different VAE with z_dim=48 and stride (4,16,16)
|
||||
@@ -409,7 +413,8 @@ def convert_wan_checkpoint(
|
||||
weights = load_torch_weights(str(vae_path))
|
||||
if is_wan22_vae:
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
weights = sanitize_wan22_vae_weights(weights)
|
||||
include_encoder = config.model_type == "ti2v"
|
||||
weights = sanitize_wan22_vae_weights(weights, include_encoder=include_encoder)
|
||||
else:
|
||||
weights = sanitize_wan_vae_weights(weights)
|
||||
# Always save VAE in float32 — official Wan2.2 runs VAE decode in
|
||||
|
||||
Reference in New Issue
Block a user