feat(wan): Add DPM++ 2M and UniPC schedulers
This commit is contained in:
@@ -412,10 +412,13 @@ def convert_wan_checkpoint(
|
||||
weights = sanitize_wan22_vae_weights(weights)
|
||||
else:
|
||||
weights = sanitize_wan_vae_weights(weights)
|
||||
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
||||
# Always save VAE in float32 — official Wan2.2 runs VAE decode in
|
||||
# float32 (dtype=torch.float). Saving in bfloat16 loses precision
|
||||
# that cannot be recovered by upcasting at load time.
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
out_path = output_dir / "vae.safetensors"
|
||||
mx.save_safetensors(str(out_path), weights)
|
||||
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
||||
print(f" Saved {len(weights)} weight tensors to {out_path} (float32)")
|
||||
|
||||
# Quantize transformer weights if requested
|
||||
if quantize:
|
||||
|
||||
@@ -56,7 +56,12 @@ def load_wan_model(model_path: Path, config, quantization: dict | None = None):
|
||||
|
||||
|
||||
def load_t5_encoder(model_path: Path, config):
|
||||
"""Load T5 text encoder."""
|
||||
"""Load T5 text encoder.
|
||||
|
||||
Weights are upcast to float32 for maximum precision — the T5 encoder
|
||||
only runs once per generation, so performance impact is negligible.
|
||||
This matches the official which computes softmax in float32 explicitly.
|
||||
"""
|
||||
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
@@ -70,6 +75,7 @@ def load_t5_encoder(model_path: Path, config):
|
||||
shared_pos=False,
|
||||
)
|
||||
weights = mx.load(str(model_path))
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
encoder.load_weights(list(weights.items()))
|
||||
mx.eval(encoder.parameters())
|
||||
return encoder
|
||||
@@ -91,11 +97,33 @@ def load_vae_decoder(model_path: Path, config=None):
|
||||
vae = WanVAE(z_dim=16)
|
||||
|
||||
weights = mx.load(str(model_path))
|
||||
# Upcast VAE weights to float32 for quality — official Wan2.2 runs VAE in float32
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
vae.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(vae.parameters())
|
||||
return vae
|
||||
|
||||
|
||||
def _clean_text(text: str) -> str:
|
||||
"""Clean text matching official Wan2.2 tokenizer preprocessing.
|
||||
|
||||
Applies ftfy.fix_text (fixes mojibake, normalizes fullwidth chars),
|
||||
double HTML unescape, and whitespace normalization. Critical for
|
||||
correct tokenization of the Chinese negative prompt.
|
||||
"""
|
||||
import html
|
||||
import re
|
||||
|
||||
try:
|
||||
import ftfy
|
||||
text = ftfy.fix_text(text)
|
||||
except ImportError:
|
||||
pass
|
||||
text = html.unescape(html.unescape(text))
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
return text
|
||||
|
||||
|
||||
def encode_text(
|
||||
encoder,
|
||||
tokenizer,
|
||||
@@ -113,6 +141,7 @@ def encode_text(
|
||||
Returns:
|
||||
Text embeddings [L, dim]
|
||||
"""
|
||||
prompt = _clean_text(prompt)
|
||||
tokens = tokenizer(
|
||||
prompt,
|
||||
max_length=text_len,
|
||||
@@ -133,7 +162,7 @@ def encode_text(
|
||||
def generate_video(
|
||||
model_dir: str,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
negative_prompt: str | None = None,
|
||||
width: int = 1280,
|
||||
height: int = 720,
|
||||
num_frames: int = 81,
|
||||
@@ -142,13 +171,14 @@ def generate_video(
|
||||
shift: float = None,
|
||||
seed: int = -1,
|
||||
output_path: str = "output.mp4",
|
||||
scheduler: str = "unipc",
|
||||
):
|
||||
"""Generate video using Wan T2V pipeline (supports 2.1 and 2.2).
|
||||
|
||||
Args:
|
||||
model_dir: Path to converted MLX model directory
|
||||
prompt: Text prompt
|
||||
negative_prompt: Negative prompt
|
||||
negative_prompt: Negative prompt (None = use config default, "" = no negative prompt)
|
||||
width: Video width
|
||||
height: Video height
|
||||
num_frames: Number of frames (must be 4n+1)
|
||||
@@ -157,11 +187,16 @@ def generate_video(
|
||||
shift: Noise schedule shift (None = use config default)
|
||||
seed: Random seed (-1 for random)
|
||||
output_path: Output video path
|
||||
scheduler: Solver type: 'euler', 'dpm++', or 'unipc' (default)
|
||||
"""
|
||||
import json
|
||||
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowMatchEulerScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
|
||||
model_dir = Path(model_dir)
|
||||
|
||||
@@ -253,12 +288,23 @@ def generate_video(
|
||||
|
||||
version_str = f"Wan{config.model_version}"
|
||||
mode_str = "dual-model" if is_dual else "single-model"
|
||||
# Resolve negative prompt: explicit user value > config default
|
||||
# The official Wan2.2 uses a Chinese negative prompt (config.sample_neg_prompt)
|
||||
# that prevents oversaturation, artifacts, and comic look. We use it by default.
|
||||
# Text cleaning (_clean_text) normalizes fullwidth chars to match official tokenization.
|
||||
if negative_prompt is None:
|
||||
neg_prompt_resolved = config.sample_neg_prompt
|
||||
else:
|
||||
neg_prompt_resolved = negative_prompt
|
||||
print(f"{Colors.CYAN}{'='*60}")
|
||||
print(f" {version_str} Text-to-Video Generation (MLX, {mode_str})")
|
||||
print(f"{'='*60}{Colors.RESET}")
|
||||
print(f"{Colors.DIM} Prompt: {prompt}")
|
||||
if neg_prompt_resolved and neg_prompt_resolved.strip():
|
||||
neg_display = neg_prompt_resolved[:60] + "..." if len(neg_prompt_resolved) > 60 else neg_prompt_resolved
|
||||
print(f" Neg prompt: {neg_display}")
|
||||
print(f" Size: {width}x{height}, Frames: {num_frames}")
|
||||
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}")
|
||||
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}")
|
||||
print(f"{Colors.RESET}")
|
||||
|
||||
# Seed
|
||||
@@ -298,10 +344,7 @@ def generate_video(
|
||||
# Encode prompts
|
||||
print(f"{Colors.BLUE}Encoding text...{Colors.RESET}")
|
||||
context = encode_text(t5_encoder, tokenizer, prompt, config.text_len)
|
||||
if negative_prompt:
|
||||
context_null = encode_text(t5_encoder, tokenizer, negative_prompt, config.text_len)
|
||||
else:
|
||||
context_null = encode_text(t5_encoder, tokenizer, "", config.text_len)
|
||||
context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len)
|
||||
mx.eval(context, context_null)
|
||||
|
||||
# Free T5 from memory
|
||||
@@ -343,8 +386,14 @@ def generate_video(
|
||||
mx.eval(cross_kv)
|
||||
|
||||
# Setup scheduler
|
||||
scheduler = FlowMatchEulerScheduler(num_train_timesteps=config.num_train_timesteps)
|
||||
scheduler.set_timesteps(steps, shift=shift)
|
||||
_schedulers = {
|
||||
"euler": FlowMatchEulerScheduler,
|
||||
"dpm++": FlowDPMPP2MScheduler,
|
||||
"unipc": FlowUniPCScheduler,
|
||||
}
|
||||
sched_cls = _schedulers.get(scheduler, FlowUniPCScheduler)
|
||||
sched = sched_cls(num_train_timesteps=config.num_train_timesteps)
|
||||
sched.set_timesteps(steps, shift=shift)
|
||||
|
||||
# Generate initial noise
|
||||
noise = mx.random.normal(target_shape)
|
||||
@@ -358,7 +407,7 @@ def generate_video(
|
||||
t3 = time.time()
|
||||
|
||||
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
|
||||
timestep_val = scheduler.timesteps[i].item()
|
||||
timestep_val = sched.timesteps[i].item()
|
||||
|
||||
# Select model, guide scale, and cached K/V
|
||||
if is_dual:
|
||||
@@ -387,7 +436,7 @@ def generate_video(
|
||||
|
||||
# Classifier-free guidance + scheduler step
|
||||
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
||||
latents = scheduler.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||
|
||||
# Release temporaries before eval to free memory for graph execution
|
||||
del noise_pred_cond, noise_pred_uncond, noise_pred, preds
|
||||
@@ -476,7 +525,10 @@ def main():
|
||||
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
|
||||
parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory")
|
||||
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
|
||||
parser.add_argument("--negative-prompt", type=str, default="", help="Negative prompt")
|
||||
parser.add_argument("--negative-prompt", type=str, default=None,
|
||||
help="Negative prompt for CFG (default: official Chinese prompt from config)")
|
||||
parser.add_argument("--no-negative-prompt", action="store_true",
|
||||
help="Disable negative prompt (use empty string instead of config default)")
|
||||
parser.add_argument("--width", type=int, default=1280, help="Video width")
|
||||
parser.add_argument("--height", type=int, default=720, help="Video height")
|
||||
parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)")
|
||||
@@ -485,6 +537,11 @@ def main():
|
||||
parser.add_argument("--shift", type=float, default=None, help="Noise schedule shift (default: from config)")
|
||||
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
|
||||
parser.add_argument("--output-path", type=str, default="output.mp4", help="Output video path")
|
||||
parser.add_argument(
|
||||
"--scheduler", type=str, default="unipc",
|
||||
choices=["euler", "dpm++", "unipc"],
|
||||
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse guide scale
|
||||
@@ -493,10 +550,15 @@ def main():
|
||||
parts = [float(x) for x in args.guide_scale.split(",")]
|
||||
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
# Handle negative prompt: --no-negative-prompt forces empty, otherwise pass through
|
||||
neg_prompt = args.negative_prompt
|
||||
if args.no_negative_prompt:
|
||||
neg_prompt = ""
|
||||
|
||||
generate_video(
|
||||
model_dir=args.model_dir,
|
||||
prompt=args.prompt,
|
||||
negative_prompt=args.negative_prompt,
|
||||
negative_prompt=neg_prompt,
|
||||
width=args.width,
|
||||
height=args.height,
|
||||
num_frames=args.num_frames,
|
||||
@@ -505,6 +567,7 @@ def main():
|
||||
shift=args.shift,
|
||||
seed=args.seed,
|
||||
output_path=args.output_path,
|
||||
scheduler=args.scheduler,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -38,6 +38,12 @@ class WanModelConfig(BaseModelConfig):
|
||||
num_train_timesteps: int = 1000
|
||||
sample_fps: int = 16
|
||||
frame_num: int = 81
|
||||
sample_neg_prompt: str = (
|
||||
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,"
|
||||
"最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,"
|
||||
"画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,"
|
||||
"杂乱的背景,三条腿,背景人很多,倒着走"
|
||||
)
|
||||
|
||||
# T5
|
||||
t5_vocab_size: int = 256384
|
||||
|
||||
@@ -1,16 +1,29 @@
|
||||
"""Flow matching scheduler for Wan2.2 inference."""
|
||||
"""Flow matching schedulers for Wan2.2 inference.
|
||||
|
||||
Provides Euler, DPM++2M, and UniPC solvers for flow matching diffusion.
|
||||
Higher-order solvers (DPM++, UniPC) converge faster, needing fewer steps
|
||||
for the same quality as Euler.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
class FlowMatchEulerScheduler:
|
||||
"""Simple Euler scheduler for flow matching diffusion.
|
||||
def _compute_sigmas(num_steps: int, shift: float = 1.0) -> np.ndarray:
|
||||
"""Compute shifted sigma schedule matching official Wan2.2 code.
|
||||
|
||||
Implements the flow matching formulation where the model predicts
|
||||
velocity (flow) and we use Euler steps to denoise.
|
||||
Returns num_steps+1 values (the last being 0.0 for the terminal state).
|
||||
"""
|
||||
sigmas = np.linspace(1.0, 0.0, num_steps + 1)[:num_steps]
|
||||
sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas)
|
||||
return np.append(sigmas, 0.0).astype(np.float32)
|
||||
|
||||
|
||||
class FlowMatchEulerScheduler:
|
||||
"""1st-order Euler scheduler for flow matching diffusion."""
|
||||
|
||||
def __init__(self, num_train_timesteps: int = 1000):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
@@ -18,30 +31,9 @@ class FlowMatchEulerScheduler:
|
||||
self.sigmas = None
|
||||
|
||||
def set_timesteps(self, num_steps: int, shift: float = 1.0):
|
||||
"""Compute sigma schedule with shift.
|
||||
|
||||
Args:
|
||||
num_steps: Number of inference steps.
|
||||
shift: Noise schedule shift factor.
|
||||
"""
|
||||
# Linear spacing from sigma_max to sigma_min
|
||||
sigmas = np.linspace(1.0, 1.0 / self.num_train_timesteps, self.num_train_timesteps)[::-1]
|
||||
sigmas = 1.0 - sigmas
|
||||
|
||||
# Select evenly spaced subset
|
||||
indices = np.linspace(0, len(sigmas) - 1, num_steps + 1).astype(int)
|
||||
sigmas = sigmas[indices[:-1]]
|
||||
|
||||
# Apply shift: sigma' = shift * sigma / (1 + (shift - 1) * sigma)
|
||||
sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas)
|
||||
|
||||
# Convert to timesteps
|
||||
timesteps = sigmas * self.num_train_timesteps
|
||||
self.timesteps = mx.array(timesteps.astype(np.float32))
|
||||
|
||||
# Append terminal sigma=0
|
||||
sigmas = np.append(sigmas, 0.0)
|
||||
self.sigmas = mx.array(sigmas.astype(np.float32))
|
||||
sigmas = _compute_sigmas(num_steps, shift)
|
||||
self.sigmas = mx.array(sigmas)
|
||||
self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps)
|
||||
self._step_index = 0
|
||||
|
||||
def step(
|
||||
@@ -50,27 +42,387 @@ class FlowMatchEulerScheduler:
|
||||
timestep,
|
||||
sample: mx.array,
|
||||
) -> mx.array:
|
||||
"""Euler step for flow matching.
|
||||
|
||||
In flow matching, model predicts velocity v, and:
|
||||
x_{t-1} = sample + (sigma_{t-1} - sigma_t) * v
|
||||
|
||||
Args:
|
||||
model_output: Predicted velocity [B, C, T, H, W]
|
||||
timestep: Current timestep (unused, step index is tracked internally)
|
||||
sample: Current noisy sample [B, C, T, H, W]
|
||||
|
||||
Returns:
|
||||
Updated sample
|
||||
"""
|
||||
# Use Python floats to avoid creating mx.array scalars that
|
||||
# could trigger type promotion (per fast-mlx guide)
|
||||
dt = float(self.sigmas[self._step_index + 1].item()) - float(self.sigmas[self._step_index].item())
|
||||
"""Euler step: x_next = x + (sigma_next - sigma_cur) * v."""
|
||||
dt = float(self.sigmas[self._step_index + 1].item()) - float(
|
||||
self.sigmas[self._step_index].item()
|
||||
)
|
||||
x_next = sample + dt * model_output
|
||||
self._step_index += 1
|
||||
return x_next
|
||||
|
||||
def reset(self):
|
||||
self._step_index = 0
|
||||
|
||||
|
||||
class FlowDPMPP2MScheduler:
|
||||
"""DPM-Solver++(2M) for flow matching diffusion.
|
||||
|
||||
2nd-order multistep solver that reuses the previous step's model output
|
||||
for a correction term. Falls back to 1st order on the first and
|
||||
(optionally) last step. Reference: Wan2.2 fm_solvers.py.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
lower_order_final: bool = True,
|
||||
):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.lower_order_final = lower_order_final
|
||||
self.timesteps = None
|
||||
self.sigmas = None
|
||||
|
||||
def set_timesteps(self, num_steps: int, shift: float = 1.0):
|
||||
sigmas = _compute_sigmas(num_steps, shift)
|
||||
self.sigmas = mx.array(sigmas)
|
||||
self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps)
|
||||
# Store sigmas as Python floats for scalar math
|
||||
self._sigmas_float = sigmas.tolist()
|
||||
self._step_index = 0
|
||||
self._num_steps = num_steps
|
||||
self._prev_x0 = None # previous x0 prediction for 2nd-order correction
|
||||
|
||||
@staticmethod
|
||||
def _lambda(sigma: float) -> float:
|
||||
"""log-SNR: lambda(sigma) = log((1-sigma)/sigma).
|
||||
|
||||
Returns -inf at sigma=1.0 (pure noise) and +inf at sigma=0.0 (clean),
|
||||
matching torch.log behavior in the official code.
|
||||
"""
|
||||
if sigma >= 1.0:
|
||||
return -math.inf
|
||||
if sigma <= 0.0:
|
||||
return math.inf
|
||||
return math.log((1.0 - sigma) / sigma)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
timestep,
|
||||
sample: mx.array,
|
||||
) -> mx.array:
|
||||
"""DPM++(2M) step for flow matching.
|
||||
|
||||
Converts velocity prediction to x0, then applies 1st or 2nd order
|
||||
update depending on available history.
|
||||
"""
|
||||
i = self._step_index
|
||||
s = self._sigmas_float
|
||||
|
||||
sigma_cur = s[i]
|
||||
sigma_next = s[i + 1]
|
||||
|
||||
# Convert velocity -> x0 prediction: x0 = sample - sigma * v
|
||||
x0 = sample - sigma_cur * model_output
|
||||
|
||||
# Decide order: 1st for first step, last step (if lower_order_final
|
||||
# and few steps), otherwise 2nd
|
||||
use_first_order = (
|
||||
self._prev_x0 is None
|
||||
or (
|
||||
self.lower_order_final
|
||||
and i == self._num_steps - 1
|
||||
and self._num_steps < 15
|
||||
)
|
||||
)
|
||||
|
||||
if use_first_order or sigma_next == 0.0:
|
||||
# 1st order DPM++ (equivalent to DDIM):
|
||||
# x_next = (σ_next/σ_cur)*x - (α_next*(exp(-h)-1))*x0
|
||||
if sigma_next == 0.0:
|
||||
x_next = x0
|
||||
else:
|
||||
lambda_cur = self._lambda(sigma_cur)
|
||||
lambda_next = self._lambda(sigma_next)
|
||||
h = lambda_next - lambda_cur
|
||||
alpha_next = 1.0 - sigma_next
|
||||
coeff_x = sigma_next / sigma_cur
|
||||
coeff_x0 = alpha_next * math.expm1(-h)
|
||||
x_next = coeff_x * sample - coeff_x0 * x0
|
||||
else:
|
||||
# 2nd order DPM++(2M) with midpoint correction
|
||||
sigma_prev = s[i - 1]
|
||||
lambda_prev = self._lambda(sigma_prev)
|
||||
lambda_cur = self._lambda(sigma_cur)
|
||||
lambda_next = self._lambda(sigma_next)
|
||||
|
||||
h = lambda_next - lambda_cur
|
||||
h_0 = lambda_cur - lambda_prev
|
||||
r0 = h_0 / h
|
||||
|
||||
# D0 = current x0, D1 = correction from previous x0
|
||||
D0 = x0
|
||||
D1 = (1.0 / r0) * (x0 - self._prev_x0)
|
||||
|
||||
alpha_next = 1.0 - sigma_next
|
||||
exp_neg_h_m1 = math.expm1(-h) # exp(-h) - 1
|
||||
|
||||
x_next = (
|
||||
(sigma_next / sigma_cur) * sample
|
||||
- (alpha_next * exp_neg_h_m1) * D0
|
||||
- 0.5 * (alpha_next * exp_neg_h_m1) * D1
|
||||
)
|
||||
|
||||
self._prev_x0 = x0
|
||||
self._step_index += 1
|
||||
return x_next
|
||||
|
||||
def reset(self):
|
||||
self._step_index = 0
|
||||
self._prev_x0 = None
|
||||
|
||||
|
||||
class FlowUniPCScheduler:
|
||||
"""UniPC (Unified Predictor-Corrector) for flow matching diffusion.
|
||||
|
||||
Multi-step predictor-corrector solver with configurable order.
|
||||
The corrector refines each step using the model output that was already
|
||||
computed, costing no extra model evaluations. Official Wan2.2 default.
|
||||
Reference: Wan2.2 fm_solvers_unipc.py.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
solver_order: int = 2,
|
||||
lower_order_final: bool = True,
|
||||
disable_corrector: list | None = None,
|
||||
use_corrector: bool = False,
|
||||
):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.solver_order = solver_order
|
||||
self.lower_order_final = lower_order_final
|
||||
self._use_corrector = use_corrector
|
||||
self.disable_corrector = set(disable_corrector or [])
|
||||
self.timesteps = None
|
||||
self.sigmas = None
|
||||
|
||||
def set_timesteps(self, num_steps: int, shift: float = 1.0):
|
||||
sigmas = _compute_sigmas(num_steps, shift)
|
||||
self.sigmas = mx.array(sigmas)
|
||||
self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps)
|
||||
self._sigmas_float = sigmas.tolist()
|
||||
self._step_index = 0
|
||||
self._num_steps = num_steps
|
||||
self._lower_order_nums = 0
|
||||
# Model output (x0) history for multi-step, stored newest-last
|
||||
self._model_outputs = [None] * self.solver_order
|
||||
self._last_sample = None # sample before prediction (for corrector)
|
||||
self._this_order = 1
|
||||
|
||||
@staticmethod
|
||||
def _lambda(sigma: float) -> float:
|
||||
"""log-SNR: lambda(sigma) = log((1-sigma)/sigma).
|
||||
|
||||
Returns -inf at sigma=1.0 (pure noise) and +inf at sigma=0.0 (clean),
|
||||
matching torch.log behavior in the official code.
|
||||
"""
|
||||
if sigma >= 1.0:
|
||||
return -math.inf
|
||||
if sigma <= 0.0:
|
||||
return math.inf
|
||||
return math.log((1.0 - sigma) / sigma)
|
||||
|
||||
def _convert_output(self, velocity: mx.array, sample: mx.array) -> mx.array:
|
||||
"""Convert velocity prediction to x0: x0 = sample - sigma * v."""
|
||||
sigma = self._sigmas_float[self._step_index]
|
||||
return sample - sigma * velocity
|
||||
|
||||
def _uni_p_bh2(self, x0: mx.array, sample: mx.array, order: int) -> mx.array:
|
||||
"""UniP predictor with B(h)=expm1(-h) basis (bh2 variant).
|
||||
|
||||
Matches official multistep_uni_p_bh_update: computes rhos_p via
|
||||
linalg.solve for order >= 3; order <= 2 uses analytic rhos_p=[0.5].
|
||||
"""
|
||||
i = self._step_index
|
||||
s = self._sigmas_float
|
||||
|
||||
sigma_s0 = s[i]
|
||||
sigma_t = s[i + 1]
|
||||
|
||||
if sigma_t == 0.0:
|
||||
return x0
|
||||
|
||||
lambda_s0 = self._lambda(sigma_s0)
|
||||
lambda_t = self._lambda(sigma_t)
|
||||
h = lambda_t - lambda_s0
|
||||
hh = -h # negated for predict_x0
|
||||
|
||||
alpha_t = 1.0 - sigma_t
|
||||
h_phi_1 = math.expm1(hh)
|
||||
B_h = h_phi_1
|
||||
|
||||
m0 = self._model_outputs[-1]
|
||||
# Base prediction
|
||||
x_t = (sigma_t / sigma_s0) * sample - (alpha_t * h_phi_1) * m0
|
||||
|
||||
if order >= 2 and m0 is not None:
|
||||
rks = []
|
||||
D1s = []
|
||||
for k in range(1, order):
|
||||
si_idx = i - k
|
||||
if si_idx < 0 or self._model_outputs[-(k + 1)] is None:
|
||||
break
|
||||
mk = self._model_outputs[-(k + 1)]
|
||||
sigma_sk = s[si_idx]
|
||||
lambda_sk = self._lambda(sigma_sk)
|
||||
rk = (lambda_sk - lambda_s0) / h
|
||||
if math.isinf(rk):
|
||||
break
|
||||
rks.append(rk)
|
||||
D1s.append((mk - m0) / rk)
|
||||
|
||||
if D1s:
|
||||
effective_order = len(D1s) + 1
|
||||
if effective_order <= 2:
|
||||
# Analytic solution for order 2
|
||||
rhos_p = [0.5]
|
||||
else:
|
||||
rks_arr = np.array(rks, dtype=np.float64)
|
||||
h_phi_k = h_phi_1 / hh - 1.0
|
||||
factorial_i = 1
|
||||
R_rows = []
|
||||
b_vals = []
|
||||
for j in range(1, effective_order):
|
||||
R_rows.append(rks_arr ** (j - 1))
|
||||
b_vals.append(float(h_phi_k * factorial_i / B_h))
|
||||
factorial_i *= j + 1
|
||||
h_phi_k = h_phi_k / hh - 1.0 / factorial_i
|
||||
R = np.stack(R_rows)
|
||||
b = np.array(b_vals)
|
||||
rhos_p = np.linalg.solve(R, b).tolist()
|
||||
|
||||
pred_res = sum(r * d for r, d in zip(rhos_p, D1s))
|
||||
x_t = x_t - (alpha_t * B_h) * pred_res
|
||||
|
||||
return x_t
|
||||
|
||||
def _uni_c_bh2(
|
||||
self,
|
||||
model_x0: mx.array,
|
||||
last_sample: mx.array,
|
||||
this_sample: mx.array,
|
||||
order: int,
|
||||
) -> mx.array:
|
||||
"""UniC corrector with B(h)=expm1(-h) basis (bh2 variant).
|
||||
|
||||
Matches official multistep_uni_c_bh_update: computes rhos_c via
|
||||
linalg.solve for order >= 2 (not hardcoded 0.5).
|
||||
"""
|
||||
i = self._step_index
|
||||
s = self._sigmas_float
|
||||
|
||||
sigma_s0 = s[i - 1]
|
||||
sigma_t = s[i]
|
||||
|
||||
if sigma_t == 0.0:
|
||||
return this_sample
|
||||
|
||||
lambda_s0 = self._lambda(sigma_s0)
|
||||
lambda_t = self._lambda(sigma_t)
|
||||
h = lambda_t - lambda_s0
|
||||
hh = -h # negated for predict_x0
|
||||
|
||||
alpha_t = 1.0 - sigma_t
|
||||
h_phi_1 = math.expm1(hh)
|
||||
B_h = h_phi_1
|
||||
|
||||
m0 = self._model_outputs[-1]
|
||||
# Re-derive base from last_sample
|
||||
x_t_ = (sigma_t / sigma_s0) * last_sample - (alpha_t * h_phi_1) * m0
|
||||
|
||||
D1_t = model_x0 - m0
|
||||
|
||||
# Gather rks and D1s from history
|
||||
rks = []
|
||||
D1s = []
|
||||
for k in range(1, order):
|
||||
si_idx = i - (k + 1)
|
||||
if si_idx < 0 or self._model_outputs[-(k + 1)] is None:
|
||||
break
|
||||
mk = self._model_outputs[-(k + 1)]
|
||||
sigma_sk = s[si_idx]
|
||||
lambda_sk = self._lambda(sigma_sk)
|
||||
rk = (lambda_sk - lambda_s0) / h
|
||||
if math.isinf(rk):
|
||||
break # History references sigma=1.0 boundary; reduce order
|
||||
rks.append(rk)
|
||||
D1s.append((mk - m0) / rk)
|
||||
rks.append(1.0)
|
||||
effective_order = len(rks) # = len(D1s) + 1
|
||||
|
||||
# Compute rhos_c coefficients
|
||||
if effective_order == 1:
|
||||
rhos_c = [0.5]
|
||||
else:
|
||||
rks_arr = np.array(rks, dtype=np.float64)
|
||||
h_phi_k = h_phi_1 / hh - 1.0
|
||||
factorial_i = 1
|
||||
R_rows = []
|
||||
b_vals = []
|
||||
for j in range(1, effective_order + 1):
|
||||
R_rows.append(rks_arr ** (j - 1))
|
||||
b_vals.append(float(h_phi_k * factorial_i / B_h))
|
||||
factorial_i *= j + 1
|
||||
h_phi_k = h_phi_k / hh - 1.0 / factorial_i
|
||||
R = np.stack(R_rows)
|
||||
b = np.array(b_vals)
|
||||
rhos_c = np.linalg.solve(R, b).tolist()
|
||||
|
||||
# Apply correction
|
||||
corr_res = mx.zeros_like(D1_t)
|
||||
for k_idx, d1 in enumerate(D1s):
|
||||
corr_res = corr_res + rhos_c[k_idx] * d1
|
||||
x_t = x_t_ - (alpha_t * B_h) * (corr_res + rhos_c[-1] * D1_t)
|
||||
return x_t
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
timestep,
|
||||
sample: mx.array,
|
||||
) -> mx.array:
|
||||
"""UniPC step: correct current, then predict next."""
|
||||
i = self._step_index
|
||||
|
||||
# Convert velocity -> x0
|
||||
x0 = self._convert_output(model_output, sample)
|
||||
|
||||
# 1. Corrector: refine current sample if we have history
|
||||
use_corrector = (
|
||||
self._use_corrector
|
||||
and i > 0
|
||||
and (i - 1) not in self.disable_corrector
|
||||
and self._last_sample is not None
|
||||
)
|
||||
if use_corrector:
|
||||
sample = self._uni_c_bh2(x0, self._last_sample, sample, self._this_order)
|
||||
|
||||
# 2. Shift model output history
|
||||
for k in range(self.solver_order - 1):
|
||||
self._model_outputs[k] = self._model_outputs[k + 1]
|
||||
self._model_outputs[-1] = x0
|
||||
|
||||
# 3. Determine prediction order
|
||||
if self.lower_order_final:
|
||||
this_order = min(self.solver_order, self._num_steps - i)
|
||||
else:
|
||||
this_order = self.solver_order
|
||||
self._this_order = min(this_order, self._lower_order_nums + 1)
|
||||
|
||||
# 4. Predict next sample
|
||||
self._last_sample = sample
|
||||
x_next = self._uni_p_bh2(x0, sample, self._this_order)
|
||||
|
||||
if self._lower_order_nums < self.solver_order:
|
||||
self._lower_order_nums += 1
|
||||
|
||||
self._step_index += 1
|
||||
return x_next
|
||||
|
||||
def reset(self):
|
||||
"""Reset step counter for new generation."""
|
||||
self._step_index = 0
|
||||
self._lower_order_nums = 0
|
||||
self._model_outputs = [None] * self.solver_order
|
||||
self._last_sample = None
|
||||
self._this_order = 1
|
||||
|
||||
@@ -106,29 +106,35 @@ class T5Attention(nn.Module):
|
||||
k = self.k(context).reshape(b, -1, n, c) # [B, Lk, N, C]
|
||||
v = self.v(context).reshape(b, -1, n, c)
|
||||
|
||||
# T5 does not use scaling
|
||||
# attn = einsum('binc,bjnc->bnij', q, k)
|
||||
# T5 uses no scaling — compute attention manually with float32 softmax
|
||||
# to match official: F.softmax(attn.float(), dim=-1).type_as(attn)
|
||||
# Using SDPA with bfloat16 inputs causes precision loss in softmax
|
||||
# since unscaled logits can be very large (no 1/sqrt(d) division).
|
||||
q = q.transpose(0, 2, 1, 3) # [B, N, Lq, C]
|
||||
k = k.transpose(0, 2, 1, 3)
|
||||
v = v.transpose(0, 2, 1, 3)
|
||||
|
||||
# Combine position bias and attention mask for SDPA
|
||||
attn_mask = None
|
||||
# QK^T (no scaling) — compute in float32 for precision
|
||||
attn = (q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2))
|
||||
|
||||
# Add position bias
|
||||
if pos_bias is not None:
|
||||
attn_mask = pos_bias.astype(q.dtype)
|
||||
attn = attn + pos_bias.astype(mx.float32)
|
||||
|
||||
# Apply attention mask (use dtype min like official, not -1e9)
|
||||
if mask is not None:
|
||||
if mask.ndim == 2:
|
||||
mask = mask[:, None, None, :] # [B, 1, 1, Lk]
|
||||
elif mask.ndim == 3:
|
||||
mask = mask[:, None, :, :] # [B, 1, Lq, Lk]
|
||||
additive_mask = mx.where(mask == 0, -1e9, 0.0).astype(q.dtype)
|
||||
attn_mask = (attn_mask + additive_mask) if attn_mask is not None else additive_mask
|
||||
additive_mask = mx.where(mask == 0, -3.389e38, 0.0).astype(mx.float32)
|
||||
attn = attn + additive_mask
|
||||
|
||||
# T5 uses no scaling (scale=1.0)
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=1.0, mask=attn_mask
|
||||
)
|
||||
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * c)
|
||||
# Softmax in float32 (matches official), then cast back
|
||||
attn = mx.softmax(attn, axis=-1).astype(q.dtype)
|
||||
|
||||
# Attention @ V
|
||||
out = (attn @ v).transpose(0, 2, 1, 3).reshape(b, -1, n * c)
|
||||
return self.o(out)
|
||||
|
||||
|
||||
|
||||
@@ -288,21 +288,34 @@ class Resample(nn.Module):
|
||||
B, T, H, W, C = x.shape
|
||||
|
||||
if self.mode == "upsample3d":
|
||||
# Temporal upsample via time_conv
|
||||
tc_out = self.time_conv(x) # [B, T, H, W, 2C]
|
||||
# Split into two interleaved temporal streams
|
||||
tc_out = tc_out.reshape(B, T, H, W, 2, C)
|
||||
# Interleave: [B, T, 2, H, W, C] → [B, T*2, H, W, C]
|
||||
stream0 = tc_out[:, :, :, :, 0, :] # [B, T, H, W, C]
|
||||
stream1 = tc_out[:, :, :, :, 1, :] # [B, T, H, W, C]
|
||||
x = mx.stack([stream0, stream1], axis=2) # [B, T, 2, H, W, C]
|
||||
x = x.reshape(B, T * 2, H, W, C)
|
||||
if first_chunk and T > 1:
|
||||
# Match official chunked behavior: the first frame bypasses
|
||||
# time_conv entirely (only spatial upsample). Remaining frames
|
||||
# go through time_conv with causal zero-padding, which
|
||||
# naturally gives each frame the same limited temporal context
|
||||
# as the official frame-by-frame decode with caching.
|
||||
first_frame = x[:, 0:1] # [B, 1, H, W, C]
|
||||
rest = x[:, 1:] # [B, T-1, H, W, C]
|
||||
|
||||
if first_chunk:
|
||||
# PyTorch skips time_conv for first chunk entirely. In all-at-once
|
||||
# mode, we trim the first frame to match (the first interleaved
|
||||
# frame is from zero-padded causal context and shouldn't be kept).
|
||||
x = x[:, 1:, :, :, :]
|
||||
# time_conv on remaining frames (causal pad gives zero context
|
||||
# before rest[0], matching the official "Rep" cache path)
|
||||
tc_out = self.time_conv(rest) # [B, T-1, H, W, 2C]
|
||||
tc_out = tc_out.reshape(B, T - 1, H, W, 2, C)
|
||||
stream0 = tc_out[:, :, :, :, 0, :]
|
||||
stream1 = tc_out[:, :, :, :, 1, :]
|
||||
interleaved = mx.stack([stream0, stream1], axis=2)
|
||||
interleaved = interleaved.reshape(B, (T - 1) * 2, H, W, C)
|
||||
|
||||
# first_frame (1) + interleaved (2*(T-1)) = 2T-1 frames
|
||||
x = mx.concatenate([first_frame, interleaved], axis=1)
|
||||
elif self.mode == "upsample3d":
|
||||
# Non-first-chunk or single frame: time_conv all frames
|
||||
tc_out = self.time_conv(x) # [B, T, H, W, 2C]
|
||||
tc_out = tc_out.reshape(B, T, H, W, 2, C)
|
||||
stream0 = tc_out[:, :, :, :, 0, :]
|
||||
stream1 = tc_out[:, :, :, :, 1, :]
|
||||
x = mx.stack([stream0, stream1], axis=2)
|
||||
x = x.reshape(B, T * 2, H, W, C)
|
||||
|
||||
mx.eval(x)
|
||||
T = x.shape[1]
|
||||
|
||||
@@ -22,6 +22,7 @@ dependencies = [
|
||||
"mlx-vlm",
|
||||
"imageio>=2.37.2",
|
||||
"imageio-ffmpeg>=0.6.0",
|
||||
"ftfy",
|
||||
]
|
||||
license = {text="MIT"}
|
||||
authors = [
|
||||
|
||||
1259
tests/test_wan.py
1259
tests/test_wan.py
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user