From f4195f0118fb5caf4fbe2eeacfe12e163dcba7a0 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 27 Feb 2026 23:43:42 +0100 Subject: [PATCH] feat(wan): Add I2V-14B dual-model support --- README.md | 39 ++- docs/DIAGNOSTICS.md | 383 ++++++++++++++++++++++++++++ docs/wan22-implementation-notes.md | 42 ++- mlx_video/convert_wan.py | 10 +- mlx_video/generate_wan.py | 142 ++++++++--- mlx_video/models/wan/attention.py | 15 +- mlx_video/models/wan/config.py | 13 + mlx_video/models/wan/loading.py | 19 +- mlx_video/models/wan/model.py | 139 ++++++---- mlx_video/models/wan/rope.py | 78 ++++++ mlx_video/models/wan/scheduler.py | 6 +- mlx_video/models/wan/transformer.py | 26 +- mlx_video/models/wan/vae.py | 279 +++++++++++++++++--- tests/test_wan_i2v.py | 293 +++++++++++++++++++++ 14 files changed, 1332 insertions(+), 152 deletions(-) create mode 100644 docs/DIAGNOSTICS.md create mode 100644 tests/test_wan_i2v.py diff --git a/README.md b/README.md index 2955ffc..751f48a 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ Supported models: - [**LTX-2**](https://huggingface.co/Lightricks/LTX-Video) — 19B parameter video generation model from Lightricks - [**Wan2.1**](https://github.com/Wan-Video/Wan2.1) — 1.3B / 14B parameter T2V models (single-model pipeline) -- [**Wan2.2**](https://github.com/Wan-Video/Wan2.2) — 14B parameter T2V model (dual-model pipeline) +- [**Wan2.2**](https://github.com/Wan-Video/Wan2.2) — T2V-14B, TI2V-5B, and I2V-14B models (dual-model pipeline) ## Features @@ -82,13 +82,15 @@ python -m mlx_video.generate \ Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.com/Wan-Video/Wan2.2) are text-to-video diffusion models built on a DiT (Diffusion Transformer) backbone with a T5 text encoder and 3D VAE. They share the same model architecture — the difference is in the inference pipeline: -| | Wan2.1 | Wan2.2 | -|---|--------|--------| -| **Pipeline** | Single model | Dual model (high-noise + low-noise) | -| **Sizes** | 1.3B, 14B | 14B | -| **Steps** | 50 | 40 | -| **Guidance** | 5.0 (fixed) | 3.0 / 4.0 (low/high noise) | -| **Shift** | 5.0 | 12.0 | +| | Wan2.1 | Wan2.2 T2V-14B | Wan2.2 I2V-14B | +|---|--------|--------|--------| +| **Task** | Text-to-Video | Text-to-Video | Image-to-Video | +| **Pipeline** | Single model | Dual model | Dual model | +| **Sizes** | 1.3B, 14B | 14B | 14B | +| **Steps** | 50 | 40 | 40 | +| **Guidance** | 5.0 (fixed) | 3.0 / 4.0 | 3.5 / 3.5 | +| **Shift** | 5.0 | 12.0 | 5.0 | +| **VAE** | Wan2.1 (z=16) | Wan2.1 (z=16) | Wan2.1 (z=16) + encoder | ### Step 1: Download Weights @@ -117,9 +119,11 @@ Download the original PyTorch checkpoints: # └── high_noise_model/ # safetensors ``` +**Wan2.2 I2V-14B** — same directory structure as Wan2.2 T2V. The conversion script auto-detects I2V-14B from the model's `config.json` (`model_type: "i2v"`, `in_dim: 36`). + ### Step 2: Convert to MLX Format -The conversion script auto-detects whether the checkpoint is Wan2.1 or Wan2.2 based on the directory structure (presence of `low_noise_model/` subdirectory). +The conversion script auto-detects the model version based on the directory structure (presence of `low_noise_model/` subdirectory) and model type (`model_type` in source config.json for I2V vs T2V). ```bash # Auto-detect version @@ -157,6 +161,7 @@ wan_mlx/ ├── config.json # Model configuration ├── t5_encoder.safetensors # T5 UMT5-XXL text encoder ├── vae.safetensors # 3D VAE decoder +├── vae_encoder.safetensors # 3D VAE encoder (I2V-14B only) ├── model.safetensors # (Wan2.1) Single transformer ├── low_noise_model.safetensors # (Wan2.2) Low-noise transformer └── high_noise_model.safetensors # (Wan2.2) High-noise transformer @@ -195,12 +200,27 @@ python -m mlx_video.generate_wan \ The pipeline auto-detects the model version from `config.json` and selects the right pipeline mode (single or dual model). You can also override any parameter via CLI flags. +#### Image-to-Video (I2V-14B) + +```bash +# Generate video from an input image +python -m mlx_video.generate_wan \ + --model-dir wan22_i2v_mlx \ + --prompt "The camera slowly zooms in as the subject begins to move" \ + --image start.png \ + --num-frames 81 \ + --output-path my_video.mp4 +``` + +The I2V-14B model encodes the input image through the Wan2.1 VAE encoder and uses channel concatenation (`y` tensor with 4 mask + 16 image latent channels) to condition generation on the first frame. + #### Generation CLI Options | Option | Default | Description | |--------|---------|-------------| | `--model-dir` | (required) | Path to converted MLX model directory | | `--prompt` | (required) | Text description of the video | +| `--image` | `None` | Input image path (for I2V models) | | `--negative-prompt` | `""` | Negative prompt for guidance | | `--width` | 1280 | Video width | | `--height` | 720 | Video height | @@ -237,6 +257,7 @@ python -m mlx_video.generate_wan \ > **Note**: On Apple Silicon, the 1.3B model fits comfortably in unified memory at bf16. Quantization reduces memory but may not speed up inference for small models. For the 14B model, quantization is essential to fit in memory and will also improve speed. + ### Wan Model Specifications **Transformer (14B)** diff --git a/docs/DIAGNOSTICS.md b/docs/DIAGNOSTICS.md new file mode 100644 index 0000000..18d112b --- /dev/null +++ b/docs/DIAGNOSTICS.md @@ -0,0 +1,383 @@ +# Wan2.2 I2V-14B Diagnostic Report + +This document records the systematic diagnostic methodology used to debug the Wan2.2 I2V-14B (Image-to-Video, 14 billion parameter) pipeline in mlx-video, along with every bug found, its root cause, and fix. + +## Table of Contents + +- [Overview](#overview) +- [Architecture Summary](#architecture-summary) +- [Diagnostic Methodology](#diagnostic-methodology) +- [Bug 1: Text Embedding Cross-Contamination](#bug-1-text-embedding-cross-contamination) +- [Bug 2: VAE Encoder Weights Excluded from Conversion](#bug-2-vae-encoder-weights-excluded-from-conversion) +- [Bug 3: RoPE Frequency Computation](#bug-3-rope-frequency-computation) +- [Bug 4: VAE Encoder Temporal Downsample Order](#bug-4-vae-encoder-temporal-downsample-order) +- [Bug 5: Non-Chunked VAE Encoding](#bug-5-non-chunked-vae-encoding) +- [Verified Correct Components](#verified-correct-components) +- [Performance Optimizations](#performance-optimizations) +- [Open Investigation: CFG Effectiveness](#open-investigation-cfg-effectiveness) +- [Reference Implementation](#reference-implementation) +- [Useful Diagnostic Commands](#useful-diagnostic-commands) + +--- + +## Overview + +The I2V-14B pipeline takes an input image and generates a video using a dual-model diffusion transformer. The initial implementation produced severely broken output — first frame showed the image, subsequent frames degraded to noise, checkerboard artifacts, or flat grey. + +Through a systematic component-by-component comparison against the reference PyTorch implementation, **five bugs** were found and fixed. The approach was to verify each component in isolation numerically, then narrow down failures to the subsystem level. + +### Timeline of Symptoms + +| Stage | Symptom | Root Cause | +|-------|---------|------------| +| Initial | Grey/blurry frames after frame 1 | Non-chunked VAE encoding (Bug 5) | +| After chunked encoding fix | First frame OK, rest degrades to noise | Text embedding cross-contamination (Bug 1) + RoPE frequencies (Bug 3) | +| After text + RoPE fix | Severe 8px checkerboard on frames 4+ | VAE encoder temporal downsample order (Bug 4) | +| After VAE fix | Image in frames 0-3, grey frames 4+ | CFG effectiveness issue (open investigation) | + +--- + +## Architecture Summary + +``` +I2V-14B Pipeline: + Input Image → VAE Encoder → [16, T_lat, H_lat, W_lat] + ↓ + Mask Construction → [4, T_lat, H_lat, W_lat] + ↓ + y = concat(mask, encoded_video) → [20, T_lat, H_lat, W_lat] + ↓ + Noise [16, T_lat, H_lat, W_lat] + y → [36, T_lat, H_lat, W_lat] + ↓ + Dual DiT (40 layers, 5120 dim) × 40 denoising steps + ↓ + Denoised Latent [16, T_lat, H_lat, W_lat] + ↓ + VAE Decoder → Video [3, F, H, W] +``` + +**Key parameters:** +- `in_dim=36` (16 noise + 4 mask + 16 image latents), `out_dim=16` +- Dual model: HIGH noise (t ≥ 900) and LOW noise (t < 900) +- 40 steps, shift=5.0, guide_scale=(3.5, 3.5) +- Uses Wan2.1 VAE (z_dim=16, stride 4×8×8) + +--- + +## Diagnostic Methodology + +### 1. Component-Level Numerical Verification + +Each component was tested in isolation against the reference PyTorch implementation: + +1. **Load identical inputs** (same random seed, same image, same prompt) +2. **Run through reference** (on CPU where possible) and save intermediate tensors as `.npy` +3. **Run through MLX** with the same inputs +4. **Compare outputs** with `np.abs(ours - ref).max()` and relative difference metrics + +Components tested this way: +- RoPE frequency parameters and rotation output +- Time embedding (sinusoidal → MLP → projection) +- Patchify (reshape+Linear vs Conv3d) +- Unpatchify (transpose-based vs einsum) +- Scheduler (UniPC) timesteps and step formulas +- VAE encoder output (frame-by-frame comparison) +- Text embeddings (per-model MLP output) +- Cross-attention K/V cache shapes +- Mask construction values + +### 2. Artifact Analysis + +When visual artifacts appeared, quantitative metrics were used to characterize them: + +- **Checkerboard metric**: Difference between even-indexed and odd-indexed pixels at patch boundaries. Values > 20 indicate visible checkerboard. +- **FFT frequency analysis**: Power at the 8px spatial frequency (matches VAE stride). 3× normal power confirmed VAE-stride-aligned artifacts. +- **Per-frame statistics**: Mean, std, min, max for each decoded video frame to track temporal degradation. +- **Frame difference**: `mean(|frame[i] - frame[i-1]|)` to measure motion vs static content. + +### 3. Isolation Testing + +- **VAE round-trip test**: Encode image+zeros → decode. If clean, VAE decoder is not the source. +- **Single-step model output**: Run one diffusion step and compare cond vs uncond predictions to check CFG effectiveness. +- **Patchify/unpatchify synthetic test**: Pass structured gradient through unpatchify to verify spatial ordering. +- **Resolution sweeps**: Test at 480×272, 640×384, 1280×720 to check resolution dependence. +- **Step count sweeps**: Test at 5, 20, 40 steps to distinguish convergence issues from model bugs. + +### 4. Weight Comparison + +Direct comparison of converted MLX weights against original PyTorch weights: +```python +# Load both weight sets +pt_weights = torch.load("model.safetensors") +mlx_weights = mx.load("model.safetensors") +# Compare each key +for key in pt_weights: + diff = np.abs(np.array(pt_weights[key]) - np.array(mlx_weights[key])).max() +``` +Expected: max diff ≈ 0.001 (bfloat16 rounding). Actual: confirmed for all keys. + +--- + +## Bug 1: Text Embedding Cross-Contamination + +**Symptom:** Model ignores text prompt, generated frames lack semantic content. + +**Root Cause:** For the dual-model architecture (high-noise and low-noise experts), text embeddings were computed using only `low_noise_model.embed_text()` and reused for both models' cross-attention K/V caches. The two models have **different** text embedding MLP weights — 42% relative mean difference in output. + +**How Found:** Compared `text_embedding_0.weight` and `text_embedding_1.weight` between `high_noise_model.safetensors` and `low_noise_model.safetensors`. Found 17.9% and 26.3% relative differences in the weight matrices. + +**Fix:** Compute separate text embeddings per model: +```python +# Before (broken): +context_emb = low_noise_model.embed_text([context, context_null]) +cross_kv = low_noise_model.prepare_cross_kv(context_emb) # used for BOTH models + +# After (correct): +context_emb_low = low_noise_model.embed_text([context, context_null]) +context_emb_high = high_noise_model.embed_text([context, context_null]) +cross_kv_low = low_noise_model.prepare_cross_kv(context_emb_low) +cross_kv_high = high_noise_model.prepare_cross_kv(context_emb_high) +``` + +**File:** `mlx_video/generate_wan.py` (lines 333–349) +**Commit:** `a85b1c21` + +--- + +## Bug 2: VAE Encoder Weights Excluded from Conversion + +**Symptom:** VAE encoder produces constant output regardless of input image (all-zero weights after conversion). + +**Root Cause:** The conversion script only included encoder weights for `model_type == "ti2v"` (TI2V-5B), not for `"i2v"` (I2V-14B). Since `load_vae_encoder()` uses `strict=False`, missing encoder weights were silently ignored, resulting in random initialization. + +**How Found:** Traced through `convert_wan.py` and found `include_encoder = config.model_type == "ti2v"`. Cross-referenced with the fact that I2V-14B also requires a VAE encoder (for image conditioning). + +**Fix:** +```python +# Before: +include_encoder = config.model_type == "ti2v" +# After: +include_encoder = config.model_type in ("ti2v", "i2v") +``` + +**Note:** The user's specific model happened to be manually converted with encoder weights already present, so this fix was preventive for future conversions. + +**File:** `mlx_video/convert_wan.py` (line 424) + +--- + +## Bug 3: RoPE Frequency Computation + +**Symptom:** Progressive 2px checkerboard artifacts on generated frames, increasing with temporal distance from the conditioned frame. + +**Root Cause:** The reference creates **one** frequency table via `rope_params(1024, head_dim=128)` producing 64 frequency exponents, which `rope_apply` then splits into temporal (22), height (21), and width (21) portions. This gives temporal axes LOW frequencies and spatial axes progressively HIGHER frequencies. + +Our code called `rope_params` **three times** with different normalizations: +```python +# WRONG: each axis gets full frequency range [0, 1) +freqs_t = rope_params(1024, d_t=44) # 22 exponents normalized by 44 +freqs_h = rope_params(1024, d_h=42) # 21 exponents normalized by 42 +freqs_w = rope_params(1024, d_w=42) # 21 exponents normalized by 42 +``` + +The max frequency difference was ~1.0 (not a precision issue — a fundamental design bug). This affected **all** Wan models (T2V, I2V, TI2V). + +**How Found:** Line-by-line comparison of `rope_params` usage between reference `model.py` (single call) and our `model.py` (three calls). Printed the actual frequency exponents to confirm the numerical divergence. + +**Fix:** +```python +# Single unified frequency table, split by rope_apply +self.freqs = rope_params(1024, dim // config.num_heads) +``` + +**Impact:** ~35% reduction in checkerboard metric, 55% reduction in FFT 8px-frequency power. + +**File:** `mlx_video/models/wan/model.py` (lines 154–156) +**Commit:** `3da4a637` + +--- + +## Bug 4: VAE Encoder Temporal Downsample Order + +**Symptom:** Massive checkerboard artifacts aligned to VAE spatial stride (8px period). VAE encoder output for frames 1–4 showed decreasing std (0.37→1.19) while reference showed stable std (0.95→1.34). + +**Root Cause:** The VAE encoder has 3 downsampling stages. Two perform spatial+temporal downsampling (`downsample3d`) and one performs spatial-only (`downsample2d`). The order matters: + +``` +Reference: [False, True, True] → stage 0: 2d, stage 1: 3d, stage 2: 3d +Ours: [True, True, False] → stage 0: 3d, stage 1: 3d, stage 2: 2d ← WRONG +``` + +This caused temporal downsampling to happen at the wrong resolution stages (96-dim instead of 384-dim), corrupting temporal feature propagation. + +**How Found:** Installed `einops` in the reference environment and ran the reference PyTorch VAE encoder on CPU. Compared frame-by-frame latent output: +- Frame 0 matched exactly (diff=0.0000) — spatial-only processing was correct +- Frames 1–4 had massive differences — proved temporal processing was broken + +Then traced through the reference `_video_vae()` function and found it sets `temperal_downsample=[False, True, True]`, while our `Encoder3d` class used the wrong default `[True, True, False]`. + +**Fix:** +```python +# In Encoder3d.__init__, change default: +temporal_downsample = [False, True, True] # was [True, True, False] +``` + +**Impact:** Encoder output now matches reference within float32 precision (max_diff=2.2e-5). Checkerboard metric dropped from 60–80 to 0.1–7.7. + +**File:** `mlx_video/models/wan/vae.py` (line 370) +**Commit:** `3da4a637` + +--- + +## Bug 5: Non-Chunked VAE Encoding + +**Symptom:** First 4–5 frames grey, then blurred version of image appears. + +**Root Cause:** The reference VAE encoder uses **chunked encoding** with temporal caching (`feat_cache`): +1. Encode first frame alone (1 frame) +2. Encode remaining frames in chunks of 4, with cached temporal features propagating across chunks +3. Each `CausalConv3d` caches last 2 temporal frames from its output, prepending them to the next chunk's input + +Our original implementation encoded all frames at once with zero-padded causal convolutions. The temporal feature propagation is fundamentally different because: +- Chunked: real features from previous chunks serve as causal context +- Non-chunked: zeros serve as causal context for the start + +**How Found:** Studied the reference `CausalConv3d` caching mechanism (`feat_cache`, `feat_idx`) and traced the temporal dimension through all encoding stages. Confirmed that non-chunked encoding produces different output by comparing tensor shapes and values. + +**Fix:** Implemented full chunked encoding with temporal caching: +- Added `cache_x` parameter to `CausalConv3d.__call__` +- Added `feat_cache`/`feat_idx` propagation to `ResidualBlock`, `Resample`, `Encoder3d` +- Rewrote `WanVAE.encode()` with chunked loop (1-frame first chunk, then 4-frame chunks) +- 24 cache slots across the encoder (1 conv1 + 18 downsamples + 4 middle + 1 head) + +**File:** `mlx_video/models/wan/vae.py` (multiple methods) +**Commit:** `b6a94c4c` + +--- + +## Verified Correct Components + +These components were numerically verified against the reference and are **not** sources of bugs: + +| Component | Method | Max Diff | Notes | +|-----------|--------|----------|-------| +| Weight conversion | Direct tensor comparison | ~0.001 | bfloat16 rounding only | +| RoPE rotation | Standalone comparison (float32 vs float64) | 1.3e-5 | Complex vs real multiplication equivalent | +| Time embedding | Full MLP comparison (sinusoidal→embed→project) | 7e-4 | 0.03% relative | +| Patchify | Conv3d vs reshape+Linear | 3.5e-3 | 0.16% relative | +| Unpatchify | einsum vs transpose(6,0,3,1,4,2,5) | exact | Identical operation | +| Scheduler (UniPC) | Formula-level audit + timestep comparison | exact | Predictor, corrector, lambda, rhos all match | +| Mask construction | Value comparison | exact | [4, T_lat, H_lat, W_lat], first temporal=1 | +| CFG formula | Code audit | — | `uncond + gs * (cond - uncond)` correct order | +| VAE decoder | Round-trip test (encode→decode) | clean | No checkerboard in round-trip output | +| Cross-attention K/V | Shape and value audit | — | Batch dimension preserved correctly | + +--- + +## Performance Optimizations + +Applied alongside bug fixes to improve inference speed: + +### Pre-Computation (Before Diffusion Loop) +- **Cross-attention K/V caching**: Precompute K/V projections for all 40 blocks once +- **RoPE cos/sin precomputation**: Build frequency tensors once instead of per-step broadcast/concat +- **Attention mask precomputation**: Build padding mask once, pass via kwargs +- **Inverse frequency caching**: Store sinusoidal `inv_freq` in `__init__` instead of recomputing +- **Timestep list conversion**: `sched.timesteps.tolist()` before loop to avoid `.item()` sync + +### Per-Step Optimizations +- **Single patchify + broadcast for CFG B=2**: Detect identical batch inputs, patchify once and broadcast instead of duplicating the Linear projection +- **Vectorized RoPE**: When all batch elements share the same grid size, apply rotation to the full batch tensor instead of looping per element +- **Redundant type cast removal**: MLX type promotion handles `bfloat16 * float32 → float32` automatically — removed 240 unnecessary graph nodes per step (6 casts × 40 blocks) +- **Euler scheduler sync fix**: Pre-store sigmas as Python floats to avoid `.item()` evaluation sync + +### TeaCache Integration +- Polynomial rescaling stays in MLX lazy graph (Horner's method) +- Single `.item()` call on the accumulated distance for the skip/compute decision +- Configurable threshold, retention steps, and cutoff steps + +--- + +## Open Investigation: CFG Effectiveness + +**Current symptom:** After all bug fixes, generated video shows the input image in frames 0–3 (latent frame 0), then grey/flat frames for the rest. + +**Finding:** A single forward pass at t=1000 shows cond and uncond predictions are nearly identical (|diff| mean = 0.01–0.035). With `guide_scale=3.5`, the CFG guidance term barely changes anything. + +**Possible causes under investigation:** +1. Cross-attention context flow — both cond and uncond may be receiving equivalent context +2. The model may genuinely produce small cond/uncond differences for I2V (since both share the same y conditioning) +3. The `embed_text` method or `prepare_cross_kv` may not properly separate B=2 batch elements +4. There may be an issue with how cross-attention K/V caches index into batch elements + +**Diagnostic approach:** Compare cross-attention K/V cache values between cond (index 0) and uncond (index 1) to confirm they contain different embeddings. + +--- + +## Reference Implementation + +The reference PyTorch implementation is at `/Users/daniel/Projects/Wan2.2/`: + +| File | Contents | +|------|----------| +| `wan/image2video.py` | I2V pipeline (y construction, mask, diffusion loop) | +| `wan/modules/model.py` | DiT model (forward pass, RoPE, patchify) | +| `wan/modules/vae2_1.py` | VAE encoder/decoder with chunked encoding | +| `wan/utils/fm_solvers_unipc.py` | UniPC scheduler | +| `wan/configs/wan_i2v_A14B.py` | Model configuration | + +Key structural differences between reference and our implementation: +- Reference runs **separate B=1 forward passes** for cond/uncond; we batch as B=2 +- Reference uses `torch.amp.autocast('cuda', dtype=bfloat16)` with explicit float32 blocks; we cast via weight dtype +- Reference uses `Conv3d` for patchify; we use equivalent `reshape + Linear` +- Reference casts timesteps to `int64`; we keep as float (diff < 1.0) + +--- + +## Useful Diagnostic Commands + +### Run I2V-14B generation +```bash +python -m mlx_video.generate_wan \ + --prompt "A woman smiles at camera" \ + --image start.png \ + --model-dir /Volumes/SSD/Wan-AI/Wan2.2-I2V-A14B-MLX \ + --num-frames 17 --steps 40 \ + --height 384 --width 640 \ + --output output_i2v.mp4 +``` + +### Check VAE encoder output +```python +import mlx.core as mx, numpy as np +from mlx_video.models.wan.vae import WanVAE +# Load VAE and encode an image +latents = vae.encode(video_tensor) # [1, 16, T_lat, H_lat, W_lat] +for t in range(latents.shape[2]): + frame = np.array(latents[0, :, t]) + print(f"Frame {t}: mean={frame.mean():.4f} std={frame.std():.4f}") +``` + +### Analyze video frame quality +```python +import cv2, numpy as np +cap = cv2.VideoCapture("output.mp4") +while True: + ret, frame = cap.read() + if not ret: break + # Checkerboard metric: high values indicate patch-boundary artifacts + checker = np.abs(frame[::2, ::2].astype(float) - frame[1::2, 1::2].astype(float)).mean() + print(f"std={frame.std():.1f} checker={checker:.1f}") +``` + +### Compare weights between PyTorch and MLX +```python +import torch, mlx.core as mx, numpy as np +pt = torch.load("model.pt", map_location="cpu") +mlx_w = mx.load("model.safetensors") +for key in sorted(pt.keys()): + if key in mlx_w: + diff = np.abs(pt[key].float().numpy() - np.array(mlx_w[key])).max() + if diff > 0.01: + print(f"LARGE DIFF {key}: {diff:.6f}") +``` diff --git a/docs/wan22-implementation-notes.md b/docs/wan22-implementation-notes.md index a46f6bc..186aabb 100644 --- a/docs/wan22-implementation-notes.md +++ b/docs/wan22-implementation-notes.md @@ -1,6 +1,6 @@ # Wan2.2 MLX Implementation Notes -> Learnings and key decisions from porting Wan2.2 (TI2V-5B / T2V-14B / T2V-1.3B) to Apple MLX. +> Learnings and key decisions from porting Wan2.2 (TI2V-5B / T2V-14B / I2V-14B / T2V-1.3B) to Apple MLX. ## Architecture Overview @@ -8,11 +8,12 @@ Wan2.2 is a Diffusion Transformer (DiT) for video generation. Despite early repo ### Key Parameters -| Model | dim | heads | layers | FFN mult | VAE z_dim | VAE stride | -|-------|-----|-------|--------|----------|-----------|------------| -| T2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | -| TI2V-5B | 3072 | 24 | 32 | 4×(3072×4/3) | 48 | (4, 16, 16) | -| T2V-1.3B | 1536 | 12 | 30 | 4×(1536×4/3) | 16 | (4, 8, 8) | +| Model | dim | heads | layers | FFN mult | VAE z_dim | VAE stride | in_dim | +|-------|-----|-------|--------|----------|-----------|------------|--------| +| T2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | 16 | +| I2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | 36 | +| TI2V-5B | 3072 | 24 | 32 | 4×(3072×4/3) | 48 | (4, 16, 16) | 48 | +| T2V-1.3B | 1536 | 12 | 30 | 4×(1536×4/3) | 16 | (4, 8, 8) | 16 | ### Codebase Structure (~3900 lines of Wan2.2 code) @@ -139,9 +140,11 @@ Default shifts: T2V-14B uses 5.0, TI2V-5B uses 3.0, T2V-1.3B uses 3.0. --- -## Image-to-Video (I2V) Pipeline +## Image-to-Video (I2V) Pipelines -### Per-Token Timesteps +Wan2.2 supports two distinct I2V approaches: + +### TI2V-5B: Per-Token Timestep Masking I2V conditions on a reference first frame by giving first-frame latent patches a timestep of 0 (clean) while other patches get the current diffusion timestep: @@ -152,7 +155,7 @@ t_tokens = mask_tokens * current_timestep # first-frame → t=0 The model receives 2D timestep input `[B, L]` instead of scalar, enabling per-token noise levels. -### Mask Re-application +#### Mask Re-application After each scheduler step, the first-frame latent is re-injected to prevent drift: @@ -160,7 +163,7 @@ After each scheduler step, the first-frame latent is re-injected to prevent drif latents = (1.0 - mask) * z_img + mask * latents ``` -### VAE Encoder Temporal Downsample Order +#### VAE Encoder Temporal Downsample Order The Wan2.2 VAE encoder has `temporal_downsample = (False, True, True)`: - Stage 0: Spatial-only downsampling @@ -168,6 +171,22 @@ The Wan2.2 VAE encoder has `temporal_downsample = (False, True, True)`: This was incorrectly set to `(True, True, False)` initially, causing wrong spatial processing paths. +### I2V-14B: Channel Concatenation + +The I2V-14B model uses a fundamentally different approach — channel concatenation via a `y` tensor: + +1. **Encode image**: Resize to target (H, W), create video tensor with image as first frame + zeros → VAE encode through Wan2.1 encoder → `[16, T_lat, H_lat, W_lat]` +2. **Build mask**: Binary mask with 1 for first frame, 0 for rest → rearranged to `[4, T_lat, H_lat, W_lat]` +3. **Construct y**: `y = concat([mask_4ch, encoded_16ch])` → `[20, T_lat, H_lat, W_lat]` +4. **Channel concat in model**: Before patchify, `x = concat([noise_16ch, y_20ch])` → 36 channels matching `in_dim=36` + +Key differences from TI2V-5B: +- Uses **Wan2.1 VAE** (z_dim=16, stride 4,8,8), not Wan2.2 VAE +- Requires the **VAE encoder** (for encoding the reference image) +- Uses **scalar timesteps** (same as T2V) — no per-token masking +- **Dual model** pipeline with boundary=0.900 +- Both conditional and unconditional predictions receive the same `y` tensor + --- ## Dimension Constraints @@ -233,7 +252,7 @@ The T2V-14B uses dual models (high-noise and low-noise). The conversion script s ## Testing Strategy -260 tests across 9 files, all running in ~4 seconds: +332 tests across 10 files, all running in ~5 seconds: | File | Focus | |------|-------| @@ -246,6 +265,7 @@ The T2V-14B uses dual models (high-noise and low-noise). The conversion script s | test_wan_scheduler.py | All 3 schedulers, cross-scheduler coherence | | test_wan_convert.py | Weight sanitization and conversion | | test_wan_generate.py | End-to-end pipeline, I2V masks, dimension alignment | +| test_wan_i2v.py | I2V-14B config, y parameter, VAE encoder, mask construction | Tests use a tiny config (`dim=64, heads=2, layers=2`) for fast execution. Cross-scheduler coherence tests verify that all three schedulers produce similar outputs from the same noise. diff --git a/mlx_video/convert_wan.py b/mlx_video/convert_wan.py index e9db2aa..f3f9037 100644 --- a/mlx_video/convert_wan.py +++ b/mlx_video/convert_wan.py @@ -316,6 +316,14 @@ def convert_wan_checkpoint( def _detect_config(): """Detect config from source config.json or transformer weight shapes.""" if is_dual: + # Check source config.json for model_type (I2V vs T2V) + src_cfg_path = checkpoint_dir / "high_noise_model" / "config.json" + if src_cfg_path.exists(): + with open(src_cfg_path) as f: + src_config = json.load(f) + src_model_type = src_config.get("model_type", "t2v") + if src_model_type == "i2v" or src_config.get("in_dim") == 36: + return WanModelConfig.wan22_i2v_14b() return WanModelConfig.wan22_t2v_14b() # Try reading source config.json first (most reliable) @@ -413,7 +421,7 @@ def convert_wan_checkpoint( weights = load_torch_weights(str(vae_path)) if is_wan22_vae: from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights - include_encoder = config.model_type == "ti2v" + include_encoder = config.model_type in ("ti2v", "i2v") weights = sanitize_wan22_vae_weights(weights, include_encoder=include_encoder) else: weights = sanitize_wan_vae_weights(weights) diff --git a/mlx_video/generate_wan.py b/mlx_video/generate_wan.py index 37501fe..1bd4fe7 100644 --- a/mlx_video/generate_wan.py +++ b/mlx_video/generate_wan.py @@ -245,24 +245,71 @@ def generate_video( z_img = None i2v_mask = None i2v_mask_tokens = None + y_i2v = None + is_i2v_channel_concat = is_i2v and config.model_type == "i2v" + is_i2v_mask_blend = is_i2v and config.model_type != "i2v" if is_i2v: print(f"\n{Colors.BLUE}Encoding input image...{Colors.RESET}") t_img = time.time() - img_tensor = preprocess_image(image, width, height) - mx.eval(img_tensor) vae_path = model_dir / "vae.safetensors" - vae_enc = load_vae_encoder(vae_path, config) - z_img = vae_enc(img_tensor) # [1, 1, H_lat, W_lat, z_dim] - mx.eval(z_img) - # Convert to channels-first: [z_dim, 1, H_lat, W_lat] - z_img = z_img[0].transpose(3, 0, 1, 2) + if is_i2v_channel_concat: + # I2V-14B: encode full video (first frame = image, rest = zeros) + # and construct y tensor with mask + encoded latents + from PIL import Image - # Build I2V mask - i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size) + img = Image.open(image).convert("RGB") + scale = max(width / img.width, height / img.height) + img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS) + x1, y1 = (img.width - width) // 2, (img.height - height) // 2 + img = img.crop((x1, y1, x1 + width, y1 + height)) + img_arr = mx.array(np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0) # [H, W, 3] + img_chw = img_arr.transpose(2, 0, 1) # [3, H, W] + + # Build video: first frame = image, rest = zeros -> [3, F, H, W] + # Chunked encoding processes 1-frame + 4-frame chunks with temporal caching + video = mx.concatenate([ + img_chw[:, None, :, :], + mx.zeros((3, num_frames - 1, height, width)), + ], axis=1) + + # Encode through Wan2.1 VAE -> [1, z_dim, T_lat, H_lat, W_lat] + vae_enc = load_vae_encoder(vae_path, config) + z_video = vae_enc.encode(video[None]) # [1, 16, T_lat, H_lat, W_lat] + mx.eval(z_video) + z_video = z_video[0] # [16, T_lat, H_lat, W_lat] + + # Build mask: 1 for first frame, 0 for rest -> rearrange to [4, T_lat, H, W] + msk = mx.ones((1, num_frames, h_latent, w_latent)) + msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1) + # Repeat first frame 4x, concat rest: [1, 4 + (F-1), H_lat, W_lat] + msk = mx.concatenate([ + mx.repeat(msk[:, :1], 4, axis=1), + msk[:, 1:], + ], axis=1) + # Reshape to [1, T_lat, 4, H_lat, W_lat] then transpose -> [4, T_lat, H_lat, W_lat] + msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent) + msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat] + + # y = concat([mask, encoded_video]) -> [20, T_lat, H_lat, W_lat] + y_i2v = mx.concatenate([msk, z_video], axis=0) + mx.eval(y_i2v) + + del vae_enc, img_arr, img_chw, video, z_video, msk + else: + # TI2V-5B: encode single image, blend with noise via mask + img_tensor = preprocess_image(image, width, height) + mx.eval(img_tensor) + + vae_enc = load_vae_encoder(vae_path, config) + z_img = vae_enc(img_tensor) # [1, 1, H_lat, W_lat, z_dim] + mx.eval(z_img) + z_img = z_img[0].transpose(3, 0, 1, 2) # [z_dim, 1, H_lat, W_lat] + i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size) + + del vae_enc, img_tensor - del vae_enc, img_tensor gc.collect(); mx.clear_cache() print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}") @@ -282,23 +329,40 @@ def generate_video( print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}") # Precompute text embeddings once (avoids redundant MLP in every step) - ref_model = single_model if not is_dual else low_noise_model - context_emb = ref_model.embed_text([context, context_null]) - mx.eval(context_emb) - context_cond = context_emb[0:1] # [1, text_len, dim] - context_uncond = context_emb[1:2] # [1, text_len, dim] - # Stack for batched CFG: [2, text_len, dim] - context_cfg = mx.concatenate([context_cond, context_uncond], axis=0) + # Each model has its own text_embedding weights, so dual models need separate embeddings + if is_dual: + context_emb_low = low_noise_model.embed_text([context, context_null]) + context_emb_high = high_noise_model.embed_text([context, context_null]) + mx.eval(context_emb_low, context_emb_high) + context_cfg_low = mx.concatenate([context_emb_low[0:1], context_emb_low[1:2]], axis=0) + context_cfg_high = mx.concatenate([context_emb_high[0:1], context_emb_high[1:2]], axis=0) + else: + context_emb = single_model.embed_text([context, context_null]) + mx.eval(context_emb) + context_cfg = mx.concatenate([context_emb[0:1], context_emb[1:2]], axis=0) # Precompute cross-attention K/V caches (constant across all steps) if is_dual: - cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg) - cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg) + cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg_low) + cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg_high) mx.eval(cross_kv_low, cross_kv_high) else: cross_kv = single_model.prepare_cross_kv(context_cfg) mx.eval(cross_kv) + # Precompute RoPE frequencies (grid sizes are constant across all steps) + f_grid = t_latent // patch_size[0] + h_grid = h_latent // patch_size[1] + w_grid = w_latent // patch_size[2] + cfg_grid_sizes = [(f_grid, h_grid, w_grid), (f_grid, h_grid, w_grid)] + if is_dual: + rope_cos_sin_low = low_noise_model.prepare_rope(cfg_grid_sizes) + rope_cos_sin_high = high_noise_model.prepare_rope(cfg_grid_sizes) + mx.eval(rope_cos_sin_low, rope_cos_sin_high) + else: + rope_cos_sin = ref_model.prepare_rope(cfg_grid_sizes) + mx.eval(rope_cos_sin) + # Setup scheduler _schedulers = { "euler": FlowMatchEulerScheduler, @@ -312,9 +376,8 @@ def generate_video( # Generate initial noise noise = mx.random.normal(target_shape) - # I2V: blend first-frame latent into noise - if is_i2v: - # Broadcast z_img [z_dim, 1, H, W] across T for first-frame conditioning + # I2V initialization: TI2V-5B blends image with noise, I2V-14B uses pure noise + if is_i2v_mask_blend: latents = (1.0 - i2v_mask) * z_img + i2v_mask * noise else: latents = noise @@ -326,26 +389,32 @@ def generate_video( print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}") t3 = time.time() - for i, t in enumerate(tqdm(range(steps), desc="Diffusion")): - timestep_val = sched.timesteps[i].item() + # Pre-convert timesteps to Python list to avoid .item() sync each step + timestep_list = sched.timesteps.tolist() - # Select model, guide scale, and cached K/V + for i, t in enumerate(tqdm(range(steps), desc="Diffusion")): + timestep_val = timestep_list[i] + + # Select model, guide scale, cached K/V, and precomputed RoPE if is_dual: if timestep_val >= boundary: model = high_noise_model gs = guide_scale[1] kv = cross_kv_high + rcs = rope_cos_sin_high else: model = low_noise_model gs = guide_scale[0] kv = cross_kv_low + rcs = rope_cos_sin_low else: model = single_model gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0] kv = cross_kv + rcs = rope_cos_sin - # Build per-token timesteps for I2V (first-frame patches get t=0) - if is_i2v: + # Build per-token timesteps for TI2V-5B (first-frame patches get t=0) + if is_i2v_mask_blend: t_tokens = i2v_mask_tokens * timestep_val # [1, L] # Pad to seq_len if needed pad_len = seq_len - t_tokens.shape[1] @@ -358,22 +427,31 @@ def generate_video( else: t_batch = mx.array([timestep_val, timestep_val]) + # I2V-14B: pass y conditioning to model (same y for cond and uncond) + y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None + # CFG: batch cond + uncond into single B=2 forward pass + ctx = context_cfg if not is_dual else ( + context_cfg_high if timestep_val >= boundary else context_cfg_low + ) preds = model( [latents, latents], t=t_batch, - context=context_cfg, + context=ctx, seq_len=seq_len, cross_kv_caches=kv, + y=y_arg, + rope_cos_sin=rcs, ) noise_pred_cond, noise_pred_uncond = preds[0], preds[1] # Classifier-free guidance + scheduler step noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond) + latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0) - # I2V: re-apply mask to keep first frame frozen - if is_i2v: + # TI2V-5B: re-apply mask to keep first frame frozen + if is_i2v_mask_blend: latents = (1.0 - i2v_mask) * z_img + i2v_mask * latents # Release temporaries before eval to free memory for graph execution @@ -385,9 +463,11 @@ def generate_video( # Free transformer models and text embeddings if is_dual: del low_noise_model, high_noise_model, cross_kv_low, cross_kv_high + del context_cfg_low, context_cfg_high else: del single_model, cross_kv - del model, kv, context, context_null, context_cfg + del context_cfg + del model, kv, context, context_null gc.collect(); mx.clear_cache() # Load VAE and decode diff --git a/mlx_video/models/wan/attention.py b/mlx_video/models/wan/attention.py index 0cab5bb..e3fe24a 100644 --- a/mlx_video/models/wan/attention.py +++ b/mlx_video/models/wan/attention.py @@ -67,6 +67,8 @@ class WanSelfAttention(nn.Module): seq_lens: list, grid_sizes: list, freqs: mx.array, + rope_cos_sin: tuple | None = None, + attn_mask: mx.array | None = None, ) -> mx.array: b, s, _ = x.shape n, d = self.num_heads, self.head_dim @@ -87,19 +89,18 @@ class WanSelfAttention(nn.Module): v = self.v(x_w).reshape(b, s, n, d) # RoPE in float32 for precision (official uses float64) - q = rope_apply(q.astype(mx.float32), grid_sizes, freqs) - k = rope_apply(k.astype(mx.float32), grid_sizes, freqs) + q = rope_apply(q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin) + k = rope_apply(k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin) # Cast back to weight dtype for efficient attention (matching official q.to(v.dtype)) q = q.astype(w_dtype).transpose(0, 2, 1, 3) k = k.astype(w_dtype).transpose(0, 2, 1, 3) v = v.transpose(0, 2, 1, 3) - # Build attention mask from seq_lens - max_len = s - mask = None - if any(sl < max_len for sl in seq_lens): - mask = mx.zeros((b, 1, 1, max_len), dtype=q.dtype) + # Use precomputed mask or build from seq_lens + mask = attn_mask + if mask is None and any(sl < s for sl in seq_lens): + mask = mx.zeros((b, 1, 1, s), dtype=q.dtype) for i, sl in enumerate(seq_lens): mask[i, :, :, sl:] = -1e9 diff --git a/mlx_video/models/wan/config.py b/mlx_video/models/wan/config.py index cae72d2..08370d4 100644 --- a/mlx_video/models/wan/config.py +++ b/mlx_video/models/wan/config.py @@ -91,6 +91,19 @@ class WanModelConfig(BaseModelConfig): """Wan2.2 T2V 14B: dual model, 40 layers, dim=5120 (default).""" return cls() + @classmethod + def wan22_i2v_14b(cls) -> "WanModelConfig": + """Wan2.2 I2V 14B: dual model, image-to-video, 40 layers, dim=5120.""" + return cls( + model_type="i2v", + in_dim=36, + out_dim=16, + dual_model=True, + boundary=0.900, + sample_shift=5.0, + sample_guide_scale=(3.5, 3.5), + ) + @classmethod def wan22_ti2v_5b(cls) -> "WanModelConfig": """Wan2.2 TI2V 5B: text+image to video, 30 layers, dim=3072.""" diff --git a/mlx_video/models/wan/loading.py b/mlx_video/models/wan/loading.py index 8acc770..4ef795b 100644 --- a/mlx_video/models/wan/loading.py +++ b/mlx_video/models/wan/loading.py @@ -87,16 +87,23 @@ def load_vae_decoder(model_path: Path, config=None): def load_vae_encoder(model_path: Path, config=None): """Load VAE encoder for I2V image encoding. - Only supports Wan2.2 (vae_z_dim=48). + For Wan2.2 TI2V (vae_z_dim=48), uses Wan22VAEEncoder. + For Wan2.1/I2V-14B (vae_z_dim=16), uses WanVAE with encoder=True. """ - from mlx_video.models.wan.vae22 import Wan22VAEEncoder + if config is not None and config.vae_z_dim == 16: + from mlx_video.models.wan.vae import WanVAE + + vae = WanVAE(z_dim=16, encoder=True) + else: + from mlx_video.models.wan.vae22 import Wan22VAEEncoder + + vae = Wan22VAEEncoder(z_dim=config.vae_z_dim if config else 48) - encoder = Wan22VAEEncoder(z_dim=config.vae_z_dim) weights = mx.load(str(model_path)) weights = {k: v.astype(mx.float32) for k, v in weights.items()} - encoder.load_weights(list(weights.items()), strict=False) - mx.eval(encoder.parameters()) - return encoder + vae.load_weights(list(weights.items()), strict=False) + mx.eval(vae.parameters()) + return vae def _clean_text(text: str) -> str: diff --git a/mlx_video/models/wan/model.py b/mlx_video/models/wan/model.py index f3689ec..b253fb3 100644 --- a/mlx_video/models/wan/model.py +++ b/mlx_video/models/wan/model.py @@ -6,7 +6,7 @@ import numpy as np from .attention import WanLayerNorm from .config import WanModelConfig -from .rope import rope_params +from .rope import rope_params, rope_precompute_cos_sin from .transformer import WanAttentionBlock @@ -38,7 +38,7 @@ class Head(nn.Module): proj_dim = math.prod(patch_size) * out_dim self.norm = WanLayerNorm(dim, eps) self.head = nn.Linear(dim, proj_dim) - self.modulation = mx.random.normal((1, 2, dim)) * (dim**-0.5) + self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(mx.float32) def __call__(self, x: mx.array, e: mx.array) -> mx.array: """ @@ -48,14 +48,13 @@ class Head(nn.Module): """ if e.ndim == 2: e = e[:, None, :] # [B, 1, dim] - e_f32 = e.astype(mx.float32) - # modulation [1, 2, dim] broadcasts with e [B, 1/L, dim] via unsqueeze - mod = self.modulation.astype(mx.float32)[:, None, :, :] + e_f32[:, :, None, :] # [B, L_e, 2, dim] + # modulation already float32; e already float32 from model forward + mod = self.modulation[:, None, :, :] + e[:, :, None, :] # [B, L_e, 2, dim] e0 = mod[:, :, 0, :] # [B, L_e, dim] shift e1 = mod[:, :, 1, :] # [B, L_e, dim] scale - x_norm = self.norm(x).astype(mx.float32) - x_mod = x_norm * (1 + e1) + e0 # broadcasts over L if L_e==1 - return self.head(x_mod.astype(x.dtype)) + x_norm = self.norm(x) + x_mod = x_norm * (1 + e1) + e0 # type promotion handles bf16→f32 + return self.head(x_mod.astype(self.head.weight.dtype)) class WanModel(nn.Module): @@ -109,17 +108,16 @@ class WanModel(nn.Module): # Output head self.head = Head(dim, config.out_dim, config.patch_size, config.eps) - # Precompute RoPE frequencies - d = dim // config.num_heads - d_t = d - 4 * (d // 6) - d_h = 2 * (d // 6) - d_w = 2 * (d // 6) - # Each rope_params returns [1024, d_x//2, 2] - freqs_t = rope_params(1024, d_t) - freqs_h = rope_params(1024, d_h) - freqs_w = rope_params(1024, d_w) - # Concatenate along the frequency dimension: [1024, d//2, 2] - self.freqs = mx.concatenate([freqs_t, freqs_h, freqs_w], axis=1) + # Precompute RoPE frequencies — single table, split by rope_apply + # Reference computes one rope_params(head_dim) and splits into t/h/w. + self.freqs = rope_params(1024, dim // config.num_heads) + + # Precompute sinusoidal inv_freq for time embedding + half = config.freq_dim // 2 + self._inv_freq = mx.power( + 10000.0, -mx.arange(half).astype(mx.float32) / half + ) + def _patchify(self, x: mx.array) -> tuple: """Convert video tensor to patch embeddings. @@ -215,6 +213,21 @@ class WanModel(nn.Module): kv_caches.append(block.cross_attn.prepare_kv(context)) return kv_caches + def prepare_rope(self, grid_sizes: list) -> tuple: + """Pre-compute RoPE cos/sin for constant grid sizes. + + Call once before the diffusion loop when grid sizes don't change + across steps. Eliminates per-step broadcast/concat overhead. + + Args: + grid_sizes: List of (F, H, W) tuples per batch element + + Returns: + (cos_f, sin_f) precomputed frequency tensors + """ + w_dtype = self.patch_embedding_proj.weight.dtype + return rope_precompute_cos_sin(grid_sizes, self.freqs, dtype=w_dtype) + def __call__( self, x_list: list, @@ -222,6 +235,8 @@ class WanModel(nn.Module): context: list | mx.array, seq_len: int, cross_kv_caches: list | None = None, + y: list | None = None, + rope_cos_sin: tuple | None = None, ) -> list: """Forward pass. @@ -233,42 +248,70 @@ class WanModel(nn.Module): seq_len: Maximum sequence length for padding cross_kv_caches: Optional list of (k, v) tuples from prepare_cross_kv(), one per block. + y: Optional list of conditioning tensors for I2V [C_y, F, H, W]. + Channel-concatenated with x before patchify. + rope_cos_sin: Optional precomputed (cos, sin) from prepare_rope(). Returns: List of denoised tensors [C, F, H, W] """ - # Patchify each video - patches = [] - grid_sizes = [] - seq_lens_list = [] - for vid in x_list: - p, gs = self._patchify(vid) # [1, L, dim] - patches.append(p) - grid_sizes.append(gs) - seq_lens_list.append(p.shape[1]) + # Detect identical inputs (CFG B=2) to avoid duplicate patchify work. + # Check BEFORE I2V concat since concat creates new array objects. + batch_size = len(x_list) + all_same = batch_size > 1 and all( + x_list[i] is x_list[0] for i in range(1, batch_size) + ) + if all_same and y is not None: + all_same = all(y[i] is y[0] for i in range(1, len(y))) - # Pad and batch - batch_size = len(patches) - x = mx.concatenate( - [ - mx.concatenate( + # I2V: channel-concatenate conditioning y with noise x + if y is not None: + x_list = [mx.concatenate([u, v], axis=0) for u, v in zip(x_list, y)] + + if all_same: + # Patchify once and broadcast — saves a Linear projection per step + p, gs = self._patchify(x_list[0]) # [1, L, dim] + grid_sizes = [gs] * batch_size + seq_lens_list = [p.shape[1]] * batch_size + # Pad and broadcast + if p.shape[1] < seq_len: + p = mx.concatenate( [p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)], axis=1, ) - if p.shape[1] < seq_len - else p - for p in patches - ], - axis=0, - ) # [B, seq_len, dim] + x = mx.broadcast_to(p, (batch_size,) + p.shape[1:]) + else: + patches = [] + grid_sizes = [] + seq_lens_list = [] + for vid in x_list: + p, gs = self._patchify(vid) # [1, L, dim] + patches.append(p) + grid_sizes.append(gs) + seq_lens_list.append(p.shape[1]) + x = mx.concatenate( + [ + mx.concatenate( + [p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)], + axis=1, + ) + if p.shape[1] < seq_len + else p + for p in patches + ], + axis=0, + ) # [B, seq_len, dim] - # Time embedding + # Time embedding (use cached inv_freq to avoid recomputing each step) if t.ndim == 0: t = t[None] + pos = t.astype(mx.float32) + sinusoid = pos[..., None] * self._inv_freq + sin_emb = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1) + if t.ndim == 1: # Standard T2V: scalar timestep per batch element [B] - sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, freq_dim] e = self.time_embedding_1( self.time_embedding_act(self.time_embedding_0(sin_emb)) ) # [B, dim] @@ -278,7 +321,6 @@ class WanModel(nn.Module): e = e.astype(mx.float32) else: # I2V: per-token timesteps [B, L] - sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, L, freq_dim] e = self.time_embedding_1( self.time_embedding_act(self.time_embedding_0(sin_emb)) ) # [B, L, dim] @@ -298,7 +340,15 @@ class WanModel(nn.Module): else: context_batch = self.embed_text(context) - # Run transformer blocks + # Pre-compute attention mask from seq_lens (constant across all blocks) + attn_mask = None + w_dtype = self.patch_embedding_proj.weight.dtype + if any(sl < seq_len for sl in seq_lens_list): + attn_mask = mx.zeros((batch_size, 1, 1, seq_len), dtype=w_dtype) + for i, sl in enumerate(seq_lens_list): + attn_mask[i, :, :, sl:] = -1e9 + + kwargs = dict( e=e0, seq_lens=seq_lens_list, @@ -306,8 +356,11 @@ class WanModel(nn.Module): freqs=self.freqs, context=context_batch, context_lens=None, + rope_cos_sin=rope_cos_sin, + attn_mask=attn_mask, ) + # Run transformer blocks for i, block in enumerate(self.blocks): kv = cross_kv_caches[i] if cross_kv_caches is not None else None x = block(x, cross_kv_cache=kv, **kwargs) diff --git a/mlx_video/models/wan/rope.py b/mlx_video/models/wan/rope.py index 0983031..d992607 100644 --- a/mlx_video/models/wan/rope.py +++ b/mlx_video/models/wan/rope.py @@ -28,6 +28,7 @@ def rope_apply( x: mx.array, grid_sizes: list, freqs: mx.array, + precomputed_cos_sin: tuple | None = None, ) -> mx.array: """Apply 3-way factorized RoPE to Q or K tensor. @@ -35,10 +36,48 @@ def rope_apply( x: Shape [B, L, num_heads, head_dim] grid_sizes: List of (F, H, W) tuples per batch element freqs: Precomputed cos/sin, shape [1024, d//2, 2] split into 3 parts + precomputed_cos_sin: Optional (cos, sin) from rope_precompute_cos_sin() """ b, s, n, d = x.shape half_d = d // 2 + if precomputed_cos_sin is not None: + cos_f, sin_f = precomputed_cos_sin + # Check if all batch elements have the same grid (common for CFG B=2) + f0, h0, w0 = grid_sizes[0] + seq_len = f0 * h0 * w0 + all_same_grid = all( + grid_sizes[i] == grid_sizes[0] for i in range(1, b) + ) if b > 1 else True + + if all_same_grid: + # Vectorized path: apply RoPE to all batch elements at once + x_seq = x[:, :seq_len].reshape(b, seq_len, n, half_d, 2) + x_real = x_seq[..., 0] + x_imag = x_seq[..., 1] + out_real = x_real * cos_f - x_imag * sin_f + out_imag = x_real * sin_f + x_imag * cos_f + x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(b, seq_len, n, d) + if seq_len < s: + x_rotated = mx.concatenate([x_rotated, x[:, seq_len:]], axis=1) + return x_rotated + else: + # Per-element path for mixed grid sizes + outputs = [] + for i in range(b): + f, h, w = grid_sizes[i] + sl = f * h * w + x_i = x[i, :sl].reshape(sl, n, half_d, 2) + x_real = x_i[..., 0] + x_imag = x_i[..., 1] + out_real = x_real * cos_f - x_imag * sin_f + out_imag = x_real * sin_f + x_imag * cos_f + x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(sl, n, d) + if sl < s: + x_rotated = mx.concatenate([x_rotated, x[i, sl:]], axis=0) + outputs.append(x_rotated) + return mx.stack(outputs) + # Cast freqs to input dtype to prevent float32 promotion cascade if freqs.dtype != x.dtype: freqs = freqs.astype(x.dtype) @@ -98,3 +137,42 @@ def rope_apply( outputs.append(x_rotated) return mx.stack(outputs) + + +def rope_precompute_cos_sin( + grid_sizes: list, freqs: mx.array, dtype: type = mx.float32 +) -> tuple: + """Precompute cos/sin frequency tensors for constant grid sizes. + + Call once before the diffusion loop. Pass result as precomputed_cos_sin + to rope_apply to skip per-step broadcast/concat. + + Args: + grid_sizes: List of (F, H, W) tuples (must be same for all batch elements) + freqs: Precomputed frequencies [1024, d//2, 2] + dtype: Target dtype for the output tensors + + Returns: + (cos_f, sin_f) each [seq_len, 1, half_d] + """ + if freqs.dtype != dtype: + freqs = freqs.astype(dtype) + + f, h, w = grid_sizes[0] + seq_len = f * h * w + half_d = freqs.shape[1] + + d_t = half_d - 2 * (half_d // 3) + d_h = half_d // 3 + d_w = half_d // 3 + + freqs_t = freqs[:, :d_t] + freqs_h = freqs[:, d_t : d_t + d_h] + freqs_w = freqs[:, d_t + d_h : d_t + d_h + d_w] + + ft = mx.broadcast_to(freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2)) + fh = mx.broadcast_to(freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2)) + fw = mx.broadcast_to(freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2)) + + freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2) + return freqs_i[..., 0], freqs_i[..., 1] diff --git a/mlx_video/models/wan/scheduler.py b/mlx_video/models/wan/scheduler.py index 946707c..1ea6b98 100644 --- a/mlx_video/models/wan/scheduler.py +++ b/mlx_video/models/wan/scheduler.py @@ -34,6 +34,8 @@ class FlowMatchEulerScheduler: sigmas = _compute_sigmas(num_steps, shift) self.sigmas = mx.array(sigmas) self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps) + # Store as Python floats to avoid .item() sync in step() + self._sigmas_float = sigmas.tolist() self._step_index = 0 def step( @@ -43,9 +45,7 @@ class FlowMatchEulerScheduler: sample: mx.array, ) -> mx.array: """Euler step: x_next = x + (sigma_next - sigma_cur) * v.""" - dt = float(self.sigmas[self._step_index + 1].item()) - float( - self.sigmas[self._step_index].item() - ) + dt = self._sigmas_float[self._step_index + 1] - self._sigmas_float[self._step_index] x_next = sample + dt * model_output self._step_index += 1 return x_next diff --git a/mlx_video/models/wan/transformer.py b/mlx_video/models/wan/transformer.py index 7611638..c85c90e 100644 --- a/mlx_video/models/wan/transformer.py +++ b/mlx_video/models/wan/transformer.py @@ -35,8 +35,8 @@ class WanAttentionBlock(nn.Module): self.norm2 = WanLayerNorm(dim, eps) self.ffn = WanFFN(dim, ffn_dim) - # Learned modulation: 6 vectors for scale/shift/gate - self.modulation = mx.random.normal((1, 6, dim)) * (dim**-0.5) + # Learned modulation: 6 vectors for scale/shift/gate (kept in float32 for precision) + self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(mx.float32) def __call__( self, @@ -48,10 +48,11 @@ class WanAttentionBlock(nn.Module): context: mx.array, context_lens: list | None = None, cross_kv_cache: tuple | None = None, + rope_cos_sin: tuple | None = None, + attn_mask: mx.array | None = None, ) -> mx.array: - # Modulation in float32 (matching official torch.amp.autocast float32) - e_f32 = e.astype(mx.float32) - mod = self.modulation.astype(mx.float32) + e_f32 + # Modulation in float32 (e is already float32 from model forward) + mod = self.modulation + e e0 = mod[:, :, 0, :] # shift for self-attn e1 = mod[:, :, 1, :] # scale for self-attn e2 = mod[:, :, 2, :] # gate for self-attn @@ -59,19 +60,20 @@ class WanAttentionBlock(nn.Module): e4 = mod[:, :, 4, :] # scale for ffn e5 = mod[:, :, 5, :] # gate for ffn - # Self-attention with modulation (norm output in float32) - x_mod = self.norm1(x).astype(mx.float32) * (1 + e1) + e0 - y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs) - x = x.astype(mx.float32) + y.astype(mx.float32) * e2 + # Self-attention with modulation + # Type promotion handles bf16→f32 automatically when multiplied with f32 modulation + x_mod = self.norm1(x) * (1 + e1) + e0 + y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs, rope_cos_sin=rope_cos_sin, attn_mask=attn_mask) + x = x + y * e2 # Cross-attention (no modulation, just norm) x_cross = self.norm3(x) if self.norm3 is not None else x x = x + self.cross_attn(x_cross, context, context_lens, kv_cache=cross_kv_cache) - # FFN with modulation (norm output in float32) - x_mod = self.norm2(x).astype(mx.float32) * (1 + e4) + e3 + # FFN with modulation + x_mod = self.norm2(x) * (1 + e4) + e3 y = self.ffn(x_mod) - x = x + y.astype(mx.float32) * e5 + x = x + y * e5 return x diff --git a/mlx_video/models/wan/vae.py b/mlx_video/models/wan/vae.py index aeac5a1..fe8ccaf 100644 --- a/mlx_video/models/wan/vae.py +++ b/mlx_video/models/wan/vae.py @@ -43,7 +43,9 @@ class CausalConv3d(nn.Module): self.kernel_size = kernel_size self.stride = stride - self._causal_pad_t = 2 * padding[0] + # Causal padding: match reference formula dilation*(k-1) + (1-stride) + # With dilation=1: k-stride (pads left only, no future context) + self._causal_pad_t = kernel_size[0] - stride[0] self._pad_h = padding[1] self._pad_w = padding[2] @@ -51,12 +53,17 @@ class CausalConv3d(nn.Module): self.weight = mx.zeros((out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels)) self.bias = mx.zeros((out_channels,)) - def __call__(self, x: mx.array) -> mx.array: + def __call__(self, x: mx.array, cache_x: mx.array = None) -> mx.array: """x: [B, C, T, H, W] (channel-first)""" b, c, t, h, w = x.shape - if self._causal_pad_t > 0: - pad_t = mx.zeros((b, c, self._causal_pad_t, h, w), dtype=x.dtype) + causal_pad = self._causal_pad_t + if cache_x is not None and causal_pad > 0: + x = mx.concatenate([cache_x, x], axis=2) + causal_pad = max(0, causal_pad - cache_x.shape[2]) + + if causal_pad > 0: + pad_t = mx.zeros((b, c, causal_pad, h, w), dtype=x.dtype) x = mx.concatenate([pad_t, x], axis=2) if self._pad_h > 0 or self._pad_w > 0: @@ -136,12 +143,35 @@ class ResidualBlock(nn.Module): ] self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None - def __call__(self, x: mx.array) -> mx.array: + def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array: h = x if self.shortcut is None else self.shortcut(x) - x = nn.silu(self.residual[0](x)) - x = self.residual[2](x) - x = nn.silu(self.residual[3](x)) - x = self.residual[6](x) + + if feat_cache is not None: + # First conv: norm -> silu -> [cache] -> conv + x = nn.silu(self.residual[0](x)) + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:] + if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None: + cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2) + x = self.residual[2](x, cache_x=feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + # Second conv: norm -> silu -> [cache] -> conv + x = nn.silu(self.residual[3](x)) + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:] + if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None: + cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2) + x = self.residual[6](x, cache_x=feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = nn.silu(self.residual[0](x)) + x = self.residual[2](x) + x = nn.silu(self.residual[3](x)) + x = self.residual[6](x) + return x + h @@ -180,23 +210,31 @@ class AttentionBlock(nn.Module): class Resample(nn.Module): - """Upsample block matching original Wan VAE structure. + """Resample block matching original Wan VAE structure. - Uses `resample` list with [None, Conv2d] to match original - nn.Sequential(Upsample, Conv2d) where index 1 has the conv params. + Supports both upsampling (decoder) and downsampling (encoder). + Uses list-based param storage to match original nn.Sequential key hierarchy. """ def __init__(self, dim: int, mode: str): super().__init__() - assert mode in ("upsample2d", "upsample3d") + assert mode in ("upsample2d", "upsample3d", "downsample2d", "downsample3d") self.mode = mode self.dim = dim - # resample.0 = Upsample (no params), resample.1 = Conv2d - self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)] - if mode == "upsample3d": - self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) - def __call__(self, x: mx.array) -> mx.array: + if mode.startswith("upsample"): + # resample.0 = Upsample (no params), resample.1 = Conv2d + self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)] + if mode == "upsample3d": + self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + else: + # resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2) + self.resample = [None, nn.Conv2d(dim, dim, 3, stride=2)] + if mode == "downsample3d": + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array: """x: [B, C, T, H, W]""" b, c, t, h, w = x.shape @@ -204,17 +242,43 @@ class Resample(nn.Module): # Temporal upsample via learned conv x_t = self.time_conv(x) # [B, 2C, T, H, W] x_t = x_t.reshape(b, 2, c, t, h, w) - # Interleave along time: [B, C, 2T, H, W] x = mx.stack([x_t[:, 0], x_t[:, 1]], axis=3).reshape(b, c, t * 2, h, w) t = t * 2 - # Per-frame spatial upsample: nearest 2x + Conv2d - x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C] - x = mx.repeat(x, 2, axis=1) - x = mx.repeat(x, 2, axis=2) - x = self.resample[1](x) # Conv2d [BT, 2H, 2W, C//2] - c_out = x.shape[-1] - return x.reshape(b, t, h * 2, w * 2, c_out).transpose(0, 4, 1, 2, 3) + if self.mode.startswith("upsample"): + # Per-frame spatial upsample: nearest 2x + Conv2d + x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C] + x = mx.repeat(x, 2, axis=1) + x = mx.repeat(x, 2, axis=2) + x = self.resample[1](x) # Conv2d [BT, 2H, 2W, C//2] + c_out = x.shape[-1] + return x.reshape(b, t, h * 2, w * 2, c_out).transpose(0, 4, 1, 2, 3) + else: + # Per-frame spatial downsample: ZeroPad(0,1,0,1) + Conv2d(stride=2) + x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C] + x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) # ZeroPad2d(0,1,0,1) + x = self.resample[1](x) # Conv2d stride=2 + c_out = x.shape[-1] + h_out, w_out = x.shape[1], x.shape[2] + x = x.reshape(b, t, h_out, w_out, c_out).transpose(0, 4, 1, 2, 3) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + # First chunk: save x, skip time_conv + feat_cache[idx] = x + feat_idx[0] += 1 + else: + # Subsequent chunks: use cached frame as temporal context + cache_x = x[:, :, -1:] + x = self.time_conv( + x, cache_x=feat_cache[idx][:, :, -1:]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.time_conv(x) + return x class Decoder3d(nn.Module): @@ -284,10 +348,108 @@ class Decoder3d(nn.Module): return x -class WanVAE(nn.Module): - """Wan2.1 VAE wrapper with per-channel normalization.""" +class Encoder3d(nn.Module): + """3D VAE Encoder matching Wan2.1 architecture. - def __init__(self, z_dim: int = 16): + Mirror of Decoder3d with downsampling instead of upsampling. + Uses flat lists to match original PyTorch nn.Sequential weight key hierarchy. + """ + + def __init__( + self, + dim: int = 96, + z_dim: int = 16, + dim_mult: list = None, + num_res_blocks: int = 2, + temporal_downsample: list = None, + ): + super().__init__() + if dim_mult is None: + dim_mult = [1, 2, 4, 4] + if temporal_downsample is None: + temporal_downsample = [False, True, True] + + dims = [dim * u for u in [1] + dim_mult] + + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # Flat downsample list matching original nn.Sequential indexing + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim)) + in_dim = out_dim + if i != len(dim_mult) - 1: + mode = "downsample3d" if temporal_downsample[i] else "downsample2d" + downsamples.append(Resample(out_dim, mode=mode)) + self.downsamples = downsamples + + # Middle: [ResBlock, AttentionBlock, ResBlock] + self.middle = [ + ResidualBlock(dims[-1], dims[-1]), + AttentionBlock(dims[-1]), + ResidualBlock(dims[-1], dims[-1]), + ] + + # Output head: [RMS_norm, SiLU (no params), CausalConv3d] + self.head = [ + RMS_norm(dims[-1], images=False), + None, # SiLU + CausalConv3d(dims[-1], z_dim, 3, padding=1), + ] + + def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array: + """x: [B, 3, T, H, W] -> [B, z_dim, T_lat, H_lat, W_lat]""" + if feat_cache is not None: + # conv1 with caching + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:] + if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None: + cache_x = mx.concatenate( + [feat_cache[idx][:, :, -1:], cache_x], axis=2) + x = self.conv1(x, cache_x=feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + for layer in self.downsamples: + if feat_cache is not None and isinstance(layer, (ResidualBlock, Resample)): + x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx) + else: + x = layer(x) + + for layer in self.middle: + if feat_cache is not None and isinstance(layer, ResidualBlock): + x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx) + else: + x = layer(x) + + if feat_cache is not None: + # Head: norm -> silu -> [cache] -> conv + x = nn.silu(self.head[0](x)) + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:] + if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None: + cache_x = mx.concatenate( + [feat_cache[idx][:, :, -1:], cache_x], axis=2) + x = self.head[2](x, cache_x=feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = nn.silu(self.head[0](x)) + x = self.head[2](x) + + return x + + +class WanVAE(nn.Module): + """Wan2.1 VAE wrapper with per-channel normalization. + + Supports both encode (for I2V) and decode (for all models). + """ + + def __init__(self, z_dim: int = 16, encoder: bool = False): super().__init__() self.z_dim = z_dim self.mean = mx.array(VAE_MEAN) @@ -297,6 +459,65 @@ class WanVAE(nn.Module): self.conv2 = CausalConv3d(z_dim, z_dim, 1) self.decoder = Decoder3d(dim=96, z_dim=z_dim) + if encoder: + self.encoder = Encoder3d(dim=96, z_dim=z_dim * 2) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + + def encode(self, x: mx.array) -> mx.array: + """Encode video to normalized latent using chunked encoding. + + Uses chunked encoding with temporal caching to match reference behavior. + First frame encoded alone, then 4-frame chunks with cached context. + + Args: + x: Video [B, 3, T, H, W] in [-1, 1] + + Returns: + Normalized latent [B, z_dim, T_lat, H_lat, W_lat] + """ + # Count cacheable CausalConv3d slots in encoder + num_slots = self._count_encoder_cache_slots() + feat_cache = [None] * num_slots + + t = x.shape[2] + num_chunks = 1 + (t - 1) // 4 + + out = None + for i in range(num_chunks): + feat_idx = [0] + if i == 0: + chunk = x[:, :, :1] + else: + chunk = x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i] + + chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx) + + if out is None: + out = chunk_out + else: + out = mx.concatenate([out, chunk_out], axis=2) + + mu, _ = mx.split(self.conv1(out), 2, axis=1) + + # Normalize: (mu - mean) * inv_std + mean = self.mean.reshape(1, -1, 1, 1, 1) + inv_std = self.inv_std.reshape(1, -1, 1, 1, 1) + return (mu - mean) * inv_std + + def _count_encoder_cache_slots(self) -> int: + """Count CausalConv3d that participate in chunked encoding cache.""" + count = 1 # encoder.conv1 + for layer in self.encoder.downsamples: + if isinstance(layer, ResidualBlock): + count += 2 # two convs in residual path + elif isinstance(layer, Resample) and layer.mode == "downsample3d": + count += 1 # time_conv + for layer in self.encoder.middle: + if isinstance(layer, ResidualBlock): + count += 2 + count += 1 # encoder.head CausalConv3d + return count + def decode(self, z: mx.array) -> mx.array: """Decode latent to video. diff --git a/tests/test_wan_i2v.py b/tests/test_wan_i2v.py new file mode 100644 index 0000000..53077a0 --- /dev/null +++ b/tests/test_wan_i2v.py @@ -0,0 +1,293 @@ +"""Tests for Wan2.2 I2V-14B support.""" + +import mlx.core as mx +import numpy as np +import pytest + +from wan_test_helpers import _make_tiny_config + + +def _make_tiny_i2v_config(): + """Create a tiny I2V-14B config for testing.""" + config = _make_tiny_config() + config.model_type = "i2v" + config.in_dim = 9 # 4 noise + 4 image + 1 mask (scaled down from 16+16+4=36) + config.out_dim = 4 + config.vae_z_dim = 4 + config.vae_stride = (4, 8, 8) + config.dual_model = True + config.boundary = 0.900 + config.sample_shift = 5.0 + config.sample_guide_scale = (3.5, 3.5) + config.teacache_coefficients = None + return config + + +class TestI2VConfig: + """Test I2V-14B config preset.""" + + def test_wan22_i2v_14b_preset(self): + from mlx_video.models.wan.config import WanModelConfig + + config = WanModelConfig.wan22_i2v_14b() + assert config.model_type == "i2v" + assert config.in_dim == 36 + assert config.out_dim == 16 + assert config.dim == 5120 + assert config.num_layers == 40 + assert config.dual_model is True + assert config.boundary == 0.900 + assert config.sample_shift == 5.0 + assert config.sample_guide_scale == (3.5, 3.5) + assert config.vae_stride == (4, 8, 8) + assert config.vae_z_dim == 16 + assert config.teacache_coefficients is None + + def test_i2v_vs_t2v_differences(self): + from mlx_video.models.wan.config import WanModelConfig + + i2v = WanModelConfig.wan22_i2v_14b() + t2v = WanModelConfig.wan22_t2v_14b() + + assert i2v.model_type == "i2v" + assert t2v.model_type == "t2v" + assert i2v.in_dim == 36 and t2v.in_dim == 16 + assert i2v.boundary == 0.900 and t2v.boundary == 0.875 + assert i2v.sample_shift == 5.0 and t2v.sample_shift == 12.0 + + def test_i2v_serialization_roundtrip(self): + from mlx_video.models.wan.config import WanModelConfig + + config = WanModelConfig.wan22_i2v_14b() + d = config.to_dict() + restored = WanModelConfig.from_dict(d) + assert restored.model_type == "i2v" + assert restored.in_dim == 36 + assert restored.boundary == 0.900 + + +class TestModelYParameter: + """Test y parameter channel concatenation in WanModel.""" + + def test_forward_without_y(self): + """Standard T2V forward pass (no y) still works.""" + from mlx_video.models.wan.model import WanModel + + config = _make_tiny_config() + model = WanModel(config) + + C, F, H, W = config.in_dim, 1, 4, 4 + pt, ph, pw = config.patch_size + seq_len = (F // pt) * (H // ph) * (W // pw) + + x_list = [mx.random.normal((C, F, H, W))] + t = mx.array([500.0]) + context = [mx.random.normal((6, config.text_dim))] + + out = model(x_list, t, context, seq_len) + mx.eval(out[0]) + assert out[0].shape == (C, F, H, W) + + def test_forward_with_y(self): + """I2V forward pass with y channel concatenation.""" + from mlx_video.models.wan.model import WanModel + + config = _make_tiny_i2v_config() + model = WanModel(config) + + C_noise = 4 # noise channels + C_y = 5 # mask (1) + image (4) + F, H, W = 1, 4, 4 + pt, ph, pw = config.patch_size + seq_len = (F // pt) * (H // ph) * (W // pw) + + x_list = [mx.random.normal((C_noise, F, H, W))] + y_list = [mx.random.normal((C_y, F, H, W))] + t = mx.array([500.0]) + context = [mx.random.normal((6, config.text_dim))] + + out = model(x_list, t, context, seq_len, y=y_list) + mx.eval(out[0]) + # Output should match noise channels (out_dim), not concatenated in_dim + assert out[0].shape == (config.out_dim, F, H, W) + + def test_y_none_is_noop(self): + """Passing y=None should be identical to not passing y.""" + from mlx_video.models.wan.model import WanModel + + config = _make_tiny_config() + model = WanModel(config) + + C, F, H, W = config.in_dim, 1, 4, 4 + pt, ph, pw = config.patch_size + seq_len = (F // pt) * (H // ph) * (W // pw) + + mx.random.seed(42) + x = mx.random.normal((C, F, H, W)) + t = mx.array([500.0]) + ctx = [mx.random.normal((6, config.text_dim))] + + out1 = model([x], t, ctx, seq_len)[0] + out2 = model([x], t, ctx, seq_len, y=None)[0] + mx.eval(out1, out2) + assert mx.allclose(out1, out2, atol=1e-5).item() + + def test_batched_cfg_with_y(self): + """Batched CFG (B=2) with y should work.""" + from mlx_video.models.wan.model import WanModel + + config = _make_tiny_i2v_config() + model = WanModel(config) + + C_noise, C_y = 4, 5 + F, H, W = 1, 4, 4 + pt, ph, pw = config.patch_size + seq_len = (F // pt) * (H // ph) * (W // pw) + + latents = mx.random.normal((C_noise, F, H, W)) + y = mx.random.normal((C_y, F, H, W)) + t = mx.array([500.0, 500.0]) + ctx = [mx.random.normal((6, config.text_dim)), mx.random.normal((6, config.text_dim))] + + out = model([latents, latents], t, ctx, seq_len, y=[y, y]) + mx.eval(out[0], out[1]) + assert len(out) == 2 + assert out[0].shape == (config.out_dim, F, H, W) + assert out[1].shape == (config.out_dim, F, H, W) + + +class TestVAEEncoder: + """Test Wan2.1 VAE encoder.""" + + def test_encoder3d_instantiation(self): + from mlx_video.models.wan.vae import Encoder3d + + enc = Encoder3d(dim=32, z_dim=8) # z_dim=8 (will output 8ch, but WanVAE wraps with z*2) + assert enc.conv1 is not None + assert len(enc.downsamples) > 0 + assert len(enc.middle) == 3 + + def test_encoder3d_output_shape(self): + """Encoder should downsample spatially by 8x and temporally by 4x.""" + from mlx_video.models.wan.vae import Encoder3d + + enc = Encoder3d(dim=32, z_dim=8) + # Random input: [B=1, 3, T=5, H=32, W=32] + x = mx.random.normal((1, 3, 5, 32, 32)) + out = enc(x) + mx.eval(out) + # With default dim_mult=[1,2,4,4] and temporal_downsample=[True,True,False]: + # Spatial: 32 -> 16 -> 8 -> 4 (3 spatial downsamples) + # Temporal: 5 -> 3 -> 2 (2 temporal downsamples: downsample3d stride 2) + assert out.shape[0] == 1 + assert out.shape[1] == 8 # z_dim + assert out.shape[3] == 32 // 8 # spatial /8 + assert out.shape[4] == 32 // 8 + + def test_wan_vae_encode(self): + """WanVAE with encoder=True should produce normalized latents.""" + from mlx_video.models.wan.vae import WanVAE + + vae = WanVAE(z_dim=16, encoder=True) + # Input: [B=1, 3, T=5, H=32, W=32] + x = mx.random.normal((1, 3, 5, 32, 32)) + z = vae.encode(x) + mx.eval(z) + assert z.shape[0] == 1 + assert z.shape[1] == 16 # z_dim + + def test_wan_vae_encoder_flag(self): + """WanVAE without encoder flag should not have encoder attribute.""" + from mlx_video.models.wan.vae import WanVAE + + vae_no_enc = WanVAE(z_dim=4, encoder=False) + assert not hasattr(vae_no_enc, 'encoder') + + vae_enc = WanVAE(z_dim=4, encoder=True) + assert hasattr(vae_enc, 'encoder') + + +class TestResampleDownsample: + """Test downsample modes in Resample.""" + + def test_downsample2d(self): + from mlx_video.models.wan.vae import Resample + + r = Resample(dim=16, mode="downsample2d") + x = mx.random.normal((1, 16, 2, 8, 8)) + out = r(x) + mx.eval(out) + # Spatial /2, temporal unchanged, channels same + assert out.shape == (1, 16, 2, 4, 4) + + def test_downsample3d(self): + from mlx_video.models.wan.vae import Resample + + r = Resample(dim=16, mode="downsample3d") + x = mx.random.normal((1, 16, 4, 8, 8)) + out = r(x) + mx.eval(out) + # Spatial /2, temporal /2, channels same + assert out.shape == (1, 16, 2, 4, 4) + + def test_upsample2d_still_works(self): + from mlx_video.models.wan.vae import Resample + + r = Resample(dim=16, mode="upsample2d") + x = mx.random.normal((1, 16, 2, 4, 4)) + out = r(x) + mx.eval(out) + assert out.shape == (1, 8, 2, 8, 8) + + def test_upsample3d_still_works(self): + from mlx_video.models.wan.vae import Resample + + r = Resample(dim=16, mode="upsample3d") + x = mx.random.normal((1, 16, 2, 4, 4)) + out = r(x) + mx.eval(out) + assert out.shape == (1, 8, 4, 8, 8) + + +class TestI2VMaskConstruction: + """Test mask construction for I2V-14B.""" + + def test_mask_shape(self): + """I2V-14B mask should have 4 channels with correct temporal structure.""" + num_frames = 81 + h_latent, w_latent = 10, 18 # example latent dims + t_latent = (num_frames - 1) // 4 + 1 # = 21 + + # Build mask following reference logic + msk = mx.ones((1, num_frames, h_latent, w_latent)) + msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1) + msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1) + msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent) + msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat] + + assert msk.shape == (4, t_latent, h_latent, w_latent) + + def test_mask_values(self): + """First temporal position should be 1, rest 0.""" + num_frames = 9 + h_latent, w_latent = 4, 4 + t_latent = (num_frames - 1) // 4 + 1 # = 3 + + msk = mx.ones((1, num_frames, h_latent, w_latent)) + msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1) + msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1) + msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent) + msk = msk.transpose(0, 2, 1, 3, 4)[0] + + mx.eval(msk) + # First temporal position: all 4 channels should be 1 + assert mx.all(msk[:, 0] == 1.0).item() + # Rest: all should be 0 + assert mx.all(msk[:, 1:] == 0.0).item() + + def test_y_tensor_shape(self): + """y = concat([mask_4ch, encoded_video_16ch]) should be 20 channels.""" + mask = mx.zeros((4, 5, 10, 18)) + encoded = mx.zeros((16, 5, 10, 18)) + y = mx.concatenate([mask, encoded], axis=0) + assert y.shape == (20, 5, 10, 18)