feat(wan): Add I2V-14B dual-model support
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# Wan2.2 MLX Implementation Notes
|
||||
|
||||
> Learnings and key decisions from porting Wan2.2 (TI2V-5B / T2V-14B / T2V-1.3B) to Apple MLX.
|
||||
> Learnings and key decisions from porting Wan2.2 (TI2V-5B / T2V-14B / I2V-14B / T2V-1.3B) to Apple MLX.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
@@ -8,11 +8,12 @@ Wan2.2 is a Diffusion Transformer (DiT) for video generation. Despite early repo
|
||||
|
||||
### Key Parameters
|
||||
|
||||
| Model | dim | heads | layers | FFN mult | VAE z_dim | VAE stride |
|
||||
|-------|-----|-------|--------|----------|-----------|------------|
|
||||
| T2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) |
|
||||
| TI2V-5B | 3072 | 24 | 32 | 4×(3072×4/3) | 48 | (4, 16, 16) |
|
||||
| T2V-1.3B | 1536 | 12 | 30 | 4×(1536×4/3) | 16 | (4, 8, 8) |
|
||||
| Model | dim | heads | layers | FFN mult | VAE z_dim | VAE stride | in_dim |
|
||||
|-------|-----|-------|--------|----------|-----------|------------|--------|
|
||||
| T2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | 16 |
|
||||
| I2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | 36 |
|
||||
| TI2V-5B | 3072 | 24 | 32 | 4×(3072×4/3) | 48 | (4, 16, 16) | 48 |
|
||||
| T2V-1.3B | 1536 | 12 | 30 | 4×(1536×4/3) | 16 | (4, 8, 8) | 16 |
|
||||
|
||||
### Codebase Structure (~3900 lines of Wan2.2 code)
|
||||
|
||||
@@ -139,9 +140,11 @@ Default shifts: T2V-14B uses 5.0, TI2V-5B uses 3.0, T2V-1.3B uses 3.0.
|
||||
|
||||
---
|
||||
|
||||
## Image-to-Video (I2V) Pipeline
|
||||
## Image-to-Video (I2V) Pipelines
|
||||
|
||||
### Per-Token Timesteps
|
||||
Wan2.2 supports two distinct I2V approaches:
|
||||
|
||||
### TI2V-5B: Per-Token Timestep Masking
|
||||
|
||||
I2V conditions on a reference first frame by giving first-frame latent patches a timestep of 0 (clean) while other patches get the current diffusion timestep:
|
||||
|
||||
@@ -152,7 +155,7 @@ t_tokens = mask_tokens * current_timestep # first-frame → t=0
|
||||
|
||||
The model receives 2D timestep input `[B, L]` instead of scalar, enabling per-token noise levels.
|
||||
|
||||
### Mask Re-application
|
||||
#### Mask Re-application
|
||||
|
||||
After each scheduler step, the first-frame latent is re-injected to prevent drift:
|
||||
|
||||
@@ -160,7 +163,7 @@ After each scheduler step, the first-frame latent is re-injected to prevent drif
|
||||
latents = (1.0 - mask) * z_img + mask * latents
|
||||
```
|
||||
|
||||
### VAE Encoder Temporal Downsample Order
|
||||
#### VAE Encoder Temporal Downsample Order
|
||||
|
||||
The Wan2.2 VAE encoder has `temporal_downsample = (False, True, True)`:
|
||||
- Stage 0: Spatial-only downsampling
|
||||
@@ -168,6 +171,22 @@ The Wan2.2 VAE encoder has `temporal_downsample = (False, True, True)`:
|
||||
|
||||
This was incorrectly set to `(True, True, False)` initially, causing wrong spatial processing paths.
|
||||
|
||||
### I2V-14B: Channel Concatenation
|
||||
|
||||
The I2V-14B model uses a fundamentally different approach — channel concatenation via a `y` tensor:
|
||||
|
||||
1. **Encode image**: Resize to target (H, W), create video tensor with image as first frame + zeros → VAE encode through Wan2.1 encoder → `[16, T_lat, H_lat, W_lat]`
|
||||
2. **Build mask**: Binary mask with 1 for first frame, 0 for rest → rearranged to `[4, T_lat, H_lat, W_lat]`
|
||||
3. **Construct y**: `y = concat([mask_4ch, encoded_16ch])` → `[20, T_lat, H_lat, W_lat]`
|
||||
4. **Channel concat in model**: Before patchify, `x = concat([noise_16ch, y_20ch])` → 36 channels matching `in_dim=36`
|
||||
|
||||
Key differences from TI2V-5B:
|
||||
- Uses **Wan2.1 VAE** (z_dim=16, stride 4,8,8), not Wan2.2 VAE
|
||||
- Requires the **VAE encoder** (for encoding the reference image)
|
||||
- Uses **scalar timesteps** (same as T2V) — no per-token masking
|
||||
- **Dual model** pipeline with boundary=0.900
|
||||
- Both conditional and unconditional predictions receive the same `y` tensor
|
||||
|
||||
---
|
||||
|
||||
## Dimension Constraints
|
||||
@@ -233,7 +252,7 @@ The T2V-14B uses dual models (high-noise and low-noise). The conversion script s
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
260 tests across 9 files, all running in ~4 seconds:
|
||||
332 tests across 10 files, all running in ~5 seconds:
|
||||
|
||||
| File | Focus |
|
||||
|------|-------|
|
||||
@@ -246,6 +265,7 @@ The T2V-14B uses dual models (high-noise and low-noise). The conversion script s
|
||||
| test_wan_scheduler.py | All 3 schedulers, cross-scheduler coherence |
|
||||
| test_wan_convert.py | Weight sanitization and conversion |
|
||||
| test_wan_generate.py | End-to-end pipeline, I2V masks, dimension alignment |
|
||||
| test_wan_i2v.py | I2V-14B config, y parameter, VAE encoder, mask construction |
|
||||
|
||||
Tests use a tiny config (`dim=64, heads=2, layers=2`) for fast execution. Cross-scheduler coherence tests verify that all three schedulers produce similar outputs from the same noise.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user