feat(wan): Add DPM++ 2M and UniPC schedulers

This commit is contained in:
Daniel
2026-02-27 10:28:33 +01:00
parent e64483a66a
commit 93da550f65
8 changed files with 1792 additions and 89 deletions

View File

@@ -412,10 +412,13 @@ def convert_wan_checkpoint(
weights = sanitize_wan22_vae_weights(weights) weights = sanitize_wan22_vae_weights(weights)
else: else:
weights = sanitize_wan_vae_weights(weights) weights = sanitize_wan_vae_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()} # Always save VAE in float32 — official Wan2.2 runs VAE decode in
# float32 (dtype=torch.float). Saving in bfloat16 loses precision
# that cannot be recovered by upcasting at load time.
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
out_path = output_dir / "vae.safetensors" out_path = output_dir / "vae.safetensors"
mx.save_safetensors(str(out_path), weights) mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}") print(f" Saved {len(weights)} weight tensors to {out_path} (float32)")
# Quantize transformer weights if requested # Quantize transformer weights if requested
if quantize: if quantize:

View File

@@ -56,7 +56,12 @@ def load_wan_model(model_path: Path, config, quantization: dict | None = None):
def load_t5_encoder(model_path: Path, config): def load_t5_encoder(model_path: Path, config):
"""Load T5 text encoder.""" """Load T5 text encoder.
Weights are upcast to float32 for maximum precision — the T5 encoder
only runs once per generation, so performance impact is negligible.
This matches the official which computes softmax in float32 explicitly.
"""
from mlx_video.models.wan.text_encoder import T5Encoder from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder( encoder = T5Encoder(
@@ -70,6 +75,7 @@ def load_t5_encoder(model_path: Path, config):
shared_pos=False, shared_pos=False,
) )
weights = mx.load(str(model_path)) weights = mx.load(str(model_path))
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
encoder.load_weights(list(weights.items())) encoder.load_weights(list(weights.items()))
mx.eval(encoder.parameters()) mx.eval(encoder.parameters())
return encoder return encoder
@@ -91,11 +97,33 @@ def load_vae_decoder(model_path: Path, config=None):
vae = WanVAE(z_dim=16) vae = WanVAE(z_dim=16)
weights = mx.load(str(model_path)) weights = mx.load(str(model_path))
# Upcast VAE weights to float32 for quality — official Wan2.2 runs VAE in float32
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
vae.load_weights(list(weights.items()), strict=False) vae.load_weights(list(weights.items()), strict=False)
mx.eval(vae.parameters()) mx.eval(vae.parameters())
return vae return vae
def _clean_text(text: str) -> str:
"""Clean text matching official Wan2.2 tokenizer preprocessing.
Applies ftfy.fix_text (fixes mojibake, normalizes fullwidth chars),
double HTML unescape, and whitespace normalization. Critical for
correct tokenization of the Chinese negative prompt.
"""
import html
import re
try:
import ftfy
text = ftfy.fix_text(text)
except ImportError:
pass
text = html.unescape(html.unescape(text))
text = re.sub(r"\s+", " ", text).strip()
return text
def encode_text( def encode_text(
encoder, encoder,
tokenizer, tokenizer,
@@ -113,6 +141,7 @@ def encode_text(
Returns: Returns:
Text embeddings [L, dim] Text embeddings [L, dim]
""" """
prompt = _clean_text(prompt)
tokens = tokenizer( tokens = tokenizer(
prompt, prompt,
max_length=text_len, max_length=text_len,
@@ -133,7 +162,7 @@ def encode_text(
def generate_video( def generate_video(
model_dir: str, model_dir: str,
prompt: str, prompt: str,
negative_prompt: str = "", negative_prompt: str | None = None,
width: int = 1280, width: int = 1280,
height: int = 720, height: int = 720,
num_frames: int = 81, num_frames: int = 81,
@@ -142,13 +171,14 @@ def generate_video(
shift: float = None, shift: float = None,
seed: int = -1, seed: int = -1,
output_path: str = "output.mp4", output_path: str = "output.mp4",
scheduler: str = "unipc",
): ):
"""Generate video using Wan T2V pipeline (supports 2.1 and 2.2). """Generate video using Wan T2V pipeline (supports 2.1 and 2.2).
Args: Args:
model_dir: Path to converted MLX model directory model_dir: Path to converted MLX model directory
prompt: Text prompt prompt: Text prompt
negative_prompt: Negative prompt negative_prompt: Negative prompt (None = use config default, "" = no negative prompt)
width: Video width width: Video width
height: Video height height: Video height
num_frames: Number of frames (must be 4n+1) num_frames: Number of frames (must be 4n+1)
@@ -157,11 +187,16 @@ def generate_video(
shift: Noise schedule shift (None = use config default) shift: Noise schedule shift (None = use config default)
seed: Random seed (-1 for random) seed: Random seed (-1 for random)
output_path: Output video path output_path: Output video path
scheduler: Solver type: 'euler', 'dpm++', or 'unipc' (default)
""" """
import json import json
from mlx_video.models.wan.config import WanModelConfig from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
)
model_dir = Path(model_dir) model_dir = Path(model_dir)
@@ -253,12 +288,23 @@ def generate_video(
version_str = f"Wan{config.model_version}" version_str = f"Wan{config.model_version}"
mode_str = "dual-model" if is_dual else "single-model" mode_str = "dual-model" if is_dual else "single-model"
# Resolve negative prompt: explicit user value > config default
# The official Wan2.2 uses a Chinese negative prompt (config.sample_neg_prompt)
# that prevents oversaturation, artifacts, and comic look. We use it by default.
# Text cleaning (_clean_text) normalizes fullwidth chars to match official tokenization.
if negative_prompt is None:
neg_prompt_resolved = config.sample_neg_prompt
else:
neg_prompt_resolved = negative_prompt
print(f"{Colors.CYAN}{'='*60}") print(f"{Colors.CYAN}{'='*60}")
print(f" {version_str} Text-to-Video Generation (MLX, {mode_str})") print(f" {version_str} Text-to-Video Generation (MLX, {mode_str})")
print(f"{'='*60}{Colors.RESET}") print(f"{'='*60}{Colors.RESET}")
print(f"{Colors.DIM} Prompt: {prompt}") print(f"{Colors.DIM} Prompt: {prompt}")
if neg_prompt_resolved and neg_prompt_resolved.strip():
neg_display = neg_prompt_resolved[:60] + "..." if len(neg_prompt_resolved) > 60 else neg_prompt_resolved
print(f" Neg prompt: {neg_display}")
print(f" Size: {width}x{height}, Frames: {num_frames}") print(f" Size: {width}x{height}, Frames: {num_frames}")
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}") print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}")
print(f"{Colors.RESET}") print(f"{Colors.RESET}")
# Seed # Seed
@@ -298,10 +344,7 @@ def generate_video(
# Encode prompts # Encode prompts
print(f"{Colors.BLUE}Encoding text...{Colors.RESET}") print(f"{Colors.BLUE}Encoding text...{Colors.RESET}")
context = encode_text(t5_encoder, tokenizer, prompt, config.text_len) context = encode_text(t5_encoder, tokenizer, prompt, config.text_len)
if negative_prompt: context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len)
context_null = encode_text(t5_encoder, tokenizer, negative_prompt, config.text_len)
else:
context_null = encode_text(t5_encoder, tokenizer, "", config.text_len)
mx.eval(context, context_null) mx.eval(context, context_null)
# Free T5 from memory # Free T5 from memory
@@ -343,8 +386,14 @@ def generate_video(
mx.eval(cross_kv) mx.eval(cross_kv)
# Setup scheduler # Setup scheduler
scheduler = FlowMatchEulerScheduler(num_train_timesteps=config.num_train_timesteps) _schedulers = {
scheduler.set_timesteps(steps, shift=shift) "euler": FlowMatchEulerScheduler,
"dpm++": FlowDPMPP2MScheduler,
"unipc": FlowUniPCScheduler,
}
sched_cls = _schedulers.get(scheduler, FlowUniPCScheduler)
sched = sched_cls(num_train_timesteps=config.num_train_timesteps)
sched.set_timesteps(steps, shift=shift)
# Generate initial noise # Generate initial noise
noise = mx.random.normal(target_shape) noise = mx.random.normal(target_shape)
@@ -358,7 +407,7 @@ def generate_video(
t3 = time.time() t3 = time.time()
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")): for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
timestep_val = scheduler.timesteps[i].item() timestep_val = sched.timesteps[i].item()
# Select model, guide scale, and cached K/V # Select model, guide scale, and cached K/V
if is_dual: if is_dual:
@@ -387,7 +436,7 @@ def generate_video(
# Classifier-free guidance + scheduler step # Classifier-free guidance + scheduler step
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond) noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
latents = scheduler.step(noise_pred[None], timestep_val, latents[None]).squeeze(0) latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
# Release temporaries before eval to free memory for graph execution # Release temporaries before eval to free memory for graph execution
del noise_pred_cond, noise_pred_uncond, noise_pred, preds del noise_pred_cond, noise_pred_uncond, noise_pred, preds
@@ -476,7 +525,10 @@ def main():
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)") parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory") parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory")
parser.add_argument("--prompt", type=str, required=True, help="Text prompt") parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
parser.add_argument("--negative-prompt", type=str, default="", help="Negative prompt") parser.add_argument("--negative-prompt", type=str, default=None,
help="Negative prompt for CFG (default: official Chinese prompt from config)")
parser.add_argument("--no-negative-prompt", action="store_true",
help="Disable negative prompt (use empty string instead of config default)")
parser.add_argument("--width", type=int, default=1280, help="Video width") parser.add_argument("--width", type=int, default=1280, help="Video width")
parser.add_argument("--height", type=int, default=720, help="Video height") parser.add_argument("--height", type=int, default=720, help="Video height")
parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)") parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)")
@@ -485,6 +537,11 @@ def main():
parser.add_argument("--shift", type=float, default=None, help="Noise schedule shift (default: from config)") parser.add_argument("--shift", type=float, default=None, help="Noise schedule shift (default: from config)")
parser.add_argument("--seed", type=int, default=-1, help="Random seed") parser.add_argument("--seed", type=int, default=-1, help="Random seed")
parser.add_argument("--output-path", type=str, default="output.mp4", help="Output video path") parser.add_argument("--output-path", type=str, default="output.mp4", help="Output video path")
parser.add_argument(
"--scheduler", type=str, default="unipc",
choices=["euler", "dpm++", "unipc"],
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
)
args = parser.parse_args() args = parser.parse_args()
# Parse guide scale # Parse guide scale
@@ -493,10 +550,15 @@ def main():
parts = [float(x) for x in args.guide_scale.split(",")] parts = [float(x) for x in args.guide_scale.split(",")]
guide_scale = tuple(parts) if len(parts) > 1 else parts[0] guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
# Handle negative prompt: --no-negative-prompt forces empty, otherwise pass through
neg_prompt = args.negative_prompt
if args.no_negative_prompt:
neg_prompt = ""
generate_video( generate_video(
model_dir=args.model_dir, model_dir=args.model_dir,
prompt=args.prompt, prompt=args.prompt,
negative_prompt=args.negative_prompt, negative_prompt=neg_prompt,
width=args.width, width=args.width,
height=args.height, height=args.height,
num_frames=args.num_frames, num_frames=args.num_frames,
@@ -505,6 +567,7 @@ def main():
shift=args.shift, shift=args.shift,
seed=args.seed, seed=args.seed,
output_path=args.output_path, output_path=args.output_path,
scheduler=args.scheduler,
) )

View File

@@ -38,6 +38,12 @@ class WanModelConfig(BaseModelConfig):
num_train_timesteps: int = 1000 num_train_timesteps: int = 1000
sample_fps: int = 16 sample_fps: int = 16
frame_num: int = 81 frame_num: int = 81
sample_neg_prompt: str = (
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,"
"最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部"
"画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,"
"杂乱的背景,三条腿,背景人很多,倒着走"
)
# T5 # T5
t5_vocab_size: int = 256384 t5_vocab_size: int = 256384

View File

@@ -1,16 +1,29 @@
"""Flow matching scheduler for Wan2.2 inference.""" """Flow matching schedulers for Wan2.2 inference.
Provides Euler, DPM++2M, and UniPC solvers for flow matching diffusion.
Higher-order solvers (DPM++, UniPC) converge faster, needing fewer steps
for the same quality as Euler.
"""
import math
import numpy as np import numpy as np
import mlx.core as mx import mlx.core as mx
class FlowMatchEulerScheduler: def _compute_sigmas(num_steps: int, shift: float = 1.0) -> np.ndarray:
"""Simple Euler scheduler for flow matching diffusion. """Compute shifted sigma schedule matching official Wan2.2 code.
Implements the flow matching formulation where the model predicts Returns num_steps+1 values (the last being 0.0 for the terminal state).
velocity (flow) and we use Euler steps to denoise.
""" """
sigmas = np.linspace(1.0, 0.0, num_steps + 1)[:num_steps]
sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas)
return np.append(sigmas, 0.0).astype(np.float32)
class FlowMatchEulerScheduler:
"""1st-order Euler scheduler for flow matching diffusion."""
def __init__(self, num_train_timesteps: int = 1000): def __init__(self, num_train_timesteps: int = 1000):
self.num_train_timesteps = num_train_timesteps self.num_train_timesteps = num_train_timesteps
@@ -18,30 +31,9 @@ class FlowMatchEulerScheduler:
self.sigmas = None self.sigmas = None
def set_timesteps(self, num_steps: int, shift: float = 1.0): def set_timesteps(self, num_steps: int, shift: float = 1.0):
"""Compute sigma schedule with shift. sigmas = _compute_sigmas(num_steps, shift)
self.sigmas = mx.array(sigmas)
Args: self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps)
num_steps: Number of inference steps.
shift: Noise schedule shift factor.
"""
# Linear spacing from sigma_max to sigma_min
sigmas = np.linspace(1.0, 1.0 / self.num_train_timesteps, self.num_train_timesteps)[::-1]
sigmas = 1.0 - sigmas
# Select evenly spaced subset
indices = np.linspace(0, len(sigmas) - 1, num_steps + 1).astype(int)
sigmas = sigmas[indices[:-1]]
# Apply shift: sigma' = shift * sigma / (1 + (shift - 1) * sigma)
sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas)
# Convert to timesteps
timesteps = sigmas * self.num_train_timesteps
self.timesteps = mx.array(timesteps.astype(np.float32))
# Append terminal sigma=0
sigmas = np.append(sigmas, 0.0)
self.sigmas = mx.array(sigmas.astype(np.float32))
self._step_index = 0 self._step_index = 0
def step( def step(
@@ -50,27 +42,387 @@ class FlowMatchEulerScheduler:
timestep, timestep,
sample: mx.array, sample: mx.array,
) -> mx.array: ) -> mx.array:
"""Euler step for flow matching. """Euler step: x_next = x + (sigma_next - sigma_cur) * v."""
dt = float(self.sigmas[self._step_index + 1].item()) - float(
In flow matching, model predicts velocity v, and: self.sigmas[self._step_index].item()
x_{t-1} = sample + (sigma_{t-1} - sigma_t) * v )
Args:
model_output: Predicted velocity [B, C, T, H, W]
timestep: Current timestep (unused, step index is tracked internally)
sample: Current noisy sample [B, C, T, H, W]
Returns:
Updated sample
"""
# Use Python floats to avoid creating mx.array scalars that
# could trigger type promotion (per fast-mlx guide)
dt = float(self.sigmas[self._step_index + 1].item()) - float(self.sigmas[self._step_index].item())
x_next = sample + dt * model_output x_next = sample + dt * model_output
self._step_index += 1
return x_next
def reset(self):
self._step_index = 0
class FlowDPMPP2MScheduler:
"""DPM-Solver++(2M) for flow matching diffusion.
2nd-order multistep solver that reuses the previous step's model output
for a correction term. Falls back to 1st order on the first and
(optionally) last step. Reference: Wan2.2 fm_solvers.py.
"""
def __init__(
self,
num_train_timesteps: int = 1000,
lower_order_final: bool = True,
):
self.num_train_timesteps = num_train_timesteps
self.lower_order_final = lower_order_final
self.timesteps = None
self.sigmas = None
def set_timesteps(self, num_steps: int, shift: float = 1.0):
sigmas = _compute_sigmas(num_steps, shift)
self.sigmas = mx.array(sigmas)
self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps)
# Store sigmas as Python floats for scalar math
self._sigmas_float = sigmas.tolist()
self._step_index = 0
self._num_steps = num_steps
self._prev_x0 = None # previous x0 prediction for 2nd-order correction
@staticmethod
def _lambda(sigma: float) -> float:
"""log-SNR: lambda(sigma) = log((1-sigma)/sigma).
Returns -inf at sigma=1.0 (pure noise) and +inf at sigma=0.0 (clean),
matching torch.log behavior in the official code.
"""
if sigma >= 1.0:
return -math.inf
if sigma <= 0.0:
return math.inf
return math.log((1.0 - sigma) / sigma)
def step(
self,
model_output: mx.array,
timestep,
sample: mx.array,
) -> mx.array:
"""DPM++(2M) step for flow matching.
Converts velocity prediction to x0, then applies 1st or 2nd order
update depending on available history.
"""
i = self._step_index
s = self._sigmas_float
sigma_cur = s[i]
sigma_next = s[i + 1]
# Convert velocity -> x0 prediction: x0 = sample - sigma * v
x0 = sample - sigma_cur * model_output
# Decide order: 1st for first step, last step (if lower_order_final
# and few steps), otherwise 2nd
use_first_order = (
self._prev_x0 is None
or (
self.lower_order_final
and i == self._num_steps - 1
and self._num_steps < 15
)
)
if use_first_order or sigma_next == 0.0:
# 1st order DPM++ (equivalent to DDIM):
# x_next = (σ_next/σ_cur)*x - (α_next*(exp(-h)-1))*x0
if sigma_next == 0.0:
x_next = x0
else:
lambda_cur = self._lambda(sigma_cur)
lambda_next = self._lambda(sigma_next)
h = lambda_next - lambda_cur
alpha_next = 1.0 - sigma_next
coeff_x = sigma_next / sigma_cur
coeff_x0 = alpha_next * math.expm1(-h)
x_next = coeff_x * sample - coeff_x0 * x0
else:
# 2nd order DPM++(2M) with midpoint correction
sigma_prev = s[i - 1]
lambda_prev = self._lambda(sigma_prev)
lambda_cur = self._lambda(sigma_cur)
lambda_next = self._lambda(sigma_next)
h = lambda_next - lambda_cur
h_0 = lambda_cur - lambda_prev
r0 = h_0 / h
# D0 = current x0, D1 = correction from previous x0
D0 = x0
D1 = (1.0 / r0) * (x0 - self._prev_x0)
alpha_next = 1.0 - sigma_next
exp_neg_h_m1 = math.expm1(-h) # exp(-h) - 1
x_next = (
(sigma_next / sigma_cur) * sample
- (alpha_next * exp_neg_h_m1) * D0
- 0.5 * (alpha_next * exp_neg_h_m1) * D1
)
self._prev_x0 = x0
self._step_index += 1
return x_next
def reset(self):
self._step_index = 0
self._prev_x0 = None
class FlowUniPCScheduler:
"""UniPC (Unified Predictor-Corrector) for flow matching diffusion.
Multi-step predictor-corrector solver with configurable order.
The corrector refines each step using the model output that was already
computed, costing no extra model evaluations. Official Wan2.2 default.
Reference: Wan2.2 fm_solvers_unipc.py.
"""
def __init__(
self,
num_train_timesteps: int = 1000,
solver_order: int = 2,
lower_order_final: bool = True,
disable_corrector: list | None = None,
use_corrector: bool = False,
):
self.num_train_timesteps = num_train_timesteps
self.solver_order = solver_order
self.lower_order_final = lower_order_final
self._use_corrector = use_corrector
self.disable_corrector = set(disable_corrector or [])
self.timesteps = None
self.sigmas = None
def set_timesteps(self, num_steps: int, shift: float = 1.0):
sigmas = _compute_sigmas(num_steps, shift)
self.sigmas = mx.array(sigmas)
self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps)
self._sigmas_float = sigmas.tolist()
self._step_index = 0
self._num_steps = num_steps
self._lower_order_nums = 0
# Model output (x0) history for multi-step, stored newest-last
self._model_outputs = [None] * self.solver_order
self._last_sample = None # sample before prediction (for corrector)
self._this_order = 1
@staticmethod
def _lambda(sigma: float) -> float:
"""log-SNR: lambda(sigma) = log((1-sigma)/sigma).
Returns -inf at sigma=1.0 (pure noise) and +inf at sigma=0.0 (clean),
matching torch.log behavior in the official code.
"""
if sigma >= 1.0:
return -math.inf
if sigma <= 0.0:
return math.inf
return math.log((1.0 - sigma) / sigma)
def _convert_output(self, velocity: mx.array, sample: mx.array) -> mx.array:
"""Convert velocity prediction to x0: x0 = sample - sigma * v."""
sigma = self._sigmas_float[self._step_index]
return sample - sigma * velocity
def _uni_p_bh2(self, x0: mx.array, sample: mx.array, order: int) -> mx.array:
"""UniP predictor with B(h)=expm1(-h) basis (bh2 variant).
Matches official multistep_uni_p_bh_update: computes rhos_p via
linalg.solve for order >= 3; order <= 2 uses analytic rhos_p=[0.5].
"""
i = self._step_index
s = self._sigmas_float
sigma_s0 = s[i]
sigma_t = s[i + 1]
if sigma_t == 0.0:
return x0
lambda_s0 = self._lambda(sigma_s0)
lambda_t = self._lambda(sigma_t)
h = lambda_t - lambda_s0
hh = -h # negated for predict_x0
alpha_t = 1.0 - sigma_t
h_phi_1 = math.expm1(hh)
B_h = h_phi_1
m0 = self._model_outputs[-1]
# Base prediction
x_t = (sigma_t / sigma_s0) * sample - (alpha_t * h_phi_1) * m0
if order >= 2 and m0 is not None:
rks = []
D1s = []
for k in range(1, order):
si_idx = i - k
if si_idx < 0 or self._model_outputs[-(k + 1)] is None:
break
mk = self._model_outputs[-(k + 1)]
sigma_sk = s[si_idx]
lambda_sk = self._lambda(sigma_sk)
rk = (lambda_sk - lambda_s0) / h
if math.isinf(rk):
break
rks.append(rk)
D1s.append((mk - m0) / rk)
if D1s:
effective_order = len(D1s) + 1
if effective_order <= 2:
# Analytic solution for order 2
rhos_p = [0.5]
else:
rks_arr = np.array(rks, dtype=np.float64)
h_phi_k = h_phi_1 / hh - 1.0
factorial_i = 1
R_rows = []
b_vals = []
for j in range(1, effective_order):
R_rows.append(rks_arr ** (j - 1))
b_vals.append(float(h_phi_k * factorial_i / B_h))
factorial_i *= j + 1
h_phi_k = h_phi_k / hh - 1.0 / factorial_i
R = np.stack(R_rows)
b = np.array(b_vals)
rhos_p = np.linalg.solve(R, b).tolist()
pred_res = sum(r * d for r, d in zip(rhos_p, D1s))
x_t = x_t - (alpha_t * B_h) * pred_res
return x_t
def _uni_c_bh2(
self,
model_x0: mx.array,
last_sample: mx.array,
this_sample: mx.array,
order: int,
) -> mx.array:
"""UniC corrector with B(h)=expm1(-h) basis (bh2 variant).
Matches official multistep_uni_c_bh_update: computes rhos_c via
linalg.solve for order >= 2 (not hardcoded 0.5).
"""
i = self._step_index
s = self._sigmas_float
sigma_s0 = s[i - 1]
sigma_t = s[i]
if sigma_t == 0.0:
return this_sample
lambda_s0 = self._lambda(sigma_s0)
lambda_t = self._lambda(sigma_t)
h = lambda_t - lambda_s0
hh = -h # negated for predict_x0
alpha_t = 1.0 - sigma_t
h_phi_1 = math.expm1(hh)
B_h = h_phi_1
m0 = self._model_outputs[-1]
# Re-derive base from last_sample
x_t_ = (sigma_t / sigma_s0) * last_sample - (alpha_t * h_phi_1) * m0
D1_t = model_x0 - m0
# Gather rks and D1s from history
rks = []
D1s = []
for k in range(1, order):
si_idx = i - (k + 1)
if si_idx < 0 or self._model_outputs[-(k + 1)] is None:
break
mk = self._model_outputs[-(k + 1)]
sigma_sk = s[si_idx]
lambda_sk = self._lambda(sigma_sk)
rk = (lambda_sk - lambda_s0) / h
if math.isinf(rk):
break # History references sigma=1.0 boundary; reduce order
rks.append(rk)
D1s.append((mk - m0) / rk)
rks.append(1.0)
effective_order = len(rks) # = len(D1s) + 1
# Compute rhos_c coefficients
if effective_order == 1:
rhos_c = [0.5]
else:
rks_arr = np.array(rks, dtype=np.float64)
h_phi_k = h_phi_1 / hh - 1.0
factorial_i = 1
R_rows = []
b_vals = []
for j in range(1, effective_order + 1):
R_rows.append(rks_arr ** (j - 1))
b_vals.append(float(h_phi_k * factorial_i / B_h))
factorial_i *= j + 1
h_phi_k = h_phi_k / hh - 1.0 / factorial_i
R = np.stack(R_rows)
b = np.array(b_vals)
rhos_c = np.linalg.solve(R, b).tolist()
# Apply correction
corr_res = mx.zeros_like(D1_t)
for k_idx, d1 in enumerate(D1s):
corr_res = corr_res + rhos_c[k_idx] * d1
x_t = x_t_ - (alpha_t * B_h) * (corr_res + rhos_c[-1] * D1_t)
return x_t
def step(
self,
model_output: mx.array,
timestep,
sample: mx.array,
) -> mx.array:
"""UniPC step: correct current, then predict next."""
i = self._step_index
# Convert velocity -> x0
x0 = self._convert_output(model_output, sample)
# 1. Corrector: refine current sample if we have history
use_corrector = (
self._use_corrector
and i > 0
and (i - 1) not in self.disable_corrector
and self._last_sample is not None
)
if use_corrector:
sample = self._uni_c_bh2(x0, self._last_sample, sample, self._this_order)
# 2. Shift model output history
for k in range(self.solver_order - 1):
self._model_outputs[k] = self._model_outputs[k + 1]
self._model_outputs[-1] = x0
# 3. Determine prediction order
if self.lower_order_final:
this_order = min(self.solver_order, self._num_steps - i)
else:
this_order = self.solver_order
self._this_order = min(this_order, self._lower_order_nums + 1)
# 4. Predict next sample
self._last_sample = sample
x_next = self._uni_p_bh2(x0, sample, self._this_order)
if self._lower_order_nums < self.solver_order:
self._lower_order_nums += 1
self._step_index += 1 self._step_index += 1
return x_next return x_next
def reset(self): def reset(self):
"""Reset step counter for new generation."""
self._step_index = 0 self._step_index = 0
self._lower_order_nums = 0
self._model_outputs = [None] * self.solver_order
self._last_sample = None
self._this_order = 1

View File

@@ -106,29 +106,35 @@ class T5Attention(nn.Module):
k = self.k(context).reshape(b, -1, n, c) # [B, Lk, N, C] k = self.k(context).reshape(b, -1, n, c) # [B, Lk, N, C]
v = self.v(context).reshape(b, -1, n, c) v = self.v(context).reshape(b, -1, n, c)
# T5 does not use scaling # T5 uses no scaling — compute attention manually with float32 softmax
# attn = einsum('binc,bjnc->bnij', q, k) # to match official: F.softmax(attn.float(), dim=-1).type_as(attn)
# Using SDPA with bfloat16 inputs causes precision loss in softmax
# since unscaled logits can be very large (no 1/sqrt(d) division).
q = q.transpose(0, 2, 1, 3) # [B, N, Lq, C] q = q.transpose(0, 2, 1, 3) # [B, N, Lq, C]
k = k.transpose(0, 2, 1, 3) k = k.transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3) v = v.transpose(0, 2, 1, 3)
# Combine position bias and attention mask for SDPA # QK^T (no scaling) — compute in float32 for precision
attn_mask = None attn = (q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2))
# Add position bias
if pos_bias is not None: if pos_bias is not None:
attn_mask = pos_bias.astype(q.dtype) attn = attn + pos_bias.astype(mx.float32)
# Apply attention mask (use dtype min like official, not -1e9)
if mask is not None: if mask is not None:
if mask.ndim == 2: if mask.ndim == 2:
mask = mask[:, None, None, :] # [B, 1, 1, Lk] mask = mask[:, None, None, :] # [B, 1, 1, Lk]
elif mask.ndim == 3: elif mask.ndim == 3:
mask = mask[:, None, :, :] # [B, 1, Lq, Lk] mask = mask[:, None, :, :] # [B, 1, Lq, Lk]
additive_mask = mx.where(mask == 0, -1e9, 0.0).astype(q.dtype) additive_mask = mx.where(mask == 0, -3.389e38, 0.0).astype(mx.float32)
attn_mask = (attn_mask + additive_mask) if attn_mask is not None else additive_mask attn = attn + additive_mask
# T5 uses no scaling (scale=1.0) # Softmax in float32 (matches official), then cast back
out = mx.fast.scaled_dot_product_attention( attn = mx.softmax(attn, axis=-1).astype(q.dtype)
q, k, v, scale=1.0, mask=attn_mask
) # Attention @ V
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * c) out = (attn @ v).transpose(0, 2, 1, 3).reshape(b, -1, n * c)
return self.o(out) return self.o(out)

View File

@@ -288,21 +288,34 @@ class Resample(nn.Module):
B, T, H, W, C = x.shape B, T, H, W, C = x.shape
if self.mode == "upsample3d": if self.mode == "upsample3d":
# Temporal upsample via time_conv if first_chunk and T > 1:
tc_out = self.time_conv(x) # [B, T, H, W, 2C] # Match official chunked behavior: the first frame bypasses
# Split into two interleaved temporal streams # time_conv entirely (only spatial upsample). Remaining frames
tc_out = tc_out.reshape(B, T, H, W, 2, C) # go through time_conv with causal zero-padding, which
# Interleave: [B, T, 2, H, W, C] → [B, T*2, H, W, C] # naturally gives each frame the same limited temporal context
stream0 = tc_out[:, :, :, :, 0, :] # [B, T, H, W, C] # as the official frame-by-frame decode with caching.
stream1 = tc_out[:, :, :, :, 1, :] # [B, T, H, W, C] first_frame = x[:, 0:1] # [B, 1, H, W, C]
x = mx.stack([stream0, stream1], axis=2) # [B, T, 2, H, W, C] rest = x[:, 1:] # [B, T-1, H, W, C]
x = x.reshape(B, T * 2, H, W, C)
if first_chunk: # time_conv on remaining frames (causal pad gives zero context
# PyTorch skips time_conv for first chunk entirely. In all-at-once # before rest[0], matching the official "Rep" cache path)
# mode, we trim the first frame to match (the first interleaved tc_out = self.time_conv(rest) # [B, T-1, H, W, 2C]
# frame is from zero-padded causal context and shouldn't be kept). tc_out = tc_out.reshape(B, T - 1, H, W, 2, C)
x = x[:, 1:, :, :, :] stream0 = tc_out[:, :, :, :, 0, :]
stream1 = tc_out[:, :, :, :, 1, :]
interleaved = mx.stack([stream0, stream1], axis=2)
interleaved = interleaved.reshape(B, (T - 1) * 2, H, W, C)
# first_frame (1) + interleaved (2*(T-1)) = 2T-1 frames
x = mx.concatenate([first_frame, interleaved], axis=1)
elif self.mode == "upsample3d":
# Non-first-chunk or single frame: time_conv all frames
tc_out = self.time_conv(x) # [B, T, H, W, 2C]
tc_out = tc_out.reshape(B, T, H, W, 2, C)
stream0 = tc_out[:, :, :, :, 0, :]
stream1 = tc_out[:, :, :, :, 1, :]
x = mx.stack([stream0, stream1], axis=2)
x = x.reshape(B, T * 2, H, W, C)
mx.eval(x) mx.eval(x)
T = x.shape[1] T = x.shape[1]

View File

@@ -22,6 +22,7 @@ dependencies = [
"mlx-vlm", "mlx-vlm",
"imageio>=2.37.2", "imageio>=2.37.2",
"imageio-ffmpeg>=0.6.0", "imageio-ffmpeg>=0.6.0",
"ftfy",
] ]
license = {text="MIT"} license = {text="MIT"}
authors = [ authors = [

File diff suppressed because it is too large Load Diff