Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.
This commit is contained in:
349
mlx_video/models/wan_2/README.md
Normal file
349
mlx_video/models/wan_2/README.md
Normal file
@@ -0,0 +1,349 @@
|
||||
|
||||
## Wan2.1 / Wan2.2
|
||||
|
||||
Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.com/Wan-Video/Wan2.2) are text-to-video diffusion models built on a DiT (Diffusion Transformer) backbone with a T5 text encoder and 3D VAE.
|
||||
|
||||
They share the same model architecture — the difference is in the inference pipeline:
|
||||
|
||||
| | Wan2.1 | Wan2.2 T2V-14B | Wan2.2 I2V-14B | Wan2.2 TI2V-5B |
|
||||
|---|--------|--------|--------|--------|
|
||||
| **Task** | Text-to-Video | Text-to-Video | Image-to-Video | Text+Image-to-Video |
|
||||
| **Pipeline** | Single model | Dual model | Dual model | Single model |
|
||||
| **Sizes** | 1.3B, 14B | 14B | 14B | 5B |
|
||||
| **Resolution** | 480P (1.3B), 720P (14B) | 720P | 720P | 720P |
|
||||
| **Steps** | 50 | 40 | 40 | 40 |
|
||||
| **Guidance** | 5.0 (fixed) | 3.0 / 4.0 | 3.5 / 3.5 | 5.0 (fixed) |
|
||||
| **Shift** | 5.0 | 12.0 | 5.0 | 5.0 |
|
||||
| **VAE** | Wan2.1 (z=16) | Wan2.1 (z=16) | Wan2.1 (z=16) + encoder | Wan2.2 (z=48) |
|
||||
|
||||
### Step 1: Download Weights
|
||||
|
||||
Download the original PyTorch checkpoints from HuggingFace using the `huggingface-cli` tool (install with `pip install huggingface_hub`):
|
||||
|
||||
**Wan2.1**
|
||||
```bash
|
||||
# Text-to-Video 1.3B (fast, fits in ~4 GB)
|
||||
huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir ./Wan2.1-T2V-1.3B
|
||||
|
||||
# Text-to-Video 14B
|
||||
huggingface-cli download Wan-AI/Wan2.1-T2V-14B --local-dir ./Wan2.1-T2V-14B
|
||||
```
|
||||
|
||||
**Wan2.2**
|
||||
```bash
|
||||
# Text-to-Video 14B
|
||||
huggingface-cli download Wan-AI/Wan2.2-T2V-A14B --local-dir ./Wan2.2-T2V-A14B
|
||||
|
||||
# Image-to-Video 14B
|
||||
huggingface-cli download Wan-AI/Wan2.2-I2V-A14B --local-dir ./Wan2.2-I2V-A14B
|
||||
|
||||
# Text+Image-to-Video 5B (uses a different VAE — z_dim=48)
|
||||
huggingface-cli download Wan-AI/Wan2.2-TI2V-5B --local-dir ./Wan2.2-TI2V-5B
|
||||
```
|
||||
|
||||
Each downloaded directory will have this structure:
|
||||
|
||||
```
|
||||
Wan2.1-T2V-*/
|
||||
├── models_t5_umt5-xxl-enc-bf16.pth # T5 text encoder
|
||||
├── Wan2.1_VAE.pth # 3D VAE
|
||||
└── diffusion_pytorch_model*.safetensors # transformer (single)
|
||||
|
||||
Wan2.2-T2V-A14B/ or Wan2.2-I2V-A14B/
|
||||
├── models_t5_umt5-xxl-enc-bf16.pth
|
||||
├── Wan2.1_VAE.pth
|
||||
├── low_noise_model/ # dual-model low-noise transformer
|
||||
└── high_noise_model/ # dual-model high-noise transformer
|
||||
|
||||
Wan2.2-TI2V-5B/
|
||||
├── models_t5_umt5-xxl-enc-bf16.pth
|
||||
├── Wan2.2_VAE.pth # different VAE (z_dim=48)
|
||||
└── diffusion_pytorch_model*.safetensors # transformer (single)
|
||||
```
|
||||
|
||||
> **Wan2.2 I2V-14B** shares the same directory structure as Wan2.2 T2V. The conversion script auto-detects I2V from the model's `config.json` (`model_type: "i2v"`, `in_dim: 36`).
|
||||
|
||||
### Step 2: Convert to MLX Format
|
||||
|
||||
The conversion script auto-detects the model version from the directory structure (presence of `low_noise_model/` → Wan2.2 dual model) and the model type from `config.json` (I2V vs T2V).
|
||||
|
||||
#### Wan2.1 T2V 1.3B
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.1-T2V-1.3B \
|
||||
--output-dir ./Wan2.1-T2V-1.3B-MLX
|
||||
```
|
||||
|
||||
#### Wan2.1 T2V 14B
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.1-T2V-14B \
|
||||
--output-dir ./Wan2.1-T2V-14B-MLX
|
||||
```
|
||||
|
||||
#### Wan2.2 T2V 14B
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.2-T2V-A14B \
|
||||
--output-dir ./Wan2.2-T2V-A14B-MLX
|
||||
```
|
||||
|
||||
#### Wan2.2 I2V 14B
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.2-I2V-A14B \
|
||||
--output-dir ./Wan2.2-I2V-A14B-MLX
|
||||
```
|
||||
|
||||
The I2V model is auto-detected from `config.json`; the output will include a `vae_encoder.safetensors` used to encode the conditioning image.
|
||||
|
||||
#### Wan2.2 TI2V 5B
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.2-TI2V-5B \
|
||||
--output-dir ./Wan2.2-TI2V-5B-MLX
|
||||
```
|
||||
|
||||
The TI2V model uses a different VAE (`z_dim=48`, `vae_stride=(4,16,16)`) and is auto-detected during conversion.
|
||||
|
||||
---
|
||||
|
||||
You can also pass `--model-version 2.1` or `--model-version 2.2` to force the version instead of relying on auto-detection.
|
||||
|
||||
#### Conversion Options
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `--checkpoint-dir` | (required) | Path to original PyTorch checkpoint directory |
|
||||
| `--output-dir` | `wan_mlx_model` | Output path for MLX model |
|
||||
| `--dtype` | `bfloat16` | Target dtype (`float16`, `float32`, `bfloat16`) |
|
||||
| `--model-version` | `auto` | Model version: `2.1`, `2.2`, or `auto` |
|
||||
| `--quantize` | off | Quantize transformer weights for reduced memory |
|
||||
| `--bits` | `4` | Quantization bits: `4` or `8` |
|
||||
| `--group-size` | `64` | Quantization group size: `32`, `64`, or `128` |
|
||||
|
||||
The converter produces:
|
||||
```
|
||||
wan_mlx/
|
||||
├── config.json # Model configuration
|
||||
├── t5_encoder.safetensors # T5 UMT5-XXL text encoder
|
||||
├── vae.safetensors # 3D VAE decoder
|
||||
├── vae_encoder.safetensors # 3D VAE encoder (I2V-14B only)
|
||||
├── model.safetensors # (Wan2.1) Single transformer
|
||||
├── low_noise_model.safetensors # (Wan2.2) Low-noise transformer
|
||||
└── high_noise_model.safetensors # (Wan2.2) High-noise transformer
|
||||
```
|
||||
|
||||
### Step 3: Generate Video
|
||||
|
||||
#### Wan2.1 T2V 1.3B
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.gemer \
|
||||
--model-dir ./Wan2.1-T2V-1.3B-MLX \
|
||||
--prompt "A cat playing piano in a cozy living room, cinematic lighting" \
|
||||
--width 832 --height 480 --num-frames 81 \
|
||||
--steps 50 --guide-scale 5.0 \
|
||||
--seed 42 \
|
||||
--output-path wan21_1b.mp4
|
||||
```
|
||||
|
||||
#### Wan2.1 T2V 14B
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.gemer \
|
||||
--model-dir ./Wan2.1-T2V-14B-MLX \
|
||||
--prompt "A woman walks through a misty forest at dawn, slow motion, cinematic" \
|
||||
--width 1280 --height 704 --num-frames 81 \
|
||||
--steps 50 --guide-scale 5.0 \
|
||||
--seed 42 \
|
||||
--output-path wan21_14b.mp4
|
||||
```
|
||||
|
||||
> **Tip**: If the first few frames look washed out or have color artifacts, add `--trim-first-frames 1` to generate 4 extra frames at the start and discard them. With the `unipc` scheduler (default), **10 steps** often gives satisfying results — useful for quick iteration.
|
||||
|
||||
#### Wan2.2 T2V 14B
|
||||
|
||||
Wan2.2 uses a dual-model pipeline (separate high-noise and low-noise transformers) and takes guidance as a `high,low` pair:
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.generate \
|
||||
--model-dir ./Wan2.2-T2V-A14B-MLX \
|
||||
--prompt "Two astronauts playing chess on the surface of the moon, dramatic lighting, 8K" \
|
||||
--negative-prompt "low quality, blurry, distorted" \
|
||||
--width 1280 --height 704 --num-frames 81 \
|
||||
--steps 40 --guide-scale "3.0,4.0" \
|
||||
--seed 42 \
|
||||
--output-path wan22_t2v.mp4
|
||||
```
|
||||
|
||||
> **Tip**: With the `unipc` scheduler (default), **10 steps** often produces satisfying results for 14B models — a significant speed-up with minimal quality loss. Try `--steps 10` for quick iterations.
|
||||
|
||||
#### Wan2.2 I2V 14B
|
||||
|
||||
Image-to-video: animates a starting image guided by a text prompt. Pass the image with `--image`:
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.generate \
|
||||
--model-dir ./Wan2.2-I2V-A14B-MLX \
|
||||
--image ./my_photo.png \
|
||||
--prompt "The person slowly turns their head and smiles, cinematic, natural lighting" \
|
||||
--negative-prompt "low quality, blurry, distorted" \
|
||||
--width 1280 --height 704 --num-frames 81 \
|
||||
--steps 40 --guide-scale "3.5,3.5" \
|
||||
--seed 42 \
|
||||
--output-path wan22_i2v.mp4
|
||||
```
|
||||
|
||||
> **Tip**: As with T2V, `--steps 10` with the `unipc` scheduler is often sufficient for fast prototyping.
|
||||
|
||||
#### Wan2.2 TI2V 5B
|
||||
|
||||
Text+image-to-video: a single-model variant with a larger VAE (`z_dim=48`). Resolution must be divisible by **32** (not 16 as with other models):
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.generate \
|
||||
--model-dir ./Wan2.2-TI2V-5B-MLX \
|
||||
--image ./my_photo.png \
|
||||
--prompt "The subject waves hello, warm sunlight, film grain" \
|
||||
--width 1280 --height 704 --num-frames 41 \
|
||||
--steps 40 --guide-scale 5.0 \
|
||||
--seed 42 \
|
||||
--output-path wan22_ti2v.mp4
|
||||
```
|
||||
|
||||
> **Note**: The 5B model is fast — 40 steps run quickly and are recommended for best quality.
|
||||
|
||||
> **Frame count**: `--num-frames` must satisfy `4n+1` for all models (e.g. 5, 9, 13, 21, 41, 81, 101 …).
|
||||
|
||||
> **Resolution**: Always use the model's native resolution. While generation will succeed at other sizes, mismatched resolutions or aspect ratios are likely to produce visual artifacts. Preferred resolutions are:
|
||||
> - **480P** — 832×480 (landscape) or 480×832 (portrait) — for Wan2.1 1.3B
|
||||
> - **720P** — 1280×704 (landscape) or 704×1280 (portrait) — for Wan2.1 14B, Wan2.2 T2V/I2V/TI2V
|
||||
|
||||
#### Generation Options
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `--model-dir` | (required) | Path to converted MLX model directory |
|
||||
| `--prompt` | (required) | Text prompt |
|
||||
| `--image` | — | Input image path (I2V and TI2V modes) |
|
||||
| `--negative-prompt` | config default | Negative guidance prompt |
|
||||
| `--width` | `1280` | Output width in pixels |
|
||||
| `--height` | `704` | Output height in pixels |
|
||||
| `--num-frames` | `81` | Number of frames (must be `4n+1`) |
|
||||
| `--steps` | config default | Diffusion steps |
|
||||
| `--guide-scale` | config default | Guidance scale; use `"high,low"` pair for Wan2.2 dual models |
|
||||
| `--shift` | config default | Noise schedule shift |
|
||||
| `--seed` | `-1` (random) | Random seed for reproducibility |
|
||||
| `--output-path` | `output.mp4` | Output video file path |
|
||||
| `--scheduler` | `unipc` | Solver: `euler`, `dpm++`, or `unipc` |
|
||||
| `--trim-first-frames` | `0` | Drop N leading frames (fixes first-frame artifacts on 14B models) |
|
||||
| `--tiling` | `auto` | VAE tiling: `auto`, `none`, `spatial`, `temporal` |
|
||||
|
||||
### Quantization (Reduced Memory)
|
||||
|
||||
Quantize the transformer weights to reduce memory usage by ~3.4×. Quantization is supported for all model variants and is especially important for running 14B models on devices with limited unified memory:
|
||||
|
||||
```bash
|
||||
# Convert with 4-bit quantization (works for any variant)
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.1-T2V-1.3B \
|
||||
--output-dir ./Wan2.1-T2V-1.3B-MLX-Q4 \
|
||||
--quantize --bits 4 --group-size 64
|
||||
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.1-T2V-14B \
|
||||
--output-dir ./Wan2.1-T2V-14B-MLX-Q4 \
|
||||
--quantize --bits 4 --group-size 64
|
||||
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.2-T2V-A14B \
|
||||
--output-dir ./Wan2.2-T2V-A14B-MLX-Q4 \
|
||||
--quantize --bits 4 --group-size 64
|
||||
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.2-I2V-A14B \
|
||||
--output-dir ./Wan2.2-I2V-A14B-MLX-Q4 \
|
||||
--quantize --bits 4 --group-size 64
|
||||
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.2-TI2V-5B \
|
||||
--output-dir ./Wan2.2-TI2V-5B-MLX-Q4 \
|
||||
--quantize --bits 4 --group-size 64
|
||||
```
|
||||
|
||||
You can also quantize an already-converted MLX model without re-converting from PyTorch:
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.2-T2V-A14B-MLX \
|
||||
--output-dir ./Wan2.2-T2V-A14B-MLX-Q4 \
|
||||
--quantize-only --bits 4
|
||||
```
|
||||
|
||||
Quantized models are used exactly the same way — the quantization is auto-detected from `config.json`:
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.generate \
|
||||
--model-dir ./Wan2.2-T2V-A14B-MLX-Q4 \
|
||||
--prompt "A cat playing piano"
|
||||
```
|
||||
|
||||
**What gets quantized**: Self-attention (Q/K/V/O), cross-attention (Q/K/V/O), and FFN (fc1/fc2) — 10 layers × N blocks = ~95% of model weights. Embeddings, norms, and the output head remain in bfloat16 for precision.
|
||||
|
||||
| Model | BF16 Size | 4-bit Size | Notes |
|
||||
|-------|-----------|------------|-------|
|
||||
| 1.3B | 2.7 GB | 799 MB | ~3.4x smaller |
|
||||
| 14B | ~28 GB | ~8 GB | Enables running on 16GB devices |
|
||||
|
||||
> **Note**: On Apple Silicon, the 1.3B model fits comfortably in unified memory at bf16. Quantization reduces memory but may not speed up inference for small models. For the 14B model, quantization is essential to fit in memory and will also improve speed.
|
||||
|
||||
### Wan Model Specifications
|
||||
|
||||
**Transformer (14B)**
|
||||
- 40 layers, 40 attention heads, dim 5120, head dim 128
|
||||
- 3-way factorized RoPE (temporal + spatial)
|
||||
- 14.29B parameters
|
||||
|
||||
**Transformer (1.3B, Wan2.1 only)**
|
||||
- 30 layers, 12 attention heads, dim 1536, head dim 128
|
||||
- Same architecture, smaller scale
|
||||
|
||||
**Text Encoder** — UMT5-XXL (5.68B parameters)
|
||||
- 24 layers, 64 heads, dim 4096, vocab 256K
|
||||
|
||||
**VAE** — 3D causal convolution decoder (72.6M parameters)
|
||||
- Latent channels: 16
|
||||
- Compression: 4× temporal, 8× spatial
|
||||
|
||||
---
|
||||
|
||||
## LoRA Support
|
||||
|
||||
LoRA's can be used with the `--lora-high` and `--lora-low` command line switches.
|
||||
|
||||
For example, for using the the distilled [Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) LoRA, use the following command. Lightning speeds up generation by using only 4 steps and a CFG scale of 1.
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.generate \
|
||||
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-T2V-A14B-MLX \
|
||||
--width 480 \
|
||||
--height 704 \
|
||||
--num-frames 41 \
|
||||
--prompt "Two dogs of the poodle breed sitting on a beach wearing sunglasses, nodding with their heads, close up, cinematic, sunset" \
|
||||
--steps 4 \
|
||||
--guide-scale 1 \
|
||||
--trim-first-frames 1 \
|
||||
--seed 2391784614 \
|
||||
--lora-high /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/high_noise_model.safetensors 1 \
|
||||
--lora-low /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/low_noise_model.safetensors 1
|
||||
```
|
||||
|
||||
## Enjoy
|
||||
|
||||

|
||||
2
mlx_video/models/wan_2/__init__.py
Normal file
2
mlx_video/models/wan_2/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
221
mlx_video/models/wan_2/attention.py
Normal file
221
mlx_video/models/wan_2/attention.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .rope import rope_apply
|
||||
|
||||
|
||||
def _linear_dtype(layer) -> mx.Dtype:
|
||||
"""Get the compute dtype of a linear layer, handling QuantizedLinear and LoRA wrappers."""
|
||||
# Unwrap LoRA wrapper to get the underlying linear layer
|
||||
inner = getattr(layer, "linear", layer)
|
||||
if isinstance(inner, nn.QuantizedLinear):
|
||||
return inner.scales.dtype
|
||||
return inner.weight.dtype
|
||||
|
||||
|
||||
class WanRMSNorm(nn.Module):
|
||||
"""RMS normalization with learnable scale."""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = mx.ones((dim,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return mx.fast.rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
class WanLayerNorm(nn.Module):
|
||||
"""LayerNorm computed in float32, with optional affine."""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if elementwise_affine:
|
||||
self.weight = mx.ones((dim,))
|
||||
self.bias = mx.zeros((dim,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.elementwise_affine:
|
||||
return mx.fast.layer_norm(x, self.weight, self.bias, self.eps)
|
||||
else:
|
||||
return mx.fast.layer_norm(x, None, None, self.eps)
|
||||
|
||||
|
||||
class WanSelfAttention(nn.Module):
|
||||
"""Self-attention with QK normalization and 3-way factorized RoPE."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
window_size: tuple = (-1, -1),
|
||||
qk_norm: bool = True,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.window_size = window_size
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
|
||||
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
||||
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
seq_lens: list,
|
||||
grid_sizes: list,
|
||||
freqs: mx.array,
|
||||
rope_cos_sin: tuple | None = None,
|
||||
attn_mask: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
b, s, _ = x.shape
|
||||
n, d = self.num_heads, self.head_dim
|
||||
|
||||
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
w_dtype = _linear_dtype(self.q)
|
||||
x_w = x.astype(w_dtype)
|
||||
|
||||
q = self.q(x_w)
|
||||
k = self.k(x_w)
|
||||
if self.norm_q is not None:
|
||||
q = self.norm_q(q)
|
||||
if self.norm_k is not None:
|
||||
k = self.norm_k(k)
|
||||
|
||||
q = q.reshape(b, s, n, d)
|
||||
k = k.reshape(b, s, n, d)
|
||||
v = self.v(x_w).reshape(b, s, n, d)
|
||||
|
||||
# RoPE in float32 for precision (official uses float64)
|
||||
q = rope_apply(
|
||||
q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
|
||||
)
|
||||
k = rope_apply(
|
||||
k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
|
||||
)
|
||||
|
||||
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
|
||||
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
|
||||
k = k.astype(w_dtype).transpose(0, 2, 1, 3)
|
||||
v = v.transpose(0, 2, 1, 3)
|
||||
|
||||
# Use precomputed mask or build from seq_lens
|
||||
mask = attn_mask
|
||||
if mask is None and any(sl < s for sl in seq_lens):
|
||||
mask = mx.zeros((b, 1, 1, s), dtype=q.dtype)
|
||||
for i, sl in enumerate(seq_lens):
|
||||
mask[i, :, :, sl:] = -1e9
|
||||
|
||||
# Use memory-efficient scaled dot-product attention
|
||||
# mx.fast.scaled_dot_product_attention expects [B, N, L, D]
|
||||
if mask is not None:
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=self.scale, mask=mask
|
||||
)
|
||||
else:
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||
|
||||
out = out.transpose(0, 2, 1, 3).reshape(b, s, -1)
|
||||
return self.o(out)
|
||||
|
||||
|
||||
class WanCrossAttention(nn.Module):
|
||||
"""Cross-attention: Q from hidden states, K/V from text context."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
qk_norm: bool = True,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
|
||||
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
||||
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
||||
|
||||
def prepare_kv(self, context: mx.array) -> tuple:
|
||||
"""Pre-compute K and V projections for caching.
|
||||
|
||||
Args:
|
||||
context: [B, L_ctx, dim]
|
||||
|
||||
Returns:
|
||||
(k, v) each [B, N, L_ctx, D] ready for attention
|
||||
"""
|
||||
b = context.shape[0]
|
||||
n, d = self.num_heads, self.head_dim
|
||||
# Cast to compute dtype for efficient matmul
|
||||
w_dtype = _linear_dtype(self.k)
|
||||
ctx = context.astype(w_dtype)
|
||||
k = self.k(ctx)
|
||||
if self.norm_k is not None:
|
||||
k = self.norm_k(k)
|
||||
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
return k, v
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
context: mx.array,
|
||||
context_lens: list | None = None,
|
||||
kv_cache: tuple | None = None,
|
||||
) -> mx.array:
|
||||
b = x.shape[0]
|
||||
n, d = self.num_heads, self.head_dim
|
||||
|
||||
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
w_dtype = _linear_dtype(self.q)
|
||||
q = self.q(x.astype(w_dtype))
|
||||
if self.norm_q is not None:
|
||||
q = self.norm_q(q)
|
||||
q = q.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
|
||||
if kv_cache is not None:
|
||||
k, v = kv_cache
|
||||
else:
|
||||
ctx = context.astype(w_dtype)
|
||||
k = self.k(ctx)
|
||||
if self.norm_k is not None:
|
||||
k = self.norm_k(k)
|
||||
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
|
||||
# Optional context masking
|
||||
mask = None
|
||||
if context_lens is not None:
|
||||
ctx_len = k.shape[2]
|
||||
mask = mx.zeros((b, 1, 1, ctx_len), dtype=q.dtype)
|
||||
for i, cl in enumerate(context_lens):
|
||||
mask[i, :, :, cl:] = -1e9
|
||||
|
||||
if mask is not None:
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=self.scale, mask=mask
|
||||
)
|
||||
else:
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||
|
||||
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d)
|
||||
return self.o(out)
|
||||
129
mlx_video/models/wan_2/config.py
Normal file
129
mlx_video/models/wan_2/config.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Union
|
||||
|
||||
from mlx_video.models.ltx_2.config import BaseModelConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class WanModelConfig(BaseModelConfig):
|
||||
"""Configuration for Wan T2V models (supports both 2.1 and 2.2)."""
|
||||
|
||||
model_type: str = "t2v"
|
||||
model_version: str = "2.2"
|
||||
patch_size: Tuple[int, int, int] = (1, 2, 2)
|
||||
text_len: int = 512
|
||||
in_dim: int = 16
|
||||
dim: int = 5120
|
||||
ffn_dim: int = 13824
|
||||
freq_dim: int = 256
|
||||
text_dim: int = 4096
|
||||
out_dim: int = 16
|
||||
num_heads: int = 40
|
||||
num_layers: int = 40
|
||||
window_size: Tuple[int, int] = (-1, -1)
|
||||
qk_norm: bool = True
|
||||
cross_attn_norm: bool = True
|
||||
eps: float = 1e-6
|
||||
|
||||
# VAE
|
||||
vae_stride: Tuple[int, int, int] = (4, 8, 8)
|
||||
vae_z_dim: int = 16
|
||||
|
||||
# Inference
|
||||
dual_model: bool = True
|
||||
boundary: float = 0.875
|
||||
sample_shift: float = 12.0
|
||||
sample_steps: int = 40
|
||||
sample_guide_scale: Union[float, Tuple[float, float]] = (3.0, 4.0)
|
||||
num_train_timesteps: int = 1000
|
||||
sample_fps: int = 16
|
||||
frame_num: int = 81
|
||||
sample_neg_prompt: str = (
|
||||
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,"
|
||||
"最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,"
|
||||
"画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,"
|
||||
"杂乱的背景,三条腿,背景人很多,倒着走"
|
||||
)
|
||||
|
||||
# Resolution constraints
|
||||
max_area: int = 0 # 0 = no limit; e.g. 704*1280 for TI2V-5B
|
||||
t5_vocab_size: int = 256384
|
||||
t5_dim: int = 4096
|
||||
t5_dim_attn: int = 4096
|
||||
t5_dim_ffn: int = 10240
|
||||
t5_num_heads: int = 64
|
||||
t5_num_layers: int = 24
|
||||
t5_num_buckets: int = 32
|
||||
|
||||
@property
|
||||
def head_dim(self) -> int:
|
||||
return self.dim // self.num_heads
|
||||
|
||||
@classmethod
|
||||
def wan21_t2v_14b(cls) -> "WanModelConfig":
|
||||
"""Wan2.1 T2V 14B: single model, 40 layers, dim=5120."""
|
||||
return cls(
|
||||
model_version="2.1",
|
||||
dual_model=False,
|
||||
boundary=0.0,
|
||||
sample_shift=5.0,
|
||||
sample_steps=50,
|
||||
sample_guide_scale=5.0,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def wan21_t2v_1_3b(cls) -> "WanModelConfig":
|
||||
"""Wan2.1 T2V 1.3B: single model, 30 layers, dim=1536."""
|
||||
return cls(
|
||||
model_version="2.1",
|
||||
dim=1536,
|
||||
ffn_dim=8960,
|
||||
num_heads=12,
|
||||
num_layers=30,
|
||||
dual_model=False,
|
||||
boundary=0.0,
|
||||
sample_shift=5.0,
|
||||
sample_steps=50,
|
||||
sample_guide_scale=5.0,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def wan22_t2v_14b(cls) -> "WanModelConfig":
|
||||
"""Wan2.2 T2V 14B: dual model, 40 layers, dim=5120 (default)."""
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def wan22_i2v_14b(cls) -> "WanModelConfig":
|
||||
"""Wan2.2 I2V 14B: dual model, image-to-video, 40 layers, dim=5120."""
|
||||
return cls(
|
||||
model_type="i2v",
|
||||
in_dim=36,
|
||||
out_dim=16,
|
||||
dual_model=True,
|
||||
boundary=0.900,
|
||||
sample_shift=5.0,
|
||||
sample_guide_scale=(3.5, 3.5),
|
||||
max_area=704 * 1280,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def wan22_ti2v_5b(cls) -> "WanModelConfig":
|
||||
"""Wan2.2 TI2V 5B: text+image to video, 30 layers, dim=3072."""
|
||||
return cls(
|
||||
model_type="ti2v",
|
||||
dim=3072,
|
||||
ffn_dim=14336,
|
||||
in_dim=48,
|
||||
out_dim=48,
|
||||
num_heads=24,
|
||||
num_layers=30,
|
||||
vae_z_dim=48,
|
||||
vae_stride=(4, 16, 16),
|
||||
dual_model=False,
|
||||
boundary=0.0,
|
||||
sample_shift=5.0,
|
||||
sample_steps=40,
|
||||
sample_guide_scale=5.0,
|
||||
sample_fps=24,
|
||||
max_area=704 * 1280,
|
||||
)
|
||||
808
mlx_video/models/wan_2/convert.py
Normal file
808
mlx_video/models/wan_2/convert.py
Normal file
@@ -0,0 +1,808 @@
|
||||
"""Weight conversion for Wan2.2 models (PyTorch -> MLX)."""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_torch_weights(path: str) -> Dict[str, mx.array]:
|
||||
"""Load PyTorch .pth weights and convert to MLX arrays.
|
||||
|
||||
Args:
|
||||
path: Path to .pth file
|
||||
|
||||
Returns:
|
||||
Dictionary of MLX arrays
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError("PyTorch is required to load .pth weights: pip install torch")
|
||||
|
||||
logging.info(f"Loading weights from {path}")
|
||||
state_dict = torch.load(path, map_location="cpu", weights_only=True)
|
||||
|
||||
weights = {}
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
np_val = value.detach().float().numpy()
|
||||
weights[key] = mx.array(np_val)
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def load_safetensors_weights(path: str) -> Dict[str, mx.array]:
|
||||
"""Load safetensors weights as MLX arrays.
|
||||
|
||||
Args:
|
||||
path: Path to directory with safetensors files or single file
|
||||
|
||||
Returns:
|
||||
Dictionary of MLX arrays
|
||||
"""
|
||||
path = Path(path)
|
||||
weights = {}
|
||||
if path.is_file():
|
||||
weights = mx.load(str(path))
|
||||
elif path.is_dir():
|
||||
for sf in sorted(path.glob("*.safetensors")):
|
||||
weights.update(mx.load(str(sf)))
|
||||
return weights
|
||||
|
||||
|
||||
def sanitize_wan_transformer_weights(
|
||||
weights: Dict[str, mx.array]
|
||||
) -> Dict[str, mx.array]:
|
||||
"""Convert Wan2.2 transformer weight keys to MLX model structure.
|
||||
|
||||
Wan2.2 keys follow the pattern:
|
||||
patch_embedding.weight/bias
|
||||
text_embedding.{0,2}.weight/bias
|
||||
time_embedding.{0,2}.weight/bias
|
||||
time_projection.1.weight/bias
|
||||
blocks.{i}.norm1.weight
|
||||
blocks.{i}.self_attn.{q,k,v,o}.weight/bias
|
||||
blocks.{i}.self_attn.norm_q.weight
|
||||
blocks.{i}.self_attn.norm_k.weight
|
||||
blocks.{i}.norm3.weight/bias (if cross_attn_norm)
|
||||
blocks.{i}.cross_attn.{q,k,v,o}.weight/bias
|
||||
blocks.{i}.cross_attn.norm_q.weight
|
||||
blocks.{i}.cross_attn.norm_k.weight
|
||||
blocks.{i}.norm2.weight
|
||||
blocks.{i}.ffn.{0,2}.weight/bias
|
||||
blocks.{i}.modulation
|
||||
head.norm.weight
|
||||
head.head.weight/bias
|
||||
head.modulation
|
||||
freqs (buffer)
|
||||
|
||||
MLX model uses:
|
||||
patch_embedding_proj.weight/bias (after patchify reshape)
|
||||
text_embedding_0.weight/bias, text_embedding_1.weight/bias
|
||||
time_embedding_0.weight/bias, time_embedding_1.weight/bias
|
||||
time_projection.weight/bias
|
||||
blocks.{i}.norm1.weight
|
||||
blocks.{i}.self_attn.{q,k,v,o}.weight/bias
|
||||
etc.
|
||||
"""
|
||||
sanitized = {}
|
||||
consumed = set()
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Patch embedding: Conv3d(16, 5120, (1,2,2)) weight is [O, I, D, H, W]
|
||||
# MLX Linear expects [O, I*D*H*W] after we flatten in patchify
|
||||
if key == "patch_embedding.weight":
|
||||
# Original: [dim, in_dim, 1, 2, 2] -> reshape to [dim, in_dim*1*2*2]
|
||||
value = value.reshape(value.shape[0], -1)
|
||||
new_key = "patch_embedding_proj.weight"
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
if key == "patch_embedding.bias":
|
||||
new_key = "patch_embedding_proj.bias"
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
# Text embedding Sequential: 0=Linear, 1=GELU(no params), 2=Linear
|
||||
if key.startswith("text_embedding.0."):
|
||||
new_key = key.replace("text_embedding.0.", "text_embedding_0.")
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
if key.startswith("text_embedding.2."):
|
||||
new_key = key.replace("text_embedding.2.", "text_embedding_1.")
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
# Time embedding Sequential: 0=Linear, 1=SiLU(no params), 2=Linear
|
||||
if key.startswith("time_embedding.0."):
|
||||
new_key = key.replace("time_embedding.0.", "time_embedding_0.")
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
if key.startswith("time_embedding.2."):
|
||||
new_key = key.replace("time_embedding.2.", "time_embedding_1.")
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
# Time projection Sequential: 0=SiLU(no params), 1=Linear
|
||||
if key.startswith("time_projection.1."):
|
||||
new_key = key.replace("time_projection.1.", "time_projection.")
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
# FFN: Sequential(Linear, GELU, Linear) -> ffn.{0,2} -> ffn.fc1, ffn.fc2
|
||||
new_key = new_key.replace(".ffn.0.", ".ffn.fc1.")
|
||||
new_key = new_key.replace(".ffn.2.", ".ffn.fc2.")
|
||||
|
||||
# Skip the freqs buffer (we compute it in the model)
|
||||
if key == "freqs":
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
|
||||
unconsumed = set(weights.keys()) - consumed
|
||||
if unconsumed:
|
||||
logger.warning("Unconsumed transformer weight keys: %s", sorted(unconsumed))
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_wan_t5_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Convert Wan2.2 T5 encoder weight keys to MLX T5Encoder structure.
|
||||
|
||||
Wan2.2 T5 keys:
|
||||
token_embedding.weight
|
||||
pos_embedding.embedding.weight (if shared_pos)
|
||||
blocks.{i}.norm1.weight
|
||||
blocks.{i}.attn.{q,k,v,o}.weight
|
||||
blocks.{i}.norm2.weight
|
||||
blocks.{i}.ffn.gate.0.weight (gate linear)
|
||||
blocks.{i}.ffn.fc1.weight
|
||||
blocks.{i}.ffn.fc2.weight
|
||||
blocks.{i}.pos_embedding.embedding.weight (if not shared_pos)
|
||||
norm.weight
|
||||
|
||||
MLX T5Encoder structure:
|
||||
token_embedding.weight
|
||||
blocks.{i}.norm1.weight
|
||||
blocks.{i}.attn.{q,k,v,o}.weight
|
||||
blocks.{i}.norm2.weight
|
||||
blocks.{i}.ffn.gate_proj.weight (mapped from gate.0)
|
||||
blocks.{i}.ffn.fc1.weight
|
||||
blocks.{i}.ffn.fc2.weight
|
||||
blocks.{i}.pos_embedding.embedding.weight
|
||||
norm.weight
|
||||
"""
|
||||
sanitized = {}
|
||||
consumed = set()
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Map gate.0 -> gate_proj (the GELU is a separate module, not a parameter)
|
||||
new_key = new_key.replace(".ffn.gate.0.", ".ffn.gate_proj.")
|
||||
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
|
||||
unconsumed = set(weights.keys()) - consumed
|
||||
if unconsumed:
|
||||
logger.warning("Unconsumed T5 weight keys: %s", sorted(unconsumed))
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_wan_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Convert Wan2.2 VAE weight keys to MLX WanVAE structure.
|
||||
|
||||
Handles Conv3d and Conv2d weight transpositions for MLX format.
|
||||
"""
|
||||
sanitized = {}
|
||||
consumed = set()
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Handle Conv3d: PyTorch [O, I, D, H, W] -> MLX CausalConv3d weight [O, D, H, W, I]
|
||||
if "weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
# Handle Conv2d: PyTorch [O, I, H, W] -> MLX [O, H, W, I]
|
||||
if "weight" in key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
# Map decoder keys to MLX decoder structure
|
||||
# Wan2.2 uses encoder/decoder with downsamples/upsamples
|
||||
# Need to adapt naming for our simplified structure
|
||||
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
|
||||
unconsumed = set(weights.keys()) - consumed
|
||||
if unconsumed:
|
||||
logger.warning("Unconsumed VAE weight keys: %s", sorted(unconsumed))
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def _load_lora_configs(
|
||||
lora_configs: List[Tuple[str, float]],
|
||||
) -> Dict[str, list]:
|
||||
"""Load LoRA weights from config tuples, returning module_to_loras dict.
|
||||
|
||||
Shared between weight-merging and runtime-wrapping paths.
|
||||
"""
|
||||
from mlx_video.models.wan_2.generate import Colors
|
||||
from mlx_video.lora import LoRAConfig, load_multiple_loras
|
||||
|
||||
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
|
||||
|
||||
configs = []
|
||||
for lora_path, strength in lora_configs:
|
||||
try:
|
||||
config = LoRAConfig(path=lora_path, strength=strength)
|
||||
configs.append(config)
|
||||
print(f" - {Path(lora_path).name} (strength: {strength})")
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}Error loading LoRA {lora_path}: {e}{Colors.RESET}")
|
||||
raise
|
||||
|
||||
module_to_loras = load_multiple_loras(configs)
|
||||
|
||||
if not module_to_loras:
|
||||
print(
|
||||
f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}"
|
||||
)
|
||||
|
||||
return module_to_loras
|
||||
|
||||
|
||||
def load_and_apply_loras(
|
||||
model_weights: Dict[str, mx.array],
|
||||
lora_configs: Optional[List[Tuple[str, float]]] = None,
|
||||
verbose: bool = False,
|
||||
quantization_bits: int = 0,
|
||||
) -> Dict[str, mx.array]:
|
||||
"""Load and apply LoRA weights to model weights by merging into weight dict.
|
||||
|
||||
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
|
||||
"""
|
||||
from mlx_video.models.wan_2.generate import Colors
|
||||
from mlx_video.lora import apply_loras_to_weights
|
||||
|
||||
if not lora_configs:
|
||||
return model_weights
|
||||
|
||||
module_to_loras = _load_lora_configs(lora_configs)
|
||||
if not module_to_loras:
|
||||
return model_weights
|
||||
|
||||
print(
|
||||
f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}"
|
||||
)
|
||||
if verbose:
|
||||
print(f" Model has {len(model_weights)} weight keys")
|
||||
|
||||
modified_weights = apply_loras_to_weights(
|
||||
model_weights,
|
||||
module_to_loras,
|
||||
verbose=verbose,
|
||||
quantization_bits=quantization_bits,
|
||||
)
|
||||
|
||||
print(f"{Colors.GREEN}✓ LoRAs applied successfully{Colors.RESET}")
|
||||
|
||||
return modified_weights
|
||||
|
||||
|
||||
def convert_wan_checkpoint(
|
||||
checkpoint_dir: str,
|
||||
output_dir: str,
|
||||
dtype: str = "bfloat16",
|
||||
model_version: str = "auto",
|
||||
quantize: bool = False,
|
||||
bits: int = 4,
|
||||
group_size: int = 64,
|
||||
):
|
||||
"""Convert a Wan2.1 or Wan2.2 checkpoint directory to MLX format.
|
||||
|
||||
Wan2.2 expected structure:
|
||||
checkpoint_dir/
|
||||
models_t5_umt5-xxl-enc-bf16.pth
|
||||
Wan2.1_VAE.pth
|
||||
low_noise_model/ (safetensors)
|
||||
high_noise_model/ (safetensors)
|
||||
|
||||
Wan2.1 expected structure:
|
||||
checkpoint_dir/
|
||||
models_t5_umt5-xxl-enc-bf16.pth
|
||||
Wan2.1_VAE.pth
|
||||
diffusion_pytorch_model*.safetensors (single model)
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Path to Wan checkpoint directory
|
||||
output_dir: Path to output MLX model directory
|
||||
dtype: Target dtype
|
||||
model_version: "2.1", "2.2", or "auto" (detect from directory)
|
||||
quantize: Whether to quantize the transformer weights
|
||||
bits: Quantization bits (4 or 8)
|
||||
group_size: Quantization group size (32, 64, or 128)
|
||||
"""
|
||||
import json
|
||||
|
||||
checkpoint_dir = Path(checkpoint_dir)
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
dtype_map = {
|
||||
"float16": mx.float16,
|
||||
"float32": mx.float32,
|
||||
"bfloat16": mx.bfloat16,
|
||||
}
|
||||
target_dtype = dtype_map.get(dtype, mx.bfloat16)
|
||||
|
||||
# Auto-detect version
|
||||
if model_version == "auto":
|
||||
if (checkpoint_dir / "low_noise_model").exists():
|
||||
model_version = "2.2"
|
||||
elif (checkpoint_dir / "Wan2.2_VAE.pth").exists():
|
||||
model_version = "2.2"
|
||||
else:
|
||||
model_version = "2.1"
|
||||
print(f"Auto-detected Wan{model_version} checkpoint")
|
||||
|
||||
is_dual = (checkpoint_dir / "low_noise_model").exists()
|
||||
|
||||
if is_dual:
|
||||
# Wan2.2: Convert dual transformer models
|
||||
low_noise_path = checkpoint_dir / "low_noise_model"
|
||||
if low_noise_path.exists():
|
||||
print("Converting low-noise transformer...")
|
||||
weights = load_safetensors_weights(str(low_noise_path))
|
||||
weights = sanitize_wan_transformer_weights(weights)
|
||||
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
||||
out_path = output_dir / "low_noise_model.safetensors"
|
||||
mx.save_safetensors(str(out_path), weights)
|
||||
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
||||
|
||||
high_noise_path = checkpoint_dir / "high_noise_model"
|
||||
if high_noise_path.exists():
|
||||
print("Converting high-noise transformer...")
|
||||
weights = load_safetensors_weights(str(high_noise_path))
|
||||
weights = sanitize_wan_transformer_weights(weights)
|
||||
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
||||
out_path = output_dir / "high_noise_model.safetensors"
|
||||
mx.save_safetensors(str(out_path), weights)
|
||||
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
||||
else:
|
||||
# Wan2.1: Convert single transformer model
|
||||
# Try safetensors in the checkpoint dir itself
|
||||
print("Converting transformer (single model)...")
|
||||
weights = load_safetensors_weights(str(checkpoint_dir))
|
||||
if not weights:
|
||||
# Fallback: look for .pth files
|
||||
for pth in sorted(checkpoint_dir.glob("*.pth")):
|
||||
if "t5" not in pth.name.lower() and "vae" not in pth.name.lower():
|
||||
print(f" Loading from {pth.name}...")
|
||||
weights = load_torch_weights(str(pth))
|
||||
break
|
||||
if weights:
|
||||
weights = sanitize_wan_transformer_weights(weights)
|
||||
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
||||
out_path = output_dir / "model.safetensors"
|
||||
mx.save_safetensors(str(out_path), weights)
|
||||
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
||||
else:
|
||||
print(" Warning: No transformer weights found!")
|
||||
|
||||
# Save config — detect model size from source config.json or transformer weights
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
def _detect_config():
|
||||
"""Detect config from source config.json or transformer weight shapes."""
|
||||
if is_dual:
|
||||
# Check source config.json for model_type (I2V vs T2V)
|
||||
src_cfg_path = checkpoint_dir / "high_noise_model" / "config.json"
|
||||
if src_cfg_path.exists():
|
||||
with open(src_cfg_path) as f:
|
||||
src_config = json.load(f)
|
||||
src_model_type = src_config.get("model_type", "t2v")
|
||||
if src_model_type == "i2v" or src_config.get("in_dim") == 36:
|
||||
return WanModelConfig.wan22_i2v_14b()
|
||||
return WanModelConfig.wan22_t2v_14b()
|
||||
|
||||
# Try reading source config.json first (most reliable)
|
||||
src_cfg_path = checkpoint_dir / "config.json"
|
||||
src_config = None
|
||||
if src_cfg_path.exists():
|
||||
with open(src_cfg_path) as f:
|
||||
src_config = json.load(f)
|
||||
|
||||
if src_config and "dim" in src_config:
|
||||
src_dim = src_config.get("dim", 5120)
|
||||
src_in_dim = src_config.get("in_dim", 16)
|
||||
src_out_dim = src_config.get("out_dim", 16)
|
||||
src_ffn_dim = src_config.get("ffn_dim", 13824)
|
||||
src_num_heads = src_config.get("num_heads", 40)
|
||||
src_num_layers = src_config.get("num_layers", 40)
|
||||
src_model_type = src_config.get("model_type", "t2v")
|
||||
src_text_len = src_config.get("text_len", 512)
|
||||
|
||||
print(
|
||||
f" Source config: dim={src_dim}, layers={src_num_layers}, "
|
||||
f"heads={src_num_heads}, type={src_model_type}"
|
||||
)
|
||||
|
||||
# Use preset for known TI2V 5B configuration
|
||||
if src_model_type == "ti2v" and src_dim == 3072:
|
||||
return WanModelConfig.wan22_ti2v_5b()
|
||||
|
||||
is_22 = model_version == "2.2"
|
||||
|
||||
# Wan2.2 uses different VAE with z_dim=48 and stride (4,16,16)
|
||||
vae_z = 48 if is_22 else 16
|
||||
vae_s = (4, 16, 16) if is_22 else (4, 8, 8)
|
||||
fps = 24 if is_22 else 16
|
||||
|
||||
return WanModelConfig(
|
||||
model_type=src_model_type,
|
||||
model_version=model_version,
|
||||
dim=src_dim,
|
||||
ffn_dim=src_ffn_dim,
|
||||
in_dim=src_in_dim,
|
||||
out_dim=src_out_dim,
|
||||
num_heads=src_num_heads,
|
||||
num_layers=src_num_layers,
|
||||
text_len=src_text_len,
|
||||
vae_z_dim=vae_z,
|
||||
vae_stride=vae_s,
|
||||
dual_model=False,
|
||||
boundary=0.0,
|
||||
sample_shift=5.0,
|
||||
sample_steps=50,
|
||||
sample_guide_scale=5.0,
|
||||
sample_fps=fps,
|
||||
)
|
||||
|
||||
# Fallback: detect from saved transformer weight shapes
|
||||
saved_model = output_dir / "model.safetensors"
|
||||
if saved_model.exists():
|
||||
det_weights = mx.load(str(saved_model))
|
||||
dim = None
|
||||
for k, v in det_weights.items():
|
||||
if "patch_embedding_proj.weight" in k:
|
||||
dim = v.shape[0]
|
||||
break
|
||||
del det_weights
|
||||
if dim is not None and dim <= 2048:
|
||||
print(f" Auto-detected 1.3B model (dim={dim})")
|
||||
return WanModelConfig.wan21_t2v_1_3b()
|
||||
|
||||
return WanModelConfig.wan21_t2v_14b()
|
||||
|
||||
config = _detect_config()
|
||||
config_path = output_dir / "config.json"
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config.to_dict(), f, indent=2)
|
||||
print(f" Saved config to {config_path}")
|
||||
|
||||
# Convert T5 encoder
|
||||
t5_path = checkpoint_dir / "models_t5_umt5-xxl-enc-bf16.pth"
|
||||
if t5_path.exists():
|
||||
print("Converting T5 encoder...")
|
||||
weights = load_torch_weights(str(t5_path))
|
||||
weights = sanitize_wan_t5_weights(weights)
|
||||
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
||||
out_path = output_dir / "t5_encoder.safetensors"
|
||||
mx.save_safetensors(str(out_path), weights)
|
||||
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
||||
|
||||
# Convert VAE (check both naming conventions)
|
||||
vae_path = checkpoint_dir / "Wan2.1_VAE.pth"
|
||||
is_wan22_vae = False
|
||||
if not vae_path.exists():
|
||||
vae_path = checkpoint_dir / "Wan2.2_VAE.pth"
|
||||
is_wan22_vae = True
|
||||
if vae_path.exists():
|
||||
print(f"Converting VAE ({'Wan2.2' if is_wan22_vae else 'Wan2.1'})...")
|
||||
weights = load_torch_weights(str(vae_path))
|
||||
if is_wan22_vae:
|
||||
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
include_encoder = config.model_type in ("ti2v", "i2v")
|
||||
weights = sanitize_wan22_vae_weights(
|
||||
weights, include_encoder=include_encoder
|
||||
)
|
||||
else:
|
||||
weights = sanitize_wan_vae_weights(weights)
|
||||
# Always save VAE in float32 — official Wan2.2 runs VAE decode in
|
||||
# float32 (dtype=torch.float). Saving in bfloat16 loses precision
|
||||
# that cannot be recovered by upcasting at load time.
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
out_path = output_dir / "vae.safetensors"
|
||||
mx.save_safetensors(str(out_path), weights)
|
||||
print(f" Saved {len(weights)} weight tensors to {out_path} (float32)")
|
||||
|
||||
# Quantize transformer weights if requested
|
||||
if quantize:
|
||||
print(
|
||||
f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})..."
|
||||
)
|
||||
_quantize_saved_model(output_dir, config, is_dual, bits, group_size)
|
||||
|
||||
print(f"\nConversion complete! Output: {output_dir}")
|
||||
|
||||
|
||||
def _quantize_predicate(path: str, module) -> bool:
|
||||
"""Return True for layers that should be quantized.
|
||||
|
||||
Targets heavyweight Linear layers in attention and FFN blocks.
|
||||
Skips embeddings, norms, head, and modulation (small, precision-sensitive).
|
||||
"""
|
||||
if not hasattr(module, "to_quantized"):
|
||||
return False
|
||||
# Quantize attention Q/K/V/O and FFN fc1/fc2
|
||||
quantize_patterns = (
|
||||
".self_attn.q",
|
||||
".self_attn.k",
|
||||
".self_attn.v",
|
||||
".self_attn.o",
|
||||
".cross_attn.q",
|
||||
".cross_attn.k",
|
||||
".cross_attn.v",
|
||||
".cross_attn.o",
|
||||
".ffn.fc1",
|
||||
".ffn.fc2",
|
||||
)
|
||||
return any(path.endswith(p) for p in quantize_patterns)
|
||||
|
||||
|
||||
def _quantize_saved_model(
|
||||
output_dir: Path,
|
||||
config,
|
||||
is_dual: bool,
|
||||
bits: int,
|
||||
group_size: int,
|
||||
source_dir: Path = None,
|
||||
):
|
||||
"""Load saved bf16 model, quantize, and re-save.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to write quantized weights to.
|
||||
config: WanModelConfig for creating the model.
|
||||
is_dual: Whether this is a dual-expert model.
|
||||
bits: Quantization bits.
|
||||
group_size: Quantization group size.
|
||||
source_dir: Directory to read bf16 weights from. Defaults to output_dir.
|
||||
"""
|
||||
import json
|
||||
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
if source_dir is None:
|
||||
source_dir = output_dir
|
||||
|
||||
model_names = []
|
||||
if is_dual:
|
||||
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
|
||||
if (source_dir / name).exists():
|
||||
model_names.append(name)
|
||||
else:
|
||||
if (source_dir / "model.safetensors").exists():
|
||||
model_names.append("model.safetensors")
|
||||
|
||||
for name in model_names:
|
||||
print(f" Quantizing {name}...")
|
||||
model = WanModel(config)
|
||||
weights = mx.load(str(source_dir / name))
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(model.parameters())
|
||||
del weights
|
||||
gc.collect()
|
||||
mx.clear_cache()
|
||||
|
||||
# Apply quantization to targeted layers
|
||||
nn.quantize(
|
||||
model,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
||||
)
|
||||
|
||||
# Save quantized weights
|
||||
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
||||
|
||||
# Validate: check for NaN/Inf in bias tensors (corruption canary)
|
||||
bad_keys = []
|
||||
for k, v in weights_dict.items():
|
||||
if k.endswith(".bias") and not k.endswith(".biases"):
|
||||
mx.eval(v)
|
||||
if mx.any(mx.isnan(v)).item() or mx.any(mx.isinf(v)).item():
|
||||
bad_keys.append(k)
|
||||
if bad_keys:
|
||||
raise RuntimeError(
|
||||
f"Quantization produced corrupted weights in {model_path.name}: "
|
||||
f"{len(bad_keys)} bias tensors contain NaN/Inf "
|
||||
f"(e.g. {bad_keys[0]}). Try re-running with more available memory."
|
||||
)
|
||||
|
||||
mx.save_safetensors(str(output_dir / name), weights_dict)
|
||||
n_quantized = sum(1 for k in weights_dict if ".scales" in k)
|
||||
print(f" {n_quantized} layers quantized, {len(weights_dict)} tensors saved")
|
||||
|
||||
# Free model before processing next file
|
||||
del model, weights_dict
|
||||
gc.collect()
|
||||
mx.clear_cache()
|
||||
|
||||
# Update config.json with quantization metadata
|
||||
config_path = output_dir / "config.json"
|
||||
with open(config_path) as f:
|
||||
cfg = json.load(f)
|
||||
cfg["quantization"] = {
|
||||
"group_size": group_size,
|
||||
"bits": bits,
|
||||
}
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(cfg, f, indent=2)
|
||||
print(f" Updated config.json with quantization metadata")
|
||||
|
||||
|
||||
def quantize_mlx_model(
|
||||
mlx_model_dir: str,
|
||||
output_dir: str,
|
||||
bits: int = 4,
|
||||
group_size: int = 64,
|
||||
):
|
||||
"""Quantize an already-converted MLX model (skips PyTorch conversion).
|
||||
|
||||
Args:
|
||||
mlx_model_dir: Path to existing MLX model directory (bf16/fp16).
|
||||
output_dir: Path to output quantized model directory.
|
||||
bits: Quantization bits (4 or 8).
|
||||
group_size: Quantization group size (32, 64, or 128).
|
||||
"""
|
||||
import json
|
||||
import shutil
|
||||
|
||||
src = Path(mlx_model_dir)
|
||||
dst = Path(output_dir)
|
||||
|
||||
config_path = src / "config.json"
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"No config.json found in {src}")
|
||||
|
||||
with open(config_path) as f:
|
||||
cfg = json.load(f)
|
||||
|
||||
if cfg.get("quantization"):
|
||||
raise ValueError(
|
||||
f"Model at {src} is already quantized "
|
||||
f"({cfg['quantization']['bits']}-bit). Use a bf16/fp16 source."
|
||||
)
|
||||
|
||||
# Detect dual vs single expert
|
||||
is_dual = (src / "low_noise_model.safetensors").exists() and (
|
||||
src / "high_noise_model.safetensors"
|
||||
).exists()
|
||||
|
||||
# Build model config
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config_dict = {
|
||||
k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__
|
||||
}
|
||||
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
|
||||
if key in config_dict and isinstance(config_dict[key], list):
|
||||
config_dict[key] = tuple(config_dict[key])
|
||||
config = WanModelConfig(**config_dict)
|
||||
|
||||
# Copy non-transformer files to output dir (skip large model weights)
|
||||
transformer_files = {
|
||||
"low_noise_model.safetensors",
|
||||
"high_noise_model.safetensors",
|
||||
"model.safetensors",
|
||||
}
|
||||
if dst.resolve() != src.resolve():
|
||||
dst.mkdir(parents=True, exist_ok=True)
|
||||
for f in src.iterdir():
|
||||
if f.is_file() and f.name not in transformer_files:
|
||||
shutil.copy2(f, dst / f.name)
|
||||
print(f"Copied non-transformer files from {src} to {dst}")
|
||||
|
||||
print(f"Quantizing transformer weights ({bits}-bit, group_size={group_size})...")
|
||||
_quantize_saved_model(dst, config, is_dual, bits, group_size, source_dir=src)
|
||||
|
||||
print(f"\nQuantization complete! Output: {dst}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert Wan model to MLX format")
|
||||
parser.add_argument(
|
||||
"--checkpoint-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to Wan checkpoint directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="wan_mlx_model",
|
||||
help="Output path for MLX model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=["float16", "float32", "bfloat16"],
|
||||
default="bfloat16",
|
||||
help="Target dtype",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-version",
|
||||
type=str,
|
||||
choices=["2.1", "2.2", "auto"],
|
||||
default="auto",
|
||||
help="Wan model version (auto-detect by default)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize",
|
||||
action="store_true",
|
||||
help="Quantize transformer weights for faster inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize-only",
|
||||
action="store_true",
|
||||
help="Quantize an already-converted MLX model (skips PyTorch conversion)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bits",
|
||||
type=int,
|
||||
choices=[4, 8],
|
||||
default=4,
|
||||
help="Quantization bits (default: 4)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-size",
|
||||
type=int,
|
||||
choices=[32, 64, 128],
|
||||
default=64,
|
||||
help="Quantization group size (default: 64)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.quantize_only:
|
||||
quantize_mlx_model(
|
||||
args.checkpoint_dir,
|
||||
args.output_dir,
|
||||
bits=args.bits,
|
||||
group_size=args.group_size,
|
||||
)
|
||||
else:
|
||||
convert_wan_checkpoint(
|
||||
args.checkpoint_dir,
|
||||
args.output_dir,
|
||||
args.dtype,
|
||||
args.model_version,
|
||||
quantize=args.quantize,
|
||||
bits=args.bits,
|
||||
group_size=args.group_size,
|
||||
)
|
||||
977
mlx_video/models/wan_2/generate.py
Normal file
977
mlx_video/models/wan_2/generate.py
Normal file
@@ -0,0 +1,977 @@
|
||||
"""Wan2.2 Text-to-Video generation pipeline for MLX."""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from mlx_video.models.wan_2.i2v_utils import build_i2v_mask, preprocess_image
|
||||
from mlx_video.models.wan_2.utils import (
|
||||
encode_text,
|
||||
load_t5_encoder,
|
||||
load_vae_decoder,
|
||||
load_vae_encoder,
|
||||
load_wan_model,
|
||||
)
|
||||
from mlx_video.models.wan_2.postprocess import save_video
|
||||
|
||||
|
||||
class Colors:
|
||||
"""ANSI color codes for terminal output."""
|
||||
|
||||
CYAN = "\033[96m"
|
||||
BLUE = "\033[94m"
|
||||
GREEN = "\033[92m"
|
||||
YELLOW = "\033[93m"
|
||||
RED = "\033[91m"
|
||||
MAGENTA = "\033[95m"
|
||||
BOLD = "\033[1m"
|
||||
DIM = "\033[2m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
|
||||
# Backward-compat alias (tests and external code may use the old name)
|
||||
_build_i2v_mask = build_i2v_mask
|
||||
|
||||
|
||||
def _best_output_size(w, h, dw, dh, max_area):
|
||||
"""Compute the best output resolution that fits within max_area while
|
||||
preserving the input aspect ratio and satisfying alignment constraints.
|
||||
Matches the reference implementation's best_output_size().
|
||||
"""
|
||||
ratio = w / h
|
||||
ow = (max_area * ratio) ** 0.5
|
||||
oh = max_area / ow
|
||||
|
||||
# Option 1: process width first
|
||||
ow1 = int(ow // dw * dw)
|
||||
oh1 = int(max_area / ow1 // dh * dh)
|
||||
ratio1 = ow1 / oh1
|
||||
|
||||
# Option 2: process height first
|
||||
oh2 = int(oh // dh * dh)
|
||||
ow2 = int(max_area / oh2 // dw * dw)
|
||||
ratio2 = ow2 / oh2
|
||||
|
||||
if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2, ratio2 / ratio):
|
||||
return ow1, oh1
|
||||
return ow2, oh2
|
||||
|
||||
|
||||
def generate_video(
|
||||
model_dir: str,
|
||||
prompt: str,
|
||||
negative_prompt: str | None = None,
|
||||
image: str | None = None,
|
||||
width: int = 1280,
|
||||
height: int = 704,
|
||||
num_frames: int = 81,
|
||||
steps: int = None,
|
||||
guide_scale: str | float | tuple = None,
|
||||
shift: float = None,
|
||||
seed: int = -1,
|
||||
output_path: str = "output.mp4",
|
||||
scheduler: str = "unipc",
|
||||
loras: list | None = None,
|
||||
loras_high: list | None = None,
|
||||
loras_low: list | None = None,
|
||||
tiling: str = "auto",
|
||||
no_compile: bool = False,
|
||||
trim_first_frames: int = 0,
|
||||
debug_latents: bool = False,
|
||||
):
|
||||
"""Generate video using Wan pipeline (supports T2V and I2V).
|
||||
|
||||
Args:
|
||||
model_dir: Path to converted MLX model directory
|
||||
prompt: Text prompt
|
||||
negative_prompt: Negative prompt (None = use config default, "" = no negative prompt)
|
||||
image: Path to input image for I2V (None = T2V mode)
|
||||
width: Video width
|
||||
height: Video height
|
||||
num_frames: Number of frames (must be 4n+1)
|
||||
steps: Number of diffusion steps (None = use config default)
|
||||
guide_scale: Guidance scale: float for single, (low,high) for dual (None = config default)
|
||||
shift: Noise schedule shift (None = use config default)
|
||||
seed: Random seed (-1 for random)
|
||||
output_path: Output video path
|
||||
scheduler: Solver type: 'euler', 'dpm++', or 'unipc' (default)
|
||||
loras: Optional list of (path, strength) tuples applied to all models
|
||||
loras_high: Optional list of (path, strength) tuples for high-noise model only
|
||||
loras_low: Optional list of (path, strength) tuples for low-noise model only
|
||||
tiling: Tiling mode for VAE decoding. Options:
|
||||
- "auto": Automatically determine tiling based on video size (default)
|
||||
- "none": Disable tiling
|
||||
- "default", "aggressive", "conservative": Preset tiling configs
|
||||
- "spatial": Spatial tiling only
|
||||
- "temporal": Temporal tiling only
|
||||
no_compile: If True, skip mx.compile on models (useful for debugging)
|
||||
trim_first_frames: Number of temporal latent positions to generate extra
|
||||
and discard from the start. Each position = 4 pixel frames. Use 1
|
||||
to fix first-frame artifacts on 14B models (generates 4 extra frames,
|
||||
discards first 4). Use 2 for more aggressive trimming. Default: 0.
|
||||
debug_latents: If True, print per-temporal-position latent statistics
|
||||
after denoising for diagnosing first-frame artifacts.
|
||||
"""
|
||||
import json
|
||||
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowMatchEulerScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
|
||||
model_dir = Path(model_dir)
|
||||
|
||||
# Load config from model dir if available, otherwise auto-detect
|
||||
config_path = model_dir / "config.json"
|
||||
quantization = None
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
config_dict = json.load(f)
|
||||
# Extract quantization config (not a model config field)
|
||||
quantization = config_dict.pop("quantization", None)
|
||||
# Handle tuple fields stored as lists in JSON
|
||||
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
|
||||
if key in config_dict and isinstance(config_dict[key], list):
|
||||
config_dict[key] = tuple(config_dict[key])
|
||||
config = WanModelConfig(
|
||||
**{
|
||||
k: v
|
||||
for k, v in config_dict.items()
|
||||
if k in WanModelConfig.__dataclass_fields__
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Auto-detect: dual model files → 2.2, single model → 2.1
|
||||
if (model_dir / "low_noise_model.safetensors").exists():
|
||||
config = WanModelConfig.wan22_t2v_14b()
|
||||
else:
|
||||
# Detect 1.3B vs 14B from weight shapes
|
||||
model_path = model_dir / "model.safetensors"
|
||||
if model_path.exists():
|
||||
probe = mx.load(str(model_path), return_metadata=False)
|
||||
for k, v in probe.items():
|
||||
if "patch_embedding_proj.weight" in k:
|
||||
dim = v.shape[0]
|
||||
if dim <= 2048:
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
else:
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
break
|
||||
else:
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
del probe
|
||||
else:
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
|
||||
is_dual = config.dual_model
|
||||
is_i2v = image is not None
|
||||
|
||||
# Validate config against actual weights (handles mismatched config.json)
|
||||
if not is_dual:
|
||||
model_path = model_dir / "model.safetensors"
|
||||
if model_path.exists():
|
||||
probe = mx.load(str(model_path), return_metadata=False)
|
||||
for k, v in probe.items():
|
||||
if "patch_embedding_proj.weight" in k:
|
||||
actual_dim = v.shape[0]
|
||||
if actual_dim != config.dim:
|
||||
print(
|
||||
f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}"
|
||||
)
|
||||
if actual_dim <= 2048:
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
else:
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
break
|
||||
del probe
|
||||
|
||||
# Auto-correct Wan2.2 VAE params from stale configs
|
||||
if config.in_dim == 48 and config.vae_z_dim != 48:
|
||||
print(
|
||||
f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}"
|
||||
)
|
||||
config = WanModelConfig(
|
||||
**{
|
||||
**{
|
||||
f.name: getattr(config, f.name)
|
||||
for f in config.__dataclass_fields__.values()
|
||||
},
|
||||
"vae_z_dim": 48,
|
||||
"vae_stride": (4, 16, 16),
|
||||
"sample_fps": 24,
|
||||
}
|
||||
)
|
||||
|
||||
# Apply defaults from config if not overridden
|
||||
if steps is None:
|
||||
steps = config.sample_steps
|
||||
if shift is None:
|
||||
shift = config.sample_shift
|
||||
if guide_scale is None:
|
||||
guide_scale = config.sample_guide_scale
|
||||
|
||||
# Normalize guide_scale
|
||||
if isinstance(guide_scale, (int, float)):
|
||||
guide_scale = float(guide_scale)
|
||||
elif isinstance(guide_scale, str):
|
||||
parts = [float(x) for x in guide_scale.split(",")]
|
||||
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
# Detect CFG-disabled mode (guide_scale=1.0 for all models → skip uncond pass for 2x speedup)
|
||||
if isinstance(guide_scale, tuple):
|
||||
cfg_disabled = all(gs <= 1.0 for gs in guide_scale)
|
||||
else:
|
||||
cfg_disabled = guide_scale <= 1.0
|
||||
|
||||
# Validate frame count
|
||||
assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}"
|
||||
|
||||
gen_frames = num_frames
|
||||
if trim_first_frames > 0:
|
||||
gen_frames = num_frames + trim_first_frames * 4
|
||||
print(
|
||||
f"{Colors.DIM} Trim: generating {gen_frames} frames, will discard first {trim_first_frames * 4}{Colors.RESET}"
|
||||
)
|
||||
|
||||
version_str = f"Wan{config.model_version}"
|
||||
mode_str = "dual-model" if is_dual else "single-model"
|
||||
pipeline_str = "Image-to-Video" if is_i2v else "Text-to-Video"
|
||||
# Resolve negative prompt: explicit user value > config default
|
||||
# The official Wan2.2 uses a Chinese negative prompt (config.sample_neg_prompt)
|
||||
# that prevents oversaturation, artifacts, and comic look. We use it by default.
|
||||
# Text cleaning (_clean_text) normalizes fullwidth chars to match official tokenization.
|
||||
if negative_prompt is None:
|
||||
neg_prompt_resolved = config.sample_neg_prompt
|
||||
else:
|
||||
neg_prompt_resolved = negative_prompt
|
||||
print(f"{Colors.CYAN}{'='*60}")
|
||||
print(f" {version_str} {pipeline_str} Generation (MLX, {mode_str})")
|
||||
print(f"{'='*60}{Colors.RESET}")
|
||||
print(f"{Colors.DIM} Prompt: {prompt}")
|
||||
if is_i2v:
|
||||
print(f" Image: {image}")
|
||||
if neg_prompt_resolved and neg_prompt_resolved.strip():
|
||||
neg_display = (
|
||||
neg_prompt_resolved[:60] + "..."
|
||||
if len(neg_prompt_resolved) > 60
|
||||
else neg_prompt_resolved
|
||||
)
|
||||
print(f" Neg prompt: {neg_display}")
|
||||
print(f" Size: {width}x{height}, Frames: {num_frames}")
|
||||
print(
|
||||
f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}"
|
||||
)
|
||||
if cfg_disabled:
|
||||
print(f" CFG: disabled (guide_scale≤1 → B=1 fast path, 2x denoising speedup)")
|
||||
print(f"{Colors.RESET}")
|
||||
|
||||
# Seed
|
||||
if seed < 0:
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
mx.random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
print(f"{Colors.DIM} Seed: {seed}{Colors.RESET}")
|
||||
|
||||
# Align dimensions to patch_size * vae_stride (required for patchify)
|
||||
vae_stride = config.vae_stride
|
||||
patch_size = config.patch_size
|
||||
align_h = patch_size[1] * vae_stride[1] # e.g. 2*16=32
|
||||
align_w = patch_size[2] * vae_stride[2]
|
||||
if height % align_h != 0 or width % align_w != 0:
|
||||
old_h, old_w = height, width
|
||||
height = (height // align_h) * align_h
|
||||
width = (width // align_w) * align_w
|
||||
if height == 0:
|
||||
height = align_h
|
||||
if width == 0:
|
||||
width = align_w
|
||||
print(
|
||||
f"{Colors.DIM} Aligned {old_w}x{old_h} → {width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}"
|
||||
)
|
||||
|
||||
# Enforce max_area constraint (model-specific resolution limit)
|
||||
if config.max_area > 0 and height * width > config.max_area:
|
||||
old_h, old_w = height, width
|
||||
width, height = _best_output_size(
|
||||
width, height, align_w, align_h, config.max_area
|
||||
)
|
||||
print(
|
||||
f"{Colors.YELLOW} ⚠ Resolution {old_w}x{old_h} exceeds model's max area "
|
||||
f"({config.max_area:,}px). Adjusted → {width}x{height}{Colors.RESET}"
|
||||
)
|
||||
|
||||
# Compute target latent shape
|
||||
z_dim = config.vae_z_dim
|
||||
t_latent = (gen_frames - 1) // vae_stride[0] + 1
|
||||
h_latent = height // vae_stride[1]
|
||||
w_latent = width // vae_stride[2]
|
||||
target_shape = (z_dim, t_latent, h_latent, w_latent)
|
||||
|
||||
# Sequence length for transformer
|
||||
seq_len = math.ceil(
|
||||
(h_latent * w_latent) / (patch_size[1] * patch_size[2]) * t_latent
|
||||
)
|
||||
|
||||
print(f"{Colors.DIM} Latent shape: {target_shape}")
|
||||
print(f" Sequence length: {seq_len}{Colors.RESET}")
|
||||
|
||||
# Load T5 encoder
|
||||
t1 = time.time()
|
||||
print(f"\n{Colors.BLUE}Loading T5 encoder...{Colors.RESET}")
|
||||
t5_path = model_dir / "t5_encoder.safetensors"
|
||||
t5_encoder = load_t5_encoder(t5_path, config)
|
||||
|
||||
# Load tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
|
||||
|
||||
# Encode prompts
|
||||
print(f"{Colors.BLUE}Encoding text...{Colors.RESET}")
|
||||
context = encode_text(t5_encoder, tokenizer, prompt, config.text_len)
|
||||
if cfg_disabled:
|
||||
context_null = None
|
||||
mx.eval(context)
|
||||
else:
|
||||
context_null = encode_text(
|
||||
t5_encoder, tokenizer, neg_prompt_resolved, config.text_len
|
||||
)
|
||||
mx.eval(context, context_null)
|
||||
|
||||
# Free T5 from memory
|
||||
del t5_encoder
|
||||
gc.collect()
|
||||
mx.clear_cache()
|
||||
print(f"{Colors.DIM} T5 encoding: {time.time() - t1:.1f}s{Colors.RESET}")
|
||||
|
||||
# I2V: encode image to latent space
|
||||
z_img = None
|
||||
i2v_mask = None
|
||||
i2v_mask_tokens = None
|
||||
y_i2v = None
|
||||
is_i2v_channel_concat = is_i2v and config.model_type == "i2v"
|
||||
is_i2v_mask_blend = is_i2v and config.model_type != "i2v"
|
||||
if is_i2v:
|
||||
print(f"\n{Colors.BLUE}Encoding input image...{Colors.RESET}")
|
||||
t_img = time.time()
|
||||
|
||||
vae_path = model_dir / "vae.safetensors"
|
||||
|
||||
if is_i2v_channel_concat:
|
||||
# I2V-14B: encode full video (first frame = image, rest = zeros)
|
||||
# and construct y tensor with mask + encoded latents
|
||||
from PIL import Image
|
||||
|
||||
img = Image.open(image).convert("RGB")
|
||||
scale = max(width / img.width, height / img.height)
|
||||
img = img.resize(
|
||||
(round(img.width * scale), round(img.height * scale)), Image.LANCZOS
|
||||
)
|
||||
x1, y1 = (img.width - width) // 2, (img.height - height) // 2
|
||||
img = img.crop((x1, y1, x1 + width, y1 + height))
|
||||
img_arr = mx.array(
|
||||
np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0
|
||||
) # [H, W, 3]
|
||||
img_chw = img_arr.transpose(2, 0, 1) # [3, H, W]
|
||||
|
||||
# Build video: first frame = image, rest = zeros -> [3, F, H, W]
|
||||
# Chunked encoding processes 1-frame + 4-frame chunks with temporal caching
|
||||
video = mx.concatenate(
|
||||
[
|
||||
img_chw[:, None, :, :],
|
||||
mx.zeros((3, num_frames - 1, height, width)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
# Encode through Wan2.1 VAE -> [1, z_dim, T_lat, H_lat, W_lat]
|
||||
vae_enc = load_vae_encoder(vae_path, config)
|
||||
z_video = vae_enc.encode(video[None]) # [1, 16, T_lat, H_lat, W_lat]
|
||||
mx.eval(z_video)
|
||||
z_video = z_video[0] # [16, T_lat, H_lat, W_lat]
|
||||
|
||||
# Build mask: 1 for first frame, 0 for rest -> rearrange to [4, T_lat, H, W]
|
||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||
msk = mx.concatenate(
|
||||
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
|
||||
)
|
||||
# Repeat first frame 4x, concat rest: [1, 4 + (F-1), H_lat, W_lat]
|
||||
msk = mx.concatenate(
|
||||
[
|
||||
mx.repeat(msk[:, :1], 4, axis=1),
|
||||
msk[:, 1:],
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
# Reshape to [1, T_lat, 4, H_lat, W_lat] then transpose -> [4, T_lat, H_lat, W_lat]
|
||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
||||
|
||||
# y = concat([mask, encoded_video]) -> [20, T_lat, H_lat, W_lat]
|
||||
y_i2v = mx.concatenate([msk, z_video], axis=0)
|
||||
mx.eval(y_i2v)
|
||||
|
||||
del vae_enc, img_arr, img_chw, video, z_video, msk
|
||||
else:
|
||||
# TI2V-5B: encode single image, blend with noise via mask
|
||||
img_tensor = preprocess_image(image, width, height)
|
||||
mx.eval(img_tensor)
|
||||
|
||||
vae_enc = load_vae_encoder(vae_path, config)
|
||||
z_img = vae_enc.encode(img_tensor) # [1, 1, H_lat, W_lat, z_dim]
|
||||
mx.eval(z_img)
|
||||
z_img = z_img[0].transpose(3, 0, 1, 2) # [z_dim, 1, H_lat, W_lat]
|
||||
i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size)
|
||||
|
||||
del vae_enc, img_tensor
|
||||
|
||||
gc.collect()
|
||||
mx.clear_cache()
|
||||
print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}")
|
||||
|
||||
# Load transformer models
|
||||
print(f"\n{Colors.BLUE}Loading transformer model(s)...{Colors.RESET}")
|
||||
if quantization:
|
||||
print(
|
||||
f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}"
|
||||
)
|
||||
t2 = time.time()
|
||||
|
||||
# Merge per-model LoRAs with shared LoRAs
|
||||
_loras_low = (loras or []) + (loras_low or []) or None
|
||||
_loras_high = (loras or []) + (loras_high or []) or None
|
||||
_loras_single = loras
|
||||
|
||||
if is_dual:
|
||||
low_noise_path = model_dir / "low_noise_model.safetensors"
|
||||
high_noise_path = model_dir / "high_noise_model.safetensors"
|
||||
low_noise_model = load_wan_model(
|
||||
low_noise_path, config, quantization, loras=_loras_low
|
||||
)
|
||||
high_noise_model = load_wan_model(
|
||||
high_noise_path, config, quantization, loras=_loras_high
|
||||
)
|
||||
else:
|
||||
single_model = load_wan_model(
|
||||
model_dir / "model.safetensors", config, quantization, loras=_loras_single
|
||||
)
|
||||
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
|
||||
|
||||
# Precompute text embeddings once (avoids redundant MLP in every step)
|
||||
# Each model has its own text_embedding weights, so dual models need separate embeddings
|
||||
if cfg_disabled:
|
||||
# No CFG: only compute cond embeddings (B=1 forward pass, 2x faster)
|
||||
if is_dual:
|
||||
context_emb_low = low_noise_model.embed_text([context])
|
||||
context_emb_high = high_noise_model.embed_text([context])
|
||||
mx.eval(context_emb_low, context_emb_high)
|
||||
context_cond_low = context_emb_low[0:1]
|
||||
context_cond_high = context_emb_high[0:1]
|
||||
else:
|
||||
context_emb = single_model.embed_text([context])
|
||||
mx.eval(context_emb)
|
||||
context_cond = context_emb[0:1]
|
||||
else:
|
||||
if is_dual:
|
||||
context_emb_low = low_noise_model.embed_text([context, context_null])
|
||||
context_emb_high = high_noise_model.embed_text([context, context_null])
|
||||
mx.eval(context_emb_low, context_emb_high)
|
||||
context_cfg_low = mx.concatenate(
|
||||
[context_emb_low[0:1], context_emb_low[1:2]], axis=0
|
||||
)
|
||||
context_cfg_high = mx.concatenate(
|
||||
[context_emb_high[0:1], context_emb_high[1:2]], axis=0
|
||||
)
|
||||
else:
|
||||
context_emb = single_model.embed_text([context, context_null])
|
||||
mx.eval(context_emb)
|
||||
context_cfg = mx.concatenate([context_emb[0:1], context_emb[1:2]], axis=0)
|
||||
|
||||
# Precompute cross-attention K/V caches (constant across all steps)
|
||||
if cfg_disabled:
|
||||
if is_dual:
|
||||
cross_kv_low = low_noise_model.prepare_cross_kv(context_cond_low)
|
||||
cross_kv_high = high_noise_model.prepare_cross_kv(context_cond_high)
|
||||
mx.eval(cross_kv_low, cross_kv_high)
|
||||
else:
|
||||
cross_kv = single_model.prepare_cross_kv(context_cond)
|
||||
mx.eval(cross_kv)
|
||||
else:
|
||||
if is_dual:
|
||||
cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg_low)
|
||||
cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg_high)
|
||||
mx.eval(cross_kv_low, cross_kv_high)
|
||||
else:
|
||||
cross_kv = single_model.prepare_cross_kv(context_cfg)
|
||||
mx.eval(cross_kv)
|
||||
|
||||
# Precompute RoPE frequencies (grid sizes are constant across all steps)
|
||||
f_grid = t_latent // patch_size[0]
|
||||
h_grid = h_latent // patch_size[1]
|
||||
w_grid = w_latent // patch_size[2]
|
||||
if cfg_disabled:
|
||||
rope_grid_sizes = [(f_grid, h_grid, w_grid)]
|
||||
else:
|
||||
rope_grid_sizes = [(f_grid, h_grid, w_grid), (f_grid, h_grid, w_grid)]
|
||||
if is_dual:
|
||||
rope_cos_sin_low = low_noise_model.prepare_rope(rope_grid_sizes)
|
||||
rope_cos_sin_high = high_noise_model.prepare_rope(rope_grid_sizes)
|
||||
mx.eval(rope_cos_sin_low, rope_cos_sin_high)
|
||||
else:
|
||||
rope_cos_sin = single_model.prepare_rope(rope_grid_sizes)
|
||||
mx.eval(rope_cos_sin)
|
||||
|
||||
# Setup scheduler
|
||||
_schedulers = {
|
||||
"euler": FlowMatchEulerScheduler,
|
||||
"dpm++": FlowDPMPP2MScheduler,
|
||||
"unipc": FlowUniPCScheduler,
|
||||
}
|
||||
sched_cls = _schedulers.get(scheduler, FlowUniPCScheduler)
|
||||
sched = sched_cls(num_train_timesteps=config.num_train_timesteps)
|
||||
sched.set_timesteps(steps, shift=shift)
|
||||
|
||||
# Generate initial noise
|
||||
noise = mx.random.normal(target_shape)
|
||||
|
||||
# I2V initialization: TI2V-5B blends image with noise, I2V-14B uses pure noise
|
||||
if is_i2v_mask_blend:
|
||||
latents = (1.0 - i2v_mask) * z_img + i2v_mask * noise
|
||||
else:
|
||||
latents = noise
|
||||
|
||||
# Boundary for model switching (dual model only)
|
||||
boundary = (config.boundary * config.num_train_timesteps) if is_dual else None
|
||||
|
||||
# Diffusion loop
|
||||
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
|
||||
t3 = time.time()
|
||||
|
||||
# Compile model forward for faster denoising
|
||||
if not no_compile:
|
||||
models_to_compile = (
|
||||
[high_noise_model, low_noise_model] if is_dual else [single_model]
|
||||
)
|
||||
for m in models_to_compile:
|
||||
m._compiled = mx.compile(m)
|
||||
|
||||
# Pre-convert timesteps to Python list to avoid .item() sync each step
|
||||
timestep_list = sched.timesteps.tolist()
|
||||
|
||||
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
|
||||
timestep_val = timestep_list[i]
|
||||
|
||||
# Select model, cached K/V, and precomputed RoPE
|
||||
if is_dual:
|
||||
if timestep_val >= boundary:
|
||||
model = high_noise_model
|
||||
kv = cross_kv_high
|
||||
rcs = rope_cos_sin_high
|
||||
else:
|
||||
model = low_noise_model
|
||||
kv = cross_kv_low
|
||||
rcs = rope_cos_sin_low
|
||||
else:
|
||||
model = single_model
|
||||
kv = cross_kv
|
||||
rcs = rope_cos_sin
|
||||
|
||||
# Use compiled forward when available (faster after first trace)
|
||||
_call = getattr(model, "_compiled", model)
|
||||
|
||||
if cfg_disabled:
|
||||
# No CFG: B=1 forward pass (2x faster than B=2 CFG batch)
|
||||
if is_i2v_mask_blend:
|
||||
t_tokens = i2v_mask_tokens * timestep_val
|
||||
pad_len = seq_len - t_tokens.shape[1]
|
||||
if pad_len > 0:
|
||||
t_tokens = mx.concatenate(
|
||||
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
|
||||
)
|
||||
t_batch = t_tokens # [1, L]
|
||||
else:
|
||||
t_batch = mx.array([timestep_val])
|
||||
|
||||
y_arg = [y_i2v] if is_i2v_channel_concat else None
|
||||
|
||||
if is_dual:
|
||||
ctx = (
|
||||
context_cond_high if timestep_val >= boundary else context_cond_low
|
||||
)
|
||||
else:
|
||||
ctx = context_cond
|
||||
preds = _call(
|
||||
[latents],
|
||||
t=t_batch,
|
||||
context=ctx,
|
||||
seq_len=seq_len,
|
||||
cross_kv_caches=kv,
|
||||
y=y_arg,
|
||||
rope_cos_sin=rcs,
|
||||
)
|
||||
noise_pred = preds[0]
|
||||
del preds
|
||||
else:
|
||||
# CFG: batch cond + uncond into single B=2 forward pass
|
||||
if is_dual:
|
||||
gs = guide_scale[1] if timestep_val >= boundary else guide_scale[0]
|
||||
else:
|
||||
gs = (
|
||||
guide_scale
|
||||
if isinstance(guide_scale, (int, float))
|
||||
else guide_scale[0]
|
||||
)
|
||||
|
||||
if is_i2v_mask_blend:
|
||||
t_tokens = i2v_mask_tokens * timestep_val
|
||||
pad_len = seq_len - t_tokens.shape[1]
|
||||
if pad_len > 0:
|
||||
t_tokens = mx.concatenate(
|
||||
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
|
||||
)
|
||||
t_batch = mx.concatenate([t_tokens, t_tokens], axis=0)
|
||||
else:
|
||||
t_batch = mx.array([timestep_val, timestep_val])
|
||||
|
||||
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
|
||||
|
||||
ctx = (
|
||||
context_cfg
|
||||
if not is_dual
|
||||
else (context_cfg_high if timestep_val >= boundary else context_cfg_low)
|
||||
)
|
||||
preds = _call(
|
||||
[latents, latents],
|
||||
t=t_batch,
|
||||
context=ctx,
|
||||
seq_len=seq_len,
|
||||
cross_kv_caches=kv,
|
||||
y=y_arg,
|
||||
rope_cos_sin=rcs,
|
||||
)
|
||||
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
|
||||
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
||||
del noise_pred_cond, noise_pred_uncond, preds
|
||||
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||
|
||||
# TI2V-5B: re-apply mask to keep first frame frozen
|
||||
if is_i2v_mask_blend:
|
||||
latents = (1.0 - i2v_mask) * z_img + i2v_mask * latents
|
||||
|
||||
# Release temporaries before eval to free memory for graph execution
|
||||
del noise_pred
|
||||
mx.eval(latents)
|
||||
|
||||
print(f"{Colors.DIM} Denoising: {time.time() - t3:.1f}s{Colors.RESET}")
|
||||
|
||||
# Diagnostic: per-temporal-position latent statistics
|
||||
if debug_latents:
|
||||
lat_np = np.array(latents) # [C, T, H, W]
|
||||
n_t = lat_np.shape[1]
|
||||
print(
|
||||
f"\n{Colors.CYAN} Latent diagnostics (shape {lat_np.shape}):{Colors.RESET}"
|
||||
)
|
||||
print(
|
||||
f" {'Pos':>4s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s} {'AbsMean':>8s}"
|
||||
)
|
||||
for t_pos in range(min(n_t, 8)):
|
||||
frame = lat_np[:, t_pos, :, :]
|
||||
print(
|
||||
f" {t_pos:4d} {frame.mean():8.4f} {frame.std():8.4f} "
|
||||
f"{frame.min():8.4f} {frame.max():8.4f} {np.abs(frame).mean():8.4f}"
|
||||
)
|
||||
if n_t > 8:
|
||||
interior = lat_np[:, 4:, :, :]
|
||||
print(
|
||||
f" {'4+':>4s} {interior.mean():8.4f} {interior.std():8.4f} "
|
||||
f"{interior.min():8.4f} {interior.max():8.4f} {np.abs(interior).mean():8.4f}"
|
||||
)
|
||||
print()
|
||||
|
||||
# Free transformer models and text embeddings
|
||||
if is_dual:
|
||||
del low_noise_model, high_noise_model, cross_kv_low, cross_kv_high
|
||||
if cfg_disabled:
|
||||
del context_cond_low, context_cond_high
|
||||
else:
|
||||
del context_cfg_low, context_cfg_high
|
||||
else:
|
||||
del single_model, cross_kv
|
||||
if cfg_disabled:
|
||||
del context_cond
|
||||
else:
|
||||
del context_cfg
|
||||
del model, kv, context
|
||||
if context_null is not None:
|
||||
del context_null
|
||||
gc.collect()
|
||||
mx.clear_cache()
|
||||
|
||||
# Load VAE and decode
|
||||
print(f"\n{Colors.BLUE}Decoding with VAE...{Colors.RESET}")
|
||||
t4 = time.time()
|
||||
vae_path = model_dir / "vae.safetensors"
|
||||
vae = load_vae_decoder(vae_path, config)
|
||||
|
||||
is_wan22_vae = config.vae_z_dim == 48
|
||||
|
||||
# Temporal extend: prepend reflected latent frames to the VAE input so that
|
||||
# the CausalConv3d zero-padding artifacts fall on the prefix (which we crop).
|
||||
# This gives the first real frame a full temporal receptive field of real data.
|
||||
# Select tiling configuration
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig
|
||||
|
||||
if tiling == "none":
|
||||
tiling_config = None
|
||||
elif tiling == "auto":
|
||||
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||
elif tiling == "default":
|
||||
tiling_config = TilingConfig.default()
|
||||
elif tiling == "aggressive":
|
||||
tiling_config = TilingConfig.aggressive()
|
||||
elif tiling == "conservative":
|
||||
tiling_config = TilingConfig.conservative()
|
||||
elif tiling == "spatial":
|
||||
tiling_config = TilingConfig.spatial_only()
|
||||
elif tiling == "temporal":
|
||||
tiling_config = TilingConfig.temporal_only()
|
||||
else:
|
||||
print(
|
||||
f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}"
|
||||
)
|
||||
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||
|
||||
if tiling_config is not None:
|
||||
spatial_info = (
|
||||
f"{tiling_config.spatial_config.tile_size_in_pixels}px"
|
||||
if tiling_config.spatial_config
|
||||
else "none"
|
||||
)
|
||||
temporal_info = (
|
||||
f"{tiling_config.temporal_config.tile_size_in_frames}f"
|
||||
if tiling_config.temporal_config
|
||||
else "none"
|
||||
)
|
||||
print(
|
||||
f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}"
|
||||
)
|
||||
|
||||
if is_wan22_vae:
|
||||
from mlx_video.models.wan_2.vae22 import denormalize_latents
|
||||
|
||||
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
|
||||
z = latents.transpose(1, 2, 3, 0)[None]
|
||||
z = denormalize_latents(z)
|
||||
if tiling_config is not None:
|
||||
video = vae.decode_tiled(z, tiling_config)
|
||||
else:
|
||||
video = vae(z)
|
||||
mx.eval(video)
|
||||
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
||||
|
||||
video = np.array(video[0]) # [T', H', W', 3]
|
||||
video = (video + 1.0) / 2.0
|
||||
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
||||
else:
|
||||
if tiling_config is not None:
|
||||
video = vae.decode_tiled(latents[None], tiling_config)
|
||||
else:
|
||||
video = vae.decode(latents[None])
|
||||
mx.eval(video)
|
||||
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
||||
|
||||
video = np.array(video[0]) # [3, T', H, W]
|
||||
video = (video + 1.0) / 2.0
|
||||
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
||||
video = video.transpose(1, 2, 3, 0) # [T, H, W, 3]
|
||||
|
||||
# Trim first N temporal chunks if requested (avoids first-frame artifacts)
|
||||
if trim_first_frames > 0:
|
||||
trim_pixels = trim_first_frames * 4
|
||||
video = video[trim_pixels:]
|
||||
print(
|
||||
f"{Colors.DIM} Trimmed first {trim_pixels} frames ({video.shape[0]} remaining){Colors.RESET}"
|
||||
)
|
||||
|
||||
save_video(video, output_path, fps=config.sample_fps)
|
||||
print(f"\n{Colors.GREEN}✓ Video saved to {output_path}{Colors.RESET}")
|
||||
print(f"{Colors.DIM} Total time: {time.time() - t1:.1f}s{Colors.RESET}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
|
||||
parser.add_argument(
|
||||
"--model-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to converted MLX model directory",
|
||||
)
|
||||
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
|
||||
parser.add_argument(
|
||||
"--image",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to input image for I2V (omit for T2V mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--negative-prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Negative prompt for CFG (default: official Chinese prompt from config)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-negative-prompt",
|
||||
action="store_true",
|
||||
help="Disable negative prompt (use empty string instead of config default)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width", type=int, default=1280, help="Video width (default: 1280)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=704,
|
||||
help="Video height (default: 704; 720p models use 704)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of diffusion steps (default: from config)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guide-scale",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Guidance scale: single float or low,high pair",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shift",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Noise schedule shift (default: from config)",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--output-path", type=str, default="output.mp4", help="Output video path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler",
|
||||
type=str,
|
||||
default="unipc",
|
||||
choices=["euler", "dpm++", "unipc"],
|
||||
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora",
|
||||
nargs=2,
|
||||
action="append",
|
||||
metavar=("PATH", "STRENGTH"),
|
||||
help="Apply a LoRA to all models (repeatable). Format: --lora path.safetensors 0.8",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-high",
|
||||
nargs=2,
|
||||
action="append",
|
||||
metavar=("PATH", "STRENGTH"),
|
||||
help="Apply a LoRA to high-noise model only (dual-model, repeatable)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-low",
|
||||
nargs=2,
|
||||
action="append",
|
||||
metavar=("PATH", "STRENGTH"),
|
||||
help="Apply a LoRA to low-noise model only (dual-model, repeatable)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tiling",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=[
|
||||
"auto",
|
||||
"none",
|
||||
"default",
|
||||
"aggressive",
|
||||
"conservative",
|
||||
"spatial",
|
||||
"temporal",
|
||||
],
|
||||
help="VAE tiling mode to reduce memory during decoding (default: auto)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-compile",
|
||||
action="store_true",
|
||||
help="Disable mx.compile on models (for debugging)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trim-first-frames",
|
||||
type=int,
|
||||
default=0,
|
||||
metavar="N",
|
||||
help="Generate N extra temporal chunks (N×4 frames) and discard them from the start. "
|
||||
"Fixes first-frame color/lighting artifacts on 14B models. Try 1 first (4 frames). "
|
||||
"Default: 0 (disabled)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug-latents",
|
||||
action="store_true",
|
||||
help="Print per-temporal-position latent statistics after denoising (diagnostic)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse guide scale
|
||||
guide_scale = None
|
||||
if args.guide_scale is not None:
|
||||
parts = [float(x) for x in args.guide_scale.split(",")]
|
||||
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
# Handle negative prompt: --no-negative-prompt forces empty, otherwise pass through
|
||||
neg_prompt = args.negative_prompt
|
||||
if args.no_negative_prompt:
|
||||
neg_prompt = ""
|
||||
|
||||
# Parse LoRA configs: convert [path, strength_str] → (path, float)
|
||||
def _parse_lora_args(lora_list):
|
||||
if not lora_list:
|
||||
return None
|
||||
return [(path, float(strength)) for path, strength in lora_list]
|
||||
|
||||
generate_video(
|
||||
model_dir=args.model_dir,
|
||||
prompt=args.prompt,
|
||||
negative_prompt=neg_prompt,
|
||||
image=args.image,
|
||||
width=args.width,
|
||||
height=args.height,
|
||||
num_frames=args.num_frames,
|
||||
steps=args.steps,
|
||||
guide_scale=guide_scale,
|
||||
shift=args.shift,
|
||||
seed=args.seed,
|
||||
output_path=args.output_path,
|
||||
scheduler=args.scheduler,
|
||||
loras=_parse_lora_args(args.lora),
|
||||
loras_high=_parse_lora_args(args.lora_high),
|
||||
loras_low=_parse_lora_args(args.lora_low),
|
||||
tiling=args.tiling,
|
||||
no_compile=args.no_compile,
|
||||
trim_first_frames=args.trim_first_frames,
|
||||
debug_latents=args.debug_latents,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
60
mlx_video/models/wan_2/i2v_utils.py
Normal file
60
mlx_video/models/wan_2/i2v_utils.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Image-to-Video utility functions for Wan2.2."""
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def preprocess_image(image_path: str, width: int, height: int) -> mx.array:
|
||||
"""Load, resize, center-crop, and normalize an image for I2V.
|
||||
|
||||
Args:
|
||||
image_path: Path to input image
|
||||
width: Target width
|
||||
height: Target height
|
||||
|
||||
Returns:
|
||||
Image tensor [1, 1, H, W, 3] in [-1, 1] (channels-last, batch + temporal dims)
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
img = Image.open(image_path).convert("RGB")
|
||||
|
||||
# Resize so that the image covers the target size (LANCZOS)
|
||||
scale = max(width / img.width, height / img.height)
|
||||
img = img.resize(
|
||||
(round(img.width * scale), round(img.height * scale)), Image.LANCZOS
|
||||
)
|
||||
|
||||
# Center crop
|
||||
x1 = (img.width - width) // 2
|
||||
y1 = (img.height - height) // 2
|
||||
img = img.crop((x1, y1, x1 + width, y1 + height))
|
||||
|
||||
# To tensor: [H, W, 3] float32 in [-1, 1]
|
||||
arr = np.array(img, dtype=np.float32) / 255.0
|
||||
arr = arr * 2.0 - 1.0 # [0,1] → [-1,1]
|
||||
return mx.array(arr[None, None]) # [1, 1, H, W, 3]
|
||||
|
||||
|
||||
def build_i2v_mask(z_shape, patch_size):
|
||||
"""Build temporal mask for I2V: first frame = 0, rest = 1.
|
||||
|
||||
Args:
|
||||
z_shape: Latent shape (C, T, H, W) in channels-first
|
||||
patch_size: (pt, ph, pw) patch size
|
||||
|
||||
Returns:
|
||||
mask: (C, T, H, W) float32 — 0 for first frame, 1 for rest
|
||||
mask_tokens: (1, L) float32 — 0 for first-frame tokens, 1 for rest
|
||||
"""
|
||||
C, T, H, W = z_shape
|
||||
mask = mx.ones(z_shape)
|
||||
# Zero out the first temporal position
|
||||
mask = mx.concatenate([mx.zeros((C, 1, H, W)), mask[:, 1:]], axis=1)
|
||||
|
||||
# Token-level mask for per-token timesteps: subsample to patch grid
|
||||
# mask shape [C, T, H, W] → take first channel, subsample by patch_size
|
||||
pt, ph, pw = patch_size
|
||||
mask_tokens = mask[0, ::pt, ::ph, ::pw] # [T', H', W']
|
||||
mask_tokens = mask_tokens.reshape(1, -1) # [1, L]
|
||||
return mask, mask_tokens
|
||||
41
mlx_video/models/wan_2/postprocess.py
Normal file
41
mlx_video/models/wan_2/postprocess.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
|
||||
"""Save video frames to MP4.
|
||||
|
||||
Args:
|
||||
frames: Video frames [T, H, W, 3] uint8
|
||||
output_path: Output file path
|
||||
fps: Frames per second
|
||||
"""
|
||||
try:
|
||||
import imageio
|
||||
|
||||
writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8)
|
||||
for frame in frames:
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
except ImportError:
|
||||
try:
|
||||
import cv2
|
||||
|
||||
h, w = frames.shape[1], frames.shape[2]
|
||||
fourcc = cv2.VideoWriter_fourcc(*"avc1")
|
||||
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
|
||||
for frame in frames:
|
||||
writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||
writer.release()
|
||||
except (ImportError, Exception):
|
||||
# Last resort: save as individual PNGs
|
||||
from PIL import Image
|
||||
|
||||
out_dir = Path(output_path).parent / Path(output_path).stem
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i, frame in enumerate(frames):
|
||||
Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png")
|
||||
print(
|
||||
f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)"
|
||||
)
|
||||
176
mlx_video/models/wan_2/rope.py
Normal file
176
mlx_video/models/wan_2/rope.py
Normal file
@@ -0,0 +1,176 @@
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def rope_params(max_seq_len: int, dim: int, theta: float = 10000.0) -> mx.array:
|
||||
"""Precompute RoPE frequency parameters as complex numbers.
|
||||
|
||||
Returns:
|
||||
Complex frequency tensor of shape [max_seq_len, dim // 2].
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
freqs = (
|
||||
np.arange(max_seq_len, dtype=np.float64)[:, None]
|
||||
* (
|
||||
1.0
|
||||
/ np.power(
|
||||
theta,
|
||||
np.arange(0, dim, 2, dtype=np.float64) / dim,
|
||||
)
|
||||
)[None, :]
|
||||
)
|
||||
# Store as (cos, sin) pairs: shape [max_seq_len, dim // 2, 2]
|
||||
cos_freqs = np.cos(freqs).astype(np.float32)
|
||||
sin_freqs = np.sin(freqs).astype(np.float32)
|
||||
return mx.array(np.stack([cos_freqs, sin_freqs], axis=-1))
|
||||
|
||||
|
||||
def rope_apply(
|
||||
x: mx.array,
|
||||
grid_sizes: list,
|
||||
freqs: mx.array,
|
||||
precomputed_cos_sin: tuple | None = None,
|
||||
) -> mx.array:
|
||||
"""Apply 3-way factorized RoPE to Q or K tensor.
|
||||
|
||||
Args:
|
||||
x: Shape [B, L, num_heads, head_dim]
|
||||
grid_sizes: List of (F, H, W) tuples per batch element
|
||||
freqs: Precomputed cos/sin, shape [1024, d//2, 2] split into 3 parts
|
||||
precomputed_cos_sin: Optional (cos, sin) from rope_precompute_cos_sin()
|
||||
"""
|
||||
b, s, n, d = x.shape
|
||||
half_d = d // 2
|
||||
|
||||
if precomputed_cos_sin is not None:
|
||||
cos_f, sin_f = precomputed_cos_sin
|
||||
# Check if all batch elements have the same grid (common for CFG B=2)
|
||||
f0, h0, w0 = grid_sizes[0]
|
||||
seq_len = f0 * h0 * w0
|
||||
all_same_grid = (
|
||||
all(grid_sizes[i] == grid_sizes[0] for i in range(1, b)) if b > 1 else True
|
||||
)
|
||||
|
||||
if all_same_grid:
|
||||
# Vectorized path: apply RoPE to all batch elements at once
|
||||
x_seq = x[:, :seq_len].reshape(b, seq_len, n, half_d, 2)
|
||||
x_real = x_seq[..., 0]
|
||||
x_imag = x_seq[..., 1]
|
||||
out_real = x_real * cos_f - x_imag * sin_f
|
||||
out_imag = x_real * sin_f + x_imag * cos_f
|
||||
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(
|
||||
b, seq_len, n, d
|
||||
)
|
||||
if seq_len < s:
|
||||
x_rotated = mx.concatenate([x_rotated, x[:, seq_len:]], axis=1)
|
||||
return x_rotated
|
||||
else:
|
||||
# Per-element path for mixed grid sizes
|
||||
outputs = []
|
||||
for i in range(b):
|
||||
f, h, w = grid_sizes[i]
|
||||
sl = f * h * w
|
||||
x_i = x[i, :sl].reshape(sl, n, half_d, 2)
|
||||
x_real = x_i[..., 0]
|
||||
x_imag = x_i[..., 1]
|
||||
out_real = x_real * cos_f - x_imag * sin_f
|
||||
out_imag = x_real * sin_f + x_imag * cos_f
|
||||
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(sl, n, d)
|
||||
if sl < s:
|
||||
x_rotated = mx.concatenate([x_rotated, x[i, sl:]], axis=0)
|
||||
outputs.append(x_rotated)
|
||||
return mx.stack(outputs)
|
||||
|
||||
# Cast freqs to input dtype to prevent float32 promotion cascade
|
||||
if freqs.dtype != x.dtype:
|
||||
freqs = freqs.astype(x.dtype)
|
||||
|
||||
# Split frequency dimensions: temporal gets more capacity
|
||||
d_t = half_d - 2 * (half_d // 3)
|
||||
d_h = half_d // 3
|
||||
d_w = half_d // 3
|
||||
|
||||
# Split freqs along dim axis
|
||||
freqs_t = freqs[:, :d_t] # [1024, d_t, 2]
|
||||
freqs_h = freqs[:, d_t : d_t + d_h] # [1024, d_h, 2]
|
||||
freqs_w = freqs[:, d_t + d_h : d_t + d_h + d_w] # [1024, d_w, 2]
|
||||
|
||||
outputs = []
|
||||
for i in range(b):
|
||||
f, h, w = grid_sizes[i]
|
||||
seq_len = f * h * w
|
||||
|
||||
# Reshape x to pairs for rotation: [seq_len, n, half_d, 2]
|
||||
x_i = x[i, :seq_len].reshape(seq_len, n, half_d, 2)
|
||||
|
||||
# Build per-position frequencies by expanding along grid dims
|
||||
# temporal: [f,1,1,d_t,2] -> [f,h,w,d_t,2]
|
||||
ft = mx.broadcast_to(freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2))
|
||||
# height: [1,h,1,d_h,2] -> [f,h,w,d_h,2]
|
||||
fh = mx.broadcast_to(freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2))
|
||||
# width: [1,1,w,d_w,2] -> [f,h,w,d_w,2]
|
||||
fw = mx.broadcast_to(freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2))
|
||||
|
||||
# Concatenate: [f*h*w, half_d, 2]
|
||||
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)
|
||||
|
||||
# Apply rotation: (a + bi) * (cos + sin*i) = (a*cos - b*sin) + (a*sin + b*cos)i
|
||||
cos_f = freqs_i[..., 0] # [seq_len, 1, half_d]
|
||||
sin_f = freqs_i[..., 1] # [seq_len, 1, half_d]
|
||||
|
||||
x_real = x_i[..., 0] # [seq_len, n, half_d]
|
||||
x_imag = x_i[..., 1] # [seq_len, n, half_d]
|
||||
|
||||
out_real = x_real * cos_f - x_imag * sin_f
|
||||
out_imag = x_real * sin_f + x_imag * cos_f
|
||||
|
||||
# Interleave back: [seq_len, n, half_d, 2] -> [seq_len, n, d]
|
||||
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(seq_len, n, d)
|
||||
|
||||
# Handle padding: keep non-rotated tokens after seq_len
|
||||
if seq_len < s:
|
||||
x_rotated = mx.concatenate([x_rotated, x[i, seq_len:]], axis=0)
|
||||
|
||||
outputs.append(x_rotated)
|
||||
|
||||
return mx.stack(outputs)
|
||||
|
||||
|
||||
def rope_precompute_cos_sin(
|
||||
grid_sizes: list, freqs: mx.array, dtype: type = mx.float32
|
||||
) -> tuple:
|
||||
"""Precompute cos/sin frequency tensors for constant grid sizes.
|
||||
|
||||
Call once before the diffusion loop. Pass result as precomputed_cos_sin
|
||||
to rope_apply to skip per-step broadcast/concat.
|
||||
|
||||
Args:
|
||||
grid_sizes: List of (F, H, W) tuples (must be same for all batch elements)
|
||||
freqs: Precomputed frequencies [1024, d//2, 2]
|
||||
dtype: Target dtype for the output tensors
|
||||
|
||||
Returns:
|
||||
(cos_f, sin_f) each [seq_len, 1, half_d]
|
||||
"""
|
||||
if freqs.dtype != dtype:
|
||||
freqs = freqs.astype(dtype)
|
||||
|
||||
f, h, w = grid_sizes[0]
|
||||
seq_len = f * h * w
|
||||
half_d = freqs.shape[1]
|
||||
|
||||
d_t = half_d - 2 * (half_d // 3)
|
||||
d_h = half_d // 3
|
||||
d_w = half_d // 3
|
||||
|
||||
freqs_t = freqs[:, :d_t]
|
||||
freqs_h = freqs[:, d_t : d_t + d_h]
|
||||
freqs_w = freqs[:, d_t + d_h : d_t + d_h + d_w]
|
||||
|
||||
ft = mx.broadcast_to(freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2))
|
||||
fh = mx.broadcast_to(freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2))
|
||||
fw = mx.broadcast_to(freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2))
|
||||
|
||||
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)
|
||||
return freqs_i[..., 0], freqs_i[..., 1]
|
||||
447
mlx_video/models/wan_2/scheduler.py
Normal file
447
mlx_video/models/wan_2/scheduler.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""Flow matching schedulers for Wan2.2 inference.
|
||||
|
||||
Provides Euler, DPM++2M, and UniPC solvers for flow matching diffusion.
|
||||
Higher-order solvers (DPM++, UniPC) converge faster, needing fewer steps
|
||||
for the same quality as Euler.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _compute_sigmas(
|
||||
num_steps: int, shift: float = 1.0, num_train_timesteps: int = 1000
|
||||
) -> np.ndarray:
|
||||
"""Compute shifted sigma schedule matching official Wan2.2 scheduler.
|
||||
|
||||
The reference creates FlowUniPCMultistepScheduler with shift=1 (identity)
|
||||
in the constructor, deriving sigma_max/sigma_min from the unshifted
|
||||
training schedule. Then set_timesteps() builds a linspace between those
|
||||
unshifted bounds and applies the actual shift once.
|
||||
|
||||
Returns num_steps+1 values (the last being 0.0 for the terminal state).
|
||||
"""
|
||||
# sigma bounds from unshifted training schedule (constructor uses shift=1)
|
||||
alphas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps)[::-1]
|
||||
sigmas_unshifted = 1.0 - alphas
|
||||
sigma_max = float(sigmas_unshifted[0]) # (N-1)/N
|
||||
sigma_min = float(sigmas_unshifted[-1]) # 0.0
|
||||
|
||||
# Interpolate, then apply shift once (matching set_timesteps)
|
||||
sigmas = np.linspace(sigma_max, sigma_min, num_steps + 1)[:-1]
|
||||
sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas)
|
||||
|
||||
return np.append(sigmas, 0.0).astype(np.float32)
|
||||
|
||||
|
||||
class FlowMatchEulerScheduler:
|
||||
"""1st-order Euler scheduler for flow matching diffusion."""
|
||||
|
||||
def __init__(self, num_train_timesteps: int = 1000):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.timesteps = None
|
||||
self.sigmas = None
|
||||
|
||||
def set_timesteps(self, num_steps: int, shift: float = 1.0):
|
||||
sigmas = _compute_sigmas(num_steps, shift, self.num_train_timesteps)
|
||||
self.sigmas = mx.array(sigmas)
|
||||
# Integer timesteps to match reference (model trained with int timesteps)
|
||||
self.timesteps = mx.array(
|
||||
(sigmas[:-1] * self.num_train_timesteps).astype(np.int64).astype(np.float32)
|
||||
)
|
||||
# Store as Python floats to avoid .item() sync in step()
|
||||
self._sigmas_float = sigmas.tolist()
|
||||
self._step_index = 0
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
timestep,
|
||||
sample: mx.array,
|
||||
) -> mx.array:
|
||||
"""Euler step: x_next = x + (sigma_next - sigma_cur) * v."""
|
||||
dt = (
|
||||
self._sigmas_float[self._step_index + 1]
|
||||
- self._sigmas_float[self._step_index]
|
||||
)
|
||||
x_next = sample + dt * model_output
|
||||
self._step_index += 1
|
||||
return x_next
|
||||
|
||||
def reset(self):
|
||||
self._step_index = 0
|
||||
|
||||
|
||||
class FlowDPMPP2MScheduler:
|
||||
"""DPM-Solver++(2M) for flow matching diffusion.
|
||||
|
||||
2nd-order multistep solver that reuses the previous step's model output
|
||||
for a correction term. Falls back to 1st order on the first and
|
||||
(optionally) last step. Reference: Wan2.2 fm_solvers.py.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
lower_order_final: bool = True,
|
||||
):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.lower_order_final = lower_order_final
|
||||
self.timesteps = None
|
||||
self.sigmas = None
|
||||
|
||||
def set_timesteps(self, num_steps: int, shift: float = 1.0):
|
||||
sigmas = _compute_sigmas(num_steps, shift, self.num_train_timesteps)
|
||||
self.sigmas = mx.array(sigmas)
|
||||
self.timesteps = mx.array(
|
||||
(sigmas[:-1] * self.num_train_timesteps).astype(np.int64).astype(np.float32)
|
||||
)
|
||||
# Store sigmas as Python floats for scalar math
|
||||
self._sigmas_float = sigmas.tolist()
|
||||
self._step_index = 0
|
||||
self._num_steps = num_steps
|
||||
self._prev_x0 = None # previous x0 prediction for 2nd-order correction
|
||||
|
||||
@staticmethod
|
||||
def _lambda(sigma: float) -> float:
|
||||
"""log-SNR: lambda(sigma) = log((1-sigma)/sigma).
|
||||
|
||||
Returns -inf at sigma=1.0 (pure noise) and +inf at sigma=0.0 (clean),
|
||||
matching torch.log behavior in the official code.
|
||||
"""
|
||||
if sigma >= 1.0:
|
||||
return -math.inf
|
||||
if sigma <= 0.0:
|
||||
return math.inf
|
||||
return math.log((1.0 - sigma) / sigma)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
timestep,
|
||||
sample: mx.array,
|
||||
) -> mx.array:
|
||||
"""DPM++(2M) step for flow matching.
|
||||
|
||||
Converts velocity prediction to x0, then applies 1st or 2nd order
|
||||
update depending on available history.
|
||||
"""
|
||||
i = self._step_index
|
||||
s = self._sigmas_float
|
||||
|
||||
sigma_cur = s[i]
|
||||
sigma_next = s[i + 1]
|
||||
|
||||
# Convert velocity -> x0 prediction: x0 = sample - sigma * v
|
||||
x0 = sample - sigma_cur * model_output
|
||||
|
||||
# Decide order: 1st for first step, last step (if lower_order_final
|
||||
# and few steps), otherwise 2nd
|
||||
use_first_order = self._prev_x0 is None or (
|
||||
self.lower_order_final and i == self._num_steps - 1 and self._num_steps < 15
|
||||
)
|
||||
|
||||
if use_first_order or sigma_next == 0.0:
|
||||
# 1st order DPM++ (equivalent to DDIM):
|
||||
# x_next = (σ_next/σ_cur)*x - (α_next*(exp(-h)-1))*x0
|
||||
if sigma_next == 0.0:
|
||||
x_next = x0
|
||||
else:
|
||||
lambda_cur = self._lambda(sigma_cur)
|
||||
lambda_next = self._lambda(sigma_next)
|
||||
h = lambda_next - lambda_cur
|
||||
alpha_next = 1.0 - sigma_next
|
||||
coeff_x = sigma_next / sigma_cur
|
||||
coeff_x0 = alpha_next * math.expm1(-h)
|
||||
x_next = coeff_x * sample - coeff_x0 * x0
|
||||
else:
|
||||
# 2nd order DPM++(2M) with midpoint correction
|
||||
sigma_prev = s[i - 1]
|
||||
lambda_prev = self._lambda(sigma_prev)
|
||||
lambda_cur = self._lambda(sigma_cur)
|
||||
lambda_next = self._lambda(sigma_next)
|
||||
|
||||
h = lambda_next - lambda_cur
|
||||
h_0 = lambda_cur - lambda_prev
|
||||
r0 = h_0 / h
|
||||
|
||||
# D0 = current x0, D1 = correction from previous x0
|
||||
D0 = x0
|
||||
D1 = (1.0 / r0) * (x0 - self._prev_x0)
|
||||
|
||||
alpha_next = 1.0 - sigma_next
|
||||
exp_neg_h_m1 = math.expm1(-h) # exp(-h) - 1
|
||||
|
||||
x_next = (
|
||||
(sigma_next / sigma_cur) * sample
|
||||
- (alpha_next * exp_neg_h_m1) * D0
|
||||
- 0.5 * (alpha_next * exp_neg_h_m1) * D1
|
||||
)
|
||||
|
||||
self._prev_x0 = x0
|
||||
self._step_index += 1
|
||||
return x_next
|
||||
|
||||
def reset(self):
|
||||
self._step_index = 0
|
||||
self._prev_x0 = None
|
||||
|
||||
|
||||
class FlowUniPCScheduler:
|
||||
"""UniPC (Unified Predictor-Corrector) for flow matching diffusion.
|
||||
|
||||
Multi-step predictor-corrector solver with configurable order.
|
||||
The corrector refines each step using the model output that was already
|
||||
computed, costing no extra model evaluations. Official Wan2.2 default.
|
||||
Reference: Wan2.2 fm_solvers_unipc.py.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
solver_order: int = 2,
|
||||
lower_order_final: bool = True,
|
||||
disable_corrector: list | None = None,
|
||||
use_corrector: bool = True,
|
||||
):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.solver_order = solver_order
|
||||
self.lower_order_final = lower_order_final
|
||||
self._use_corrector = use_corrector
|
||||
self.disable_corrector = set(disable_corrector or [])
|
||||
self.timesteps = None
|
||||
self.sigmas = None
|
||||
|
||||
def set_timesteps(self, num_steps: int, shift: float = 1.0):
|
||||
sigmas = _compute_sigmas(num_steps, shift, self.num_train_timesteps)
|
||||
self.sigmas = mx.array(sigmas)
|
||||
self.timesteps = mx.array(
|
||||
(sigmas[:-1] * self.num_train_timesteps).astype(np.int64).astype(np.float32)
|
||||
)
|
||||
self._sigmas_float = sigmas.tolist()
|
||||
self._step_index = 0
|
||||
self._num_steps = num_steps
|
||||
self._lower_order_nums = 0
|
||||
# Model output (x0) history for multi-step, stored newest-last
|
||||
self._model_outputs = [None] * self.solver_order
|
||||
self._last_sample = None # sample before prediction (for corrector)
|
||||
self._this_order = 1
|
||||
|
||||
@staticmethod
|
||||
def _lambda(sigma: float) -> float:
|
||||
"""log-SNR: lambda(sigma) = log((1-sigma)/sigma).
|
||||
|
||||
Returns -inf at sigma=1.0 (pure noise) and +inf at sigma=0.0 (clean),
|
||||
matching torch.log behavior in the official code.
|
||||
"""
|
||||
if sigma >= 1.0:
|
||||
return -math.inf
|
||||
if sigma <= 0.0:
|
||||
return math.inf
|
||||
return math.log((1.0 - sigma) / sigma)
|
||||
|
||||
def _convert_output(self, velocity: mx.array, sample: mx.array) -> mx.array:
|
||||
"""Convert velocity prediction to x0: x0 = sample - sigma * v."""
|
||||
sigma = self._sigmas_float[self._step_index]
|
||||
return sample - sigma * velocity
|
||||
|
||||
def _uni_p_bh2(self, x0: mx.array, sample: mx.array, order: int) -> mx.array:
|
||||
"""UniP predictor with B(h)=expm1(-h) basis (bh2 variant).
|
||||
|
||||
Matches official multistep_uni_p_bh_update: computes rhos_p via
|
||||
linalg.solve for order >= 3; order <= 2 uses analytic rhos_p=[0.5].
|
||||
"""
|
||||
i = self._step_index
|
||||
s = self._sigmas_float
|
||||
|
||||
sigma_s0 = s[i]
|
||||
sigma_t = s[i + 1]
|
||||
|
||||
if sigma_t == 0.0:
|
||||
return x0
|
||||
|
||||
lambda_s0 = self._lambda(sigma_s0)
|
||||
lambda_t = self._lambda(sigma_t)
|
||||
h = lambda_t - lambda_s0
|
||||
hh = -h # negated for predict_x0
|
||||
|
||||
alpha_t = 1.0 - sigma_t
|
||||
h_phi_1 = math.expm1(hh)
|
||||
B_h = h_phi_1
|
||||
|
||||
m0 = self._model_outputs[-1]
|
||||
# Base prediction
|
||||
x_t = (sigma_t / sigma_s0) * sample - (alpha_t * h_phi_1) * m0
|
||||
|
||||
if order >= 2 and m0 is not None:
|
||||
rks = []
|
||||
D1s = []
|
||||
for k in range(1, order):
|
||||
si_idx = i - k
|
||||
if si_idx < 0 or self._model_outputs[-(k + 1)] is None:
|
||||
break
|
||||
mk = self._model_outputs[-(k + 1)]
|
||||
sigma_sk = s[si_idx]
|
||||
lambda_sk = self._lambda(sigma_sk)
|
||||
rk = (lambda_sk - lambda_s0) / h
|
||||
if math.isinf(rk):
|
||||
break
|
||||
rks.append(rk)
|
||||
D1s.append((mk - m0) / rk)
|
||||
|
||||
if D1s:
|
||||
effective_order = len(D1s) + 1
|
||||
if effective_order <= 2:
|
||||
# Analytic solution for order 2
|
||||
rhos_p = [0.5]
|
||||
else:
|
||||
rks_arr = np.array(rks, dtype=np.float64)
|
||||
h_phi_k = h_phi_1 / hh - 1.0
|
||||
factorial_i = 1
|
||||
R_rows = []
|
||||
b_vals = []
|
||||
for j in range(1, effective_order):
|
||||
R_rows.append(rks_arr ** (j - 1))
|
||||
b_vals.append(float(h_phi_k * factorial_i / B_h))
|
||||
factorial_i *= j + 1
|
||||
h_phi_k = h_phi_k / hh - 1.0 / factorial_i
|
||||
R = np.stack(R_rows)
|
||||
b = np.array(b_vals)
|
||||
rhos_p = np.linalg.solve(R, b).tolist()
|
||||
|
||||
pred_res = sum(r * d for r, d in zip(rhos_p, D1s))
|
||||
x_t = x_t - (alpha_t * B_h) * pred_res
|
||||
|
||||
return x_t
|
||||
|
||||
def _uni_c_bh2(
|
||||
self,
|
||||
model_x0: mx.array,
|
||||
last_sample: mx.array,
|
||||
this_sample: mx.array,
|
||||
order: int,
|
||||
) -> mx.array:
|
||||
"""UniC corrector with B(h)=expm1(-h) basis (bh2 variant).
|
||||
|
||||
Matches official multistep_uni_c_bh_update: computes rhos_c via
|
||||
linalg.solve for order >= 2 (not hardcoded 0.5).
|
||||
"""
|
||||
i = self._step_index
|
||||
s = self._sigmas_float
|
||||
|
||||
sigma_s0 = s[i - 1]
|
||||
sigma_t = s[i]
|
||||
|
||||
if sigma_t == 0.0:
|
||||
return this_sample
|
||||
|
||||
lambda_s0 = self._lambda(sigma_s0)
|
||||
lambda_t = self._lambda(sigma_t)
|
||||
h = lambda_t - lambda_s0
|
||||
hh = -h # negated for predict_x0
|
||||
|
||||
alpha_t = 1.0 - sigma_t
|
||||
h_phi_1 = math.expm1(hh)
|
||||
B_h = h_phi_1
|
||||
|
||||
m0 = self._model_outputs[-1]
|
||||
# Re-derive base from last_sample
|
||||
x_t_ = (sigma_t / sigma_s0) * last_sample - (alpha_t * h_phi_1) * m0
|
||||
|
||||
D1_t = model_x0 - m0
|
||||
|
||||
# Gather rks and D1s from history
|
||||
rks = []
|
||||
D1s = []
|
||||
for k in range(1, order):
|
||||
si_idx = i - (k + 1)
|
||||
if si_idx < 0 or self._model_outputs[-(k + 1)] is None:
|
||||
break
|
||||
mk = self._model_outputs[-(k + 1)]
|
||||
sigma_sk = s[si_idx]
|
||||
lambda_sk = self._lambda(sigma_sk)
|
||||
rk = (lambda_sk - lambda_s0) / h
|
||||
if math.isinf(rk):
|
||||
break # History references sigma=1.0 boundary; reduce order
|
||||
rks.append(rk)
|
||||
D1s.append((mk - m0) / rk)
|
||||
rks.append(1.0)
|
||||
effective_order = len(rks) # = len(D1s) + 1
|
||||
|
||||
# Compute rhos_c coefficients
|
||||
if effective_order == 1:
|
||||
rhos_c = [0.5]
|
||||
else:
|
||||
rks_arr = np.array(rks, dtype=np.float64)
|
||||
h_phi_k = h_phi_1 / hh - 1.0
|
||||
factorial_i = 1
|
||||
R_rows = []
|
||||
b_vals = []
|
||||
for j in range(1, effective_order + 1):
|
||||
R_rows.append(rks_arr ** (j - 1))
|
||||
b_vals.append(float(h_phi_k * factorial_i / B_h))
|
||||
factorial_i *= j + 1
|
||||
h_phi_k = h_phi_k / hh - 1.0 / factorial_i
|
||||
R = np.stack(R_rows)
|
||||
b = np.array(b_vals)
|
||||
rhos_c = np.linalg.solve(R, b).tolist()
|
||||
|
||||
# Apply correction
|
||||
corr_res = mx.zeros_like(D1_t)
|
||||
for k_idx, d1 in enumerate(D1s):
|
||||
corr_res = corr_res + rhos_c[k_idx] * d1
|
||||
x_t = x_t_ - (alpha_t * B_h) * (corr_res + rhos_c[-1] * D1_t)
|
||||
return x_t
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
timestep,
|
||||
sample: mx.array,
|
||||
) -> mx.array:
|
||||
"""UniPC step: correct current, then predict next."""
|
||||
i = self._step_index
|
||||
|
||||
# Convert velocity -> x0
|
||||
x0 = self._convert_output(model_output, sample)
|
||||
|
||||
# 1. Corrector: refine current sample if we have history
|
||||
use_corrector = (
|
||||
self._use_corrector
|
||||
and i > 0
|
||||
and (i - 1) not in self.disable_corrector
|
||||
and self._last_sample is not None
|
||||
)
|
||||
if use_corrector:
|
||||
sample = self._uni_c_bh2(x0, self._last_sample, sample, self._this_order)
|
||||
|
||||
# 2. Shift model output history
|
||||
for k in range(self.solver_order - 1):
|
||||
self._model_outputs[k] = self._model_outputs[k + 1]
|
||||
self._model_outputs[-1] = x0
|
||||
|
||||
# 3. Determine prediction order
|
||||
if self.lower_order_final:
|
||||
this_order = min(self.solver_order, self._num_steps - i)
|
||||
else:
|
||||
this_order = self.solver_order
|
||||
self._this_order = min(this_order, self._lower_order_nums + 1)
|
||||
|
||||
# 4. Predict next sample
|
||||
self._last_sample = sample
|
||||
x_next = self._uni_p_bh2(x0, sample, self._this_order)
|
||||
|
||||
if self._lower_order_nums < self.solver_order:
|
||||
self._lower_order_nums += 1
|
||||
|
||||
self._step_index += 1
|
||||
return x_next
|
||||
|
||||
def reset(self):
|
||||
self._step_index = 0
|
||||
self._lower_order_nums = 0
|
||||
self._model_outputs = [None] * self.solver_order
|
||||
self._last_sample = None
|
||||
self._this_order = 1
|
||||
239
mlx_video/models/wan_2/text_encoder.py
Normal file
239
mlx_video/models/wan_2/text_encoder.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""T5 Text Encoder (UMT5-XXL) for Wan2.2 text conditioning."""
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
"""RMS-based layer normalization (T5 style)."""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = mx.ones((dim,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return mx.fast.rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
class T5RelativeEmbedding(nn.Module):
|
||||
"""T5-style relative position bias with bucketing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_buckets: int,
|
||||
num_heads: int,
|
||||
bidirectional: bool = True,
|
||||
max_dist: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_buckets = num_buckets
|
||||
self.num_heads = num_heads
|
||||
self.bidirectional = bidirectional
|
||||
self.max_dist = max_dist
|
||||
self.embedding = nn.Embedding(num_buckets, num_heads)
|
||||
|
||||
def _relative_position_bucket(self, rel_pos: mx.array) -> mx.array:
|
||||
if self.bidirectional:
|
||||
num_buckets = self.num_buckets // 2
|
||||
rel_buckets = (rel_pos > 0).astype(mx.int32) * num_buckets
|
||||
rel_pos = mx.abs(rel_pos)
|
||||
else:
|
||||
num_buckets = self.num_buckets
|
||||
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
|
||||
rel_pos = mx.maximum(-rel_pos, mx.zeros_like(rel_pos))
|
||||
|
||||
max_exact = num_buckets // 2
|
||||
is_small = rel_pos < max_exact
|
||||
|
||||
rel_pos_f = rel_pos.astype(mx.float32)
|
||||
rel_pos_large = max_exact + (
|
||||
mx.log(rel_pos_f / max_exact)
|
||||
/ math.log(self.max_dist / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).astype(mx.int32)
|
||||
rel_pos_large = mx.minimum(
|
||||
rel_pos_large,
|
||||
mx.full(rel_pos_large.shape, num_buckets - 1, dtype=mx.int32),
|
||||
)
|
||||
|
||||
rel_buckets = rel_buckets + mx.where(
|
||||
is_small, rel_pos.astype(mx.int32), rel_pos_large
|
||||
)
|
||||
return rel_buckets
|
||||
|
||||
def __call__(self, lq: int, lk: int) -> mx.array:
|
||||
positions_k = mx.arange(lk)[None, :] # [1, lk]
|
||||
positions_q = mx.arange(lq)[:, None] # [lq, 1]
|
||||
rel_pos = positions_k - positions_q # [lq, lk]
|
||||
|
||||
buckets = self._relative_position_bucket(rel_pos)
|
||||
embeds = self.embedding(buckets) # [lq, lk, num_heads]
|
||||
embeds = embeds.transpose(2, 0, 1)[None, :, :, :] # [1, N, lq, lk]
|
||||
return embeds
|
||||
|
||||
|
||||
class T5Attention(nn.Module):
|
||||
"""T5-style multi-head attention (no scaling)."""
|
||||
|
||||
def __init__(self, dim: int, dim_attn: int, num_heads: int, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
assert dim_attn % num_heads == 0
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim_attn // num_heads
|
||||
|
||||
self.q = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.k = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.v = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.o = nn.Linear(dim_attn, dim, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
context: mx.array | None = None,
|
||||
mask: mx.array | None = None,
|
||||
pos_bias: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
context = x if context is None else context
|
||||
b, n, c = x.shape[0], self.num_heads, self.head_dim
|
||||
|
||||
q = self.q(x).reshape(b, -1, n, c) # [B, Lq, N, C]
|
||||
k = self.k(context).reshape(b, -1, n, c) # [B, Lk, N, C]
|
||||
v = self.v(context).reshape(b, -1, n, c)
|
||||
|
||||
# T5 uses no scaling — compute attention manually with float32 softmax
|
||||
# to match official: F.softmax(attn.float(), dim=-1).type_as(attn)
|
||||
# Using SDPA with bfloat16 inputs causes precision loss in softmax
|
||||
# since unscaled logits can be very large (no 1/sqrt(d) division).
|
||||
q = q.transpose(0, 2, 1, 3) # [B, N, Lq, C]
|
||||
k = k.transpose(0, 2, 1, 3)
|
||||
v = v.transpose(0, 2, 1, 3)
|
||||
|
||||
# QK^T (no scaling) — compute in float32 for precision
|
||||
attn = q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2)
|
||||
|
||||
# Add position bias
|
||||
if pos_bias is not None:
|
||||
attn = attn + pos_bias.astype(mx.float32)
|
||||
|
||||
# Apply attention mask (use dtype min like official, not -1e9)
|
||||
if mask is not None:
|
||||
if mask.ndim == 2:
|
||||
mask = mask[:, None, None, :] # [B, 1, 1, Lk]
|
||||
elif mask.ndim == 3:
|
||||
mask = mask[:, None, :, :] # [B, 1, Lq, Lk]
|
||||
additive_mask = mx.where(mask == 0, -3.389e38, 0.0).astype(mx.float32)
|
||||
attn = attn + additive_mask
|
||||
|
||||
# Softmax in float32 (matches official), then cast back
|
||||
attn = mx.softmax(attn, axis=-1).astype(q.dtype)
|
||||
|
||||
# Attention @ V
|
||||
out = (attn @ v).transpose(0, 2, 1, 3).reshape(b, -1, n * c)
|
||||
return self.o(out)
|
||||
|
||||
|
||||
class T5FeedForward(nn.Module):
|
||||
"""Gated feed-forward: gate(x) * fc1(x) -> fc2."""
|
||||
|
||||
def __init__(self, dim: int, dim_ffn: int):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_ffn = dim_ffn
|
||||
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.gate_act = nn.GELU(approx="tanh")
|
||||
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.fc2(self.fc1(x) * self.gate_act(self.gate_proj(x)))
|
||||
|
||||
|
||||
class T5SelfAttentionBlock(nn.Module):
|
||||
"""T5 encoder block: self-attention + FFN."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_attn: int,
|
||||
dim_ffn: int,
|
||||
num_heads: int,
|
||||
num_buckets: int,
|
||||
shared_pos: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.shared_pos = shared_pos
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.attn = T5Attention(dim, dim_attn, num_heads)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn)
|
||||
self.pos_embedding = (
|
||||
None
|
||||
if shared_pos
|
||||
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array | None = None,
|
||||
pos_bias: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(x.shape[1], x.shape[1])
|
||||
x = x + self.attn(self.norm1(x), mask=mask, pos_bias=e)
|
||||
x = x + self.ffn(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class T5Encoder(nn.Module):
|
||||
"""T5 Encoder (UMT5-XXL configuration)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int = 256384,
|
||||
dim: int = 4096,
|
||||
dim_attn: int = 4096,
|
||||
dim_ffn: int = 10240,
|
||||
num_heads: int = 64,
|
||||
num_layers: int = 24,
|
||||
num_buckets: int = 32,
|
||||
shared_pos: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
self.token_embedding = nn.Embedding(vocab_size, dim)
|
||||
self.pos_embedding = (
|
||||
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
|
||||
if shared_pos
|
||||
else None
|
||||
)
|
||||
self.blocks = [
|
||||
T5SelfAttentionBlock(
|
||||
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
def __call__(self, ids: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||
"""
|
||||
Args:
|
||||
ids: Token IDs [B, L]
|
||||
mask: Attention mask [B, L]
|
||||
|
||||
Returns:
|
||||
Hidden states [B, L, dim]
|
||||
"""
|
||||
x = self.token_embedding(ids)
|
||||
|
||||
e = self.pos_embedding(x.shape[1], x.shape[1]) if self.pos_embedding else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask=mask, pos_bias=e)
|
||||
|
||||
x = self.norm(x)
|
||||
return x
|
||||
338
mlx_video/models/wan_2/tiling.py
Normal file
338
mlx_video/models/wan_2/tiling.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""Wan-specific tiled VAE decoding.
|
||||
|
||||
Re-exports all tiling utilities from the LTX VAE tiling module and provides
|
||||
a Wan-specific ``decode_with_tiling`` that adds ``causal_temporal`` support
|
||||
for non-causal temporal decoders (e.g. Wan2.1 where T latent frames → T*scale
|
||||
output frames rather than LTX's 1+(T-1)*scale mapping).
|
||||
|
||||
# TODO: This function can be refactored to consolidate with
|
||||
# mlx_video.models.ltx_2.video_vae.tiling.decode_with_tiling once the
|
||||
# causal_temporal generalisation is accepted upstream.
|
||||
"""
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import (
|
||||
SpatialTilingConfig,
|
||||
TemporalTilingConfig,
|
||||
TilingConfig,
|
||||
map_spatial_slice,
|
||||
map_temporal_slice,
|
||||
split_in_spatial,
|
||||
split_in_temporal,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SpatialTilingConfig",
|
||||
"TemporalTilingConfig",
|
||||
"TilingConfig",
|
||||
"decode_with_tiling",
|
||||
"map_spatial_slice",
|
||||
"map_temporal_slice",
|
||||
"split_in_spatial",
|
||||
"split_in_temporal",
|
||||
]
|
||||
|
||||
|
||||
def decode_with_tiling(
|
||||
decoder_fn,
|
||||
latents: mx.array,
|
||||
tiling_config: TilingConfig,
|
||||
spatial_scale: int = 32,
|
||||
temporal_scale: int = 8,
|
||||
causal: bool = False,
|
||||
causal_temporal: bool = True,
|
||||
timestep: Optional[mx.array] = None,
|
||||
chunked_conv: bool = False,
|
||||
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
|
||||
) -> mx.array:
|
||||
"""Decode latents using tiling to reduce memory usage.
|
||||
|
||||
Args:
|
||||
decoder_fn: Decoder function to call for each tile.
|
||||
latents: Input latents of shape (B, C, F, H, W).
|
||||
tiling_config: Tiling configuration.
|
||||
spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify).
|
||||
temporal_scale: Temporal scale factor (8 for LTX VAE).
|
||||
causal: Whether to use causal convolutions.
|
||||
causal_temporal: Whether the decoder uses causal temporal mapping where
|
||||
T input frames produce 1+(T-1)*scale output frames. When False, uses
|
||||
simple scaling where T frames produce T*scale output frames.
|
||||
Default True (LTX behavior). Set False for non-causal decoders (e.g. Wan2.1).
|
||||
timestep: Optional timestep for conditioning.
|
||||
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
|
||||
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
|
||||
frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames.
|
||||
start_idx: Starting frame index in the full video.
|
||||
|
||||
Returns:
|
||||
Decoded video.
|
||||
"""
|
||||
import gc
|
||||
|
||||
b, c, f_latent, h_latent, w_latent = latents.shape
|
||||
|
||||
# Compute output shape
|
||||
out_f = (
|
||||
(1 + (f_latent - 1) * temporal_scale)
|
||||
if causal_temporal
|
||||
else (f_latent * temporal_scale)
|
||||
)
|
||||
out_h = h_latent * spatial_scale
|
||||
out_w = w_latent * spatial_scale
|
||||
|
||||
# Get tile size and overlap in latent space
|
||||
if tiling_config.spatial_config is not None:
|
||||
s_cfg = tiling_config.spatial_config
|
||||
spatial_tile_size = s_cfg.tile_size_in_pixels // spatial_scale
|
||||
spatial_overlap = s_cfg.tile_overlap_in_pixels // spatial_scale
|
||||
else:
|
||||
spatial_tile_size = max(h_latent, w_latent)
|
||||
spatial_overlap = 0
|
||||
|
||||
if tiling_config.temporal_config is not None:
|
||||
t_cfg = tiling_config.temporal_config
|
||||
temporal_tile_size = t_cfg.tile_size_in_frames // temporal_scale
|
||||
temporal_overlap = t_cfg.tile_overlap_in_frames // temporal_scale
|
||||
else:
|
||||
temporal_tile_size = f_latent
|
||||
temporal_overlap = 0
|
||||
|
||||
# Compute intervals for each dimension
|
||||
if causal_temporal:
|
||||
temporal_intervals = split_in_temporal(
|
||||
temporal_tile_size, temporal_overlap, f_latent
|
||||
)
|
||||
else:
|
||||
temporal_intervals = split_in_spatial(
|
||||
temporal_tile_size, temporal_overlap, f_latent
|
||||
)
|
||||
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
|
||||
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
|
||||
|
||||
num_t_tiles = len(temporal_intervals.starts)
|
||||
num_h_tiles = len(height_intervals.starts)
|
||||
num_w_tiles = len(width_intervals.starts)
|
||||
total_tiles = num_t_tiles * num_h_tiles * num_w_tiles # noqa: F841
|
||||
|
||||
# Initialize output and weight accumulator
|
||||
# Use float32 for accumulation to avoid precision issues
|
||||
output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32)
|
||||
weights = mx.zeros((b, 1, out_f, out_h, out_w), dtype=mx.float32)
|
||||
mx.eval(output, weights)
|
||||
|
||||
tile_idx = 0
|
||||
for t_idx in range(num_t_tiles):
|
||||
t_start = temporal_intervals.starts[t_idx]
|
||||
t_end = temporal_intervals.ends[t_idx]
|
||||
t_left = temporal_intervals.left_ramps[t_idx]
|
||||
t_right = temporal_intervals.right_ramps[t_idx]
|
||||
|
||||
# Map temporal coordinates
|
||||
if causal_temporal:
|
||||
out_t_slice, t_mask = map_temporal_slice(
|
||||
t_start, t_end, t_left, t_right, temporal_scale
|
||||
)
|
||||
else:
|
||||
out_t_slice, t_mask = map_spatial_slice(
|
||||
t_start, t_end, t_left, t_right, temporal_scale
|
||||
)
|
||||
|
||||
for h_idx in range(num_h_tiles):
|
||||
h_start = height_intervals.starts[h_idx]
|
||||
h_end = height_intervals.ends[h_idx]
|
||||
h_left = height_intervals.left_ramps[h_idx]
|
||||
h_right = height_intervals.right_ramps[h_idx]
|
||||
|
||||
# Map height coordinates
|
||||
out_h_slice, h_mask = map_spatial_slice(
|
||||
h_start, h_end, h_left, h_right, spatial_scale
|
||||
)
|
||||
|
||||
for w_idx in range(num_w_tiles):
|
||||
w_start = width_intervals.starts[w_idx]
|
||||
w_end = width_intervals.ends[w_idx]
|
||||
w_left = width_intervals.left_ramps[w_idx]
|
||||
w_right = width_intervals.right_ramps[w_idx]
|
||||
|
||||
# Map width coordinates
|
||||
out_w_slice, w_mask = map_spatial_slice(
|
||||
w_start, w_end, w_left, w_right, spatial_scale
|
||||
)
|
||||
|
||||
# Extract tile latents (small slice)
|
||||
tile_latents = latents[
|
||||
:, :, t_start:t_end, h_start:h_end, w_start:w_end
|
||||
]
|
||||
|
||||
# Decode tile
|
||||
tile_output = decoder_fn(
|
||||
tile_latents,
|
||||
causal=causal,
|
||||
timestep=timestep,
|
||||
debug=False,
|
||||
chunked_conv=chunked_conv,
|
||||
)
|
||||
mx.eval(tile_output)
|
||||
|
||||
# Clear tile_latents reference
|
||||
del tile_latents
|
||||
|
||||
# Get actual decoded dimensions
|
||||
_, _, decoded_t, decoded_h, decoded_w = tile_output.shape
|
||||
expected_t = out_t_slice.stop - out_t_slice.start
|
||||
expected_h = out_h_slice.stop - out_h_slice.start
|
||||
expected_w = out_w_slice.stop - out_w_slice.start
|
||||
|
||||
# Handle potential size mismatches (use minimum)
|
||||
actual_t = min(decoded_t, expected_t)
|
||||
actual_h = min(decoded_h, expected_h)
|
||||
actual_w = min(decoded_w, expected_w)
|
||||
|
||||
# Build blend mask
|
||||
t_mask_slice = t_mask[:actual_t] if len(t_mask) > actual_t else t_mask
|
||||
h_mask_slice = h_mask[:actual_h] if len(h_mask) > actual_h else h_mask
|
||||
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
|
||||
|
||||
blend_mask = (
|
||||
t_mask_slice.reshape(1, 1, -1, 1, 1)
|
||||
* h_mask_slice.reshape(1, 1, 1, -1, 1)
|
||||
* w_mask_slice.reshape(1, 1, 1, 1, -1)
|
||||
)
|
||||
|
||||
# Slice tile output to match
|
||||
tile_output_slice = tile_output[
|
||||
:, :, :actual_t, :actual_h, :actual_w
|
||||
].astype(mx.float32)
|
||||
|
||||
# Clear full tile_output
|
||||
del tile_output
|
||||
|
||||
# Compute output coordinates
|
||||
t_out_start = out_t_slice.start
|
||||
t_out_end = t_out_start + actual_t
|
||||
h_out_start = out_h_slice.start
|
||||
h_out_end = h_out_start + actual_h
|
||||
w_out_start = out_w_slice.start
|
||||
w_out_end = w_out_start + actual_w
|
||||
|
||||
# Weighted accumulation
|
||||
weighted_tile = tile_output_slice * blend_mask
|
||||
|
||||
# Update output using slice assignment
|
||||
output[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
] = (
|
||||
output[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
]
|
||||
+ weighted_tile
|
||||
)
|
||||
weights[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
] = (
|
||||
weights[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
]
|
||||
+ blend_mask
|
||||
)
|
||||
|
||||
# Force evaluation to free memory
|
||||
mx.eval(output, weights)
|
||||
|
||||
# Clean up tile-specific arrays
|
||||
del tile_output_slice, weighted_tile, blend_mask
|
||||
del t_mask_slice, h_mask_slice, w_mask_slice
|
||||
|
||||
tile_idx += 1
|
||||
|
||||
# Periodic garbage collection and cache clearing
|
||||
if tile_idx % 4 == 0:
|
||||
gc.collect()
|
||||
try:
|
||||
mx.clear_cache()
|
||||
except Exception:
|
||||
pass # May not be available on all platforms
|
||||
|
||||
# After completing all spatial tiles for this temporal tile,
|
||||
# check if any frames are now finalized (no future tiles will contribute)
|
||||
if on_frames_ready is not None and num_t_tiles > 1:
|
||||
# Determine the finalized frame boundary
|
||||
# Frames before the start of the next tile's output region are finalized
|
||||
if t_idx < num_t_tiles - 1:
|
||||
# Next tile starts at temporal_intervals.starts[t_idx + 1]
|
||||
next_tile_start_latent = temporal_intervals.starts[t_idx + 1]
|
||||
# Map to output frame index (first frame of next tile's contribution)
|
||||
if next_tile_start_latent == 0:
|
||||
next_tile_start_out = 0
|
||||
elif causal_temporal:
|
||||
next_tile_start_out = (
|
||||
1 + (next_tile_start_latent - 1) * temporal_scale
|
||||
)
|
||||
else:
|
||||
next_tile_start_out = next_tile_start_latent * temporal_scale
|
||||
|
||||
# We need to track how many frames we've already emitted
|
||||
if not hasattr(decode_with_tiling, "_emitted_frames"):
|
||||
decode_with_tiling._emitted_frames = 0
|
||||
emitted = decode_with_tiling._emitted_frames
|
||||
|
||||
if next_tile_start_out > emitted:
|
||||
# Normalize and emit frames [emitted, next_tile_start_out)
|
||||
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
|
||||
finalized_weights = mx.maximum(finalized_weights, 1e-8)
|
||||
finalized_output = (
|
||||
output[:, :, emitted:next_tile_start_out, :, :]
|
||||
/ finalized_weights
|
||||
)
|
||||
finalized_output = finalized_output.astype(latents.dtype)
|
||||
mx.eval(finalized_output)
|
||||
|
||||
on_frames_ready(finalized_output, emitted)
|
||||
decode_with_tiling._emitted_frames = next_tile_start_out
|
||||
|
||||
del finalized_output, finalized_weights
|
||||
gc.collect()
|
||||
|
||||
# Normalize by weights
|
||||
weights = mx.maximum(weights, 1e-8)
|
||||
output = output / weights
|
||||
mx.eval(output)
|
||||
|
||||
# Emit remaining frames if callback provided
|
||||
if on_frames_ready is not None:
|
||||
emitted = getattr(decode_with_tiling, "_emitted_frames", 0)
|
||||
if emitted < out_f:
|
||||
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
||||
mx.eval(remaining_output)
|
||||
on_frames_ready(remaining_output, emitted)
|
||||
del remaining_output
|
||||
|
||||
# Reset emitted frames counter for next call
|
||||
if hasattr(decode_with_tiling, "_emitted_frames"):
|
||||
del decode_with_tiling._emitted_frames
|
||||
|
||||
# Clean up weights
|
||||
del weights
|
||||
gc.collect()
|
||||
|
||||
# Convert back to original dtype if needed
|
||||
return output.astype(latents.dtype)
|
||||
104
mlx_video/models/wan_2/transformer.py
Normal file
104
mlx_video/models/wan_2/transformer.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention, _linear_dtype
|
||||
|
||||
|
||||
class WanAttentionBlock(nn.Module):
|
||||
"""Wan transformer block with learned modulation, self-attn, cross-attn, and FFN."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
ffn_dim: int,
|
||||
num_heads: int,
|
||||
window_size: tuple = (-1, -1),
|
||||
qk_norm: bool = True,
|
||||
cross_attn_norm: bool = False,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Self-attention
|
||||
self.norm1 = WanLayerNorm(dim, eps)
|
||||
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
|
||||
|
||||
# Cross-attention (with optional norm on context)
|
||||
self.norm3 = (
|
||||
WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else None
|
||||
)
|
||||
self.cross_attn = WanCrossAttention(dim, num_heads, qk_norm, eps)
|
||||
|
||||
# Feed-forward
|
||||
self.norm2 = WanLayerNorm(dim, eps)
|
||||
self.ffn = WanFFN(dim, ffn_dim)
|
||||
|
||||
# Learned modulation: 6 vectors for scale/shift/gate (kept in float32 for precision)
|
||||
self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(
|
||||
mx.float32
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
e: mx.array,
|
||||
seq_lens: list,
|
||||
grid_sizes: list,
|
||||
freqs: mx.array,
|
||||
context: mx.array,
|
||||
context_lens: list | None = None,
|
||||
cross_kv_cache: tuple | None = None,
|
||||
rope_cos_sin: tuple | None = None,
|
||||
attn_mask: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
# Modulation: compute in float32 for precision, matching the reference
|
||||
# which keeps residual x in float32 via torch.amp.autocast(dtype=float32).
|
||||
# By keeping modulation in float32, type promotion ensures the residual
|
||||
# stream stays float32 throughout all 30 layers (gate * output + x → float32).
|
||||
mod = self.modulation + e # float32
|
||||
e0, e1, e2, e3, e4, e5 = (
|
||||
mod[:, :, 0, :], # shift for self-attn
|
||||
mod[:, :, 1, :], # scale for self-attn
|
||||
mod[:, :, 2, :], # gate for self-attn
|
||||
mod[:, :, 3, :], # shift for ffn
|
||||
mod[:, :, 4, :], # scale for ffn
|
||||
mod[:, :, 5, :], # gate for ffn
|
||||
)
|
||||
|
||||
# Self-attention with modulation (hidden state stays in w_dtype)
|
||||
x_mod = self.norm1(x) * (1 + e1) + e0
|
||||
y = self.self_attn(
|
||||
x_mod,
|
||||
seq_lens,
|
||||
grid_sizes,
|
||||
freqs,
|
||||
rope_cos_sin=rope_cos_sin,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
x = x + y * e2
|
||||
|
||||
# Cross-attention (no modulation, just norm)
|
||||
x_cross = self.norm3(x) if self.norm3 is not None else x
|
||||
x = x + self.cross_attn(x_cross, context, context_lens, kv_cache=cross_kv_cache)
|
||||
|
||||
# FFN with modulation
|
||||
x_mod = self.norm2(x) * (1 + e4) + e3
|
||||
y = self.ffn(x_mod)
|
||||
x = x + y * e5
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class WanFFN(nn.Module):
|
||||
"""Gated feed-forward network with GELU(tanh) activation."""
|
||||
|
||||
def __init__(self, dim: int, ffn_dim: int):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(dim, ffn_dim)
|
||||
self.act = nn.GELU(approx="tanh")
|
||||
self.fc2 = nn.Linear(ffn_dim, dim)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
x_w = x.astype(_linear_dtype(self.fc1))
|
||||
return self.fc2(self.act(self.fc1(x_w)))
|
||||
191
mlx_video/models/wan_2/utils.py
Normal file
191
mlx_video/models/wan_2/utils.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Wan model loading utilities."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
def load_wan_model(
|
||||
model_path: Path,
|
||||
config,
|
||||
quantization: dict | None = None,
|
||||
loras: list | None = None,
|
||||
):
|
||||
"""Load and initialize WanModel, with optional quantization and LoRA support.
|
||||
|
||||
Args:
|
||||
model_path: Path to model safetensors file
|
||||
config: WanModelConfig
|
||||
quantization: Optional dict with 'bits' and 'group_size' keys.
|
||||
If provided, creates QuantizedLinear stubs before loading.
|
||||
loras: Optional list of (lora_path, strength) tuples to apply.
|
||||
"""
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
model = WanModel(config)
|
||||
|
||||
if quantization:
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
|
||||
nn.quantize(
|
||||
model,
|
||||
group_size=quantization["group_size"],
|
||||
bits=quantization["bits"],
|
||||
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
||||
)
|
||||
|
||||
weights = mx.load(str(model_path))
|
||||
|
||||
# Apply LoRAs: dequantize+merge for quantized models, weight merge for bf16
|
||||
if loras:
|
||||
if quantization:
|
||||
# Dequantize LoRA-targeted layers, merge delta, replace with bf16 Linear.
|
||||
# Non-LoRA layers stay 4-bit. Zero per-step overhead.
|
||||
from mlx_video.models.wan_2.convert import _load_lora_configs
|
||||
from mlx_video.lora import apply_loras_to_model
|
||||
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(model.parameters())
|
||||
module_to_loras = _load_lora_configs(loras)
|
||||
apply_loras_to_model(model, module_to_loras)
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
else:
|
||||
# Weight merging: fold LoRA into bf16 weights before loading
|
||||
from mlx_video.models.wan_2.convert import load_and_apply_loras
|
||||
|
||||
weights = load_and_apply_loras(dict(weights), loras)
|
||||
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
|
||||
|
||||
def load_t5_encoder(model_path: Path, config):
|
||||
"""Load T5 text encoder.
|
||||
|
||||
Weights are upcast to float32 for maximum precision — the T5 encoder
|
||||
only runs once per generation, so performance impact is negligible.
|
||||
This matches the official which computes softmax in float32 explicitly.
|
||||
"""
|
||||
from mlx_video.models.wan_2.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
vocab_size=config.t5_vocab_size,
|
||||
dim=config.t5_dim,
|
||||
dim_attn=config.t5_dim_attn,
|
||||
dim_ffn=config.t5_dim_ffn,
|
||||
num_heads=config.t5_num_heads,
|
||||
num_layers=config.t5_num_layers,
|
||||
num_buckets=config.t5_num_buckets,
|
||||
shared_pos=False,
|
||||
)
|
||||
weights = mx.load(str(model_path))
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
encoder.load_weights(list(weights.items()))
|
||||
mx.eval(encoder.parameters())
|
||||
return encoder
|
||||
|
||||
|
||||
def load_vae_decoder(model_path: Path, config=None):
|
||||
"""Load VAE decoder (skips encoder weights with strict=False).
|
||||
|
||||
For Wan2.2 (vae_z_dim=48), uses Wan22VAEDecoder.
|
||||
For Wan2.1 (vae_z_dim=16), uses WanVAE.
|
||||
"""
|
||||
is_wan22 = config is not None and config.vae_z_dim == 48
|
||||
|
||||
if is_wan22:
|
||||
from mlx_video.models.wan_2.vae22 import Wan22VAEDecoder
|
||||
|
||||
vae = Wan22VAEDecoder(z_dim=48)
|
||||
else:
|
||||
from mlx_video.models.wan_2.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16)
|
||||
|
||||
weights = mx.load(str(model_path))
|
||||
# Upcast VAE weights to float32 for quality — official Wan2.2 runs VAE in float32
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
vae.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(vae.parameters())
|
||||
return vae
|
||||
|
||||
|
||||
def load_vae_encoder(model_path: Path, config=None):
|
||||
"""Load VAE encoder for I2V image encoding.
|
||||
|
||||
For Wan2.2 TI2V (vae_z_dim=48), uses Wan22VAEEncoder.
|
||||
For Wan2.1/I2V-14B (vae_z_dim=16), uses WanVAE with encoder=True.
|
||||
"""
|
||||
if config is not None and config.vae_z_dim == 16:
|
||||
from mlx_video.models.wan_2.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16, encoder=True)
|
||||
else:
|
||||
from mlx_video.models.wan_2.vae22 import Wan22VAEEncoder
|
||||
|
||||
vae = Wan22VAEEncoder(z_dim=config.vae_z_dim if config else 48)
|
||||
|
||||
weights = mx.load(str(model_path))
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
vae.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(vae.parameters())
|
||||
return vae
|
||||
|
||||
|
||||
def _clean_text(text: str) -> str:
|
||||
"""Clean text matching official Wan2.2 tokenizer preprocessing.
|
||||
|
||||
Applies ftfy.fix_text (fixes mojibake, normalizes fullwidth chars),
|
||||
double HTML unescape, and whitespace normalization. Critical for
|
||||
correct tokenization of the Chinese negative prompt.
|
||||
"""
|
||||
import html
|
||||
import re
|
||||
|
||||
try:
|
||||
import ftfy
|
||||
|
||||
text = ftfy.fix_text(text)
|
||||
except ImportError:
|
||||
pass
|
||||
text = html.unescape(html.unescape(text))
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
return text
|
||||
|
||||
|
||||
def encode_text(
|
||||
encoder,
|
||||
tokenizer,
|
||||
prompt: str,
|
||||
text_len: int = 512,
|
||||
) -> mx.array:
|
||||
"""Encode text prompt using T5 encoder.
|
||||
|
||||
Args:
|
||||
encoder: T5Encoder model
|
||||
tokenizer: HuggingFace tokenizer
|
||||
prompt: Text prompt
|
||||
text_len: Maximum text length
|
||||
|
||||
Returns:
|
||||
Text embeddings [L, dim]
|
||||
"""
|
||||
prompt = _clean_text(prompt)
|
||||
tokens = tokenizer(
|
||||
prompt,
|
||||
max_length=text_len,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
ids = mx.array(tokens["input_ids"])
|
||||
mask = mx.array(tokens["attention_mask"])
|
||||
|
||||
embeddings = encoder(ids, mask=mask)
|
||||
|
||||
# Return only non-padding tokens
|
||||
seq_len = int(mask.sum().item())
|
||||
return embeddings[0, :seq_len]
|
||||
629
mlx_video/models/wan_2/vae.py
Normal file
629
mlx_video/models/wan_2/vae.py
Normal file
@@ -0,0 +1,629 @@
|
||||
"""3D VAE Decoder for Wan2.1/2.2 (compression 4×8×8).
|
||||
|
||||
Module structure mirrors original PyTorch checkpoint key hierarchy
|
||||
so weights load directly without key sanitization.
|
||||
"""
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
# Per-channel normalization statistics for z_dim=16
|
||||
VAE_MEAN = [
|
||||
-0.7571,
|
||||
-0.7089,
|
||||
-0.9113,
|
||||
0.1075,
|
||||
-0.1745,
|
||||
0.9653,
|
||||
-0.1517,
|
||||
1.5508,
|
||||
0.4134,
|
||||
-0.0715,
|
||||
0.5517,
|
||||
-0.3632,
|
||||
-0.1922,
|
||||
-0.9497,
|
||||
0.2503,
|
||||
-0.2921,
|
||||
]
|
||||
VAE_STD = [
|
||||
2.8184,
|
||||
1.4541,
|
||||
2.3275,
|
||||
2.6558,
|
||||
1.2196,
|
||||
1.7708,
|
||||
2.6052,
|
||||
2.0743,
|
||||
3.2687,
|
||||
2.1526,
|
||||
2.8652,
|
||||
1.5579,
|
||||
1.6382,
|
||||
1.1253,
|
||||
2.8251,
|
||||
1.9160,
|
||||
]
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
"""3D convolution with causal temporal padding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int | tuple,
|
||||
stride: int | tuple = 1,
|
||||
padding: int | tuple = 0,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
if isinstance(padding, int):
|
||||
padding = (padding, padding, padding)
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
# Causal padding: match reference formula dilation*(k-1) + (1-stride)
|
||||
# With dilation=1: k-stride (pads left only, no future context)
|
||||
self._causal_pad_t = kernel_size[0] - stride[0]
|
||||
self._pad_h = padding[1]
|
||||
self._pad_w = padding[2]
|
||||
|
||||
# MLX Conv3d: weight shape [O, D, H, W, I]
|
||||
self.weight = mx.zeros(
|
||||
(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels)
|
||||
)
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
|
||||
def __call__(self, x: mx.array, cache_x: mx.array = None) -> mx.array:
|
||||
"""x: [B, C, T, H, W] (channel-first)"""
|
||||
b, c, t, h, w = x.shape
|
||||
|
||||
causal_pad = self._causal_pad_t
|
||||
if cache_x is not None and causal_pad > 0:
|
||||
x = mx.concatenate([cache_x, x], axis=2)
|
||||
causal_pad = max(0, causal_pad - cache_x.shape[2])
|
||||
|
||||
if causal_pad > 0:
|
||||
pad_t = mx.zeros((b, c, causal_pad, h, w), dtype=x.dtype)
|
||||
x = mx.concatenate([pad_t, x], axis=2)
|
||||
|
||||
if self._pad_h > 0 or self._pad_w > 0:
|
||||
x = mx.pad(
|
||||
x,
|
||||
[
|
||||
(0, 0),
|
||||
(0, 0),
|
||||
(0, 0),
|
||||
(self._pad_h, self._pad_h),
|
||||
(self._pad_w, self._pad_w),
|
||||
],
|
||||
)
|
||||
|
||||
x = x.transpose(0, 2, 3, 4, 1) # [B, T, H, W, C]
|
||||
out = self._conv3d(x)
|
||||
return out.transpose(0, 4, 1, 2, 3) # [B, O, T', H', W']
|
||||
|
||||
def _conv3d(self, x: mx.array) -> mx.array:
|
||||
"""3D conv via sliding window + 2D conv per time step.
|
||||
x: [B, T, H, W, C_in] -> [B, T_out, H_out, W_out, C_out]
|
||||
"""
|
||||
b, t, h, w, c_in = x.shape
|
||||
kt, kh, kw = self.kernel_size
|
||||
st, sh, sw = self.stride
|
||||
t_out = (t - kt) // st + 1
|
||||
|
||||
# Pre-reshape weight: [O, D, H, W, I] -> [O, H, W, D*I]
|
||||
w_2d = self.weight.transpose(0, 2, 3, 1, 4).reshape(
|
||||
self.weight.shape[0], kh, kw, kt * c_in
|
||||
)
|
||||
outputs = []
|
||||
for t_i in range(t_out):
|
||||
t_start = t_i * st
|
||||
window = x[:, t_start : t_start + kt]
|
||||
window = window.transpose(0, 2, 3, 1, 4).reshape(b, h, w, kt * c_in)
|
||||
out_2d = mx.conv2d(window, w_2d, stride=(sh, sw)) + self.bias
|
||||
outputs.append(out_2d)
|
||||
return mx.stack(outputs, axis=1)
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
"""Channel-first L2 normalization matching original Wan VAE.
|
||||
|
||||
Uses F.normalize (L2 norm) with learned scale, equivalent to RMS norm.
|
||||
images=True: gamma shape (dim, 1, 1) for 4D (per-frame) input.
|
||||
images=False: gamma shape (dim, 1, 1, 1) for 5D video input.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, channel_first: bool = True, images: bool = True):
|
||||
super().__init__()
|
||||
self.channel_first = channel_first
|
||||
self.scale = dim**0.5
|
||||
if channel_first:
|
||||
broadcastable = (1, 1) if images else (1, 1, 1)
|
||||
self.gamma = mx.ones((dim, *broadcastable))
|
||||
else:
|
||||
self.gamma = mx.ones((dim,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
norm_dim = 1 if self.channel_first else -1
|
||||
# L2 normalize along channel dim (matches F.normalize)
|
||||
norm = mx.sqrt(
|
||||
mx.clip(
|
||||
mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None
|
||||
)
|
||||
)
|
||||
return (x / norm) * self.scale * self.gamma
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
"""Residual block with causal 3D convolutions.
|
||||
|
||||
Uses `residual` list with None gaps to match original PyTorch
|
||||
nn.Sequential indices: [0]=norm, [1]=SiLU, [2]=conv, [3]=norm,
|
||||
[4]=SiLU, [5]=Dropout, [6]=conv. Only indices 0,2,3,6 have params.
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim: int, out_dim: int):
|
||||
super().__init__()
|
||||
self.residual = [
|
||||
RMS_norm(in_dim, images=False), # [0]
|
||||
None, # [1] SiLU
|
||||
CausalConv3d(in_dim, out_dim, 3, padding=1), # [2]
|
||||
RMS_norm(out_dim, images=False), # [3]
|
||||
None, # [4] SiLU
|
||||
None, # [5] Dropout
|
||||
CausalConv3d(out_dim, out_dim, 3, padding=1), # [6]
|
||||
]
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
|
||||
|
||||
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
|
||||
h = x if self.shortcut is None else self.shortcut(x)
|
||||
|
||||
if feat_cache is not None:
|
||||
# First conv: norm -> silu -> [cache] -> conv
|
||||
x = nn.silu(self.residual[0](x))
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:]
|
||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
x = self.residual[2](x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
# Second conv: norm -> silu -> [cache] -> conv
|
||||
x = nn.silu(self.residual[3](x))
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:]
|
||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
x = self.residual[6](x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = nn.silu(self.residual[0](x))
|
||||
x = self.residual[2](x)
|
||||
x = nn.silu(self.residual[3](x))
|
||||
x = self.residual[6](x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""Single-head spatial self-attention."""
|
||||
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.norm = RMS_norm(dim, images=True)
|
||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""x: [B, C, T, H, W]"""
|
||||
identity = x
|
||||
b, c, t, h, w = x.shape
|
||||
|
||||
# [B,C,T,H,W] -> [B,T,C,H,W] -> [BT,C,H,W] -> norm -> [BT,H,W,C]
|
||||
x = x.transpose(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
x = self.norm(x)
|
||||
x = x.transpose(0, 2, 3, 1) # [BT, H, W, C]
|
||||
|
||||
qkv = self.to_qkv(x) # [BT, H, W, 3C]
|
||||
qkv = qkv.reshape(b * t, h * w, 3, c).transpose(2, 0, 1, 3)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
q = q[:, None, :, :] # [BT, 1, HW, C]
|
||||
k = k[:, None, :, :]
|
||||
v = v[:, None, :, :]
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=c**-0.5)
|
||||
out = out.squeeze(1).reshape(b * t, h, w, c) # [BT, H, W, C]
|
||||
|
||||
out = self.proj(out) # [BT, H, W, C]
|
||||
out = out.reshape(b, t, h, w, c).transpose(0, 4, 1, 2, 3) # [B, C, T, H, W]
|
||||
return out + identity
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
"""Resample block matching original Wan VAE structure.
|
||||
|
||||
Supports both upsampling (decoder) and downsampling (encoder).
|
||||
Uses list-based param storage to match original nn.Sequential key hierarchy.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, mode: str):
|
||||
super().__init__()
|
||||
assert mode in ("upsample2d", "upsample3d", "downsample2d", "downsample3d")
|
||||
self.mode = mode
|
||||
self.dim = dim
|
||||
|
||||
if mode.startswith("upsample"):
|
||||
# resample.0 = Upsample (no params), resample.1 = Conv2d
|
||||
self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)]
|
||||
if mode == "upsample3d":
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)
|
||||
)
|
||||
else:
|
||||
# resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2)
|
||||
self.resample = [None, nn.Conv2d(dim, dim, 3, stride=2)]
|
||||
if mode == "downsample3d":
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
|
||||
"""x: [B, C, T, H, W]"""
|
||||
b, c, t, h, w = x.shape
|
||||
|
||||
if self.mode == "upsample3d":
|
||||
# Temporal upsample via learned conv
|
||||
x_t = self.time_conv(x) # [B, 2C, T, H, W]
|
||||
x_t = x_t.reshape(b, 2, c, t, h, w)
|
||||
x = mx.stack([x_t[:, 0], x_t[:, 1]], axis=3).reshape(b, c, t * 2, h, w)
|
||||
t = t * 2
|
||||
|
||||
if self.mode.startswith("upsample"):
|
||||
# Per-frame spatial upsample: nearest 2x + Conv2d
|
||||
x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C]
|
||||
x = mx.repeat(x, 2, axis=1)
|
||||
x = mx.repeat(x, 2, axis=2)
|
||||
x = self.resample[1](x) # Conv2d [BT, 2H, 2W, C//2]
|
||||
c_out = x.shape[-1]
|
||||
return x.reshape(b, t, h * 2, w * 2, c_out).transpose(0, 4, 1, 2, 3)
|
||||
else:
|
||||
# Per-frame spatial downsample: ZeroPad(0,1,0,1) + Conv2d(stride=2)
|
||||
x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C]
|
||||
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) # ZeroPad2d(0,1,0,1)
|
||||
x = self.resample[1](x) # Conv2d stride=2
|
||||
c_out = x.shape[-1]
|
||||
h_out, w_out = x.shape[1], x.shape[2]
|
||||
x = x.reshape(b, t, h_out, w_out, c_out).transpose(0, 4, 1, 2, 3)
|
||||
|
||||
if self.mode == "downsample3d":
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
# First chunk: save x, skip time_conv
|
||||
feat_cache[idx] = x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
# Subsequent chunks: use cached frame as temporal context
|
||||
cache_x = x[:, :, -1:]
|
||||
x = self.time_conv(x, cache_x=feat_cache[idx][:, :, -1:])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.time_conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
"""3D VAE Decoder matching Wan2.1 architecture.
|
||||
|
||||
Uses flat `middle` and `upsamples` lists to match original
|
||||
PyTorch nn.Sequential weight key hierarchy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 96,
|
||||
z_dim: int = 16,
|
||||
dim_mult: list = None,
|
||||
num_res_blocks: int = 2,
|
||||
temporal_upsample: list = None,
|
||||
):
|
||||
super().__init__()
|
||||
if dim_mult is None:
|
||||
dim_mult = [1, 2, 4, 4]
|
||||
if temporal_upsample is None:
|
||||
temporal_upsample = [True, True, False]
|
||||
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
|
||||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# Middle: [ResBlock, AttentionBlock, ResBlock]
|
||||
self.middle = [
|
||||
ResidualBlock(dims[0], dims[0]),
|
||||
AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0]),
|
||||
]
|
||||
|
||||
# Flat upsample list matching original nn.Sequential indexing
|
||||
upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
if i in (1, 2, 3):
|
||||
in_dim = in_dim // 2
|
||||
for _ in range(num_res_blocks + 1):
|
||||
upsamples.append(ResidualBlock(in_dim, out_dim))
|
||||
in_dim = out_dim
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = "upsample3d" if temporal_upsample[i] else "upsample2d"
|
||||
upsamples.append(Resample(out_dim, mode=mode))
|
||||
self.upsamples = upsamples
|
||||
|
||||
# Output head: [RMS_norm, SiLU (no params), CausalConv3d]
|
||||
self.head = [
|
||||
RMS_norm(dims[-1], images=False), # [0]
|
||||
None, # [1] SiLU
|
||||
CausalConv3d(dims[-1], 3, 3, padding=1), # [2]
|
||||
]
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""x: [B, z_dim, T, H, W] -> [B, 3, T_out, H_out, W_out]"""
|
||||
x = self.conv1(x)
|
||||
|
||||
for layer in self.middle:
|
||||
x = layer(x)
|
||||
|
||||
for layer in self.upsamples:
|
||||
x = layer(x)
|
||||
|
||||
x = nn.silu(self.head[0](x))
|
||||
x = self.head[2](x)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
"""3D VAE Encoder matching Wan2.1 architecture.
|
||||
|
||||
Mirror of Decoder3d with downsampling instead of upsampling.
|
||||
Uses flat lists to match original PyTorch nn.Sequential weight key hierarchy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 96,
|
||||
z_dim: int = 16,
|
||||
dim_mult: list = None,
|
||||
num_res_blocks: int = 2,
|
||||
temporal_downsample: list = None,
|
||||
):
|
||||
super().__init__()
|
||||
if dim_mult is None:
|
||||
dim_mult = [1, 2, 4, 4]
|
||||
if temporal_downsample is None:
|
||||
temporal_downsample = [False, True, True]
|
||||
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
|
||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# Flat downsample list matching original nn.Sequential indexing
|
||||
downsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
for _ in range(num_res_blocks):
|
||||
downsamples.append(ResidualBlock(in_dim, out_dim))
|
||||
in_dim = out_dim
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = "downsample3d" if temporal_downsample[i] else "downsample2d"
|
||||
downsamples.append(Resample(out_dim, mode=mode))
|
||||
self.downsamples = downsamples
|
||||
|
||||
# Middle: [ResBlock, AttentionBlock, ResBlock]
|
||||
self.middle = [
|
||||
ResidualBlock(dims[-1], dims[-1]),
|
||||
AttentionBlock(dims[-1]),
|
||||
ResidualBlock(dims[-1], dims[-1]),
|
||||
]
|
||||
|
||||
# Output head: [RMS_norm, SiLU (no params), CausalConv3d]
|
||||
self.head = [
|
||||
RMS_norm(dims[-1], images=False),
|
||||
None, # SiLU
|
||||
CausalConv3d(dims[-1], z_dim, 3, padding=1),
|
||||
]
|
||||
|
||||
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
|
||||
"""x: [B, 3, T, H, W] -> [B, z_dim, T_lat, H_lat, W_lat]"""
|
||||
if feat_cache is not None:
|
||||
# conv1 with caching
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:]
|
||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
x = self.conv1(x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None and isinstance(layer, (ResidualBlock, Resample)):
|
||||
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
for layer in self.middle:
|
||||
if feat_cache is not None and isinstance(layer, ResidualBlock):
|
||||
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
if feat_cache is not None:
|
||||
# Head: norm -> silu -> [cache] -> conv
|
||||
x = nn.silu(self.head[0](x))
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:]
|
||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
x = self.head[2](x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = nn.silu(self.head[0](x))
|
||||
x = self.head[2](x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class WanVAE(nn.Module):
|
||||
"""Wan2.1 VAE wrapper with per-channel normalization.
|
||||
|
||||
Supports both encode (for I2V) and decode (for all models).
|
||||
"""
|
||||
|
||||
def __init__(self, z_dim: int = 16, encoder: bool = False):
|
||||
super().__init__()
|
||||
self.z_dim = z_dim
|
||||
self.mean = mx.array(VAE_MEAN)
|
||||
self.std = mx.array(VAE_STD)
|
||||
self.inv_std = 1.0 / self.std
|
||||
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim=96, z_dim=z_dim)
|
||||
|
||||
if encoder:
|
||||
self.encoder = Encoder3d(dim=96, z_dim=z_dim * 2)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
|
||||
def encode(self, x: mx.array) -> mx.array:
|
||||
"""Encode video to normalized latent using chunked encoding.
|
||||
|
||||
Uses chunked encoding with temporal caching to match reference behavior.
|
||||
First frame encoded alone, then 4-frame chunks with cached context.
|
||||
|
||||
Args:
|
||||
x: Video [B, 3, T, H, W] in [-1, 1]
|
||||
|
||||
Returns:
|
||||
Normalized latent [B, z_dim, T_lat, H_lat, W_lat]
|
||||
"""
|
||||
# Count cacheable CausalConv3d slots in encoder
|
||||
num_slots = self._count_encoder_cache_slots()
|
||||
feat_cache = [None] * num_slots
|
||||
|
||||
t = x.shape[2]
|
||||
num_chunks = 1 + (t - 1) // 4
|
||||
|
||||
out = None
|
||||
for i in range(num_chunks):
|
||||
feat_idx = [0]
|
||||
if i == 0:
|
||||
chunk = x[:, :, :1]
|
||||
else:
|
||||
chunk = x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i]
|
||||
|
||||
chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
|
||||
if out is None:
|
||||
out = chunk_out
|
||||
else:
|
||||
out = mx.concatenate([out, chunk_out], axis=2)
|
||||
|
||||
mu, _ = mx.split(self.conv1(out), 2, axis=1)
|
||||
|
||||
# Normalize: (mu - mean) * inv_std
|
||||
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
|
||||
return (mu - mean) * inv_std
|
||||
|
||||
def _count_encoder_cache_slots(self) -> int:
|
||||
"""Count CausalConv3d that participate in chunked encoding cache."""
|
||||
count = 1 # encoder.conv1
|
||||
for layer in self.encoder.downsamples:
|
||||
if isinstance(layer, ResidualBlock):
|
||||
count += 2 # two convs in residual path
|
||||
elif isinstance(layer, Resample) and layer.mode == "downsample3d":
|
||||
count += 1 # time_conv
|
||||
for layer in self.encoder.middle:
|
||||
if isinstance(layer, ResidualBlock):
|
||||
count += 2
|
||||
count += 1 # encoder.head CausalConv3d
|
||||
return count
|
||||
|
||||
def decode(self, z: mx.array) -> mx.array:
|
||||
"""Decode latent to video.
|
||||
|
||||
Args:
|
||||
z: Normalized latent [B, z_dim, T, H, W]
|
||||
|
||||
Returns:
|
||||
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
|
||||
"""
|
||||
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
|
||||
z = z / inv_std + mean
|
||||
|
||||
x = self.conv2(z)
|
||||
out = self.decoder(x)
|
||||
return mx.clip(out, -1, 1)
|
||||
|
||||
def decode_tiled(self, z: mx.array, tiling_config=None) -> mx.array:
|
||||
"""Decode latent to video using tiling to reduce memory usage.
|
||||
|
||||
Splits the latent tensor into overlapping spatial/temporal tiles,
|
||||
decodes each tile independently, and blends them with trapezoidal
|
||||
masks. Reuses the LTX-2 tiling infrastructure.
|
||||
|
||||
Args:
|
||||
z: Normalized latent [B, z_dim, T, H, W]
|
||||
tiling_config: Optional TilingConfig. If None, uses default.
|
||||
|
||||
Returns:
|
||||
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
|
||||
"""
|
||||
from mlx_video.models.wan_2.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
if tiling_config is None:
|
||||
tiling_config = TilingConfig.default()
|
||||
|
||||
# Check if tiling is actually needed
|
||||
_, _, f, h, w = z.shape
|
||||
needs_tiling = False
|
||||
if tiling_config.spatial_config is not None:
|
||||
s_tile = tiling_config.spatial_config.tile_size_in_pixels // 8
|
||||
if h > s_tile or w > s_tile:
|
||||
needs_tiling = True
|
||||
if tiling_config.temporal_config is not None:
|
||||
t_tile = tiling_config.temporal_config.tile_size_in_frames // 4
|
||||
if f > t_tile:
|
||||
needs_tiling = True
|
||||
|
||||
if not needs_tiling:
|
||||
return self.decode(z)
|
||||
|
||||
# Denormalize once (small tensor), then tile the denormalized latents
|
||||
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
|
||||
z_denorm = z / inv_std + mean
|
||||
|
||||
def tile_decode(tile_latents, **kwargs):
|
||||
x = self.conv2(tile_latents)
|
||||
out = self.decoder(x)
|
||||
return mx.clip(out, -1, 1)
|
||||
|
||||
return decode_with_tiling(
|
||||
decoder_fn=tile_decode,
|
||||
latents=z_denorm,
|
||||
tiling_config=tiling_config,
|
||||
spatial_scale=8, # 3× spatial 2× upsamples = 8×
|
||||
temporal_scale=4, # 2× temporal upsamples × 2 = 4×
|
||||
causal_temporal=False, # Wan2.1 uses non-causal temporal (T → 4T)
|
||||
)
|
||||
1150
mlx_video/models/wan_2/vae22.py
Normal file
1150
mlx_video/models/wan_2/vae22.py
Normal file
File diff suppressed because it is too large
Load Diff
388
mlx_video/models/wan_2/wan_2.py
Normal file
388
mlx_video/models/wan_2/wan_2.py
Normal file
@@ -0,0 +1,388 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .attention import WanLayerNorm, _linear_dtype
|
||||
from .config import WanModelConfig
|
||||
from .rope import rope_params, rope_precompute_cos_sin
|
||||
from .transformer import WanAttentionBlock
|
||||
|
||||
|
||||
def sinusoidal_embedding_1d(dim: int, position: mx.array) -> mx.array:
|
||||
"""Compute sinusoidal positional embeddings.
|
||||
|
||||
Args:
|
||||
dim: Embedding dimension (must be even).
|
||||
position: Tensor of positions — 1D [L] or 2D [B, L].
|
||||
|
||||
Returns:
|
||||
Embeddings of shape [L, dim] or [B, L, dim].
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
half = dim // 2
|
||||
pos = position.astype(mx.float32)
|
||||
inv_freq = mx.power(10000.0, -mx.arange(half).astype(mx.float32) / half)
|
||||
sinusoid = pos[..., None] * inv_freq # [..., half]
|
||||
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
"""Output projection head with learned modulation."""
|
||||
|
||||
def __init__(self, dim: int, out_dim: int, patch_size: tuple, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.out_dim = out_dim
|
||||
self.patch_size = patch_size
|
||||
proj_dim = math.prod(patch_size) * out_dim
|
||||
self.norm = WanLayerNorm(dim, eps)
|
||||
self.head = nn.Linear(dim, proj_dim)
|
||||
self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(
|
||||
mx.float32
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
|
||||
"""
|
||||
Args:
|
||||
x: [B, L, dim]
|
||||
e: [B, dim] or [B, 1, dim] (broadcast) or [B, L, dim] (per-token)
|
||||
"""
|
||||
if e.ndim == 2:
|
||||
e = e[:, None, :] # [B, 1, dim]
|
||||
# Compute modulation in float32 (matching reference's autocast(float32))
|
||||
mod = self.modulation[:, None, :, :] + e[:, :, None, :] # float32
|
||||
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
|
||||
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
|
||||
x_norm = self.norm(x)
|
||||
x_mod = x_norm * (1 + e1) + e0
|
||||
return self.head(x_mod)
|
||||
|
||||
|
||||
class WanModel(nn.Module):
|
||||
"""Wan2.2 diffusion backbone for text-to-video generation."""
|
||||
|
||||
def __init__(self, config: WanModelConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
dim = config.dim
|
||||
self.dim = dim
|
||||
self.num_heads = config.num_heads
|
||||
self.out_dim = config.out_dim
|
||||
self.patch_size = config.patch_size
|
||||
self.text_len = config.text_len
|
||||
self.freq_dim = config.freq_dim
|
||||
|
||||
# Patch embedding: Conv3d implemented as a reshaped linear
|
||||
# For kernel (1,2,2) and stride (1,2,2): reshape input then linear
|
||||
patch_dim = config.in_dim * math.prod(config.patch_size)
|
||||
self.patch_embedding_proj = nn.Linear(patch_dim, dim)
|
||||
self._patch_size = config.patch_size
|
||||
|
||||
# Text embedding MLP
|
||||
self.text_embedding_0 = nn.Linear(config.text_dim, dim)
|
||||
self.text_embedding_act = nn.GELU(approx="tanh")
|
||||
self.text_embedding_1 = nn.Linear(dim, dim)
|
||||
|
||||
# Time embedding MLP
|
||||
self.time_embedding_0 = nn.Linear(config.freq_dim, dim)
|
||||
self.time_embedding_act = nn.SiLU()
|
||||
self.time_embedding_1 = nn.Linear(dim, dim)
|
||||
|
||||
# Time projection for modulation (6x dim)
|
||||
self.time_projection_act = nn.SiLU()
|
||||
self.time_projection = nn.Linear(dim, dim * 6)
|
||||
|
||||
# Transformer blocks
|
||||
self.blocks = [
|
||||
WanAttentionBlock(
|
||||
dim=dim,
|
||||
ffn_dim=config.ffn_dim,
|
||||
num_heads=config.num_heads,
|
||||
window_size=config.window_size,
|
||||
qk_norm=config.qk_norm,
|
||||
cross_attn_norm=config.cross_attn_norm,
|
||||
eps=config.eps,
|
||||
)
|
||||
for _ in range(config.num_layers)
|
||||
]
|
||||
|
||||
# Output head
|
||||
self.head = Head(dim, config.out_dim, config.patch_size, config.eps)
|
||||
|
||||
# Precompute RoPE frequencies — three separate tables concatenated.
|
||||
# Reference computes three rope_params with different dim normalizations
|
||||
# so each axis (temporal/height/width) gets its own full frequency range.
|
||||
d = dim // config.num_heads
|
||||
self.freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
# Precompute sinusoidal inv_freq for time embedding.
|
||||
half = config.freq_dim // 2
|
||||
self._inv_freq = mx.array(
|
||||
np.power(10000.0, -np.arange(half, dtype=np.float64) / half).astype(
|
||||
np.float32
|
||||
)
|
||||
)
|
||||
|
||||
def _patchify(self, x: mx.array) -> tuple:
|
||||
"""Convert video tensor to patch embeddings.
|
||||
|
||||
Args:
|
||||
x: Video latent [C, F, H, W]
|
||||
|
||||
Returns:
|
||||
(patches, grid_size): patches [1, L, dim], grid_size (F', H', W')
|
||||
"""
|
||||
c, f, h, w = x.shape
|
||||
pt, ph, pw = self._patch_size
|
||||
|
||||
f_out = f // pt
|
||||
h_out = h // ph
|
||||
w_out = w // pw
|
||||
|
||||
# Reshape: [C, F, H, W] -> [F', H', W', C, pt, ph, pw] -> [F'*H'*W', C*pt*ph*pw]
|
||||
# Order must be [C, pt, ph, pw] (C slowest) to match Conv3d weight layout
|
||||
x = x.reshape(c, f_out, pt, h_out, ph, w_out, pw)
|
||||
x = x.transpose(1, 3, 5, 0, 2, 4, 6) # [F', H', W', C, pt, ph, pw]
|
||||
x = x.reshape(f_out * h_out * w_out, -1) # [L, C*pt*ph*pw]
|
||||
|
||||
# Project and cast to model dtype to prevent float32 cascade from input latents
|
||||
patches = self.patch_embedding_proj(x) # [L, dim]
|
||||
patches = patches.astype(_linear_dtype(self.patch_embedding_proj))
|
||||
patches = patches[None, :, :] # [1, L, dim]
|
||||
|
||||
return patches, (f_out, h_out, w_out)
|
||||
|
||||
def unpatchify(self, x: mx.array, grid_sizes: list) -> list:
|
||||
"""Reconstruct video from patch embeddings.
|
||||
|
||||
Args:
|
||||
x: [B, L, out_dim * prod(patch_size)]
|
||||
grid_sizes: List of (F', H', W') per batch element
|
||||
|
||||
Returns:
|
||||
List of tensors [C, F, H, W]
|
||||
"""
|
||||
c = self.out_dim
|
||||
pt, ph, pw = self.patch_size
|
||||
out = []
|
||||
for i, (f, h, w) in enumerate(grid_sizes):
|
||||
seq_len = f * h * w
|
||||
u = x[i, :seq_len] # [L, out_dim * pt * ph * pw]
|
||||
u = u.reshape(f, h, w, pt, ph, pw, c)
|
||||
# Rearrange: [F', H', W', pt, ph, pw, C] -> [C, F'*pt, H'*ph, W'*pw]
|
||||
u = u.transpose(6, 0, 3, 1, 4, 2, 5) # [C, F', pt, H', ph, W', pw]
|
||||
u = u.reshape(c, f * pt, h * ph, w * pw)
|
||||
out.append(u)
|
||||
return out
|
||||
|
||||
def embed_text(self, context: list) -> mx.array:
|
||||
"""Precompute text embeddings (call once, reuse across steps).
|
||||
|
||||
Args:
|
||||
context: List of text embeddings [L_text, text_dim]
|
||||
|
||||
Returns:
|
||||
Embedded context [B, text_len, dim] in model dtype
|
||||
"""
|
||||
model_dtype = _linear_dtype(self.patch_embedding_proj)
|
||||
context_padded = []
|
||||
for ctx in context:
|
||||
pad_len = self.text_len - ctx.shape[0]
|
||||
if pad_len > 0:
|
||||
ctx = mx.concatenate(
|
||||
[ctx, mx.zeros((pad_len, ctx.shape[1]), dtype=ctx.dtype)],
|
||||
axis=0,
|
||||
)
|
||||
context_padded.append(ctx)
|
||||
context_batch = mx.stack(context_padded) # [B, text_len, text_dim]
|
||||
context_batch = self.text_embedding_1(
|
||||
self.text_embedding_act(self.text_embedding_0(context_batch))
|
||||
)
|
||||
return context_batch.astype(model_dtype)
|
||||
|
||||
def prepare_cross_kv(self, context: mx.array) -> list:
|
||||
"""Pre-compute cross-attention K/V for all blocks.
|
||||
|
||||
Call once before the diffusion loop to cache K/V projections,
|
||||
eliminating redundant computation at each denoising step.
|
||||
|
||||
Args:
|
||||
context: Pre-embedded text [B, text_len, dim]
|
||||
|
||||
Returns:
|
||||
List of (k, v) tuples, one per block
|
||||
"""
|
||||
kv_caches = []
|
||||
for block in self.blocks:
|
||||
kv_caches.append(block.cross_attn.prepare_kv(context))
|
||||
return kv_caches
|
||||
|
||||
def prepare_rope(self, grid_sizes: list) -> tuple:
|
||||
"""Pre-compute RoPE cos/sin for constant grid sizes.
|
||||
|
||||
Call once before the diffusion loop when grid sizes don't change
|
||||
across steps. Eliminates per-step broadcast/concat overhead.
|
||||
|
||||
Args:
|
||||
grid_sizes: List of (F, H, W) tuples per batch element
|
||||
|
||||
Returns:
|
||||
(cos_f, sin_f) precomputed frequency tensors
|
||||
"""
|
||||
w_dtype = _linear_dtype(self.patch_embedding_proj)
|
||||
return rope_precompute_cos_sin(grid_sizes, self.freqs, dtype=w_dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x_list: list,
|
||||
t: mx.array,
|
||||
context: list | mx.array,
|
||||
seq_len: int,
|
||||
cross_kv_caches: list | None = None,
|
||||
y: list | None = None,
|
||||
rope_cos_sin: tuple | None = None,
|
||||
) -> list:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x_list: List of video latent tensors [C, F, H, W]
|
||||
t: Timestep tensor [B]
|
||||
context: List of raw text embeddings, OR pre-embedded tensor
|
||||
from embed_text() [B, text_len, dim]
|
||||
seq_len: Maximum sequence length for padding
|
||||
cross_kv_caches: Optional list of (k, v) tuples from
|
||||
prepare_cross_kv(), one per block.
|
||||
y: Optional list of conditioning tensors for I2V [C_y, F, H, W].
|
||||
Channel-concatenated with x before patchify.
|
||||
rope_cos_sin: Optional precomputed (cos, sin) from prepare_rope().
|
||||
|
||||
Returns:
|
||||
List of denoised tensors [C, F, H, W]
|
||||
"""
|
||||
# Detect identical inputs (CFG B=2) to avoid duplicate patchify work.
|
||||
# Check BEFORE I2V concat since concat creates new array objects.
|
||||
batch_size = len(x_list)
|
||||
all_same = batch_size > 1 and all(
|
||||
x_list[i] is x_list[0] for i in range(1, batch_size)
|
||||
)
|
||||
if all_same and y is not None:
|
||||
all_same = all(y[i] is y[0] for i in range(1, len(y)))
|
||||
|
||||
# I2V: channel-concatenate conditioning y with noise x
|
||||
if y is not None:
|
||||
x_list = [mx.concatenate([u, v], axis=0) for u, v in zip(x_list, y)]
|
||||
|
||||
if all_same:
|
||||
# Patchify once and broadcast — saves a Linear projection per step
|
||||
p, gs = self._patchify(x_list[0]) # [1, L, dim]
|
||||
grid_sizes = [gs] * batch_size
|
||||
seq_lens_list = [p.shape[1]] * batch_size
|
||||
# Pad and broadcast
|
||||
if p.shape[1] < seq_len:
|
||||
p = mx.concatenate(
|
||||
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
|
||||
axis=1,
|
||||
)
|
||||
x = mx.broadcast_to(p, (batch_size,) + p.shape[1:])
|
||||
else:
|
||||
patches = []
|
||||
grid_sizes = []
|
||||
seq_lens_list = []
|
||||
for vid in x_list:
|
||||
p, gs = self._patchify(vid) # [1, L, dim]
|
||||
patches.append(p)
|
||||
grid_sizes.append(gs)
|
||||
seq_lens_list.append(p.shape[1])
|
||||
x = mx.concatenate(
|
||||
[
|
||||
(
|
||||
mx.concatenate(
|
||||
[
|
||||
p,
|
||||
mx.zeros(
|
||||
(1, seq_len - p.shape[1], self.dim), dtype=p.dtype
|
||||
),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
if p.shape[1] < seq_len
|
||||
else p
|
||||
)
|
||||
for p in patches
|
||||
],
|
||||
axis=0,
|
||||
) # [B, seq_len, dim]
|
||||
|
||||
# Time embedding: sinusoidal from precomputed inv_freq.
|
||||
# inv_freq was computed in float64 for precision, stored as float32.
|
||||
# With integer timesteps (matching reference), float32 sin/cos is fine.
|
||||
if t.ndim == 0:
|
||||
t = t[None]
|
||||
|
||||
sinusoid = t[..., None].astype(mx.float32) * self._inv_freq
|
||||
sin_emb = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
|
||||
|
||||
if t.ndim == 1:
|
||||
# Standard T2V: scalar timestep per batch element [B]
|
||||
e = self.time_embedding_1(
|
||||
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||
) # [B, dim]
|
||||
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
|
||||
e0 = e0.reshape(batch_size, 1, 6, self.dim)
|
||||
else:
|
||||
# I2V: per-token timesteps [B, L]
|
||||
e = self.time_embedding_1(
|
||||
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||
) # [B, L, dim]
|
||||
e0 = self.time_projection(self.time_projection_act(e)) # [B, L, dim*6]
|
||||
e0 = e0.reshape(batch_size, -1, 6, self.dim)
|
||||
|
||||
# Text embedding: skip MLP if context is already embedded (mx.array)
|
||||
if isinstance(context, mx.array):
|
||||
# Pre-embedded: expand to batch size if needed
|
||||
context_batch = context
|
||||
if context_batch.shape[0] == 1 and batch_size > 1:
|
||||
context_batch = mx.broadcast_to(
|
||||
context_batch, (batch_size,) + context_batch.shape[1:]
|
||||
)
|
||||
else:
|
||||
context_batch = self.embed_text(context)
|
||||
|
||||
# Pre-compute attention mask from seq_lens (constant across all blocks)
|
||||
attn_mask = None
|
||||
w_dtype = _linear_dtype(self.patch_embedding_proj)
|
||||
if any(sl < seq_len for sl in seq_lens_list):
|
||||
attn_mask = mx.zeros((batch_size, 1, 1, seq_len), dtype=w_dtype)
|
||||
for i, sl in enumerate(seq_lens_list):
|
||||
attn_mask[i, :, :, sl:] = -1e9
|
||||
|
||||
kwargs = dict(
|
||||
e=e0,
|
||||
seq_lens=seq_lens_list,
|
||||
grid_sizes=grid_sizes,
|
||||
freqs=self.freqs,
|
||||
context=context_batch,
|
||||
context_lens=None,
|
||||
rope_cos_sin=rope_cos_sin,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
|
||||
# Run transformer blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
kv = cross_kv_caches[i] if cross_kv_caches is not None else None
|
||||
x = block(x, cross_kv_cache=kv, **kwargs)
|
||||
|
||||
# Output head
|
||||
x = self.head(x, e)
|
||||
|
||||
# Unpatchify
|
||||
outputs = self.unpatchify(x, grid_sizes)
|
||||
return [u.astype(mx.float32) for u in outputs]
|
||||
Reference in New Issue
Block a user