Merge branch 'main' into pc/unify-apis

This commit is contained in:
Prince Canuma
2026-03-18 17:14:17 +01:00
48 changed files with 14133 additions and 10 deletions

View File

@@ -1,2 +1,3 @@
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
from mlx_video.models.wan import WanModel, WanModelConfig

View File

@@ -0,0 +1,349 @@
## Wan2.1 / Wan2.2
Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.com/Wan-Video/Wan2.2) are text-to-video diffusion models built on a DiT (Diffusion Transformer) backbone with a T5 text encoder and 3D VAE.
They share the same model architecture — the difference is in the inference pipeline:
| | Wan2.1 | Wan2.2 T2V-14B | Wan2.2 I2V-14B | Wan2.2 TI2V-5B |
|---|--------|--------|--------|--------|
| **Task** | Text-to-Video | Text-to-Video | Image-to-Video | Text+Image-to-Video |
| **Pipeline** | Single model | Dual model | Dual model | Single model |
| **Sizes** | 1.3B, 14B | 14B | 14B | 5B |
| **Resolution** | 480P (1.3B), 720P (14B) | 720P | 720P | 720P |
| **Steps** | 50 | 40 | 40 | 40 |
| **Guidance** | 5.0 (fixed) | 3.0 / 4.0 | 3.5 / 3.5 | 5.0 (fixed) |
| **Shift** | 5.0 | 12.0 | 5.0 | 5.0 |
| **VAE** | Wan2.1 (z=16) | Wan2.1 (z=16) | Wan2.1 (z=16) + encoder | Wan2.2 (z=48) |
### Step 1: Download Weights
Download the original PyTorch checkpoints from HuggingFace using the `huggingface-cli` tool (install with `pip install huggingface_hub`):
**Wan2.1**
```bash
# Text-to-Video 1.3B (fast, fits in ~4 GB)
huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir ./Wan2.1-T2V-1.3B
# Text-to-Video 14B
huggingface-cli download Wan-AI/Wan2.1-T2V-14B --local-dir ./Wan2.1-T2V-14B
```
**Wan2.2**
```bash
# Text-to-Video 14B
huggingface-cli download Wan-AI/Wan2.2-T2V-A14B --local-dir ./Wan2.2-T2V-A14B
# Image-to-Video 14B
huggingface-cli download Wan-AI/Wan2.2-I2V-A14B --local-dir ./Wan2.2-I2V-A14B
# Text+Image-to-Video 5B (uses a different VAE — z_dim=48)
huggingface-cli download Wan-AI/Wan2.2-TI2V-5B --local-dir ./Wan2.2-TI2V-5B
```
Each downloaded directory will have this structure:
```
Wan2.1-T2V-*/
├── models_t5_umt5-xxl-enc-bf16.pth # T5 text encoder
├── Wan2.1_VAE.pth # 3D VAE
└── diffusion_pytorch_model*.safetensors # transformer (single)
Wan2.2-T2V-A14B/ or Wan2.2-I2V-A14B/
├── models_t5_umt5-xxl-enc-bf16.pth
├── Wan2.1_VAE.pth
├── low_noise_model/ # dual-model low-noise transformer
└── high_noise_model/ # dual-model high-noise transformer
Wan2.2-TI2V-5B/
├── models_t5_umt5-xxl-enc-bf16.pth
├── Wan2.2_VAE.pth # different VAE (z_dim=48)
└── diffusion_pytorch_model*.safetensors # transformer (single)
```
> **Wan2.2 I2V-14B** shares the same directory structure as Wan2.2 T2V. The conversion script auto-detects I2V from the model's `config.json` (`model_type: "i2v"`, `in_dim: 36`).
### Step 2: Convert to MLX Format
The conversion script auto-detects the model version from the directory structure (presence of `low_noise_model/` → Wan2.2 dual model) and the model type from `config.json` (I2V vs T2V).
#### Wan2.1 T2V 1.3B
```bash
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.1-T2V-1.3B \
--output-dir ./Wan2.1-T2V-1.3B-MLX
```
#### Wan2.1 T2V 14B
```bash
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.1-T2V-14B \
--output-dir ./Wan2.1-T2V-14B-MLX
```
#### Wan2.2 T2V 14B
```bash
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.2-T2V-A14B \
--output-dir ./Wan2.2-T2V-A14B-MLX
```
#### Wan2.2 I2V 14B
```bash
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.2-I2V-A14B \
--output-dir ./Wan2.2-I2V-A14B-MLX
```
The I2V model is auto-detected from `config.json`; the output will include a `vae_encoder.safetensors` used to encode the conditioning image.
#### Wan2.2 TI2V 5B
```bash
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.2-TI2V-5B \
--output-dir ./Wan2.2-TI2V-5B-MLX
```
The TI2V model uses a different VAE (`z_dim=48`, `vae_stride=(4,16,16)`) and is auto-detected during conversion.
---
You can also pass `--model-version 2.1` or `--model-version 2.2` to force the version instead of relying on auto-detection.
#### Conversion Options
| Option | Default | Description |
|--------|---------|-------------|
| `--checkpoint-dir` | (required) | Path to original PyTorch checkpoint directory |
| `--output-dir` | `wan_mlx_model` | Output path for MLX model |
| `--dtype` | `bfloat16` | Target dtype (`float16`, `float32`, `bfloat16`) |
| `--model-version` | `auto` | Model version: `2.1`, `2.2`, or `auto` |
| `--quantize` | off | Quantize transformer weights for reduced memory |
| `--bits` | `4` | Quantization bits: `4` or `8` |
| `--group-size` | `64` | Quantization group size: `32`, `64`, or `128` |
The converter produces:
```
wan_mlx/
├── config.json # Model configuration
├── t5_encoder.safetensors # T5 UMT5-XXL text encoder
├── vae.safetensors # 3D VAE decoder
├── vae_encoder.safetensors # 3D VAE encoder (I2V-14B only)
├── model.safetensors # (Wan2.1) Single transformer
├── low_noise_model.safetensors # (Wan2.2) Low-noise transformer
└── high_noise_model.safetensors # (Wan2.2) High-noise transformer
```
### Step 3: Generate Video
#### Wan2.1 T2V 1.3B
```bash
python -m mlx_video.generate_wan \
--model-dir ./Wan2.1-T2V-1.3B-MLX \
--prompt "A cat playing piano in a cozy living room, cinematic lighting" \
--width 832 --height 480 --num-frames 81 \
--steps 50 --guide-scale 5.0 \
--seed 42 \
--output-path wan21_1b.mp4
```
#### Wan2.1 T2V 14B
```bash
python -m mlx_video.generate_wan \
--model-dir ./Wan2.1-T2V-14B-MLX \
--prompt "A woman walks through a misty forest at dawn, slow motion, cinematic" \
--width 1280 --height 704 --num-frames 81 \
--steps 50 --guide-scale 5.0 \
--seed 42 \
--output-path wan21_14b.mp4
```
> **Tip**: If the first few frames look washed out or have color artifacts, add `--trim-first-frames 1` to generate 4 extra frames at the start and discard them. With the `unipc` scheduler (default), **10 steps** often gives satisfying results — useful for quick iteration.
#### Wan2.2 T2V 14B
Wan2.2 uses a dual-model pipeline (separate high-noise and low-noise transformers) and takes guidance as a `high,low` pair:
```bash
python -m mlx_video.generate_wan \
--model-dir ./Wan2.2-T2V-A14B-MLX \
--prompt "Two astronauts playing chess on the surface of the moon, dramatic lighting, 8K" \
--negative-prompt "low quality, blurry, distorted" \
--width 1280 --height 704 --num-frames 81 \
--steps 40 --guide-scale "3.0,4.0" \
--seed 42 \
--output-path wan22_t2v.mp4
```
> **Tip**: With the `unipc` scheduler (default), **10 steps** often produces satisfying results for 14B models — a significant speed-up with minimal quality loss. Try `--steps 10` for quick iterations.
#### Wan2.2 I2V 14B
Image-to-video: animates a starting image guided by a text prompt. Pass the image with `--image`:
```bash
python -m mlx_video.generate_wan \
--model-dir ./Wan2.2-I2V-A14B-MLX \
--image ./my_photo.png \
--prompt "The person slowly turns their head and smiles, cinematic, natural lighting" \
--negative-prompt "low quality, blurry, distorted" \
--width 1280 --height 704 --num-frames 81 \
--steps 40 --guide-scale "3.5,3.5" \
--seed 42 \
--output-path wan22_i2v.mp4
```
> **Tip**: As with T2V, `--steps 10` with the `unipc` scheduler is often sufficient for fast prototyping.
#### Wan2.2 TI2V 5B
Text+image-to-video: a single-model variant with a larger VAE (`z_dim=48`). Resolution must be divisible by **32** (not 16 as with other models):
```bash
python -m mlx_video.generate_wan \
--model-dir ./Wan2.2-TI2V-5B-MLX \
--image ./my_photo.png \
--prompt "The subject waves hello, warm sunlight, film grain" \
--width 1280 --height 704 --num-frames 41 \
--steps 40 --guide-scale 5.0 \
--seed 42 \
--output-path wan22_ti2v.mp4
```
> **Note**: The 5B model is fast — 40 steps run quickly and are recommended for best quality.
> **Frame count**: `--num-frames` must satisfy `4n+1` for all models (e.g. 5, 9, 13, 21, 41, 81, 101 …).
> **Resolution**: Always use the model's native resolution. While generation will succeed at other sizes, mismatched resolutions or aspect ratios are likely to produce visual artifacts. Preferred resolutions are:
> - **480P** — 832×480 (landscape) or 480×832 (portrait) — for Wan2.1 1.3B
> - **720P** — 1280×704 (landscape) or 704×1280 (portrait) — for Wan2.1 14B, Wan2.2 T2V/I2V/TI2V
#### Generation Options
| Option | Default | Description |
|--------|---------|-------------|
| `--model-dir` | (required) | Path to converted MLX model directory |
| `--prompt` | (required) | Text prompt |
| `--image` | — | Input image path (I2V and TI2V modes) |
| `--negative-prompt` | config default | Negative guidance prompt |
| `--width` | `1280` | Output width in pixels |
| `--height` | `704` | Output height in pixels |
| `--num-frames` | `81` | Number of frames (must be `4n+1`) |
| `--steps` | config default | Diffusion steps |
| `--guide-scale` | config default | Guidance scale; use `"high,low"` pair for Wan2.2 dual models |
| `--shift` | config default | Noise schedule shift |
| `--seed` | `-1` (random) | Random seed for reproducibility |
| `--output-path` | `output.mp4` | Output video file path |
| `--scheduler` | `unipc` | Solver: `euler`, `dpm++`, or `unipc` |
| `--trim-first-frames` | `0` | Drop N leading frames (fixes first-frame artifacts on 14B models) |
| `--tiling` | `auto` | VAE tiling: `auto`, `none`, `spatial`, `temporal` |
### Quantization (Reduced Memory)
Quantize the transformer weights to reduce memory usage by ~3.4×. Quantization is supported for all model variants and is especially important for running 14B models on devices with limited unified memory:
```bash
# Convert with 4-bit quantization (works for any variant)
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.1-T2V-1.3B \
--output-dir ./Wan2.1-T2V-1.3B-MLX-Q4 \
--quantize --bits 4 --group-size 64
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.1-T2V-14B \
--output-dir ./Wan2.1-T2V-14B-MLX-Q4 \
--quantize --bits 4 --group-size 64
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.2-T2V-A14B \
--output-dir ./Wan2.2-T2V-A14B-MLX-Q4 \
--quantize --bits 4 --group-size 64
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.2-I2V-A14B \
--output-dir ./Wan2.2-I2V-A14B-MLX-Q4 \
--quantize --bits 4 --group-size 64
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.2-TI2V-5B \
--output-dir ./Wan2.2-TI2V-5B-MLX-Q4 \
--quantize --bits 4 --group-size 64
```
You can also quantize an already-converted MLX model without re-converting from PyTorch:
```bash
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.2-T2V-A14B-MLX \
--output-dir ./Wan2.2-T2V-A14B-MLX-Q4 \
--quantize-only --bits 4
```
Quantized models are used exactly the same way — the quantization is auto-detected from `config.json`:
```bash
python -m mlx_video.generate_wan \
--model-dir ./Wan2.2-T2V-A14B-MLX-Q4 \
--prompt "A cat playing piano"
```
**What gets quantized**: Self-attention (Q/K/V/O), cross-attention (Q/K/V/O), and FFN (fc1/fc2) — 10 layers × N blocks = ~95% of model weights. Embeddings, norms, and the output head remain in bfloat16 for precision.
| Model | BF16 Size | 4-bit Size | Notes |
|-------|-----------|------------|-------|
| 1.3B | 2.7 GB | 799 MB | ~3.4x smaller |
| 14B | ~28 GB | ~8 GB | Enables running on 16GB devices |
> **Note**: On Apple Silicon, the 1.3B model fits comfortably in unified memory at bf16. Quantization reduces memory but may not speed up inference for small models. For the 14B model, quantization is essential to fit in memory and will also improve speed.
### Wan Model Specifications
**Transformer (14B)**
- 40 layers, 40 attention heads, dim 5120, head dim 128
- 3-way factorized RoPE (temporal + spatial)
- 14.29B parameters
**Transformer (1.3B, Wan2.1 only)**
- 30 layers, 12 attention heads, dim 1536, head dim 128
- Same architecture, smaller scale
**Text Encoder** — UMT5-XXL (5.68B parameters)
- 24 layers, 64 heads, dim 4096, vocab 256K
**VAE** — 3D causal convolution decoder (72.6M parameters)
- Latent channels: 16
- Compression: 4× temporal, 8× spatial
---
## LoRA Support
LoRA's can be used with the `--lora-high` and `--lora-low` command line switches.
For example, for using the the distilled [Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) LoRA, use the following command. Lightning speeds up generation by using only 4 steps and a CFG scale of 1.
```bash
python -m mlx_video.generate_wan \
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-T2V-A14B-MLX \
--width 480 \
--height 704 \
--num-frames 41 \
--prompt "Two dogs of the poodle breed sitting on a beach wearing sunglasses, nodding with their heads, close up, cinematic, sunset" \
--steps 4 \
--guide-scale 1 \
--trim-first-frames 1 \
--seed 2391784614 \
--lora-high /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/high_noise_model.safetensors 1 \
--lora-low /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/low_noise_model.safetensors 1
```
## Enjoy
![Poodles](../../../examples/poodles-wan.gif)

View File

@@ -0,0 +1,2 @@
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan.model import WanModel

View File

@@ -0,0 +1,221 @@
import mlx.core as mx
import mlx.nn as nn
from .rope import rope_apply
def _linear_dtype(layer) -> mx.Dtype:
"""Get the compute dtype of a linear layer, handling QuantizedLinear and LoRA wrappers."""
# Unwrap LoRA wrapper to get the underlying linear layer
inner = getattr(layer, "linear", layer)
if isinstance(inner, nn.QuantizedLinear):
return inner.scales.dtype
return inner.weight.dtype
class WanRMSNorm(nn.Module):
"""RMS normalization with learnable scale."""
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = mx.ones((dim,))
def __call__(self, x: mx.array) -> mx.array:
return mx.fast.rms_norm(x, self.weight, self.eps)
class WanLayerNorm(nn.Module):
"""LayerNorm computed in float32, with optional affine."""
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False):
super().__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
if elementwise_affine:
self.weight = mx.ones((dim,))
self.bias = mx.zeros((dim,))
def __call__(self, x: mx.array) -> mx.array:
if self.elementwise_affine:
return mx.fast.layer_norm(x, self.weight, self.bias, self.eps)
else:
return mx.fast.layer_norm(x, None, None, self.eps)
class WanSelfAttention(nn.Module):
"""Self-attention with QK normalization and 3-way factorized RoPE."""
def __init__(
self,
dim: int,
num_heads: int,
window_size: tuple = (-1, -1),
qk_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
assert dim % num_heads == 0
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.scale = self.head_dim**-0.5
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
def __call__(
self,
x: mx.array,
seq_lens: list,
grid_sizes: list,
freqs: mx.array,
rope_cos_sin: tuple | None = None,
attn_mask: mx.array | None = None,
) -> mx.array:
b, s, _ = x.shape
n, d = self.num_heads, self.head_dim
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = _linear_dtype(self.q)
x_w = x.astype(w_dtype)
q = self.q(x_w)
k = self.k(x_w)
if self.norm_q is not None:
q = self.norm_q(q)
if self.norm_k is not None:
k = self.norm_k(k)
q = q.reshape(b, s, n, d)
k = k.reshape(b, s, n, d)
v = self.v(x_w).reshape(b, s, n, d)
# RoPE in float32 for precision (official uses float64)
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
k = k.astype(w_dtype).transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
# Use precomputed mask or build from seq_lens
mask = attn_mask
if mask is None and any(sl < s for sl in seq_lens):
mask = mx.zeros((b, 1, 1, s), dtype=q.dtype)
for i, sl in enumerate(seq_lens):
mask[i, :, :, sl:] = -1e9
# Use memory-efficient scaled dot-product attention
# mx.fast.scaled_dot_product_attention expects [B, N, L, D]
if mask is not None:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale, mask=mask
)
else:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale
)
out = out.transpose(0, 2, 1, 3).reshape(b, s, -1)
return self.o(out)
class WanCrossAttention(nn.Module):
"""Cross-attention: Q from hidden states, K/V from text context."""
def __init__(
self,
dim: int,
num_heads: int,
qk_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
def prepare_kv(self, context: mx.array) -> tuple:
"""Pre-compute K and V projections for caching.
Args:
context: [B, L_ctx, dim]
Returns:
(k, v) each [B, N, L_ctx, D] ready for attention
"""
b = context.shape[0]
n, d = self.num_heads, self.head_dim
# Cast to compute dtype for efficient matmul
w_dtype = _linear_dtype(self.k)
ctx = context.astype(w_dtype)
k = self.k(ctx)
if self.norm_k is not None:
k = self.norm_k(k)
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
return k, v
def __call__(
self,
x: mx.array,
context: mx.array,
context_lens: list | None = None,
kv_cache: tuple | None = None,
) -> mx.array:
b = x.shape[0]
n, d = self.num_heads, self.head_dim
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = _linear_dtype(self.q)
q = self.q(x.astype(w_dtype))
if self.norm_q is not None:
q = self.norm_q(q)
q = q.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
if kv_cache is not None:
k, v = kv_cache
else:
ctx = context.astype(w_dtype)
k = self.k(ctx)
if self.norm_k is not None:
k = self.norm_k(k)
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
# Optional context masking
mask = None
if context_lens is not None:
ctx_len = k.shape[2]
mask = mx.zeros((b, 1, 1, ctx_len), dtype=q.dtype)
for i, cl in enumerate(context_lens):
mask[i, :, :, cl:] = -1e9
if mask is not None:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale, mask=mask
)
else:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale
)
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d)
return self.o(out)

View File

@@ -0,0 +1,129 @@
from dataclasses import dataclass
from typing import Tuple, Union
from mlx_video.models.ltx.config import BaseModelConfig
@dataclass
class WanModelConfig(BaseModelConfig):
"""Configuration for Wan T2V models (supports both 2.1 and 2.2)."""
model_type: str = "t2v"
model_version: str = "2.2"
patch_size: Tuple[int, int, int] = (1, 2, 2)
text_len: int = 512
in_dim: int = 16
dim: int = 5120
ffn_dim: int = 13824
freq_dim: int = 256
text_dim: int = 4096
out_dim: int = 16
num_heads: int = 40
num_layers: int = 40
window_size: Tuple[int, int] = (-1, -1)
qk_norm: bool = True
cross_attn_norm: bool = True
eps: float = 1e-6
# VAE
vae_stride: Tuple[int, int, int] = (4, 8, 8)
vae_z_dim: int = 16
# Inference
dual_model: bool = True
boundary: float = 0.875
sample_shift: float = 12.0
sample_steps: int = 40
sample_guide_scale: Union[float, Tuple[float, float]] = (3.0, 4.0)
num_train_timesteps: int = 1000
sample_fps: int = 16
frame_num: int = 81
sample_neg_prompt: str = (
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,"
"最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部"
"画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,"
"杂乱的背景,三条腿,背景人很多,倒着走"
)
# Resolution constraints
max_area: int = 0 # 0 = no limit; e.g. 704*1280 for TI2V-5B
t5_vocab_size: int = 256384
t5_dim: int = 4096
t5_dim_attn: int = 4096
t5_dim_ffn: int = 10240
t5_num_heads: int = 64
t5_num_layers: int = 24
t5_num_buckets: int = 32
@property
def head_dim(self) -> int:
return self.dim // self.num_heads
@classmethod
def wan21_t2v_14b(cls) -> "WanModelConfig":
"""Wan2.1 T2V 14B: single model, 40 layers, dim=5120."""
return cls(
model_version="2.1",
dual_model=False,
boundary=0.0,
sample_shift=5.0,
sample_steps=50,
sample_guide_scale=5.0,
)
@classmethod
def wan21_t2v_1_3b(cls) -> "WanModelConfig":
"""Wan2.1 T2V 1.3B: single model, 30 layers, dim=1536."""
return cls(
model_version="2.1",
dim=1536,
ffn_dim=8960,
num_heads=12,
num_layers=30,
dual_model=False,
boundary=0.0,
sample_shift=5.0,
sample_steps=50,
sample_guide_scale=5.0,
)
@classmethod
def wan22_t2v_14b(cls) -> "WanModelConfig":
"""Wan2.2 T2V 14B: dual model, 40 layers, dim=5120 (default)."""
return cls()
@classmethod
def wan22_i2v_14b(cls) -> "WanModelConfig":
"""Wan2.2 I2V 14B: dual model, image-to-video, 40 layers, dim=5120."""
return cls(
model_type="i2v",
in_dim=36,
out_dim=16,
dual_model=True,
boundary=0.900,
sample_shift=5.0,
sample_guide_scale=(3.5, 3.5),
max_area=704 * 1280,
)
@classmethod
def wan22_ti2v_5b(cls) -> "WanModelConfig":
"""Wan2.2 TI2V 5B: text+image to video, 30 layers, dim=3072."""
return cls(
model_type="ti2v",
dim=3072,
ffn_dim=14336,
in_dim=48,
out_dim=48,
num_heads=24,
num_layers=30,
vae_z_dim=48,
vae_stride=(4, 16, 16),
dual_model=False,
boundary=0.0,
sample_shift=5.0,
sample_steps=40,
sample_guide_scale=5.0,
sample_fps=24,
max_area=704 * 1280,
)

View File

@@ -0,0 +1,394 @@
# Wan2.2 I2V-14B Diagnostic Report
This document records the systematic diagnostic methodology used to debug the Wan2.2 I2V-14B (Image-to-Video, 14 billion parameter) pipeline in mlx-video, along with every bug found, its root cause, and fix.
## Table of Contents
- [Overview](#overview)
- [Architecture Summary](#architecture-summary)
- [Diagnostic Methodology](#diagnostic-methodology)
- [Bug 1: Text Embedding Cross-Contamination](#bug-1-text-embedding-cross-contamination)
- [Bug 2: VAE Encoder Weights Excluded from Conversion](#bug-2-vae-encoder-weights-excluded-from-conversion)
- [Bug 3: RoPE Frequency Computation (original)](#bug-3-rope-frequency-computation-original)
- [Bug 6: RoPE Frequency Distribution (Bug 3 Fix Was Wrong)](#bug-6-rope-frequency-distribution-bug-3-fix-was-wrong)
- [Bug 4: VAE Encoder Temporal Downsample Order](#bug-4-vae-encoder-temporal-downsample-order)
- [Bug 5: Non-Chunked VAE Encoding](#bug-5-non-chunked-vae-encoding)
- [Verified Correct Components](#verified-correct-components)
- [Performance Optimizations](#performance-optimizations)
- [Resolved: CFG Effectiveness](#resolved-cfg-effectiveness-was-open-investigation)
- [Reference Implementation](#reference-implementation)
- [Useful Diagnostic Commands](#useful-diagnostic-commands)
---
## Overview
The I2V-14B pipeline takes an input image and generates a video using a dual-model diffusion transformer. The initial implementation produced severely broken output — first frame showed the image, subsequent frames degraded to noise, checkerboard artifacts, or flat grey.
Through a systematic component-by-component comparison against the reference PyTorch implementation, **five bugs** were found and fixed. The approach was to verify each component in isolation numerically, then narrow down failures to the subsystem level.
### Timeline of Symptoms
| Stage | Symptom | Root Cause |
|-------|---------|------------|
| Initial | Grey/blurry frames after frame 1 | Non-chunked VAE encoding (Bug 5) |
| After chunked encoding fix | First frame OK, rest degrades to noise | Text embedding cross-contamination (Bug 1) + RoPE frequencies (Bug 3) |
| After text + RoPE fix | Severe 8px checkerboard on frames 4+ | VAE encoder temporal downsample order (Bug 4) |
| After VAE fix | Image in frames 0-3, grey frames 4+ | CFG effectiveness issue (open investigation) |
---
## Architecture Summary
```
I2V-14B Pipeline:
Input Image → VAE Encoder → [16, T_lat, H_lat, W_lat]
Mask Construction → [4, T_lat, H_lat, W_lat]
y = concat(mask, encoded_video) → [20, T_lat, H_lat, W_lat]
Noise [16, T_lat, H_lat, W_lat] + y → [36, T_lat, H_lat, W_lat]
Dual DiT (40 layers, 5120 dim) × 40 denoising steps
Denoised Latent [16, T_lat, H_lat, W_lat]
VAE Decoder → Video [3, F, H, W]
```
**Key parameters:**
- `in_dim=36` (16 noise + 4 mask + 16 image latents), `out_dim=16`
- Dual model: HIGH noise (t ≥ 900) and LOW noise (t < 900)
- 40 steps, shift=5.0, guide_scale=(3.5, 3.5)
- Uses Wan2.1 VAE (z_dim=16, stride 4×8×8)
---
## Diagnostic Methodology
### 1. Component-Level Numerical Verification
Each component was tested in isolation against the reference PyTorch implementation:
1. **Load identical inputs** (same random seed, same image, same prompt)
2. **Run through reference** (on CPU where possible) and save intermediate tensors as `.npy`
3. **Run through MLX** with the same inputs
4. **Compare outputs** with `np.abs(ours - ref).max()` and relative difference metrics
Components tested this way:
- RoPE frequency parameters and rotation output
- Time embedding (sinusoidal → MLP → projection)
- Patchify (reshape+Linear vs Conv3d)
- Unpatchify (transpose-based vs einsum)
- Scheduler (UniPC) timesteps and step formulas
- VAE encoder output (frame-by-frame comparison)
- Text embeddings (per-model MLP output)
- Cross-attention K/V cache shapes
- Mask construction values
### 2. Artifact Analysis
When visual artifacts appeared, quantitative metrics were used to characterize them:
- **Checkerboard metric**: Difference between even-indexed and odd-indexed pixels at patch boundaries. Values > 20 indicate visible checkerboard.
- **FFT frequency analysis**: Power at the 8px spatial frequency (matches VAE stride). 3× normal power confirmed VAE-stride-aligned artifacts.
- **Per-frame statistics**: Mean, std, min, max for each decoded video frame to track temporal degradation.
- **Frame difference**: `mean(|frame[i] - frame[i-1]|)` to measure motion vs static content.
### 3. Isolation Testing
- **VAE round-trip test**: Encode image+zeros → decode. If clean, VAE decoder is not the source.
- **Single-step model output**: Run one diffusion step and compare cond vs uncond predictions to check CFG effectiveness.
- **Patchify/unpatchify synthetic test**: Pass structured gradient through unpatchify to verify spatial ordering.
- **Resolution sweeps**: Test at 480×272, 640×384, 1280×720 to check resolution dependence.
- **Step count sweeps**: Test at 5, 20, 40 steps to distinguish convergence issues from model bugs.
### 4. Weight Comparison
Direct comparison of converted MLX weights against original PyTorch weights:
```python
# Load both weight sets
pt_weights = torch.load("model.safetensors")
mlx_weights = mx.load("model.safetensors")
# Compare each key
for key in pt_weights:
diff = np.abs(np.array(pt_weights[key]) - np.array(mlx_weights[key])).max()
```
Expected: max diff ≈ 0.001 (bfloat16 rounding). Actual: confirmed for all keys.
---
## Bug 1: Text Embedding Cross-Contamination
**Symptom:** Model ignores text prompt, generated frames lack semantic content.
**Root Cause:** For the dual-model architecture (high-noise and low-noise experts), text embeddings were computed using only `low_noise_model.embed_text()` and reused for both models' cross-attention K/V caches. The two models have **different** text embedding MLP weights — 42% relative mean difference in output.
**How Found:** Compared `text_embedding_0.weight` and `text_embedding_1.weight` between `high_noise_model.safetensors` and `low_noise_model.safetensors`. Found 17.9% and 26.3% relative differences in the weight matrices.
**Fix:** Compute separate text embeddings per model:
```python
# Before (broken):
context_emb = low_noise_model.embed_text([context, context_null])
cross_kv = low_noise_model.prepare_cross_kv(context_emb) # used for BOTH models
# After (correct):
context_emb_low = low_noise_model.embed_text([context, context_null])
context_emb_high = high_noise_model.embed_text([context, context_null])
cross_kv_low = low_noise_model.prepare_cross_kv(context_emb_low)
cross_kv_high = high_noise_model.prepare_cross_kv(context_emb_high)
```
**File:** `mlx_video/generate_wan.py` (lines 333349)
**Commit:** `a85b1c21`
---
## Bug 2: VAE Encoder Weights Excluded from Conversion
**Symptom:** VAE encoder produces constant output regardless of input image (all-zero weights after conversion).
**Root Cause:** The conversion script only included encoder weights for `model_type == "ti2v"` (TI2V-5B), not for `"i2v"` (I2V-14B). Since `load_vae_encoder()` uses `strict=False`, missing encoder weights were silently ignored, resulting in random initialization.
**How Found:** Traced through `convert_wan.py` and found `include_encoder = config.model_type == "ti2v"`. Cross-referenced with the fact that I2V-14B also requires a VAE encoder (for image conditioning).
**Fix:**
```python
# Before:
include_encoder = config.model_type == "ti2v"
# After:
include_encoder = config.model_type in ("ti2v", "i2v")
```
**Note:** The user's specific model happened to be manually converted with encoder weights already present, so this fix was preventive for future conversions.
**File:** `mlx_video/convert_wan.py` (line 424)
---
## Bug 3: RoPE Frequency Computation (original)
**Symptom:** Progressive 2px checkerboard artifacts on generated frames, increasing with temporal distance from the conditioned frame.
**Root Cause (original):** Our original code called `rope_params` three times but applied them incorrectly (per-axis in the model init, then rope_apply did NOT split). This was initially "fixed" by switching to a single `rope_params(1024, head_dim=128)` call, which reduced checkerboard but introduced Bug 6 (see below).
**File:** `mlx_video/models/wan/model.py`
**Commit:** `3da4a637`
---
## Bug 6: RoPE Frequency Distribution (Bug 3 Fix Was Wrong)
**Symptom:** I2V generates input image in frames 03, colorful checkerboard on frame 4, then grey frames. CFG cond/uncond predictions nearly identical. Model cannot produce coherent motion.
**Root Cause:** The Bug 3 "fix" replaced three separate `rope_params` calls with a single `rope_params(1024, 128)`. But the reference (`wan/modules/model.py` lines 400405) actually uses **three separate calls with different dimension normalizations**, concatenated:
```python
# Reference (CORRECT):
d = dim // num_heads # 128
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)), # rope_params(1024, 44)
rope_params(1024, 2 * (d // 6)), # rope_params(1024, 42)
rope_params(1024, 2 * (d // 6)) # rope_params(1024, 42)
], dim=1)
```
Each axis gets its own full frequency range [θ^0, θ^(-~0.95)]. The single-call approach gave:
- Temporal: low frequencies only [1.0 → 0.049]
- Height: medium frequencies only [0.042 → 0.002] (should start at 1.0!)
- Width: high frequencies only [0.002 → 0.0001] (should start at 1.0!)
The height/width position encoding was essentially destroyed — nearby spatial positions were indistinguishable (max diff 0.958 for height, 0.998 for width vs reference).
**How Found:** Direct line-by-line comparison of `WanModel.__init__` freq construction between reference `wan/modules/model.py` and our `models/wan/model.py`. Numerical verification confirmed the three-call approach gives each axis a full [0, ~1) exponent range, while the single-call monotonically assigns low→high across axes.
**Fix:**
```python
d = dim // config.num_heads
self.freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
```
**Verification:** Max diff vs reference cos/sin: 0.00000000 (exact float32 match).
**Impact:** Affects ALL Wan models (T2V, I2V, TI2V). Resolves the "Open Investigation: CFG Effectiveness" issue — the model could not produce meaningful cond/uncond differences because it couldn't encode spatial positions.
**File:** `mlx_video/models/wan/model.py` (line 155)
---
## Bug 4: VAE Encoder Temporal Downsample Order
**Symptom:** Massive checkerboard artifacts aligned to VAE spatial stride (8px period). VAE encoder output for frames 14 showed decreasing std (0.37→1.19) while reference showed stable std (0.95→1.34).
**Root Cause:** The VAE encoder has 3 downsampling stages. Two perform spatial+temporal downsampling (`downsample3d`) and one performs spatial-only (`downsample2d`). The order matters:
```
Reference: [False, True, True] → stage 0: 2d, stage 1: 3d, stage 2: 3d
Ours: [True, True, False] → stage 0: 3d, stage 1: 3d, stage 2: 2d ← WRONG
```
This caused temporal downsampling to happen at the wrong resolution stages (96-dim instead of 384-dim), corrupting temporal feature propagation.
**How Found:** Installed `einops` in the reference environment and ran the reference PyTorch VAE encoder on CPU. Compared frame-by-frame latent output:
- Frame 0 matched exactly (diff=0.0000) — spatial-only processing was correct
- Frames 14 had massive differences — proved temporal processing was broken
Then traced through the reference `_video_vae()` function and found it sets `temperal_downsample=[False, True, True]`, while our `Encoder3d` class used the wrong default `[True, True, False]`.
**Fix:**
```python
# In Encoder3d.__init__, change default:
temporal_downsample = [False, True, True] # was [True, True, False]
```
**Impact:** Encoder output now matches reference within float32 precision (max_diff=2.2e-5). Checkerboard metric dropped from 6080 to 0.17.7.
**File:** `mlx_video/models/wan/vae.py` (line 370)
**Commit:** `3da4a637`
---
## Bug 5: Non-Chunked VAE Encoding
**Symptom:** First 45 frames grey, then blurred version of image appears.
**Root Cause:** The reference VAE encoder uses **chunked encoding** with temporal caching (`feat_cache`):
1. Encode first frame alone (1 frame)
2. Encode remaining frames in chunks of 4, with cached temporal features propagating across chunks
3. Each `CausalConv3d` caches last 2 temporal frames from its output, prepending them to the next chunk's input
Our original implementation encoded all frames at once with zero-padded causal convolutions. The temporal feature propagation is fundamentally different because:
- Chunked: real features from previous chunks serve as causal context
- Non-chunked: zeros serve as causal context for the start
**How Found:** Studied the reference `CausalConv3d` caching mechanism (`feat_cache`, `feat_idx`) and traced the temporal dimension through all encoding stages. Confirmed that non-chunked encoding produces different output by comparing tensor shapes and values.
**Fix:** Implemented full chunked encoding with temporal caching:
- Added `cache_x` parameter to `CausalConv3d.__call__`
- Added `feat_cache`/`feat_idx` propagation to `ResidualBlock`, `Resample`, `Encoder3d`
- Rewrote `WanVAE.encode()` with chunked loop (1-frame first chunk, then 4-frame chunks)
- 24 cache slots across the encoder (1 conv1 + 18 downsamples + 4 middle + 1 head)
**File:** `mlx_video/models/wan/vae.py` (multiple methods)
**Commit:** `b6a94c4c`
---
## Verified Correct Components
These components were numerically verified against the reference and are **not** sources of bugs:
| Component | Method | Max Diff | Notes |
|-----------|--------|----------|-------|
| Weight conversion | Direct tensor comparison | ~0.001 | bfloat16 rounding only |
| RoPE rotation | Standalone comparison (float32 vs float64) | 1.3e-5 | Complex vs real multiplication equivalent |
| Time embedding | Full MLP comparison (sinusoidal→embed→project) | 7e-4 | 0.03% relative |
| Patchify | Conv3d vs reshape+Linear | 3.5e-3 | 0.16% relative |
| Unpatchify | einsum vs transpose(6,0,3,1,4,2,5) | exact | Identical operation |
| Scheduler (UniPC) | Formula-level audit + timestep comparison | exact | Predictor, corrector, lambda, rhos all match |
| Mask construction | Value comparison | exact | [4, T_lat, H_lat, W_lat], first temporal=1 |
| CFG formula | Code audit | — | `uncond + gs * (cond - uncond)` correct order |
| VAE decoder | Round-trip test (encode→decode) | clean | No checkerboard in round-trip output |
| Cross-attention K/V | Shape and value audit | — | Batch dimension preserved correctly |
---
## Performance Optimizations
Applied alongside bug fixes to improve inference speed:
### Pre-Computation (Before Diffusion Loop)
- **Cross-attention K/V caching**: Precompute K/V projections for all 40 blocks once
- **RoPE cos/sin precomputation**: Build frequency tensors once instead of per-step broadcast/concat
- **Attention mask precomputation**: Build padding mask once, pass via kwargs
- **Inverse frequency caching**: Store sinusoidal `inv_freq` in `__init__` instead of recomputing
- **Timestep list conversion**: `sched.timesteps.tolist()` before loop to avoid `.item()` sync
### Per-Step Optimizations
- **Single patchify + broadcast for CFG B=2**: Detect identical batch inputs, patchify once and broadcast instead of duplicating the Linear projection
- **Vectorized RoPE**: When all batch elements share the same grid size, apply rotation to the full batch tensor instead of looping per element
- **Redundant type cast removal**: MLX type promotion handles `bfloat16 * float32 → float32` automatically — removed 240 unnecessary graph nodes per step (6 casts × 40 blocks)
- **Euler scheduler sync fix**: Pre-store sigmas as Python floats to avoid `.item()` evaluation sync
---
## Resolved: CFG Effectiveness (was Open Investigation)
**Symptom:** Generated video shows the input image in frames 03 (latent frame 0), then grey/flat frames for the rest. Cond and uncond predictions were nearly identical.
**Resolution:** This was caused by Bug 6 (incorrect RoPE frequency distribution). The single `rope_params(1024, 128)` call gave height frequencies starting at 0.042 and width at 0.002 (instead of 1.0 for both), making the model unable to encode spatial positions. This caused the transformer to produce nearly identical outputs regardless of text conditioning, explaining the tiny cond/uncond differences.
---
## Reference Implementation
The reference PyTorch implementation is at `/Users/daniel/Projects/Wan2.2/`:
| File | Contents |
|------|----------|
| `wan/image2video.py` | I2V pipeline (y construction, mask, diffusion loop) |
| `wan/modules/model.py` | DiT model (forward pass, RoPE, patchify) |
| `wan/modules/vae2_1.py` | VAE encoder/decoder with chunked encoding |
| `wan/utils/fm_solvers_unipc.py` | UniPC scheduler |
| `wan/configs/wan_i2v_A14B.py` | Model configuration |
Key structural differences between reference and our implementation:
- Reference runs **separate B=1 forward passes** for cond/uncond; we batch as B=2
- Reference uses `torch.amp.autocast('cuda', dtype=bfloat16)` with explicit float32 blocks; we cast via weight dtype
- Reference uses `Conv3d` for patchify; we use equivalent `reshape + Linear`
- Reference casts timesteps to `int64`; we keep as float (diff < 1.0)
---
## Useful Diagnostic Commands
### Run I2V-14B generation
```bash
python -m mlx_video.generate_wan \
--prompt "A woman smiles at camera" \
--image start.png \
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-I2V-A14B-MLX \
--num-frames 17 --steps 40 \
--height 384 --width 640 \
--output output_i2v.mp4
```
### Check VAE encoder output
```python
import mlx.core as mx, numpy as np
from mlx_video.models.wan.vae import WanVAE
# Load VAE and encode an image
latents = vae.encode(video_tensor) # [1, 16, T_lat, H_lat, W_lat]
for t in range(latents.shape[2]):
frame = np.array(latents[0, :, t])
print(f"Frame {t}: mean={frame.mean():.4f} std={frame.std():.4f}")
```
### Analyze video frame quality
```python
import cv2, numpy as np
cap = cv2.VideoCapture("output.mp4")
while True:
ret, frame = cap.read()
if not ret: break
# Checkerboard metric: high values indicate patch-boundary artifacts
checker = np.abs(frame[::2, ::2].astype(float) - frame[1::2, 1::2].astype(float)).mean()
print(f"std={frame.std():.1f} checker={checker:.1f}")
```
### Compare weights between PyTorch and MLX
```python
import torch, mlx.core as mx, numpy as np
pt = torch.load("model.pt", map_location="cpu")
mlx_w = mx.load("model.safetensors")
for key in sorted(pt.keys()):
if key in mlx_w:
diff = np.abs(pt[key].float().numpy() - np.array(mlx_w[key])).max()
if diff > 0.01:
print(f"LARGE DIFF {key}: {diff:.6f}")
```

View File

@@ -0,0 +1,285 @@
# Wan2.2 MLX Implementation Notes
> Learnings and key decisions from porting Wan2.2 (TI2V-5B / T2V-14B / I2V-14B / T2V-1.3B) to Apple MLX.
## Architecture Overview
Wan2.2 is a Diffusion Transformer (DiT) for video generation. Despite early reports, the T2V/TI2V models do **not** use Mixture-of-Experts — they are dense DiT models with a dual-model architecture for the 14B variant (separate high-noise and low-noise denoisers with a boundary timestep).
### Key Parameters
| Model | dim | heads | layers | FFN mult | VAE z_dim | VAE stride | in_dim |
|-------|-----|-------|--------|----------|-----------|------------|--------|
| T2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | 16 |
| I2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | 36 |
| TI2V-5B | 3072 | 24 | 32 | 4×(3072×4/3) | 48 | (4, 16, 16) | 48 |
| T2V-1.3B | 1536 | 12 | 30 | 4×(1536×4/3) | 16 | (4, 8, 8) | 16 |
### Codebase Structure (~3900 lines of Wan2.2 code)
```
mlx_video/
├── generate_wan.py # 483L - Generation pipeline (T2V + I2V)
├── convert_wan.py # 564L - Weight conversion from HuggingFace
└── models/wan/
├── config.py # 113L - Model configs (dataclass presets)
├── model.py # 320L - DiT model (time embed, patchify, unpatchify)
├── transformer.py # 91L - Attention block + FFN
├── attention.py # 211L - Self-attention + cross-attention
├── rope.py # 100L - 3D Rotary Position Embeddings
├── text_encoder.py # 240L - T5 encoder (UMT5-XXL)
├── scheduler.py # 428L - Euler, DPM++ 2M, UniPC schedulers
├── vae.py # 315L - Wan2.1 VAE decoder (4×8×8)
├── vae22.py # 836L - Wan2.2 VAE encoder + decoder (4×16×16)
├── loading.py # 154L - Model loading utilities
└── i2v_utils.py # 58L - I2V mask/preprocessing
```
---
## Critical Bugs & Fixes
### 1. MLX Underscore Attribute Gotcha
**Problem**: MLX's `nn.Module` silently ignores underscore-prefixed attributes (`_layer_0`, `_layer_1`, etc.) in `parameters()` and `load_weights()`. The Wan2.2 VAE had layers named `_layer_N`, causing **87 out of 110 weights to be silently dropped** during loading.
**Fix**: Rename all `_layer_N` attributes to `layer_N`. MLX treats underscore-prefixed attributes as "private" and excludes them from the parameter tree.
**Lesson**: Never use underscore-prefixed names for `nn.Module` sub-modules in MLX.
### 2. Patchify Channel Ordering
**Problem**: The patchify/unpatchify operations transposed channels incorrectly — producing `[C fastest]` layout instead of `[C slowest]`, causing completely garbled video output.
**Fix**: Changed reshape to produce correct `[B, T', H', W', pt*ph*pw*C]` ordering matching PyTorch's contiguous memory layout.
**Lesson**: When porting PyTorch reshape/view operations to MLX, pay close attention to memory layout — PyTorch is row-major by default, and reshape semantics differ when dimensions are reordered.
### 3. VAE AttentionBlock Reshape
**Problem**: Attention block merged batch (B) with channels (C) instead of batch with temporal (T), producing a green checker pattern in output.
**Fix**: Correct reshape from `[B*C, T, H, W]` to `[B*T, C, H, W]` for spatial attention.
### 4. RMS Norm vs L2 Norm
**Problem**: The Wan2.2 VAE uses a class named `RMS_norm` in PyTorch, but it actually computes **L2 normalization** (divide by L2 norm), not RMS normalization (divide by RMS). Using actual RMS norm caused exponential value explosion.
**Fix**: Implement as `x / ||x||₂` instead of `x / sqrt(mean(x²))`.
**Lesson**: Don't trust class names in reference code — read the actual computation.
### 5. Video Codec Green Output
**Problem**: OpenCV's `mp4v` codec on macOS produces green-tinted video.
**Fix**: Switch to `imageio` with `libx264` codec. Fallback chain: imageio → cv2 (avc1) → PNG frames.
---
## Precision & Dtype Flow
### The bfloat16 Autocast Pattern
The official PyTorch implementation uses `torch.autocast("cuda", dtype=torch.bfloat16)` which automatically casts matmul inputs. In MLX, we replicate this manually:
| Operation | Official (PyTorch) | MLX Implementation |
|---|---|---|
| Modulation/gates | float32 (explicit `autocast(enabled=False)`) | `x.astype(mx.float32)` before modulation |
| QKV projections | bfloat16 (outer autocast) | Cast input to `self.q.weight.dtype` |
| RoPE computation | float64 → float32 | float32 (MLX lacks float64 on GPU) |
| Q/K after RoPE | bfloat16 (`q.to(v.dtype)`) | Cast back to weight dtype after RoPE |
| FFN matmuls | bfloat16 (outer autocast) | Cast input to `self.fc1.weight.dtype` |
| Residual stream | float32 | float32 (no cast) |
**Result**: ~16% speedup (47s vs 56s for 20 steps at 480p) with no quality regression.
**Key insight**: Modulation parameters (scale, shift, gate) must stay in float32 — they are small values (~0.010.1) that lose significant precision in bfloat16. The official code explicitly disables autocast for these computations.
### T5 Encoder Precision
The T5 text encoder must run in float32. Bfloat16 weights cause the attention softmax to produce degenerate distributions, which corrupts text conditioning and manifests as blurry patches in generated video. Since T5 only runs once per generation, the performance cost is negligible.
### VAE Decoder Precision
VAE weights must be float32. Bfloat16 VAE decode introduces visible quality loss in the decoded video frames.
---
## Scheduler Implementation Details
### Three Schedulers: Euler, DPM++ 2M, UniPC
All operate in the flow-matching formulation where `sigma` represents the noise level (1.0 = pure noise, 0.0 = clean).
**Euler**: Simple first-order ODE solver. Most stable, recommended for debugging.
**DPM++ 2M**: Second-order multistep solver. Uses previous step's model output for higher-order correction. Requires special handling at boundaries (return `±inf` from `_lambda()` when sigma is 0 or 1).
**UniPC** (default, matches official): Second-order predictor-corrector. The "C" (corrector) part is critical — it refines each step using the already-computed model output at **zero additional model evaluation cost**.
### UniPC Corrector: Must Be Enabled
**Discovery**: Our implementation had `use_corrector=False` by default, but the official Wan2.2 code **always** enables it (there's no flag — the corrector runs whenever `step_index > 0`).
**Impact**: Without the corrector, UniPC degrades to a simple predictor, losing its second-order accuracy advantage.
### UniPC Corrector Coefficients
The corrector coefficients (`rhos_c`) must be computed by solving a linear system, not hardcoded. For order ≥ 2, hardcoding `rhos_c[-1] = 0.5` introduces ~613% error in the correction term across 47+ steps. The fix uses `np.linalg.solve()` to compute exact coefficients.
### Sigma Schedule
```python
# Flow-matching sigma schedule with shift
sigmas = np.linspace(1.0, 1.0 / num_steps, num_steps)
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
```
Default shifts: T2V-14B uses 5.0, TI2V-5B uses 3.0, T2V-1.3B uses 3.0.
---
## Image-to-Video (I2V) Pipelines
Wan2.2 supports two distinct I2V approaches:
### TI2V-5B: Per-Token Timestep Masking
I2V conditions on a reference first frame by giving first-frame latent patches a timestep of 0 (clean) while other patches get the current diffusion timestep:
```python
# mask_tokens: [1, L] — 0 for first-frame patches, 1 for rest
t_tokens = mask_tokens * current_timestep # first-frame → t=0
```
The model receives 2D timestep input `[B, L]` instead of scalar, enabling per-token noise levels.
#### Mask Re-application
After each scheduler step, the first-frame latent is re-injected to prevent drift:
```python
latents = (1.0 - mask) * z_img + mask * latents
```
#### VAE Encoder Temporal Downsample Order
The Wan2.2 VAE encoder has `temporal_downsample = (False, True, True)`:
- Stage 0: Spatial-only downsampling
- Stages 12: Spatial + temporal downsampling
This was incorrectly set to `(True, True, False)` initially, causing wrong spatial processing paths.
### I2V-14B: Channel Concatenation
The I2V-14B model uses a fundamentally different approach — channel concatenation via a `y` tensor:
1. **Encode image**: Resize to target (H, W), create video tensor with image as first frame + zeros → VAE encode through Wan2.1 encoder → `[16, T_lat, H_lat, W_lat]`
2. **Build mask**: Binary mask with 1 for first frame, 0 for rest → rearranged to `[4, T_lat, H_lat, W_lat]`
3. **Construct y**: `y = concat([mask_4ch, encoded_16ch])``[20, T_lat, H_lat, W_lat]`
4. **Channel concat in model**: Before patchify, `x = concat([noise_16ch, y_20ch])` → 36 channels matching `in_dim=36`
Key differences from TI2V-5B:
- Uses **Wan2.1 VAE** (z_dim=16, stride 4,8,8), not Wan2.2 VAE
- Requires the **VAE encoder** (for encoding the reference image)
- Uses **scalar timesteps** (same as T2V) — no per-token masking
- **Dual model** pipeline with boundary=0.900
- Both conditional and unconditional predictions receive the same `y` tensor
---
## Dimension Constraints
### Patchify Alignment
Video dimensions must be divisible by `patch_size × vae_stride`:
- **TI2V-5B**: patch=(1,2,2), stride=(4,16,16) → alignment = **32** pixels
- **T2V-14B**: patch=(1,2,2), stride=(4,8,8) → alignment = **16** pixels
Example: 720p (1280×720) → 720 % 32 ≠ 0, auto-aligns to **704**.
### Frame Count
Frames must satisfy `num_frames = 4n + 1` (e.g., 5, 9, 13, ..., 81) due to temporal VAE stride of 4.
---
## Performance Optimizations
### Batched CFG
Instead of two separate forward passes for conditional and unconditional predictions, batch them into a single B=2 forward pass:
```python
preds = model([latents, latents], t=t_batch, context=context_cfg, ...)
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
```
**Result**: ~40% speedup by amortizing attention overhead.
### Precomputed Text Embeddings & Cross-Attention KV Cache
Text embeddings and cross-attention K/V projections are constant across all diffusion steps. Computing them once and passing as caches eliminates redundant computation.
### Memory Management in Diffusion Loop
```python
# Release temporaries before eval to free memory for graph execution
del noise_pred_cond, noise_pred_uncond, noise_pred, preds
mx.eval(latents)
```
MLX's lazy evaluation means `mx.eval()` triggers the full computation graph. Deleting intermediate arrays before eval allows MLX to reuse their memory during execution.
---
## Weight Conversion
### Key Mapping Patterns
The PyTorch → MLX conversion (`convert_wan.py`) handles several systematic transforms:
1. **Conv3d weight transposition**: PyTorch `(out, in, D, H, W)` → MLX `(out, D, H, W, in)`
2. **Linear weight transposition**: PyTorch `(out, in)` → MLX `(out, in)` (same convention for `nn.Linear`)
3. **Nested module paths**: `blocks.0.self_attn.q.weight` → same paths, MLX loads by dotted key
### Dual-Model Splitting
The T2V-14B uses dual models (high-noise and low-noise). The conversion script splits a single checkpoint into separate files or handles pre-split checkpoints from HuggingFace.
---
## Testing Strategy
332 tests across 10 files, all running in ~5 seconds:
| File | Focus |
|------|-------|
| test_wan_config.py | Config presets, field validation |
| test_wan_attention.py | Self/cross attention, RMSNorm, bf16 autocast |
| test_wan_transformer.py | FFN, attention block, float32 modulation |
| test_wan_model.py | Full DiT forward pass, per-token timesteps |
| test_wan_t5.py | T5 encoder layers and full encoding |
| test_wan_vae.py | VAE 2.1 decoder, VAE 2.2 encoder + decoder |
| test_wan_scheduler.py | All 3 schedulers, cross-scheduler coherence |
| test_wan_convert.py | Weight sanitization and conversion |
| test_wan_generate.py | End-to-end pipeline, I2V masks, dimension alignment |
| test_wan_i2v.py | I2V-14B config, y parameter, VAE encoder, mask construction |
Tests use a tiny config (`dim=64, heads=2, layers=2`) for fast execution. Cross-scheduler coherence tests verify that all three schedulers produce similar outputs from the same noise.
---
## Known Issues
### I2V Quality Degradation
Frames 213 gradually degrade, and frame 14 often has a "flash" artifact. All implementation details have been verified against the official PyTorch code with no discrepancies found. Possible causes:
- Subtle numerical differences from float32 vs float64 RoPE (MLX lacks float64 on GPU)
- MLX-specific attention precision behavior
- Better prompts and 720p resolution (the model's native resolution) help reduce artifacts
### Chinese Negative Prompt
The official Wan2.2 uses a Chinese negative prompt that prevents oversaturation and comic-style artifacts. Correct tokenization requires `ftfy.fix_text()` to normalize fullwidth characters and double HTML unescaping. Without proper text cleaning, the negative prompt tokens don't match the training distribution, causing blurry patches.

View File

@@ -0,0 +1,58 @@
"""Image-to-Video utility functions for Wan2.2."""
import mlx.core as mx
import numpy as np
def preprocess_image(image_path: str, width: int, height: int) -> mx.array:
"""Load, resize, center-crop, and normalize an image for I2V.
Args:
image_path: Path to input image
width: Target width
height: Target height
Returns:
Image tensor [1, 1, H, W, 3] in [-1, 1] (channels-last, batch + temporal dims)
"""
from PIL import Image
img = Image.open(image_path).convert("RGB")
# Resize so that the image covers the target size (LANCZOS)
scale = max(width / img.width, height / img.height)
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
# Center crop
x1 = (img.width - width) // 2
y1 = (img.height - height) // 2
img = img.crop((x1, y1, x1 + width, y1 + height))
# To tensor: [H, W, 3] float32 in [-1, 1]
arr = np.array(img, dtype=np.float32) / 255.0
arr = arr * 2.0 - 1.0 # [0,1] → [-1,1]
return mx.array(arr[None, None]) # [1, 1, H, W, 3]
def build_i2v_mask(z_shape, patch_size):
"""Build temporal mask for I2V: first frame = 0, rest = 1.
Args:
z_shape: Latent shape (C, T, H, W) in channels-first
patch_size: (pt, ph, pw) patch size
Returns:
mask: (C, T, H, W) float32 — 0 for first frame, 1 for rest
mask_tokens: (1, L) float32 — 0 for first-frame tokens, 1 for rest
"""
C, T, H, W = z_shape
mask = mx.ones(z_shape)
# Zero out the first temporal position
mask = mx.concatenate([mx.zeros((C, 1, H, W)), mask[:, 1:]], axis=1)
# Token-level mask for per-token timesteps: subsample to patch grid
# mask shape [C, T, H, W] → take first channel, subsample by patch_size
pt, ph, pw = patch_size
mask_tokens = mask[0, ::pt, ::ph, ::pw] # [T', H', W']
mask_tokens = mask_tokens.reshape(1, -1) # [1, L]
return mask, mask_tokens

View File

@@ -0,0 +1,183 @@
"""Wan model loading utilities."""
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
def load_wan_model(model_path: Path, config, quantization: dict | None = None, loras: list | None = None):
"""Load and initialize WanModel, with optional quantization and LoRA support.
Args:
model_path: Path to model safetensors file
config: WanModelConfig
quantization: Optional dict with 'bits' and 'group_size' keys.
If provided, creates QuantizedLinear stubs before loading.
loras: Optional list of (lora_path, strength) tuples to apply.
"""
from mlx_video.models.wan.model import WanModel
model = WanModel(config)
if quantization:
from mlx_video.convert_wan import _quantize_predicate
nn.quantize(
model,
group_size=quantization["group_size"],
bits=quantization["bits"],
class_predicate=lambda path, m: _quantize_predicate(path, m),
)
weights = mx.load(str(model_path))
# Apply LoRAs: dequantize+merge for quantized models, weight merge for bf16
if loras:
if quantization:
# Dequantize LoRA-targeted layers, merge delta, replace with bf16 Linear.
# Non-LoRA layers stay 4-bit. Zero per-step overhead.
from mlx_video.convert_wan import _load_lora_configs
from mlx_video.lora import apply_loras_to_model
model.load_weights(list(weights.items()), strict=False)
mx.eval(model.parameters())
module_to_loras = _load_lora_configs(loras)
apply_loras_to_model(model, module_to_loras)
mx.eval(model.parameters())
return model
else:
# Weight merging: fold LoRA into bf16 weights before loading
from mlx_video.convert_wan import load_and_apply_loras
weights = load_and_apply_loras(dict(weights), loras)
model.load_weights(list(weights.items()), strict=False)
mx.eval(model.parameters())
return model
def load_t5_encoder(model_path: Path, config):
"""Load T5 text encoder.
Weights are upcast to float32 for maximum precision — the T5 encoder
only runs once per generation, so performance impact is negligible.
This matches the official which computes softmax in float32 explicitly.
"""
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=config.t5_vocab_size,
dim=config.t5_dim,
dim_attn=config.t5_dim_attn,
dim_ffn=config.t5_dim_ffn,
num_heads=config.t5_num_heads,
num_layers=config.t5_num_layers,
num_buckets=config.t5_num_buckets,
shared_pos=False,
)
weights = mx.load(str(model_path))
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
encoder.load_weights(list(weights.items()))
mx.eval(encoder.parameters())
return encoder
def load_vae_decoder(model_path: Path, config=None):
"""Load VAE decoder (skips encoder weights with strict=False).
For Wan2.2 (vae_z_dim=48), uses Wan22VAEDecoder.
For Wan2.1 (vae_z_dim=16), uses WanVAE.
"""
is_wan22 = config is not None and config.vae_z_dim == 48
if is_wan22:
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
vae = Wan22VAEDecoder(z_dim=48)
else:
from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16)
weights = mx.load(str(model_path))
# Upcast VAE weights to float32 for quality — official Wan2.2 runs VAE in float32
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
vae.load_weights(list(weights.items()), strict=False)
mx.eval(vae.parameters())
return vae
def load_vae_encoder(model_path: Path, config=None):
"""Load VAE encoder for I2V image encoding.
For Wan2.2 TI2V (vae_z_dim=48), uses Wan22VAEEncoder.
For Wan2.1/I2V-14B (vae_z_dim=16), uses WanVAE with encoder=True.
"""
if config is not None and config.vae_z_dim == 16:
from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16, encoder=True)
else:
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
vae = Wan22VAEEncoder(z_dim=config.vae_z_dim if config else 48)
weights = mx.load(str(model_path))
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
vae.load_weights(list(weights.items()), strict=False)
mx.eval(vae.parameters())
return vae
def _clean_text(text: str) -> str:
"""Clean text matching official Wan2.2 tokenizer preprocessing.
Applies ftfy.fix_text (fixes mojibake, normalizes fullwidth chars),
double HTML unescape, and whitespace normalization. Critical for
correct tokenization of the Chinese negative prompt.
"""
import html
import re
try:
import ftfy
text = ftfy.fix_text(text)
except ImportError:
pass
text = html.unescape(html.unescape(text))
text = re.sub(r"\s+", " ", text).strip()
return text
def encode_text(
encoder,
tokenizer,
prompt: str,
text_len: int = 512,
) -> mx.array:
"""Encode text prompt using T5 encoder.
Args:
encoder: T5Encoder model
tokenizer: HuggingFace tokenizer
prompt: Text prompt
text_len: Maximum text length
Returns:
Text embeddings [L, dim]
"""
prompt = _clean_text(prompt)
tokens = tokenizer(
prompt,
max_length=text_len,
padding="max_length",
truncation=True,
return_tensors="np",
)
ids = mx.array(tokens["input_ids"])
mask = mx.array(tokens["attention_mask"])
embeddings = encoder(ids, mask=mask)
# Return only non-padding tokens
seq_len = int(mask.sum().item())
return embeddings[0, :seq_len]

View File

@@ -0,0 +1,377 @@
import math
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .attention import WanLayerNorm, _linear_dtype
from .config import WanModelConfig
from .rope import rope_params, rope_precompute_cos_sin
from .transformer import WanAttentionBlock
def sinusoidal_embedding_1d(dim: int, position: mx.array) -> mx.array:
"""Compute sinusoidal positional embeddings.
Args:
dim: Embedding dimension (must be even).
position: Tensor of positions — 1D [L] or 2D [B, L].
Returns:
Embeddings of shape [L, dim] or [B, L, dim].
"""
assert dim % 2 == 0
half = dim // 2
pos = position.astype(mx.float32)
inv_freq = mx.power(10000.0, -mx.arange(half).astype(mx.float32) / half)
sinusoid = pos[..., None] * inv_freq # [..., half]
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
class Head(nn.Module):
"""Output projection head with learned modulation."""
def __init__(self, dim: int, out_dim: int, patch_size: tuple, eps: float = 1e-6):
super().__init__()
self.out_dim = out_dim
self.patch_size = patch_size
proj_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, proj_dim)
self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(mx.float32)
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
"""
Args:
x: [B, L, dim]
e: [B, dim] or [B, 1, dim] (broadcast) or [B, L, dim] (per-token)
"""
if e.ndim == 2:
e = e[:, None, :] # [B, 1, dim]
# Compute modulation in float32 (matching reference's autocast(float32))
mod = self.modulation[:, None, :, :] + e[:, :, None, :] # float32
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
x_norm = self.norm(x)
x_mod = x_norm * (1 + e1) + e0
return self.head(x_mod)
class WanModel(nn.Module):
"""Wan2.2 diffusion backbone for text-to-video generation."""
def __init__(self, config: WanModelConfig):
super().__init__()
self.config = config
dim = config.dim
self.dim = dim
self.num_heads = config.num_heads
self.out_dim = config.out_dim
self.patch_size = config.patch_size
self.text_len = config.text_len
self.freq_dim = config.freq_dim
# Patch embedding: Conv3d implemented as a reshaped linear
# For kernel (1,2,2) and stride (1,2,2): reshape input then linear
patch_dim = config.in_dim * math.prod(config.patch_size)
self.patch_embedding_proj = nn.Linear(patch_dim, dim)
self._patch_size = config.patch_size
# Text embedding MLP
self.text_embedding_0 = nn.Linear(config.text_dim, dim)
self.text_embedding_act = nn.GELU(approx="tanh")
self.text_embedding_1 = nn.Linear(dim, dim)
# Time embedding MLP
self.time_embedding_0 = nn.Linear(config.freq_dim, dim)
self.time_embedding_act = nn.SiLU()
self.time_embedding_1 = nn.Linear(dim, dim)
# Time projection for modulation (6x dim)
self.time_projection_act = nn.SiLU()
self.time_projection = nn.Linear(dim, dim * 6)
# Transformer blocks
self.blocks = [
WanAttentionBlock(
dim=dim,
ffn_dim=config.ffn_dim,
num_heads=config.num_heads,
window_size=config.window_size,
qk_norm=config.qk_norm,
cross_attn_norm=config.cross_attn_norm,
eps=config.eps,
)
for _ in range(config.num_layers)
]
# Output head
self.head = Head(dim, config.out_dim, config.patch_size, config.eps)
# Precompute RoPE frequencies — three separate tables concatenated.
# Reference computes three rope_params with different dim normalizations
# so each axis (temporal/height/width) gets its own full frequency range.
d = dim // config.num_heads
self.freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
# Precompute sinusoidal inv_freq for time embedding.
half = config.freq_dim // 2
self._inv_freq = mx.array(
np.power(10000.0, -np.arange(half, dtype=np.float64) / half
).astype(np.float32)
)
def _patchify(self, x: mx.array) -> tuple:
"""Convert video tensor to patch embeddings.
Args:
x: Video latent [C, F, H, W]
Returns:
(patches, grid_size): patches [1, L, dim], grid_size (F', H', W')
"""
c, f, h, w = x.shape
pt, ph, pw = self._patch_size
f_out = f // pt
h_out = h // ph
w_out = w // pw
# Reshape: [C, F, H, W] -> [F', H', W', C, pt, ph, pw] -> [F'*H'*W', C*pt*ph*pw]
# Order must be [C, pt, ph, pw] (C slowest) to match Conv3d weight layout
x = x.reshape(c, f_out, pt, h_out, ph, w_out, pw)
x = x.transpose(1, 3, 5, 0, 2, 4, 6) # [F', H', W', C, pt, ph, pw]
x = x.reshape(f_out * h_out * w_out, -1) # [L, C*pt*ph*pw]
# Project and cast to model dtype to prevent float32 cascade from input latents
patches = self.patch_embedding_proj(x) # [L, dim]
patches = patches.astype(_linear_dtype(self.patch_embedding_proj))
patches = patches[None, :, :] # [1, L, dim]
return patches, (f_out, h_out, w_out)
def unpatchify(self, x: mx.array, grid_sizes: list) -> list:
"""Reconstruct video from patch embeddings.
Args:
x: [B, L, out_dim * prod(patch_size)]
grid_sizes: List of (F', H', W') per batch element
Returns:
List of tensors [C, F, H, W]
"""
c = self.out_dim
pt, ph, pw = self.patch_size
out = []
for i, (f, h, w) in enumerate(grid_sizes):
seq_len = f * h * w
u = x[i, :seq_len] # [L, out_dim * pt * ph * pw]
u = u.reshape(f, h, w, pt, ph, pw, c)
# Rearrange: [F', H', W', pt, ph, pw, C] -> [C, F'*pt, H'*ph, W'*pw]
u = u.transpose(6, 0, 3, 1, 4, 2, 5) # [C, F', pt, H', ph, W', pw]
u = u.reshape(c, f * pt, h * ph, w * pw)
out.append(u)
return out
def embed_text(self, context: list) -> mx.array:
"""Precompute text embeddings (call once, reuse across steps).
Args:
context: List of text embeddings [L_text, text_dim]
Returns:
Embedded context [B, text_len, dim] in model dtype
"""
model_dtype = _linear_dtype(self.patch_embedding_proj)
context_padded = []
for ctx in context:
pad_len = self.text_len - ctx.shape[0]
if pad_len > 0:
ctx = mx.concatenate(
[ctx, mx.zeros((pad_len, ctx.shape[1]), dtype=ctx.dtype)],
axis=0,
)
context_padded.append(ctx)
context_batch = mx.stack(context_padded) # [B, text_len, text_dim]
context_batch = self.text_embedding_1(
self.text_embedding_act(self.text_embedding_0(context_batch))
)
return context_batch.astype(model_dtype)
def prepare_cross_kv(self, context: mx.array) -> list:
"""Pre-compute cross-attention K/V for all blocks.
Call once before the diffusion loop to cache K/V projections,
eliminating redundant computation at each denoising step.
Args:
context: Pre-embedded text [B, text_len, dim]
Returns:
List of (k, v) tuples, one per block
"""
kv_caches = []
for block in self.blocks:
kv_caches.append(block.cross_attn.prepare_kv(context))
return kv_caches
def prepare_rope(self, grid_sizes: list) -> tuple:
"""Pre-compute RoPE cos/sin for constant grid sizes.
Call once before the diffusion loop when grid sizes don't change
across steps. Eliminates per-step broadcast/concat overhead.
Args:
grid_sizes: List of (F, H, W) tuples per batch element
Returns:
(cos_f, sin_f) precomputed frequency tensors
"""
w_dtype = _linear_dtype(self.patch_embedding_proj)
return rope_precompute_cos_sin(grid_sizes, self.freqs, dtype=w_dtype)
def __call__(
self,
x_list: list,
t: mx.array,
context: list | mx.array,
seq_len: int,
cross_kv_caches: list | None = None,
y: list | None = None,
rope_cos_sin: tuple | None = None,
) -> list:
"""Forward pass.
Args:
x_list: List of video latent tensors [C, F, H, W]
t: Timestep tensor [B]
context: List of raw text embeddings, OR pre-embedded tensor
from embed_text() [B, text_len, dim]
seq_len: Maximum sequence length for padding
cross_kv_caches: Optional list of (k, v) tuples from
prepare_cross_kv(), one per block.
y: Optional list of conditioning tensors for I2V [C_y, F, H, W].
Channel-concatenated with x before patchify.
rope_cos_sin: Optional precomputed (cos, sin) from prepare_rope().
Returns:
List of denoised tensors [C, F, H, W]
"""
# Detect identical inputs (CFG B=2) to avoid duplicate patchify work.
# Check BEFORE I2V concat since concat creates new array objects.
batch_size = len(x_list)
all_same = batch_size > 1 and all(
x_list[i] is x_list[0] for i in range(1, batch_size)
)
if all_same and y is not None:
all_same = all(y[i] is y[0] for i in range(1, len(y)))
# I2V: channel-concatenate conditioning y with noise x
if y is not None:
x_list = [mx.concatenate([u, v], axis=0) for u, v in zip(x_list, y)]
if all_same:
# Patchify once and broadcast — saves a Linear projection per step
p, gs = self._patchify(x_list[0]) # [1, L, dim]
grid_sizes = [gs] * batch_size
seq_lens_list = [p.shape[1]] * batch_size
# Pad and broadcast
if p.shape[1] < seq_len:
p = mx.concatenate(
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
axis=1,
)
x = mx.broadcast_to(p, (batch_size,) + p.shape[1:])
else:
patches = []
grid_sizes = []
seq_lens_list = []
for vid in x_list:
p, gs = self._patchify(vid) # [1, L, dim]
patches.append(p)
grid_sizes.append(gs)
seq_lens_list.append(p.shape[1])
x = mx.concatenate(
[
mx.concatenate(
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
axis=1,
)
if p.shape[1] < seq_len
else p
for p in patches
],
axis=0,
) # [B, seq_len, dim]
# Time embedding: sinusoidal from precomputed inv_freq.
# inv_freq was computed in float64 for precision, stored as float32.
# With integer timesteps (matching reference), float32 sin/cos is fine.
if t.ndim == 0:
t = t[None]
sinusoid = t[..., None].astype(mx.float32) * self._inv_freq
sin_emb = mx.concatenate(
[mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1
)
if t.ndim == 1:
# Standard T2V: scalar timestep per batch element [B]
e = self.time_embedding_1(
self.time_embedding_act(self.time_embedding_0(sin_emb))
) # [B, dim]
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
e0 = e0.reshape(batch_size, 1, 6, self.dim)
else:
# I2V: per-token timesteps [B, L]
e = self.time_embedding_1(
self.time_embedding_act(self.time_embedding_0(sin_emb))
) # [B, L, dim]
e0 = self.time_projection(self.time_projection_act(e)) # [B, L, dim*6]
e0 = e0.reshape(batch_size, -1, 6, self.dim)
# Text embedding: skip MLP if context is already embedded (mx.array)
if isinstance(context, mx.array):
# Pre-embedded: expand to batch size if needed
context_batch = context
if context_batch.shape[0] == 1 and batch_size > 1:
context_batch = mx.broadcast_to(
context_batch, (batch_size,) + context_batch.shape[1:]
)
else:
context_batch = self.embed_text(context)
# Pre-compute attention mask from seq_lens (constant across all blocks)
attn_mask = None
w_dtype = _linear_dtype(self.patch_embedding_proj)
if any(sl < seq_len for sl in seq_lens_list):
attn_mask = mx.zeros((batch_size, 1, 1, seq_len), dtype=w_dtype)
for i, sl in enumerate(seq_lens_list):
attn_mask[i, :, :, sl:] = -1e9
kwargs = dict(
e=e0,
seq_lens=seq_lens_list,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context_batch,
context_lens=None,
rope_cos_sin=rope_cos_sin,
attn_mask=attn_mask,
)
# Run transformer blocks
for i, block in enumerate(self.blocks):
kv = cross_kv_caches[i] if cross_kv_caches is not None else None
x = block(x, cross_kv_cache=kv, **kwargs)
# Output head
x = self.head(x, e)
# Unpatchify
outputs = self.unpatchify(x, grid_sizes)
return [u.astype(mx.float32) for u in outputs]

View File

@@ -0,0 +1,35 @@
import numpy as np
from pathlib import Path
def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
"""Save video frames to MP4.
Args:
frames: Video frames [T, H, W, 3] uint8
output_path: Output file path
fps: Frames per second
"""
try:
import imageio
writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8)
for frame in frames:
writer.append_data(frame)
writer.close()
except ImportError:
try:
import cv2
h, w = frames.shape[1], frames.shape[2]
fourcc = cv2.VideoWriter_fourcc(*"avc1")
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
for frame in frames:
writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
writer.release()
except (ImportError, Exception):
# Last resort: save as individual PNGs
from PIL import Image
out_dir = Path(output_path).parent / Path(output_path).stem
out_dir.mkdir(parents=True, exist_ok=True)
for i, frame in enumerate(frames):
Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png")
print(f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)")

View File

@@ -0,0 +1,178 @@
import math
import mlx.core as mx
import numpy as np
def rope_params(max_seq_len: int, dim: int, theta: float = 10000.0) -> mx.array:
"""Precompute RoPE frequency parameters as complex numbers.
Returns:
Complex frequency tensor of shape [max_seq_len, dim // 2].
"""
assert dim % 2 == 0
freqs = np.arange(max_seq_len, dtype=np.float64)[:, None] * (
1.0
/ np.power(
theta,
np.arange(0, dim, 2, dtype=np.float64) / dim,
)
)[None, :]
# Store as (cos, sin) pairs: shape [max_seq_len, dim // 2, 2]
cos_freqs = np.cos(freqs).astype(np.float32)
sin_freqs = np.sin(freqs).astype(np.float32)
return mx.array(np.stack([cos_freqs, sin_freqs], axis=-1))
def rope_apply(
x: mx.array,
grid_sizes: list,
freqs: mx.array,
precomputed_cos_sin: tuple | None = None,
) -> mx.array:
"""Apply 3-way factorized RoPE to Q or K tensor.
Args:
x: Shape [B, L, num_heads, head_dim]
grid_sizes: List of (F, H, W) tuples per batch element
freqs: Precomputed cos/sin, shape [1024, d//2, 2] split into 3 parts
precomputed_cos_sin: Optional (cos, sin) from rope_precompute_cos_sin()
"""
b, s, n, d = x.shape
half_d = d // 2
if precomputed_cos_sin is not None:
cos_f, sin_f = precomputed_cos_sin
# Check if all batch elements have the same grid (common for CFG B=2)
f0, h0, w0 = grid_sizes[0]
seq_len = f0 * h0 * w0
all_same_grid = all(
grid_sizes[i] == grid_sizes[0] for i in range(1, b)
) if b > 1 else True
if all_same_grid:
# Vectorized path: apply RoPE to all batch elements at once
x_seq = x[:, :seq_len].reshape(b, seq_len, n, half_d, 2)
x_real = x_seq[..., 0]
x_imag = x_seq[..., 1]
out_real = x_real * cos_f - x_imag * sin_f
out_imag = x_real * sin_f + x_imag * cos_f
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(b, seq_len, n, d)
if seq_len < s:
x_rotated = mx.concatenate([x_rotated, x[:, seq_len:]], axis=1)
return x_rotated
else:
# Per-element path for mixed grid sizes
outputs = []
for i in range(b):
f, h, w = grid_sizes[i]
sl = f * h * w
x_i = x[i, :sl].reshape(sl, n, half_d, 2)
x_real = x_i[..., 0]
x_imag = x_i[..., 1]
out_real = x_real * cos_f - x_imag * sin_f
out_imag = x_real * sin_f + x_imag * cos_f
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(sl, n, d)
if sl < s:
x_rotated = mx.concatenate([x_rotated, x[i, sl:]], axis=0)
outputs.append(x_rotated)
return mx.stack(outputs)
# Cast freqs to input dtype to prevent float32 promotion cascade
if freqs.dtype != x.dtype:
freqs = freqs.astype(x.dtype)
# Split frequency dimensions: temporal gets more capacity
d_t = half_d - 2 * (half_d // 3)
d_h = half_d // 3
d_w = half_d // 3
# Split freqs along dim axis
freqs_t = freqs[:, :d_t] # [1024, d_t, 2]
freqs_h = freqs[:, d_t : d_t + d_h] # [1024, d_h, 2]
freqs_w = freqs[:, d_t + d_h : d_t + d_h + d_w] # [1024, d_w, 2]
outputs = []
for i in range(b):
f, h, w = grid_sizes[i]
seq_len = f * h * w
# Reshape x to pairs for rotation: [seq_len, n, half_d, 2]
x_i = x[i, :seq_len].reshape(seq_len, n, half_d, 2)
# Build per-position frequencies by expanding along grid dims
# temporal: [f,1,1,d_t,2] -> [f,h,w,d_t,2]
ft = mx.broadcast_to(
freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2)
)
# height: [1,h,1,d_h,2] -> [f,h,w,d_h,2]
fh = mx.broadcast_to(
freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2)
)
# width: [1,1,w,d_w,2] -> [f,h,w,d_w,2]
fw = mx.broadcast_to(
freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2)
)
# Concatenate: [f*h*w, half_d, 2]
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)
# Apply rotation: (a + bi) * (cos + sin*i) = (a*cos - b*sin) + (a*sin + b*cos)i
cos_f = freqs_i[..., 0] # [seq_len, 1, half_d]
sin_f = freqs_i[..., 1] # [seq_len, 1, half_d]
x_real = x_i[..., 0] # [seq_len, n, half_d]
x_imag = x_i[..., 1] # [seq_len, n, half_d]
out_real = x_real * cos_f - x_imag * sin_f
out_imag = x_real * sin_f + x_imag * cos_f
# Interleave back: [seq_len, n, half_d, 2] -> [seq_len, n, d]
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(seq_len, n, d)
# Handle padding: keep non-rotated tokens after seq_len
if seq_len < s:
x_rotated = mx.concatenate([x_rotated, x[i, seq_len:]], axis=0)
outputs.append(x_rotated)
return mx.stack(outputs)
def rope_precompute_cos_sin(
grid_sizes: list, freqs: mx.array, dtype: type = mx.float32
) -> tuple:
"""Precompute cos/sin frequency tensors for constant grid sizes.
Call once before the diffusion loop. Pass result as precomputed_cos_sin
to rope_apply to skip per-step broadcast/concat.
Args:
grid_sizes: List of (F, H, W) tuples (must be same for all batch elements)
freqs: Precomputed frequencies [1024, d//2, 2]
dtype: Target dtype for the output tensors
Returns:
(cos_f, sin_f) each [seq_len, 1, half_d]
"""
if freqs.dtype != dtype:
freqs = freqs.astype(dtype)
f, h, w = grid_sizes[0]
seq_len = f * h * w
half_d = freqs.shape[1]
d_t = half_d - 2 * (half_d // 3)
d_h = half_d // 3
d_w = half_d // 3
freqs_t = freqs[:, :d_t]
freqs_h = freqs[:, d_t : d_t + d_h]
freqs_w = freqs[:, d_t + d_h : d_t + d_h + d_w]
ft = mx.broadcast_to(freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2))
fh = mx.broadcast_to(freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2))
fw = mx.broadcast_to(freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2))
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)
return freqs_i[..., 0], freqs_i[..., 1]

View File

@@ -0,0 +1,452 @@
"""Flow matching schedulers for Wan2.2 inference.
Provides Euler, DPM++2M, and UniPC solvers for flow matching diffusion.
Higher-order solvers (DPM++, UniPC) converge faster, needing fewer steps
for the same quality as Euler.
"""
import math
import numpy as np
import mlx.core as mx
def _compute_sigmas(
num_steps: int, shift: float = 1.0, num_train_timesteps: int = 1000
) -> np.ndarray:
"""Compute shifted sigma schedule matching official Wan2.2 scheduler.
The reference creates FlowUniPCMultistepScheduler with shift=1 (identity)
in the constructor, deriving sigma_max/sigma_min from the unshifted
training schedule. Then set_timesteps() builds a linspace between those
unshifted bounds and applies the actual shift once.
Returns num_steps+1 values (the last being 0.0 for the terminal state).
"""
# sigma bounds from unshifted training schedule (constructor uses shift=1)
alphas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps)[
::-1
]
sigmas_unshifted = 1.0 - alphas
sigma_max = float(sigmas_unshifted[0]) # (N-1)/N
sigma_min = float(sigmas_unshifted[-1]) # 0.0
# Interpolate, then apply shift once (matching set_timesteps)
sigmas = np.linspace(sigma_max, sigma_min, num_steps + 1)[:-1]
sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas)
return np.append(sigmas, 0.0).astype(np.float32)
class FlowMatchEulerScheduler:
"""1st-order Euler scheduler for flow matching diffusion."""
def __init__(self, num_train_timesteps: int = 1000):
self.num_train_timesteps = num_train_timesteps
self.timesteps = None
self.sigmas = None
def set_timesteps(self, num_steps: int, shift: float = 1.0):
sigmas = _compute_sigmas(num_steps, shift, self.num_train_timesteps)
self.sigmas = mx.array(sigmas)
# Integer timesteps to match reference (model trained with int timesteps)
self.timesteps = mx.array(
(sigmas[:-1] * self.num_train_timesteps).astype(np.int64).astype(np.float32)
)
# Store as Python floats to avoid .item() sync in step()
self._sigmas_float = sigmas.tolist()
self._step_index = 0
def step(
self,
model_output: mx.array,
timestep,
sample: mx.array,
) -> mx.array:
"""Euler step: x_next = x + (sigma_next - sigma_cur) * v."""
dt = self._sigmas_float[self._step_index + 1] - self._sigmas_float[self._step_index]
x_next = sample + dt * model_output
self._step_index += 1
return x_next
def reset(self):
self._step_index = 0
class FlowDPMPP2MScheduler:
"""DPM-Solver++(2M) for flow matching diffusion.
2nd-order multistep solver that reuses the previous step's model output
for a correction term. Falls back to 1st order on the first and
(optionally) last step. Reference: Wan2.2 fm_solvers.py.
"""
def __init__(
self,
num_train_timesteps: int = 1000,
lower_order_final: bool = True,
):
self.num_train_timesteps = num_train_timesteps
self.lower_order_final = lower_order_final
self.timesteps = None
self.sigmas = None
def set_timesteps(self, num_steps: int, shift: float = 1.0):
sigmas = _compute_sigmas(num_steps, shift, self.num_train_timesteps)
self.sigmas = mx.array(sigmas)
self.timesteps = mx.array(
(sigmas[:-1] * self.num_train_timesteps).astype(np.int64).astype(np.float32)
)
# Store sigmas as Python floats for scalar math
self._sigmas_float = sigmas.tolist()
self._step_index = 0
self._num_steps = num_steps
self._prev_x0 = None # previous x0 prediction for 2nd-order correction
@staticmethod
def _lambda(sigma: float) -> float:
"""log-SNR: lambda(sigma) = log((1-sigma)/sigma).
Returns -inf at sigma=1.0 (pure noise) and +inf at sigma=0.0 (clean),
matching torch.log behavior in the official code.
"""
if sigma >= 1.0:
return -math.inf
if sigma <= 0.0:
return math.inf
return math.log((1.0 - sigma) / sigma)
def step(
self,
model_output: mx.array,
timestep,
sample: mx.array,
) -> mx.array:
"""DPM++(2M) step for flow matching.
Converts velocity prediction to x0, then applies 1st or 2nd order
update depending on available history.
"""
i = self._step_index
s = self._sigmas_float
sigma_cur = s[i]
sigma_next = s[i + 1]
# Convert velocity -> x0 prediction: x0 = sample - sigma * v
x0 = sample - sigma_cur * model_output
# Decide order: 1st for first step, last step (if lower_order_final
# and few steps), otherwise 2nd
use_first_order = (
self._prev_x0 is None
or (
self.lower_order_final
and i == self._num_steps - 1
and self._num_steps < 15
)
)
if use_first_order or sigma_next == 0.0:
# 1st order DPM++ (equivalent to DDIM):
# x_next = (σ_next/σ_cur)*x - (α_next*(exp(-h)-1))*x0
if sigma_next == 0.0:
x_next = x0
else:
lambda_cur = self._lambda(sigma_cur)
lambda_next = self._lambda(sigma_next)
h = lambda_next - lambda_cur
alpha_next = 1.0 - sigma_next
coeff_x = sigma_next / sigma_cur
coeff_x0 = alpha_next * math.expm1(-h)
x_next = coeff_x * sample - coeff_x0 * x0
else:
# 2nd order DPM++(2M) with midpoint correction
sigma_prev = s[i - 1]
lambda_prev = self._lambda(sigma_prev)
lambda_cur = self._lambda(sigma_cur)
lambda_next = self._lambda(sigma_next)
h = lambda_next - lambda_cur
h_0 = lambda_cur - lambda_prev
r0 = h_0 / h
# D0 = current x0, D1 = correction from previous x0
D0 = x0
D1 = (1.0 / r0) * (x0 - self._prev_x0)
alpha_next = 1.0 - sigma_next
exp_neg_h_m1 = math.expm1(-h) # exp(-h) - 1
x_next = (
(sigma_next / sigma_cur) * sample
- (alpha_next * exp_neg_h_m1) * D0
- 0.5 * (alpha_next * exp_neg_h_m1) * D1
)
self._prev_x0 = x0
self._step_index += 1
return x_next
def reset(self):
self._step_index = 0
self._prev_x0 = None
class FlowUniPCScheduler:
"""UniPC (Unified Predictor-Corrector) for flow matching diffusion.
Multi-step predictor-corrector solver with configurable order.
The corrector refines each step using the model output that was already
computed, costing no extra model evaluations. Official Wan2.2 default.
Reference: Wan2.2 fm_solvers_unipc.py.
"""
def __init__(
self,
num_train_timesteps: int = 1000,
solver_order: int = 2,
lower_order_final: bool = True,
disable_corrector: list | None = None,
use_corrector: bool = True,
):
self.num_train_timesteps = num_train_timesteps
self.solver_order = solver_order
self.lower_order_final = lower_order_final
self._use_corrector = use_corrector
self.disable_corrector = set(disable_corrector or [])
self.timesteps = None
self.sigmas = None
def set_timesteps(self, num_steps: int, shift: float = 1.0):
sigmas = _compute_sigmas(num_steps, shift, self.num_train_timesteps)
self.sigmas = mx.array(sigmas)
self.timesteps = mx.array(
(sigmas[:-1] * self.num_train_timesteps).astype(np.int64).astype(np.float32)
)
self._sigmas_float = sigmas.tolist()
self._step_index = 0
self._num_steps = num_steps
self._lower_order_nums = 0
# Model output (x0) history for multi-step, stored newest-last
self._model_outputs = [None] * self.solver_order
self._last_sample = None # sample before prediction (for corrector)
self._this_order = 1
@staticmethod
def _lambda(sigma: float) -> float:
"""log-SNR: lambda(sigma) = log((1-sigma)/sigma).
Returns -inf at sigma=1.0 (pure noise) and +inf at sigma=0.0 (clean),
matching torch.log behavior in the official code.
"""
if sigma >= 1.0:
return -math.inf
if sigma <= 0.0:
return math.inf
return math.log((1.0 - sigma) / sigma)
def _convert_output(self, velocity: mx.array, sample: mx.array) -> mx.array:
"""Convert velocity prediction to x0: x0 = sample - sigma * v."""
sigma = self._sigmas_float[self._step_index]
return sample - sigma * velocity
def _uni_p_bh2(self, x0: mx.array, sample: mx.array, order: int) -> mx.array:
"""UniP predictor with B(h)=expm1(-h) basis (bh2 variant).
Matches official multistep_uni_p_bh_update: computes rhos_p via
linalg.solve for order >= 3; order <= 2 uses analytic rhos_p=[0.5].
"""
i = self._step_index
s = self._sigmas_float
sigma_s0 = s[i]
sigma_t = s[i + 1]
if sigma_t == 0.0:
return x0
lambda_s0 = self._lambda(sigma_s0)
lambda_t = self._lambda(sigma_t)
h = lambda_t - lambda_s0
hh = -h # negated for predict_x0
alpha_t = 1.0 - sigma_t
h_phi_1 = math.expm1(hh)
B_h = h_phi_1
m0 = self._model_outputs[-1]
# Base prediction
x_t = (sigma_t / sigma_s0) * sample - (alpha_t * h_phi_1) * m0
if order >= 2 and m0 is not None:
rks = []
D1s = []
for k in range(1, order):
si_idx = i - k
if si_idx < 0 or self._model_outputs[-(k + 1)] is None:
break
mk = self._model_outputs[-(k + 1)]
sigma_sk = s[si_idx]
lambda_sk = self._lambda(sigma_sk)
rk = (lambda_sk - lambda_s0) / h
if math.isinf(rk):
break
rks.append(rk)
D1s.append((mk - m0) / rk)
if D1s:
effective_order = len(D1s) + 1
if effective_order <= 2:
# Analytic solution for order 2
rhos_p = [0.5]
else:
rks_arr = np.array(rks, dtype=np.float64)
h_phi_k = h_phi_1 / hh - 1.0
factorial_i = 1
R_rows = []
b_vals = []
for j in range(1, effective_order):
R_rows.append(rks_arr ** (j - 1))
b_vals.append(float(h_phi_k * factorial_i / B_h))
factorial_i *= j + 1
h_phi_k = h_phi_k / hh - 1.0 / factorial_i
R = np.stack(R_rows)
b = np.array(b_vals)
rhos_p = np.linalg.solve(R, b).tolist()
pred_res = sum(r * d for r, d in zip(rhos_p, D1s))
x_t = x_t - (alpha_t * B_h) * pred_res
return x_t
def _uni_c_bh2(
self,
model_x0: mx.array,
last_sample: mx.array,
this_sample: mx.array,
order: int,
) -> mx.array:
"""UniC corrector with B(h)=expm1(-h) basis (bh2 variant).
Matches official multistep_uni_c_bh_update: computes rhos_c via
linalg.solve for order >= 2 (not hardcoded 0.5).
"""
i = self._step_index
s = self._sigmas_float
sigma_s0 = s[i - 1]
sigma_t = s[i]
if sigma_t == 0.0:
return this_sample
lambda_s0 = self._lambda(sigma_s0)
lambda_t = self._lambda(sigma_t)
h = lambda_t - lambda_s0
hh = -h # negated for predict_x0
alpha_t = 1.0 - sigma_t
h_phi_1 = math.expm1(hh)
B_h = h_phi_1
m0 = self._model_outputs[-1]
# Re-derive base from last_sample
x_t_ = (sigma_t / sigma_s0) * last_sample - (alpha_t * h_phi_1) * m0
D1_t = model_x0 - m0
# Gather rks and D1s from history
rks = []
D1s = []
for k in range(1, order):
si_idx = i - (k + 1)
if si_idx < 0 or self._model_outputs[-(k + 1)] is None:
break
mk = self._model_outputs[-(k + 1)]
sigma_sk = s[si_idx]
lambda_sk = self._lambda(sigma_sk)
rk = (lambda_sk - lambda_s0) / h
if math.isinf(rk):
break # History references sigma=1.0 boundary; reduce order
rks.append(rk)
D1s.append((mk - m0) / rk)
rks.append(1.0)
effective_order = len(rks) # = len(D1s) + 1
# Compute rhos_c coefficients
if effective_order == 1:
rhos_c = [0.5]
else:
rks_arr = np.array(rks, dtype=np.float64)
h_phi_k = h_phi_1 / hh - 1.0
factorial_i = 1
R_rows = []
b_vals = []
for j in range(1, effective_order + 1):
R_rows.append(rks_arr ** (j - 1))
b_vals.append(float(h_phi_k * factorial_i / B_h))
factorial_i *= j + 1
h_phi_k = h_phi_k / hh - 1.0 / factorial_i
R = np.stack(R_rows)
b = np.array(b_vals)
rhos_c = np.linalg.solve(R, b).tolist()
# Apply correction
corr_res = mx.zeros_like(D1_t)
for k_idx, d1 in enumerate(D1s):
corr_res = corr_res + rhos_c[k_idx] * d1
x_t = x_t_ - (alpha_t * B_h) * (corr_res + rhos_c[-1] * D1_t)
return x_t
def step(
self,
model_output: mx.array,
timestep,
sample: mx.array,
) -> mx.array:
"""UniPC step: correct current, then predict next."""
i = self._step_index
# Convert velocity -> x0
x0 = self._convert_output(model_output, sample)
# 1. Corrector: refine current sample if we have history
use_corrector = (
self._use_corrector
and i > 0
and (i - 1) not in self.disable_corrector
and self._last_sample is not None
)
if use_corrector:
sample = self._uni_c_bh2(x0, self._last_sample, sample, self._this_order)
# 2. Shift model output history
for k in range(self.solver_order - 1):
self._model_outputs[k] = self._model_outputs[k + 1]
self._model_outputs[-1] = x0
# 3. Determine prediction order
if self.lower_order_final:
this_order = min(self.solver_order, self._num_steps - i)
else:
this_order = self.solver_order
self._this_order = min(this_order, self._lower_order_nums + 1)
# 4. Predict next sample
self._last_sample = sample
x_next = self._uni_p_bh2(x0, sample, self._this_order)
if self._lower_order_nums < self.solver_order:
self._lower_order_nums += 1
self._step_index += 1
return x_next
def reset(self):
self._step_index = 0
self._lower_order_nums = 0
self._model_outputs = [None] * self.solver_order
self._last_sample = None
self._this_order = 1

View File

@@ -0,0 +1,240 @@
"""T5 Text Encoder (UMT5-XXL) for Wan2.2 text conditioning."""
import math
import mlx.core as mx
import mlx.nn as nn
class T5LayerNorm(nn.Module):
"""RMS-based layer normalization (T5 style)."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = mx.ones((dim,))
def __call__(self, x: mx.array) -> mx.array:
return mx.fast.rms_norm(x, self.weight, self.eps)
class T5RelativeEmbedding(nn.Module):
"""T5-style relative position bias with bucketing."""
def __init__(
self,
num_buckets: int,
num_heads: int,
bidirectional: bool = True,
max_dist: int = 128,
):
super().__init__()
self.num_buckets = num_buckets
self.num_heads = num_heads
self.bidirectional = bidirectional
self.max_dist = max_dist
self.embedding = nn.Embedding(num_buckets, num_heads)
def _relative_position_bucket(self, rel_pos: mx.array) -> mx.array:
if self.bidirectional:
num_buckets = self.num_buckets // 2
rel_buckets = (rel_pos > 0).astype(mx.int32) * num_buckets
rel_pos = mx.abs(rel_pos)
else:
num_buckets = self.num_buckets
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
rel_pos = mx.maximum(-rel_pos, mx.zeros_like(rel_pos))
max_exact = num_buckets // 2
is_small = rel_pos < max_exact
rel_pos_f = rel_pos.astype(mx.float32)
rel_pos_large = (
max_exact
+ (
mx.log(rel_pos_f / max_exact)
/ math.log(self.max_dist / max_exact)
* (num_buckets - max_exact)
).astype(mx.int32)
)
rel_pos_large = mx.minimum(
rel_pos_large,
mx.full(rel_pos_large.shape, num_buckets - 1, dtype=mx.int32),
)
rel_buckets = rel_buckets + mx.where(is_small, rel_pos.astype(mx.int32), rel_pos_large)
return rel_buckets
def __call__(self, lq: int, lk: int) -> mx.array:
positions_k = mx.arange(lk)[None, :] # [1, lk]
positions_q = mx.arange(lq)[:, None] # [lq, 1]
rel_pos = positions_k - positions_q # [lq, lk]
buckets = self._relative_position_bucket(rel_pos)
embeds = self.embedding(buckets) # [lq, lk, num_heads]
embeds = embeds.transpose(2, 0, 1)[None, :, :, :] # [1, N, lq, lk]
return embeds
class T5Attention(nn.Module):
"""T5-style multi-head attention (no scaling)."""
def __init__(self, dim: int, dim_attn: int, num_heads: int, dropout: float = 0.0):
super().__init__()
assert dim_attn % num_heads == 0
self.dim = dim
self.dim_attn = dim_attn
self.num_heads = num_heads
self.head_dim = dim_attn // num_heads
self.q = nn.Linear(dim, dim_attn, bias=False)
self.k = nn.Linear(dim, dim_attn, bias=False)
self.v = nn.Linear(dim, dim_attn, bias=False)
self.o = nn.Linear(dim_attn, dim, bias=False)
def __call__(
self,
x: mx.array,
context: mx.array | None = None,
mask: mx.array | None = None,
pos_bias: mx.array | None = None,
) -> mx.array:
context = x if context is None else context
b, n, c = x.shape[0], self.num_heads, self.head_dim
q = self.q(x).reshape(b, -1, n, c) # [B, Lq, N, C]
k = self.k(context).reshape(b, -1, n, c) # [B, Lk, N, C]
v = self.v(context).reshape(b, -1, n, c)
# T5 uses no scaling — compute attention manually with float32 softmax
# to match official: F.softmax(attn.float(), dim=-1).type_as(attn)
# Using SDPA with bfloat16 inputs causes precision loss in softmax
# since unscaled logits can be very large (no 1/sqrt(d) division).
q = q.transpose(0, 2, 1, 3) # [B, N, Lq, C]
k = k.transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
# QK^T (no scaling) — compute in float32 for precision
attn = (q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2))
# Add position bias
if pos_bias is not None:
attn = attn + pos_bias.astype(mx.float32)
# Apply attention mask (use dtype min like official, not -1e9)
if mask is not None:
if mask.ndim == 2:
mask = mask[:, None, None, :] # [B, 1, 1, Lk]
elif mask.ndim == 3:
mask = mask[:, None, :, :] # [B, 1, Lq, Lk]
additive_mask = mx.where(mask == 0, -3.389e38, 0.0).astype(mx.float32)
attn = attn + additive_mask
# Softmax in float32 (matches official), then cast back
attn = mx.softmax(attn, axis=-1).astype(q.dtype)
# Attention @ V
out = (attn @ v).transpose(0, 2, 1, 3).reshape(b, -1, n * c)
return self.o(out)
class T5FeedForward(nn.Module):
"""Gated feed-forward: gate(x) * fc1(x) -> fc2."""
def __init__(self, dim: int, dim_ffn: int):
super().__init__()
self.dim = dim
self.dim_ffn = dim_ffn
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
self.gate_act = nn.GELU(approx="tanh")
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
def __call__(self, x: mx.array) -> mx.array:
return self.fc2(self.fc1(x) * self.gate_act(self.gate_proj(x)))
class T5SelfAttentionBlock(nn.Module):
"""T5 encoder block: self-attention + FFN."""
def __init__(
self,
dim: int,
dim_attn: int,
dim_ffn: int,
num_heads: int,
num_buckets: int,
shared_pos: bool = True,
):
super().__init__()
self.shared_pos = shared_pos
self.norm1 = T5LayerNorm(dim)
self.attn = T5Attention(dim, dim_attn, num_heads)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn)
self.pos_embedding = (
None
if shared_pos
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
)
def __call__(
self,
x: mx.array,
mask: mx.array | None = None,
pos_bias: mx.array | None = None,
) -> mx.array:
e = pos_bias if self.shared_pos else self.pos_embedding(x.shape[1], x.shape[1])
x = x + self.attn(self.norm1(x), mask=mask, pos_bias=e)
x = x + self.ffn(self.norm2(x))
return x
class T5Encoder(nn.Module):
"""T5 Encoder (UMT5-XXL configuration)."""
def __init__(
self,
vocab_size: int = 256384,
dim: int = 4096,
dim_attn: int = 4096,
dim_ffn: int = 10240,
num_heads: int = 64,
num_layers: int = 24,
num_buckets: int = 32,
shared_pos: bool = False,
):
super().__init__()
self.dim = dim
self.token_embedding = nn.Embedding(vocab_size, dim)
self.pos_embedding = (
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
if shared_pos
else None
)
self.blocks = [
T5SelfAttentionBlock(
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos
)
for _ in range(num_layers)
]
self.norm = T5LayerNorm(dim)
def __call__(self, ids: mx.array, mask: mx.array | None = None) -> mx.array:
"""
Args:
ids: Token IDs [B, L]
mask: Attention mask [B, L]
Returns:
Hidden states [B, L, dim]
"""
x = self.token_embedding(ids)
e = self.pos_embedding(x.shape[1], x.shape[1]) if self.pos_embedding else None
for block in self.blocks:
x = block(x, mask=mask, pos_bias=e)
x = self.norm(x)
return x

View File

@@ -0,0 +1,281 @@
"""Wan-specific tiled VAE decoding.
Re-exports all tiling utilities from the LTX VAE tiling module and provides
a Wan-specific ``decode_with_tiling`` that adds ``causal_temporal`` support
for non-causal temporal decoders (e.g. Wan2.1 where T latent frames → T*scale
output frames rather than LTX's 1+(T-1)*scale mapping).
# TODO: This function can be refactored to consolidate with
# mlx_video.models.ltx.video_vae.tiling.decode_with_tiling once the
# causal_temporal generalisation is accepted upstream.
"""
from typing import Callable, Optional
import mlx.core as mx
from mlx_video.models.ltx.video_vae.tiling import (
SpatialTilingConfig,
TemporalTilingConfig,
TilingConfig,
map_spatial_slice,
map_temporal_slice,
split_in_spatial,
split_in_temporal,
)
__all__ = [
"SpatialTilingConfig",
"TemporalTilingConfig",
"TilingConfig",
"decode_with_tiling",
"map_spatial_slice",
"map_temporal_slice",
"split_in_spatial",
"split_in_temporal",
]
def decode_with_tiling(
decoder_fn,
latents: mx.array,
tiling_config: TilingConfig,
spatial_scale: int = 32,
temporal_scale: int = 8,
causal: bool = False,
causal_temporal: bool = True,
timestep: Optional[mx.array] = None,
chunked_conv: bool = False,
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
) -> mx.array:
"""Decode latents using tiling to reduce memory usage.
Args:
decoder_fn: Decoder function to call for each tile.
latents: Input latents of shape (B, C, F, H, W).
tiling_config: Tiling configuration.
spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify).
temporal_scale: Temporal scale factor (8 for LTX VAE).
causal: Whether to use causal convolutions.
causal_temporal: Whether the decoder uses causal temporal mapping where
T input frames produce 1+(T-1)*scale output frames. When False, uses
simple scaling where T frames produce T*scale output frames.
Default True (LTX behavior). Set False for non-causal decoders (e.g. Wan2.1).
timestep: Optional timestep for conditioning.
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames.
start_idx: Starting frame index in the full video.
Returns:
Decoded video.
"""
import gc
b, c, f_latent, h_latent, w_latent = latents.shape
# Compute output shape
out_f = (1 + (f_latent - 1) * temporal_scale) if causal_temporal else (f_latent * temporal_scale)
out_h = h_latent * spatial_scale
out_w = w_latent * spatial_scale
# Get tile size and overlap in latent space
if tiling_config.spatial_config is not None:
s_cfg = tiling_config.spatial_config
spatial_tile_size = s_cfg.tile_size_in_pixels // spatial_scale
spatial_overlap = s_cfg.tile_overlap_in_pixels // spatial_scale
else:
spatial_tile_size = max(h_latent, w_latent)
spatial_overlap = 0
if tiling_config.temporal_config is not None:
t_cfg = tiling_config.temporal_config
temporal_tile_size = t_cfg.tile_size_in_frames // temporal_scale
temporal_overlap = t_cfg.tile_overlap_in_frames // temporal_scale
else:
temporal_tile_size = f_latent
temporal_overlap = 0
# Compute intervals for each dimension
if causal_temporal:
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
else:
temporal_intervals = split_in_spatial(temporal_tile_size, temporal_overlap, f_latent)
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
num_t_tiles = len(temporal_intervals.starts)
num_h_tiles = len(height_intervals.starts)
num_w_tiles = len(width_intervals.starts)
total_tiles = num_t_tiles * num_h_tiles * num_w_tiles # noqa: F841
# Initialize output and weight accumulator
# Use float32 for accumulation to avoid precision issues
output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32)
weights = mx.zeros((b, 1, out_f, out_h, out_w), dtype=mx.float32)
mx.eval(output, weights)
tile_idx = 0
for t_idx in range(num_t_tiles):
t_start = temporal_intervals.starts[t_idx]
t_end = temporal_intervals.ends[t_idx]
t_left = temporal_intervals.left_ramps[t_idx]
t_right = temporal_intervals.right_ramps[t_idx]
# Map temporal coordinates
if causal_temporal:
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
else:
out_t_slice, t_mask = map_spatial_slice(t_start, t_end, t_left, t_right, temporal_scale)
for h_idx in range(num_h_tiles):
h_start = height_intervals.starts[h_idx]
h_end = height_intervals.ends[h_idx]
h_left = height_intervals.left_ramps[h_idx]
h_right = height_intervals.right_ramps[h_idx]
# Map height coordinates
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
for w_idx in range(num_w_tiles):
w_start = width_intervals.starts[w_idx]
w_end = width_intervals.ends[w_idx]
w_left = width_intervals.left_ramps[w_idx]
w_right = width_intervals.right_ramps[w_idx]
# Map width coordinates
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
# Extract tile latents (small slice)
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
# Decode tile
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
mx.eval(tile_output)
# Clear tile_latents reference
del tile_latents
# Get actual decoded dimensions
_, _, decoded_t, decoded_h, decoded_w = tile_output.shape
expected_t = out_t_slice.stop - out_t_slice.start
expected_h = out_h_slice.stop - out_h_slice.start
expected_w = out_w_slice.stop - out_w_slice.start
# Handle potential size mismatches (use minimum)
actual_t = min(decoded_t, expected_t)
actual_h = min(decoded_h, expected_h)
actual_w = min(decoded_w, expected_w)
# Build blend mask
t_mask_slice = t_mask[:actual_t] if len(t_mask) > actual_t else t_mask
h_mask_slice = h_mask[:actual_h] if len(h_mask) > actual_h else h_mask
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
blend_mask = (
t_mask_slice.reshape(1, 1, -1, 1, 1) *
h_mask_slice.reshape(1, 1, 1, -1, 1) *
w_mask_slice.reshape(1, 1, 1, 1, -1)
)
# Slice tile output to match
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
# Clear full tile_output
del tile_output
# Compute output coordinates
t_out_start = out_t_slice.start
t_out_end = t_out_start + actual_t
h_out_start = out_h_slice.start
h_out_end = h_out_start + actual_h
w_out_start = out_w_slice.start
w_out_end = w_out_start + actual_w
# Weighted accumulation
weighted_tile = tile_output_slice * blend_mask
# Update output using slice assignment
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
)
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
)
# Force evaluation to free memory
mx.eval(output, weights)
# Clean up tile-specific arrays
del tile_output_slice, weighted_tile, blend_mask
del t_mask_slice, h_mask_slice, w_mask_slice
tile_idx += 1
# Periodic garbage collection and cache clearing
if tile_idx % 4 == 0:
gc.collect()
try:
mx.clear_cache()
except Exception:
pass # May not be available on all platforms
# After completing all spatial tiles for this temporal tile,
# check if any frames are now finalized (no future tiles will contribute)
if on_frames_ready is not None and num_t_tiles > 1:
# Determine the finalized frame boundary
# Frames before the start of the next tile's output region are finalized
if t_idx < num_t_tiles - 1:
# Next tile starts at temporal_intervals.starts[t_idx + 1]
next_tile_start_latent = temporal_intervals.starts[t_idx + 1]
# Map to output frame index (first frame of next tile's contribution)
if next_tile_start_latent == 0:
next_tile_start_out = 0
elif causal_temporal:
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
else:
next_tile_start_out = next_tile_start_latent * temporal_scale
# We need to track how many frames we've already emitted
if not hasattr(decode_with_tiling, '_emitted_frames'):
decode_with_tiling._emitted_frames = 0
emitted = decode_with_tiling._emitted_frames
if next_tile_start_out > emitted:
# Normalize and emit frames [emitted, next_tile_start_out)
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
finalized_weights = mx.maximum(finalized_weights, 1e-8)
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
finalized_output = finalized_output.astype(latents.dtype)
mx.eval(finalized_output)
on_frames_ready(finalized_output, emitted)
decode_with_tiling._emitted_frames = next_tile_start_out
del finalized_output, finalized_weights
gc.collect()
# Normalize by weights
weights = mx.maximum(weights, 1e-8)
output = output / weights
mx.eval(output)
# Emit remaining frames if callback provided
if on_frames_ready is not None:
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
if emitted < out_f:
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
mx.eval(remaining_output)
on_frames_ready(remaining_output, emitted)
del remaining_output
# Reset emitted frames counter for next call
if hasattr(decode_with_tiling, '_emitted_frames'):
del decode_with_tiling._emitted_frames
# Clean up weights
del weights
gc.collect()
# Convert back to original dtype if needed
return output.astype(latents.dtype)

View File

@@ -0,0 +1,97 @@
import mlx.core as mx
import mlx.nn as nn
from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention, _linear_dtype
class WanAttentionBlock(nn.Module):
"""Wan transformer block with learned modulation, self-attn, cross-attn, and FFN."""
def __init__(
self,
dim: int,
ffn_dim: int,
num_heads: int,
window_size: tuple = (-1, -1),
qk_norm: bool = True,
cross_attn_norm: bool = False,
eps: float = 1e-6,
):
super().__init__()
# Self-attention
self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
# Cross-attention (with optional norm on context)
self.norm3 = (
WanLayerNorm(dim, eps, elementwise_affine=True)
if cross_attn_norm
else None
)
self.cross_attn = WanCrossAttention(dim, num_heads, qk_norm, eps)
# Feed-forward
self.norm2 = WanLayerNorm(dim, eps)
self.ffn = WanFFN(dim, ffn_dim)
# Learned modulation: 6 vectors for scale/shift/gate (kept in float32 for precision)
self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(mx.float32)
def __call__(
self,
x: mx.array,
e: mx.array,
seq_lens: list,
grid_sizes: list,
freqs: mx.array,
context: mx.array,
context_lens: list | None = None,
cross_kv_cache: tuple | None = None,
rope_cos_sin: tuple | None = None,
attn_mask: mx.array | None = None,
) -> mx.array:
# Modulation: compute in float32 for precision, matching the reference
# which keeps residual x in float32 via torch.amp.autocast(dtype=float32).
# By keeping modulation in float32, type promotion ensures the residual
# stream stays float32 throughout all 30 layers (gate * output + x → float32).
mod = self.modulation + e # float32
e0, e1, e2, e3, e4, e5 = (
mod[:, :, 0, :], # shift for self-attn
mod[:, :, 1, :], # scale for self-attn
mod[:, :, 2, :], # gate for self-attn
mod[:, :, 3, :], # shift for ffn
mod[:, :, 4, :], # scale for ffn
mod[:, :, 5, :], # gate for ffn
)
# Self-attention with modulation (hidden state stays in w_dtype)
x_mod = self.norm1(x) * (1 + e1) + e0
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs, rope_cos_sin=rope_cos_sin, attn_mask=attn_mask)
x = x + y * e2
# Cross-attention (no modulation, just norm)
x_cross = self.norm3(x) if self.norm3 is not None else x
x = x + self.cross_attn(x_cross, context, context_lens, kv_cache=cross_kv_cache)
# FFN with modulation
x_mod = self.norm2(x) * (1 + e4) + e3
y = self.ffn(x_mod)
x = x + y * e5
return x
class WanFFN(nn.Module):
"""Gated feed-forward network with GELU(tanh) activation."""
def __init__(self, dim: int, ffn_dim: int):
super().__init__()
self.fc1 = nn.Linear(dim, ffn_dim)
self.act = nn.GELU(approx="tanh")
self.fc2 = nn.Linear(ffn_dim, dim)
def __call__(self, x: mx.array) -> mx.array:
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
x_w = x.astype(_linear_dtype(self.fc1))
return self.fc2(self.act(self.fc1(x_w)))

589
mlx_video/models/wan/vae.py Normal file
View File

@@ -0,0 +1,589 @@
"""3D VAE Decoder for Wan2.1/2.2 (compression 4×8×8).
Module structure mirrors original PyTorch checkpoint key hierarchy
so weights load directly without key sanitization.
"""
import mlx.core as mx
import mlx.nn as nn
import numpy as np
CACHE_T = 2
# Per-channel normalization statistics for z_dim=16
VAE_MEAN = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921,
]
VAE_STD = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160,
]
class CausalConv3d(nn.Module):
"""3D convolution with causal temporal padding."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple,
stride: int | tuple = 1,
padding: int | tuple = 0,
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(padding, int):
padding = (padding, padding, padding)
self.kernel_size = kernel_size
self.stride = stride
# Causal padding: match reference formula dilation*(k-1) + (1-stride)
# With dilation=1: k-stride (pads left only, no future context)
self._causal_pad_t = kernel_size[0] - stride[0]
self._pad_h = padding[1]
self._pad_w = padding[2]
# MLX Conv3d: weight shape [O, D, H, W, I]
self.weight = mx.zeros((out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels))
self.bias = mx.zeros((out_channels,))
def __call__(self, x: mx.array, cache_x: mx.array = None) -> mx.array:
"""x: [B, C, T, H, W] (channel-first)"""
b, c, t, h, w = x.shape
causal_pad = self._causal_pad_t
if cache_x is not None and causal_pad > 0:
x = mx.concatenate([cache_x, x], axis=2)
causal_pad = max(0, causal_pad - cache_x.shape[2])
if causal_pad > 0:
pad_t = mx.zeros((b, c, causal_pad, h, w), dtype=x.dtype)
x = mx.concatenate([pad_t, x], axis=2)
if self._pad_h > 0 or self._pad_w > 0:
x = mx.pad(x, [(0, 0), (0, 0), (0, 0),
(self._pad_h, self._pad_h), (self._pad_w, self._pad_w)])
x = x.transpose(0, 2, 3, 4, 1) # [B, T, H, W, C]
out = self._conv3d(x)
return out.transpose(0, 4, 1, 2, 3) # [B, O, T', H', W']
def _conv3d(self, x: mx.array) -> mx.array:
"""3D conv via sliding window + 2D conv per time step.
x: [B, T, H, W, C_in] -> [B, T_out, H_out, W_out, C_out]
"""
b, t, h, w, c_in = x.shape
kt, kh, kw = self.kernel_size
st, sh, sw = self.stride
t_out = (t - kt) // st + 1
# Pre-reshape weight: [O, D, H, W, I] -> [O, H, W, D*I]
w_2d = self.weight.transpose(0, 2, 3, 1, 4).reshape(
self.weight.shape[0], kh, kw, kt * c_in
)
outputs = []
for t_i in range(t_out):
t_start = t_i * st
window = x[:, t_start : t_start + kt]
window = window.transpose(0, 2, 3, 1, 4).reshape(b, h, w, kt * c_in)
out_2d = mx.conv2d(window, w_2d, stride=(sh, sw)) + self.bias
outputs.append(out_2d)
return mx.stack(outputs, axis=1)
class RMS_norm(nn.Module):
"""Channel-first L2 normalization matching original Wan VAE.
Uses F.normalize (L2 norm) with learned scale, equivalent to RMS norm.
images=True: gamma shape (dim, 1, 1) for 4D (per-frame) input.
images=False: gamma shape (dim, 1, 1, 1) for 5D video input.
"""
def __init__(self, dim: int, channel_first: bool = True, images: bool = True):
super().__init__()
self.channel_first = channel_first
self.scale = dim**0.5
if channel_first:
broadcastable = (1, 1) if images else (1, 1, 1)
self.gamma = mx.ones((dim, *broadcastable))
else:
self.gamma = mx.ones((dim,))
def __call__(self, x: mx.array) -> mx.array:
norm_dim = 1 if self.channel_first else -1
# L2 normalize along channel dim (matches F.normalize)
norm = mx.sqrt(mx.clip(mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None))
return (x / norm) * self.scale * self.gamma
class ResidualBlock(nn.Module):
"""Residual block with causal 3D convolutions.
Uses `residual` list with None gaps to match original PyTorch
nn.Sequential indices: [0]=norm, [1]=SiLU, [2]=conv, [3]=norm,
[4]=SiLU, [5]=Dropout, [6]=conv. Only indices 0,2,3,6 have params.
"""
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.residual = [
RMS_norm(in_dim, images=False), # [0]
None, # [1] SiLU
CausalConv3d(in_dim, out_dim, 3, padding=1), # [2]
RMS_norm(out_dim, images=False), # [3]
None, # [4] SiLU
None, # [5] Dropout
CausalConv3d(out_dim, out_dim, 3, padding=1), # [6]
]
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
h = x if self.shortcut is None else self.shortcut(x)
if feat_cache is not None:
# First conv: norm -> silu -> [cache] -> conv
x = nn.silu(self.residual[0](x))
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:]
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
x = self.residual[2](x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
# Second conv: norm -> silu -> [cache] -> conv
x = nn.silu(self.residual[3](x))
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:]
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
x = self.residual[6](x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = nn.silu(self.residual[0](x))
x = self.residual[2](x)
x = nn.silu(self.residual[3](x))
x = self.residual[6](x)
return x + h
class AttentionBlock(nn.Module):
"""Single-head spatial self-attention."""
def __init__(self, dim: int):
super().__init__()
self.norm = RMS_norm(dim, images=True)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
def __call__(self, x: mx.array) -> mx.array:
"""x: [B, C, T, H, W]"""
identity = x
b, c, t, h, w = x.shape
# [B,C,T,H,W] -> [B,T,C,H,W] -> [BT,C,H,W] -> norm -> [BT,H,W,C]
x = x.transpose(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
x = self.norm(x)
x = x.transpose(0, 2, 3, 1) # [BT, H, W, C]
qkv = self.to_qkv(x) # [BT, H, W, 3C]
qkv = qkv.reshape(b * t, h * w, 3, c).transpose(2, 0, 1, 3)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q[:, None, :, :] # [BT, 1, HW, C]
k = k[:, None, :, :]
v = v[:, None, :, :]
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=c**-0.5)
out = out.squeeze(1).reshape(b * t, h, w, c) # [BT, H, W, C]
out = self.proj(out) # [BT, H, W, C]
out = out.reshape(b, t, h, w, c).transpose(0, 4, 1, 2, 3) # [B, C, T, H, W]
return out + identity
class Resample(nn.Module):
"""Resample block matching original Wan VAE structure.
Supports both upsampling (decoder) and downsampling (encoder).
Uses list-based param storage to match original nn.Sequential key hierarchy.
"""
def __init__(self, dim: int, mode: str):
super().__init__()
assert mode in ("upsample2d", "upsample3d", "downsample2d", "downsample3d")
self.mode = mode
self.dim = dim
if mode.startswith("upsample"):
# resample.0 = Upsample (no params), resample.1 = Conv2d
self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)]
if mode == "upsample3d":
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
else:
# resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2)
self.resample = [None, nn.Conv2d(dim, dim, 3, stride=2)]
if mode == "downsample3d":
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
"""x: [B, C, T, H, W]"""
b, c, t, h, w = x.shape
if self.mode == "upsample3d":
# Temporal upsample via learned conv
x_t = self.time_conv(x) # [B, 2C, T, H, W]
x_t = x_t.reshape(b, 2, c, t, h, w)
x = mx.stack([x_t[:, 0], x_t[:, 1]], axis=3).reshape(b, c, t * 2, h, w)
t = t * 2
if self.mode.startswith("upsample"):
# Per-frame spatial upsample: nearest 2x + Conv2d
x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C]
x = mx.repeat(x, 2, axis=1)
x = mx.repeat(x, 2, axis=2)
x = self.resample[1](x) # Conv2d [BT, 2H, 2W, C//2]
c_out = x.shape[-1]
return x.reshape(b, t, h * 2, w * 2, c_out).transpose(0, 4, 1, 2, 3)
else:
# Per-frame spatial downsample: ZeroPad(0,1,0,1) + Conv2d(stride=2)
x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C]
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) # ZeroPad2d(0,1,0,1)
x = self.resample[1](x) # Conv2d stride=2
c_out = x.shape[-1]
h_out, w_out = x.shape[1], x.shape[2]
x = x.reshape(b, t, h_out, w_out, c_out).transpose(0, 4, 1, 2, 3)
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
# First chunk: save x, skip time_conv
feat_cache[idx] = x
feat_idx[0] += 1
else:
# Subsequent chunks: use cached frame as temporal context
cache_x = x[:, :, -1:]
x = self.time_conv(
x, cache_x=feat_cache[idx][:, :, -1:])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.time_conv(x)
return x
class Decoder3d(nn.Module):
"""3D VAE Decoder matching Wan2.1 architecture.
Uses flat `middle` and `upsamples` lists to match original
PyTorch nn.Sequential weight key hierarchy.
"""
def __init__(
self,
dim: int = 96,
z_dim: int = 16,
dim_mult: list = None,
num_res_blocks: int = 2,
temporal_upsample: list = None,
):
super().__init__()
if dim_mult is None:
dim_mult = [1, 2, 4, 4]
if temporal_upsample is None:
temporal_upsample = [True, True, False]
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# Middle: [ResBlock, AttentionBlock, ResBlock]
self.middle = [
ResidualBlock(dims[0], dims[0]),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0]),
]
# Flat upsample list matching original nn.Sequential indexing
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
if i in (1, 2, 3):
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim))
in_dim = out_dim
if i != len(dim_mult) - 1:
mode = "upsample3d" if temporal_upsample[i] else "upsample2d"
upsamples.append(Resample(out_dim, mode=mode))
self.upsamples = upsamples
# Output head: [RMS_norm, SiLU (no params), CausalConv3d]
self.head = [
RMS_norm(dims[-1], images=False), # [0]
None, # [1] SiLU
CausalConv3d(dims[-1], 3, 3, padding=1), # [2]
]
def __call__(self, x: mx.array) -> mx.array:
"""x: [B, z_dim, T, H, W] -> [B, 3, T_out, H_out, W_out]"""
x = self.conv1(x)
for layer in self.middle:
x = layer(x)
for layer in self.upsamples:
x = layer(x)
x = nn.silu(self.head[0](x))
x = self.head[2](x)
return x
class Encoder3d(nn.Module):
"""3D VAE Encoder matching Wan2.1 architecture.
Mirror of Decoder3d with downsampling instead of upsampling.
Uses flat lists to match original PyTorch nn.Sequential weight key hierarchy.
"""
def __init__(
self,
dim: int = 96,
z_dim: int = 16,
dim_mult: list = None,
num_res_blocks: int = 2,
temporal_downsample: list = None,
):
super().__init__()
if dim_mult is None:
dim_mult = [1, 2, 4, 4]
if temporal_downsample is None:
temporal_downsample = [False, True, True]
dims = [dim * u for u in [1] + dim_mult]
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
# Flat downsample list matching original nn.Sequential indexing
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
for _ in range(num_res_blocks):
downsamples.append(ResidualBlock(in_dim, out_dim))
in_dim = out_dim
if i != len(dim_mult) - 1:
mode = "downsample3d" if temporal_downsample[i] else "downsample2d"
downsamples.append(Resample(out_dim, mode=mode))
self.downsamples = downsamples
# Middle: [ResBlock, AttentionBlock, ResBlock]
self.middle = [
ResidualBlock(dims[-1], dims[-1]),
AttentionBlock(dims[-1]),
ResidualBlock(dims[-1], dims[-1]),
]
# Output head: [RMS_norm, SiLU (no params), CausalConv3d]
self.head = [
RMS_norm(dims[-1], images=False),
None, # SiLU
CausalConv3d(dims[-1], z_dim, 3, padding=1),
]
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
"""x: [B, 3, T, H, W] -> [B, z_dim, T_lat, H_lat, W_lat]"""
if feat_cache is not None:
# conv1 with caching
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:]
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
x = self.conv1(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
for layer in self.downsamples:
if feat_cache is not None and isinstance(layer, (ResidualBlock, Resample)):
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = layer(x)
for layer in self.middle:
if feat_cache is not None and isinstance(layer, ResidualBlock):
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = layer(x)
if feat_cache is not None:
# Head: norm -> silu -> [cache] -> conv
x = nn.silu(self.head[0](x))
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:]
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
x = self.head[2](x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = nn.silu(self.head[0](x))
x = self.head[2](x)
return x
class WanVAE(nn.Module):
"""Wan2.1 VAE wrapper with per-channel normalization.
Supports both encode (for I2V) and decode (for all models).
"""
def __init__(self, z_dim: int = 16, encoder: bool = False):
super().__init__()
self.z_dim = z_dim
self.mean = mx.array(VAE_MEAN)
self.std = mx.array(VAE_STD)
self.inv_std = 1.0 / self.std
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim=96, z_dim=z_dim)
if encoder:
self.encoder = Encoder3d(dim=96, z_dim=z_dim * 2)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
def encode(self, x: mx.array) -> mx.array:
"""Encode video to normalized latent using chunked encoding.
Uses chunked encoding with temporal caching to match reference behavior.
First frame encoded alone, then 4-frame chunks with cached context.
Args:
x: Video [B, 3, T, H, W] in [-1, 1]
Returns:
Normalized latent [B, z_dim, T_lat, H_lat, W_lat]
"""
# Count cacheable CausalConv3d slots in encoder
num_slots = self._count_encoder_cache_slots()
feat_cache = [None] * num_slots
t = x.shape[2]
num_chunks = 1 + (t - 1) // 4
out = None
for i in range(num_chunks):
feat_idx = [0]
if i == 0:
chunk = x[:, :, :1]
else:
chunk = x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i]
chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx)
if out is None:
out = chunk_out
else:
out = mx.concatenate([out, chunk_out], axis=2)
mu, _ = mx.split(self.conv1(out), 2, axis=1)
# Normalize: (mu - mean) * inv_std
mean = self.mean.reshape(1, -1, 1, 1, 1)
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
return (mu - mean) * inv_std
def _count_encoder_cache_slots(self) -> int:
"""Count CausalConv3d that participate in chunked encoding cache."""
count = 1 # encoder.conv1
for layer in self.encoder.downsamples:
if isinstance(layer, ResidualBlock):
count += 2 # two convs in residual path
elif isinstance(layer, Resample) and layer.mode == "downsample3d":
count += 1 # time_conv
for layer in self.encoder.middle:
if isinstance(layer, ResidualBlock):
count += 2
count += 1 # encoder.head CausalConv3d
return count
def decode(self, z: mx.array) -> mx.array:
"""Decode latent to video.
Args:
z: Normalized latent [B, z_dim, T, H, W]
Returns:
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
"""
mean = self.mean.reshape(1, -1, 1, 1, 1)
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
z = z / inv_std + mean
x = self.conv2(z)
out = self.decoder(x)
return mx.clip(out, -1, 1)
def decode_tiled(self, z: mx.array, tiling_config=None) -> mx.array:
"""Decode latent to video using tiling to reduce memory usage.
Splits the latent tensor into overlapping spatial/temporal tiles,
decodes each tile independently, and blends them with trapezoidal
masks. Reuses the LTX-2 tiling infrastructure.
Args:
z: Normalized latent [B, z_dim, T, H, W]
tiling_config: Optional TilingConfig. If None, uses default.
Returns:
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
"""
from mlx_video.models.wan.tiling import TilingConfig, decode_with_tiling
if tiling_config is None:
tiling_config = TilingConfig.default()
# Check if tiling is actually needed
_, _, f, h, w = z.shape
needs_tiling = False
if tiling_config.spatial_config is not None:
s_tile = tiling_config.spatial_config.tile_size_in_pixels // 8
if h > s_tile or w > s_tile:
needs_tiling = True
if tiling_config.temporal_config is not None:
t_tile = tiling_config.temporal_config.tile_size_in_frames // 4
if f > t_tile:
needs_tiling = True
if not needs_tiling:
return self.decode(z)
# Denormalize once (small tensor), then tile the denormalized latents
mean = self.mean.reshape(1, -1, 1, 1, 1)
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
z_denorm = z / inv_std + mean
def tile_decode(tile_latents, **kwargs):
x = self.conv2(tile_latents)
out = self.decoder(x)
return mx.clip(out, -1, 1)
return decode_with_tiling(
decoder_fn=tile_decode,
latents=z_denorm,
tiling_config=tiling_config,
spatial_scale=8, # 3× spatial 2× upsamples = 8×
temporal_scale=4, # 2× temporal upsamples × 2 = 4×
causal_temporal=False, # Wan2.1 uses non-causal temporal (T → 4T)
)

File diff suppressed because it is too large Load Diff