feat(wan): Add I2V-14B dual-model support
This commit is contained in:
383
docs/DIAGNOSTICS.md
Normal file
383
docs/DIAGNOSTICS.md
Normal file
@@ -0,0 +1,383 @@
|
||||
# Wan2.2 I2V-14B Diagnostic Report
|
||||
|
||||
This document records the systematic diagnostic methodology used to debug the Wan2.2 I2V-14B (Image-to-Video, 14 billion parameter) pipeline in mlx-video, along with every bug found, its root cause, and fix.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Architecture Summary](#architecture-summary)
|
||||
- [Diagnostic Methodology](#diagnostic-methodology)
|
||||
- [Bug 1: Text Embedding Cross-Contamination](#bug-1-text-embedding-cross-contamination)
|
||||
- [Bug 2: VAE Encoder Weights Excluded from Conversion](#bug-2-vae-encoder-weights-excluded-from-conversion)
|
||||
- [Bug 3: RoPE Frequency Computation](#bug-3-rope-frequency-computation)
|
||||
- [Bug 4: VAE Encoder Temporal Downsample Order](#bug-4-vae-encoder-temporal-downsample-order)
|
||||
- [Bug 5: Non-Chunked VAE Encoding](#bug-5-non-chunked-vae-encoding)
|
||||
- [Verified Correct Components](#verified-correct-components)
|
||||
- [Performance Optimizations](#performance-optimizations)
|
||||
- [Open Investigation: CFG Effectiveness](#open-investigation-cfg-effectiveness)
|
||||
- [Reference Implementation](#reference-implementation)
|
||||
- [Useful Diagnostic Commands](#useful-diagnostic-commands)
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
The I2V-14B pipeline takes an input image and generates a video using a dual-model diffusion transformer. The initial implementation produced severely broken output — first frame showed the image, subsequent frames degraded to noise, checkerboard artifacts, or flat grey.
|
||||
|
||||
Through a systematic component-by-component comparison against the reference PyTorch implementation, **five bugs** were found and fixed. The approach was to verify each component in isolation numerically, then narrow down failures to the subsystem level.
|
||||
|
||||
### Timeline of Symptoms
|
||||
|
||||
| Stage | Symptom | Root Cause |
|
||||
|-------|---------|------------|
|
||||
| Initial | Grey/blurry frames after frame 1 | Non-chunked VAE encoding (Bug 5) |
|
||||
| After chunked encoding fix | First frame OK, rest degrades to noise | Text embedding cross-contamination (Bug 1) + RoPE frequencies (Bug 3) |
|
||||
| After text + RoPE fix | Severe 8px checkerboard on frames 4+ | VAE encoder temporal downsample order (Bug 4) |
|
||||
| After VAE fix | Image in frames 0-3, grey frames 4+ | CFG effectiveness issue (open investigation) |
|
||||
|
||||
---
|
||||
|
||||
## Architecture Summary
|
||||
|
||||
```
|
||||
I2V-14B Pipeline:
|
||||
Input Image → VAE Encoder → [16, T_lat, H_lat, W_lat]
|
||||
↓
|
||||
Mask Construction → [4, T_lat, H_lat, W_lat]
|
||||
↓
|
||||
y = concat(mask, encoded_video) → [20, T_lat, H_lat, W_lat]
|
||||
↓
|
||||
Noise [16, T_lat, H_lat, W_lat] + y → [36, T_lat, H_lat, W_lat]
|
||||
↓
|
||||
Dual DiT (40 layers, 5120 dim) × 40 denoising steps
|
||||
↓
|
||||
Denoised Latent [16, T_lat, H_lat, W_lat]
|
||||
↓
|
||||
VAE Decoder → Video [3, F, H, W]
|
||||
```
|
||||
|
||||
**Key parameters:**
|
||||
- `in_dim=36` (16 noise + 4 mask + 16 image latents), `out_dim=16`
|
||||
- Dual model: HIGH noise (t ≥ 900) and LOW noise (t < 900)
|
||||
- 40 steps, shift=5.0, guide_scale=(3.5, 3.5)
|
||||
- Uses Wan2.1 VAE (z_dim=16, stride 4×8×8)
|
||||
|
||||
---
|
||||
|
||||
## Diagnostic Methodology
|
||||
|
||||
### 1. Component-Level Numerical Verification
|
||||
|
||||
Each component was tested in isolation against the reference PyTorch implementation:
|
||||
|
||||
1. **Load identical inputs** (same random seed, same image, same prompt)
|
||||
2. **Run through reference** (on CPU where possible) and save intermediate tensors as `.npy`
|
||||
3. **Run through MLX** with the same inputs
|
||||
4. **Compare outputs** with `np.abs(ours - ref).max()` and relative difference metrics
|
||||
|
||||
Components tested this way:
|
||||
- RoPE frequency parameters and rotation output
|
||||
- Time embedding (sinusoidal → MLP → projection)
|
||||
- Patchify (reshape+Linear vs Conv3d)
|
||||
- Unpatchify (transpose-based vs einsum)
|
||||
- Scheduler (UniPC) timesteps and step formulas
|
||||
- VAE encoder output (frame-by-frame comparison)
|
||||
- Text embeddings (per-model MLP output)
|
||||
- Cross-attention K/V cache shapes
|
||||
- Mask construction values
|
||||
|
||||
### 2. Artifact Analysis
|
||||
|
||||
When visual artifacts appeared, quantitative metrics were used to characterize them:
|
||||
|
||||
- **Checkerboard metric**: Difference between even-indexed and odd-indexed pixels at patch boundaries. Values > 20 indicate visible checkerboard.
|
||||
- **FFT frequency analysis**: Power at the 8px spatial frequency (matches VAE stride). 3× normal power confirmed VAE-stride-aligned artifacts.
|
||||
- **Per-frame statistics**: Mean, std, min, max for each decoded video frame to track temporal degradation.
|
||||
- **Frame difference**: `mean(|frame[i] - frame[i-1]|)` to measure motion vs static content.
|
||||
|
||||
### 3. Isolation Testing
|
||||
|
||||
- **VAE round-trip test**: Encode image+zeros → decode. If clean, VAE decoder is not the source.
|
||||
- **Single-step model output**: Run one diffusion step and compare cond vs uncond predictions to check CFG effectiveness.
|
||||
- **Patchify/unpatchify synthetic test**: Pass structured gradient through unpatchify to verify spatial ordering.
|
||||
- **Resolution sweeps**: Test at 480×272, 640×384, 1280×720 to check resolution dependence.
|
||||
- **Step count sweeps**: Test at 5, 20, 40 steps to distinguish convergence issues from model bugs.
|
||||
|
||||
### 4. Weight Comparison
|
||||
|
||||
Direct comparison of converted MLX weights against original PyTorch weights:
|
||||
```python
|
||||
# Load both weight sets
|
||||
pt_weights = torch.load("model.safetensors")
|
||||
mlx_weights = mx.load("model.safetensors")
|
||||
# Compare each key
|
||||
for key in pt_weights:
|
||||
diff = np.abs(np.array(pt_weights[key]) - np.array(mlx_weights[key])).max()
|
||||
```
|
||||
Expected: max diff ≈ 0.001 (bfloat16 rounding). Actual: confirmed for all keys.
|
||||
|
||||
---
|
||||
|
||||
## Bug 1: Text Embedding Cross-Contamination
|
||||
|
||||
**Symptom:** Model ignores text prompt, generated frames lack semantic content.
|
||||
|
||||
**Root Cause:** For the dual-model architecture (high-noise and low-noise experts), text embeddings were computed using only `low_noise_model.embed_text()` and reused for both models' cross-attention K/V caches. The two models have **different** text embedding MLP weights — 42% relative mean difference in output.
|
||||
|
||||
**How Found:** Compared `text_embedding_0.weight` and `text_embedding_1.weight` between `high_noise_model.safetensors` and `low_noise_model.safetensors`. Found 17.9% and 26.3% relative differences in the weight matrices.
|
||||
|
||||
**Fix:** Compute separate text embeddings per model:
|
||||
```python
|
||||
# Before (broken):
|
||||
context_emb = low_noise_model.embed_text([context, context_null])
|
||||
cross_kv = low_noise_model.prepare_cross_kv(context_emb) # used for BOTH models
|
||||
|
||||
# After (correct):
|
||||
context_emb_low = low_noise_model.embed_text([context, context_null])
|
||||
context_emb_high = high_noise_model.embed_text([context, context_null])
|
||||
cross_kv_low = low_noise_model.prepare_cross_kv(context_emb_low)
|
||||
cross_kv_high = high_noise_model.prepare_cross_kv(context_emb_high)
|
||||
```
|
||||
|
||||
**File:** `mlx_video/generate_wan.py` (lines 333–349)
|
||||
**Commit:** `a85b1c21`
|
||||
|
||||
---
|
||||
|
||||
## Bug 2: VAE Encoder Weights Excluded from Conversion
|
||||
|
||||
**Symptom:** VAE encoder produces constant output regardless of input image (all-zero weights after conversion).
|
||||
|
||||
**Root Cause:** The conversion script only included encoder weights for `model_type == "ti2v"` (TI2V-5B), not for `"i2v"` (I2V-14B). Since `load_vae_encoder()` uses `strict=False`, missing encoder weights were silently ignored, resulting in random initialization.
|
||||
|
||||
**How Found:** Traced through `convert_wan.py` and found `include_encoder = config.model_type == "ti2v"`. Cross-referenced with the fact that I2V-14B also requires a VAE encoder (for image conditioning).
|
||||
|
||||
**Fix:**
|
||||
```python
|
||||
# Before:
|
||||
include_encoder = config.model_type == "ti2v"
|
||||
# After:
|
||||
include_encoder = config.model_type in ("ti2v", "i2v")
|
||||
```
|
||||
|
||||
**Note:** The user's specific model happened to be manually converted with encoder weights already present, so this fix was preventive for future conversions.
|
||||
|
||||
**File:** `mlx_video/convert_wan.py` (line 424)
|
||||
|
||||
---
|
||||
|
||||
## Bug 3: RoPE Frequency Computation
|
||||
|
||||
**Symptom:** Progressive 2px checkerboard artifacts on generated frames, increasing with temporal distance from the conditioned frame.
|
||||
|
||||
**Root Cause:** The reference creates **one** frequency table via `rope_params(1024, head_dim=128)` producing 64 frequency exponents, which `rope_apply` then splits into temporal (22), height (21), and width (21) portions. This gives temporal axes LOW frequencies and spatial axes progressively HIGHER frequencies.
|
||||
|
||||
Our code called `rope_params` **three times** with different normalizations:
|
||||
```python
|
||||
# WRONG: each axis gets full frequency range [0, 1)
|
||||
freqs_t = rope_params(1024, d_t=44) # 22 exponents normalized by 44
|
||||
freqs_h = rope_params(1024, d_h=42) # 21 exponents normalized by 42
|
||||
freqs_w = rope_params(1024, d_w=42) # 21 exponents normalized by 42
|
||||
```
|
||||
|
||||
The max frequency difference was ~1.0 (not a precision issue — a fundamental design bug). This affected **all** Wan models (T2V, I2V, TI2V).
|
||||
|
||||
**How Found:** Line-by-line comparison of `rope_params` usage between reference `model.py` (single call) and our `model.py` (three calls). Printed the actual frequency exponents to confirm the numerical divergence.
|
||||
|
||||
**Fix:**
|
||||
```python
|
||||
# Single unified frequency table, split by rope_apply
|
||||
self.freqs = rope_params(1024, dim // config.num_heads)
|
||||
```
|
||||
|
||||
**Impact:** ~35% reduction in checkerboard metric, 55% reduction in FFT 8px-frequency power.
|
||||
|
||||
**File:** `mlx_video/models/wan/model.py` (lines 154–156)
|
||||
**Commit:** `3da4a637`
|
||||
|
||||
---
|
||||
|
||||
## Bug 4: VAE Encoder Temporal Downsample Order
|
||||
|
||||
**Symptom:** Massive checkerboard artifacts aligned to VAE spatial stride (8px period). VAE encoder output for frames 1–4 showed decreasing std (0.37→1.19) while reference showed stable std (0.95→1.34).
|
||||
|
||||
**Root Cause:** The VAE encoder has 3 downsampling stages. Two perform spatial+temporal downsampling (`downsample3d`) and one performs spatial-only (`downsample2d`). The order matters:
|
||||
|
||||
```
|
||||
Reference: [False, True, True] → stage 0: 2d, stage 1: 3d, stage 2: 3d
|
||||
Ours: [True, True, False] → stage 0: 3d, stage 1: 3d, stage 2: 2d ← WRONG
|
||||
```
|
||||
|
||||
This caused temporal downsampling to happen at the wrong resolution stages (96-dim instead of 384-dim), corrupting temporal feature propagation.
|
||||
|
||||
**How Found:** Installed `einops` in the reference environment and ran the reference PyTorch VAE encoder on CPU. Compared frame-by-frame latent output:
|
||||
- Frame 0 matched exactly (diff=0.0000) — spatial-only processing was correct
|
||||
- Frames 1–4 had massive differences — proved temporal processing was broken
|
||||
|
||||
Then traced through the reference `_video_vae()` function and found it sets `temperal_downsample=[False, True, True]`, while our `Encoder3d` class used the wrong default `[True, True, False]`.
|
||||
|
||||
**Fix:**
|
||||
```python
|
||||
# In Encoder3d.__init__, change default:
|
||||
temporal_downsample = [False, True, True] # was [True, True, False]
|
||||
```
|
||||
|
||||
**Impact:** Encoder output now matches reference within float32 precision (max_diff=2.2e-5). Checkerboard metric dropped from 60–80 to 0.1–7.7.
|
||||
|
||||
**File:** `mlx_video/models/wan/vae.py` (line 370)
|
||||
**Commit:** `3da4a637`
|
||||
|
||||
---
|
||||
|
||||
## Bug 5: Non-Chunked VAE Encoding
|
||||
|
||||
**Symptom:** First 4–5 frames grey, then blurred version of image appears.
|
||||
|
||||
**Root Cause:** The reference VAE encoder uses **chunked encoding** with temporal caching (`feat_cache`):
|
||||
1. Encode first frame alone (1 frame)
|
||||
2. Encode remaining frames in chunks of 4, with cached temporal features propagating across chunks
|
||||
3. Each `CausalConv3d` caches last 2 temporal frames from its output, prepending them to the next chunk's input
|
||||
|
||||
Our original implementation encoded all frames at once with zero-padded causal convolutions. The temporal feature propagation is fundamentally different because:
|
||||
- Chunked: real features from previous chunks serve as causal context
|
||||
- Non-chunked: zeros serve as causal context for the start
|
||||
|
||||
**How Found:** Studied the reference `CausalConv3d` caching mechanism (`feat_cache`, `feat_idx`) and traced the temporal dimension through all encoding stages. Confirmed that non-chunked encoding produces different output by comparing tensor shapes and values.
|
||||
|
||||
**Fix:** Implemented full chunked encoding with temporal caching:
|
||||
- Added `cache_x` parameter to `CausalConv3d.__call__`
|
||||
- Added `feat_cache`/`feat_idx` propagation to `ResidualBlock`, `Resample`, `Encoder3d`
|
||||
- Rewrote `WanVAE.encode()` with chunked loop (1-frame first chunk, then 4-frame chunks)
|
||||
- 24 cache slots across the encoder (1 conv1 + 18 downsamples + 4 middle + 1 head)
|
||||
|
||||
**File:** `mlx_video/models/wan/vae.py` (multiple methods)
|
||||
**Commit:** `b6a94c4c`
|
||||
|
||||
---
|
||||
|
||||
## Verified Correct Components
|
||||
|
||||
These components were numerically verified against the reference and are **not** sources of bugs:
|
||||
|
||||
| Component | Method | Max Diff | Notes |
|
||||
|-----------|--------|----------|-------|
|
||||
| Weight conversion | Direct tensor comparison | ~0.001 | bfloat16 rounding only |
|
||||
| RoPE rotation | Standalone comparison (float32 vs float64) | 1.3e-5 | Complex vs real multiplication equivalent |
|
||||
| Time embedding | Full MLP comparison (sinusoidal→embed→project) | 7e-4 | 0.03% relative |
|
||||
| Patchify | Conv3d vs reshape+Linear | 3.5e-3 | 0.16% relative |
|
||||
| Unpatchify | einsum vs transpose(6,0,3,1,4,2,5) | exact | Identical operation |
|
||||
| Scheduler (UniPC) | Formula-level audit + timestep comparison | exact | Predictor, corrector, lambda, rhos all match |
|
||||
| Mask construction | Value comparison | exact | [4, T_lat, H_lat, W_lat], first temporal=1 |
|
||||
| CFG formula | Code audit | — | `uncond + gs * (cond - uncond)` correct order |
|
||||
| VAE decoder | Round-trip test (encode→decode) | clean | No checkerboard in round-trip output |
|
||||
| Cross-attention K/V | Shape and value audit | — | Batch dimension preserved correctly |
|
||||
|
||||
---
|
||||
|
||||
## Performance Optimizations
|
||||
|
||||
Applied alongside bug fixes to improve inference speed:
|
||||
|
||||
### Pre-Computation (Before Diffusion Loop)
|
||||
- **Cross-attention K/V caching**: Precompute K/V projections for all 40 blocks once
|
||||
- **RoPE cos/sin precomputation**: Build frequency tensors once instead of per-step broadcast/concat
|
||||
- **Attention mask precomputation**: Build padding mask once, pass via kwargs
|
||||
- **Inverse frequency caching**: Store sinusoidal `inv_freq` in `__init__` instead of recomputing
|
||||
- **Timestep list conversion**: `sched.timesteps.tolist()` before loop to avoid `.item()` sync
|
||||
|
||||
### Per-Step Optimizations
|
||||
- **Single patchify + broadcast for CFG B=2**: Detect identical batch inputs, patchify once and broadcast instead of duplicating the Linear projection
|
||||
- **Vectorized RoPE**: When all batch elements share the same grid size, apply rotation to the full batch tensor instead of looping per element
|
||||
- **Redundant type cast removal**: MLX type promotion handles `bfloat16 * float32 → float32` automatically — removed 240 unnecessary graph nodes per step (6 casts × 40 blocks)
|
||||
- **Euler scheduler sync fix**: Pre-store sigmas as Python floats to avoid `.item()` evaluation sync
|
||||
|
||||
### TeaCache Integration
|
||||
- Polynomial rescaling stays in MLX lazy graph (Horner's method)
|
||||
- Single `.item()` call on the accumulated distance for the skip/compute decision
|
||||
- Configurable threshold, retention steps, and cutoff steps
|
||||
|
||||
---
|
||||
|
||||
## Open Investigation: CFG Effectiveness
|
||||
|
||||
**Current symptom:** After all bug fixes, generated video shows the input image in frames 0–3 (latent frame 0), then grey/flat frames for the rest.
|
||||
|
||||
**Finding:** A single forward pass at t=1000 shows cond and uncond predictions are nearly identical (|diff| mean = 0.01–0.035). With `guide_scale=3.5`, the CFG guidance term barely changes anything.
|
||||
|
||||
**Possible causes under investigation:**
|
||||
1. Cross-attention context flow — both cond and uncond may be receiving equivalent context
|
||||
2. The model may genuinely produce small cond/uncond differences for I2V (since both share the same y conditioning)
|
||||
3. The `embed_text` method or `prepare_cross_kv` may not properly separate B=2 batch elements
|
||||
4. There may be an issue with how cross-attention K/V caches index into batch elements
|
||||
|
||||
**Diagnostic approach:** Compare cross-attention K/V cache values between cond (index 0) and uncond (index 1) to confirm they contain different embeddings.
|
||||
|
||||
---
|
||||
|
||||
## Reference Implementation
|
||||
|
||||
The reference PyTorch implementation is at `/Users/daniel/Projects/Wan2.2/`:
|
||||
|
||||
| File | Contents |
|
||||
|------|----------|
|
||||
| `wan/image2video.py` | I2V pipeline (y construction, mask, diffusion loop) |
|
||||
| `wan/modules/model.py` | DiT model (forward pass, RoPE, patchify) |
|
||||
| `wan/modules/vae2_1.py` | VAE encoder/decoder with chunked encoding |
|
||||
| `wan/utils/fm_solvers_unipc.py` | UniPC scheduler |
|
||||
| `wan/configs/wan_i2v_A14B.py` | Model configuration |
|
||||
|
||||
Key structural differences between reference and our implementation:
|
||||
- Reference runs **separate B=1 forward passes** for cond/uncond; we batch as B=2
|
||||
- Reference uses `torch.amp.autocast('cuda', dtype=bfloat16)` with explicit float32 blocks; we cast via weight dtype
|
||||
- Reference uses `Conv3d` for patchify; we use equivalent `reshape + Linear`
|
||||
- Reference casts timesteps to `int64`; we keep as float (diff < 1.0)
|
||||
|
||||
---
|
||||
|
||||
## Useful Diagnostic Commands
|
||||
|
||||
### Run I2V-14B generation
|
||||
```bash
|
||||
python -m mlx_video.generate_wan \
|
||||
--prompt "A woman smiles at camera" \
|
||||
--image start.png \
|
||||
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-I2V-A14B-MLX \
|
||||
--num-frames 17 --steps 40 \
|
||||
--height 384 --width 640 \
|
||||
--output output_i2v.mp4
|
||||
```
|
||||
|
||||
### Check VAE encoder output
|
||||
```python
|
||||
import mlx.core as mx, numpy as np
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
# Load VAE and encode an image
|
||||
latents = vae.encode(video_tensor) # [1, 16, T_lat, H_lat, W_lat]
|
||||
for t in range(latents.shape[2]):
|
||||
frame = np.array(latents[0, :, t])
|
||||
print(f"Frame {t}: mean={frame.mean():.4f} std={frame.std():.4f}")
|
||||
```
|
||||
|
||||
### Analyze video frame quality
|
||||
```python
|
||||
import cv2, numpy as np
|
||||
cap = cv2.VideoCapture("output.mp4")
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret: break
|
||||
# Checkerboard metric: high values indicate patch-boundary artifacts
|
||||
checker = np.abs(frame[::2, ::2].astype(float) - frame[1::2, 1::2].astype(float)).mean()
|
||||
print(f"std={frame.std():.1f} checker={checker:.1f}")
|
||||
```
|
||||
|
||||
### Compare weights between PyTorch and MLX
|
||||
```python
|
||||
import torch, mlx.core as mx, numpy as np
|
||||
pt = torch.load("model.pt", map_location="cpu")
|
||||
mlx_w = mx.load("model.safetensors")
|
||||
for key in sorted(pt.keys()):
|
||||
if key in mlx_w:
|
||||
diff = np.abs(pt[key].float().numpy() - np.array(mlx_w[key])).max()
|
||||
if diff > 0.01:
|
||||
print(f"LARGE DIFF {key}: {diff:.6f}")
|
||||
```
|
||||
@@ -1,6 +1,6 @@
|
||||
# Wan2.2 MLX Implementation Notes
|
||||
|
||||
> Learnings and key decisions from porting Wan2.2 (TI2V-5B / T2V-14B / T2V-1.3B) to Apple MLX.
|
||||
> Learnings and key decisions from porting Wan2.2 (TI2V-5B / T2V-14B / I2V-14B / T2V-1.3B) to Apple MLX.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
@@ -8,11 +8,12 @@ Wan2.2 is a Diffusion Transformer (DiT) for video generation. Despite early repo
|
||||
|
||||
### Key Parameters
|
||||
|
||||
| Model | dim | heads | layers | FFN mult | VAE z_dim | VAE stride |
|
||||
|-------|-----|-------|--------|----------|-----------|------------|
|
||||
| T2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) |
|
||||
| TI2V-5B | 3072 | 24 | 32 | 4×(3072×4/3) | 48 | (4, 16, 16) |
|
||||
| T2V-1.3B | 1536 | 12 | 30 | 4×(1536×4/3) | 16 | (4, 8, 8) |
|
||||
| Model | dim | heads | layers | FFN mult | VAE z_dim | VAE stride | in_dim |
|
||||
|-------|-----|-------|--------|----------|-----------|------------|--------|
|
||||
| T2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | 16 |
|
||||
| I2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | 36 |
|
||||
| TI2V-5B | 3072 | 24 | 32 | 4×(3072×4/3) | 48 | (4, 16, 16) | 48 |
|
||||
| T2V-1.3B | 1536 | 12 | 30 | 4×(1536×4/3) | 16 | (4, 8, 8) | 16 |
|
||||
|
||||
### Codebase Structure (~3900 lines of Wan2.2 code)
|
||||
|
||||
@@ -139,9 +140,11 @@ Default shifts: T2V-14B uses 5.0, TI2V-5B uses 3.0, T2V-1.3B uses 3.0.
|
||||
|
||||
---
|
||||
|
||||
## Image-to-Video (I2V) Pipeline
|
||||
## Image-to-Video (I2V) Pipelines
|
||||
|
||||
### Per-Token Timesteps
|
||||
Wan2.2 supports two distinct I2V approaches:
|
||||
|
||||
### TI2V-5B: Per-Token Timestep Masking
|
||||
|
||||
I2V conditions on a reference first frame by giving first-frame latent patches a timestep of 0 (clean) while other patches get the current diffusion timestep:
|
||||
|
||||
@@ -152,7 +155,7 @@ t_tokens = mask_tokens * current_timestep # first-frame → t=0
|
||||
|
||||
The model receives 2D timestep input `[B, L]` instead of scalar, enabling per-token noise levels.
|
||||
|
||||
### Mask Re-application
|
||||
#### Mask Re-application
|
||||
|
||||
After each scheduler step, the first-frame latent is re-injected to prevent drift:
|
||||
|
||||
@@ -160,7 +163,7 @@ After each scheduler step, the first-frame latent is re-injected to prevent drif
|
||||
latents = (1.0 - mask) * z_img + mask * latents
|
||||
```
|
||||
|
||||
### VAE Encoder Temporal Downsample Order
|
||||
#### VAE Encoder Temporal Downsample Order
|
||||
|
||||
The Wan2.2 VAE encoder has `temporal_downsample = (False, True, True)`:
|
||||
- Stage 0: Spatial-only downsampling
|
||||
@@ -168,6 +171,22 @@ The Wan2.2 VAE encoder has `temporal_downsample = (False, True, True)`:
|
||||
|
||||
This was incorrectly set to `(True, True, False)` initially, causing wrong spatial processing paths.
|
||||
|
||||
### I2V-14B: Channel Concatenation
|
||||
|
||||
The I2V-14B model uses a fundamentally different approach — channel concatenation via a `y` tensor:
|
||||
|
||||
1. **Encode image**: Resize to target (H, W), create video tensor with image as first frame + zeros → VAE encode through Wan2.1 encoder → `[16, T_lat, H_lat, W_lat]`
|
||||
2. **Build mask**: Binary mask with 1 for first frame, 0 for rest → rearranged to `[4, T_lat, H_lat, W_lat]`
|
||||
3. **Construct y**: `y = concat([mask_4ch, encoded_16ch])` → `[20, T_lat, H_lat, W_lat]`
|
||||
4. **Channel concat in model**: Before patchify, `x = concat([noise_16ch, y_20ch])` → 36 channels matching `in_dim=36`
|
||||
|
||||
Key differences from TI2V-5B:
|
||||
- Uses **Wan2.1 VAE** (z_dim=16, stride 4,8,8), not Wan2.2 VAE
|
||||
- Requires the **VAE encoder** (for encoding the reference image)
|
||||
- Uses **scalar timesteps** (same as T2V) — no per-token masking
|
||||
- **Dual model** pipeline with boundary=0.900
|
||||
- Both conditional and unconditional predictions receive the same `y` tensor
|
||||
|
||||
---
|
||||
|
||||
## Dimension Constraints
|
||||
@@ -233,7 +252,7 @@ The T2V-14B uses dual models (high-noise and low-noise). The conversion script s
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
260 tests across 9 files, all running in ~4 seconds:
|
||||
332 tests across 10 files, all running in ~5 seconds:
|
||||
|
||||
| File | Focus |
|
||||
|------|-------|
|
||||
@@ -246,6 +265,7 @@ The T2V-14B uses dual models (high-noise and low-noise). The conversion script s
|
||||
| test_wan_scheduler.py | All 3 schedulers, cross-scheduler coherence |
|
||||
| test_wan_convert.py | Weight sanitization and conversion |
|
||||
| test_wan_generate.py | End-to-end pipeline, I2V masks, dimension alignment |
|
||||
| test_wan_i2v.py | I2V-14B config, y parameter, VAE encoder, mask construction |
|
||||
|
||||
Tests use a tiny config (`dim=64, heads=2, layers=2`) for fast execution. Cross-scheduler coherence tests verify that all three schedulers produce similar outputs from the same noise.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user