feat(wan): Add DPM++ 2M and UniPC schedulers

This commit is contained in:
Daniel
2026-02-27 10:28:33 +01:00
parent e64483a66a
commit 93da550f65
8 changed files with 1792 additions and 89 deletions

View File

@@ -56,7 +56,12 @@ def load_wan_model(model_path: Path, config, quantization: dict | None = None):
def load_t5_encoder(model_path: Path, config):
"""Load T5 text encoder."""
"""Load T5 text encoder.
Weights are upcast to float32 for maximum precision — the T5 encoder
only runs once per generation, so performance impact is negligible.
This matches the official which computes softmax in float32 explicitly.
"""
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
@@ -70,6 +75,7 @@ def load_t5_encoder(model_path: Path, config):
shared_pos=False,
)
weights = mx.load(str(model_path))
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
encoder.load_weights(list(weights.items()))
mx.eval(encoder.parameters())
return encoder
@@ -91,11 +97,33 @@ def load_vae_decoder(model_path: Path, config=None):
vae = WanVAE(z_dim=16)
weights = mx.load(str(model_path))
# Upcast VAE weights to float32 for quality — official Wan2.2 runs VAE in float32
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
vae.load_weights(list(weights.items()), strict=False)
mx.eval(vae.parameters())
return vae
def _clean_text(text: str) -> str:
"""Clean text matching official Wan2.2 tokenizer preprocessing.
Applies ftfy.fix_text (fixes mojibake, normalizes fullwidth chars),
double HTML unescape, and whitespace normalization. Critical for
correct tokenization of the Chinese negative prompt.
"""
import html
import re
try:
import ftfy
text = ftfy.fix_text(text)
except ImportError:
pass
text = html.unescape(html.unescape(text))
text = re.sub(r"\s+", " ", text).strip()
return text
def encode_text(
encoder,
tokenizer,
@@ -113,6 +141,7 @@ def encode_text(
Returns:
Text embeddings [L, dim]
"""
prompt = _clean_text(prompt)
tokens = tokenizer(
prompt,
max_length=text_len,
@@ -133,7 +162,7 @@ def encode_text(
def generate_video(
model_dir: str,
prompt: str,
negative_prompt: str = "",
negative_prompt: str | None = None,
width: int = 1280,
height: int = 720,
num_frames: int = 81,
@@ -142,13 +171,14 @@ def generate_video(
shift: float = None,
seed: int = -1,
output_path: str = "output.mp4",
scheduler: str = "unipc",
):
"""Generate video using Wan T2V pipeline (supports 2.1 and 2.2).
Args:
model_dir: Path to converted MLX model directory
prompt: Text prompt
negative_prompt: Negative prompt
negative_prompt: Negative prompt (None = use config default, "" = no negative prompt)
width: Video width
height: Video height
num_frames: Number of frames (must be 4n+1)
@@ -157,11 +187,16 @@ def generate_video(
shift: Noise schedule shift (None = use config default)
seed: Random seed (-1 for random)
output_path: Output video path
scheduler: Solver type: 'euler', 'dpm++', or 'unipc' (default)
"""
import json
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
)
model_dir = Path(model_dir)
@@ -253,12 +288,23 @@ def generate_video(
version_str = f"Wan{config.model_version}"
mode_str = "dual-model" if is_dual else "single-model"
# Resolve negative prompt: explicit user value > config default
# The official Wan2.2 uses a Chinese negative prompt (config.sample_neg_prompt)
# that prevents oversaturation, artifacts, and comic look. We use it by default.
# Text cleaning (_clean_text) normalizes fullwidth chars to match official tokenization.
if negative_prompt is None:
neg_prompt_resolved = config.sample_neg_prompt
else:
neg_prompt_resolved = negative_prompt
print(f"{Colors.CYAN}{'='*60}")
print(f" {version_str} Text-to-Video Generation (MLX, {mode_str})")
print(f"{'='*60}{Colors.RESET}")
print(f"{Colors.DIM} Prompt: {prompt}")
if neg_prompt_resolved and neg_prompt_resolved.strip():
neg_display = neg_prompt_resolved[:60] + "..." if len(neg_prompt_resolved) > 60 else neg_prompt_resolved
print(f" Neg prompt: {neg_display}")
print(f" Size: {width}x{height}, Frames: {num_frames}")
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}")
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}")
print(f"{Colors.RESET}")
# Seed
@@ -298,10 +344,7 @@ def generate_video(
# Encode prompts
print(f"{Colors.BLUE}Encoding text...{Colors.RESET}")
context = encode_text(t5_encoder, tokenizer, prompt, config.text_len)
if negative_prompt:
context_null = encode_text(t5_encoder, tokenizer, negative_prompt, config.text_len)
else:
context_null = encode_text(t5_encoder, tokenizer, "", config.text_len)
context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len)
mx.eval(context, context_null)
# Free T5 from memory
@@ -343,8 +386,14 @@ def generate_video(
mx.eval(cross_kv)
# Setup scheduler
scheduler = FlowMatchEulerScheduler(num_train_timesteps=config.num_train_timesteps)
scheduler.set_timesteps(steps, shift=shift)
_schedulers = {
"euler": FlowMatchEulerScheduler,
"dpm++": FlowDPMPP2MScheduler,
"unipc": FlowUniPCScheduler,
}
sched_cls = _schedulers.get(scheduler, FlowUniPCScheduler)
sched = sched_cls(num_train_timesteps=config.num_train_timesteps)
sched.set_timesteps(steps, shift=shift)
# Generate initial noise
noise = mx.random.normal(target_shape)
@@ -358,7 +407,7 @@ def generate_video(
t3 = time.time()
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
timestep_val = scheduler.timesteps[i].item()
timestep_val = sched.timesteps[i].item()
# Select model, guide scale, and cached K/V
if is_dual:
@@ -387,7 +436,7 @@ def generate_video(
# Classifier-free guidance + scheduler step
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
latents = scheduler.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
# Release temporaries before eval to free memory for graph execution
del noise_pred_cond, noise_pred_uncond, noise_pred, preds
@@ -476,7 +525,10 @@ def main():
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory")
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
parser.add_argument("--negative-prompt", type=str, default="", help="Negative prompt")
parser.add_argument("--negative-prompt", type=str, default=None,
help="Negative prompt for CFG (default: official Chinese prompt from config)")
parser.add_argument("--no-negative-prompt", action="store_true",
help="Disable negative prompt (use empty string instead of config default)")
parser.add_argument("--width", type=int, default=1280, help="Video width")
parser.add_argument("--height", type=int, default=720, help="Video height")
parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)")
@@ -485,6 +537,11 @@ def main():
parser.add_argument("--shift", type=float, default=None, help="Noise schedule shift (default: from config)")
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
parser.add_argument("--output-path", type=str, default="output.mp4", help="Output video path")
parser.add_argument(
"--scheduler", type=str, default="unipc",
choices=["euler", "dpm++", "unipc"],
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
)
args = parser.parse_args()
# Parse guide scale
@@ -493,10 +550,15 @@ def main():
parts = [float(x) for x in args.guide_scale.split(",")]
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
# Handle negative prompt: --no-negative-prompt forces empty, otherwise pass through
neg_prompt = args.negative_prompt
if args.no_negative_prompt:
neg_prompt = ""
generate_video(
model_dir=args.model_dir,
prompt=args.prompt,
negative_prompt=args.negative_prompt,
negative_prompt=neg_prompt,
width=args.width,
height=args.height,
num_frames=args.num_frames,
@@ -505,6 +567,7 @@ def main():
shift=args.shift,
seed=args.seed,
output_path=args.output_path,
scheduler=args.scheduler,
)