From 93da550f6530ebb4f8b5ce9fbb6328717a3345ec Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 27 Feb 2026 10:28:33 +0100 Subject: [PATCH] feat(wan): Add DPM++ 2M and UniPC schedulers --- mlx_video/convert_wan.py | 7 +- mlx_video/generate_wan.py | 93 +- mlx_video/models/wan/config.py | 6 + mlx_video/models/wan/scheduler.py | 444 ++++++++- mlx_video/models/wan/text_encoder.py | 30 +- mlx_video/models/wan/vae22.py | 41 +- pyproject.toml | 1 + tests/test_wan.py | 1259 ++++++++++++++++++++++++++ 8 files changed, 1792 insertions(+), 89 deletions(-) diff --git a/mlx_video/convert_wan.py b/mlx_video/convert_wan.py index 7d89d55..33be9ac 100644 --- a/mlx_video/convert_wan.py +++ b/mlx_video/convert_wan.py @@ -412,10 +412,13 @@ def convert_wan_checkpoint( weights = sanitize_wan22_vae_weights(weights) else: weights = sanitize_wan_vae_weights(weights) - weights = {k: v.astype(target_dtype) for k, v in weights.items()} + # Always save VAE in float32 — official Wan2.2 runs VAE decode in + # float32 (dtype=torch.float). Saving in bfloat16 loses precision + # that cannot be recovered by upcasting at load time. + weights = {k: v.astype(mx.float32) for k, v in weights.items()} out_path = output_dir / "vae.safetensors" mx.save_safetensors(str(out_path), weights) - print(f" Saved {len(weights)} weight tensors to {out_path}") + print(f" Saved {len(weights)} weight tensors to {out_path} (float32)") # Quantize transformer weights if requested if quantize: diff --git a/mlx_video/generate_wan.py b/mlx_video/generate_wan.py index 69d5723..7d1e601 100644 --- a/mlx_video/generate_wan.py +++ b/mlx_video/generate_wan.py @@ -56,7 +56,12 @@ def load_wan_model(model_path: Path, config, quantization: dict | None = None): def load_t5_encoder(model_path: Path, config): - """Load T5 text encoder.""" + """Load T5 text encoder. + + Weights are upcast to float32 for maximum precision — the T5 encoder + only runs once per generation, so performance impact is negligible. + This matches the official which computes softmax in float32 explicitly. + """ from mlx_video.models.wan.text_encoder import T5Encoder encoder = T5Encoder( @@ -70,6 +75,7 @@ def load_t5_encoder(model_path: Path, config): shared_pos=False, ) weights = mx.load(str(model_path)) + weights = {k: v.astype(mx.float32) for k, v in weights.items()} encoder.load_weights(list(weights.items())) mx.eval(encoder.parameters()) return encoder @@ -91,11 +97,33 @@ def load_vae_decoder(model_path: Path, config=None): vae = WanVAE(z_dim=16) weights = mx.load(str(model_path)) + # Upcast VAE weights to float32 for quality — official Wan2.2 runs VAE in float32 + weights = {k: v.astype(mx.float32) for k, v in weights.items()} vae.load_weights(list(weights.items()), strict=False) mx.eval(vae.parameters()) return vae +def _clean_text(text: str) -> str: + """Clean text matching official Wan2.2 tokenizer preprocessing. + + Applies ftfy.fix_text (fixes mojibake, normalizes fullwidth chars), + double HTML unescape, and whitespace normalization. Critical for + correct tokenization of the Chinese negative prompt. + """ + import html + import re + + try: + import ftfy + text = ftfy.fix_text(text) + except ImportError: + pass + text = html.unescape(html.unescape(text)) + text = re.sub(r"\s+", " ", text).strip() + return text + + def encode_text( encoder, tokenizer, @@ -113,6 +141,7 @@ def encode_text( Returns: Text embeddings [L, dim] """ + prompt = _clean_text(prompt) tokens = tokenizer( prompt, max_length=text_len, @@ -133,7 +162,7 @@ def encode_text( def generate_video( model_dir: str, prompt: str, - negative_prompt: str = "", + negative_prompt: str | None = None, width: int = 1280, height: int = 720, num_frames: int = 81, @@ -142,13 +171,14 @@ def generate_video( shift: float = None, seed: int = -1, output_path: str = "output.mp4", + scheduler: str = "unipc", ): """Generate video using Wan T2V pipeline (supports 2.1 and 2.2). Args: model_dir: Path to converted MLX model directory prompt: Text prompt - negative_prompt: Negative prompt + negative_prompt: Negative prompt (None = use config default, "" = no negative prompt) width: Video width height: Video height num_frames: Number of frames (must be 4n+1) @@ -157,11 +187,16 @@ def generate_video( shift: Noise schedule shift (None = use config default) seed: Random seed (-1 for random) output_path: Output video path + scheduler: Solver type: 'euler', 'dpm++', or 'unipc' (default) """ import json from mlx_video.models.wan.config import WanModelConfig - from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler + from mlx_video.models.wan.scheduler import ( + FlowDPMPP2MScheduler, + FlowMatchEulerScheduler, + FlowUniPCScheduler, + ) model_dir = Path(model_dir) @@ -253,12 +288,23 @@ def generate_video( version_str = f"Wan{config.model_version}" mode_str = "dual-model" if is_dual else "single-model" + # Resolve negative prompt: explicit user value > config default + # The official Wan2.2 uses a Chinese negative prompt (config.sample_neg_prompt) + # that prevents oversaturation, artifacts, and comic look. We use it by default. + # Text cleaning (_clean_text) normalizes fullwidth chars to match official tokenization. + if negative_prompt is None: + neg_prompt_resolved = config.sample_neg_prompt + else: + neg_prompt_resolved = negative_prompt print(f"{Colors.CYAN}{'='*60}") print(f" {version_str} Text-to-Video Generation (MLX, {mode_str})") print(f"{'='*60}{Colors.RESET}") print(f"{Colors.DIM} Prompt: {prompt}") + if neg_prompt_resolved and neg_prompt_resolved.strip(): + neg_display = neg_prompt_resolved[:60] + "..." if len(neg_prompt_resolved) > 60 else neg_prompt_resolved + print(f" Neg prompt: {neg_display}") print(f" Size: {width}x{height}, Frames: {num_frames}") - print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}") + print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}") print(f"{Colors.RESET}") # Seed @@ -298,10 +344,7 @@ def generate_video( # Encode prompts print(f"{Colors.BLUE}Encoding text...{Colors.RESET}") context = encode_text(t5_encoder, tokenizer, prompt, config.text_len) - if negative_prompt: - context_null = encode_text(t5_encoder, tokenizer, negative_prompt, config.text_len) - else: - context_null = encode_text(t5_encoder, tokenizer, "", config.text_len) + context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len) mx.eval(context, context_null) # Free T5 from memory @@ -343,8 +386,14 @@ def generate_video( mx.eval(cross_kv) # Setup scheduler - scheduler = FlowMatchEulerScheduler(num_train_timesteps=config.num_train_timesteps) - scheduler.set_timesteps(steps, shift=shift) + _schedulers = { + "euler": FlowMatchEulerScheduler, + "dpm++": FlowDPMPP2MScheduler, + "unipc": FlowUniPCScheduler, + } + sched_cls = _schedulers.get(scheduler, FlowUniPCScheduler) + sched = sched_cls(num_train_timesteps=config.num_train_timesteps) + sched.set_timesteps(steps, shift=shift) # Generate initial noise noise = mx.random.normal(target_shape) @@ -358,7 +407,7 @@ def generate_video( t3 = time.time() for i, t in enumerate(tqdm(range(steps), desc="Diffusion")): - timestep_val = scheduler.timesteps[i].item() + timestep_val = sched.timesteps[i].item() # Select model, guide scale, and cached K/V if is_dual: @@ -387,7 +436,7 @@ def generate_video( # Classifier-free guidance + scheduler step noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond) - latents = scheduler.step(noise_pred[None], timestep_val, latents[None]).squeeze(0) + latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0) # Release temporaries before eval to free memory for graph execution del noise_pred_cond, noise_pred_uncond, noise_pred, preds @@ -476,7 +525,10 @@ def main(): parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)") parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory") parser.add_argument("--prompt", type=str, required=True, help="Text prompt") - parser.add_argument("--negative-prompt", type=str, default="", help="Negative prompt") + parser.add_argument("--negative-prompt", type=str, default=None, + help="Negative prompt for CFG (default: official Chinese prompt from config)") + parser.add_argument("--no-negative-prompt", action="store_true", + help="Disable negative prompt (use empty string instead of config default)") parser.add_argument("--width", type=int, default=1280, help="Video width") parser.add_argument("--height", type=int, default=720, help="Video height") parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)") @@ -485,6 +537,11 @@ def main(): parser.add_argument("--shift", type=float, default=None, help="Noise schedule shift (default: from config)") parser.add_argument("--seed", type=int, default=-1, help="Random seed") parser.add_argument("--output-path", type=str, default="output.mp4", help="Output video path") + parser.add_argument( + "--scheduler", type=str, default="unipc", + choices=["euler", "dpm++", "unipc"], + help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)", + ) args = parser.parse_args() # Parse guide scale @@ -493,10 +550,15 @@ def main(): parts = [float(x) for x in args.guide_scale.split(",")] guide_scale = tuple(parts) if len(parts) > 1 else parts[0] + # Handle negative prompt: --no-negative-prompt forces empty, otherwise pass through + neg_prompt = args.negative_prompt + if args.no_negative_prompt: + neg_prompt = "" + generate_video( model_dir=args.model_dir, prompt=args.prompt, - negative_prompt=args.negative_prompt, + negative_prompt=neg_prompt, width=args.width, height=args.height, num_frames=args.num_frames, @@ -505,6 +567,7 @@ def main(): shift=args.shift, seed=args.seed, output_path=args.output_path, + scheduler=args.scheduler, ) diff --git a/mlx_video/models/wan/config.py b/mlx_video/models/wan/config.py index 1be3374..e4bf900 100644 --- a/mlx_video/models/wan/config.py +++ b/mlx_video/models/wan/config.py @@ -38,6 +38,12 @@ class WanModelConfig(BaseModelConfig): num_train_timesteps: int = 1000 sample_fps: int = 16 frame_num: int = 81 + sample_neg_prompt: str = ( + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰," + "最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部," + "画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面," + "杂乱的背景,三条腿,背景人很多,倒着走" + ) # T5 t5_vocab_size: int = 256384 diff --git a/mlx_video/models/wan/scheduler.py b/mlx_video/models/wan/scheduler.py index 377f058..49ef10f 100644 --- a/mlx_video/models/wan/scheduler.py +++ b/mlx_video/models/wan/scheduler.py @@ -1,16 +1,29 @@ -"""Flow matching scheduler for Wan2.2 inference.""" +"""Flow matching schedulers for Wan2.2 inference. + +Provides Euler, DPM++2M, and UniPC solvers for flow matching diffusion. +Higher-order solvers (DPM++, UniPC) converge faster, needing fewer steps +for the same quality as Euler. +""" + +import math import numpy as np import mlx.core as mx -class FlowMatchEulerScheduler: - """Simple Euler scheduler for flow matching diffusion. +def _compute_sigmas(num_steps: int, shift: float = 1.0) -> np.ndarray: + """Compute shifted sigma schedule matching official Wan2.2 code. - Implements the flow matching formulation where the model predicts - velocity (flow) and we use Euler steps to denoise. + Returns num_steps+1 values (the last being 0.0 for the terminal state). """ + sigmas = np.linspace(1.0, 0.0, num_steps + 1)[:num_steps] + sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas) + return np.append(sigmas, 0.0).astype(np.float32) + + +class FlowMatchEulerScheduler: + """1st-order Euler scheduler for flow matching diffusion.""" def __init__(self, num_train_timesteps: int = 1000): self.num_train_timesteps = num_train_timesteps @@ -18,30 +31,9 @@ class FlowMatchEulerScheduler: self.sigmas = None def set_timesteps(self, num_steps: int, shift: float = 1.0): - """Compute sigma schedule with shift. - - Args: - num_steps: Number of inference steps. - shift: Noise schedule shift factor. - """ - # Linear spacing from sigma_max to sigma_min - sigmas = np.linspace(1.0, 1.0 / self.num_train_timesteps, self.num_train_timesteps)[::-1] - sigmas = 1.0 - sigmas - - # Select evenly spaced subset - indices = np.linspace(0, len(sigmas) - 1, num_steps + 1).astype(int) - sigmas = sigmas[indices[:-1]] - - # Apply shift: sigma' = shift * sigma / (1 + (shift - 1) * sigma) - sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas) - - # Convert to timesteps - timesteps = sigmas * self.num_train_timesteps - self.timesteps = mx.array(timesteps.astype(np.float32)) - - # Append terminal sigma=0 - sigmas = np.append(sigmas, 0.0) - self.sigmas = mx.array(sigmas.astype(np.float32)) + sigmas = _compute_sigmas(num_steps, shift) + self.sigmas = mx.array(sigmas) + self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps) self._step_index = 0 def step( @@ -50,27 +42,387 @@ class FlowMatchEulerScheduler: timestep, sample: mx.array, ) -> mx.array: - """Euler step for flow matching. - - In flow matching, model predicts velocity v, and: - x_{t-1} = sample + (sigma_{t-1} - sigma_t) * v - - Args: - model_output: Predicted velocity [B, C, T, H, W] - timestep: Current timestep (unused, step index is tracked internally) - sample: Current noisy sample [B, C, T, H, W] - - Returns: - Updated sample - """ - # Use Python floats to avoid creating mx.array scalars that - # could trigger type promotion (per fast-mlx guide) - dt = float(self.sigmas[self._step_index + 1].item()) - float(self.sigmas[self._step_index].item()) + """Euler step: x_next = x + (sigma_next - sigma_cur) * v.""" + dt = float(self.sigmas[self._step_index + 1].item()) - float( + self.sigmas[self._step_index].item() + ) x_next = sample + dt * model_output + self._step_index += 1 + return x_next + + def reset(self): + self._step_index = 0 + + +class FlowDPMPP2MScheduler: + """DPM-Solver++(2M) for flow matching diffusion. + + 2nd-order multistep solver that reuses the previous step's model output + for a correction term. Falls back to 1st order on the first and + (optionally) last step. Reference: Wan2.2 fm_solvers.py. + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + lower_order_final: bool = True, + ): + self.num_train_timesteps = num_train_timesteps + self.lower_order_final = lower_order_final + self.timesteps = None + self.sigmas = None + + def set_timesteps(self, num_steps: int, shift: float = 1.0): + sigmas = _compute_sigmas(num_steps, shift) + self.sigmas = mx.array(sigmas) + self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps) + # Store sigmas as Python floats for scalar math + self._sigmas_float = sigmas.tolist() + self._step_index = 0 + self._num_steps = num_steps + self._prev_x0 = None # previous x0 prediction for 2nd-order correction + + @staticmethod + def _lambda(sigma: float) -> float: + """log-SNR: lambda(sigma) = log((1-sigma)/sigma). + + Returns -inf at sigma=1.0 (pure noise) and +inf at sigma=0.0 (clean), + matching torch.log behavior in the official code. + """ + if sigma >= 1.0: + return -math.inf + if sigma <= 0.0: + return math.inf + return math.log((1.0 - sigma) / sigma) + + def step( + self, + model_output: mx.array, + timestep, + sample: mx.array, + ) -> mx.array: + """DPM++(2M) step for flow matching. + + Converts velocity prediction to x0, then applies 1st or 2nd order + update depending on available history. + """ + i = self._step_index + s = self._sigmas_float + + sigma_cur = s[i] + sigma_next = s[i + 1] + + # Convert velocity -> x0 prediction: x0 = sample - sigma * v + x0 = sample - sigma_cur * model_output + + # Decide order: 1st for first step, last step (if lower_order_final + # and few steps), otherwise 2nd + use_first_order = ( + self._prev_x0 is None + or ( + self.lower_order_final + and i == self._num_steps - 1 + and self._num_steps < 15 + ) + ) + + if use_first_order or sigma_next == 0.0: + # 1st order DPM++ (equivalent to DDIM): + # x_next = (σ_next/σ_cur)*x - (α_next*(exp(-h)-1))*x0 + if sigma_next == 0.0: + x_next = x0 + else: + lambda_cur = self._lambda(sigma_cur) + lambda_next = self._lambda(sigma_next) + h = lambda_next - lambda_cur + alpha_next = 1.0 - sigma_next + coeff_x = sigma_next / sigma_cur + coeff_x0 = alpha_next * math.expm1(-h) + x_next = coeff_x * sample - coeff_x0 * x0 + else: + # 2nd order DPM++(2M) with midpoint correction + sigma_prev = s[i - 1] + lambda_prev = self._lambda(sigma_prev) + lambda_cur = self._lambda(sigma_cur) + lambda_next = self._lambda(sigma_next) + + h = lambda_next - lambda_cur + h_0 = lambda_cur - lambda_prev + r0 = h_0 / h + + # D0 = current x0, D1 = correction from previous x0 + D0 = x0 + D1 = (1.0 / r0) * (x0 - self._prev_x0) + + alpha_next = 1.0 - sigma_next + exp_neg_h_m1 = math.expm1(-h) # exp(-h) - 1 + + x_next = ( + (sigma_next / sigma_cur) * sample + - (alpha_next * exp_neg_h_m1) * D0 + - 0.5 * (alpha_next * exp_neg_h_m1) * D1 + ) + + self._prev_x0 = x0 + self._step_index += 1 + return x_next + + def reset(self): + self._step_index = 0 + self._prev_x0 = None + + +class FlowUniPCScheduler: + """UniPC (Unified Predictor-Corrector) for flow matching diffusion. + + Multi-step predictor-corrector solver with configurable order. + The corrector refines each step using the model output that was already + computed, costing no extra model evaluations. Official Wan2.2 default. + Reference: Wan2.2 fm_solvers_unipc.py. + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + lower_order_final: bool = True, + disable_corrector: list | None = None, + use_corrector: bool = False, + ): + self.num_train_timesteps = num_train_timesteps + self.solver_order = solver_order + self.lower_order_final = lower_order_final + self._use_corrector = use_corrector + self.disable_corrector = set(disable_corrector or []) + self.timesteps = None + self.sigmas = None + + def set_timesteps(self, num_steps: int, shift: float = 1.0): + sigmas = _compute_sigmas(num_steps, shift) + self.sigmas = mx.array(sigmas) + self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps) + self._sigmas_float = sigmas.tolist() + self._step_index = 0 + self._num_steps = num_steps + self._lower_order_nums = 0 + # Model output (x0) history for multi-step, stored newest-last + self._model_outputs = [None] * self.solver_order + self._last_sample = None # sample before prediction (for corrector) + self._this_order = 1 + + @staticmethod + def _lambda(sigma: float) -> float: + """log-SNR: lambda(sigma) = log((1-sigma)/sigma). + + Returns -inf at sigma=1.0 (pure noise) and +inf at sigma=0.0 (clean), + matching torch.log behavior in the official code. + """ + if sigma >= 1.0: + return -math.inf + if sigma <= 0.0: + return math.inf + return math.log((1.0 - sigma) / sigma) + + def _convert_output(self, velocity: mx.array, sample: mx.array) -> mx.array: + """Convert velocity prediction to x0: x0 = sample - sigma * v.""" + sigma = self._sigmas_float[self._step_index] + return sample - sigma * velocity + + def _uni_p_bh2(self, x0: mx.array, sample: mx.array, order: int) -> mx.array: + """UniP predictor with B(h)=expm1(-h) basis (bh2 variant). + + Matches official multistep_uni_p_bh_update: computes rhos_p via + linalg.solve for order >= 3; order <= 2 uses analytic rhos_p=[0.5]. + """ + i = self._step_index + s = self._sigmas_float + + sigma_s0 = s[i] + sigma_t = s[i + 1] + + if sigma_t == 0.0: + return x0 + + lambda_s0 = self._lambda(sigma_s0) + lambda_t = self._lambda(sigma_t) + h = lambda_t - lambda_s0 + hh = -h # negated for predict_x0 + + alpha_t = 1.0 - sigma_t + h_phi_1 = math.expm1(hh) + B_h = h_phi_1 + + m0 = self._model_outputs[-1] + # Base prediction + x_t = (sigma_t / sigma_s0) * sample - (alpha_t * h_phi_1) * m0 + + if order >= 2 and m0 is not None: + rks = [] + D1s = [] + for k in range(1, order): + si_idx = i - k + if si_idx < 0 or self._model_outputs[-(k + 1)] is None: + break + mk = self._model_outputs[-(k + 1)] + sigma_sk = s[si_idx] + lambda_sk = self._lambda(sigma_sk) + rk = (lambda_sk - lambda_s0) / h + if math.isinf(rk): + break + rks.append(rk) + D1s.append((mk - m0) / rk) + + if D1s: + effective_order = len(D1s) + 1 + if effective_order <= 2: + # Analytic solution for order 2 + rhos_p = [0.5] + else: + rks_arr = np.array(rks, dtype=np.float64) + h_phi_k = h_phi_1 / hh - 1.0 + factorial_i = 1 + R_rows = [] + b_vals = [] + for j in range(1, effective_order): + R_rows.append(rks_arr ** (j - 1)) + b_vals.append(float(h_phi_k * factorial_i / B_h)) + factorial_i *= j + 1 + h_phi_k = h_phi_k / hh - 1.0 / factorial_i + R = np.stack(R_rows) + b = np.array(b_vals) + rhos_p = np.linalg.solve(R, b).tolist() + + pred_res = sum(r * d for r, d in zip(rhos_p, D1s)) + x_t = x_t - (alpha_t * B_h) * pred_res + + return x_t + + def _uni_c_bh2( + self, + model_x0: mx.array, + last_sample: mx.array, + this_sample: mx.array, + order: int, + ) -> mx.array: + """UniC corrector with B(h)=expm1(-h) basis (bh2 variant). + + Matches official multistep_uni_c_bh_update: computes rhos_c via + linalg.solve for order >= 2 (not hardcoded 0.5). + """ + i = self._step_index + s = self._sigmas_float + + sigma_s0 = s[i - 1] + sigma_t = s[i] + + if sigma_t == 0.0: + return this_sample + + lambda_s0 = self._lambda(sigma_s0) + lambda_t = self._lambda(sigma_t) + h = lambda_t - lambda_s0 + hh = -h # negated for predict_x0 + + alpha_t = 1.0 - sigma_t + h_phi_1 = math.expm1(hh) + B_h = h_phi_1 + + m0 = self._model_outputs[-1] + # Re-derive base from last_sample + x_t_ = (sigma_t / sigma_s0) * last_sample - (alpha_t * h_phi_1) * m0 + + D1_t = model_x0 - m0 + + # Gather rks and D1s from history + rks = [] + D1s = [] + for k in range(1, order): + si_idx = i - (k + 1) + if si_idx < 0 or self._model_outputs[-(k + 1)] is None: + break + mk = self._model_outputs[-(k + 1)] + sigma_sk = s[si_idx] + lambda_sk = self._lambda(sigma_sk) + rk = (lambda_sk - lambda_s0) / h + if math.isinf(rk): + break # History references sigma=1.0 boundary; reduce order + rks.append(rk) + D1s.append((mk - m0) / rk) + rks.append(1.0) + effective_order = len(rks) # = len(D1s) + 1 + + # Compute rhos_c coefficients + if effective_order == 1: + rhos_c = [0.5] + else: + rks_arr = np.array(rks, dtype=np.float64) + h_phi_k = h_phi_1 / hh - 1.0 + factorial_i = 1 + R_rows = [] + b_vals = [] + for j in range(1, effective_order + 1): + R_rows.append(rks_arr ** (j - 1)) + b_vals.append(float(h_phi_k * factorial_i / B_h)) + factorial_i *= j + 1 + h_phi_k = h_phi_k / hh - 1.0 / factorial_i + R = np.stack(R_rows) + b = np.array(b_vals) + rhos_c = np.linalg.solve(R, b).tolist() + + # Apply correction + corr_res = mx.zeros_like(D1_t) + for k_idx, d1 in enumerate(D1s): + corr_res = corr_res + rhos_c[k_idx] * d1 + x_t = x_t_ - (alpha_t * B_h) * (corr_res + rhos_c[-1] * D1_t) + return x_t + + def step( + self, + model_output: mx.array, + timestep, + sample: mx.array, + ) -> mx.array: + """UniPC step: correct current, then predict next.""" + i = self._step_index + + # Convert velocity -> x0 + x0 = self._convert_output(model_output, sample) + + # 1. Corrector: refine current sample if we have history + use_corrector = ( + self._use_corrector + and i > 0 + and (i - 1) not in self.disable_corrector + and self._last_sample is not None + ) + if use_corrector: + sample = self._uni_c_bh2(x0, self._last_sample, sample, self._this_order) + + # 2. Shift model output history + for k in range(self.solver_order - 1): + self._model_outputs[k] = self._model_outputs[k + 1] + self._model_outputs[-1] = x0 + + # 3. Determine prediction order + if self.lower_order_final: + this_order = min(self.solver_order, self._num_steps - i) + else: + this_order = self.solver_order + self._this_order = min(this_order, self._lower_order_nums + 1) + + # 4. Predict next sample + self._last_sample = sample + x_next = self._uni_p_bh2(x0, sample, self._this_order) + + if self._lower_order_nums < self.solver_order: + self._lower_order_nums += 1 self._step_index += 1 return x_next def reset(self): - """Reset step counter for new generation.""" self._step_index = 0 + self._lower_order_nums = 0 + self._model_outputs = [None] * self.solver_order + self._last_sample = None + self._this_order = 1 diff --git a/mlx_video/models/wan/text_encoder.py b/mlx_video/models/wan/text_encoder.py index 8eaac36..d325ed5 100644 --- a/mlx_video/models/wan/text_encoder.py +++ b/mlx_video/models/wan/text_encoder.py @@ -106,29 +106,35 @@ class T5Attention(nn.Module): k = self.k(context).reshape(b, -1, n, c) # [B, Lk, N, C] v = self.v(context).reshape(b, -1, n, c) - # T5 does not use scaling - # attn = einsum('binc,bjnc->bnij', q, k) + # T5 uses no scaling — compute attention manually with float32 softmax + # to match official: F.softmax(attn.float(), dim=-1).type_as(attn) + # Using SDPA with bfloat16 inputs causes precision loss in softmax + # since unscaled logits can be very large (no 1/sqrt(d) division). q = q.transpose(0, 2, 1, 3) # [B, N, Lq, C] k = k.transpose(0, 2, 1, 3) v = v.transpose(0, 2, 1, 3) - # Combine position bias and attention mask for SDPA - attn_mask = None + # QK^T (no scaling) — compute in float32 for precision + attn = (q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2)) + + # Add position bias if pos_bias is not None: - attn_mask = pos_bias.astype(q.dtype) + attn = attn + pos_bias.astype(mx.float32) + + # Apply attention mask (use dtype min like official, not -1e9) if mask is not None: if mask.ndim == 2: mask = mask[:, None, None, :] # [B, 1, 1, Lk] elif mask.ndim == 3: mask = mask[:, None, :, :] # [B, 1, Lq, Lk] - additive_mask = mx.where(mask == 0, -1e9, 0.0).astype(q.dtype) - attn_mask = (attn_mask + additive_mask) if attn_mask is not None else additive_mask + additive_mask = mx.where(mask == 0, -3.389e38, 0.0).astype(mx.float32) + attn = attn + additive_mask - # T5 uses no scaling (scale=1.0) - out = mx.fast.scaled_dot_product_attention( - q, k, v, scale=1.0, mask=attn_mask - ) - out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * c) + # Softmax in float32 (matches official), then cast back + attn = mx.softmax(attn, axis=-1).astype(q.dtype) + + # Attention @ V + out = (attn @ v).transpose(0, 2, 1, 3).reshape(b, -1, n * c) return self.o(out) diff --git a/mlx_video/models/wan/vae22.py b/mlx_video/models/wan/vae22.py index c48b7cb..8865542 100644 --- a/mlx_video/models/wan/vae22.py +++ b/mlx_video/models/wan/vae22.py @@ -288,21 +288,34 @@ class Resample(nn.Module): B, T, H, W, C = x.shape if self.mode == "upsample3d": - # Temporal upsample via time_conv - tc_out = self.time_conv(x) # [B, T, H, W, 2C] - # Split into two interleaved temporal streams - tc_out = tc_out.reshape(B, T, H, W, 2, C) - # Interleave: [B, T, 2, H, W, C] → [B, T*2, H, W, C] - stream0 = tc_out[:, :, :, :, 0, :] # [B, T, H, W, C] - stream1 = tc_out[:, :, :, :, 1, :] # [B, T, H, W, C] - x = mx.stack([stream0, stream1], axis=2) # [B, T, 2, H, W, C] - x = x.reshape(B, T * 2, H, W, C) + if first_chunk and T > 1: + # Match official chunked behavior: the first frame bypasses + # time_conv entirely (only spatial upsample). Remaining frames + # go through time_conv with causal zero-padding, which + # naturally gives each frame the same limited temporal context + # as the official frame-by-frame decode with caching. + first_frame = x[:, 0:1] # [B, 1, H, W, C] + rest = x[:, 1:] # [B, T-1, H, W, C] - if first_chunk: - # PyTorch skips time_conv for first chunk entirely. In all-at-once - # mode, we trim the first frame to match (the first interleaved - # frame is from zero-padded causal context and shouldn't be kept). - x = x[:, 1:, :, :, :] + # time_conv on remaining frames (causal pad gives zero context + # before rest[0], matching the official "Rep" cache path) + tc_out = self.time_conv(rest) # [B, T-1, H, W, 2C] + tc_out = tc_out.reshape(B, T - 1, H, W, 2, C) + stream0 = tc_out[:, :, :, :, 0, :] + stream1 = tc_out[:, :, :, :, 1, :] + interleaved = mx.stack([stream0, stream1], axis=2) + interleaved = interleaved.reshape(B, (T - 1) * 2, H, W, C) + + # first_frame (1) + interleaved (2*(T-1)) = 2T-1 frames + x = mx.concatenate([first_frame, interleaved], axis=1) + elif self.mode == "upsample3d": + # Non-first-chunk or single frame: time_conv all frames + tc_out = self.time_conv(x) # [B, T, H, W, 2C] + tc_out = tc_out.reshape(B, T, H, W, 2, C) + stream0 = tc_out[:, :, :, :, 0, :] + stream1 = tc_out[:, :, :, :, 1, :] + x = mx.stack([stream0, stream1], axis=2) + x = x.reshape(B, T * 2, H, W, C) mx.eval(x) T = x.shape[1] diff --git a/pyproject.toml b/pyproject.toml index 843eb5e..198956d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "mlx-vlm", "imageio>=2.37.2", "imageio-ffmpeg>=0.6.0", + "ftfy", ] license = {text="MIT"} authors = [ diff --git a/tests/test_wan.py b/tests/test_wan.py index 0c6976d..9a35bbd 100644 --- a/tests/test_wan.py +++ b/tests/test_wan.py @@ -1449,5 +1449,1264 @@ class TestWan21Convert: assert d["sample_shift"] == 5.0 +# --------------------------------------------------------------------------- +# Shared Sigma Schedule Tests +# --------------------------------------------------------------------------- + + +class TestComputeSigmas: + """Tests for the shared _compute_sigmas helper.""" + + def test_length(self): + from mlx_video.models.wan.scheduler import _compute_sigmas + sigmas = _compute_sigmas(20, shift=5.0) + assert len(sigmas) == 21 # num_steps + terminal + + def test_terminal_zero(self): + from mlx_video.models.wan.scheduler import _compute_sigmas + sigmas = _compute_sigmas(10, shift=1.0) + assert sigmas[-1] == 0.0 + + def test_starts_at_one(self): + from mlx_video.models.wan.scheduler import _compute_sigmas + sigmas = _compute_sigmas(20, shift=5.0) + np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-6) + + def test_decreasing(self): + from mlx_video.models.wan.scheduler import _compute_sigmas + sigmas = _compute_sigmas(20, shift=5.0) + assert np.all(np.diff(sigmas) <= 0) + + def test_matches_official_wan22(self): + """Sigma schedule should match the official Wan2.2 get_sampling_sigmas.""" + from mlx_video.models.wan.scheduler import _compute_sigmas + steps, shift = 50, 5.0 + sigmas = _compute_sigmas(steps, shift) + # Official: sigma = linspace(1, 0, steps+1)[:steps]; sigma = shift*sigma/(1+(shift-1)*sigma) + official = np.linspace(1, 0, steps + 1)[:steps] + official = shift * official / (1 + (shift - 1) * official) + official = np.append(official, 0.0).astype(np.float32) + np.testing.assert_allclose(sigmas, official, atol=1e-6) + + def test_shift_one_is_linear(self): + from mlx_video.models.wan.scheduler import _compute_sigmas + sigmas = _compute_sigmas(10, shift=1.0) + # With shift=1, f(sigma)=sigma, so schedule is linear from 1 to 0 + expected = np.linspace(1, 0, 11).astype(np.float32) + np.testing.assert_allclose(sigmas, expected, atol=1e-6) + + def test_all_schedulers_same_sigmas(self): + """All three schedulers should produce identical sigma schedules.""" + from mlx_video.models.wan.scheduler import ( + FlowDPMPP2MScheduler, + FlowMatchEulerScheduler, + FlowUniPCScheduler, + ) + scheds = [ + FlowMatchEulerScheduler(1000), + FlowDPMPP2MScheduler(1000), + FlowUniPCScheduler(1000), + ] + for s in scheds: + s.set_timesteps(20, shift=5.0) + mx.eval(*[s.sigmas for s in scheds]) + ref = np.array(scheds[0].sigmas) + for s in scheds[1:]: + np.testing.assert_allclose(np.array(s.sigmas), ref, atol=1e-6) + + def test_all_schedulers_same_timesteps(self): + from mlx_video.models.wan.scheduler import ( + FlowDPMPP2MScheduler, + FlowMatchEulerScheduler, + FlowUniPCScheduler, + ) + scheds = [ + FlowMatchEulerScheduler(1000), + FlowDPMPP2MScheduler(1000), + FlowUniPCScheduler(1000), + ] + for s in scheds: + s.set_timesteps(30, shift=12.0) + mx.eval(*[s.timesteps for s in scheds]) + ref = np.array(scheds[0].timesteps) + for s in scheds[1:]: + np.testing.assert_allclose(np.array(s.timesteps), ref, atol=1e-3) + + +# --------------------------------------------------------------------------- +# DPM++ 2M Scheduler Tests +# --------------------------------------------------------------------------- + + +class TestFlowDPMPP2MScheduler: + def test_initialization(self): + from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + sched = FlowDPMPP2MScheduler() + assert sched.num_train_timesteps == 1000 + assert sched.lower_order_final is True + + def test_set_timesteps(self): + from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + sched = FlowDPMPP2MScheduler() + sched.set_timesteps(20, shift=5.0) + mx.eval(sched.timesteps, sched.sigmas) + assert sched.timesteps.shape == (20,) + assert sched.sigmas.shape == (21,) + + def test_step_index_increments(self): + from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + sched = FlowDPMPP2MScheduler() + sched.set_timesteps(5, shift=1.0) + sample = mx.ones((1, 4, 1, 2, 2)) + vel = mx.zeros_like(sample) + assert sched._step_index == 0 + sched.step(vel, sched.timesteps[0], sample) + assert sched._step_index == 1 + sched.step(vel, sched.timesteps[1], sample) + assert sched._step_index == 2 + + def test_reset(self): + from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + sched = FlowDPMPP2MScheduler() + sched.set_timesteps(5, shift=1.0) + sample = mx.ones((1, 1, 1, 1, 1)) + sched.step(mx.zeros_like(sample), 0, sample) + sched.reset() + assert sched._step_index == 0 + assert sched._prev_x0 is None + + def test_full_loop_finite(self): + """Full loop with constant velocity should produce finite output.""" + from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + sched = FlowDPMPP2MScheduler() + sched.set_timesteps(10, shift=1.0) + sample = mx.ones((1, 2, 1, 2, 2)) + for i in range(10): + vel = mx.ones_like(sample) * 0.1 + sample = sched.step(vel, sched.timesteps[i], sample) + mx.eval(sample) + assert np.isfinite(np.array(sample)).all() + + def test_first_step_is_first_order(self): + """First step should use 1st-order (no prev_x0 available).""" + from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + sched = FlowDPMPP2MScheduler() + sched.set_timesteps(10, shift=5.0) + sample = mx.random.normal((1, 4, 2, 4, 4)) + vel = mx.random.normal(sample.shape) + # Before first step, no prev_x0 + assert sched._prev_x0 is None + result = sched.step(vel, sched.timesteps[0], sample) + mx.eval(result) + # After first step, prev_x0 should be set + assert sched._prev_x0 is not None + + def test_second_step_uses_correction(self): + """After first step, DPM++ should have stored prev_x0 for correction.""" + from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + sched = FlowDPMPP2MScheduler() + sched.set_timesteps(10, shift=5.0) + sample = mx.random.normal((1, 4, 1, 2, 2)) + vel = mx.random.normal(sample.shape) + # Step 1 + sample = sched.step(vel, sched.timesteps[0], sample) + mx.eval(sample) + x0_after_first = sched._prev_x0 + # Step 2 + vel = mx.random.normal(sample.shape) + sample = sched.step(vel, sched.timesteps[1], sample) + mx.eval(sample) + # prev_x0 should have been updated + x0_after_second = sched._prev_x0 + assert x0_after_second is not None + # The stored x0 should differ from the first step's + assert not np.allclose(np.array(x0_after_first), np.array(x0_after_second), atol=1e-6) + + def test_denoise_to_target(self): + """Perfect oracle should denoise to target with any solver.""" + from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + sched = FlowDPMPP2MScheduler() + sched.set_timesteps(20, shift=5.0) + target = mx.zeros((1, 2, 1, 4, 4)) + latents = mx.random.normal(target.shape) + for i in range(20): + sigma = float(sched.sigmas[i].item()) + v = latents / max(sigma, 1e-6) # perfect velocity for target=0 + latents = sched.step(v, sched.timesteps[i], latents) + mx.eval(latents) + np.testing.assert_allclose(np.array(latents), 0.0, atol=1e-3) + + @pytest.mark.parametrize("steps", [5, 10, 20, 50]) + def test_various_step_counts(self, steps): + from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + sched = FlowDPMPP2MScheduler() + sched.set_timesteps(steps, shift=5.0) + mx.eval(sched.timesteps, sched.sigmas) + assert sched.timesteps.shape == (steps,) + assert sched.sigmas.shape == (steps + 1,) + + def test_terminal_sigma_produces_x0(self): + """When sigma_next=0 the scheduler should return x0 directly.""" + from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + sched = FlowDPMPP2MScheduler() + sched.set_timesteps(5, shift=1.0) + sample = mx.ones((1, 1, 1, 1, 1)) * 3.0 + vel = mx.ones_like(sample) * 2.0 + # Run through all steps; the last step has sigma_next=0 + for i in range(5): + sample = sched.step(vel, sched.timesteps[i], sample) + mx.eval(sample) + # Final value should be finite + assert np.isfinite(np.array(sample)).all() + + +# --------------------------------------------------------------------------- +# UniPC Scheduler Tests +# --------------------------------------------------------------------------- + + +class TestFlowUniPCScheduler: + def test_initialization(self): + from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler() + assert sched.num_train_timesteps == 1000 + assert sched.solver_order == 2 + assert sched.lower_order_final is True + + def test_set_timesteps(self): + from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler() + sched.set_timesteps(30, shift=12.0) + mx.eval(sched.timesteps, sched.sigmas) + assert sched.timesteps.shape == (30,) + assert sched.sigmas.shape == (31,) + + def test_step_index_increments(self): + from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler() + sched.set_timesteps(5, shift=1.0) + sample = mx.ones((1, 1, 1, 1, 1)) + vel = mx.zeros_like(sample) + assert sched._step_index == 0 + sched.step(vel, 0, sample) + assert sched._step_index == 1 + + def test_reset(self): + from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler() + sched.set_timesteps(5, shift=1.0) + sample = mx.ones((1, 1, 1, 1, 1)) + sched.step(mx.zeros_like(sample), 0, sample) + sched.reset() + assert sched._step_index == 0 + assert sched._lower_order_nums == 0 + assert sched._last_sample is None + assert all(m is None for m in sched._model_outputs) + + def test_full_loop_finite(self): + from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler() + sched.set_timesteps(10, shift=1.0) + sample = mx.ones((1, 2, 1, 2, 2)) + for i in range(10): + vel = mx.ones_like(sample) * 0.1 + sample = sched.step(vel, sched.timesteps[i], sample) + mx.eval(sample) + assert np.isfinite(np.array(sample)).all() + + def test_corrector_not_applied_first_step(self): + """First step should skip the corrector (no history).""" + from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler(use_corrector=True) + sched.set_timesteps(10, shift=5.0) + sample = mx.random.normal((1, 4, 1, 2, 2)) + vel = mx.random.normal(sample.shape) + # Before step 0: no last_sample + assert sched._last_sample is None + sched.step(vel, sched.timesteps[0], sample) + # After step 0: last_sample should be set for corrector on step 1 + assert sched._last_sample is not None + + def test_corrector_applied_after_first_step(self): + """Steps after the first should use the corrector when enabled.""" + from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler(use_corrector=True) + sched.set_timesteps(10, shift=5.0) + sample = mx.random.normal((1, 2, 1, 4, 4)) + for i in range(3): + vel = mx.random.normal(sample.shape) + sample = sched.step(vel, sched.timesteps[i], sample) + mx.eval(sample) + # lower_order_nums should have increased + assert sched._lower_order_nums >= 2 + + def test_denoise_to_target(self): + from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler() + sched.set_timesteps(20, shift=5.0) + target = mx.zeros((1, 2, 1, 4, 4)) + latents = mx.random.normal(target.shape) + for i in range(20): + sigma = float(sched.sigmas[i].item()) + v = latents / max(sigma, 1e-6) + latents = sched.step(v, sched.timesteps[i], latents) + mx.eval(latents) + np.testing.assert_allclose(np.array(latents), 0.0, atol=1e-3) + + @pytest.mark.parametrize("steps", [5, 10, 20, 50]) + def test_various_step_counts(self, steps): + from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler() + sched.set_timesteps(steps, shift=5.0) + mx.eval(sched.timesteps, sched.sigmas) + assert sched.timesteps.shape == (steps,) + assert sched.sigmas.shape == (steps + 1,) + + def test_disable_corrector(self): + """Disabling corrector on step 0 should still work without error.""" + from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0]) + sched.set_timesteps(5, shift=1.0) + sample = mx.ones((1, 1, 1, 2, 2)) + for i in range(5): + vel = mx.ones_like(sample) * 0.1 + sample = sched.step(vel, sched.timesteps[i], sample) + mx.eval(sample) + assert np.isfinite(np.array(sample)).all() + + def test_solver_order_3(self): + """Order 3 should work without error.""" + from mlx_video.models.wan.scheduler import FlowUniPCScheduler + sched = FlowUniPCScheduler(solver_order=3, use_corrector=True) + sched.set_timesteps(10, shift=5.0) + sample = mx.random.normal((1, 2, 1, 2, 2)) + for i in range(10): + vel = mx.random.normal(sample.shape) + sample = sched.step(vel, sched.timesteps[i], sample) + mx.eval(sample) + assert np.isfinite(np.array(sample)).all() + + def test_corrector_rhos_c_not_hardcoded(self): + """Corrector rhos_c should be computed via linalg.solve, not hardcoded 0.5.""" + import math + # For 50-step schedule with shift=5.0, order 2 corrector at step 5: + # rhos_c[0] (history) should be ~0.07, NOT 0.5 + # rhos_c[1] (D1_t) should be ~0.45, NOT 0.5 + from mlx_video.models.wan.scheduler import _compute_sigmas + + sigmas = _compute_sigmas(50, shift=5.0) + + def _lambda(sigma): + if sigma >= 1.0: + return -math.inf + if sigma <= 0.0: + return math.inf + return math.log(1 - sigma) - math.log(sigma) + + for step_idx in [5, 10, 25, 45]: + sigma_s0 = sigmas[step_idx - 1] + sigma_t = sigmas[step_idx] + lambda_s0 = _lambda(sigma_s0) + lambda_t = _lambda(sigma_t) + h = lambda_t - lambda_s0 + hh = -h + + sigma_sk = sigmas[step_idx - 2] + lambda_sk = _lambda(sigma_sk) + rk = (lambda_sk - lambda_s0) / h + rks = np.array([rk, 1.0]) + + h_phi_1 = math.expm1(hh) + B_h = h_phi_1 + h_phi_k = h_phi_1 / hh - 1.0 + factorial_i = 1 + R_rows, b_vals = [], [] + for j in range(1, 3): + R_rows.append(rks ** (j - 1)) + b_vals.append(h_phi_k * factorial_i / B_h) + factorial_i *= j + 1 + h_phi_k = h_phi_k / hh - 1.0 / factorial_i + R = np.stack(R_rows) + b = np.array(b_vals) + rhos_c = np.linalg.solve(R, b) + + # History weight should be small (~0.07-0.09), not 0.5 + assert rhos_c[0] < 0.15, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large" + assert rhos_c[0] > 0.0, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive" + # D1_t weight should be ~0.42-0.45, not 0.5 + assert 0.3 < rhos_c[1] < 0.5, f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range" + + +class TestSchedulerCoherence: + """Tests that Euler, DPM++, and UniPC schedulers produce coherent results. + + All three schedulers should agree on shared structure (sigma schedules, + first-step behavior) and converge to the same result given perfect + velocity oracles, even though they use different update rules. + """ + + @staticmethod + def _make_schedulers(steps=10, shift=5.0): + from mlx_video.models.wan.scheduler import ( + FlowDPMPP2MScheduler, + FlowMatchEulerScheduler, + FlowUniPCScheduler, + ) + + scheds = { + "euler": FlowMatchEulerScheduler(), + "dpm++": FlowDPMPP2MScheduler(), + "unipc": FlowUniPCScheduler(), + } + for s in scheds.values(): + s.set_timesteps(steps, shift=shift) + return scheds + + def test_identical_sigma_schedules(self): + """All schedulers must use the same sigma schedule.""" + scheds = self._make_schedulers(20, shift=5.0) + ref = np.array(scheds["euler"].sigmas) + for name in ("dpm++", "unipc"): + np.testing.assert_allclose( + np.array(scheds[name].sigmas), + ref, + atol=1e-6, + err_msg=f"{name} sigma schedule differs from Euler", + ) + + def test_identical_timesteps(self): + """All schedulers must produce the same timestep sequence.""" + scheds = self._make_schedulers(20, shift=5.0) + ref = np.array(scheds["euler"].timesteps) + for name in ("dpm++", "unipc"): + np.testing.assert_allclose( + np.array(scheds[name].timesteps), + ref, + atol=1e-6, + err_msg=f"{name} timesteps differ from Euler", + ) + + def test_first_step_matches_euler(self): + """Step 0 (1st-order for all solvers) should match Euler exactly.""" + mx.random.seed(42) + shape = (1, 4, 1, 4, 4) + noise = mx.random.normal(shape) + vel = mx.random.normal(shape) + + scheds = self._make_schedulers(10, shift=5.0) + results = {} + for name, sched in scheds.items(): + r = sched.step(vel, sched.timesteps[0], noise) + mx.eval(r) + results[name] = np.array(r) + + np.testing.assert_allclose( + results["dpm++"], results["euler"], atol=1e-5, + err_msg="DPM++ step 0 should match Euler", + ) + np.testing.assert_allclose( + results["unipc"], results["euler"], atol=1e-5, + err_msg="UniPC step 0 should match Euler", + ) + + def test_first_step_matches_across_shifts(self): + """Step 0 should match Euler for different shift values.""" + mx.random.seed(99) + shape = (1, 2, 1, 2, 2) + noise = mx.random.normal(shape) + vel = mx.random.normal(shape) + + for shift in (1.0, 5.0, 12.0): + scheds = self._make_schedulers(10, shift=shift) + euler_r = scheds["euler"].step(vel, scheds["euler"].timesteps[0], noise) + dpm_r = scheds["dpm++"].step(vel, scheds["dpm++"].timesteps[0], noise) + unipc_r = scheds["unipc"].step(vel, scheds["unipc"].timesteps[0], noise) + mx.eval(euler_r, dpm_r, unipc_r) + np.testing.assert_allclose( + np.array(dpm_r), np.array(euler_r), atol=1e-5, + err_msg=f"DPM++ step 0 differs from Euler at shift={shift}", + ) + np.testing.assert_allclose( + np.array(unipc_r), np.array(euler_r), atol=1e-5, + err_msg=f"UniPC step 0 differs from Euler at shift={shift}", + ) + + def test_oracle_all_converge_to_target(self): + """Given a perfect velocity oracle v=x/sigma, all solvers should + denoise to approximately zero (the target).""" + mx.random.seed(7) + shape = (1, 2, 1, 4, 4) + noise = mx.random.normal(shape) + + for name, sched in self._make_schedulers(20, shift=5.0).items(): + latents = noise + for i in range(20): + sigma = float(sched.sigmas[i].item()) + v = latents / max(sigma, 1e-8) + latents = sched.step(v, sched.timesteps[i], latents) + mx.eval(latents) + np.testing.assert_allclose( + np.array(latents), 0.0, atol=1e-3, + err_msg=f"{name} did not converge to target with oracle", + ) + + def test_oracle_higher_order_closer_to_target(self): + """With few steps and a perfect oracle, higher-order solvers should + be at least as accurate as Euler.""" + mx.random.seed(12) + shape = (1, 2, 1, 4, 4) + noise = mx.random.normal(shape) + steps = 5 + + errors = {} + for name, sched in self._make_schedulers(steps, shift=5.0).items(): + latents = noise + for i in range(steps): + sigma = float(sched.sigmas[i].item()) + v = latents / max(sigma, 1e-8) + latents = sched.step(v, sched.timesteps[i], latents) + mx.eval(latents) + errors[name] = float(mx.mean(mx.abs(latents)).item()) + + # Higher-order solvers should not be significantly worse than Euler + assert errors["dpm++"] <= errors["euler"] * 1.5, ( + f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}" + ) + assert errors["unipc"] <= errors["euler"] * 1.5, ( + f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}" + ) + + def test_multistep_trajectory_similar_magnitude(self): + """Over a full denoising loop with constant velocity, all solvers + should produce outputs of similar magnitude (not diverging).""" + mx.random.seed(42) + shape = (1, 4, 1, 4, 4) + noise = mx.random.normal(shape) + steps = 20 + + final_means = {} + for name, sched in self._make_schedulers(steps, shift=5.0).items(): + latents = noise + for i in range(steps): + vel = latents * 0.1 + latents = sched.step(vel, sched.timesteps[i], latents) + mx.eval(latents) + final_means[name] = float(mx.mean(mx.abs(latents)).item()) + + # All solvers should produce results within the same order of magnitude + vals = list(final_means.values()) + ratio = max(vals) / max(min(vals), 1e-10) + assert ratio < 10.0, ( + f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}" + ) + + def test_intermediate_values_finite(self): + """Every intermediate latent value must be finite for all solvers.""" + mx.random.seed(0) + shape = (1, 2, 1, 2, 2) + noise = mx.random.normal(shape) + + for name, sched in self._make_schedulers(15, shift=5.0).items(): + latents = noise + for i in range(15): + vel = mx.random.normal(shape) + latents = sched.step(vel, sched.timesteps[i], latents) + mx.eval(latents) + assert np.isfinite(np.array(latents)).all(), ( + f"{name} produced non-finite values at step {i}" + ) + + def test_lambda_boundary_values(self): + """_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0.""" + from mlx_video.models.wan.scheduler import ( + FlowDPMPP2MScheduler, + FlowUniPCScheduler, + ) + + for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler): + assert cls._lambda(1.0) == -math.inf, ( + f"{cls.__name__}._lambda(1.0) should be -inf" + ) + assert cls._lambda(0.0) == math.inf, ( + f"{cls.__name__}._lambda(0.0) should be +inf" + ) + # Interior values should be finite + lam = cls._lambda(0.5) + assert math.isfinite(lam) and lam == 0.0, ( + f"{cls.__name__}._lambda(0.5) should be 0.0" + ) + + def test_lambda_monotonically_decreasing(self): + """_lambda(sigma) should decrease as sigma increases (more noise → lower SNR).""" + from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler + + sigmas = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99] + lambdas = [FlowDPMPP2MScheduler._lambda(s) for s in sigmas] + for i in range(len(lambdas) - 1): + assert lambdas[i] > lambdas[i + 1], ( + f"_lambda not decreasing: _lambda({sigmas[i]})={lambdas[i]} " + f"vs _lambda({sigmas[i+1]})={lambdas[i+1]}" + ) + + def test_step0_is_ddim_formula(self): + """At sigma=1.0, the DPM++/UniPC first step should reduce to the + DDIM formula: x_next = sigma_next * x + (1 - sigma_next) * x0.""" + mx.random.seed(55) + shape = (1, 2, 1, 2, 2) + sample = mx.random.normal(shape) + vel = mx.random.normal(shape) + + for steps, shift in [(10, 5.0), (20, 12.0)]: + scheds = self._make_schedulers(steps, shift=shift) + sigma_next = float(scheds["euler"].sigmas[1].item()) + sigma_cur = float(scheds["euler"].sigmas[0].item()) + assert abs(sigma_cur - 1.0) < 1e-6, "First sigma should be ~1.0" + + x0 = sample - sigma_cur * vel + expected = sigma_next * sample + (1.0 - sigma_next) * x0 + mx.eval(expected) + + for name in ("dpm++", "unipc"): + result = scheds[name].step(vel, scheds[name].timesteps[0], sample) + mx.eval(result) + np.testing.assert_allclose( + np.array(result), np.array(expected), atol=1e-5, + err_msg=f"{name} step 0 doesn't match DDIM formula (shift={shift})", + ) + + @pytest.mark.parametrize("steps", [5, 10, 20, 50]) + def test_coherent_across_step_counts(self, steps): + """All solvers should agree on step 0 regardless of total step count.""" + mx.random.seed(77) + shape = (1, 2, 1, 2, 2) + noise = mx.random.normal(shape) + vel = mx.random.normal(shape) + + scheds = self._make_schedulers(steps, shift=5.0) + results = {} + for name, sched in scheds.items(): + r = sched.step(vel, sched.timesteps[0], noise) + mx.eval(r) + results[name] = np.array(r) + + np.testing.assert_allclose( + results["dpm++"], results["euler"], atol=1e-5, + ) + np.testing.assert_allclose( + results["unipc"], results["euler"], atol=1e-5, + ) + + def test_dpmpp_unipc_agree_on_step1(self): + """After warmup, DPM++ and UniPC step 1 should be similar + (both use 2nd-order corrections based on the same model outputs).""" + mx.random.seed(42) + shape = (1, 4, 1, 4, 4) + noise = mx.random.normal(shape) + + scheds = self._make_schedulers(10, shift=5.0) + # Run step 0 with same velocity + vel0 = mx.random.normal(shape) + for sched in scheds.values(): + sched.step(vel0, sched.timesteps[0], noise) + + # Run step 1 from same sample with same velocity + sample1 = scheds["euler"].step(vel0, scheds["euler"].timesteps[0], noise) + mx.eval(sample1) + vel1 = mx.random.normal(shape) + + r_dpm = scheds["dpm++"].step(vel1, scheds["dpm++"].timesteps[1], sample1) + r_unipc = scheds["unipc"].step(vel1, scheds["unipc"].timesteps[1], sample1) + mx.eval(r_dpm, r_unipc) + + # They won't be identical (different correction formulas) but should + # be in the same ballpark (within 50% of each other's magnitude) + mean_dpm = float(mx.mean(mx.abs(r_dpm)).item()) + mean_unipc = float(mx.mean(mx.abs(r_unipc)).item()) + ratio = max(mean_dpm, mean_unipc) / max(min(mean_dpm, mean_unipc), 1e-10) + assert ratio < 2.0, ( + f"DPM++ and UniPC step 1 differ too much: " + f"DPM++={mean_dpm:.4f}, UniPC={mean_unipc:.4f}" + ) + + def test_reset_makes_solvers_reproducible(self): + """After reset(), running the same loop should produce identical output.""" + mx.random.seed(42) + shape = (1, 2, 1, 2, 2) + noise = mx.random.normal(shape) + + from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler, FlowUniPCScheduler + + for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler): + sched = cls() + sched.set_timesteps(5, shift=5.0) + + # First run + latents = noise + for i in range(5): + vel = latents * 0.1 + latents = sched.step(vel, sched.timesteps[i], latents) + mx.eval(latents) + result1 = np.array(latents) + + # Reset and run again + sched.reset() + latents = noise + for i in range(5): + vel = latents * 0.1 + latents = sched.step(vel, sched.timesteps[i], latents) + mx.eval(latents) + result2 = np.array(latents) + + np.testing.assert_allclose(result1, result2, atol=1e-5, + err_msg=f"{cls.__name__} not reproducible after reset()") + + +# --------------------------------------------------------------------------- +# Wan2.2 VAE Component Tests +# --------------------------------------------------------------------------- + + +class TestVAE22CausalConv3d: + """Tests for vae22.CausalConv3d (channels-last).""" + + def test_output_shape_k3(self): + from mlx_video.models.wan.vae22 import CausalConv3d + conv = CausalConv3d(8, 16, kernel_size=3, padding=1) + x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C] + out = conv(x) + mx.eval(out) + assert out.shape == (1, 4, 8, 8, 16) + + def test_output_shape_k1(self): + from mlx_video.models.wan.vae22 import CausalConv3d + conv = CausalConv3d(8, 16, kernel_size=1) + x = mx.random.normal((1, 2, 4, 4, 8)) + out = conv(x) + mx.eval(out) + assert out.shape == (1, 2, 4, 4, 16) + + def test_temporal_causal(self): + """Output at t=0 should not depend on t>0.""" + from mlx_video.models.wan.vae22 import CausalConv3d + conv = CausalConv3d(2, 2, kernel_size=3, padding=1) + conv.weight = mx.random.normal(conv.weight.shape) * 0.1 + conv.bias = mx.zeros(conv.bias.shape) + + x = mx.zeros((1, 4, 4, 4, 2)) + out_zero = conv(x) + mx.eval(out_zero) + t0_ref = np.array(out_zero[0, 0]) + + # Modify t=2..3; output at t=0 should be unchanged + x_mod = mx.concatenate([ + x[:, :2], + mx.ones((1, 2, 4, 4, 2)), + ], axis=1) + out_mod = conv(x_mod) + mx.eval(out_mod) + t0_mod = np.array(out_mod[0, 0]) + np.testing.assert_allclose(t0_ref, t0_mod, atol=1e-5) + + def test_channels_last_format(self): + """Verify input/output are channels-last [B, T, H, W, C].""" + from mlx_video.models.wan.vae22 import CausalConv3d + conv = CausalConv3d(4, 8, kernel_size=3, padding=1) + x = mx.random.normal((2, 3, 6, 6, 4)) + out = conv(x) + mx.eval(out) + assert out.shape[-1] == 8 # last dim = out_channels + + +class TestRMSNorm: + """Tests for vae22.RMS_norm (actually L2 normalization).""" + + def test_output_shape(self): + from mlx_video.models.wan.vae22 import RMS_norm + norm = RMS_norm(16) + x = mx.random.normal((2, 4, 4, 4, 16)) + out = norm(x) + mx.eval(out) + assert out.shape == x.shape + + def test_l2_normalization(self): + """RMS_norm should normalize to unit L2 norm * sqrt(dim).""" + from mlx_video.models.wan.vae22 import RMS_norm + dim = 32 + norm = RMS_norm(dim) + x = mx.random.normal((1, 1, 1, 1, dim)) * 5.0 # large values + out = norm(x) + mx.eval(out) + # After L2 norm * scale(=sqrt(dim)) * gamma(=1): ||out|| = sqrt(dim) + out_np = np.array(out).flatten() + l2 = np.linalg.norm(out_np) + np.testing.assert_allclose(l2, math.sqrt(dim), rtol=1e-3) + + def test_scale_invariant(self): + """Scaling input by constant should not change output (L2 norm property).""" + from mlx_video.models.wan.vae22 import RMS_norm + norm = RMS_norm(8) + x = mx.random.normal((1, 1, 1, 1, 8)) + out1 = norm(x) + out2 = norm(x * 10.0) + mx.eval(out1, out2) + np.testing.assert_allclose(np.array(out1), np.array(out2), atol=1e-4) + + def test_gamma_effect(self): + """Non-unit gamma should scale output.""" + from mlx_video.models.wan.vae22 import RMS_norm + norm = RMS_norm(4) + norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0]) + x = mx.ones((1, 1, 1, 1, 4)) + out = norm(x) + mx.eval(out) + # With gamma=2, each component is 2 * sqrt(4) * x/||x|| = 2 * 2 * 1/2 = 2 + np.testing.assert_allclose(np.array(out).flatten(), 2.0, atol=1e-4) + + +class TestDupUp3D: + """Tests for vae22.DupUp3D spatial/temporal upsampling.""" + + def test_spatial_only(self): + from mlx_video.models.wan.vae22 import DupUp3D + up = DupUp3D(8, 4, factor_t=1, factor_s=2) + x = mx.random.normal((1, 3, 4, 4, 8)) + out = up(x) + mx.eval(out) + assert out.shape == (1, 3, 8, 8, 4) + + def test_temporal_and_spatial(self): + from mlx_video.models.wan.vae22 import DupUp3D + up = DupUp3D(16, 8, factor_t=2, factor_s=2) + x = mx.random.normal((1, 3, 4, 4, 16)) + out = up(x) + mx.eval(out) + assert out.shape == (1, 6, 8, 8, 8) + + def test_first_chunk_trims(self): + from mlx_video.models.wan.vae22 import DupUp3D + up = DupUp3D(8, 4, factor_t=2, factor_s=2) + x = mx.random.normal((1, 3, 4, 4, 8)) + out_normal = up(x, first_chunk=False) + out_trimmed = up(x, first_chunk=True) + mx.eval(out_normal, out_trimmed) + # first_chunk removes factor_t-1=1 temporal frame + assert out_normal.shape[1] == 6 + assert out_trimmed.shape[1] == 5 + + def test_no_temporal_first_chunk_noop(self): + from mlx_video.models.wan.vae22 import DupUp3D + up = DupUp3D(8, 4, factor_t=1, factor_s=2) + x = mx.random.normal((1, 3, 4, 4, 8)) + out_normal = up(x, first_chunk=False) + out_trimmed = up(x, first_chunk=True) + mx.eval(out_normal, out_trimmed) + # factor_t=1, so first_chunk removes 0 frames + assert out_normal.shape == out_trimmed.shape + + +class TestVAE22Resample: + """Tests for vae22.Resample (spatial/temporal upsampling).""" + + def test_upsample2d_shape(self): + from mlx_video.models.wan.vae22 import Resample + r = Resample(8, "upsample2d") + r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 + x = mx.random.normal((1, 2, 4, 4, 8)) + out = r(x) + mx.eval(out) + assert out.shape == (1, 2, 8, 8, 8) # 2x spatial, same temporal + + def test_upsample3d_shape(self): + from mlx_video.models.wan.vae22 import Resample + r = Resample(8, "upsample3d") + r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 + x = mx.random.normal((1, 2, 4, 4, 8)) + out = r(x) + mx.eval(out) + assert out.shape == (1, 4, 8, 8, 8) # 2x spatial + 2x temporal + + def test_upsample3d_first_chunk(self): + from mlx_video.models.wan.vae22 import Resample + r = Resample(8, "upsample3d") + r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 + x = mx.random.normal((1, 2, 4, 4, 8)) + out = r(x, first_chunk=True) + mx.eval(out) + # first_chunk: 1 (bypass) + 2*(T-1) (interleaved) = 2T-1 = 3 + assert out.shape == (1, 3, 8, 8, 8) + + def test_upsample3d_first_chunk_single_frame(self): + """Single-frame input with first_chunk: no temporal upsample.""" + from mlx_video.models.wan.vae22 import Resample + r = Resample(8, "upsample3d") + r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01 + x = mx.random.normal((1, 1, 4, 4, 8)) + out = r(x, first_chunk=True) + mx.eval(out) + # Single frame with first_chunk: falls through to non-first path + # time_conv on 1 frame → 2 interleaved + assert out.shape == (1, 2, 8, 8, 8) + + def test_upsample3d_first_frame_bypasses_time_conv(self): + """First frame of first_chunk should NOT go through time_conv. + + Official Wan2.2 skips time_conv for the very first frame entirely. + We verify this by checking that the first output frame depends only on + the first input frame (not on time_conv parameters). + """ + from mlx_video.models.wan.vae22 import Resample + C = 8 + r = Resample(C, "upsample3d") + # Set time_conv weights to large values so its effect is detectable + r.time_conv.weight = mx.ones(r.time_conv.weight.shape) * 10.0 + r.time_conv.bias = mx.zeros(r.time_conv.bias.shape) + # Set spatial conv to identity-like + r.resample_weight = mx.zeros(r.resample_weight.shape) + r.resample_bias = mx.zeros(r.resample_bias.shape) + + x = mx.random.normal((1, 3, 2, 2, C)) + out = r(x, first_chunk=True) + mx.eval(out) + # Output: 5 frames (1 bypass + 4 interleaved from 2 remaining) + assert out.shape[1] == 5 + + # First frame should be spatial upsample of x[:, 0:1] only. + # Run just the first frame through spatial upsample for reference + first_only = x[:, 0:1] + ref = r._upsample2x(first_only.reshape(1, 2, 2, C)) + ref = mx.pad(ref, [(0, 0), (1, 1), (1, 1), (0, 0)]) + ref = mx.conv_general(ref, r.resample_weight) + r.resample_bias + mx.eval(ref) + + # Compare first output frame to reference + first_out = out[:, 0:1].reshape(1, out.shape[2], out.shape[3], C) + mx.eval(first_out) + assert mx.allclose(first_out, ref, atol=1e-5).item(), \ + "First frame should bypass time_conv and match spatial-only upsample" + + +class TestVAE22ResidualBlock: + """Tests for vae22.ResidualBlock.""" + + def test_same_dim(self): + from mlx_video.models.wan.vae22 import ResidualBlock + block = ResidualBlock(8, 8) + x = mx.random.normal((1, 2, 4, 4, 8)) + out = block(x) + mx.eval(out) + assert out.shape == (1, 2, 4, 4, 8) + + def test_different_dim(self): + from mlx_video.models.wan.vae22 import ResidualBlock + block = ResidualBlock(8, 16) + x = mx.random.normal((1, 2, 4, 4, 8)) + out = block(x) + mx.eval(out) + assert out.shape == (1, 2, 4, 4, 16) + + def test_shortcut_when_dims_differ(self): + from mlx_video.models.wan.vae22 import ResidualBlock + block = ResidualBlock(8, 16) + assert block.shortcut is not None + + def test_no_shortcut_same_dim(self): + from mlx_video.models.wan.vae22 import ResidualBlock + block = ResidualBlock(8, 8) + assert block.shortcut is None + + +class TestResidualBlockLayers: + """Tests for vae22.ResidualBlockLayers naming convention.""" + + def test_layer_names_no_underscore_prefix(self): + """Layer names must NOT start with underscore (MLX ignores them).""" + from mlx_video.models.wan.vae22 import ResidualBlockLayers + block = ResidualBlockLayers(8, 8) + params = dict(block.parameters()) + # All param keys should use layer_N, not _layer_N + for key in params: + assert not key.startswith("_"), f"Parameter {key} starts with underscore" + + def test_has_expected_layers(self): + from mlx_video.models.wan.vae22 import ResidualBlockLayers + block = ResidualBlockLayers(8, 16) + assert hasattr(block, "layer_0") # first RMS_norm + assert hasattr(block, "layer_2") # first CausalConv3d + assert hasattr(block, "layer_3") # second RMS_norm + assert hasattr(block, "layer_6") # second CausalConv3d + + def test_forward_shape(self): + from mlx_video.models.wan.vae22 import ResidualBlockLayers + block = ResidualBlockLayers(8, 16) + x = mx.random.normal((1, 2, 4, 4, 8)) + out = block(x) + mx.eval(out) + assert out.shape == (1, 2, 4, 4, 16) + + +class TestVAE22AttentionBlock: + """Tests for vae22.AttentionBlock (per-frame 2D self-attention).""" + + def test_output_shape(self): + from mlx_video.models.wan.vae22 import AttentionBlock + block = AttentionBlock(16) + block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01 + block.proj_weight = mx.random.normal(block.proj_weight.shape) * 0.01 + x = mx.random.normal((1, 2, 4, 4, 16)) + out = block(x) + mx.eval(out) + assert out.shape == (1, 2, 4, 4, 16) + + def test_residual_connection(self): + from mlx_video.models.wan.vae22 import AttentionBlock + block = AttentionBlock(8) + block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape) + block.proj_weight = mx.zeros(block.proj_weight.shape) + x = mx.ones((1, 1, 2, 2, 8)) + out = block(x) + mx.eval(out) + # With zero weights, attention output is 0 → residual is identity + np.testing.assert_allclose(np.array(out), np.array(x), atol=1e-5) + + +class TestHead22: + """Tests for vae22.Head22 output head.""" + + def test_output_shape(self): + from mlx_video.models.wan.vae22 import Head22 + head = Head22(16, out_channels=12) + x = mx.random.normal((1, 2, 4, 4, 16)) + out = head(x) + mx.eval(out) + assert out.shape == (1, 2, 4, 4, 12) + + def test_layer_names_no_underscore(self): + """Head layers must not use underscore prefix.""" + from mlx_video.models.wan.vae22 import Head22 + head = Head22(8) + assert hasattr(head, "layer_0") # RMS_norm + assert hasattr(head, "layer_2") # CausalConv3d + params = dict(head.parameters()) + for key in params: + assert not key.startswith("_"), f"Head param {key} starts with underscore" + + +class TestUnpatchify: + """Tests for vae22._unpatchify.""" + + def test_basic_shape(self): + from mlx_video.models.wan.vae22 import _unpatchify + x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2 + out = _unpatchify(x, patch_size=2) + mx.eval(out) + assert out.shape == (1, 2, 8, 8, 3) + + def test_patch_size_1_noop(self): + from mlx_video.models.wan.vae22 import _unpatchify + x = mx.random.normal((1, 2, 4, 4, 3)) + out = _unpatchify(x, patch_size=1) + mx.eval(out) + np.testing.assert_array_equal(np.array(out), np.array(x)) + + def test_preserves_content(self): + """Unpatchify should be a lossless rearrangement.""" + from mlx_video.models.wan.vae22 import _unpatchify + x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32) + out = _unpatchify(x, patch_size=2) + mx.eval(out) + # All elements should be preserved + assert np.array(out).size == 48 + assert set(np.array(out).flatten().tolist()) == set(range(48)) + + +class TestDenormalizeLatents: + """Tests for vae22.denormalize_latents.""" + + def test_output_shape(self): + from mlx_video.models.wan.vae22 import denormalize_latents + z = mx.random.normal((1, 2, 4, 4, 48)) + out = denormalize_latents(z) + mx.eval(out) + assert out.shape == (1, 2, 4, 4, 48) + + def test_custom_mean_std(self): + from mlx_video.models.wan.vae22 import denormalize_latents + z = mx.ones((1, 1, 1, 1, 4)) + mean = mx.array([1.0, 2.0, 3.0, 4.0]) + std = mx.array([0.5, 0.5, 0.5, 0.5]) + out = denormalize_latents(z, mean=mean, std=std) + mx.eval(out) + # z * std + mean = 1*0.5 + [1,2,3,4] = [1.5, 2.5, 3.5, 4.5] + np.testing.assert_allclose(np.array(out).flatten(), [1.5, 2.5, 3.5, 4.5], atol=1e-5) + + def test_uses_default_constants(self): + from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD, denormalize_latents + # Should not raise with default constants + z = mx.zeros((1, 1, 1, 1, 48)) + out = denormalize_latents(z) + mx.eval(out) + # z=0 → result = 0 * std + mean = mean + np.testing.assert_allclose( + np.array(out).flatten(), + np.array(VAE22_MEAN).flatten(), + atol=1e-5, + ) + + +class TestVAE22NormConstants: + """Tests for VAE22_MEAN and VAE22_STD constants.""" + + def test_dimensions(self): + from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD + mx.eval(VAE22_MEAN, VAE22_STD) + assert VAE22_MEAN.shape == (48,) + assert VAE22_STD.shape == (48,) + + def test_std_positive(self): + from mlx_video.models.wan.vae22 import VAE22_STD + mx.eval(VAE22_STD) + assert (np.array(VAE22_STD) > 0).all() + + +class TestWan22VAEDecoder: + """Tests for the full Wan22VAEDecoder (tiny configuration).""" + + def test_output_shape_small(self): + """Tiny decoder should produce correct spatial/temporal output.""" + from mlx_video.models.wan.vae22 import Wan22VAEDecoder + # Use very small dims to keep test fast + dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8) + # Latent: [B=1, T=3, H=2, W=2, C=4] + # Expected: temporal 3→5→9→9→9 (two temporal upsamples), spatial 2→4→8→16 + z = mx.random.normal((1, 3, 2, 2, 4)) * 0.1 + out = dec(z) + mx.eval(out) + # Output should have 3 RGB channels and be clipped to [-1, 1] + assert out.shape[-1] == 3 + assert out.ndim == 5 + assert np.array(out).min() >= -1.0 + assert np.array(out).max() <= 1.0 + + def test_output_clipped(self): + from mlx_video.models.wan.vae22 import Wan22VAEDecoder + dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8) + z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values + out = dec(z) + mx.eval(out) + assert np.array(out).min() >= -1.0 - 1e-6 + assert np.array(out).max() <= 1.0 + 1e-6 + + +class TestSanitizeWan22VAEWeights: + """Tests for vae22.sanitize_wan22_vae_weights.""" + + def test_skip_encoder(self): + from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + weights = { + "encoder.layer.weight": mx.zeros((4,)), + "conv1.weight": mx.zeros((4,)), + "decoder.conv1.bias": mx.zeros((4,)), + } + out = sanitize_wan22_vae_weights(weights) + assert "encoder.layer.weight" not in out + assert "conv1.weight" not in out + assert "decoder.conv1.bias" in out + + def test_sequential_index_remapping(self): + from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + weights = { + "decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)), + "decoder.upsamples.0.upsamples.0.residual.6.bias": mx.zeros((8,)), + "decoder.head.0.gamma": mx.ones((4,)), + "decoder.head.2.bias": mx.zeros((12,)), + } + out = sanitize_wan22_vae_weights(weights) + assert "decoder.upsamples.0.upsamples.0.residual.layer_0.gamma" in out + assert "decoder.upsamples.0.upsamples.0.residual.layer_6.bias" in out + assert "decoder.head.layer_0.gamma" in out + assert "decoder.head.layer_2.bias" in out + + def test_resample_conv_remapping(self): + from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + weights = { + "decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)), + "decoder.upsamples.1.upsamples.3.resample.1.bias": mx.zeros((8,)), + } + out = sanitize_wan22_vae_weights(weights) + assert "decoder.upsamples.1.upsamples.3.resample_weight" in out + assert "decoder.upsamples.1.upsamples.3.resample_bias" in out + + def test_attention_remapping(self): + from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + weights = { + "decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)), + "decoder.middle.1.to_qkv.bias": mx.zeros((24,)), + "decoder.middle.1.proj.weight": mx.zeros((8, 8, 1, 1)), + "decoder.middle.1.proj.bias": mx.zeros((8,)), + } + out = sanitize_wan22_vae_weights(weights) + assert "decoder.middle.1.to_qkv_weight" in out + assert "decoder.middle.1.to_qkv_bias" in out + assert "decoder.middle.1.proj_weight" in out + assert "decoder.middle.1.proj_bias" in out + + def test_conv3d_transpose(self): + from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + # Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I] + w = mx.zeros((16, 8, 3, 3, 3)) + weights = {"decoder.conv1.weight": w} + out = sanitize_wan22_vae_weights(weights) + assert out["decoder.conv1.weight"].shape == (16, 3, 3, 3, 8) + + def test_conv2d_transpose(self): + from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + # Conv2d weight: [O, I, H, W] → [O, H, W, I] + w = mx.zeros((8, 8, 3, 3)) + weights = {"decoder.upsamples.0.upsamples.2.resample.1.weight": w} + out = sanitize_wan22_vae_weights(weights) + key = "decoder.upsamples.0.upsamples.2.resample_weight" + assert out[key].shape == (8, 3, 3, 8) + + def test_gamma_squeeze(self): + from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights + # gamma: (dim, 1, 1, 1) → (dim,) + w = mx.ones((16, 1, 1, 1)) + weights = {"decoder.upsamples.0.upsamples.0.residual.0.gamma": w} + out = sanitize_wan22_vae_weights(weights) + key = "decoder.upsamples.0.upsamples.0.residual.layer_0.gamma" + assert out[key].shape == (16,) + + +class TestUpResidualBlock: + """Tests for vae22.Up_ResidualBlock.""" + + def test_no_upsample(self): + from mlx_video.models.wan.vae22 import Up_ResidualBlock + block = Up_ResidualBlock(8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False) + x = mx.random.normal((1, 2, 4, 4, 8)) + out = block(x) + mx.eval(out) + # No upsample: same shape + assert out.shape == (1, 2, 4, 4, 8) + + def test_spatial_upsample(self): + from mlx_video.models.wan.vae22 import Up_ResidualBlock + block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True) + x = mx.random.normal((1, 2, 4, 4, 8)) + out = block(x) + mx.eval(out) + # 2x spatial upsample, no temporal + assert out.shape == (1, 2, 8, 8, 4) + + def test_spatial_temporal_upsample(self): + from mlx_video.models.wan.vae22 import Up_ResidualBlock + block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True) + x = mx.random.normal((1, 2, 4, 4, 8)) + out = block(x) + mx.eval(out) + # 2x spatial + 2x temporal + assert out.shape == (1, 4, 8, 8, 4) + + if __name__ == "__main__": pytest.main([__file__, "-v"])