Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
|
||||
from mlx_video.models.wan import WanModel, WanModelConfig
|
||||
from mlx_video.models.wan2 import WanModel, WanModelConfig
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Union
|
||||
|
||||
from mlx_video.models.ltx.config import BaseModelConfig
|
||||
from mlx_video.models.ltx_2.config import BaseModelConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -247,7 +247,7 @@ def _load_lora_configs(
|
||||
|
||||
Shared between weight-merging and runtime-wrapping paths.
|
||||
"""
|
||||
from mlx_video.generate_wan import Colors
|
||||
from mlx_video.models.wan2.generate import Colors
|
||||
from mlx_video.lora import LoRAConfig, load_multiple_loras
|
||||
|
||||
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
|
||||
@@ -282,7 +282,7 @@ def load_and_apply_loras(
|
||||
|
||||
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
|
||||
"""
|
||||
from mlx_video.generate_wan import Colors
|
||||
from mlx_video.models.wan2.generate import Colors
|
||||
from mlx_video.lora import apply_loras_to_weights
|
||||
|
||||
if not lora_configs:
|
||||
@@ -411,7 +411,7 @@ def convert_wan_checkpoint(
|
||||
print(" Warning: No transformer weights found!")
|
||||
|
||||
# Save config — detect model size from source config.json or transformer weights
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
|
||||
def _detect_config():
|
||||
"""Detect config from source config.json or transformer weight shapes."""
|
||||
@@ -522,7 +522,7 @@ def convert_wan_checkpoint(
|
||||
print(f"Converting VAE ({'Wan2.2' if is_wan22_vae else 'Wan2.1'})...")
|
||||
weights = load_torch_weights(str(vae_path))
|
||||
if is_wan22_vae:
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
include_encoder = config.model_type in ("ti2v", "i2v")
|
||||
weights = sanitize_wan22_vae_weights(
|
||||
@@ -594,7 +594,7 @@ def _quantize_saved_model(
|
||||
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
if source_dir is None:
|
||||
source_dir = output_dir
|
||||
@@ -704,7 +704,7 @@ def quantize_mlx_model(
|
||||
).exists()
|
||||
|
||||
# Build model config
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
|
||||
config_dict = {
|
||||
k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__
|
||||
|
||||
@@ -1,394 +0,0 @@
|
||||
# Wan2.2 I2V-14B Diagnostic Report
|
||||
|
||||
This document records the systematic diagnostic methodology used to debug the Wan2.2 I2V-14B (Image-to-Video, 14 billion parameter) pipeline in mlx-video, along with every bug found, its root cause, and fix.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Architecture Summary](#architecture-summary)
|
||||
- [Diagnostic Methodology](#diagnostic-methodology)
|
||||
- [Bug 1: Text Embedding Cross-Contamination](#bug-1-text-embedding-cross-contamination)
|
||||
- [Bug 2: VAE Encoder Weights Excluded from Conversion](#bug-2-vae-encoder-weights-excluded-from-conversion)
|
||||
- [Bug 3: RoPE Frequency Computation (original)](#bug-3-rope-frequency-computation-original)
|
||||
- [Bug 6: RoPE Frequency Distribution (Bug 3 Fix Was Wrong)](#bug-6-rope-frequency-distribution-bug-3-fix-was-wrong)
|
||||
- [Bug 4: VAE Encoder Temporal Downsample Order](#bug-4-vae-encoder-temporal-downsample-order)
|
||||
- [Bug 5: Non-Chunked VAE Encoding](#bug-5-non-chunked-vae-encoding)
|
||||
- [Verified Correct Components](#verified-correct-components)
|
||||
- [Performance Optimizations](#performance-optimizations)
|
||||
- [Resolved: CFG Effectiveness](#resolved-cfg-effectiveness-was-open-investigation)
|
||||
- [Reference Implementation](#reference-implementation)
|
||||
- [Useful Diagnostic Commands](#useful-diagnostic-commands)
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
The I2V-14B pipeline takes an input image and generates a video using a dual-model diffusion transformer. The initial implementation produced severely broken output — first frame showed the image, subsequent frames degraded to noise, checkerboard artifacts, or flat grey.
|
||||
|
||||
Through a systematic component-by-component comparison against the reference PyTorch implementation, **five bugs** were found and fixed. The approach was to verify each component in isolation numerically, then narrow down failures to the subsystem level.
|
||||
|
||||
### Timeline of Symptoms
|
||||
|
||||
| Stage | Symptom | Root Cause |
|
||||
|-------|---------|------------|
|
||||
| Initial | Grey/blurry frames after frame 1 | Non-chunked VAE encoding (Bug 5) |
|
||||
| After chunked encoding fix | First frame OK, rest degrades to noise | Text embedding cross-contamination (Bug 1) + RoPE frequencies (Bug 3) |
|
||||
| After text + RoPE fix | Severe 8px checkerboard on frames 4+ | VAE encoder temporal downsample order (Bug 4) |
|
||||
| After VAE fix | Image in frames 0-3, grey frames 4+ | CFG effectiveness issue (open investigation) |
|
||||
|
||||
---
|
||||
|
||||
## Architecture Summary
|
||||
|
||||
```
|
||||
I2V-14B Pipeline:
|
||||
Input Image → VAE Encoder → [16, T_lat, H_lat, W_lat]
|
||||
↓
|
||||
Mask Construction → [4, T_lat, H_lat, W_lat]
|
||||
↓
|
||||
y = concat(mask, encoded_video) → [20, T_lat, H_lat, W_lat]
|
||||
↓
|
||||
Noise [16, T_lat, H_lat, W_lat] + y → [36, T_lat, H_lat, W_lat]
|
||||
↓
|
||||
Dual DiT (40 layers, 5120 dim) × 40 denoising steps
|
||||
↓
|
||||
Denoised Latent [16, T_lat, H_lat, W_lat]
|
||||
↓
|
||||
VAE Decoder → Video [3, F, H, W]
|
||||
```
|
||||
|
||||
**Key parameters:**
|
||||
- `in_dim=36` (16 noise + 4 mask + 16 image latents), `out_dim=16`
|
||||
- Dual model: HIGH noise (t ≥ 900) and LOW noise (t < 900)
|
||||
- 40 steps, shift=5.0, guide_scale=(3.5, 3.5)
|
||||
- Uses Wan2.1 VAE (z_dim=16, stride 4×8×8)
|
||||
|
||||
---
|
||||
|
||||
## Diagnostic Methodology
|
||||
|
||||
### 1. Component-Level Numerical Verification
|
||||
|
||||
Each component was tested in isolation against the reference PyTorch implementation:
|
||||
|
||||
1. **Load identical inputs** (same random seed, same image, same prompt)
|
||||
2. **Run through reference** (on CPU where possible) and save intermediate tensors as `.npy`
|
||||
3. **Run through MLX** with the same inputs
|
||||
4. **Compare outputs** with `np.abs(ours - ref).max()` and relative difference metrics
|
||||
|
||||
Components tested this way:
|
||||
- RoPE frequency parameters and rotation output
|
||||
- Time embedding (sinusoidal → MLP → projection)
|
||||
- Patchify (reshape+Linear vs Conv3d)
|
||||
- Unpatchify (transpose-based vs einsum)
|
||||
- Scheduler (UniPC) timesteps and step formulas
|
||||
- VAE encoder output (frame-by-frame comparison)
|
||||
- Text embeddings (per-model MLP output)
|
||||
- Cross-attention K/V cache shapes
|
||||
- Mask construction values
|
||||
|
||||
### 2. Artifact Analysis
|
||||
|
||||
When visual artifacts appeared, quantitative metrics were used to characterize them:
|
||||
|
||||
- **Checkerboard metric**: Difference between even-indexed and odd-indexed pixels at patch boundaries. Values > 20 indicate visible checkerboard.
|
||||
- **FFT frequency analysis**: Power at the 8px spatial frequency (matches VAE stride). 3× normal power confirmed VAE-stride-aligned artifacts.
|
||||
- **Per-frame statistics**: Mean, std, min, max for each decoded video frame to track temporal degradation.
|
||||
- **Frame difference**: `mean(|frame[i] - frame[i-1]|)` to measure motion vs static content.
|
||||
|
||||
### 3. Isolation Testing
|
||||
|
||||
- **VAE round-trip test**: Encode image+zeros → decode. If clean, VAE decoder is not the source.
|
||||
- **Single-step model output**: Run one diffusion step and compare cond vs uncond predictions to check CFG effectiveness.
|
||||
- **Patchify/unpatchify synthetic test**: Pass structured gradient through unpatchify to verify spatial ordering.
|
||||
- **Resolution sweeps**: Test at 480×272, 640×384, 1280×720 to check resolution dependence.
|
||||
- **Step count sweeps**: Test at 5, 20, 40 steps to distinguish convergence issues from model bugs.
|
||||
|
||||
### 4. Weight Comparison
|
||||
|
||||
Direct comparison of converted MLX weights against original PyTorch weights:
|
||||
```python
|
||||
# Load both weight sets
|
||||
pt_weights = torch.load("model.safetensors")
|
||||
mlx_weights = mx.load("model.safetensors")
|
||||
# Compare each key
|
||||
for key in pt_weights:
|
||||
diff = np.abs(np.array(pt_weights[key]) - np.array(mlx_weights[key])).max()
|
||||
```
|
||||
Expected: max diff ≈ 0.001 (bfloat16 rounding). Actual: confirmed for all keys.
|
||||
|
||||
---
|
||||
|
||||
## Bug 1: Text Embedding Cross-Contamination
|
||||
|
||||
**Symptom:** Model ignores text prompt, generated frames lack semantic content.
|
||||
|
||||
**Root Cause:** For the dual-model architecture (high-noise and low-noise experts), text embeddings were computed using only `low_noise_model.embed_text()` and reused for both models' cross-attention K/V caches. The two models have **different** text embedding MLP weights — 42% relative mean difference in output.
|
||||
|
||||
**How Found:** Compared `text_embedding_0.weight` and `text_embedding_1.weight` between `high_noise_model.safetensors` and `low_noise_model.safetensors`. Found 17.9% and 26.3% relative differences in the weight matrices.
|
||||
|
||||
**Fix:** Compute separate text embeddings per model:
|
||||
```python
|
||||
# Before (broken):
|
||||
context_emb = low_noise_model.embed_text([context, context_null])
|
||||
cross_kv = low_noise_model.prepare_cross_kv(context_emb) # used for BOTH models
|
||||
|
||||
# After (correct):
|
||||
context_emb_low = low_noise_model.embed_text([context, context_null])
|
||||
context_emb_high = high_noise_model.embed_text([context, context_null])
|
||||
cross_kv_low = low_noise_model.prepare_cross_kv(context_emb_low)
|
||||
cross_kv_high = high_noise_model.prepare_cross_kv(context_emb_high)
|
||||
```
|
||||
|
||||
**File:** `mlx_video/generate_wan.py` (lines 333–349)
|
||||
**Commit:** `a85b1c21`
|
||||
|
||||
---
|
||||
|
||||
## Bug 2: VAE Encoder Weights Excluded from Conversion
|
||||
|
||||
**Symptom:** VAE encoder produces constant output regardless of input image (all-zero weights after conversion).
|
||||
|
||||
**Root Cause:** The conversion script only included encoder weights for `model_type == "ti2v"` (TI2V-5B), not for `"i2v"` (I2V-14B). Since `load_vae_encoder()` uses `strict=False`, missing encoder weights were silently ignored, resulting in random initialization.
|
||||
|
||||
**How Found:** Traced through `convert_wan.py` and found `include_encoder = config.model_type == "ti2v"`. Cross-referenced with the fact that I2V-14B also requires a VAE encoder (for image conditioning).
|
||||
|
||||
**Fix:**
|
||||
```python
|
||||
# Before:
|
||||
include_encoder = config.model_type == "ti2v"
|
||||
# After:
|
||||
include_encoder = config.model_type in ("ti2v", "i2v")
|
||||
```
|
||||
|
||||
**Note:** The user's specific model happened to be manually converted with encoder weights already present, so this fix was preventive for future conversions.
|
||||
|
||||
**File:** `mlx_video/convert_wan.py` (line 424)
|
||||
|
||||
---
|
||||
|
||||
## Bug 3: RoPE Frequency Computation (original)
|
||||
|
||||
**Symptom:** Progressive 2px checkerboard artifacts on generated frames, increasing with temporal distance from the conditioned frame.
|
||||
|
||||
**Root Cause (original):** Our original code called `rope_params` three times but applied them incorrectly (per-axis in the model init, then rope_apply did NOT split). This was initially "fixed" by switching to a single `rope_params(1024, head_dim=128)` call, which reduced checkerboard but introduced Bug 6 (see below).
|
||||
|
||||
**File:** `mlx_video/models/wan/model.py`
|
||||
**Commit:** `3da4a637`
|
||||
|
||||
---
|
||||
|
||||
## Bug 6: RoPE Frequency Distribution (Bug 3 Fix Was Wrong)
|
||||
|
||||
**Symptom:** I2V generates input image in frames 0–3, colorful checkerboard on frame 4, then grey frames. CFG cond/uncond predictions nearly identical. Model cannot produce coherent motion.
|
||||
|
||||
**Root Cause:** The Bug 3 "fix" replaced three separate `rope_params` calls with a single `rope_params(1024, 128)`. But the reference (`wan/modules/model.py` lines 400–405) actually uses **three separate calls with different dimension normalizations**, concatenated:
|
||||
|
||||
```python
|
||||
# Reference (CORRECT):
|
||||
d = dim // num_heads # 128
|
||||
self.freqs = torch.cat([
|
||||
rope_params(1024, d - 4 * (d // 6)), # rope_params(1024, 44)
|
||||
rope_params(1024, 2 * (d // 6)), # rope_params(1024, 42)
|
||||
rope_params(1024, 2 * (d // 6)) # rope_params(1024, 42)
|
||||
], dim=1)
|
||||
```
|
||||
|
||||
Each axis gets its own full frequency range [θ^0, θ^(-~0.95)]. The single-call approach gave:
|
||||
- Temporal: low frequencies only [1.0 → 0.049]
|
||||
- Height: medium frequencies only [0.042 → 0.002] (should start at 1.0!)
|
||||
- Width: high frequencies only [0.002 → 0.0001] (should start at 1.0!)
|
||||
|
||||
The height/width position encoding was essentially destroyed — nearby spatial positions were indistinguishable (max diff 0.958 for height, 0.998 for width vs reference).
|
||||
|
||||
**How Found:** Direct line-by-line comparison of `WanModel.__init__` freq construction between reference `wan/modules/model.py` and our `models/wan/model.py`. Numerical verification confirmed the three-call approach gives each axis a full [0, ~1) exponent range, while the single-call monotonically assigns low→high across axes.
|
||||
|
||||
**Fix:**
|
||||
```python
|
||||
d = dim // config.num_heads
|
||||
self.freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
```
|
||||
|
||||
**Verification:** Max diff vs reference cos/sin: 0.00000000 (exact float32 match).
|
||||
|
||||
**Impact:** Affects ALL Wan models (T2V, I2V, TI2V). Resolves the "Open Investigation: CFG Effectiveness" issue — the model could not produce meaningful cond/uncond differences because it couldn't encode spatial positions.
|
||||
|
||||
**File:** `mlx_video/models/wan/model.py` (line 155)
|
||||
|
||||
---
|
||||
|
||||
## Bug 4: VAE Encoder Temporal Downsample Order
|
||||
|
||||
**Symptom:** Massive checkerboard artifacts aligned to VAE spatial stride (8px period). VAE encoder output for frames 1–4 showed decreasing std (0.37→1.19) while reference showed stable std (0.95→1.34).
|
||||
|
||||
**Root Cause:** The VAE encoder has 3 downsampling stages. Two perform spatial+temporal downsampling (`downsample3d`) and one performs spatial-only (`downsample2d`). The order matters:
|
||||
|
||||
```
|
||||
Reference: [False, True, True] → stage 0: 2d, stage 1: 3d, stage 2: 3d
|
||||
Ours: [True, True, False] → stage 0: 3d, stage 1: 3d, stage 2: 2d ← WRONG
|
||||
```
|
||||
|
||||
This caused temporal downsampling to happen at the wrong resolution stages (96-dim instead of 384-dim), corrupting temporal feature propagation.
|
||||
|
||||
**How Found:** Installed `einops` in the reference environment and ran the reference PyTorch VAE encoder on CPU. Compared frame-by-frame latent output:
|
||||
- Frame 0 matched exactly (diff=0.0000) — spatial-only processing was correct
|
||||
- Frames 1–4 had massive differences — proved temporal processing was broken
|
||||
|
||||
Then traced through the reference `_video_vae()` function and found it sets `temperal_downsample=[False, True, True]`, while our `Encoder3d` class used the wrong default `[True, True, False]`.
|
||||
|
||||
**Fix:**
|
||||
```python
|
||||
# In Encoder3d.__init__, change default:
|
||||
temporal_downsample = [False, True, True] # was [True, True, False]
|
||||
```
|
||||
|
||||
**Impact:** Encoder output now matches reference within float32 precision (max_diff=2.2e-5). Checkerboard metric dropped from 60–80 to 0.1–7.7.
|
||||
|
||||
**File:** `mlx_video/models/wan/vae.py` (line 370)
|
||||
**Commit:** `3da4a637`
|
||||
|
||||
---
|
||||
|
||||
## Bug 5: Non-Chunked VAE Encoding
|
||||
|
||||
**Symptom:** First 4–5 frames grey, then blurred version of image appears.
|
||||
|
||||
**Root Cause:** The reference VAE encoder uses **chunked encoding** with temporal caching (`feat_cache`):
|
||||
1. Encode first frame alone (1 frame)
|
||||
2. Encode remaining frames in chunks of 4, with cached temporal features propagating across chunks
|
||||
3. Each `CausalConv3d` caches last 2 temporal frames from its output, prepending them to the next chunk's input
|
||||
|
||||
Our original implementation encoded all frames at once with zero-padded causal convolutions. The temporal feature propagation is fundamentally different because:
|
||||
- Chunked: real features from previous chunks serve as causal context
|
||||
- Non-chunked: zeros serve as causal context for the start
|
||||
|
||||
**How Found:** Studied the reference `CausalConv3d` caching mechanism (`feat_cache`, `feat_idx`) and traced the temporal dimension through all encoding stages. Confirmed that non-chunked encoding produces different output by comparing tensor shapes and values.
|
||||
|
||||
**Fix:** Implemented full chunked encoding with temporal caching:
|
||||
- Added `cache_x` parameter to `CausalConv3d.__call__`
|
||||
- Added `feat_cache`/`feat_idx` propagation to `ResidualBlock`, `Resample`, `Encoder3d`
|
||||
- Rewrote `WanVAE.encode()` with chunked loop (1-frame first chunk, then 4-frame chunks)
|
||||
- 24 cache slots across the encoder (1 conv1 + 18 downsamples + 4 middle + 1 head)
|
||||
|
||||
**File:** `mlx_video/models/wan/vae.py` (multiple methods)
|
||||
**Commit:** `b6a94c4c`
|
||||
|
||||
---
|
||||
|
||||
## Verified Correct Components
|
||||
|
||||
These components were numerically verified against the reference and are **not** sources of bugs:
|
||||
|
||||
| Component | Method | Max Diff | Notes |
|
||||
|-----------|--------|----------|-------|
|
||||
| Weight conversion | Direct tensor comparison | ~0.001 | bfloat16 rounding only |
|
||||
| RoPE rotation | Standalone comparison (float32 vs float64) | 1.3e-5 | Complex vs real multiplication equivalent |
|
||||
| Time embedding | Full MLP comparison (sinusoidal→embed→project) | 7e-4 | 0.03% relative |
|
||||
| Patchify | Conv3d vs reshape+Linear | 3.5e-3 | 0.16% relative |
|
||||
| Unpatchify | einsum vs transpose(6,0,3,1,4,2,5) | exact | Identical operation |
|
||||
| Scheduler (UniPC) | Formula-level audit + timestep comparison | exact | Predictor, corrector, lambda, rhos all match |
|
||||
| Mask construction | Value comparison | exact | [4, T_lat, H_lat, W_lat], first temporal=1 |
|
||||
| CFG formula | Code audit | — | `uncond + gs * (cond - uncond)` correct order |
|
||||
| VAE decoder | Round-trip test (encode→decode) | clean | No checkerboard in round-trip output |
|
||||
| Cross-attention K/V | Shape and value audit | — | Batch dimension preserved correctly |
|
||||
|
||||
---
|
||||
|
||||
## Performance Optimizations
|
||||
|
||||
Applied alongside bug fixes to improve inference speed:
|
||||
|
||||
### Pre-Computation (Before Diffusion Loop)
|
||||
- **Cross-attention K/V caching**: Precompute K/V projections for all 40 blocks once
|
||||
- **RoPE cos/sin precomputation**: Build frequency tensors once instead of per-step broadcast/concat
|
||||
- **Attention mask precomputation**: Build padding mask once, pass via kwargs
|
||||
- **Inverse frequency caching**: Store sinusoidal `inv_freq` in `__init__` instead of recomputing
|
||||
- **Timestep list conversion**: `sched.timesteps.tolist()` before loop to avoid `.item()` sync
|
||||
|
||||
### Per-Step Optimizations
|
||||
- **Single patchify + broadcast for CFG B=2**: Detect identical batch inputs, patchify once and broadcast instead of duplicating the Linear projection
|
||||
- **Vectorized RoPE**: When all batch elements share the same grid size, apply rotation to the full batch tensor instead of looping per element
|
||||
- **Redundant type cast removal**: MLX type promotion handles `bfloat16 * float32 → float32` automatically — removed 240 unnecessary graph nodes per step (6 casts × 40 blocks)
|
||||
- **Euler scheduler sync fix**: Pre-store sigmas as Python floats to avoid `.item()` evaluation sync
|
||||
|
||||
---
|
||||
|
||||
## Resolved: CFG Effectiveness (was Open Investigation)
|
||||
|
||||
**Symptom:** Generated video shows the input image in frames 0–3 (latent frame 0), then grey/flat frames for the rest. Cond and uncond predictions were nearly identical.
|
||||
|
||||
**Resolution:** This was caused by Bug 6 (incorrect RoPE frequency distribution). The single `rope_params(1024, 128)` call gave height frequencies starting at 0.042 and width at 0.002 (instead of 1.0 for both), making the model unable to encode spatial positions. This caused the transformer to produce nearly identical outputs regardless of text conditioning, explaining the tiny cond/uncond differences.
|
||||
|
||||
---
|
||||
|
||||
## Reference Implementation
|
||||
|
||||
The reference PyTorch implementation is at `/Users/daniel/Projects/Wan2.2/`:
|
||||
|
||||
| File | Contents |
|
||||
|------|----------|
|
||||
| `wan/image2video.py` | I2V pipeline (y construction, mask, diffusion loop) |
|
||||
| `wan/modules/model.py` | DiT model (forward pass, RoPE, patchify) |
|
||||
| `wan/modules/vae2_1.py` | VAE encoder/decoder with chunked encoding |
|
||||
| `wan/utils/fm_solvers_unipc.py` | UniPC scheduler |
|
||||
| `wan/configs/wan_i2v_A14B.py` | Model configuration |
|
||||
|
||||
Key structural differences between reference and our implementation:
|
||||
- Reference runs **separate B=1 forward passes** for cond/uncond; we batch as B=2
|
||||
- Reference uses `torch.amp.autocast('cuda', dtype=bfloat16)` with explicit float32 blocks; we cast via weight dtype
|
||||
- Reference uses `Conv3d` for patchify; we use equivalent `reshape + Linear`
|
||||
- Reference casts timesteps to `int64`; we keep as float (diff < 1.0)
|
||||
|
||||
---
|
||||
|
||||
## Useful Diagnostic Commands
|
||||
|
||||
### Run I2V-14B generation
|
||||
```bash
|
||||
python -m mlx_video.generate_wan \
|
||||
--prompt "A woman smiles at camera" \
|
||||
--image start.png \
|
||||
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-I2V-A14B-MLX \
|
||||
--num-frames 17 --steps 40 \
|
||||
--height 384 --width 640 \
|
||||
--output output_i2v.mp4
|
||||
```
|
||||
|
||||
### Check VAE encoder output
|
||||
```python
|
||||
import mlx.core as mx, numpy as np
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
# Load VAE and encode an image
|
||||
latents = vae.encode(video_tensor) # [1, 16, T_lat, H_lat, W_lat]
|
||||
for t in range(latents.shape[2]):
|
||||
frame = np.array(latents[0, :, t])
|
||||
print(f"Frame {t}: mean={frame.mean():.4f} std={frame.std():.4f}")
|
||||
```
|
||||
|
||||
### Analyze video frame quality
|
||||
```python
|
||||
import cv2, numpy as np
|
||||
cap = cv2.VideoCapture("output.mp4")
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret: break
|
||||
# Checkerboard metric: high values indicate patch-boundary artifacts
|
||||
checker = np.abs(frame[::2, ::2].astype(float) - frame[1::2, 1::2].astype(float)).mean()
|
||||
print(f"std={frame.std():.1f} checker={checker:.1f}")
|
||||
```
|
||||
|
||||
### Compare weights between PyTorch and MLX
|
||||
```python
|
||||
import torch, mlx.core as mx, numpy as np
|
||||
pt = torch.load("model.pt", map_location="cpu")
|
||||
mlx_w = mx.load("model.safetensors")
|
||||
for key in sorted(pt.keys()):
|
||||
if key in mlx_w:
|
||||
diff = np.abs(pt[key].float().numpy() - np.array(mlx_w[key])).max()
|
||||
if diff > 0.01:
|
||||
print(f"LARGE DIFF {key}: {diff:.6f}")
|
||||
```
|
||||
@@ -1,285 +0,0 @@
|
||||
# Wan2.2 MLX Implementation Notes
|
||||
|
||||
> Learnings and key decisions from porting Wan2.2 (TI2V-5B / T2V-14B / I2V-14B / T2V-1.3B) to Apple MLX.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
Wan2.2 is a Diffusion Transformer (DiT) for video generation. Despite early reports, the T2V/TI2V models do **not** use Mixture-of-Experts — they are dense DiT models with a dual-model architecture for the 14B variant (separate high-noise and low-noise denoisers with a boundary timestep).
|
||||
|
||||
### Key Parameters
|
||||
|
||||
| Model | dim | heads | layers | FFN mult | VAE z_dim | VAE stride | in_dim |
|
||||
|-------|-----|-------|--------|----------|-----------|------------|--------|
|
||||
| T2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | 16 |
|
||||
| I2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | 36 |
|
||||
| TI2V-5B | 3072 | 24 | 32 | 4×(3072×4/3) | 48 | (4, 16, 16) | 48 |
|
||||
| T2V-1.3B | 1536 | 12 | 30 | 4×(1536×4/3) | 16 | (4, 8, 8) | 16 |
|
||||
|
||||
### Codebase Structure (~3900 lines of Wan2.2 code)
|
||||
|
||||
```
|
||||
mlx_video/
|
||||
├── generate_wan.py # 483L - Generation pipeline (T2V + I2V)
|
||||
├── convert_wan.py # 564L - Weight conversion from HuggingFace
|
||||
└── models/wan/
|
||||
├── config.py # 113L - Model configs (dataclass presets)
|
||||
├── model.py # 320L - DiT model (time embed, patchify, unpatchify)
|
||||
├── transformer.py # 91L - Attention block + FFN
|
||||
├── attention.py # 211L - Self-attention + cross-attention
|
||||
├── rope.py # 100L - 3D Rotary Position Embeddings
|
||||
├── text_encoder.py # 240L - T5 encoder (UMT5-XXL)
|
||||
├── scheduler.py # 428L - Euler, DPM++ 2M, UniPC schedulers
|
||||
├── vae.py # 315L - Wan2.1 VAE decoder (4×8×8)
|
||||
├── vae22.py # 836L - Wan2.2 VAE encoder + decoder (4×16×16)
|
||||
├── loading.py # 154L - Model loading utilities
|
||||
└── i2v_utils.py # 58L - I2V mask/preprocessing
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Critical Bugs & Fixes
|
||||
|
||||
### 1. MLX Underscore Attribute Gotcha
|
||||
|
||||
**Problem**: MLX's `nn.Module` silently ignores underscore-prefixed attributes (`_layer_0`, `_layer_1`, etc.) in `parameters()` and `load_weights()`. The Wan2.2 VAE had layers named `_layer_N`, causing **87 out of 110 weights to be silently dropped** during loading.
|
||||
|
||||
**Fix**: Rename all `_layer_N` attributes to `layer_N`. MLX treats underscore-prefixed attributes as "private" and excludes them from the parameter tree.
|
||||
|
||||
**Lesson**: Never use underscore-prefixed names for `nn.Module` sub-modules in MLX.
|
||||
|
||||
### 2. Patchify Channel Ordering
|
||||
|
||||
**Problem**: The patchify/unpatchify operations transposed channels incorrectly — producing `[C fastest]` layout instead of `[C slowest]`, causing completely garbled video output.
|
||||
|
||||
**Fix**: Changed reshape to produce correct `[B, T', H', W', pt*ph*pw*C]` ordering matching PyTorch's contiguous memory layout.
|
||||
|
||||
**Lesson**: When porting PyTorch reshape/view operations to MLX, pay close attention to memory layout — PyTorch is row-major by default, and reshape semantics differ when dimensions are reordered.
|
||||
|
||||
### 3. VAE AttentionBlock Reshape
|
||||
|
||||
**Problem**: Attention block merged batch (B) with channels (C) instead of batch with temporal (T), producing a green checker pattern in output.
|
||||
|
||||
**Fix**: Correct reshape from `[B*C, T, H, W]` to `[B*T, C, H, W]` for spatial attention.
|
||||
|
||||
### 4. RMS Norm vs L2 Norm
|
||||
|
||||
**Problem**: The Wan2.2 VAE uses a class named `RMS_norm` in PyTorch, but it actually computes **L2 normalization** (divide by L2 norm), not RMS normalization (divide by RMS). Using actual RMS norm caused exponential value explosion.
|
||||
|
||||
**Fix**: Implement as `x / ||x||₂` instead of `x / sqrt(mean(x²))`.
|
||||
|
||||
**Lesson**: Don't trust class names in reference code — read the actual computation.
|
||||
|
||||
### 5. Video Codec Green Output
|
||||
|
||||
**Problem**: OpenCV's `mp4v` codec on macOS produces green-tinted video.
|
||||
|
||||
**Fix**: Switch to `imageio` with `libx264` codec. Fallback chain: imageio → cv2 (avc1) → PNG frames.
|
||||
|
||||
---
|
||||
|
||||
## Precision & Dtype Flow
|
||||
|
||||
### The bfloat16 Autocast Pattern
|
||||
|
||||
The official PyTorch implementation uses `torch.autocast("cuda", dtype=torch.bfloat16)` which automatically casts matmul inputs. In MLX, we replicate this manually:
|
||||
|
||||
| Operation | Official (PyTorch) | MLX Implementation |
|
||||
|---|---|---|
|
||||
| Modulation/gates | float32 (explicit `autocast(enabled=False)`) | `x.astype(mx.float32)` before modulation |
|
||||
| QKV projections | bfloat16 (outer autocast) | Cast input to `self.q.weight.dtype` |
|
||||
| RoPE computation | float64 → float32 | float32 (MLX lacks float64 on GPU) |
|
||||
| Q/K after RoPE | bfloat16 (`q.to(v.dtype)`) | Cast back to weight dtype after RoPE |
|
||||
| FFN matmuls | bfloat16 (outer autocast) | Cast input to `self.fc1.weight.dtype` |
|
||||
| Residual stream | float32 | float32 (no cast) |
|
||||
|
||||
**Result**: ~16% speedup (47s vs 56s for 20 steps at 480p) with no quality regression.
|
||||
|
||||
**Key insight**: Modulation parameters (scale, shift, gate) must stay in float32 — they are small values (~0.01–0.1) that lose significant precision in bfloat16. The official code explicitly disables autocast for these computations.
|
||||
|
||||
### T5 Encoder Precision
|
||||
|
||||
The T5 text encoder must run in float32. Bfloat16 weights cause the attention softmax to produce degenerate distributions, which corrupts text conditioning and manifests as blurry patches in generated video. Since T5 only runs once per generation, the performance cost is negligible.
|
||||
|
||||
### VAE Decoder Precision
|
||||
|
||||
VAE weights must be float32. Bfloat16 VAE decode introduces visible quality loss in the decoded video frames.
|
||||
|
||||
---
|
||||
|
||||
## Scheduler Implementation Details
|
||||
|
||||
### Three Schedulers: Euler, DPM++ 2M, UniPC
|
||||
|
||||
All operate in the flow-matching formulation where `sigma` represents the noise level (1.0 = pure noise, 0.0 = clean).
|
||||
|
||||
**Euler**: Simple first-order ODE solver. Most stable, recommended for debugging.
|
||||
|
||||
**DPM++ 2M**: Second-order multistep solver. Uses previous step's model output for higher-order correction. Requires special handling at boundaries (return `±inf` from `_lambda()` when sigma is 0 or 1).
|
||||
|
||||
**UniPC** (default, matches official): Second-order predictor-corrector. The "C" (corrector) part is critical — it refines each step using the already-computed model output at **zero additional model evaluation cost**.
|
||||
|
||||
### UniPC Corrector: Must Be Enabled
|
||||
|
||||
**Discovery**: Our implementation had `use_corrector=False` by default, but the official Wan2.2 code **always** enables it (there's no flag — the corrector runs whenever `step_index > 0`).
|
||||
|
||||
**Impact**: Without the corrector, UniPC degrades to a simple predictor, losing its second-order accuracy advantage.
|
||||
|
||||
### UniPC Corrector Coefficients
|
||||
|
||||
The corrector coefficients (`rhos_c`) must be computed by solving a linear system, not hardcoded. For order ≥ 2, hardcoding `rhos_c[-1] = 0.5` introduces ~6–13% error in the correction term across 47+ steps. The fix uses `np.linalg.solve()` to compute exact coefficients.
|
||||
|
||||
### Sigma Schedule
|
||||
|
||||
```python
|
||||
# Flow-matching sigma schedule with shift
|
||||
sigmas = np.linspace(1.0, 1.0 / num_steps, num_steps)
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
```
|
||||
|
||||
Default shifts: T2V-14B uses 5.0, TI2V-5B uses 3.0, T2V-1.3B uses 3.0.
|
||||
|
||||
---
|
||||
|
||||
## Image-to-Video (I2V) Pipelines
|
||||
|
||||
Wan2.2 supports two distinct I2V approaches:
|
||||
|
||||
### TI2V-5B: Per-Token Timestep Masking
|
||||
|
||||
I2V conditions on a reference first frame by giving first-frame latent patches a timestep of 0 (clean) while other patches get the current diffusion timestep:
|
||||
|
||||
```python
|
||||
# mask_tokens: [1, L] — 0 for first-frame patches, 1 for rest
|
||||
t_tokens = mask_tokens * current_timestep # first-frame → t=0
|
||||
```
|
||||
|
||||
The model receives 2D timestep input `[B, L]` instead of scalar, enabling per-token noise levels.
|
||||
|
||||
#### Mask Re-application
|
||||
|
||||
After each scheduler step, the first-frame latent is re-injected to prevent drift:
|
||||
|
||||
```python
|
||||
latents = (1.0 - mask) * z_img + mask * latents
|
||||
```
|
||||
|
||||
#### VAE Encoder Temporal Downsample Order
|
||||
|
||||
The Wan2.2 VAE encoder has `temporal_downsample = (False, True, True)`:
|
||||
- Stage 0: Spatial-only downsampling
|
||||
- Stages 1–2: Spatial + temporal downsampling
|
||||
|
||||
This was incorrectly set to `(True, True, False)` initially, causing wrong spatial processing paths.
|
||||
|
||||
### I2V-14B: Channel Concatenation
|
||||
|
||||
The I2V-14B model uses a fundamentally different approach — channel concatenation via a `y` tensor:
|
||||
|
||||
1. **Encode image**: Resize to target (H, W), create video tensor with image as first frame + zeros → VAE encode through Wan2.1 encoder → `[16, T_lat, H_lat, W_lat]`
|
||||
2. **Build mask**: Binary mask with 1 for first frame, 0 for rest → rearranged to `[4, T_lat, H_lat, W_lat]`
|
||||
3. **Construct y**: `y = concat([mask_4ch, encoded_16ch])` → `[20, T_lat, H_lat, W_lat]`
|
||||
4. **Channel concat in model**: Before patchify, `x = concat([noise_16ch, y_20ch])` → 36 channels matching `in_dim=36`
|
||||
|
||||
Key differences from TI2V-5B:
|
||||
- Uses **Wan2.1 VAE** (z_dim=16, stride 4,8,8), not Wan2.2 VAE
|
||||
- Requires the **VAE encoder** (for encoding the reference image)
|
||||
- Uses **scalar timesteps** (same as T2V) — no per-token masking
|
||||
- **Dual model** pipeline with boundary=0.900
|
||||
- Both conditional and unconditional predictions receive the same `y` tensor
|
||||
|
||||
---
|
||||
|
||||
## Dimension Constraints
|
||||
|
||||
### Patchify Alignment
|
||||
|
||||
Video dimensions must be divisible by `patch_size × vae_stride`:
|
||||
- **TI2V-5B**: patch=(1,2,2), stride=(4,16,16) → alignment = **32** pixels
|
||||
- **T2V-14B**: patch=(1,2,2), stride=(4,8,8) → alignment = **16** pixels
|
||||
|
||||
Example: 720p (1280×720) → 720 % 32 ≠ 0, auto-aligns to **704**.
|
||||
|
||||
### Frame Count
|
||||
|
||||
Frames must satisfy `num_frames = 4n + 1` (e.g., 5, 9, 13, ..., 81) due to temporal VAE stride of 4.
|
||||
|
||||
---
|
||||
|
||||
## Performance Optimizations
|
||||
|
||||
### Batched CFG
|
||||
|
||||
Instead of two separate forward passes for conditional and unconditional predictions, batch them into a single B=2 forward pass:
|
||||
|
||||
```python
|
||||
preds = model([latents, latents], t=t_batch, context=context_cfg, ...)
|
||||
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
|
||||
```
|
||||
|
||||
**Result**: ~40% speedup by amortizing attention overhead.
|
||||
|
||||
### Precomputed Text Embeddings & Cross-Attention KV Cache
|
||||
|
||||
Text embeddings and cross-attention K/V projections are constant across all diffusion steps. Computing them once and passing as caches eliminates redundant computation.
|
||||
|
||||
### Memory Management in Diffusion Loop
|
||||
|
||||
```python
|
||||
# Release temporaries before eval to free memory for graph execution
|
||||
del noise_pred_cond, noise_pred_uncond, noise_pred, preds
|
||||
mx.eval(latents)
|
||||
```
|
||||
|
||||
MLX's lazy evaluation means `mx.eval()` triggers the full computation graph. Deleting intermediate arrays before eval allows MLX to reuse their memory during execution.
|
||||
|
||||
---
|
||||
|
||||
## Weight Conversion
|
||||
|
||||
### Key Mapping Patterns
|
||||
|
||||
The PyTorch → MLX conversion (`convert_wan.py`) handles several systematic transforms:
|
||||
|
||||
1. **Conv3d weight transposition**: PyTorch `(out, in, D, H, W)` → MLX `(out, D, H, W, in)`
|
||||
2. **Linear weight transposition**: PyTorch `(out, in)` → MLX `(out, in)` (same convention for `nn.Linear`)
|
||||
3. **Nested module paths**: `blocks.0.self_attn.q.weight` → same paths, MLX loads by dotted key
|
||||
|
||||
### Dual-Model Splitting
|
||||
|
||||
The T2V-14B uses dual models (high-noise and low-noise). The conversion script splits a single checkpoint into separate files or handles pre-split checkpoints from HuggingFace.
|
||||
|
||||
---
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
332 tests across 10 files, all running in ~5 seconds:
|
||||
|
||||
| File | Focus |
|
||||
|------|-------|
|
||||
| test_wan_config.py | Config presets, field validation |
|
||||
| test_wan_attention.py | Self/cross attention, RMSNorm, bf16 autocast |
|
||||
| test_wan_transformer.py | FFN, attention block, float32 modulation |
|
||||
| test_wan_model.py | Full DiT forward pass, per-token timesteps |
|
||||
| test_wan_t5.py | T5 encoder layers and full encoding |
|
||||
| test_wan_vae.py | VAE 2.1 decoder, VAE 2.2 encoder + decoder |
|
||||
| test_wan_scheduler.py | All 3 schedulers, cross-scheduler coherence |
|
||||
| test_wan_convert.py | Weight sanitization and conversion |
|
||||
| test_wan_generate.py | End-to-end pipeline, I2V masks, dimension alignment |
|
||||
| test_wan_i2v.py | I2V-14B config, y parameter, VAE encoder, mask construction |
|
||||
|
||||
Tests use a tiny config (`dim=64, heads=2, layers=2`) for fast execution. Cross-scheduler coherence tests verify that all three schedulers produce similar outputs from the same noise.
|
||||
|
||||
---
|
||||
|
||||
## Known Issues
|
||||
|
||||
### I2V Quality Degradation
|
||||
|
||||
Frames 2–13 gradually degrade, and frame 14 often has a "flash" artifact. All implementation details have been verified against the official PyTorch code with no discrepancies found. Possible causes:
|
||||
- Subtle numerical differences from float32 vs float64 RoPE (MLX lacks float64 on GPU)
|
||||
- MLX-specific attention precision behavior
|
||||
- Better prompts and 720p resolution (the model's native resolution) help reduce artifacts
|
||||
|
||||
### Chinese Negative Prompt
|
||||
|
||||
The official Wan2.2 uses a Chinese negative prompt that prevents oversaturation and comic-style artifacts. Correct tokenization requires `ftfy.fix_text()` to normalize fullwidth characters and double HTML unescaping. Without proper text cleaning, the negative prompt tokens don't match the training distribution, causing blurry patches.
|
||||
@@ -11,15 +11,15 @@ import mlx.core as mx
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from mlx_video.models.wan.i2v_utils import build_i2v_mask, preprocess_image
|
||||
from mlx_video.models.wan.loading import (
|
||||
from mlx_video.models.wan2.i2v_utils import build_i2v_mask, preprocess_image
|
||||
from mlx_video.models.wan2.utils import (
|
||||
encode_text,
|
||||
load_t5_encoder,
|
||||
load_vae_decoder,
|
||||
load_vae_encoder,
|
||||
load_wan_model,
|
||||
)
|
||||
from mlx_video.models.wan.postprocess import save_video
|
||||
from mlx_video.models.wan2.postprocess import save_video
|
||||
|
||||
|
||||
class Colors:
|
||||
@@ -121,8 +121,8 @@ def generate_video(
|
||||
"""
|
||||
import json
|
||||
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
from mlx_video.models.wan.scheduler import (
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan2.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowMatchEulerScheduler,
|
||||
FlowUniPCScheduler,
|
||||
@@ -729,7 +729,7 @@ def generate_video(
|
||||
# the CausalConv3d zero-padding artifacts fall on the prefix (which we crop).
|
||||
# This gives the first real frame a full temporal receptive field of real data.
|
||||
# Select tiling configuration
|
||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig
|
||||
|
||||
if tiling == "none":
|
||||
tiling_config = None
|
||||
@@ -767,7 +767,7 @@ def generate_video(
|
||||
)
|
||||
|
||||
if is_wan22_vae:
|
||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||
from mlx_video.models.wan2.vae22 import denormalize_latents
|
||||
|
||||
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
|
||||
z = latents.transpose(1, 2, 3, 0)[None]
|
||||
|
||||
@@ -6,7 +6,7 @@ for non-causal temporal decoders (e.g. Wan2.1 where T latent frames → T*scale
|
||||
output frames rather than LTX's 1+(T-1)*scale mapping).
|
||||
|
||||
# TODO: This function can be refactored to consolidate with
|
||||
# mlx_video.models.ltx.video_vae.tiling.decode_with_tiling once the
|
||||
# mlx_video.models.ltx_2.video_vae.tiling.decode_with_tiling once the
|
||||
# causal_temporal generalisation is accepted upstream.
|
||||
"""
|
||||
|
||||
@@ -14,7 +14,7 @@ from typing import Callable, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from mlx_video.models.ltx.video_vae.tiling import (
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import (
|
||||
SpatialTilingConfig,
|
||||
TemporalTilingConfig,
|
||||
TilingConfig,
|
||||
|
||||
@@ -21,12 +21,12 @@ def load_wan_model(
|
||||
If provided, creates QuantizedLinear stubs before loading.
|
||||
loras: Optional list of (lora_path, strength) tuples to apply.
|
||||
"""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
model = WanModel(config)
|
||||
|
||||
if quantization:
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
|
||||
nn.quantize(
|
||||
model,
|
||||
@@ -42,7 +42,7 @@ def load_wan_model(
|
||||
if quantization:
|
||||
# Dequantize LoRA-targeted layers, merge delta, replace with bf16 Linear.
|
||||
# Non-LoRA layers stay 4-bit. Zero per-step overhead.
|
||||
from mlx_video.convert_wan import _load_lora_configs
|
||||
from mlx_video.models.wan2.convert import _load_lora_configs
|
||||
from mlx_video.lora import apply_loras_to_model
|
||||
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
@@ -53,7 +53,7 @@ def load_wan_model(
|
||||
return model
|
||||
else:
|
||||
# Weight merging: fold LoRA into bf16 weights before loading
|
||||
from mlx_video.convert_wan import load_and_apply_loras
|
||||
from mlx_video.models.wan2.convert import load_and_apply_loras
|
||||
|
||||
weights = load_and_apply_loras(dict(weights), loras)
|
||||
|
||||
@@ -69,7 +69,7 @@ def load_t5_encoder(model_path: Path, config):
|
||||
only runs once per generation, so performance impact is negligible.
|
||||
This matches the official which computes softmax in float32 explicitly.
|
||||
"""
|
||||
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||
from mlx_video.models.wan2.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
vocab_size=config.t5_vocab_size,
|
||||
@@ -97,11 +97,11 @@ def load_vae_decoder(model_path: Path, config=None):
|
||||
is_wan22 = config is not None and config.vae_z_dim == 48
|
||||
|
||||
if is_wan22:
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||
from mlx_video.models.wan2.vae22 import Wan22VAEDecoder
|
||||
|
||||
vae = Wan22VAEDecoder(z_dim=48)
|
||||
else:
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
from mlx_video.models.wan2.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16)
|
||||
|
||||
@@ -120,11 +120,11 @@ def load_vae_encoder(model_path: Path, config=None):
|
||||
For Wan2.1/I2V-14B (vae_z_dim=16), uses WanVAE with encoder=True.
|
||||
"""
|
||||
if config is not None and config.vae_z_dim == 16:
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
from mlx_video.models.wan2.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16, encoder=True)
|
||||
else:
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder
|
||||
|
||||
vae = Wan22VAEEncoder(z_dim=config.vae_z_dim if config else 48)
|
||||
|
||||
@@ -589,7 +589,7 @@ class WanVAE(nn.Module):
|
||||
Returns:
|
||||
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
|
||||
"""
|
||||
from mlx_video.models.wan.tiling import TilingConfig, decode_with_tiling
|
||||
from mlx_video.models.wan2.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
if tiling_config is None:
|
||||
tiling_config = TilingConfig.default()
|
||||
|
||||
@@ -966,7 +966,7 @@ class Wan22VAEDecoder(nn.Module):
|
||||
Returns:
|
||||
video: [B, T', H', W', 3] decoded RGB in [-1, 1]
|
||||
"""
|
||||
from mlx_video.models.wan.tiling import TilingConfig, decode_with_tiling
|
||||
from mlx_video.models.wan2.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
if tiling_config is None:
|
||||
tiling_config = TilingConfig.default()
|
||||
|
||||
Reference in New Issue
Block a user