feat(wan): Add I2V-14B dual-model support

This commit is contained in:
Daniel
2026-02-27 23:43:42 +01:00
parent 2bb95c61ed
commit f4195f0118
14 changed files with 1332 additions and 152 deletions

View File

@@ -20,7 +20,7 @@ Supported models:
- [**LTX-2**](https://huggingface.co/Lightricks/LTX-Video) — 19B parameter video generation model from Lightricks - [**LTX-2**](https://huggingface.co/Lightricks/LTX-Video) — 19B parameter video generation model from Lightricks
- [**Wan2.1**](https://github.com/Wan-Video/Wan2.1) — 1.3B / 14B parameter T2V models (single-model pipeline) - [**Wan2.1**](https://github.com/Wan-Video/Wan2.1) — 1.3B / 14B parameter T2V models (single-model pipeline)
- [**Wan2.2**](https://github.com/Wan-Video/Wan2.2) — 14B parameter T2V model (dual-model pipeline) - [**Wan2.2**](https://github.com/Wan-Video/Wan2.2) — T2V-14B, TI2V-5B, and I2V-14B models (dual-model pipeline)
## Features ## Features
@@ -82,13 +82,15 @@ python -m mlx_video.generate \
Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.com/Wan-Video/Wan2.2) are text-to-video diffusion models built on a DiT (Diffusion Transformer) backbone with a T5 text encoder and 3D VAE. They share the same model architecture — the difference is in the inference pipeline: Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.com/Wan-Video/Wan2.2) are text-to-video diffusion models built on a DiT (Diffusion Transformer) backbone with a T5 text encoder and 3D VAE. They share the same model architecture — the difference is in the inference pipeline:
| | Wan2.1 | Wan2.2 | | | Wan2.1 | Wan2.2 T2V-14B | Wan2.2 I2V-14B |
|---|--------|--------| |---|--------|--------|--------|
| **Pipeline** | Single model | Dual model (high-noise + low-noise) | | **Task** | Text-to-Video | Text-to-Video | Image-to-Video |
| **Sizes** | 1.3B, 14B | 14B | | **Pipeline** | Single model | Dual model | Dual model |
| **Steps** | 50 | 40 | | **Sizes** | 1.3B, 14B | 14B | 14B |
| **Guidance** | 5.0 (fixed) | 3.0 / 4.0 (low/high noise) | | **Steps** | 50 | 40 | 40 |
| **Shift** | 5.0 | 12.0 | | **Guidance** | 5.0 (fixed) | 3.0 / 4.0 | 3.5 / 3.5 |
| **Shift** | 5.0 | 12.0 | 5.0 |
| **VAE** | Wan2.1 (z=16) | Wan2.1 (z=16) | Wan2.1 (z=16) + encoder |
### Step 1: Download Weights ### Step 1: Download Weights
@@ -117,9 +119,11 @@ Download the original PyTorch checkpoints:
# └── high_noise_model/ # safetensors # └── high_noise_model/ # safetensors
``` ```
**Wan2.2 I2V-14B** — same directory structure as Wan2.2 T2V. The conversion script auto-detects I2V-14B from the model's `config.json` (`model_type: "i2v"`, `in_dim: 36`).
### Step 2: Convert to MLX Format ### Step 2: Convert to MLX Format
The conversion script auto-detects whether the checkpoint is Wan2.1 or Wan2.2 based on the directory structure (presence of `low_noise_model/` subdirectory). The conversion script auto-detects the model version based on the directory structure (presence of `low_noise_model/` subdirectory) and model type (`model_type` in source config.json for I2V vs T2V).
```bash ```bash
# Auto-detect version # Auto-detect version
@@ -157,6 +161,7 @@ wan_mlx/
├── config.json # Model configuration ├── config.json # Model configuration
├── t5_encoder.safetensors # T5 UMT5-XXL text encoder ├── t5_encoder.safetensors # T5 UMT5-XXL text encoder
├── vae.safetensors # 3D VAE decoder ├── vae.safetensors # 3D VAE decoder
├── vae_encoder.safetensors # 3D VAE encoder (I2V-14B only)
├── model.safetensors # (Wan2.1) Single transformer ├── model.safetensors # (Wan2.1) Single transformer
├── low_noise_model.safetensors # (Wan2.2) Low-noise transformer ├── low_noise_model.safetensors # (Wan2.2) Low-noise transformer
└── high_noise_model.safetensors # (Wan2.2) High-noise transformer └── high_noise_model.safetensors # (Wan2.2) High-noise transformer
@@ -195,12 +200,27 @@ python -m mlx_video.generate_wan \
The pipeline auto-detects the model version from `config.json` and selects the right pipeline mode (single or dual model). You can also override any parameter via CLI flags. The pipeline auto-detects the model version from `config.json` and selects the right pipeline mode (single or dual model). You can also override any parameter via CLI flags.
#### Image-to-Video (I2V-14B)
```bash
# Generate video from an input image
python -m mlx_video.generate_wan \
--model-dir wan22_i2v_mlx \
--prompt "The camera slowly zooms in as the subject begins to move" \
--image start.png \
--num-frames 81 \
--output-path my_video.mp4
```
The I2V-14B model encodes the input image through the Wan2.1 VAE encoder and uses channel concatenation (`y` tensor with 4 mask + 16 image latent channels) to condition generation on the first frame.
#### Generation CLI Options #### Generation CLI Options
| Option | Default | Description | | Option | Default | Description |
|--------|---------|-------------| |--------|---------|-------------|
| `--model-dir` | (required) | Path to converted MLX model directory | | `--model-dir` | (required) | Path to converted MLX model directory |
| `--prompt` | (required) | Text description of the video | | `--prompt` | (required) | Text description of the video |
| `--image` | `None` | Input image path (for I2V models) |
| `--negative-prompt` | `""` | Negative prompt for guidance | | `--negative-prompt` | `""` | Negative prompt for guidance |
| `--width` | 1280 | Video width | | `--width` | 1280 | Video width |
| `--height` | 720 | Video height | | `--height` | 720 | Video height |
@@ -237,6 +257,7 @@ python -m mlx_video.generate_wan \
> **Note**: On Apple Silicon, the 1.3B model fits comfortably in unified memory at bf16. Quantization reduces memory but may not speed up inference for small models. For the 14B model, quantization is essential to fit in memory and will also improve speed. > **Note**: On Apple Silicon, the 1.3B model fits comfortably in unified memory at bf16. Quantization reduces memory but may not speed up inference for small models. For the 14B model, quantization is essential to fit in memory and will also improve speed.
### Wan Model Specifications ### Wan Model Specifications
**Transformer (14B)** **Transformer (14B)**

383
docs/DIAGNOSTICS.md Normal file
View File

@@ -0,0 +1,383 @@
# Wan2.2 I2V-14B Diagnostic Report
This document records the systematic diagnostic methodology used to debug the Wan2.2 I2V-14B (Image-to-Video, 14 billion parameter) pipeline in mlx-video, along with every bug found, its root cause, and fix.
## Table of Contents
- [Overview](#overview)
- [Architecture Summary](#architecture-summary)
- [Diagnostic Methodology](#diagnostic-methodology)
- [Bug 1: Text Embedding Cross-Contamination](#bug-1-text-embedding-cross-contamination)
- [Bug 2: VAE Encoder Weights Excluded from Conversion](#bug-2-vae-encoder-weights-excluded-from-conversion)
- [Bug 3: RoPE Frequency Computation](#bug-3-rope-frequency-computation)
- [Bug 4: VAE Encoder Temporal Downsample Order](#bug-4-vae-encoder-temporal-downsample-order)
- [Bug 5: Non-Chunked VAE Encoding](#bug-5-non-chunked-vae-encoding)
- [Verified Correct Components](#verified-correct-components)
- [Performance Optimizations](#performance-optimizations)
- [Open Investigation: CFG Effectiveness](#open-investigation-cfg-effectiveness)
- [Reference Implementation](#reference-implementation)
- [Useful Diagnostic Commands](#useful-diagnostic-commands)
---
## Overview
The I2V-14B pipeline takes an input image and generates a video using a dual-model diffusion transformer. The initial implementation produced severely broken output — first frame showed the image, subsequent frames degraded to noise, checkerboard artifacts, or flat grey.
Through a systematic component-by-component comparison against the reference PyTorch implementation, **five bugs** were found and fixed. The approach was to verify each component in isolation numerically, then narrow down failures to the subsystem level.
### Timeline of Symptoms
| Stage | Symptom | Root Cause |
|-------|---------|------------|
| Initial | Grey/blurry frames after frame 1 | Non-chunked VAE encoding (Bug 5) |
| After chunked encoding fix | First frame OK, rest degrades to noise | Text embedding cross-contamination (Bug 1) + RoPE frequencies (Bug 3) |
| After text + RoPE fix | Severe 8px checkerboard on frames 4+ | VAE encoder temporal downsample order (Bug 4) |
| After VAE fix | Image in frames 0-3, grey frames 4+ | CFG effectiveness issue (open investigation) |
---
## Architecture Summary
```
I2V-14B Pipeline:
Input Image → VAE Encoder → [16, T_lat, H_lat, W_lat]
Mask Construction → [4, T_lat, H_lat, W_lat]
y = concat(mask, encoded_video) → [20, T_lat, H_lat, W_lat]
Noise [16, T_lat, H_lat, W_lat] + y → [36, T_lat, H_lat, W_lat]
Dual DiT (40 layers, 5120 dim) × 40 denoising steps
Denoised Latent [16, T_lat, H_lat, W_lat]
VAE Decoder → Video [3, F, H, W]
```
**Key parameters:**
- `in_dim=36` (16 noise + 4 mask + 16 image latents), `out_dim=16`
- Dual model: HIGH noise (t ≥ 900) and LOW noise (t < 900)
- 40 steps, shift=5.0, guide_scale=(3.5, 3.5)
- Uses Wan2.1 VAE (z_dim=16, stride 4×8×8)
---
## Diagnostic Methodology
### 1. Component-Level Numerical Verification
Each component was tested in isolation against the reference PyTorch implementation:
1. **Load identical inputs** (same random seed, same image, same prompt)
2. **Run through reference** (on CPU where possible) and save intermediate tensors as `.npy`
3. **Run through MLX** with the same inputs
4. **Compare outputs** with `np.abs(ours - ref).max()` and relative difference metrics
Components tested this way:
- RoPE frequency parameters and rotation output
- Time embedding (sinusoidal → MLP → projection)
- Patchify (reshape+Linear vs Conv3d)
- Unpatchify (transpose-based vs einsum)
- Scheduler (UniPC) timesteps and step formulas
- VAE encoder output (frame-by-frame comparison)
- Text embeddings (per-model MLP output)
- Cross-attention K/V cache shapes
- Mask construction values
### 2. Artifact Analysis
When visual artifacts appeared, quantitative metrics were used to characterize them:
- **Checkerboard metric**: Difference between even-indexed and odd-indexed pixels at patch boundaries. Values > 20 indicate visible checkerboard.
- **FFT frequency analysis**: Power at the 8px spatial frequency (matches VAE stride). 3× normal power confirmed VAE-stride-aligned artifacts.
- **Per-frame statistics**: Mean, std, min, max for each decoded video frame to track temporal degradation.
- **Frame difference**: `mean(|frame[i] - frame[i-1]|)` to measure motion vs static content.
### 3. Isolation Testing
- **VAE round-trip test**: Encode image+zeros → decode. If clean, VAE decoder is not the source.
- **Single-step model output**: Run one diffusion step and compare cond vs uncond predictions to check CFG effectiveness.
- **Patchify/unpatchify synthetic test**: Pass structured gradient through unpatchify to verify spatial ordering.
- **Resolution sweeps**: Test at 480×272, 640×384, 1280×720 to check resolution dependence.
- **Step count sweeps**: Test at 5, 20, 40 steps to distinguish convergence issues from model bugs.
### 4. Weight Comparison
Direct comparison of converted MLX weights against original PyTorch weights:
```python
# Load both weight sets
pt_weights = torch.load("model.safetensors")
mlx_weights = mx.load("model.safetensors")
# Compare each key
for key in pt_weights:
diff = np.abs(np.array(pt_weights[key]) - np.array(mlx_weights[key])).max()
```
Expected: max diff ≈ 0.001 (bfloat16 rounding). Actual: confirmed for all keys.
---
## Bug 1: Text Embedding Cross-Contamination
**Symptom:** Model ignores text prompt, generated frames lack semantic content.
**Root Cause:** For the dual-model architecture (high-noise and low-noise experts), text embeddings were computed using only `low_noise_model.embed_text()` and reused for both models' cross-attention K/V caches. The two models have **different** text embedding MLP weights — 42% relative mean difference in output.
**How Found:** Compared `text_embedding_0.weight` and `text_embedding_1.weight` between `high_noise_model.safetensors` and `low_noise_model.safetensors`. Found 17.9% and 26.3% relative differences in the weight matrices.
**Fix:** Compute separate text embeddings per model:
```python
# Before (broken):
context_emb = low_noise_model.embed_text([context, context_null])
cross_kv = low_noise_model.prepare_cross_kv(context_emb) # used for BOTH models
# After (correct):
context_emb_low = low_noise_model.embed_text([context, context_null])
context_emb_high = high_noise_model.embed_text([context, context_null])
cross_kv_low = low_noise_model.prepare_cross_kv(context_emb_low)
cross_kv_high = high_noise_model.prepare_cross_kv(context_emb_high)
```
**File:** `mlx_video/generate_wan.py` (lines 333349)
**Commit:** `a85b1c21`
---
## Bug 2: VAE Encoder Weights Excluded from Conversion
**Symptom:** VAE encoder produces constant output regardless of input image (all-zero weights after conversion).
**Root Cause:** The conversion script only included encoder weights for `model_type == "ti2v"` (TI2V-5B), not for `"i2v"` (I2V-14B). Since `load_vae_encoder()` uses `strict=False`, missing encoder weights were silently ignored, resulting in random initialization.
**How Found:** Traced through `convert_wan.py` and found `include_encoder = config.model_type == "ti2v"`. Cross-referenced with the fact that I2V-14B also requires a VAE encoder (for image conditioning).
**Fix:**
```python
# Before:
include_encoder = config.model_type == "ti2v"
# After:
include_encoder = config.model_type in ("ti2v", "i2v")
```
**Note:** The user's specific model happened to be manually converted with encoder weights already present, so this fix was preventive for future conversions.
**File:** `mlx_video/convert_wan.py` (line 424)
---
## Bug 3: RoPE Frequency Computation
**Symptom:** Progressive 2px checkerboard artifacts on generated frames, increasing with temporal distance from the conditioned frame.
**Root Cause:** The reference creates **one** frequency table via `rope_params(1024, head_dim=128)` producing 64 frequency exponents, which `rope_apply` then splits into temporal (22), height (21), and width (21) portions. This gives temporal axes LOW frequencies and spatial axes progressively HIGHER frequencies.
Our code called `rope_params` **three times** with different normalizations:
```python
# WRONG: each axis gets full frequency range [0, 1)
freqs_t = rope_params(1024, d_t=44) # 22 exponents normalized by 44
freqs_h = rope_params(1024, d_h=42) # 21 exponents normalized by 42
freqs_w = rope_params(1024, d_w=42) # 21 exponents normalized by 42
```
The max frequency difference was ~1.0 (not a precision issue — a fundamental design bug). This affected **all** Wan models (T2V, I2V, TI2V).
**How Found:** Line-by-line comparison of `rope_params` usage between reference `model.py` (single call) and our `model.py` (three calls). Printed the actual frequency exponents to confirm the numerical divergence.
**Fix:**
```python
# Single unified frequency table, split by rope_apply
self.freqs = rope_params(1024, dim // config.num_heads)
```
**Impact:** ~35% reduction in checkerboard metric, 55% reduction in FFT 8px-frequency power.
**File:** `mlx_video/models/wan/model.py` (lines 154156)
**Commit:** `3da4a637`
---
## Bug 4: VAE Encoder Temporal Downsample Order
**Symptom:** Massive checkerboard artifacts aligned to VAE spatial stride (8px period). VAE encoder output for frames 14 showed decreasing std (0.37→1.19) while reference showed stable std (0.95→1.34).
**Root Cause:** The VAE encoder has 3 downsampling stages. Two perform spatial+temporal downsampling (`downsample3d`) and one performs spatial-only (`downsample2d`). The order matters:
```
Reference: [False, True, True] → stage 0: 2d, stage 1: 3d, stage 2: 3d
Ours: [True, True, False] → stage 0: 3d, stage 1: 3d, stage 2: 2d ← WRONG
```
This caused temporal downsampling to happen at the wrong resolution stages (96-dim instead of 384-dim), corrupting temporal feature propagation.
**How Found:** Installed `einops` in the reference environment and ran the reference PyTorch VAE encoder on CPU. Compared frame-by-frame latent output:
- Frame 0 matched exactly (diff=0.0000) — spatial-only processing was correct
- Frames 14 had massive differences — proved temporal processing was broken
Then traced through the reference `_video_vae()` function and found it sets `temperal_downsample=[False, True, True]`, while our `Encoder3d` class used the wrong default `[True, True, False]`.
**Fix:**
```python
# In Encoder3d.__init__, change default:
temporal_downsample = [False, True, True] # was [True, True, False]
```
**Impact:** Encoder output now matches reference within float32 precision (max_diff=2.2e-5). Checkerboard metric dropped from 6080 to 0.17.7.
**File:** `mlx_video/models/wan/vae.py` (line 370)
**Commit:** `3da4a637`
---
## Bug 5: Non-Chunked VAE Encoding
**Symptom:** First 45 frames grey, then blurred version of image appears.
**Root Cause:** The reference VAE encoder uses **chunked encoding** with temporal caching (`feat_cache`):
1. Encode first frame alone (1 frame)
2. Encode remaining frames in chunks of 4, with cached temporal features propagating across chunks
3. Each `CausalConv3d` caches last 2 temporal frames from its output, prepending them to the next chunk's input
Our original implementation encoded all frames at once with zero-padded causal convolutions. The temporal feature propagation is fundamentally different because:
- Chunked: real features from previous chunks serve as causal context
- Non-chunked: zeros serve as causal context for the start
**How Found:** Studied the reference `CausalConv3d` caching mechanism (`feat_cache`, `feat_idx`) and traced the temporal dimension through all encoding stages. Confirmed that non-chunked encoding produces different output by comparing tensor shapes and values.
**Fix:** Implemented full chunked encoding with temporal caching:
- Added `cache_x` parameter to `CausalConv3d.__call__`
- Added `feat_cache`/`feat_idx` propagation to `ResidualBlock`, `Resample`, `Encoder3d`
- Rewrote `WanVAE.encode()` with chunked loop (1-frame first chunk, then 4-frame chunks)
- 24 cache slots across the encoder (1 conv1 + 18 downsamples + 4 middle + 1 head)
**File:** `mlx_video/models/wan/vae.py` (multiple methods)
**Commit:** `b6a94c4c`
---
## Verified Correct Components
These components were numerically verified against the reference and are **not** sources of bugs:
| Component | Method | Max Diff | Notes |
|-----------|--------|----------|-------|
| Weight conversion | Direct tensor comparison | ~0.001 | bfloat16 rounding only |
| RoPE rotation | Standalone comparison (float32 vs float64) | 1.3e-5 | Complex vs real multiplication equivalent |
| Time embedding | Full MLP comparison (sinusoidal→embed→project) | 7e-4 | 0.03% relative |
| Patchify | Conv3d vs reshape+Linear | 3.5e-3 | 0.16% relative |
| Unpatchify | einsum vs transpose(6,0,3,1,4,2,5) | exact | Identical operation |
| Scheduler (UniPC) | Formula-level audit + timestep comparison | exact | Predictor, corrector, lambda, rhos all match |
| Mask construction | Value comparison | exact | [4, T_lat, H_lat, W_lat], first temporal=1 |
| CFG formula | Code audit | — | `uncond + gs * (cond - uncond)` correct order |
| VAE decoder | Round-trip test (encode→decode) | clean | No checkerboard in round-trip output |
| Cross-attention K/V | Shape and value audit | — | Batch dimension preserved correctly |
---
## Performance Optimizations
Applied alongside bug fixes to improve inference speed:
### Pre-Computation (Before Diffusion Loop)
- **Cross-attention K/V caching**: Precompute K/V projections for all 40 blocks once
- **RoPE cos/sin precomputation**: Build frequency tensors once instead of per-step broadcast/concat
- **Attention mask precomputation**: Build padding mask once, pass via kwargs
- **Inverse frequency caching**: Store sinusoidal `inv_freq` in `__init__` instead of recomputing
- **Timestep list conversion**: `sched.timesteps.tolist()` before loop to avoid `.item()` sync
### Per-Step Optimizations
- **Single patchify + broadcast for CFG B=2**: Detect identical batch inputs, patchify once and broadcast instead of duplicating the Linear projection
- **Vectorized RoPE**: When all batch elements share the same grid size, apply rotation to the full batch tensor instead of looping per element
- **Redundant type cast removal**: MLX type promotion handles `bfloat16 * float32 → float32` automatically — removed 240 unnecessary graph nodes per step (6 casts × 40 blocks)
- **Euler scheduler sync fix**: Pre-store sigmas as Python floats to avoid `.item()` evaluation sync
### TeaCache Integration
- Polynomial rescaling stays in MLX lazy graph (Horner's method)
- Single `.item()` call on the accumulated distance for the skip/compute decision
- Configurable threshold, retention steps, and cutoff steps
---
## Open Investigation: CFG Effectiveness
**Current symptom:** After all bug fixes, generated video shows the input image in frames 03 (latent frame 0), then grey/flat frames for the rest.
**Finding:** A single forward pass at t=1000 shows cond and uncond predictions are nearly identical (|diff| mean = 0.010.035). With `guide_scale=3.5`, the CFG guidance term barely changes anything.
**Possible causes under investigation:**
1. Cross-attention context flow — both cond and uncond may be receiving equivalent context
2. The model may genuinely produce small cond/uncond differences for I2V (since both share the same y conditioning)
3. The `embed_text` method or `prepare_cross_kv` may not properly separate B=2 batch elements
4. There may be an issue with how cross-attention K/V caches index into batch elements
**Diagnostic approach:** Compare cross-attention K/V cache values between cond (index 0) and uncond (index 1) to confirm they contain different embeddings.
---
## Reference Implementation
The reference PyTorch implementation is at `/Users/daniel/Projects/Wan2.2/`:
| File | Contents |
|------|----------|
| `wan/image2video.py` | I2V pipeline (y construction, mask, diffusion loop) |
| `wan/modules/model.py` | DiT model (forward pass, RoPE, patchify) |
| `wan/modules/vae2_1.py` | VAE encoder/decoder with chunked encoding |
| `wan/utils/fm_solvers_unipc.py` | UniPC scheduler |
| `wan/configs/wan_i2v_A14B.py` | Model configuration |
Key structural differences between reference and our implementation:
- Reference runs **separate B=1 forward passes** for cond/uncond; we batch as B=2
- Reference uses `torch.amp.autocast('cuda', dtype=bfloat16)` with explicit float32 blocks; we cast via weight dtype
- Reference uses `Conv3d` for patchify; we use equivalent `reshape + Linear`
- Reference casts timesteps to `int64`; we keep as float (diff < 1.0)
---
## Useful Diagnostic Commands
### Run I2V-14B generation
```bash
python -m mlx_video.generate_wan \
--prompt "A woman smiles at camera" \
--image start.png \
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-I2V-A14B-MLX \
--num-frames 17 --steps 40 \
--height 384 --width 640 \
--output output_i2v.mp4
```
### Check VAE encoder output
```python
import mlx.core as mx, numpy as np
from mlx_video.models.wan.vae import WanVAE
# Load VAE and encode an image
latents = vae.encode(video_tensor) # [1, 16, T_lat, H_lat, W_lat]
for t in range(latents.shape[2]):
frame = np.array(latents[0, :, t])
print(f"Frame {t}: mean={frame.mean():.4f} std={frame.std():.4f}")
```
### Analyze video frame quality
```python
import cv2, numpy as np
cap = cv2.VideoCapture("output.mp4")
while True:
ret, frame = cap.read()
if not ret: break
# Checkerboard metric: high values indicate patch-boundary artifacts
checker = np.abs(frame[::2, ::2].astype(float) - frame[1::2, 1::2].astype(float)).mean()
print(f"std={frame.std():.1f} checker={checker:.1f}")
```
### Compare weights between PyTorch and MLX
```python
import torch, mlx.core as mx, numpy as np
pt = torch.load("model.pt", map_location="cpu")
mlx_w = mx.load("model.safetensors")
for key in sorted(pt.keys()):
if key in mlx_w:
diff = np.abs(pt[key].float().numpy() - np.array(mlx_w[key])).max()
if diff > 0.01:
print(f"LARGE DIFF {key}: {diff:.6f}")
```

View File

@@ -1,6 +1,6 @@
# Wan2.2 MLX Implementation Notes # Wan2.2 MLX Implementation Notes
> Learnings and key decisions from porting Wan2.2 (TI2V-5B / T2V-14B / T2V-1.3B) to Apple MLX. > Learnings and key decisions from porting Wan2.2 (TI2V-5B / T2V-14B / I2V-14B / T2V-1.3B) to Apple MLX.
## Architecture Overview ## Architecture Overview
@@ -8,11 +8,12 @@ Wan2.2 is a Diffusion Transformer (DiT) for video generation. Despite early repo
### Key Parameters ### Key Parameters
| Model | dim | heads | layers | FFN mult | VAE z_dim | VAE stride | | Model | dim | heads | layers | FFN mult | VAE z_dim | VAE stride | in_dim |
|-------|-----|-------|--------|----------|-----------|------------| |-------|-----|-------|--------|----------|-----------|------------|--------|
| T2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | | T2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | 16 |
| TI2V-5B | 3072 | 24 | 32 | 4×(3072×4/3) | 48 | (4, 16, 16) | | I2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | 36 |
| T2V-1.3B | 1536 | 12 | 30 | 4×(1536×4/3) | 16 | (4, 8, 8) | | TI2V-5B | 3072 | 24 | 32 | 4×(3072×4/3) | 48 | (4, 16, 16) | 48 |
| T2V-1.3B | 1536 | 12 | 30 | 4×(1536×4/3) | 16 | (4, 8, 8) | 16 |
### Codebase Structure (~3900 lines of Wan2.2 code) ### Codebase Structure (~3900 lines of Wan2.2 code)
@@ -139,9 +140,11 @@ Default shifts: T2V-14B uses 5.0, TI2V-5B uses 3.0, T2V-1.3B uses 3.0.
--- ---
## Image-to-Video (I2V) Pipeline ## Image-to-Video (I2V) Pipelines
### Per-Token Timesteps Wan2.2 supports two distinct I2V approaches:
### TI2V-5B: Per-Token Timestep Masking
I2V conditions on a reference first frame by giving first-frame latent patches a timestep of 0 (clean) while other patches get the current diffusion timestep: I2V conditions on a reference first frame by giving first-frame latent patches a timestep of 0 (clean) while other patches get the current diffusion timestep:
@@ -152,7 +155,7 @@ t_tokens = mask_tokens * current_timestep # first-frame → t=0
The model receives 2D timestep input `[B, L]` instead of scalar, enabling per-token noise levels. The model receives 2D timestep input `[B, L]` instead of scalar, enabling per-token noise levels.
### Mask Re-application #### Mask Re-application
After each scheduler step, the first-frame latent is re-injected to prevent drift: After each scheduler step, the first-frame latent is re-injected to prevent drift:
@@ -160,7 +163,7 @@ After each scheduler step, the first-frame latent is re-injected to prevent drif
latents = (1.0 - mask) * z_img + mask * latents latents = (1.0 - mask) * z_img + mask * latents
``` ```
### VAE Encoder Temporal Downsample Order #### VAE Encoder Temporal Downsample Order
The Wan2.2 VAE encoder has `temporal_downsample = (False, True, True)`: The Wan2.2 VAE encoder has `temporal_downsample = (False, True, True)`:
- Stage 0: Spatial-only downsampling - Stage 0: Spatial-only downsampling
@@ -168,6 +171,22 @@ The Wan2.2 VAE encoder has `temporal_downsample = (False, True, True)`:
This was incorrectly set to `(True, True, False)` initially, causing wrong spatial processing paths. This was incorrectly set to `(True, True, False)` initially, causing wrong spatial processing paths.
### I2V-14B: Channel Concatenation
The I2V-14B model uses a fundamentally different approach — channel concatenation via a `y` tensor:
1. **Encode image**: Resize to target (H, W), create video tensor with image as first frame + zeros → VAE encode through Wan2.1 encoder → `[16, T_lat, H_lat, W_lat]`
2. **Build mask**: Binary mask with 1 for first frame, 0 for rest → rearranged to `[4, T_lat, H_lat, W_lat]`
3. **Construct y**: `y = concat([mask_4ch, encoded_16ch])``[20, T_lat, H_lat, W_lat]`
4. **Channel concat in model**: Before patchify, `x = concat([noise_16ch, y_20ch])` → 36 channels matching `in_dim=36`
Key differences from TI2V-5B:
- Uses **Wan2.1 VAE** (z_dim=16, stride 4,8,8), not Wan2.2 VAE
- Requires the **VAE encoder** (for encoding the reference image)
- Uses **scalar timesteps** (same as T2V) — no per-token masking
- **Dual model** pipeline with boundary=0.900
- Both conditional and unconditional predictions receive the same `y` tensor
--- ---
## Dimension Constraints ## Dimension Constraints
@@ -233,7 +252,7 @@ The T2V-14B uses dual models (high-noise and low-noise). The conversion script s
## Testing Strategy ## Testing Strategy
260 tests across 9 files, all running in ~4 seconds: 332 tests across 10 files, all running in ~5 seconds:
| File | Focus | | File | Focus |
|------|-------| |------|-------|
@@ -246,6 +265,7 @@ The T2V-14B uses dual models (high-noise and low-noise). The conversion script s
| test_wan_scheduler.py | All 3 schedulers, cross-scheduler coherence | | test_wan_scheduler.py | All 3 schedulers, cross-scheduler coherence |
| test_wan_convert.py | Weight sanitization and conversion | | test_wan_convert.py | Weight sanitization and conversion |
| test_wan_generate.py | End-to-end pipeline, I2V masks, dimension alignment | | test_wan_generate.py | End-to-end pipeline, I2V masks, dimension alignment |
| test_wan_i2v.py | I2V-14B config, y parameter, VAE encoder, mask construction |
Tests use a tiny config (`dim=64, heads=2, layers=2`) for fast execution. Cross-scheduler coherence tests verify that all three schedulers produce similar outputs from the same noise. Tests use a tiny config (`dim=64, heads=2, layers=2`) for fast execution. Cross-scheduler coherence tests verify that all three schedulers produce similar outputs from the same noise.

View File

@@ -316,6 +316,14 @@ def convert_wan_checkpoint(
def _detect_config(): def _detect_config():
"""Detect config from source config.json or transformer weight shapes.""" """Detect config from source config.json or transformer weight shapes."""
if is_dual: if is_dual:
# Check source config.json for model_type (I2V vs T2V)
src_cfg_path = checkpoint_dir / "high_noise_model" / "config.json"
if src_cfg_path.exists():
with open(src_cfg_path) as f:
src_config = json.load(f)
src_model_type = src_config.get("model_type", "t2v")
if src_model_type == "i2v" or src_config.get("in_dim") == 36:
return WanModelConfig.wan22_i2v_14b()
return WanModelConfig.wan22_t2v_14b() return WanModelConfig.wan22_t2v_14b()
# Try reading source config.json first (most reliable) # Try reading source config.json first (most reliable)
@@ -413,7 +421,7 @@ def convert_wan_checkpoint(
weights = load_torch_weights(str(vae_path)) weights = load_torch_weights(str(vae_path))
if is_wan22_vae: if is_wan22_vae:
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
include_encoder = config.model_type == "ti2v" include_encoder = config.model_type in ("ti2v", "i2v")
weights = sanitize_wan22_vae_weights(weights, include_encoder=include_encoder) weights = sanitize_wan22_vae_weights(weights, include_encoder=include_encoder)
else: else:
weights = sanitize_wan_vae_weights(weights) weights = sanitize_wan_vae_weights(weights)

View File

@@ -245,24 +245,71 @@ def generate_video(
z_img = None z_img = None
i2v_mask = None i2v_mask = None
i2v_mask_tokens = None i2v_mask_tokens = None
y_i2v = None
is_i2v_channel_concat = is_i2v and config.model_type == "i2v"
is_i2v_mask_blend = is_i2v and config.model_type != "i2v"
if is_i2v: if is_i2v:
print(f"\n{Colors.BLUE}Encoding input image...{Colors.RESET}") print(f"\n{Colors.BLUE}Encoding input image...{Colors.RESET}")
t_img = time.time() t_img = time.time()
vae_path = model_dir / "vae.safetensors"
if is_i2v_channel_concat:
# I2V-14B: encode full video (first frame = image, rest = zeros)
# and construct y tensor with mask + encoded latents
from PIL import Image
img = Image.open(image).convert("RGB")
scale = max(width / img.width, height / img.height)
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
x1, y1 = (img.width - width) // 2, (img.height - height) // 2
img = img.crop((x1, y1, x1 + width, y1 + height))
img_arr = mx.array(np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0) # [H, W, 3]
img_chw = img_arr.transpose(2, 0, 1) # [3, H, W]
# Build video: first frame = image, rest = zeros -> [3, F, H, W]
# Chunked encoding processes 1-frame + 4-frame chunks with temporal caching
video = mx.concatenate([
img_chw[:, None, :, :],
mx.zeros((3, num_frames - 1, height, width)),
], axis=1)
# Encode through Wan2.1 VAE -> [1, z_dim, T_lat, H_lat, W_lat]
vae_enc = load_vae_encoder(vae_path, config)
z_video = vae_enc.encode(video[None]) # [1, 16, T_lat, H_lat, W_lat]
mx.eval(z_video)
z_video = z_video[0] # [16, T_lat, H_lat, W_lat]
# Build mask: 1 for first frame, 0 for rest -> rearrange to [4, T_lat, H, W]
msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
# Repeat first frame 4x, concat rest: [1, 4 + (F-1), H_lat, W_lat]
msk = mx.concatenate([
mx.repeat(msk[:, :1], 4, axis=1),
msk[:, 1:],
], axis=1)
# Reshape to [1, T_lat, 4, H_lat, W_lat] then transpose -> [4, T_lat, H_lat, W_lat]
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
# y = concat([mask, encoded_video]) -> [20, T_lat, H_lat, W_lat]
y_i2v = mx.concatenate([msk, z_video], axis=0)
mx.eval(y_i2v)
del vae_enc, img_arr, img_chw, video, z_video, msk
else:
# TI2V-5B: encode single image, blend with noise via mask
img_tensor = preprocess_image(image, width, height) img_tensor = preprocess_image(image, width, height)
mx.eval(img_tensor) mx.eval(img_tensor)
vae_path = model_dir / "vae.safetensors"
vae_enc = load_vae_encoder(vae_path, config) vae_enc = load_vae_encoder(vae_path, config)
z_img = vae_enc(img_tensor) # [1, 1, H_lat, W_lat, z_dim] z_img = vae_enc(img_tensor) # [1, 1, H_lat, W_lat, z_dim]
mx.eval(z_img) mx.eval(z_img)
z_img = z_img[0].transpose(3, 0, 1, 2) # [z_dim, 1, H_lat, W_lat]
# Convert to channels-first: [z_dim, 1, H_lat, W_lat]
z_img = z_img[0].transpose(3, 0, 1, 2)
# Build I2V mask
i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size) i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size)
del vae_enc, img_tensor del vae_enc, img_tensor
gc.collect(); mx.clear_cache() gc.collect(); mx.clear_cache()
print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}") print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}")
@@ -282,23 +329,40 @@ def generate_video(
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}") print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
# Precompute text embeddings once (avoids redundant MLP in every step) # Precompute text embeddings once (avoids redundant MLP in every step)
ref_model = single_model if not is_dual else low_noise_model # Each model has its own text_embedding weights, so dual models need separate embeddings
context_emb = ref_model.embed_text([context, context_null]) if is_dual:
context_emb_low = low_noise_model.embed_text([context, context_null])
context_emb_high = high_noise_model.embed_text([context, context_null])
mx.eval(context_emb_low, context_emb_high)
context_cfg_low = mx.concatenate([context_emb_low[0:1], context_emb_low[1:2]], axis=0)
context_cfg_high = mx.concatenate([context_emb_high[0:1], context_emb_high[1:2]], axis=0)
else:
context_emb = single_model.embed_text([context, context_null])
mx.eval(context_emb) mx.eval(context_emb)
context_cond = context_emb[0:1] # [1, text_len, dim] context_cfg = mx.concatenate([context_emb[0:1], context_emb[1:2]], axis=0)
context_uncond = context_emb[1:2] # [1, text_len, dim]
# Stack for batched CFG: [2, text_len, dim]
context_cfg = mx.concatenate([context_cond, context_uncond], axis=0)
# Precompute cross-attention K/V caches (constant across all steps) # Precompute cross-attention K/V caches (constant across all steps)
if is_dual: if is_dual:
cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg) cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg_low)
cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg) cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg_high)
mx.eval(cross_kv_low, cross_kv_high) mx.eval(cross_kv_low, cross_kv_high)
else: else:
cross_kv = single_model.prepare_cross_kv(context_cfg) cross_kv = single_model.prepare_cross_kv(context_cfg)
mx.eval(cross_kv) mx.eval(cross_kv)
# Precompute RoPE frequencies (grid sizes are constant across all steps)
f_grid = t_latent // patch_size[0]
h_grid = h_latent // patch_size[1]
w_grid = w_latent // patch_size[2]
cfg_grid_sizes = [(f_grid, h_grid, w_grid), (f_grid, h_grid, w_grid)]
if is_dual:
rope_cos_sin_low = low_noise_model.prepare_rope(cfg_grid_sizes)
rope_cos_sin_high = high_noise_model.prepare_rope(cfg_grid_sizes)
mx.eval(rope_cos_sin_low, rope_cos_sin_high)
else:
rope_cos_sin = ref_model.prepare_rope(cfg_grid_sizes)
mx.eval(rope_cos_sin)
# Setup scheduler # Setup scheduler
_schedulers = { _schedulers = {
"euler": FlowMatchEulerScheduler, "euler": FlowMatchEulerScheduler,
@@ -312,9 +376,8 @@ def generate_video(
# Generate initial noise # Generate initial noise
noise = mx.random.normal(target_shape) noise = mx.random.normal(target_shape)
# I2V: blend first-frame latent into noise # I2V initialization: TI2V-5B blends image with noise, I2V-14B uses pure noise
if is_i2v: if is_i2v_mask_blend:
# Broadcast z_img [z_dim, 1, H, W] across T for first-frame conditioning
latents = (1.0 - i2v_mask) * z_img + i2v_mask * noise latents = (1.0 - i2v_mask) * z_img + i2v_mask * noise
else: else:
latents = noise latents = noise
@@ -326,26 +389,32 @@ def generate_video(
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}") print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
t3 = time.time() t3 = time.time()
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")): # Pre-convert timesteps to Python list to avoid .item() sync each step
timestep_val = sched.timesteps[i].item() timestep_list = sched.timesteps.tolist()
# Select model, guide scale, and cached K/V for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
timestep_val = timestep_list[i]
# Select model, guide scale, cached K/V, and precomputed RoPE
if is_dual: if is_dual:
if timestep_val >= boundary: if timestep_val >= boundary:
model = high_noise_model model = high_noise_model
gs = guide_scale[1] gs = guide_scale[1]
kv = cross_kv_high kv = cross_kv_high
rcs = rope_cos_sin_high
else: else:
model = low_noise_model model = low_noise_model
gs = guide_scale[0] gs = guide_scale[0]
kv = cross_kv_low kv = cross_kv_low
rcs = rope_cos_sin_low
else: else:
model = single_model model = single_model
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0] gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
kv = cross_kv kv = cross_kv
rcs = rope_cos_sin
# Build per-token timesteps for I2V (first-frame patches get t=0) # Build per-token timesteps for TI2V-5B (first-frame patches get t=0)
if is_i2v: if is_i2v_mask_blend:
t_tokens = i2v_mask_tokens * timestep_val # [1, L] t_tokens = i2v_mask_tokens * timestep_val # [1, L]
# Pad to seq_len if needed # Pad to seq_len if needed
pad_len = seq_len - t_tokens.shape[1] pad_len = seq_len - t_tokens.shape[1]
@@ -358,22 +427,31 @@ def generate_video(
else: else:
t_batch = mx.array([timestep_val, timestep_val]) t_batch = mx.array([timestep_val, timestep_val])
# I2V-14B: pass y conditioning to model (same y for cond and uncond)
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
# CFG: batch cond + uncond into single B=2 forward pass # CFG: batch cond + uncond into single B=2 forward pass
ctx = context_cfg if not is_dual else (
context_cfg_high if timestep_val >= boundary else context_cfg_low
)
preds = model( preds = model(
[latents, latents], [latents, latents],
t=t_batch, t=t_batch,
context=context_cfg, context=ctx,
seq_len=seq_len, seq_len=seq_len,
cross_kv_caches=kv, cross_kv_caches=kv,
y=y_arg,
rope_cos_sin=rcs,
) )
noise_pred_cond, noise_pred_uncond = preds[0], preds[1] noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
# Classifier-free guidance + scheduler step # Classifier-free guidance + scheduler step
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond) noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0) latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
# I2V: re-apply mask to keep first frame frozen # TI2V-5B: re-apply mask to keep first frame frozen
if is_i2v: if is_i2v_mask_blend:
latents = (1.0 - i2v_mask) * z_img + i2v_mask * latents latents = (1.0 - i2v_mask) * z_img + i2v_mask * latents
# Release temporaries before eval to free memory for graph execution # Release temporaries before eval to free memory for graph execution
@@ -385,9 +463,11 @@ def generate_video(
# Free transformer models and text embeddings # Free transformer models and text embeddings
if is_dual: if is_dual:
del low_noise_model, high_noise_model, cross_kv_low, cross_kv_high del low_noise_model, high_noise_model, cross_kv_low, cross_kv_high
del context_cfg_low, context_cfg_high
else: else:
del single_model, cross_kv del single_model, cross_kv
del model, kv, context, context_null, context_cfg del context_cfg
del model, kv, context, context_null
gc.collect(); mx.clear_cache() gc.collect(); mx.clear_cache()
# Load VAE and decode # Load VAE and decode

View File

@@ -67,6 +67,8 @@ class WanSelfAttention(nn.Module):
seq_lens: list, seq_lens: list,
grid_sizes: list, grid_sizes: list,
freqs: mx.array, freqs: mx.array,
rope_cos_sin: tuple | None = None,
attn_mask: mx.array | None = None,
) -> mx.array: ) -> mx.array:
b, s, _ = x.shape b, s, _ = x.shape
n, d = self.num_heads, self.head_dim n, d = self.num_heads, self.head_dim
@@ -87,19 +89,18 @@ class WanSelfAttention(nn.Module):
v = self.v(x_w).reshape(b, s, n, d) v = self.v(x_w).reshape(b, s, n, d)
# RoPE in float32 for precision (official uses float64) # RoPE in float32 for precision (official uses float64)
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs) q = rope_apply(q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs) k = rope_apply(k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype)) # Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
q = q.astype(w_dtype).transpose(0, 2, 1, 3) q = q.astype(w_dtype).transpose(0, 2, 1, 3)
k = k.astype(w_dtype).transpose(0, 2, 1, 3) k = k.astype(w_dtype).transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3) v = v.transpose(0, 2, 1, 3)
# Build attention mask from seq_lens # Use precomputed mask or build from seq_lens
max_len = s mask = attn_mask
mask = None if mask is None and any(sl < s for sl in seq_lens):
if any(sl < max_len for sl in seq_lens): mask = mx.zeros((b, 1, 1, s), dtype=q.dtype)
mask = mx.zeros((b, 1, 1, max_len), dtype=q.dtype)
for i, sl in enumerate(seq_lens): for i, sl in enumerate(seq_lens):
mask[i, :, :, sl:] = -1e9 mask[i, :, :, sl:] = -1e9

View File

@@ -91,6 +91,19 @@ class WanModelConfig(BaseModelConfig):
"""Wan2.2 T2V 14B: dual model, 40 layers, dim=5120 (default).""" """Wan2.2 T2V 14B: dual model, 40 layers, dim=5120 (default)."""
return cls() return cls()
@classmethod
def wan22_i2v_14b(cls) -> "WanModelConfig":
"""Wan2.2 I2V 14B: dual model, image-to-video, 40 layers, dim=5120."""
return cls(
model_type="i2v",
in_dim=36,
out_dim=16,
dual_model=True,
boundary=0.900,
sample_shift=5.0,
sample_guide_scale=(3.5, 3.5),
)
@classmethod @classmethod
def wan22_ti2v_5b(cls) -> "WanModelConfig": def wan22_ti2v_5b(cls) -> "WanModelConfig":
"""Wan2.2 TI2V 5B: text+image to video, 30 layers, dim=3072.""" """Wan2.2 TI2V 5B: text+image to video, 30 layers, dim=3072."""

View File

@@ -87,16 +87,23 @@ def load_vae_decoder(model_path: Path, config=None):
def load_vae_encoder(model_path: Path, config=None): def load_vae_encoder(model_path: Path, config=None):
"""Load VAE encoder for I2V image encoding. """Load VAE encoder for I2V image encoding.
Only supports Wan2.2 (vae_z_dim=48). For Wan2.2 TI2V (vae_z_dim=48), uses Wan22VAEEncoder.
For Wan2.1/I2V-14B (vae_z_dim=16), uses WanVAE with encoder=True.
""" """
if config is not None and config.vae_z_dim == 16:
from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16, encoder=True)
else:
from mlx_video.models.wan.vae22 import Wan22VAEEncoder from mlx_video.models.wan.vae22 import Wan22VAEEncoder
encoder = Wan22VAEEncoder(z_dim=config.vae_z_dim) vae = Wan22VAEEncoder(z_dim=config.vae_z_dim if config else 48)
weights = mx.load(str(model_path)) weights = mx.load(str(model_path))
weights = {k: v.astype(mx.float32) for k, v in weights.items()} weights = {k: v.astype(mx.float32) for k, v in weights.items()}
encoder.load_weights(list(weights.items()), strict=False) vae.load_weights(list(weights.items()), strict=False)
mx.eval(encoder.parameters()) mx.eval(vae.parameters())
return encoder return vae
def _clean_text(text: str) -> str: def _clean_text(text: str) -> str:

View File

@@ -6,7 +6,7 @@ import numpy as np
from .attention import WanLayerNorm from .attention import WanLayerNorm
from .config import WanModelConfig from .config import WanModelConfig
from .rope import rope_params from .rope import rope_params, rope_precompute_cos_sin
from .transformer import WanAttentionBlock from .transformer import WanAttentionBlock
@@ -38,7 +38,7 @@ class Head(nn.Module):
proj_dim = math.prod(patch_size) * out_dim proj_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps) self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, proj_dim) self.head = nn.Linear(dim, proj_dim)
self.modulation = mx.random.normal((1, 2, dim)) * (dim**-0.5) self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(mx.float32)
def __call__(self, x: mx.array, e: mx.array) -> mx.array: def __call__(self, x: mx.array, e: mx.array) -> mx.array:
""" """
@@ -48,14 +48,13 @@ class Head(nn.Module):
""" """
if e.ndim == 2: if e.ndim == 2:
e = e[:, None, :] # [B, 1, dim] e = e[:, None, :] # [B, 1, dim]
e_f32 = e.astype(mx.float32) # modulation already float32; e already float32 from model forward
# modulation [1, 2, dim] broadcasts with e [B, 1/L, dim] via unsqueeze mod = self.modulation[:, None, :, :] + e[:, :, None, :] # [B, L_e, 2, dim]
mod = self.modulation.astype(mx.float32)[:, None, :, :] + e_f32[:, :, None, :] # [B, L_e, 2, dim]
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
x_norm = self.norm(x).astype(mx.float32) x_norm = self.norm(x)
x_mod = x_norm * (1 + e1) + e0 # broadcasts over L if L_e==1 x_mod = x_norm * (1 + e1) + e0 # type promotion handles bf16→f32
return self.head(x_mod.astype(x.dtype)) return self.head(x_mod.astype(self.head.weight.dtype))
class WanModel(nn.Module): class WanModel(nn.Module):
@@ -109,17 +108,16 @@ class WanModel(nn.Module):
# Output head # Output head
self.head = Head(dim, config.out_dim, config.patch_size, config.eps) self.head = Head(dim, config.out_dim, config.patch_size, config.eps)
# Precompute RoPE frequencies # Precompute RoPE frequencies — single table, split by rope_apply
d = dim // config.num_heads # Reference computes one rope_params(head_dim) and splits into t/h/w.
d_t = d - 4 * (d // 6) self.freqs = rope_params(1024, dim // config.num_heads)
d_h = 2 * (d // 6)
d_w = 2 * (d // 6) # Precompute sinusoidal inv_freq for time embedding
# Each rope_params returns [1024, d_x//2, 2] half = config.freq_dim // 2
freqs_t = rope_params(1024, d_t) self._inv_freq = mx.power(
freqs_h = rope_params(1024, d_h) 10000.0, -mx.arange(half).astype(mx.float32) / half
freqs_w = rope_params(1024, d_w) )
# Concatenate along the frequency dimension: [1024, d//2, 2]
self.freqs = mx.concatenate([freqs_t, freqs_h, freqs_w], axis=1)
def _patchify(self, x: mx.array) -> tuple: def _patchify(self, x: mx.array) -> tuple:
"""Convert video tensor to patch embeddings. """Convert video tensor to patch embeddings.
@@ -215,6 +213,21 @@ class WanModel(nn.Module):
kv_caches.append(block.cross_attn.prepare_kv(context)) kv_caches.append(block.cross_attn.prepare_kv(context))
return kv_caches return kv_caches
def prepare_rope(self, grid_sizes: list) -> tuple:
"""Pre-compute RoPE cos/sin for constant grid sizes.
Call once before the diffusion loop when grid sizes don't change
across steps. Eliminates per-step broadcast/concat overhead.
Args:
grid_sizes: List of (F, H, W) tuples per batch element
Returns:
(cos_f, sin_f) precomputed frequency tensors
"""
w_dtype = self.patch_embedding_proj.weight.dtype
return rope_precompute_cos_sin(grid_sizes, self.freqs, dtype=w_dtype)
def __call__( def __call__(
self, self,
x_list: list, x_list: list,
@@ -222,6 +235,8 @@ class WanModel(nn.Module):
context: list | mx.array, context: list | mx.array,
seq_len: int, seq_len: int,
cross_kv_caches: list | None = None, cross_kv_caches: list | None = None,
y: list | None = None,
rope_cos_sin: tuple | None = None,
) -> list: ) -> list:
"""Forward pass. """Forward pass.
@@ -233,11 +248,39 @@ class WanModel(nn.Module):
seq_len: Maximum sequence length for padding seq_len: Maximum sequence length for padding
cross_kv_caches: Optional list of (k, v) tuples from cross_kv_caches: Optional list of (k, v) tuples from
prepare_cross_kv(), one per block. prepare_cross_kv(), one per block.
y: Optional list of conditioning tensors for I2V [C_y, F, H, W].
Channel-concatenated with x before patchify.
rope_cos_sin: Optional precomputed (cos, sin) from prepare_rope().
Returns: Returns:
List of denoised tensors [C, F, H, W] List of denoised tensors [C, F, H, W]
""" """
# Patchify each video # Detect identical inputs (CFG B=2) to avoid duplicate patchify work.
# Check BEFORE I2V concat since concat creates new array objects.
batch_size = len(x_list)
all_same = batch_size > 1 and all(
x_list[i] is x_list[0] for i in range(1, batch_size)
)
if all_same and y is not None:
all_same = all(y[i] is y[0] for i in range(1, len(y)))
# I2V: channel-concatenate conditioning y with noise x
if y is not None:
x_list = [mx.concatenate([u, v], axis=0) for u, v in zip(x_list, y)]
if all_same:
# Patchify once and broadcast — saves a Linear projection per step
p, gs = self._patchify(x_list[0]) # [1, L, dim]
grid_sizes = [gs] * batch_size
seq_lens_list = [p.shape[1]] * batch_size
# Pad and broadcast
if p.shape[1] < seq_len:
p = mx.concatenate(
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
axis=1,
)
x = mx.broadcast_to(p, (batch_size,) + p.shape[1:])
else:
patches = [] patches = []
grid_sizes = [] grid_sizes = []
seq_lens_list = [] seq_lens_list = []
@@ -246,9 +289,6 @@ class WanModel(nn.Module):
patches.append(p) patches.append(p)
grid_sizes.append(gs) grid_sizes.append(gs)
seq_lens_list.append(p.shape[1]) seq_lens_list.append(p.shape[1])
# Pad and batch
batch_size = len(patches)
x = mx.concatenate( x = mx.concatenate(
[ [
mx.concatenate( mx.concatenate(
@@ -262,13 +302,16 @@ class WanModel(nn.Module):
axis=0, axis=0,
) # [B, seq_len, dim] ) # [B, seq_len, dim]
# Time embedding # Time embedding (use cached inv_freq to avoid recomputing each step)
if t.ndim == 0: if t.ndim == 0:
t = t[None] t = t[None]
pos = t.astype(mx.float32)
sinusoid = pos[..., None] * self._inv_freq
sin_emb = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
if t.ndim == 1: if t.ndim == 1:
# Standard T2V: scalar timestep per batch element [B] # Standard T2V: scalar timestep per batch element [B]
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, freq_dim]
e = self.time_embedding_1( e = self.time_embedding_1(
self.time_embedding_act(self.time_embedding_0(sin_emb)) self.time_embedding_act(self.time_embedding_0(sin_emb))
) # [B, dim] ) # [B, dim]
@@ -278,7 +321,6 @@ class WanModel(nn.Module):
e = e.astype(mx.float32) e = e.astype(mx.float32)
else: else:
# I2V: per-token timesteps [B, L] # I2V: per-token timesteps [B, L]
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, L, freq_dim]
e = self.time_embedding_1( e = self.time_embedding_1(
self.time_embedding_act(self.time_embedding_0(sin_emb)) self.time_embedding_act(self.time_embedding_0(sin_emb))
) # [B, L, dim] ) # [B, L, dim]
@@ -298,7 +340,15 @@ class WanModel(nn.Module):
else: else:
context_batch = self.embed_text(context) context_batch = self.embed_text(context)
# Run transformer blocks # Pre-compute attention mask from seq_lens (constant across all blocks)
attn_mask = None
w_dtype = self.patch_embedding_proj.weight.dtype
if any(sl < seq_len for sl in seq_lens_list):
attn_mask = mx.zeros((batch_size, 1, 1, seq_len), dtype=w_dtype)
for i, sl in enumerate(seq_lens_list):
attn_mask[i, :, :, sl:] = -1e9
kwargs = dict( kwargs = dict(
e=e0, e=e0,
seq_lens=seq_lens_list, seq_lens=seq_lens_list,
@@ -306,8 +356,11 @@ class WanModel(nn.Module):
freqs=self.freqs, freqs=self.freqs,
context=context_batch, context=context_batch,
context_lens=None, context_lens=None,
rope_cos_sin=rope_cos_sin,
attn_mask=attn_mask,
) )
# Run transformer blocks
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
kv = cross_kv_caches[i] if cross_kv_caches is not None else None kv = cross_kv_caches[i] if cross_kv_caches is not None else None
x = block(x, cross_kv_cache=kv, **kwargs) x = block(x, cross_kv_cache=kv, **kwargs)

View File

@@ -28,6 +28,7 @@ def rope_apply(
x: mx.array, x: mx.array,
grid_sizes: list, grid_sizes: list,
freqs: mx.array, freqs: mx.array,
precomputed_cos_sin: tuple | None = None,
) -> mx.array: ) -> mx.array:
"""Apply 3-way factorized RoPE to Q or K tensor. """Apply 3-way factorized RoPE to Q or K tensor.
@@ -35,10 +36,48 @@ def rope_apply(
x: Shape [B, L, num_heads, head_dim] x: Shape [B, L, num_heads, head_dim]
grid_sizes: List of (F, H, W) tuples per batch element grid_sizes: List of (F, H, W) tuples per batch element
freqs: Precomputed cos/sin, shape [1024, d//2, 2] split into 3 parts freqs: Precomputed cos/sin, shape [1024, d//2, 2] split into 3 parts
precomputed_cos_sin: Optional (cos, sin) from rope_precompute_cos_sin()
""" """
b, s, n, d = x.shape b, s, n, d = x.shape
half_d = d // 2 half_d = d // 2
if precomputed_cos_sin is not None:
cos_f, sin_f = precomputed_cos_sin
# Check if all batch elements have the same grid (common for CFG B=2)
f0, h0, w0 = grid_sizes[0]
seq_len = f0 * h0 * w0
all_same_grid = all(
grid_sizes[i] == grid_sizes[0] for i in range(1, b)
) if b > 1 else True
if all_same_grid:
# Vectorized path: apply RoPE to all batch elements at once
x_seq = x[:, :seq_len].reshape(b, seq_len, n, half_d, 2)
x_real = x_seq[..., 0]
x_imag = x_seq[..., 1]
out_real = x_real * cos_f - x_imag * sin_f
out_imag = x_real * sin_f + x_imag * cos_f
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(b, seq_len, n, d)
if seq_len < s:
x_rotated = mx.concatenate([x_rotated, x[:, seq_len:]], axis=1)
return x_rotated
else:
# Per-element path for mixed grid sizes
outputs = []
for i in range(b):
f, h, w = grid_sizes[i]
sl = f * h * w
x_i = x[i, :sl].reshape(sl, n, half_d, 2)
x_real = x_i[..., 0]
x_imag = x_i[..., 1]
out_real = x_real * cos_f - x_imag * sin_f
out_imag = x_real * sin_f + x_imag * cos_f
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(sl, n, d)
if sl < s:
x_rotated = mx.concatenate([x_rotated, x[i, sl:]], axis=0)
outputs.append(x_rotated)
return mx.stack(outputs)
# Cast freqs to input dtype to prevent float32 promotion cascade # Cast freqs to input dtype to prevent float32 promotion cascade
if freqs.dtype != x.dtype: if freqs.dtype != x.dtype:
freqs = freqs.astype(x.dtype) freqs = freqs.astype(x.dtype)
@@ -98,3 +137,42 @@ def rope_apply(
outputs.append(x_rotated) outputs.append(x_rotated)
return mx.stack(outputs) return mx.stack(outputs)
def rope_precompute_cos_sin(
grid_sizes: list, freqs: mx.array, dtype: type = mx.float32
) -> tuple:
"""Precompute cos/sin frequency tensors for constant grid sizes.
Call once before the diffusion loop. Pass result as precomputed_cos_sin
to rope_apply to skip per-step broadcast/concat.
Args:
grid_sizes: List of (F, H, W) tuples (must be same for all batch elements)
freqs: Precomputed frequencies [1024, d//2, 2]
dtype: Target dtype for the output tensors
Returns:
(cos_f, sin_f) each [seq_len, 1, half_d]
"""
if freqs.dtype != dtype:
freqs = freqs.astype(dtype)
f, h, w = grid_sizes[0]
seq_len = f * h * w
half_d = freqs.shape[1]
d_t = half_d - 2 * (half_d // 3)
d_h = half_d // 3
d_w = half_d // 3
freqs_t = freqs[:, :d_t]
freqs_h = freqs[:, d_t : d_t + d_h]
freqs_w = freqs[:, d_t + d_h : d_t + d_h + d_w]
ft = mx.broadcast_to(freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2))
fh = mx.broadcast_to(freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2))
fw = mx.broadcast_to(freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2))
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)
return freqs_i[..., 0], freqs_i[..., 1]

View File

@@ -34,6 +34,8 @@ class FlowMatchEulerScheduler:
sigmas = _compute_sigmas(num_steps, shift) sigmas = _compute_sigmas(num_steps, shift)
self.sigmas = mx.array(sigmas) self.sigmas = mx.array(sigmas)
self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps) self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps)
# Store as Python floats to avoid .item() sync in step()
self._sigmas_float = sigmas.tolist()
self._step_index = 0 self._step_index = 0
def step( def step(
@@ -43,9 +45,7 @@ class FlowMatchEulerScheduler:
sample: mx.array, sample: mx.array,
) -> mx.array: ) -> mx.array:
"""Euler step: x_next = x + (sigma_next - sigma_cur) * v.""" """Euler step: x_next = x + (sigma_next - sigma_cur) * v."""
dt = float(self.sigmas[self._step_index + 1].item()) - float( dt = self._sigmas_float[self._step_index + 1] - self._sigmas_float[self._step_index]
self.sigmas[self._step_index].item()
)
x_next = sample + dt * model_output x_next = sample + dt * model_output
self._step_index += 1 self._step_index += 1
return x_next return x_next

View File

@@ -35,8 +35,8 @@ class WanAttentionBlock(nn.Module):
self.norm2 = WanLayerNorm(dim, eps) self.norm2 = WanLayerNorm(dim, eps)
self.ffn = WanFFN(dim, ffn_dim) self.ffn = WanFFN(dim, ffn_dim)
# Learned modulation: 6 vectors for scale/shift/gate # Learned modulation: 6 vectors for scale/shift/gate (kept in float32 for precision)
self.modulation = mx.random.normal((1, 6, dim)) * (dim**-0.5) self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(mx.float32)
def __call__( def __call__(
self, self,
@@ -48,10 +48,11 @@ class WanAttentionBlock(nn.Module):
context: mx.array, context: mx.array,
context_lens: list | None = None, context_lens: list | None = None,
cross_kv_cache: tuple | None = None, cross_kv_cache: tuple | None = None,
rope_cos_sin: tuple | None = None,
attn_mask: mx.array | None = None,
) -> mx.array: ) -> mx.array:
# Modulation in float32 (matching official torch.amp.autocast float32) # Modulation in float32 (e is already float32 from model forward)
e_f32 = e.astype(mx.float32) mod = self.modulation + e
mod = self.modulation.astype(mx.float32) + e_f32
e0 = mod[:, :, 0, :] # shift for self-attn e0 = mod[:, :, 0, :] # shift for self-attn
e1 = mod[:, :, 1, :] # scale for self-attn e1 = mod[:, :, 1, :] # scale for self-attn
e2 = mod[:, :, 2, :] # gate for self-attn e2 = mod[:, :, 2, :] # gate for self-attn
@@ -59,19 +60,20 @@ class WanAttentionBlock(nn.Module):
e4 = mod[:, :, 4, :] # scale for ffn e4 = mod[:, :, 4, :] # scale for ffn
e5 = mod[:, :, 5, :] # gate for ffn e5 = mod[:, :, 5, :] # gate for ffn
# Self-attention with modulation (norm output in float32) # Self-attention with modulation
x_mod = self.norm1(x).astype(mx.float32) * (1 + e1) + e0 # Type promotion handles bf16→f32 automatically when multiplied with f32 modulation
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs) x_mod = self.norm1(x) * (1 + e1) + e0
x = x.astype(mx.float32) + y.astype(mx.float32) * e2 y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs, rope_cos_sin=rope_cos_sin, attn_mask=attn_mask)
x = x + y * e2
# Cross-attention (no modulation, just norm) # Cross-attention (no modulation, just norm)
x_cross = self.norm3(x) if self.norm3 is not None else x x_cross = self.norm3(x) if self.norm3 is not None else x
x = x + self.cross_attn(x_cross, context, context_lens, kv_cache=cross_kv_cache) x = x + self.cross_attn(x_cross, context, context_lens, kv_cache=cross_kv_cache)
# FFN with modulation (norm output in float32) # FFN with modulation
x_mod = self.norm2(x).astype(mx.float32) * (1 + e4) + e3 x_mod = self.norm2(x) * (1 + e4) + e3
y = self.ffn(x_mod) y = self.ffn(x_mod)
x = x + y.astype(mx.float32) * e5 x = x + y * e5
return x return x

View File

@@ -43,7 +43,9 @@ class CausalConv3d(nn.Module):
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.stride = stride self.stride = stride
self._causal_pad_t = 2 * padding[0] # Causal padding: match reference formula dilation*(k-1) + (1-stride)
# With dilation=1: k-stride (pads left only, no future context)
self._causal_pad_t = kernel_size[0] - stride[0]
self._pad_h = padding[1] self._pad_h = padding[1]
self._pad_w = padding[2] self._pad_w = padding[2]
@@ -51,12 +53,17 @@ class CausalConv3d(nn.Module):
self.weight = mx.zeros((out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels)) self.weight = mx.zeros((out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels))
self.bias = mx.zeros((out_channels,)) self.bias = mx.zeros((out_channels,))
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array, cache_x: mx.array = None) -> mx.array:
"""x: [B, C, T, H, W] (channel-first)""" """x: [B, C, T, H, W] (channel-first)"""
b, c, t, h, w = x.shape b, c, t, h, w = x.shape
if self._causal_pad_t > 0: causal_pad = self._causal_pad_t
pad_t = mx.zeros((b, c, self._causal_pad_t, h, w), dtype=x.dtype) if cache_x is not None and causal_pad > 0:
x = mx.concatenate([cache_x, x], axis=2)
causal_pad = max(0, causal_pad - cache_x.shape[2])
if causal_pad > 0:
pad_t = mx.zeros((b, c, causal_pad, h, w), dtype=x.dtype)
x = mx.concatenate([pad_t, x], axis=2) x = mx.concatenate([pad_t, x], axis=2)
if self._pad_h > 0 or self._pad_w > 0: if self._pad_h > 0 or self._pad_w > 0:
@@ -136,12 +143,35 @@ class ResidualBlock(nn.Module):
] ]
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
h = x if self.shortcut is None else self.shortcut(x) h = x if self.shortcut is None else self.shortcut(x)
if feat_cache is not None:
# First conv: norm -> silu -> [cache] -> conv
x = nn.silu(self.residual[0](x))
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:]
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
x = self.residual[2](x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
# Second conv: norm -> silu -> [cache] -> conv
x = nn.silu(self.residual[3](x))
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:]
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
x = self.residual[6](x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = nn.silu(self.residual[0](x)) x = nn.silu(self.residual[0](x))
x = self.residual[2](x) x = self.residual[2](x)
x = nn.silu(self.residual[3](x)) x = nn.silu(self.residual[3](x))
x = self.residual[6](x) x = self.residual[6](x)
return x + h return x + h
@@ -180,23 +210,31 @@ class AttentionBlock(nn.Module):
class Resample(nn.Module): class Resample(nn.Module):
"""Upsample block matching original Wan VAE structure. """Resample block matching original Wan VAE structure.
Uses `resample` list with [None, Conv2d] to match original Supports both upsampling (decoder) and downsampling (encoder).
nn.Sequential(Upsample, Conv2d) where index 1 has the conv params. Uses list-based param storage to match original nn.Sequential key hierarchy.
""" """
def __init__(self, dim: int, mode: str): def __init__(self, dim: int, mode: str):
super().__init__() super().__init__()
assert mode in ("upsample2d", "upsample3d") assert mode in ("upsample2d", "upsample3d", "downsample2d", "downsample3d")
self.mode = mode self.mode = mode
self.dim = dim self.dim = dim
if mode.startswith("upsample"):
# resample.0 = Upsample (no params), resample.1 = Conv2d # resample.0 = Upsample (no params), resample.1 = Conv2d
self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)] self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)]
if mode == "upsample3d": if mode == "upsample3d":
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
else:
# resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2)
self.resample = [None, nn.Conv2d(dim, dim, 3, stride=2)]
if mode == "downsample3d":
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
"""x: [B, C, T, H, W]""" """x: [B, C, T, H, W]"""
b, c, t, h, w = x.shape b, c, t, h, w = x.shape
@@ -204,10 +242,10 @@ class Resample(nn.Module):
# Temporal upsample via learned conv # Temporal upsample via learned conv
x_t = self.time_conv(x) # [B, 2C, T, H, W] x_t = self.time_conv(x) # [B, 2C, T, H, W]
x_t = x_t.reshape(b, 2, c, t, h, w) x_t = x_t.reshape(b, 2, c, t, h, w)
# Interleave along time: [B, C, 2T, H, W]
x = mx.stack([x_t[:, 0], x_t[:, 1]], axis=3).reshape(b, c, t * 2, h, w) x = mx.stack([x_t[:, 0], x_t[:, 1]], axis=3).reshape(b, c, t * 2, h, w)
t = t * 2 t = t * 2
if self.mode.startswith("upsample"):
# Per-frame spatial upsample: nearest 2x + Conv2d # Per-frame spatial upsample: nearest 2x + Conv2d
x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C] x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C]
x = mx.repeat(x, 2, axis=1) x = mx.repeat(x, 2, axis=1)
@@ -215,6 +253,32 @@ class Resample(nn.Module):
x = self.resample[1](x) # Conv2d [BT, 2H, 2W, C//2] x = self.resample[1](x) # Conv2d [BT, 2H, 2W, C//2]
c_out = x.shape[-1] c_out = x.shape[-1]
return x.reshape(b, t, h * 2, w * 2, c_out).transpose(0, 4, 1, 2, 3) return x.reshape(b, t, h * 2, w * 2, c_out).transpose(0, 4, 1, 2, 3)
else:
# Per-frame spatial downsample: ZeroPad(0,1,0,1) + Conv2d(stride=2)
x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C]
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) # ZeroPad2d(0,1,0,1)
x = self.resample[1](x) # Conv2d stride=2
c_out = x.shape[-1]
h_out, w_out = x.shape[1], x.shape[2]
x = x.reshape(b, t, h_out, w_out, c_out).transpose(0, 4, 1, 2, 3)
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
# First chunk: save x, skip time_conv
feat_cache[idx] = x
feat_idx[0] += 1
else:
# Subsequent chunks: use cached frame as temporal context
cache_x = x[:, :, -1:]
x = self.time_conv(
x, cache_x=feat_cache[idx][:, :, -1:])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.time_conv(x)
return x
class Decoder3d(nn.Module): class Decoder3d(nn.Module):
@@ -284,10 +348,108 @@ class Decoder3d(nn.Module):
return x return x
class WanVAE(nn.Module): class Encoder3d(nn.Module):
"""Wan2.1 VAE wrapper with per-channel normalization.""" """3D VAE Encoder matching Wan2.1 architecture.
def __init__(self, z_dim: int = 16): Mirror of Decoder3d with downsampling instead of upsampling.
Uses flat lists to match original PyTorch nn.Sequential weight key hierarchy.
"""
def __init__(
self,
dim: int = 96,
z_dim: int = 16,
dim_mult: list = None,
num_res_blocks: int = 2,
temporal_downsample: list = None,
):
super().__init__()
if dim_mult is None:
dim_mult = [1, 2, 4, 4]
if temporal_downsample is None:
temporal_downsample = [False, True, True]
dims = [dim * u for u in [1] + dim_mult]
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
# Flat downsample list matching original nn.Sequential indexing
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
for _ in range(num_res_blocks):
downsamples.append(ResidualBlock(in_dim, out_dim))
in_dim = out_dim
if i != len(dim_mult) - 1:
mode = "downsample3d" if temporal_downsample[i] else "downsample2d"
downsamples.append(Resample(out_dim, mode=mode))
self.downsamples = downsamples
# Middle: [ResBlock, AttentionBlock, ResBlock]
self.middle = [
ResidualBlock(dims[-1], dims[-1]),
AttentionBlock(dims[-1]),
ResidualBlock(dims[-1], dims[-1]),
]
# Output head: [RMS_norm, SiLU (no params), CausalConv3d]
self.head = [
RMS_norm(dims[-1], images=False),
None, # SiLU
CausalConv3d(dims[-1], z_dim, 3, padding=1),
]
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
"""x: [B, 3, T, H, W] -> [B, z_dim, T_lat, H_lat, W_lat]"""
if feat_cache is not None:
# conv1 with caching
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:]
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
x = self.conv1(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
for layer in self.downsamples:
if feat_cache is not None and isinstance(layer, (ResidualBlock, Resample)):
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = layer(x)
for layer in self.middle:
if feat_cache is not None and isinstance(layer, ResidualBlock):
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = layer(x)
if feat_cache is not None:
# Head: norm -> silu -> [cache] -> conv
x = nn.silu(self.head[0](x))
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:]
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
x = self.head[2](x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = nn.silu(self.head[0](x))
x = self.head[2](x)
return x
class WanVAE(nn.Module):
"""Wan2.1 VAE wrapper with per-channel normalization.
Supports both encode (for I2V) and decode (for all models).
"""
def __init__(self, z_dim: int = 16, encoder: bool = False):
super().__init__() super().__init__()
self.z_dim = z_dim self.z_dim = z_dim
self.mean = mx.array(VAE_MEAN) self.mean = mx.array(VAE_MEAN)
@@ -297,6 +459,65 @@ class WanVAE(nn.Module):
self.conv2 = CausalConv3d(z_dim, z_dim, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim=96, z_dim=z_dim) self.decoder = Decoder3d(dim=96, z_dim=z_dim)
if encoder:
self.encoder = Encoder3d(dim=96, z_dim=z_dim * 2)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
def encode(self, x: mx.array) -> mx.array:
"""Encode video to normalized latent using chunked encoding.
Uses chunked encoding with temporal caching to match reference behavior.
First frame encoded alone, then 4-frame chunks with cached context.
Args:
x: Video [B, 3, T, H, W] in [-1, 1]
Returns:
Normalized latent [B, z_dim, T_lat, H_lat, W_lat]
"""
# Count cacheable CausalConv3d slots in encoder
num_slots = self._count_encoder_cache_slots()
feat_cache = [None] * num_slots
t = x.shape[2]
num_chunks = 1 + (t - 1) // 4
out = None
for i in range(num_chunks):
feat_idx = [0]
if i == 0:
chunk = x[:, :, :1]
else:
chunk = x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i]
chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx)
if out is None:
out = chunk_out
else:
out = mx.concatenate([out, chunk_out], axis=2)
mu, _ = mx.split(self.conv1(out), 2, axis=1)
# Normalize: (mu - mean) * inv_std
mean = self.mean.reshape(1, -1, 1, 1, 1)
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
return (mu - mean) * inv_std
def _count_encoder_cache_slots(self) -> int:
"""Count CausalConv3d that participate in chunked encoding cache."""
count = 1 # encoder.conv1
for layer in self.encoder.downsamples:
if isinstance(layer, ResidualBlock):
count += 2 # two convs in residual path
elif isinstance(layer, Resample) and layer.mode == "downsample3d":
count += 1 # time_conv
for layer in self.encoder.middle:
if isinstance(layer, ResidualBlock):
count += 2
count += 1 # encoder.head CausalConv3d
return count
def decode(self, z: mx.array) -> mx.array: def decode(self, z: mx.array) -> mx.array:
"""Decode latent to video. """Decode latent to video.

293
tests/test_wan_i2v.py Normal file
View File

@@ -0,0 +1,293 @@
"""Tests for Wan2.2 I2V-14B support."""
import mlx.core as mx
import numpy as np
import pytest
from wan_test_helpers import _make_tiny_config
def _make_tiny_i2v_config():
"""Create a tiny I2V-14B config for testing."""
config = _make_tiny_config()
config.model_type = "i2v"
config.in_dim = 9 # 4 noise + 4 image + 1 mask (scaled down from 16+16+4=36)
config.out_dim = 4
config.vae_z_dim = 4
config.vae_stride = (4, 8, 8)
config.dual_model = True
config.boundary = 0.900
config.sample_shift = 5.0
config.sample_guide_scale = (3.5, 3.5)
config.teacache_coefficients = None
return config
class TestI2VConfig:
"""Test I2V-14B config preset."""
def test_wan22_i2v_14b_preset(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan22_i2v_14b()
assert config.model_type == "i2v"
assert config.in_dim == 36
assert config.out_dim == 16
assert config.dim == 5120
assert config.num_layers == 40
assert config.dual_model is True
assert config.boundary == 0.900
assert config.sample_shift == 5.0
assert config.sample_guide_scale == (3.5, 3.5)
assert config.vae_stride == (4, 8, 8)
assert config.vae_z_dim == 16
assert config.teacache_coefficients is None
def test_i2v_vs_t2v_differences(self):
from mlx_video.models.wan.config import WanModelConfig
i2v = WanModelConfig.wan22_i2v_14b()
t2v = WanModelConfig.wan22_t2v_14b()
assert i2v.model_type == "i2v"
assert t2v.model_type == "t2v"
assert i2v.in_dim == 36 and t2v.in_dim == 16
assert i2v.boundary == 0.900 and t2v.boundary == 0.875
assert i2v.sample_shift == 5.0 and t2v.sample_shift == 12.0
def test_i2v_serialization_roundtrip(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan22_i2v_14b()
d = config.to_dict()
restored = WanModelConfig.from_dict(d)
assert restored.model_type == "i2v"
assert restored.in_dim == 36
assert restored.boundary == 0.900
class TestModelYParameter:
"""Test y parameter channel concatenation in WanModel."""
def test_forward_without_y(self):
"""Standard T2V forward pass (no y) still works."""
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
x_list = [mx.random.normal((C, F, H, W))]
t = mx.array([500.0])
context = [mx.random.normal((6, config.text_dim))]
out = model(x_list, t, context, seq_len)
mx.eval(out[0])
assert out[0].shape == (C, F, H, W)
def test_forward_with_y(self):
"""I2V forward pass with y channel concatenation."""
from mlx_video.models.wan.model import WanModel
config = _make_tiny_i2v_config()
model = WanModel(config)
C_noise = 4 # noise channels
C_y = 5 # mask (1) + image (4)
F, H, W = 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
x_list = [mx.random.normal((C_noise, F, H, W))]
y_list = [mx.random.normal((C_y, F, H, W))]
t = mx.array([500.0])
context = [mx.random.normal((6, config.text_dim))]
out = model(x_list, t, context, seq_len, y=y_list)
mx.eval(out[0])
# Output should match noise channels (out_dim), not concatenated in_dim
assert out[0].shape == (config.out_dim, F, H, W)
def test_y_none_is_noop(self):
"""Passing y=None should be identical to not passing y."""
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
mx.random.seed(42)
x = mx.random.normal((C, F, H, W))
t = mx.array([500.0])
ctx = [mx.random.normal((6, config.text_dim))]
out1 = model([x], t, ctx, seq_len)[0]
out2 = model([x], t, ctx, seq_len, y=None)[0]
mx.eval(out1, out2)
assert mx.allclose(out1, out2, atol=1e-5).item()
def test_batched_cfg_with_y(self):
"""Batched CFG (B=2) with y should work."""
from mlx_video.models.wan.model import WanModel
config = _make_tiny_i2v_config()
model = WanModel(config)
C_noise, C_y = 4, 5
F, H, W = 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
latents = mx.random.normal((C_noise, F, H, W))
y = mx.random.normal((C_y, F, H, W))
t = mx.array([500.0, 500.0])
ctx = [mx.random.normal((6, config.text_dim)), mx.random.normal((6, config.text_dim))]
out = model([latents, latents], t, ctx, seq_len, y=[y, y])
mx.eval(out[0], out[1])
assert len(out) == 2
assert out[0].shape == (config.out_dim, F, H, W)
assert out[1].shape == (config.out_dim, F, H, W)
class TestVAEEncoder:
"""Test Wan2.1 VAE encoder."""
def test_encoder3d_instantiation(self):
from mlx_video.models.wan.vae import Encoder3d
enc = Encoder3d(dim=32, z_dim=8) # z_dim=8 (will output 8ch, but WanVAE wraps with z*2)
assert enc.conv1 is not None
assert len(enc.downsamples) > 0
assert len(enc.middle) == 3
def test_encoder3d_output_shape(self):
"""Encoder should downsample spatially by 8x and temporally by 4x."""
from mlx_video.models.wan.vae import Encoder3d
enc = Encoder3d(dim=32, z_dim=8)
# Random input: [B=1, 3, T=5, H=32, W=32]
x = mx.random.normal((1, 3, 5, 32, 32))
out = enc(x)
mx.eval(out)
# With default dim_mult=[1,2,4,4] and temporal_downsample=[True,True,False]:
# Spatial: 32 -> 16 -> 8 -> 4 (3 spatial downsamples)
# Temporal: 5 -> 3 -> 2 (2 temporal downsamples: downsample3d stride 2)
assert out.shape[0] == 1
assert out.shape[1] == 8 # z_dim
assert out.shape[3] == 32 // 8 # spatial /8
assert out.shape[4] == 32 // 8
def test_wan_vae_encode(self):
"""WanVAE with encoder=True should produce normalized latents."""
from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16, encoder=True)
# Input: [B=1, 3, T=5, H=32, W=32]
x = mx.random.normal((1, 3, 5, 32, 32))
z = vae.encode(x)
mx.eval(z)
assert z.shape[0] == 1
assert z.shape[1] == 16 # z_dim
def test_wan_vae_encoder_flag(self):
"""WanVAE without encoder flag should not have encoder attribute."""
from mlx_video.models.wan.vae import WanVAE
vae_no_enc = WanVAE(z_dim=4, encoder=False)
assert not hasattr(vae_no_enc, 'encoder')
vae_enc = WanVAE(z_dim=4, encoder=True)
assert hasattr(vae_enc, 'encoder')
class TestResampleDownsample:
"""Test downsample modes in Resample."""
def test_downsample2d(self):
from mlx_video.models.wan.vae import Resample
r = Resample(dim=16, mode="downsample2d")
x = mx.random.normal((1, 16, 2, 8, 8))
out = r(x)
mx.eval(out)
# Spatial /2, temporal unchanged, channels same
assert out.shape == (1, 16, 2, 4, 4)
def test_downsample3d(self):
from mlx_video.models.wan.vae import Resample
r = Resample(dim=16, mode="downsample3d")
x = mx.random.normal((1, 16, 4, 8, 8))
out = r(x)
mx.eval(out)
# Spatial /2, temporal /2, channels same
assert out.shape == (1, 16, 2, 4, 4)
def test_upsample2d_still_works(self):
from mlx_video.models.wan.vae import Resample
r = Resample(dim=16, mode="upsample2d")
x = mx.random.normal((1, 16, 2, 4, 4))
out = r(x)
mx.eval(out)
assert out.shape == (1, 8, 2, 8, 8)
def test_upsample3d_still_works(self):
from mlx_video.models.wan.vae import Resample
r = Resample(dim=16, mode="upsample3d")
x = mx.random.normal((1, 16, 2, 4, 4))
out = r(x)
mx.eval(out)
assert out.shape == (1, 8, 4, 8, 8)
class TestI2VMaskConstruction:
"""Test mask construction for I2V-14B."""
def test_mask_shape(self):
"""I2V-14B mask should have 4 channels with correct temporal structure."""
num_frames = 81
h_latent, w_latent = 10, 18 # example latent dims
t_latent = (num_frames - 1) // 4 + 1 # = 21
# Build mask following reference logic
msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
assert msk.shape == (4, t_latent, h_latent, w_latent)
def test_mask_values(self):
"""First temporal position should be 1, rest 0."""
num_frames = 9
h_latent, w_latent = 4, 4
t_latent = (num_frames - 1) // 4 + 1 # = 3
msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0]
mx.eval(msk)
# First temporal position: all 4 channels should be 1
assert mx.all(msk[:, 0] == 1.0).item()
# Rest: all should be 0
assert mx.all(msk[:, 1:] == 0.0).item()
def test_y_tensor_shape(self):
"""y = concat([mask_4ch, encoded_video_16ch]) should be 20 channels."""
mask = mx.zeros((4, 5, 10, 18))
encoded = mx.zeros((16, 5, 10, 18))
y = mx.concatenate([mask, encoded], axis=0)
assert y.shape == (20, 5, 10, 18)