feat(wan): Add I2V-14B dual-model support

This commit is contained in:
Daniel
2026-02-27 23:43:42 +01:00
parent 2bb95c61ed
commit f4195f0118
14 changed files with 1332 additions and 152 deletions

View File

@@ -1,6 +1,6 @@
# Wan2.2 MLX Implementation Notes
> Learnings and key decisions from porting Wan2.2 (TI2V-5B / T2V-14B / T2V-1.3B) to Apple MLX.
> Learnings and key decisions from porting Wan2.2 (TI2V-5B / T2V-14B / I2V-14B / T2V-1.3B) to Apple MLX.
## Architecture Overview
@@ -8,11 +8,12 @@ Wan2.2 is a Diffusion Transformer (DiT) for video generation. Despite early repo
### Key Parameters
| Model | dim | heads | layers | FFN mult | VAE z_dim | VAE stride |
|-------|-----|-------|--------|----------|-----------|------------|
| T2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) |
| TI2V-5B | 3072 | 24 | 32 | 4×(3072×4/3) | 48 | (4, 16, 16) |
| T2V-1.3B | 1536 | 12 | 30 | 4×(1536×4/3) | 16 | (4, 8, 8) |
| Model | dim | heads | layers | FFN mult | VAE z_dim | VAE stride | in_dim |
|-------|-----|-------|--------|----------|-----------|------------|--------|
| T2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | 16 |
| I2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | 36 |
| TI2V-5B | 3072 | 24 | 32 | 4×(3072×4/3) | 48 | (4, 16, 16) | 48 |
| T2V-1.3B | 1536 | 12 | 30 | 4×(1536×4/3) | 16 | (4, 8, 8) | 16 |
### Codebase Structure (~3900 lines of Wan2.2 code)
@@ -139,9 +140,11 @@ Default shifts: T2V-14B uses 5.0, TI2V-5B uses 3.0, T2V-1.3B uses 3.0.
---
## Image-to-Video (I2V) Pipeline
## Image-to-Video (I2V) Pipelines
### Per-Token Timesteps
Wan2.2 supports two distinct I2V approaches:
### TI2V-5B: Per-Token Timestep Masking
I2V conditions on a reference first frame by giving first-frame latent patches a timestep of 0 (clean) while other patches get the current diffusion timestep:
@@ -152,7 +155,7 @@ t_tokens = mask_tokens * current_timestep # first-frame → t=0
The model receives 2D timestep input `[B, L]` instead of scalar, enabling per-token noise levels.
### Mask Re-application
#### Mask Re-application
After each scheduler step, the first-frame latent is re-injected to prevent drift:
@@ -160,7 +163,7 @@ After each scheduler step, the first-frame latent is re-injected to prevent drif
latents = (1.0 - mask) * z_img + mask * latents
```
### VAE Encoder Temporal Downsample Order
#### VAE Encoder Temporal Downsample Order
The Wan2.2 VAE encoder has `temporal_downsample = (False, True, True)`:
- Stage 0: Spatial-only downsampling
@@ -168,6 +171,22 @@ The Wan2.2 VAE encoder has `temporal_downsample = (False, True, True)`:
This was incorrectly set to `(True, True, False)` initially, causing wrong spatial processing paths.
### I2V-14B: Channel Concatenation
The I2V-14B model uses a fundamentally different approach — channel concatenation via a `y` tensor:
1. **Encode image**: Resize to target (H, W), create video tensor with image as first frame + zeros → VAE encode through Wan2.1 encoder → `[16, T_lat, H_lat, W_lat]`
2. **Build mask**: Binary mask with 1 for first frame, 0 for rest → rearranged to `[4, T_lat, H_lat, W_lat]`
3. **Construct y**: `y = concat([mask_4ch, encoded_16ch])``[20, T_lat, H_lat, W_lat]`
4. **Channel concat in model**: Before patchify, `x = concat([noise_16ch, y_20ch])` → 36 channels matching `in_dim=36`
Key differences from TI2V-5B:
- Uses **Wan2.1 VAE** (z_dim=16, stride 4,8,8), not Wan2.2 VAE
- Requires the **VAE encoder** (for encoding the reference image)
- Uses **scalar timesteps** (same as T2V) — no per-token masking
- **Dual model** pipeline with boundary=0.900
- Both conditional and unconditional predictions receive the same `y` tensor
---
## Dimension Constraints
@@ -233,7 +252,7 @@ The T2V-14B uses dual models (high-noise and low-noise). The conversion script s
## Testing Strategy
260 tests across 9 files, all running in ~4 seconds:
332 tests across 10 files, all running in ~5 seconds:
| File | Focus |
|------|-------|
@@ -246,6 +265,7 @@ The T2V-14B uses dual models (high-noise and low-noise). The conversion script s
| test_wan_scheduler.py | All 3 schedulers, cross-scheduler coherence |
| test_wan_convert.py | Weight sanitization and conversion |
| test_wan_generate.py | End-to-end pipeline, I2V masks, dimension alignment |
| test_wan_i2v.py | I2V-14B config, y parameter, VAE encoder, mask construction |
Tests use a tiny config (`dim=64, heads=2, layers=2`) for fast execution. Cross-scheduler coherence tests verify that all three schedulers produce similar outputs from the same noise.