feat(wan): Add DPM++ 2M and UniPC schedulers
This commit is contained in:
@@ -412,10 +412,13 @@ def convert_wan_checkpoint(
|
||||
weights = sanitize_wan22_vae_weights(weights)
|
||||
else:
|
||||
weights = sanitize_wan_vae_weights(weights)
|
||||
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
||||
# Always save VAE in float32 — official Wan2.2 runs VAE decode in
|
||||
# float32 (dtype=torch.float). Saving in bfloat16 loses precision
|
||||
# that cannot be recovered by upcasting at load time.
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
out_path = output_dir / "vae.safetensors"
|
||||
mx.save_safetensors(str(out_path), weights)
|
||||
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
||||
print(f" Saved {len(weights)} weight tensors to {out_path} (float32)")
|
||||
|
||||
# Quantize transformer weights if requested
|
||||
if quantize:
|
||||
|
||||
Reference in New Issue
Block a user