feat(wan): Add DPM++ 2M and UniPC schedulers

This commit is contained in:
Daniel
2026-02-27 10:28:33 +01:00
parent e64483a66a
commit 93da550f65
8 changed files with 1792 additions and 89 deletions

View File

@@ -412,10 +412,13 @@ def convert_wan_checkpoint(
weights = sanitize_wan22_vae_weights(weights)
else:
weights = sanitize_wan_vae_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
# Always save VAE in float32 — official Wan2.2 runs VAE decode in
# float32 (dtype=torch.float). Saving in bfloat16 loses precision
# that cannot be recovered by upcasting at load time.
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
out_path = output_dir / "vae.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
print(f" Saved {len(weights)} weight tensors to {out_path} (float32)")
# Quantize transformer weights if requested
if quantize: