feat(wan): Add DPM++ 2M and UniPC schedulers
This commit is contained in:
@@ -56,7 +56,12 @@ def load_wan_model(model_path: Path, config, quantization: dict | None = None):
|
||||
|
||||
|
||||
def load_t5_encoder(model_path: Path, config):
|
||||
"""Load T5 text encoder."""
|
||||
"""Load T5 text encoder.
|
||||
|
||||
Weights are upcast to float32 for maximum precision — the T5 encoder
|
||||
only runs once per generation, so performance impact is negligible.
|
||||
This matches the official which computes softmax in float32 explicitly.
|
||||
"""
|
||||
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
@@ -70,6 +75,7 @@ def load_t5_encoder(model_path: Path, config):
|
||||
shared_pos=False,
|
||||
)
|
||||
weights = mx.load(str(model_path))
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
encoder.load_weights(list(weights.items()))
|
||||
mx.eval(encoder.parameters())
|
||||
return encoder
|
||||
@@ -91,11 +97,33 @@ def load_vae_decoder(model_path: Path, config=None):
|
||||
vae = WanVAE(z_dim=16)
|
||||
|
||||
weights = mx.load(str(model_path))
|
||||
# Upcast VAE weights to float32 for quality — official Wan2.2 runs VAE in float32
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
vae.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(vae.parameters())
|
||||
return vae
|
||||
|
||||
|
||||
def _clean_text(text: str) -> str:
|
||||
"""Clean text matching official Wan2.2 tokenizer preprocessing.
|
||||
|
||||
Applies ftfy.fix_text (fixes mojibake, normalizes fullwidth chars),
|
||||
double HTML unescape, and whitespace normalization. Critical for
|
||||
correct tokenization of the Chinese negative prompt.
|
||||
"""
|
||||
import html
|
||||
import re
|
||||
|
||||
try:
|
||||
import ftfy
|
||||
text = ftfy.fix_text(text)
|
||||
except ImportError:
|
||||
pass
|
||||
text = html.unescape(html.unescape(text))
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
return text
|
||||
|
||||
|
||||
def encode_text(
|
||||
encoder,
|
||||
tokenizer,
|
||||
@@ -113,6 +141,7 @@ def encode_text(
|
||||
Returns:
|
||||
Text embeddings [L, dim]
|
||||
"""
|
||||
prompt = _clean_text(prompt)
|
||||
tokens = tokenizer(
|
||||
prompt,
|
||||
max_length=text_len,
|
||||
@@ -133,7 +162,7 @@ def encode_text(
|
||||
def generate_video(
|
||||
model_dir: str,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
negative_prompt: str | None = None,
|
||||
width: int = 1280,
|
||||
height: int = 720,
|
||||
num_frames: int = 81,
|
||||
@@ -142,13 +171,14 @@ def generate_video(
|
||||
shift: float = None,
|
||||
seed: int = -1,
|
||||
output_path: str = "output.mp4",
|
||||
scheduler: str = "unipc",
|
||||
):
|
||||
"""Generate video using Wan T2V pipeline (supports 2.1 and 2.2).
|
||||
|
||||
Args:
|
||||
model_dir: Path to converted MLX model directory
|
||||
prompt: Text prompt
|
||||
negative_prompt: Negative prompt
|
||||
negative_prompt: Negative prompt (None = use config default, "" = no negative prompt)
|
||||
width: Video width
|
||||
height: Video height
|
||||
num_frames: Number of frames (must be 4n+1)
|
||||
@@ -157,11 +187,16 @@ def generate_video(
|
||||
shift: Noise schedule shift (None = use config default)
|
||||
seed: Random seed (-1 for random)
|
||||
output_path: Output video path
|
||||
scheduler: Solver type: 'euler', 'dpm++', or 'unipc' (default)
|
||||
"""
|
||||
import json
|
||||
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowMatchEulerScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
|
||||
model_dir = Path(model_dir)
|
||||
|
||||
@@ -253,12 +288,23 @@ def generate_video(
|
||||
|
||||
version_str = f"Wan{config.model_version}"
|
||||
mode_str = "dual-model" if is_dual else "single-model"
|
||||
# Resolve negative prompt: explicit user value > config default
|
||||
# The official Wan2.2 uses a Chinese negative prompt (config.sample_neg_prompt)
|
||||
# that prevents oversaturation, artifacts, and comic look. We use it by default.
|
||||
# Text cleaning (_clean_text) normalizes fullwidth chars to match official tokenization.
|
||||
if negative_prompt is None:
|
||||
neg_prompt_resolved = config.sample_neg_prompt
|
||||
else:
|
||||
neg_prompt_resolved = negative_prompt
|
||||
print(f"{Colors.CYAN}{'='*60}")
|
||||
print(f" {version_str} Text-to-Video Generation (MLX, {mode_str})")
|
||||
print(f"{'='*60}{Colors.RESET}")
|
||||
print(f"{Colors.DIM} Prompt: {prompt}")
|
||||
if neg_prompt_resolved and neg_prompt_resolved.strip():
|
||||
neg_display = neg_prompt_resolved[:60] + "..." if len(neg_prompt_resolved) > 60 else neg_prompt_resolved
|
||||
print(f" Neg prompt: {neg_display}")
|
||||
print(f" Size: {width}x{height}, Frames: {num_frames}")
|
||||
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}")
|
||||
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}")
|
||||
print(f"{Colors.RESET}")
|
||||
|
||||
# Seed
|
||||
@@ -298,10 +344,7 @@ def generate_video(
|
||||
# Encode prompts
|
||||
print(f"{Colors.BLUE}Encoding text...{Colors.RESET}")
|
||||
context = encode_text(t5_encoder, tokenizer, prompt, config.text_len)
|
||||
if negative_prompt:
|
||||
context_null = encode_text(t5_encoder, tokenizer, negative_prompt, config.text_len)
|
||||
else:
|
||||
context_null = encode_text(t5_encoder, tokenizer, "", config.text_len)
|
||||
context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len)
|
||||
mx.eval(context, context_null)
|
||||
|
||||
# Free T5 from memory
|
||||
@@ -343,8 +386,14 @@ def generate_video(
|
||||
mx.eval(cross_kv)
|
||||
|
||||
# Setup scheduler
|
||||
scheduler = FlowMatchEulerScheduler(num_train_timesteps=config.num_train_timesteps)
|
||||
scheduler.set_timesteps(steps, shift=shift)
|
||||
_schedulers = {
|
||||
"euler": FlowMatchEulerScheduler,
|
||||
"dpm++": FlowDPMPP2MScheduler,
|
||||
"unipc": FlowUniPCScheduler,
|
||||
}
|
||||
sched_cls = _schedulers.get(scheduler, FlowUniPCScheduler)
|
||||
sched = sched_cls(num_train_timesteps=config.num_train_timesteps)
|
||||
sched.set_timesteps(steps, shift=shift)
|
||||
|
||||
# Generate initial noise
|
||||
noise = mx.random.normal(target_shape)
|
||||
@@ -358,7 +407,7 @@ def generate_video(
|
||||
t3 = time.time()
|
||||
|
||||
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
|
||||
timestep_val = scheduler.timesteps[i].item()
|
||||
timestep_val = sched.timesteps[i].item()
|
||||
|
||||
# Select model, guide scale, and cached K/V
|
||||
if is_dual:
|
||||
@@ -387,7 +436,7 @@ def generate_video(
|
||||
|
||||
# Classifier-free guidance + scheduler step
|
||||
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
||||
latents = scheduler.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||
|
||||
# Release temporaries before eval to free memory for graph execution
|
||||
del noise_pred_cond, noise_pred_uncond, noise_pred, preds
|
||||
@@ -476,7 +525,10 @@ def main():
|
||||
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
|
||||
parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory")
|
||||
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
|
||||
parser.add_argument("--negative-prompt", type=str, default="", help="Negative prompt")
|
||||
parser.add_argument("--negative-prompt", type=str, default=None,
|
||||
help="Negative prompt for CFG (default: official Chinese prompt from config)")
|
||||
parser.add_argument("--no-negative-prompt", action="store_true",
|
||||
help="Disable negative prompt (use empty string instead of config default)")
|
||||
parser.add_argument("--width", type=int, default=1280, help="Video width")
|
||||
parser.add_argument("--height", type=int, default=720, help="Video height")
|
||||
parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)")
|
||||
@@ -485,6 +537,11 @@ def main():
|
||||
parser.add_argument("--shift", type=float, default=None, help="Noise schedule shift (default: from config)")
|
||||
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
|
||||
parser.add_argument("--output-path", type=str, default="output.mp4", help="Output video path")
|
||||
parser.add_argument(
|
||||
"--scheduler", type=str, default="unipc",
|
||||
choices=["euler", "dpm++", "unipc"],
|
||||
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse guide scale
|
||||
@@ -493,10 +550,15 @@ def main():
|
||||
parts = [float(x) for x in args.guide_scale.split(",")]
|
||||
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
# Handle negative prompt: --no-negative-prompt forces empty, otherwise pass through
|
||||
neg_prompt = args.negative_prompt
|
||||
if args.no_negative_prompt:
|
||||
neg_prompt = ""
|
||||
|
||||
generate_video(
|
||||
model_dir=args.model_dir,
|
||||
prompt=args.prompt,
|
||||
negative_prompt=args.negative_prompt,
|
||||
negative_prompt=neg_prompt,
|
||||
width=args.width,
|
||||
height=args.height,
|
||||
num_frames=args.num_frames,
|
||||
@@ -505,6 +567,7 @@ def main():
|
||||
shift=args.shift,
|
||||
seed=args.seed,
|
||||
output_path=args.output_path,
|
||||
scheduler=args.scheduler,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user