429 lines
14 KiB
Python
429 lines
14 KiB
Python
"""Flow matching schedulers for Wan2.2 inference.
|
||
|
||
Provides Euler, DPM++2M, and UniPC solvers for flow matching diffusion.
|
||
Higher-order solvers (DPM++, UniPC) converge faster, needing fewer steps
|
||
for the same quality as Euler.
|
||
"""
|
||
|
||
import math
|
||
|
||
import numpy as np
|
||
|
||
import mlx.core as mx
|
||
|
||
|
||
def _compute_sigmas(num_steps: int, shift: float = 1.0) -> np.ndarray:
|
||
"""Compute shifted sigma schedule matching official Wan2.2 code.
|
||
|
||
Returns num_steps+1 values (the last being 0.0 for the terminal state).
|
||
"""
|
||
sigmas = np.linspace(1.0, 0.0, num_steps + 1)[:num_steps]
|
||
sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas)
|
||
return np.append(sigmas, 0.0).astype(np.float32)
|
||
|
||
|
||
class FlowMatchEulerScheduler:
|
||
"""1st-order Euler scheduler for flow matching diffusion."""
|
||
|
||
def __init__(self, num_train_timesteps: int = 1000):
|
||
self.num_train_timesteps = num_train_timesteps
|
||
self.timesteps = None
|
||
self.sigmas = None
|
||
|
||
def set_timesteps(self, num_steps: int, shift: float = 1.0):
|
||
sigmas = _compute_sigmas(num_steps, shift)
|
||
self.sigmas = mx.array(sigmas)
|
||
self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps)
|
||
self._step_index = 0
|
||
|
||
def step(
|
||
self,
|
||
model_output: mx.array,
|
||
timestep,
|
||
sample: mx.array,
|
||
) -> mx.array:
|
||
"""Euler step: x_next = x + (sigma_next - sigma_cur) * v."""
|
||
dt = float(self.sigmas[self._step_index + 1].item()) - float(
|
||
self.sigmas[self._step_index].item()
|
||
)
|
||
x_next = sample + dt * model_output
|
||
self._step_index += 1
|
||
return x_next
|
||
|
||
def reset(self):
|
||
self._step_index = 0
|
||
|
||
|
||
class FlowDPMPP2MScheduler:
|
||
"""DPM-Solver++(2M) for flow matching diffusion.
|
||
|
||
2nd-order multistep solver that reuses the previous step's model output
|
||
for a correction term. Falls back to 1st order on the first and
|
||
(optionally) last step. Reference: Wan2.2 fm_solvers.py.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
num_train_timesteps: int = 1000,
|
||
lower_order_final: bool = True,
|
||
):
|
||
self.num_train_timesteps = num_train_timesteps
|
||
self.lower_order_final = lower_order_final
|
||
self.timesteps = None
|
||
self.sigmas = None
|
||
|
||
def set_timesteps(self, num_steps: int, shift: float = 1.0):
|
||
sigmas = _compute_sigmas(num_steps, shift)
|
||
self.sigmas = mx.array(sigmas)
|
||
self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps)
|
||
# Store sigmas as Python floats for scalar math
|
||
self._sigmas_float = sigmas.tolist()
|
||
self._step_index = 0
|
||
self._num_steps = num_steps
|
||
self._prev_x0 = None # previous x0 prediction for 2nd-order correction
|
||
|
||
@staticmethod
|
||
def _lambda(sigma: float) -> float:
|
||
"""log-SNR: lambda(sigma) = log((1-sigma)/sigma).
|
||
|
||
Returns -inf at sigma=1.0 (pure noise) and +inf at sigma=0.0 (clean),
|
||
matching torch.log behavior in the official code.
|
||
"""
|
||
if sigma >= 1.0:
|
||
return -math.inf
|
||
if sigma <= 0.0:
|
||
return math.inf
|
||
return math.log((1.0 - sigma) / sigma)
|
||
|
||
def step(
|
||
self,
|
||
model_output: mx.array,
|
||
timestep,
|
||
sample: mx.array,
|
||
) -> mx.array:
|
||
"""DPM++(2M) step for flow matching.
|
||
|
||
Converts velocity prediction to x0, then applies 1st or 2nd order
|
||
update depending on available history.
|
||
"""
|
||
i = self._step_index
|
||
s = self._sigmas_float
|
||
|
||
sigma_cur = s[i]
|
||
sigma_next = s[i + 1]
|
||
|
||
# Convert velocity -> x0 prediction: x0 = sample - sigma * v
|
||
x0 = sample - sigma_cur * model_output
|
||
|
||
# Decide order: 1st for first step, last step (if lower_order_final
|
||
# and few steps), otherwise 2nd
|
||
use_first_order = (
|
||
self._prev_x0 is None
|
||
or (
|
||
self.lower_order_final
|
||
and i == self._num_steps - 1
|
||
and self._num_steps < 15
|
||
)
|
||
)
|
||
|
||
if use_first_order or sigma_next == 0.0:
|
||
# 1st order DPM++ (equivalent to DDIM):
|
||
# x_next = (σ_next/σ_cur)*x - (α_next*(exp(-h)-1))*x0
|
||
if sigma_next == 0.0:
|
||
x_next = x0
|
||
else:
|
||
lambda_cur = self._lambda(sigma_cur)
|
||
lambda_next = self._lambda(sigma_next)
|
||
h = lambda_next - lambda_cur
|
||
alpha_next = 1.0 - sigma_next
|
||
coeff_x = sigma_next / sigma_cur
|
||
coeff_x0 = alpha_next * math.expm1(-h)
|
||
x_next = coeff_x * sample - coeff_x0 * x0
|
||
else:
|
||
# 2nd order DPM++(2M) with midpoint correction
|
||
sigma_prev = s[i - 1]
|
||
lambda_prev = self._lambda(sigma_prev)
|
||
lambda_cur = self._lambda(sigma_cur)
|
||
lambda_next = self._lambda(sigma_next)
|
||
|
||
h = lambda_next - lambda_cur
|
||
h_0 = lambda_cur - lambda_prev
|
||
r0 = h_0 / h
|
||
|
||
# D0 = current x0, D1 = correction from previous x0
|
||
D0 = x0
|
||
D1 = (1.0 / r0) * (x0 - self._prev_x0)
|
||
|
||
alpha_next = 1.0 - sigma_next
|
||
exp_neg_h_m1 = math.expm1(-h) # exp(-h) - 1
|
||
|
||
x_next = (
|
||
(sigma_next / sigma_cur) * sample
|
||
- (alpha_next * exp_neg_h_m1) * D0
|
||
- 0.5 * (alpha_next * exp_neg_h_m1) * D1
|
||
)
|
||
|
||
self._prev_x0 = x0
|
||
self._step_index += 1
|
||
return x_next
|
||
|
||
def reset(self):
|
||
self._step_index = 0
|
||
self._prev_x0 = None
|
||
|
||
|
||
class FlowUniPCScheduler:
|
||
"""UniPC (Unified Predictor-Corrector) for flow matching diffusion.
|
||
|
||
Multi-step predictor-corrector solver with configurable order.
|
||
The corrector refines each step using the model output that was already
|
||
computed, costing no extra model evaluations. Official Wan2.2 default.
|
||
Reference: Wan2.2 fm_solvers_unipc.py.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
num_train_timesteps: int = 1000,
|
||
solver_order: int = 2,
|
||
lower_order_final: bool = True,
|
||
disable_corrector: list | None = None,
|
||
use_corrector: bool = False,
|
||
):
|
||
self.num_train_timesteps = num_train_timesteps
|
||
self.solver_order = solver_order
|
||
self.lower_order_final = lower_order_final
|
||
self._use_corrector = use_corrector
|
||
self.disable_corrector = set(disable_corrector or [])
|
||
self.timesteps = None
|
||
self.sigmas = None
|
||
|
||
def set_timesteps(self, num_steps: int, shift: float = 1.0):
|
||
sigmas = _compute_sigmas(num_steps, shift)
|
||
self.sigmas = mx.array(sigmas)
|
||
self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps)
|
||
self._sigmas_float = sigmas.tolist()
|
||
self._step_index = 0
|
||
self._num_steps = num_steps
|
||
self._lower_order_nums = 0
|
||
# Model output (x0) history for multi-step, stored newest-last
|
||
self._model_outputs = [None] * self.solver_order
|
||
self._last_sample = None # sample before prediction (for corrector)
|
||
self._this_order = 1
|
||
|
||
@staticmethod
|
||
def _lambda(sigma: float) -> float:
|
||
"""log-SNR: lambda(sigma) = log((1-sigma)/sigma).
|
||
|
||
Returns -inf at sigma=1.0 (pure noise) and +inf at sigma=0.0 (clean),
|
||
matching torch.log behavior in the official code.
|
||
"""
|
||
if sigma >= 1.0:
|
||
return -math.inf
|
||
if sigma <= 0.0:
|
||
return math.inf
|
||
return math.log((1.0 - sigma) / sigma)
|
||
|
||
def _convert_output(self, velocity: mx.array, sample: mx.array) -> mx.array:
|
||
"""Convert velocity prediction to x0: x0 = sample - sigma * v."""
|
||
sigma = self._sigmas_float[self._step_index]
|
||
return sample - sigma * velocity
|
||
|
||
def _uni_p_bh2(self, x0: mx.array, sample: mx.array, order: int) -> mx.array:
|
||
"""UniP predictor with B(h)=expm1(-h) basis (bh2 variant).
|
||
|
||
Matches official multistep_uni_p_bh_update: computes rhos_p via
|
||
linalg.solve for order >= 3; order <= 2 uses analytic rhos_p=[0.5].
|
||
"""
|
||
i = self._step_index
|
||
s = self._sigmas_float
|
||
|
||
sigma_s0 = s[i]
|
||
sigma_t = s[i + 1]
|
||
|
||
if sigma_t == 0.0:
|
||
return x0
|
||
|
||
lambda_s0 = self._lambda(sigma_s0)
|
||
lambda_t = self._lambda(sigma_t)
|
||
h = lambda_t - lambda_s0
|
||
hh = -h # negated for predict_x0
|
||
|
||
alpha_t = 1.0 - sigma_t
|
||
h_phi_1 = math.expm1(hh)
|
||
B_h = h_phi_1
|
||
|
||
m0 = self._model_outputs[-1]
|
||
# Base prediction
|
||
x_t = (sigma_t / sigma_s0) * sample - (alpha_t * h_phi_1) * m0
|
||
|
||
if order >= 2 and m0 is not None:
|
||
rks = []
|
||
D1s = []
|
||
for k in range(1, order):
|
||
si_idx = i - k
|
||
if si_idx < 0 or self._model_outputs[-(k + 1)] is None:
|
||
break
|
||
mk = self._model_outputs[-(k + 1)]
|
||
sigma_sk = s[si_idx]
|
||
lambda_sk = self._lambda(sigma_sk)
|
||
rk = (lambda_sk - lambda_s0) / h
|
||
if math.isinf(rk):
|
||
break
|
||
rks.append(rk)
|
||
D1s.append((mk - m0) / rk)
|
||
|
||
if D1s:
|
||
effective_order = len(D1s) + 1
|
||
if effective_order <= 2:
|
||
# Analytic solution for order 2
|
||
rhos_p = [0.5]
|
||
else:
|
||
rks_arr = np.array(rks, dtype=np.float64)
|
||
h_phi_k = h_phi_1 / hh - 1.0
|
||
factorial_i = 1
|
||
R_rows = []
|
||
b_vals = []
|
||
for j in range(1, effective_order):
|
||
R_rows.append(rks_arr ** (j - 1))
|
||
b_vals.append(float(h_phi_k * factorial_i / B_h))
|
||
factorial_i *= j + 1
|
||
h_phi_k = h_phi_k / hh - 1.0 / factorial_i
|
||
R = np.stack(R_rows)
|
||
b = np.array(b_vals)
|
||
rhos_p = np.linalg.solve(R, b).tolist()
|
||
|
||
pred_res = sum(r * d for r, d in zip(rhos_p, D1s))
|
||
x_t = x_t - (alpha_t * B_h) * pred_res
|
||
|
||
return x_t
|
||
|
||
def _uni_c_bh2(
|
||
self,
|
||
model_x0: mx.array,
|
||
last_sample: mx.array,
|
||
this_sample: mx.array,
|
||
order: int,
|
||
) -> mx.array:
|
||
"""UniC corrector with B(h)=expm1(-h) basis (bh2 variant).
|
||
|
||
Matches official multistep_uni_c_bh_update: computes rhos_c via
|
||
linalg.solve for order >= 2 (not hardcoded 0.5).
|
||
"""
|
||
i = self._step_index
|
||
s = self._sigmas_float
|
||
|
||
sigma_s0 = s[i - 1]
|
||
sigma_t = s[i]
|
||
|
||
if sigma_t == 0.0:
|
||
return this_sample
|
||
|
||
lambda_s0 = self._lambda(sigma_s0)
|
||
lambda_t = self._lambda(sigma_t)
|
||
h = lambda_t - lambda_s0
|
||
hh = -h # negated for predict_x0
|
||
|
||
alpha_t = 1.0 - sigma_t
|
||
h_phi_1 = math.expm1(hh)
|
||
B_h = h_phi_1
|
||
|
||
m0 = self._model_outputs[-1]
|
||
# Re-derive base from last_sample
|
||
x_t_ = (sigma_t / sigma_s0) * last_sample - (alpha_t * h_phi_1) * m0
|
||
|
||
D1_t = model_x0 - m0
|
||
|
||
# Gather rks and D1s from history
|
||
rks = []
|
||
D1s = []
|
||
for k in range(1, order):
|
||
si_idx = i - (k + 1)
|
||
if si_idx < 0 or self._model_outputs[-(k + 1)] is None:
|
||
break
|
||
mk = self._model_outputs[-(k + 1)]
|
||
sigma_sk = s[si_idx]
|
||
lambda_sk = self._lambda(sigma_sk)
|
||
rk = (lambda_sk - lambda_s0) / h
|
||
if math.isinf(rk):
|
||
break # History references sigma=1.0 boundary; reduce order
|
||
rks.append(rk)
|
||
D1s.append((mk - m0) / rk)
|
||
rks.append(1.0)
|
||
effective_order = len(rks) # = len(D1s) + 1
|
||
|
||
# Compute rhos_c coefficients
|
||
if effective_order == 1:
|
||
rhos_c = [0.5]
|
||
else:
|
||
rks_arr = np.array(rks, dtype=np.float64)
|
||
h_phi_k = h_phi_1 / hh - 1.0
|
||
factorial_i = 1
|
||
R_rows = []
|
||
b_vals = []
|
||
for j in range(1, effective_order + 1):
|
||
R_rows.append(rks_arr ** (j - 1))
|
||
b_vals.append(float(h_phi_k * factorial_i / B_h))
|
||
factorial_i *= j + 1
|
||
h_phi_k = h_phi_k / hh - 1.0 / factorial_i
|
||
R = np.stack(R_rows)
|
||
b = np.array(b_vals)
|
||
rhos_c = np.linalg.solve(R, b).tolist()
|
||
|
||
# Apply correction
|
||
corr_res = mx.zeros_like(D1_t)
|
||
for k_idx, d1 in enumerate(D1s):
|
||
corr_res = corr_res + rhos_c[k_idx] * d1
|
||
x_t = x_t_ - (alpha_t * B_h) * (corr_res + rhos_c[-1] * D1_t)
|
||
return x_t
|
||
|
||
def step(
|
||
self,
|
||
model_output: mx.array,
|
||
timestep,
|
||
sample: mx.array,
|
||
) -> mx.array:
|
||
"""UniPC step: correct current, then predict next."""
|
||
i = self._step_index
|
||
|
||
# Convert velocity -> x0
|
||
x0 = self._convert_output(model_output, sample)
|
||
|
||
# 1. Corrector: refine current sample if we have history
|
||
use_corrector = (
|
||
self._use_corrector
|
||
and i > 0
|
||
and (i - 1) not in self.disable_corrector
|
||
and self._last_sample is not None
|
||
)
|
||
if use_corrector:
|
||
sample = self._uni_c_bh2(x0, self._last_sample, sample, self._this_order)
|
||
|
||
# 2. Shift model output history
|
||
for k in range(self.solver_order - 1):
|
||
self._model_outputs[k] = self._model_outputs[k + 1]
|
||
self._model_outputs[-1] = x0
|
||
|
||
# 3. Determine prediction order
|
||
if self.lower_order_final:
|
||
this_order = min(self.solver_order, self._num_steps - i)
|
||
else:
|
||
this_order = self.solver_order
|
||
self._this_order = min(this_order, self._lower_order_nums + 1)
|
||
|
||
# 4. Predict next sample
|
||
self._last_sample = sample
|
||
x_next = self._uni_p_bh2(x0, sample, self._this_order)
|
||
|
||
if self._lower_order_nums < self.solver_order:
|
||
self._lower_order_nums += 1
|
||
|
||
self._step_index += 1
|
||
return x_next
|
||
|
||
def reset(self):
|
||
self._step_index = 0
|
||
self._lower_order_nums = 0
|
||
self._model_outputs = [None] * self.solver_order
|
||
self._last_sample = None
|
||
self._this_order = 1
|