feat(wan): Add Wan2.1/2.2 T2V with quantization support

This commit is contained in:
Daniel
2026-02-26 16:16:07 +01:00
parent 7a74946c57
commit e64483a66a
21 changed files with 5309 additions and 35 deletions

158
.github/copilot-instructions.md vendored Normal file
View File

@@ -0,0 +1,158 @@
# MLX-Video Copilot Instructions
## Overview
MLX-Video is a video/audio generation package using Apple MLX framework. It implements the LTX-2 model (19B parameter DiT) for text-to-video, image-to-video, and audio-video generation, optimized for Apple Silicon.
## Build, Test, and Lint
### Testing
```bash
# Install test dependencies first (pytest not in main deps)
pip install pytest
# Run all tests
python -m pytest tests/
# Run specific test file
python -m pytest tests/test_generate_dev.py
# Run specific test
python -m pytest tests/test_generate_dev.py::TestLTX2Scheduler::test_scheduler_output_shape
```
### Linting
Pre-commit hooks configured with:
- **black**: Code formatting
- **isort**: Import sorting (profile: black)
- **autoflake**: Remove unused imports
```bash
# Run pre-commit manually
pre-commit run --all-files
```
### Running Generation
```bash
# Quick test - distilled model (two-stage pipeline)
python -m mlx_video.generate --prompt "test video" --num-frames 33
# Dev model with CFG (single-stage, higher quality)
python -m mlx_video.generate_dev --prompt "test video" --steps 40 --cfg-scale 4.0
# Audio-video generation
python -m mlx_video.generate_av --prompt "test video" --output-path out.mp4 --output-audio out.wav
```
## Architecture
### Two-Stage Pipeline (Distilled Model)
The distilled model (`generate.py`) uses a two-stage approach for efficiency:
1. **Stage 1**: Generate at half resolution with 8 denoising steps using STAGE_1_SIGMAS
2. **Upsampler**: 2x spatial upsampling via LatentUpsampler
3. **Stage 2**: Refine at full resolution with 3 steps using STAGE_2_SIGMAS
4. **VAE Decoder**: Convert latents to RGB video (tiled decoding for memory efficiency)
### Single-Stage Pipeline (Dev Model)
The dev model (`generate_dev.py`) uses classifier-free guidance (CFG):
- Full resolution generation with configurable steps (typically 40)
- CFG guidance scale controls prompt adherence vs. diversity
- More flexible but slower than distilled model
### Core Components
**DiT Transformer** (`models/ltx/ltx.py`):
- 48 layers, 32 attention heads, 128 dim per head
- Dual modality support: video (3840-dim) and audio (2048-dim) embeddings
- Uses RoPE (Rotary Position Embeddings) in SPLIT mode with double precision
- AdaLN-Zero conditioning blocks inject timestep/text embeddings
**VAE Architecture**:
- **Video VAE**: 128 latent channels, 8x temporal + 32x spatial compression
- Encoder: `models/ltx/video_vae/encoder.py`
- Decoder: `models/ltx/video_vae/decoder.py` (supports tiled decoding)
- **Audio VAE**: 8 latent channels, mel-spectrogram intermediate
- Decoder: `models/ltx/audio_vae/decoder.py`
- HiFi-GAN vocoder: `models/ltx/audio_vae/vocoder.py`
**Text Encoder** (`models/ltx/text_encoder.py`):
- Based on Gemma 3 model
- Returns separate embeddings for video (3840-dim) and audio (2048-dim)
- Supports prompt enhancement via `enhance_t2v()` method
**Tiling System** (`models/ltx/video_vae/tiling.py`):
- Memory-efficient decoding for large videos
- Modes: auto, default (512px/64f), aggressive (256px/32f), conservative (768px/96f)
- Supports streaming via `on_frames_ready` callback
### Key Patterns
**Position Grids**:
- Created in pixel space, then converted to latent space internally
- Video: (B, 3, num_patches, 2) with [start, end) bounds for temporal/spatial dims
- Audio: (B, 1, num_patches, 2) for temporal dimension only
- See `create_position_grid()` in generate modules
**Latent Conditioning** (`conditioning/latent.py`):
- `LatentState` tracks clean latents, noise, and sigma values
- `VideoConditionByLatentIndex` enables I2V by conditioning specific frames
- `apply_denoise_mask()` protects conditioned regions during denoising
**Weight Loading**:
- `convert.py`: Downloads from HuggingFace, converts PyTorch → MLX format
- Sanitization functions (`sanitize_transformer_weights`, `sanitize_vae_encoder_weights`) adapt keys
- Uses safetensors for efficient loading
## Key Conventions
### Model Configuration
- Always use `LTXModelConfig` to instantiate models
- `model_type` determines modality: `VideoOnly`, `AudioOnly`, or `AudioVideo`
- `rope_type=LTXRopeType.SPLIT` and `double_precision_rope=True` are standard
### Frame Count Requirements
- **Distilled model**: `num_frames = 1 + 8*k` format (e.g., 33, 65, 97)
- **Dev model**: No strict requirement, but odd numbers work better
- Audio frames auto-computed from video duration via `AUDIO_LATENTS_PER_SECOND`
### Dimension Constraints
- Video height/width must be divisible by 64 (VAE spatial compression)
- Latent dimensions are pixel dimensions divided by 32
### Audio Constants
```python
AUDIO_SAMPLE_RATE = 24000 # Output sample rate
AUDIO_LATENT_SAMPLE_RATE = 16000 # VAE internal rate
AUDIO_HOP_LENGTH = 160 # Mel hop length
AUDIO_LATENT_CHANNELS = 8 # Audio latent channels
AUDIO_MEL_BINS = 16 # Mel frequency bins
```
### Sigma Schedules
Distilled model uses predefined schedules (no scheduler class):
```python
STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
```
Dev model computes schedules via `ltx2_scheduler(steps)` function.
### Code Style
- Follow black formatting (configured in pre-commit)
- Import sorting: isort with black profile
- Remove unused imports (autoflake)
- Type hints encouraged but not enforced
### Modality Enum
Use `Modality.VIDEO` and `Modality.AUDIO` from `models/ltx/transformer.py` for multi-modal operations.
### Video Post-Processing
- `postprocess.py`: Contains utilities for frame normalization and video saving
- Always denormalize latents from [-1, 1] to [0, 255] before saving
- Use opencv-python for video I/O
## Python Requirements
- Python >= 3.11
- MLX >= 0.22.0
- Primary dependencies: numpy, safetensors, transformers, opencv-python, Pillow, mlx-vlm, scipy, librosa
- Package manager: uv recommended for faster installs, pip also supported

26
.github/skills/fast-mlx/SKILL.md vendored Normal file
View File

@@ -0,0 +1,26 @@
---
name: fast-mlx
description: Optimize MLX code for performance and memory. Use when asked to implement or speed up MLX models or algorithms, reduce latency/throughput bottlenecks, tune lazy evaluation, type promotion, fast ops, compilation, memory use, or profiling.
---
# Fast MLX
## Workflow
- Looks for opportunities to compile functions of mostly elementwise operations.
- For models with fixed shape inputs or where the shapes don't change much, compile the entire graph
- Replace slow implementations with MLX fast ops
- Identify evaluation boundaries and unintended sync points (`mx.eval`, `item()`, NumPy conversions).
- Check dtype promotion and scalar usage; keep precision consistent with intent.
- Review compilation strategy; avoid unnecessary recompiles and closure captures.
- Reduce peak memory via lazy loading order and releasing temporaries before `mx.eval`.
- Suggest profiling steps if the bottleneck is unclear.
## References
- Read `references/fast-mlx-guide.md` for detailed tips and examples. Use it as the source of truth.
## Output expectations
- Provide concrete code changes with brief rationale
- Call out changes that need user confirmation (e.g., enabling async eval or shapeless compile).

View File

@@ -0,0 +1,350 @@
# Making MLX Go Fast
## Table of Contents
- [Graph Evaluation](#graph-evaluation)
- [Type Promotion](#type-promotion)
- [Operations](#operations)
- [Compile](#compile)
- [Memory Use](#memory-use)
- [Profiling](#profiling)
This guide assumes you have some familiarity with MLX and want to make your MLX
model or algorithm as efficient as possible.
### Graph Evaluation
Recall, MLX is lazy. When you call an MLX op, no computation actually happens.
You are simply building a graph. The computation happens when you explicitly or
implicitly evaluate an array. Read more about how this works in the
documentation:
https://ml-explore.github.io/mlx/build/html/usage/lazy_evaluation.html
Evaluating the graph incurs some overhead, so don't do it too frequently.
Conversely you don't want the graph to get too big before evaluating it as this
can also be expensive. Most numerical and machine learning algorithms are
iterative. A good place to evaluate the graph is at the end of each iteration.
Some examples:
- After an iteration of gradient descent
- After producing one token with a language model
- After taking one denoising step in a diffusion model
Overly frequent evaluations sometimes happen by accident. For example:
```python
# output is an mx.array
for x in output:
do_something(x.item())
```
The same thing can be written more explicitly with operations and `mx.eval` as:
```python
for i in range(len(output)):
x = output[i]
mx.eval(x)
do_something(x.item())
```
Two better options are:
1. When possible avoid calling `item()` and do everything in MLX.
2. Move the entire output to Python or NumPy first.
An example of the second approach:
```python
for x in output.tolist():
do_something(x)
```
#### Asynchronous Evaluation
For a latency sensitive computation which is run many times, `mx.async_eval`
can be useful. Normally `mx.eval` is synchronous. It returns only when the
computation is complete. Instead `mx.async_eval` asynchronously evaluates the
graph and returns to the main thread immediately. You can use this to pipeline
graph construction with computation like so:
```python
def generator():
out = mx.async_eval(my_function())
while True:
out_next = mx.async_eval(my_function())
mx.eval(out)
yield out
out = out_next
```
For this to work `my_function()` cannot do any synchronous evaluations (e.g.
calling `mx.eval`, converting to NumPy, etc.). Furthermore, any work done on
`out` that is synchronous and on the same stream can stall the pipeline:
```python
for out in generator():
out = out * 2
# Stalls the pipeline!
mx.eval(out)
```
An easy fix for this is to put the pipeline in a separate stream:
```python
def generator():
with mx.stream(mx.new_stream(mx.gpu)):
out = mx.async_eval(my_function())
while True:
out_next = mx.async_eval(my_function())
mx.eval(out)
yield out
out = out_next
```
### Type Promotion
One of the most common performance issues comes from accidental up-casting.
Make sure you understand how type promotion works in MLX. The inputs to an MLX
operation are typically promoted to a common type which doesn't lose precision.
For example:
```python
x = mx.array(1.0, mx.float32) * mx.array(2.0, mx.float16)
```
will result in `x` with type `mx.float32`. Similarly:
```python
x = mx.array(1.0, mx.bfloat16) * mx.array(2.0, mx.float16)
```
will result in `x` with type `mx.float32`. A common mistake is to multiply a
half-precision array by a default-typed scalar array which up-casts everything
to `mx.float32`:
```python
# Warning: x has type mx.float32
x = my_fp16_array * mx.array(2.0)
```
To multiply by a scalar while preserving the input type, use Python scalars.
Python scalars are weakly typed and have more relaxed promotion rules when
used with MLX arrays.
```python
# Ok, x has type mx.float16
x = my_fp16_array * 2.0
```
### Operations
#### Use Fast Ops
Use `mx.fast` ops when possible:
- `mx.fast.rms_norm`
- `mx.fast.layer_norm`
- `mx.fast.rope`
- `mx.fast.scaled_dot_product_attention`
A lot of these operations take a variety of parameters so they can be used for
different variations of the function. For example, the weight and bias
parameters are optional in `mx.fast.layer_norm` so it can be used with
different permutations of inputs.
#### Precision
For operations which typically use higher precision there is usually no
need to explicitly upcast. For example, `mx.fast.rms_norm` and
`mx.fast.layer_norm` accumulate in higher precision so it's
wasteful to upcast and downcast into and out of these operations:
```python
# No need for this!
mx.fast.rms_norm(x.astype(mx.float32), w, b, eps).astype(x.dtype)
# This is just as good:
mx.fast.rms_norm(x, w, b, eps)
```
Similarly, for `mx.softmax` use `precise=True` if you want to do the softmax in
higher precision rather than explicitly casting the input and output.
#### Misc
- For vector-matrix multiplication `x @ W.T` is faster than `x @ W`, for
matrix-vector multiplication `W @ x` is faster than `W.T @ x`
- Use `mx.addmm` for `a @ b + c` (e.g. a linear layer with a bias).
- Where it makes sense, use `mx.take_along_axis` and `mx.put_along_axis`
instead of fancy indexing
- Use broadcasting instead of concatenation. For example, prefer `mx.repeat(a,
n)` over `mx.concatenate([a] * n)`
### Compile
Compiling graphs with `mx.compile` can make them run a lot faster. But there
are some sharp-edges that are good to be aware of.
First, be aware of when a function will be recompiled. Recompilation is
relatively expensive and should only be done if there is sufficient work over
which to amortize the cost.
The default behavior of `mx.compile` is to do a shape-dependent compilation.
This means the function will be recompiled if the shape of any input changes.
MLX supports a shapeless compilation by passing `shapeless=True` to
`mx.compile`. It's easy to make hard-to-detect mistakes with shapeless
compilation. Make sure to read and understand the documentation and use it
with care:
https://ml-explore.github.io/mlx/build/html/usage/compile.html#shapeless-compilation
A function will also be recompiled if any constant inputs change:
```python
@mx.compile
def fun(x, scale):
return scale * x
fun(x, 3)
# Recompiles!
fun(x, 4)
```
In this case a simple fix is to make `scale` an `mx.array`.
#### Compiling Closures
Be careful when compiling a closure where the function encloses any
`mx.array`.
```python
y = some_function()
@mx.compile
def fun(x):
return x + y
```
Since `y` is not an input to `fun`, the compiled graph will include the entire
computation which produces `y`. Usually you only want to compute `y` one time
and re-use it in the compiled function. Either explicitly pass it as an input
to `fun` or pass it as an implicit input to `mx.compile` like so:
```python
y = some_function()
@partial(mx.compile, inputs=[y])
def fun(x):
return x + y
```
### Memory Use
#### Lazy Loading
Loading arrays from a file is lazy in MLX:
```python
weights = mx.load("model.safetensors")
```
The above function returns instantly, regardless of the file size. To actually
load the weights into memory, you can do `mx.eval(weights)`.
Assume the weights are stored on disk in 32-bit precision (i.e. `mx.float32`).
But for your model you only need 16-bit precision:
```python
weights = mx.load("model.safetensors")
mx.eval(weights)
weights = {k: v.astype(mx.float16) for k, v in weights.items()}
```
In the above, the weights will be loaded into memory in full precision and then
cast to 16-bit. This requires memory for all the weights in 32-bit plus memory
for the weights in 16-bit.
This is much better:
```python
weights = mx.load("model.safetensors")
weights = {k: v.astype(mx.float16) for k, v in weights.items()}
mx.eval(weights)
```
Evaluating after the cast to `mx.float16` reduces peak memory by nearly a
third. That's because all the weights are never fully materialized in 32-bit.
Right after each weight is loaded in 32-bit precision it is cast to 16-bit.
The memory for the 32-bit weight can be reused when loading the next weight.
Note, MLX is only able to lazy load from a file when it is given to `mx.load`
as a string path. Due to lifetime management issues, lazy loading from file
handles is not supported. So avoid this:
```python
weights = mx.load(open("model.safetensors", 'rb'))
```
#### Release Temporaries
One way to reduce memory consumption is to avoid holding
temporaries you don't need. This is a typical training loop:
```python
for x, y in dataset:
loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
optimizer.update(model, grads)
mx.eval(model, optimizer.state)
```
It's suboptimal since a reference to `grads` is held during the call to
`mx.eval` which keeps the respective memory from being used for any other part
of the computation.
This is better:
```python
def step(x, y):
loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
optimizer.update(model, grads)
return loss
for x, y in dataset:
loss = step(x, y)
mx.eval(model, optimizer.state)
```
In this case the reference to `grads` is released before `mx.eval` and the
memory can be reused. You can achieve the same goal using `del` as long as it's
before the call to `mx.eval`:
```python
for x, y in dataset:
loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
optimizer.update(model, grads)
del grads
mx.eval(model, optimizer.state)
```
#### Misc
- MLX will cache memory buffers of recently released arrays rather than
returning them to the system. In some cases, especially for variable shape
computations, the cache can get large. To help with this, MLX has some
functions for logging and customizing the behavior of memory allocation:
https://ml-explore.github.io/mlx/build/html/python/metal.html
### Profiling
A good first step is to check GPU utilization using, for example,
mactop: https://github.com/context-labs/mactop. If it's not pegged at close
to 100% then there is likely a non-MLX bottleneck somewhere in the program. A
common culprit is data loading or preprocessing.
If GPU utilization is good, a good next step is to figure out which operations
are taking up so much time. One way to do this is with the Metal debugger. For
that, see the documentation on profiling MLX with the Metal debugger:
https://ml-explore.github.io/mlx/build/html/dev/metal_debugger.html

246
README.md
View File

@@ -18,18 +18,20 @@ uv pip install git+https://github.com/Blaizzy/mlx-video.git
Supported models:
### LTX-2
[LTX-2](https://huggingface.co/Lightricks/LTX-Video) is 19B parameter video generation model from Lightricks
- [**LTX-2**](https://huggingface.co/Lightricks/LTX-Video) — 19B parameter video generation model from Lightricks
- [**Wan2.1**](https://github.com/Wan-Video/Wan2.1) — 1.3B / 14B parameter T2V models (single-model pipeline)
- [**Wan2.2**](https://github.com/Wan-Video/Wan2.2) — 14B parameter T2V model (dual-model pipeline)
## Features
- Text-to-video generation with the LTX-2 19B DiT model
- Two-stage generation pipeline for high-quality output
- 2x spatial upscaling for images and videos
- Text-to-video generation with multiple model families
- LTX-2: Two-stage pipeline with 2x spatial upscaling
- Wan2.1/2.2: Flow-matching diffusion with classifier-free guidance
- Optimized for Apple Silicon using MLX
---
## Usage
## LTX-2
> ** Info:** Currently, only the distilled variant is supported. Full LTX-2 feature support is coming soon.
@@ -53,7 +55,7 @@ python -m mlx_video.generate \
--output my_video.mp4
```
### CLI Options
### LTX-2 CLI Options
| Option | Default | Description |
|--------|---------|-------------|
@@ -67,45 +69,229 @@ python -m mlx_video.generate \
| `--save-frames` | false | Save individual frames as images |
| `--model-repo` | Lightricks/LTX-2 | HuggingFace model repository |
## How It Works
### How It Works (LTX-2)
The pipeline uses a two-stage generation process:
1. **Stage 1**: Generate at half resolution (e.g., 384x384) with 8 denoising steps
2. **Upsample**: 2x spatial upsampling via LatentUpsampler
3. **Stage 2**: Refine at full resolution (e.g., 768x768) with 3 denoising steps
1. **Stage 1**: Generate at half resolution (e.g., 384×384) with 8 denoising steps
2. **Upsample**: 2× spatial upsampling via LatentUpsampler
3. **Stage 2**: Refine at full resolution (e.g., 768×768) with 3 denoising steps
4. **Decode**: VAE decoder converts latents to RGB video
---
## Wan2.1 / Wan2.2
Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.com/Wan-Video/Wan2.2) are text-to-video diffusion models built on a DiT (Diffusion Transformer) backbone with a T5 text encoder and 3D VAE. They share the same model architecture — the difference is in the inference pipeline:
| | Wan2.1 | Wan2.2 |
|---|--------|--------|
| **Pipeline** | Single model | Dual model (high-noise + low-noise) |
| **Sizes** | 1.3B, 14B | 14B |
| **Steps** | 50 | 40 |
| **Guidance** | 5.0 (fixed) | 3.0 / 4.0 (low/high noise) |
| **Shift** | 5.0 | 12.0 |
### Step 1: Download Weights
Download the original PyTorch checkpoints:
**Wan2.1 (14B)**
```bash
# From https://github.com/Wan-Video/Wan2.1 or HuggingFace
# Expected directory structure:
# wan21_checkpoint/
# ├── models_t5_umt5-xxl-enc-bf16.pth
# ├── Wan2.1_VAE.pth
# └── diffusion_pytorch_model*.safetensors # single model
```
**Wan2.1 (1.3B)** — same structure, smaller transformer weights.
**Wan2.2 (14B)**
```bash
# From https://github.com/Wan-Video/Wan2.2 or HuggingFace
# Expected directory structure:
# wan22_checkpoint/
# ├── models_t5_umt5-xxl-enc-bf16.pth
# ├── Wan2.1_VAE.pth
# ├── low_noise_model/ # safetensors
# └── high_noise_model/ # safetensors
```
### Step 2: Convert to MLX Format
The conversion script auto-detects whether the checkpoint is Wan2.1 or Wan2.2 based on the directory structure (presence of `low_noise_model/` subdirectory).
```bash
# Auto-detect version
python -m mlx_video.convert_wan \
--checkpoint-dir /path/to/wan_checkpoint \
--output-dir wan_mlx
# Explicit version
python -m mlx_video.convert_wan \
--checkpoint-dir /path/to/wan21_checkpoint \
--output-dir wan21_mlx \
--model-version 2.1
python -m mlx_video.convert_wan \
--checkpoint-dir /path/to/wan22_checkpoint \
--output-dir wan22_mlx \
--model-version 2.2
```
#### Conversion Options
| Option | Default | Description |
|--------|---------|-------------|
| `--checkpoint-dir` | (required) | Path to original PyTorch checkpoint directory |
| `--output-dir` | `wan_mlx_model` | Output path for MLX model |
| `--dtype` | `bfloat16` | Target dtype (`float16`, `float32`, `bfloat16`) |
| `--model-version` | `auto` | Model version: `2.1`, `2.2`, or `auto` |
| `--quantize` | off | Quantize transformer weights for reduced memory |
| `--bits` | `4` | Quantization bits: `4` or `8` |
| `--group-size` | `64` | Quantization group size: `32`, `64`, or `128` |
The converter produces:
```
wan_mlx/
├── config.json # Model configuration
├── t5_encoder.safetensors # T5 UMT5-XXL text encoder
├── vae.safetensors # 3D VAE decoder
├── model.safetensors # (Wan2.1) Single transformer
├── low_noise_model.safetensors # (Wan2.2) Low-noise transformer
└── high_noise_model.safetensors # (Wan2.2) High-noise transformer
```
### Step 3: Generate Video
```bash
# Wan2.1 — uses defaults from config (50 steps, shift=5.0, guide=5.0)
python -m mlx_video.generate_wan \
--model-dir wan21_mlx \
--prompt "A cat playing piano in a cozy room"
# Wan2.2 — uses defaults from config (40 steps, shift=12.0, guide=3.0,4.0)
python -m mlx_video.generate_wan \
--model-dir wan22_mlx \
--prompt "A cat playing piano in a cozy room"
```
With custom settings:
```bash
python -m mlx_video.generate_wan \
--model-dir wan21_mlx \
--prompt "Ocean waves at sunset, cinematic, 4K" \
--negative-prompt "blurry, low quality" \
--width 1280 \
--height 720 \
--num-frames 81 \
--steps 50 \
--guide-scale 5.0 \
--shift 5.0 \
--seed 42 \
--output-path my_video.mp4
```
The pipeline auto-detects the model version from `config.json` and selects the right pipeline mode (single or dual model). You can also override any parameter via CLI flags.
#### Generation CLI Options
| Option | Default | Description |
|--------|---------|-------------|
| `--model-dir` | (required) | Path to converted MLX model directory |
| `--prompt` | (required) | Text description of the video |
| `--negative-prompt` | `""` | Negative prompt for guidance |
| `--width` | 1280 | Video width |
| `--height` | 720 | Video height |
| `--num-frames` | 81 | Number of frames (must be 4n+1) |
| `--steps` | from config | Number of diffusion steps |
| `--guide-scale` | from config | Guidance scale: float or `low,high` pair |
| `--shift` | from config | Noise schedule shift |
| `--seed` | -1 (random) | Random seed for reproducibility |
| `--output-path` | `output.mp4` | Output video path |
### Quantization (Reduced Memory)
Quantize the transformer weights to reduce memory usage by ~3.4x. This is especially useful for the 14B model or memory-constrained devices:
```bash
# Convert with 4-bit quantization
python -m mlx_video.convert_wan \
--checkpoint-dir /path/to/Wan2.1-T2V-1.3B \
--output-dir wan21_mlx_q4 \
--quantize --bits 4 --group-size 64
# Generate with quantized model (auto-detected from config.json)
python -m mlx_video.generate_wan \
--model-dir wan21_mlx_q4 \
--prompt "A cat playing piano"
```
**What gets quantized**: Self-attention (Q/K/V/O), cross-attention (Q/K/V/O), and FFN (fc1/fc2) — 10 layers × N blocks = ~95% of model weights. Embeddings, norms, and the output head remain in bfloat16 for precision.
| Model | BF16 Size | 4-bit Size | Notes |
|-------|-----------|------------|-------|
| 1.3B | 2.7 GB | 799 MB | ~3.4x smaller |
| 14B | ~28 GB | ~8 GB | Enables running on 16GB devices |
> **Note**: On Apple Silicon, the 1.3B model fits comfortably in unified memory at bf16. Quantization reduces memory but may not speed up inference for small models. For the 14B model, quantization is essential to fit in memory and will also improve speed.
### Wan Model Specifications
**Transformer (14B)**
- 40 layers, 40 attention heads, dim 5120, head dim 128
- 3-way factorized RoPE (temporal + spatial)
- 14.29B parameters
**Transformer (1.3B, Wan2.1 only)**
- 30 layers, 12 attention heads, dim 1536, head dim 128
- Same architecture, smaller scale
**Text Encoder** — UMT5-XXL (5.68B parameters)
- 24 layers, 64 heads, dim 4096, vocab 256K
**VAE** — 3D causal convolution decoder (72.6M parameters)
- Latent channels: 16
- Compression: 4× temporal, 8× spatial
---
## Requirements
- macOS with Apple Silicon
- Python >= 3.11
- MLX >= 0.22.0
## Model Specifications
- **Transformer**: 48 layers, 32 attention heads, 128 dim per head
- **Latent channels**: 128
- **Text encoder**: Gemma 3 with 3840-dim output
- **RoPE**: Split mode with double precision
- For weight conversion: PyTorch (`pip install torch`)
## Project Structure
```
mlx_video/
├── generate.py # Video generation pipeline
├── convert.py # Weight conversion (PyTorch -> MLX)
├── generate.py # LTX-2 generation pipeline
├── generate_wan.py # Wan2.1/2.2 generation pipeline
├── convert.py # LTX-2 weight conversion
├── convert_wan.py # Wan weight conversion (PyTorch → MLX)
├── postprocess.py # Video post-processing utilities
├── utils.py # Helper functions
└── models/
── ltx/
├── ltx.py # Main LTXModel (DiT transformer)
├── config.py # Model configuration
├── transformer.py # Transformer blocks
├── attention.py # Multi-head attention with RoPE
├── text_encoder.py # Text encoder
├── upsampler.py # 2x spatial upsampler
└── video_vae/ # VAE encoder/decoder
── ltx/ # LTX-2 model
├── ltx.py # DiT transformer
├── config.py # Configuration
├── transformer.py # Transformer blocks
├── attention.py # Multi-head attention with RoPE
├── text_encoder.py # Gemma 3 text encoder
├── upsampler.py # 2x spatial upsampler
└── video_vae/ # VAE encoder/decoder
└── wan/ # Wan2.1/2.2 model
├── config.py # Configuration (2.1 & 2.2 presets)
├── model.py # WanModel (DiT transformer)
├── transformer.py # Attention blocks with 6-element modulation
├── attention.py # Self/cross attention with QK-norm
├── rope.py # 3-way factorized RoPE
├── text_encoder.py # T5 UMT5-XXL encoder
├── vae.py # 3D causal VAE decoder
└── scheduler.py # Flow-matching Euler scheduler
```
## License

View File

@@ -1,9 +1,12 @@
from mlx_video.models.ltx import LTXModel, LTXModelConfig
from mlx_video.models.wan import WanModel, WanModelConfig
from mlx_video.convert import load_transformer_weights, load_vae_weights
import os
__all__ = [
"LTXModel",
"LTXModelConfig",
"WanModel",
"WanModelConfig",
"load_transformer_weights",
"load_vae_weights",
]

556
mlx_video/convert_wan.py Normal file
View File

@@ -0,0 +1,556 @@
"""Weight conversion for Wan2.2 models (PyTorch -> MLX)."""
import logging
from pathlib import Path
from typing import Dict
import mlx.core as mx
import mlx.utils
import numpy as np
def load_torch_weights(path: str) -> Dict[str, mx.array]:
"""Load PyTorch .pth weights and convert to MLX arrays.
Args:
path: Path to .pth file
Returns:
Dictionary of MLX arrays
"""
try:
import torch
except ImportError:
raise ImportError("PyTorch is required to load .pth weights: pip install torch")
logging.info(f"Loading weights from {path}")
state_dict = torch.load(path, map_location="cpu", weights_only=True)
weights = {}
for key, value in state_dict.items():
if isinstance(value, torch.Tensor):
np_val = value.detach().float().numpy()
weights[key] = mx.array(np_val)
return weights
def load_safetensors_weights(path: str) -> Dict[str, mx.array]:
"""Load safetensors weights as MLX arrays.
Args:
path: Path to directory with safetensors files or single file
Returns:
Dictionary of MLX arrays
"""
path = Path(path)
weights = {}
if path.is_file():
weights = mx.load(str(path))
elif path.is_dir():
for sf in sorted(path.glob("*.safetensors")):
weights.update(mx.load(str(sf)))
return weights
def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Convert Wan2.2 transformer weight keys to MLX model structure.
Wan2.2 keys follow the pattern:
patch_embedding.weight/bias
text_embedding.{0,2}.weight/bias
time_embedding.{0,2}.weight/bias
time_projection.1.weight/bias
blocks.{i}.norm1.weight
blocks.{i}.self_attn.{q,k,v,o}.weight/bias
blocks.{i}.self_attn.norm_q.weight
blocks.{i}.self_attn.norm_k.weight
blocks.{i}.norm3.weight/bias (if cross_attn_norm)
blocks.{i}.cross_attn.{q,k,v,o}.weight/bias
blocks.{i}.cross_attn.norm_q.weight
blocks.{i}.cross_attn.norm_k.weight
blocks.{i}.norm2.weight
blocks.{i}.ffn.{0,2}.weight/bias
blocks.{i}.modulation
head.norm.weight
head.head.weight/bias
head.modulation
freqs (buffer)
MLX model uses:
patch_embedding_proj.weight/bias (after patchify reshape)
text_embedding_0.weight/bias, text_embedding_1.weight/bias
time_embedding_0.weight/bias, time_embedding_1.weight/bias
time_projection.weight/bias
blocks.{i}.norm1.weight
blocks.{i}.self_attn.{q,k,v,o}.weight/bias
etc.
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Patch embedding: Conv3d(16, 5120, (1,2,2)) weight is [O, I, D, H, W]
# MLX Linear expects [O, I*D*H*W] after we flatten in patchify
if key == "patch_embedding.weight":
# Original: [dim, in_dim, 1, 2, 2] -> reshape to [dim, in_dim*1*2*2]
value = value.reshape(value.shape[0], -1)
new_key = "patch_embedding_proj.weight"
sanitized[new_key] = value
continue
if key == "patch_embedding.bias":
new_key = "patch_embedding_proj.bias"
sanitized[new_key] = value
continue
# Text embedding Sequential: 0=Linear, 1=GELU(no params), 2=Linear
if key.startswith("text_embedding.0."):
new_key = key.replace("text_embedding.0.", "text_embedding_0.")
sanitized[new_key] = value
continue
if key.startswith("text_embedding.2."):
new_key = key.replace("text_embedding.2.", "text_embedding_1.")
sanitized[new_key] = value
continue
# Time embedding Sequential: 0=Linear, 1=SiLU(no params), 2=Linear
if key.startswith("time_embedding.0."):
new_key = key.replace("time_embedding.0.", "time_embedding_0.")
sanitized[new_key] = value
continue
if key.startswith("time_embedding.2."):
new_key = key.replace("time_embedding.2.", "time_embedding_1.")
sanitized[new_key] = value
continue
# Time projection Sequential: 0=SiLU(no params), 1=Linear
if key.startswith("time_projection.1."):
new_key = key.replace("time_projection.1.", "time_projection.")
sanitized[new_key] = value
continue
# FFN: Sequential(Linear, GELU, Linear) -> ffn.{0,2} -> ffn.fc1, ffn.fc2
new_key = new_key.replace(".ffn.0.", ".ffn.fc1.")
new_key = new_key.replace(".ffn.2.", ".ffn.fc2.")
# Skip the freqs buffer (we compute it in the model)
if key == "freqs":
continue
sanitized[new_key] = value
return sanitized
def sanitize_wan_t5_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Convert Wan2.2 T5 encoder weight keys to MLX T5Encoder structure.
Wan2.2 T5 keys:
token_embedding.weight
pos_embedding.embedding.weight (if shared_pos)
blocks.{i}.norm1.weight
blocks.{i}.attn.{q,k,v,o}.weight
blocks.{i}.norm2.weight
blocks.{i}.ffn.gate.0.weight (gate linear)
blocks.{i}.ffn.fc1.weight
blocks.{i}.ffn.fc2.weight
blocks.{i}.pos_embedding.embedding.weight (if not shared_pos)
norm.weight
MLX T5Encoder structure:
token_embedding.weight
blocks.{i}.norm1.weight
blocks.{i}.attn.{q,k,v,o}.weight
blocks.{i}.norm2.weight
blocks.{i}.ffn.gate_proj.weight (mapped from gate.0)
blocks.{i}.ffn.fc1.weight
blocks.{i}.ffn.fc2.weight
blocks.{i}.pos_embedding.embedding.weight
norm.weight
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Map gate.0 -> gate_proj (the GELU is a separate module, not a parameter)
new_key = new_key.replace(".ffn.gate.0.", ".ffn.gate_proj.")
sanitized[new_key] = value
return sanitized
def sanitize_wan_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Convert Wan2.2 VAE weight keys to MLX WanVAE structure.
Handles Conv3d and Conv2d weight transpositions for MLX format.
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Handle Conv3d: PyTorch [O, I, D, H, W] -> MLX CausalConv3d weight [O, D, H, W, I]
if "weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Handle Conv2d: PyTorch [O, I, H, W] -> MLX [O, H, W, I]
if "weight" in key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
# Map decoder keys to MLX decoder structure
# Wan2.2 uses encoder/decoder with downsamples/upsamples
# Need to adapt naming for our simplified structure
sanitized[new_key] = value
return sanitized
def convert_wan_checkpoint(
checkpoint_dir: str,
output_dir: str,
dtype: str = "bfloat16",
model_version: str = "auto",
quantize: bool = False,
bits: int = 4,
group_size: int = 64,
):
"""Convert a Wan2.1 or Wan2.2 checkpoint directory to MLX format.
Wan2.2 expected structure:
checkpoint_dir/
models_t5_umt5-xxl-enc-bf16.pth
Wan2.1_VAE.pth
low_noise_model/ (safetensors)
high_noise_model/ (safetensors)
Wan2.1 expected structure:
checkpoint_dir/
models_t5_umt5-xxl-enc-bf16.pth
Wan2.1_VAE.pth
diffusion_pytorch_model*.safetensors (single model)
Args:
checkpoint_dir: Path to Wan checkpoint directory
output_dir: Path to output MLX model directory
dtype: Target dtype
model_version: "2.1", "2.2", or "auto" (detect from directory)
quantize: Whether to quantize the transformer weights
bits: Quantization bits (4 or 8)
group_size: Quantization group size (32, 64, or 128)
"""
import json
checkpoint_dir = Path(checkpoint_dir)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
dtype_map = {
"float16": mx.float16,
"float32": mx.float32,
"bfloat16": mx.bfloat16,
}
target_dtype = dtype_map.get(dtype, mx.bfloat16)
# Auto-detect version
if model_version == "auto":
if (checkpoint_dir / "low_noise_model").exists():
model_version = "2.2"
elif (checkpoint_dir / "Wan2.2_VAE.pth").exists():
model_version = "2.2"
else:
model_version = "2.1"
print(f"Auto-detected Wan{model_version} checkpoint")
is_dual = (checkpoint_dir / "low_noise_model").exists()
if is_dual:
# Wan2.2: Convert dual transformer models
low_noise_path = checkpoint_dir / "low_noise_model"
if low_noise_path.exists():
print("Converting low-noise transformer...")
weights = load_safetensors_weights(str(low_noise_path))
weights = sanitize_wan_transformer_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "low_noise_model.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
high_noise_path = checkpoint_dir / "high_noise_model"
if high_noise_path.exists():
print("Converting high-noise transformer...")
weights = load_safetensors_weights(str(high_noise_path))
weights = sanitize_wan_transformer_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "high_noise_model.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
else:
# Wan2.1: Convert single transformer model
# Try safetensors in the checkpoint dir itself
print("Converting transformer (single model)...")
weights = load_safetensors_weights(str(checkpoint_dir))
if not weights:
# Fallback: look for .pth files
for pth in sorted(checkpoint_dir.glob("*.pth")):
if "t5" not in pth.name.lower() and "vae" not in pth.name.lower():
print(f" Loading from {pth.name}...")
weights = load_torch_weights(str(pth))
break
if weights:
weights = sanitize_wan_transformer_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "model.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
else:
print(" Warning: No transformer weights found!")
# Save config — detect model size from source config.json or transformer weights
from mlx_video.models.wan.config import WanModelConfig
def _detect_config():
"""Detect config from source config.json or transformer weight shapes."""
if is_dual:
return WanModelConfig.wan22_t2v_14b()
# Try reading source config.json first (most reliable)
src_cfg_path = checkpoint_dir / "config.json"
src_config = None
if src_cfg_path.exists():
with open(src_cfg_path) as f:
src_config = json.load(f)
if src_config and "dim" in src_config:
src_dim = src_config.get("dim", 5120)
src_in_dim = src_config.get("in_dim", 16)
src_out_dim = src_config.get("out_dim", 16)
src_ffn_dim = src_config.get("ffn_dim", 13824)
src_num_heads = src_config.get("num_heads", 40)
src_num_layers = src_config.get("num_layers", 40)
src_model_type = src_config.get("model_type", "t2v")
src_text_len = src_config.get("text_len", 512)
print(f" Source config: dim={src_dim}, layers={src_num_layers}, "
f"heads={src_num_heads}, type={src_model_type}")
is_22 = model_version == "2.2"
# Wan2.2 uses different VAE with z_dim=48 and stride (4,16,16)
vae_z = 48 if is_22 else 16
vae_s = (4, 16, 16) if is_22 else (4, 8, 8)
fps = 24 if is_22 else 16
return WanModelConfig(
model_type=src_model_type,
model_version=model_version,
dim=src_dim,
ffn_dim=src_ffn_dim,
in_dim=src_in_dim,
out_dim=src_out_dim,
num_heads=src_num_heads,
num_layers=src_num_layers,
text_len=src_text_len,
vae_z_dim=vae_z,
vae_stride=vae_s,
dual_model=False,
boundary=0.0,
sample_shift=5.0,
sample_steps=50,
sample_guide_scale=5.0,
sample_fps=fps,
)
# Fallback: detect from saved transformer weight shapes
saved_model = output_dir / "model.safetensors"
if saved_model.exists():
det_weights = mx.load(str(saved_model))
dim = None
for k, v in det_weights.items():
if "patch_embedding_proj.weight" in k:
dim = v.shape[0]
break
del det_weights
if dim is not None and dim <= 2048:
print(f" Auto-detected 1.3B model (dim={dim})")
return WanModelConfig.wan21_t2v_1_3b()
return WanModelConfig.wan21_t2v_14b()
config = _detect_config()
config_path = output_dir / "config.json"
with open(config_path, "w") as f:
json.dump(config.to_dict(), f, indent=2)
print(f" Saved config to {config_path}")
# Convert T5 encoder
t5_path = checkpoint_dir / "models_t5_umt5-xxl-enc-bf16.pth"
if t5_path.exists():
print("Converting T5 encoder...")
weights = load_torch_weights(str(t5_path))
weights = sanitize_wan_t5_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "t5_encoder.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
# Convert VAE (check both naming conventions)
vae_path = checkpoint_dir / "Wan2.1_VAE.pth"
is_wan22_vae = False
if not vae_path.exists():
vae_path = checkpoint_dir / "Wan2.2_VAE.pth"
is_wan22_vae = True
if vae_path.exists():
print(f"Converting VAE ({'Wan2.2' if is_wan22_vae else 'Wan2.1'})...")
weights = load_torch_weights(str(vae_path))
if is_wan22_vae:
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = sanitize_wan22_vae_weights(weights)
else:
weights = sanitize_wan_vae_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "vae.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
# Quantize transformer weights if requested
if quantize:
print(f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})...")
_quantize_saved_model(output_dir, config, is_dual, bits, group_size)
print(f"\nConversion complete! Output: {output_dir}")
def _quantize_predicate(path: str, module) -> bool:
"""Return True for layers that should be quantized.
Targets heavyweight Linear layers in attention and FFN blocks.
Skips embeddings, norms, head, and modulation (small, precision-sensitive).
"""
if not hasattr(module, "to_quantized"):
return False
# Quantize attention Q/K/V/O and FFN fc1/fc2
quantize_patterns = (
".self_attn.q", ".self_attn.k", ".self_attn.v", ".self_attn.o",
".cross_attn.q", ".cross_attn.k", ".cross_attn.v", ".cross_attn.o",
".ffn.fc1", ".ffn.fc2",
)
return any(path.endswith(p) for p in quantize_patterns)
def _quantize_saved_model(
output_dir: Path,
config,
is_dual: bool,
bits: int,
group_size: int,
):
"""Load saved bf16 model, quantize, and re-save."""
import json
import mlx.nn as nn
from mlx_video.models.wan.model import WanModel
model_files = []
if is_dual:
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
p = output_dir / name
if p.exists():
model_files.append(p)
else:
p = output_dir / "model.safetensors"
if p.exists():
model_files.append(p)
for model_path in model_files:
print(f" Quantizing {model_path.name}...")
model = WanModel(config)
weights = mx.load(str(model_path))
model.load_weights(list(weights.items()), strict=False)
# Apply quantization to targeted layers
nn.quantize(
model,
group_size=group_size,
bits=bits,
class_predicate=lambda path, m: _quantize_predicate(path, m),
)
# Save quantized weights
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
mx.save_safetensors(str(model_path), weights_dict)
n_quantized = sum(1 for k in weights_dict if ".scales" in k)
print(f" {n_quantized} layers quantized, {len(weights_dict)} tensors saved")
# Update config.json with quantization metadata
config_path = output_dir / "config.json"
with open(config_path) as f:
cfg = json.load(f)
cfg["quantization"] = {
"group_size": group_size,
"bits": bits,
}
with open(config_path, "w") as f:
json.dump(cfg, f, indent=2)
print(f" Updated config.json with quantization metadata")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Convert Wan model to MLX format")
parser.add_argument(
"--checkpoint-dir",
type=str,
required=True,
help="Path to Wan checkpoint directory",
)
parser.add_argument(
"--output-dir",
type=str,
default="wan_mlx_model",
help="Output path for MLX model",
)
parser.add_argument(
"--dtype",
type=str,
choices=["float16", "float32", "bfloat16"],
default="bfloat16",
help="Target dtype",
)
parser.add_argument(
"--model-version",
type=str,
choices=["2.1", "2.2", "auto"],
default="auto",
help="Wan model version (auto-detect by default)",
)
parser.add_argument(
"--quantize",
action="store_true",
help="Quantize transformer weights for faster inference",
)
parser.add_argument(
"--bits",
type=int,
choices=[4, 8],
default=4,
help="Quantization bits (default: 4)",
)
parser.add_argument(
"--group-size",
type=int,
choices=[32, 64, 128],
default=64,
help="Quantization group size (default: 64)",
)
args = parser.parse_args()
convert_wan_checkpoint(
args.checkpoint_dir, args.output_dir, args.dtype, args.model_version,
quantize=args.quantize, bits=args.bits, group_size=args.group_size,
)

512
mlx_video/generate_wan.py Normal file
View File

@@ -0,0 +1,512 @@
"""Wan2.2 Text-to-Video generation pipeline for MLX."""
import argparse
import gc
import math
import random
import sys
import time
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from tqdm import tqdm
class Colors:
CYAN = "\033[96m"
BLUE = "\033[94m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
RED = "\033[91m"
MAGENTA = "\033[95m"
BOLD = "\033[1m"
DIM = "\033[2m"
RESET = "\033[0m"
def load_wan_model(model_path: Path, config, quantization: dict | None = None):
"""Load and initialize WanModel, with optional quantization support.
Args:
model_path: Path to model safetensors file
config: WanModelConfig
quantization: Optional dict with 'bits' and 'group_size' keys.
If provided, creates QuantizedLinear stubs before loading.
"""
from mlx_video.models.wan.model import WanModel
model = WanModel(config)
if quantization:
from mlx_video.convert_wan import _quantize_predicate
nn.quantize(
model,
group_size=quantization["group_size"],
bits=quantization["bits"],
class_predicate=lambda path, m: _quantize_predicate(path, m),
)
weights = mx.load(str(model_path))
model.load_weights(list(weights.items()), strict=False)
mx.eval(model.parameters())
return model
def load_t5_encoder(model_path: Path, config):
"""Load T5 text encoder."""
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=config.t5_vocab_size,
dim=config.t5_dim,
dim_attn=config.t5_dim_attn,
dim_ffn=config.t5_dim_ffn,
num_heads=config.t5_num_heads,
num_layers=config.t5_num_layers,
num_buckets=config.t5_num_buckets,
shared_pos=False,
)
weights = mx.load(str(model_path))
encoder.load_weights(list(weights.items()))
mx.eval(encoder.parameters())
return encoder
def load_vae_decoder(model_path: Path, config=None):
"""Load VAE decoder (skips encoder weights with strict=False).
For Wan2.2 (vae_z_dim=48), uses Wan22VAEDecoder.
For Wan2.1 (vae_z_dim=16), uses WanVAE.
"""
is_wan22 = config is not None and config.vae_z_dim == 48
if is_wan22:
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
vae = Wan22VAEDecoder(z_dim=48)
else:
from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16)
weights = mx.load(str(model_path))
vae.load_weights(list(weights.items()), strict=False)
mx.eval(vae.parameters())
return vae
def encode_text(
encoder,
tokenizer,
prompt: str,
text_len: int = 512,
) -> mx.array:
"""Encode text prompt using T5 encoder.
Args:
encoder: T5Encoder model
tokenizer: HuggingFace tokenizer
prompt: Text prompt
text_len: Maximum text length
Returns:
Text embeddings [L, dim]
"""
tokens = tokenizer(
prompt,
max_length=text_len,
padding="max_length",
truncation=True,
return_tensors="np",
)
ids = mx.array(tokens["input_ids"])
mask = mx.array(tokens["attention_mask"])
embeddings = encoder(ids, mask=mask)
# Return only non-padding tokens
seq_len = int(mask.sum().item())
return embeddings[0, :seq_len]
def generate_video(
model_dir: str,
prompt: str,
negative_prompt: str = "",
width: int = 1280,
height: int = 720,
num_frames: int = 81,
steps: int = None,
guide_scale: str | float | tuple = None,
shift: float = None,
seed: int = -1,
output_path: str = "output.mp4",
):
"""Generate video using Wan T2V pipeline (supports 2.1 and 2.2).
Args:
model_dir: Path to converted MLX model directory
prompt: Text prompt
negative_prompt: Negative prompt
width: Video width
height: Video height
num_frames: Number of frames (must be 4n+1)
steps: Number of diffusion steps (None = use config default)
guide_scale: Guidance scale: float for single, (low,high) for dual (None = config default)
shift: Noise schedule shift (None = use config default)
seed: Random seed (-1 for random)
output_path: Output video path
"""
import json
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
model_dir = Path(model_dir)
# Load config from model dir if available, otherwise auto-detect
config_path = model_dir / "config.json"
quantization = None
if config_path.exists():
with open(config_path) as f:
config_dict = json.load(f)
# Extract quantization config (not a model config field)
quantization = config_dict.pop("quantization", None)
# Handle tuple fields stored as lists in JSON
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
if key in config_dict and isinstance(config_dict[key], list):
config_dict[key] = tuple(config_dict[key])
config = WanModelConfig(**{
k: v for k, v in config_dict.items()
if k in WanModelConfig.__dataclass_fields__
})
else:
# Auto-detect: dual model files → 2.2, single model → 2.1
if (model_dir / "low_noise_model.safetensors").exists():
config = WanModelConfig.wan22_t2v_14b()
else:
# Detect 1.3B vs 14B from weight shapes
model_path = model_dir / "model.safetensors"
if model_path.exists():
probe = mx.load(str(model_path), return_metadata=False)
for k, v in probe.items():
if "patch_embedding_proj.weight" in k:
dim = v.shape[0]
if dim <= 2048:
config = WanModelConfig.wan21_t2v_1_3b()
else:
config = WanModelConfig.wan21_t2v_14b()
break
else:
config = WanModelConfig.wan21_t2v_14b()
del probe
else:
config = WanModelConfig.wan21_t2v_14b()
is_dual = config.dual_model
# Validate config against actual weights (handles mismatched config.json)
if not is_dual:
model_path = model_dir / "model.safetensors"
if model_path.exists():
probe = mx.load(str(model_path), return_metadata=False)
for k, v in probe.items():
if "patch_embedding_proj.weight" in k:
actual_dim = v.shape[0]
if actual_dim != config.dim:
print(f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}")
if actual_dim <= 2048:
config = WanModelConfig.wan21_t2v_1_3b()
else:
config = WanModelConfig.wan21_t2v_14b()
break
del probe
# Auto-correct Wan2.2 VAE params from stale configs
if config.in_dim == 48 and config.vae_z_dim != 48:
print(f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}")
config = WanModelConfig(**{
**{f.name: getattr(config, f.name) for f in config.__dataclass_fields__.values()},
"vae_z_dim": 48,
"vae_stride": (4, 16, 16),
"sample_fps": 24,
})
# Apply defaults from config if not overridden
if steps is None:
steps = config.sample_steps
if shift is None:
shift = config.sample_shift
if guide_scale is None:
guide_scale = config.sample_guide_scale
# Normalize guide_scale
if isinstance(guide_scale, (int, float)):
guide_scale = float(guide_scale)
elif isinstance(guide_scale, str):
parts = [float(x) for x in guide_scale.split(",")]
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
# Validate frame count
assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}"
version_str = f"Wan{config.model_version}"
mode_str = "dual-model" if is_dual else "single-model"
print(f"{Colors.CYAN}{'='*60}")
print(f" {version_str} Text-to-Video Generation (MLX, {mode_str})")
print(f"{'='*60}{Colors.RESET}")
print(f"{Colors.DIM} Prompt: {prompt}")
print(f" Size: {width}x{height}, Frames: {num_frames}")
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}")
print(f"{Colors.RESET}")
# Seed
if seed < 0:
seed = random.randint(0, 2**32 - 1)
mx.random.seed(seed)
np.random.seed(seed)
print(f"{Colors.DIM} Seed: {seed}{Colors.RESET}")
# Compute target latent shape
vae_stride = config.vae_stride
z_dim = config.vae_z_dim
t_latent = (num_frames - 1) // vae_stride[0] + 1
h_latent = height // vae_stride[1]
w_latent = width // vae_stride[2]
target_shape = (z_dim, t_latent, h_latent, w_latent)
# Sequence length for transformer
patch_size = config.patch_size
seq_len = math.ceil(
(h_latent * w_latent) / (patch_size[1] * patch_size[2]) * t_latent
)
print(f"{Colors.DIM} Latent shape: {target_shape}")
print(f" Sequence length: {seq_len}{Colors.RESET}")
# Load T5 encoder
t1 = time.time()
print(f"\n{Colors.BLUE}Loading T5 encoder...{Colors.RESET}")
t5_path = model_dir / "t5_encoder.safetensors"
t5_encoder = load_t5_encoder(t5_path, config)
# Load tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
# Encode prompts
print(f"{Colors.BLUE}Encoding text...{Colors.RESET}")
context = encode_text(t5_encoder, tokenizer, prompt, config.text_len)
if negative_prompt:
context_null = encode_text(t5_encoder, tokenizer, negative_prompt, config.text_len)
else:
context_null = encode_text(t5_encoder, tokenizer, "", config.text_len)
mx.eval(context, context_null)
# Free T5 from memory
del t5_encoder
gc.collect(); mx.clear_cache()
print(f"{Colors.DIM} T5 encoding: {time.time() - t1:.1f}s{Colors.RESET}")
# Load transformer models
print(f"\n{Colors.BLUE}Loading transformer model(s)...{Colors.RESET}")
if quantization:
print(f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}")
t2 = time.time()
if is_dual:
low_noise_path = model_dir / "low_noise_model.safetensors"
high_noise_path = model_dir / "high_noise_model.safetensors"
low_noise_model = load_wan_model(low_noise_path, config, quantization)
high_noise_model = load_wan_model(high_noise_path, config, quantization)
else:
single_model = load_wan_model(model_dir / "model.safetensors", config, quantization)
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
# Precompute text embeddings once (avoids redundant MLP in every step)
ref_model = single_model if not is_dual else low_noise_model
context_emb = ref_model.embed_text([context, context_null])
mx.eval(context_emb)
context_cond = context_emb[0:1] # [1, text_len, dim]
context_uncond = context_emb[1:2] # [1, text_len, dim]
# Stack for batched CFG: [2, text_len, dim]
context_cfg = mx.concatenate([context_cond, context_uncond], axis=0)
# Precompute cross-attention K/V caches (constant across all steps)
if is_dual:
cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg)
cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg)
mx.eval(cross_kv_low, cross_kv_high)
else:
cross_kv = single_model.prepare_cross_kv(context_cfg)
mx.eval(cross_kv)
# Setup scheduler
scheduler = FlowMatchEulerScheduler(num_train_timesteps=config.num_train_timesteps)
scheduler.set_timesteps(steps, shift=shift)
# Generate initial noise
noise = mx.random.normal(target_shape)
# Boundary for model switching (dual model only)
boundary = (config.boundary * config.num_train_timesteps) if is_dual else None
# Diffusion loop
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
latents = noise
t3 = time.time()
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
timestep_val = scheduler.timesteps[i].item()
# Select model, guide scale, and cached K/V
if is_dual:
if timestep_val >= boundary:
model = high_noise_model
gs = guide_scale[1]
kv = cross_kv_high
else:
model = low_noise_model
gs = guide_scale[0]
kv = cross_kv_low
else:
model = single_model
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
kv = cross_kv
# CFG: batch cond + uncond into single B=2 forward pass
preds = model(
[latents, latents],
t=mx.array([timestep_val, timestep_val]),
context=context_cfg,
seq_len=seq_len,
cross_kv_caches=kv,
)
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
# Classifier-free guidance + scheduler step
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
latents = scheduler.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
# Release temporaries before eval to free memory for graph execution
del noise_pred_cond, noise_pred_uncond, noise_pred, preds
mx.eval(latents)
print(f"{Colors.DIM} Denoising: {time.time() - t3:.1f}s{Colors.RESET}")
# Free transformer models and text embeddings
if is_dual:
del low_noise_model, high_noise_model, cross_kv_low, cross_kv_high
else:
del single_model, cross_kv
del model, kv, context, context_null, context_cfg
gc.collect(); mx.clear_cache()
# Load VAE and decode
print(f"\n{Colors.BLUE}Decoding with VAE...{Colors.RESET}")
t4 = time.time()
vae_path = model_dir / "vae.safetensors"
vae = load_vae_decoder(vae_path, config)
is_wan22_vae = config.vae_z_dim == 48
if is_wan22_vae:
from mlx_video.models.wan.vae22 import denormalize_latents
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
z = latents.transpose(1, 2, 3, 0)[None] # [1, T, H, W, C]
z = denormalize_latents(z)
video = vae(z) # [1, T', H', W', 3]
mx.eval(video)
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
video = np.array(video[0]) # [T', H', W', 3]
video = (video + 1.0) / 2.0
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
else:
video = vae.decode(latents[None]) # [1, 3, T, H, W]
mx.eval(video)
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
video = np.array(video[0]) # [3, T, H, W]
video = (video + 1.0) / 2.0
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
video = video.transpose(1, 2, 3, 0) # [T, H, W, 3]
save_video(video, output_path, fps=config.sample_fps)
print(f"\n{Colors.GREEN}✓ Video saved to {output_path}{Colors.RESET}")
print(f"{Colors.DIM} Total time: {time.time() - t1:.1f}s{Colors.RESET}")
def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
"""Save video frames to MP4.
Args:
frames: Video frames [T, H, W, 3] uint8
output_path: Output file path
fps: Frames per second
"""
try:
import imageio
writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8)
for frame in frames:
writer.append_data(frame)
writer.close()
except ImportError:
try:
import cv2
h, w = frames.shape[1], frames.shape[2]
fourcc = cv2.VideoWriter_fourcc(*"avc1")
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
for frame in frames:
writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
writer.release()
except (ImportError, Exception):
# Last resort: save as individual PNGs
from PIL import Image
out_dir = Path(output_path).parent / Path(output_path).stem
out_dir.mkdir(parents=True, exist_ok=True)
for i, frame in enumerate(frames):
Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png")
print(f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)")
def main():
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory")
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
parser.add_argument("--negative-prompt", type=str, default="", help="Negative prompt")
parser.add_argument("--width", type=int, default=1280, help="Video width")
parser.add_argument("--height", type=int, default=720, help="Video height")
parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)")
parser.add_argument("--steps", type=int, default=None, help="Number of diffusion steps (default: from config)")
parser.add_argument("--guide-scale", type=str, default=None, help="Guidance scale: single float or low,high pair")
parser.add_argument("--shift", type=float, default=None, help="Noise schedule shift (default: from config)")
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
parser.add_argument("--output-path", type=str, default="output.mp4", help="Output video path")
args = parser.parse_args()
# Parse guide scale
guide_scale = None
if args.guide_scale is not None:
parts = [float(x) for x in args.guide_scale.split(",")]
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
generate_video(
model_dir=args.model_dir,
prompt=args.prompt,
negative_prompt=args.negative_prompt,
width=args.width,
height=args.height,
num_frames=args.num_frames,
steps=args.steps,
guide_scale=guide_scale,
shift=args.shift,
seed=args.seed,
output_path=args.output_path,
)
if __name__ == "__main__":
main()

View File

@@ -1,2 +1,3 @@
from mlx_video.models.ltx import LTXModel, LTXModelConfig
from mlx_video.models.wan import WanModel, WanModelConfig

View File

@@ -0,0 +1,2 @@
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan.model import WanModel

View File

@@ -0,0 +1,201 @@
import mlx.core as mx
import mlx.nn as nn
from .rope import rope_apply
class WanRMSNorm(nn.Module):
"""RMS normalization with learnable scale."""
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = mx.ones((dim,))
def __call__(self, x: mx.array) -> mx.array:
return mx.fast.rms_norm(x, self.weight, self.eps)
class WanLayerNorm(nn.Module):
"""LayerNorm computed in float32, with optional affine."""
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False):
super().__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
if elementwise_affine:
self.weight = mx.ones((dim,))
self.bias = mx.zeros((dim,))
def __call__(self, x: mx.array) -> mx.array:
if self.elementwise_affine:
return mx.fast.layer_norm(x, self.weight, self.bias, self.eps)
else:
return mx.fast.layer_norm(x, None, None, self.eps)
class WanSelfAttention(nn.Module):
"""Self-attention with QK normalization and 3-way factorized RoPE."""
def __init__(
self,
dim: int,
num_heads: int,
window_size: tuple = (-1, -1),
qk_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
assert dim % num_heads == 0
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.scale = self.head_dim**-0.5
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
def __call__(
self,
x: mx.array,
seq_lens: list,
grid_sizes: list,
freqs: mx.array,
) -> mx.array:
b, s, _ = x.shape
n, d = self.num_heads, self.head_dim
q = self.q(x)
k = self.k(x)
if self.norm_q is not None:
q = self.norm_q(q)
if self.norm_k is not None:
k = self.norm_k(k)
q = q.reshape(b, s, n, d)
k = k.reshape(b, s, n, d)
v = self.v(x).reshape(b, s, n, d)
# Apply RoPE
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
# Scaled dot-product attention: [B, L, N, D] -> [B, N, L, D]
q = q.transpose(0, 2, 1, 3)
k = k.transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
# Build attention mask from seq_lens
max_len = s
mask = None
if any(sl < max_len for sl in seq_lens):
mask = mx.zeros((b, 1, 1, max_len), dtype=q.dtype)
for i, sl in enumerate(seq_lens):
mask[i, :, :, sl:] = -1e9
# Use memory-efficient scaled dot-product attention
# mx.fast.scaled_dot_product_attention expects [B, N, L, D]
if mask is not None:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale, mask=mask
)
else:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale
)
out = out.transpose(0, 2, 1, 3).reshape(b, s, -1)
return self.o(out)
class WanCrossAttention(nn.Module):
"""Cross-attention: Q from hidden states, K/V from text context."""
def __init__(
self,
dim: int,
num_heads: int,
qk_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
def prepare_kv(self, context: mx.array) -> tuple:
"""Pre-compute K and V projections for caching.
Args:
context: [B, L_ctx, dim]
Returns:
(k, v) each [B, N, L_ctx, D] ready for attention
"""
b = context.shape[0]
n, d = self.num_heads, self.head_dim
k = self.k(context)
if self.norm_k is not None:
k = self.norm_k(k)
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
return k, v
def __call__(
self,
x: mx.array,
context: mx.array,
context_lens: list | None = None,
kv_cache: tuple | None = None,
) -> mx.array:
b = x.shape[0]
n, d = self.num_heads, self.head_dim
q = self.q(x)
if self.norm_q is not None:
q = self.norm_q(q)
q = q.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
if kv_cache is not None:
k, v = kv_cache
else:
k = self.k(context)
if self.norm_k is not None:
k = self.norm_k(k)
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
# Optional context masking
mask = None
if context_lens is not None:
ctx_len = k.shape[2]
mask = mx.zeros((b, 1, 1, ctx_len), dtype=q.dtype)
for i, cl in enumerate(context_lens):
mask[i, :, :, cl:] = -1e9
if mask is not None:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale, mask=mask
)
else:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale
)
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d)
return self.o(out)

View File

@@ -0,0 +1,86 @@
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Union
from mlx_video.models.ltx.config import BaseModelConfig
@dataclass
class WanModelConfig(BaseModelConfig):
"""Configuration for Wan T2V models (supports both 2.1 and 2.2)."""
model_type: str = "t2v"
model_version: str = "2.2"
patch_size: Tuple[int, int, int] = (1, 2, 2)
text_len: int = 512
in_dim: int = 16
dim: int = 5120
ffn_dim: int = 13824
freq_dim: int = 256
text_dim: int = 4096
out_dim: int = 16
num_heads: int = 40
num_layers: int = 40
window_size: Tuple[int, int] = (-1, -1)
qk_norm: bool = True
cross_attn_norm: bool = True
eps: float = 1e-6
# VAE
vae_stride: Tuple[int, int, int] = (4, 8, 8)
vae_z_dim: int = 16
# Inference
dual_model: bool = True
boundary: float = 0.875
sample_shift: float = 12.0
sample_steps: int = 40
sample_guide_scale: Union[float, Tuple[float, float]] = (3.0, 4.0)
num_train_timesteps: int = 1000
sample_fps: int = 16
frame_num: int = 81
# T5
t5_vocab_size: int = 256384
t5_dim: int = 4096
t5_dim_attn: int = 4096
t5_dim_ffn: int = 10240
t5_num_heads: int = 64
t5_num_layers: int = 24
t5_num_buckets: int = 32
@property
def head_dim(self) -> int:
return self.dim // self.num_heads
@classmethod
def wan21_t2v_14b(cls) -> "WanModelConfig":
"""Wan2.1 T2V 14B: single model, 40 layers, dim=5120."""
return cls(
model_version="2.1",
dual_model=False,
boundary=0.0,
sample_shift=5.0,
sample_steps=50,
sample_guide_scale=5.0,
)
@classmethod
def wan21_t2v_1_3b(cls) -> "WanModelConfig":
"""Wan2.1 T2V 1.3B: single model, 30 layers, dim=1536."""
return cls(
model_version="2.1",
dim=1536,
ffn_dim=8960,
num_heads=12,
num_layers=30,
dual_model=False,
boundary=0.0,
sample_shift=5.0,
sample_steps=50,
sample_guide_scale=5.0,
)
@classmethod
def wan22_t2v_14b(cls) -> "WanModelConfig":
"""Wan2.2 T2V 14B: dual model, 40 layers, dim=5120 (default)."""
return cls()

View File

@@ -0,0 +1,307 @@
import math
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .attention import WanLayerNorm
from .config import WanModelConfig
from .rope import rope_params
from .transformer import WanAttentionBlock
def sinusoidal_embedding_1d(dim: int, position: mx.array) -> mx.array:
"""Compute sinusoidal positional embeddings.
Args:
dim: Embedding dimension (must be even).
position: 1D tensor of positions.
Returns:
Embeddings of shape [len(position), dim].
"""
assert dim % 2 == 0
half = dim // 2
pos = position.astype(mx.float32)
inv_freq = mx.power(10000.0, -mx.arange(half).astype(mx.float32) / half)
sinusoid = pos[:, None] * inv_freq[None, :]
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=1)
class Head(nn.Module):
"""Output projection head with learned modulation."""
def __init__(self, dim: int, out_dim: int, patch_size: tuple, eps: float = 1e-6):
super().__init__()
self.out_dim = out_dim
self.patch_size = patch_size
proj_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, proj_dim)
self.modulation = mx.random.normal((1, 2, dim)) * (dim**-0.5)
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
"""
Args:
x: [B, L, dim]
e: [B, dim] or [B, 1, dim] (time embedding, broadcast to all tokens)
"""
if e.ndim == 2:
e = e[:, None, :] # [B, 1, dim]
e_f32 = e.astype(mx.float32)
mod = (self.modulation + e_f32) # broadcasts [1, 2, dim] + [B, 1, dim] -> [B, 2, dim]
e0 = mod[:, 0:1, :] # [B, 1, dim] shift
e1 = mod[:, 1:2, :] # [B, 1, dim] scale
x_norm = self.norm(x).astype(mx.float32)
x_mod = x_norm * (1 + e1) + e0 # broadcasts over L
return self.head(x_mod.astype(x.dtype))
class WanModel(nn.Module):
"""Wan2.2 diffusion backbone for text-to-video generation."""
def __init__(self, config: WanModelConfig):
super().__init__()
self.config = config
dim = config.dim
self.dim = dim
self.num_heads = config.num_heads
self.out_dim = config.out_dim
self.patch_size = config.patch_size
self.text_len = config.text_len
self.freq_dim = config.freq_dim
# Patch embedding: Conv3d implemented as a reshaped linear
# For kernel (1,2,2) and stride (1,2,2): reshape input then linear
patch_dim = config.in_dim * math.prod(config.patch_size)
self.patch_embedding_proj = nn.Linear(patch_dim, dim)
self._patch_size = config.patch_size
# Text embedding MLP
self.text_embedding_0 = nn.Linear(config.text_dim, dim)
self.text_embedding_act = nn.GELU(approx="precise")
self.text_embedding_1 = nn.Linear(dim, dim)
# Time embedding MLP
self.time_embedding_0 = nn.Linear(config.freq_dim, dim)
self.time_embedding_act = nn.SiLU()
self.time_embedding_1 = nn.Linear(dim, dim)
# Time projection for modulation (6x dim)
self.time_projection_act = nn.SiLU()
self.time_projection = nn.Linear(dim, dim * 6)
# Transformer blocks
self.blocks = [
WanAttentionBlock(
dim=dim,
ffn_dim=config.ffn_dim,
num_heads=config.num_heads,
window_size=config.window_size,
qk_norm=config.qk_norm,
cross_attn_norm=config.cross_attn_norm,
eps=config.eps,
)
for _ in range(config.num_layers)
]
# Output head
self.head = Head(dim, config.out_dim, config.patch_size, config.eps)
# Precompute RoPE frequencies
d = dim // config.num_heads
d_t = d - 4 * (d // 6)
d_h = 2 * (d // 6)
d_w = 2 * (d // 6)
# Each rope_params returns [1024, d_x//2, 2]
freqs_t = rope_params(1024, d_t)
freqs_h = rope_params(1024, d_h)
freqs_w = rope_params(1024, d_w)
# Concatenate along the frequency dimension: [1024, d//2, 2]
self.freqs = mx.concatenate([freqs_t, freqs_h, freqs_w], axis=1)
def _patchify(self, x: mx.array) -> tuple:
"""Convert video tensor to patch embeddings.
Args:
x: Video latent [C, F, H, W]
Returns:
(patches, grid_size): patches [1, L, dim], grid_size (F', H', W')
"""
c, f, h, w = x.shape
pt, ph, pw = self._patch_size
f_out = f // pt
h_out = h // ph
w_out = w // pw
# Reshape: [C, F, H, W] -> [F', H', W', C, pt, ph, pw] -> [F'*H'*W', C*pt*ph*pw]
# Order must be [C, pt, ph, pw] (C slowest) to match Conv3d weight layout
x = x.reshape(c, f_out, pt, h_out, ph, w_out, pw)
x = x.transpose(1, 3, 5, 0, 2, 4, 6) # [F', H', W', C, pt, ph, pw]
x = x.reshape(f_out * h_out * w_out, -1) # [L, C*pt*ph*pw]
# Project and cast to model dtype to prevent float32 cascade from input latents
patches = self.patch_embedding_proj(x) # [L, dim]
patches = patches.astype(self.patch_embedding_proj.weight.dtype)
patches = patches[None, :, :] # [1, L, dim]
return patches, (f_out, h_out, w_out)
def unpatchify(self, x: mx.array, grid_sizes: list) -> list:
"""Reconstruct video from patch embeddings.
Args:
x: [B, L, out_dim * prod(patch_size)]
grid_sizes: List of (F', H', W') per batch element
Returns:
List of tensors [C, F, H, W]
"""
c = self.out_dim
pt, ph, pw = self.patch_size
out = []
for i, (f, h, w) in enumerate(grid_sizes):
seq_len = f * h * w
u = x[i, :seq_len] # [L, out_dim * pt * ph * pw]
u = u.reshape(f, h, w, pt, ph, pw, c)
# Rearrange: [F', H', W', pt, ph, pw, C] -> [C, F'*pt, H'*ph, W'*pw]
u = u.transpose(6, 0, 3, 1, 4, 2, 5) # [C, F', pt, H', ph, W', pw]
u = u.reshape(c, f * pt, h * ph, w * pw)
out.append(u)
return out
def embed_text(self, context: list) -> mx.array:
"""Precompute text embeddings (call once, reuse across steps).
Args:
context: List of text embeddings [L_text, text_dim]
Returns:
Embedded context [B, text_len, dim] in model dtype
"""
model_dtype = self.patch_embedding_proj.weight.dtype
context_padded = []
for ctx in context:
pad_len = self.text_len - ctx.shape[0]
if pad_len > 0:
ctx = mx.concatenate(
[ctx, mx.zeros((pad_len, ctx.shape[1]), dtype=ctx.dtype)],
axis=0,
)
context_padded.append(ctx)
context_batch = mx.stack(context_padded) # [B, text_len, text_dim]
context_batch = self.text_embedding_1(
self.text_embedding_act(self.text_embedding_0(context_batch))
)
return context_batch.astype(model_dtype)
def prepare_cross_kv(self, context: mx.array) -> list:
"""Pre-compute cross-attention K/V for all blocks.
Call once before the diffusion loop to cache K/V projections,
eliminating redundant computation at each denoising step.
Args:
context: Pre-embedded text [B, text_len, dim]
Returns:
List of (k, v) tuples, one per block
"""
kv_caches = []
for block in self.blocks:
kv_caches.append(block.cross_attn.prepare_kv(context))
return kv_caches
def __call__(
self,
x_list: list,
t: mx.array,
context: list | mx.array,
seq_len: int,
cross_kv_caches: list | None = None,
) -> list:
"""Forward pass.
Args:
x_list: List of video latent tensors [C, F, H, W]
t: Timestep tensor [B]
context: List of raw text embeddings, OR pre-embedded tensor
from embed_text() [B, text_len, dim]
seq_len: Maximum sequence length for padding
cross_kv_caches: Optional list of (k, v) tuples from
prepare_cross_kv(), one per block.
Returns:
List of denoised tensors [C, F, H, W]
"""
# Patchify each video
patches = []
grid_sizes = []
seq_lens_list = []
for vid in x_list:
p, gs = self._patchify(vid) # [1, L, dim]
patches.append(p)
grid_sizes.append(gs)
seq_lens_list.append(p.shape[1])
# Pad and batch
batch_size = len(patches)
x = mx.concatenate(
[
mx.concatenate(
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
axis=1,
)
if p.shape[1] < seq_len
else p
for p in patches
],
axis=0,
) # [B, seq_len, dim]
# Time embedding: compute once per sample, then broadcast to all tokens
if t.ndim == 0:
t = t[None]
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, freq_dim]
model_dtype = self.patch_embedding_proj.weight.dtype
e = self.time_embedding_1(
self.time_embedding_act(self.time_embedding_0(sin_emb))
) # [B, dim]
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
e0 = e0.reshape(batch_size, 1, 6, self.dim).astype(model_dtype)
e = e.astype(model_dtype)
# Text embedding: skip MLP if context is already embedded (mx.array)
if isinstance(context, mx.array):
# Pre-embedded: expand to batch size if needed
context_batch = context
if context_batch.shape[0] == 1 and batch_size > 1:
context_batch = mx.broadcast_to(
context_batch, (batch_size,) + context_batch.shape[1:]
)
else:
context_batch = self.embed_text(context)
# Run transformer blocks
kwargs = dict(
e=e0,
seq_lens=seq_lens_list,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context_batch,
context_lens=None,
)
for i, block in enumerate(self.blocks):
kv = cross_kv_caches[i] if cross_kv_caches is not None else None
x = block(x, cross_kv_cache=kv, **kwargs)
# Output head
x = self.head(x, e)
# Unpatchify
outputs = self.unpatchify(x, grid_sizes)
return [u.astype(mx.float32) for u in outputs]

View File

@@ -0,0 +1,100 @@
import math
import mlx.core as mx
import numpy as np
def rope_params(max_seq_len: int, dim: int, theta: float = 10000.0) -> mx.array:
"""Precompute RoPE frequency parameters as complex numbers.
Returns:
Complex frequency tensor of shape [max_seq_len, dim // 2].
"""
assert dim % 2 == 0
freqs = np.arange(max_seq_len, dtype=np.float64)[:, None] * (
1.0
/ np.power(
theta,
np.arange(0, dim, 2, dtype=np.float64) / dim,
)
)[None, :]
# Store as (cos, sin) pairs: shape [max_seq_len, dim // 2, 2]
cos_freqs = np.cos(freqs).astype(np.float32)
sin_freqs = np.sin(freqs).astype(np.float32)
return mx.array(np.stack([cos_freqs, sin_freqs], axis=-1))
def rope_apply(
x: mx.array,
grid_sizes: list,
freqs: mx.array,
) -> mx.array:
"""Apply 3-way factorized RoPE to Q or K tensor.
Args:
x: Shape [B, L, num_heads, head_dim]
grid_sizes: List of (F, H, W) tuples per batch element
freqs: Precomputed cos/sin, shape [1024, d//2, 2] split into 3 parts
"""
b, s, n, d = x.shape
half_d = d // 2
# Cast freqs to input dtype to prevent float32 promotion cascade
if freqs.dtype != x.dtype:
freqs = freqs.astype(x.dtype)
# Split frequency dimensions: temporal gets more capacity
d_t = half_d - 2 * (half_d // 3)
d_h = half_d // 3
d_w = half_d // 3
# Split freqs along dim axis
freqs_t = freqs[:, :d_t] # [1024, d_t, 2]
freqs_h = freqs[:, d_t : d_t + d_h] # [1024, d_h, 2]
freqs_w = freqs[:, d_t + d_h : d_t + d_h + d_w] # [1024, d_w, 2]
outputs = []
for i in range(b):
f, h, w = grid_sizes[i]
seq_len = f * h * w
# Reshape x to pairs for rotation: [seq_len, n, half_d, 2]
x_i = x[i, :seq_len].reshape(seq_len, n, half_d, 2)
# Build per-position frequencies by expanding along grid dims
# temporal: [f,1,1,d_t,2] -> [f,h,w,d_t,2]
ft = mx.broadcast_to(
freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2)
)
# height: [1,h,1,d_h,2] -> [f,h,w,d_h,2]
fh = mx.broadcast_to(
freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2)
)
# width: [1,1,w,d_w,2] -> [f,h,w,d_w,2]
fw = mx.broadcast_to(
freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2)
)
# Concatenate: [f*h*w, half_d, 2]
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)
# Apply rotation: (a + bi) * (cos + sin*i) = (a*cos - b*sin) + (a*sin + b*cos)i
cos_f = freqs_i[..., 0] # [seq_len, 1, half_d]
sin_f = freqs_i[..., 1] # [seq_len, 1, half_d]
x_real = x_i[..., 0] # [seq_len, n, half_d]
x_imag = x_i[..., 1] # [seq_len, n, half_d]
out_real = x_real * cos_f - x_imag * sin_f
out_imag = x_real * sin_f + x_imag * cos_f
# Interleave back: [seq_len, n, half_d, 2] -> [seq_len, n, d]
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(seq_len, n, d)
# Handle padding: keep non-rotated tokens after seq_len
if seq_len < s:
x_rotated = mx.concatenate([x_rotated, x[i, seq_len:]], axis=0)
outputs.append(x_rotated)
return mx.stack(outputs)

View File

@@ -0,0 +1,76 @@
"""Flow matching scheduler for Wan2.2 inference."""
import numpy as np
import mlx.core as mx
class FlowMatchEulerScheduler:
"""Simple Euler scheduler for flow matching diffusion.
Implements the flow matching formulation where the model predicts
velocity (flow) and we use Euler steps to denoise.
"""
def __init__(self, num_train_timesteps: int = 1000):
self.num_train_timesteps = num_train_timesteps
self.timesteps = None
self.sigmas = None
def set_timesteps(self, num_steps: int, shift: float = 1.0):
"""Compute sigma schedule with shift.
Args:
num_steps: Number of inference steps.
shift: Noise schedule shift factor.
"""
# Linear spacing from sigma_max to sigma_min
sigmas = np.linspace(1.0, 1.0 / self.num_train_timesteps, self.num_train_timesteps)[::-1]
sigmas = 1.0 - sigmas
# Select evenly spaced subset
indices = np.linspace(0, len(sigmas) - 1, num_steps + 1).astype(int)
sigmas = sigmas[indices[:-1]]
# Apply shift: sigma' = shift * sigma / (1 + (shift - 1) * sigma)
sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas)
# Convert to timesteps
timesteps = sigmas * self.num_train_timesteps
self.timesteps = mx.array(timesteps.astype(np.float32))
# Append terminal sigma=0
sigmas = np.append(sigmas, 0.0)
self.sigmas = mx.array(sigmas.astype(np.float32))
self._step_index = 0
def step(
self,
model_output: mx.array,
timestep,
sample: mx.array,
) -> mx.array:
"""Euler step for flow matching.
In flow matching, model predicts velocity v, and:
x_{t-1} = sample + (sigma_{t-1} - sigma_t) * v
Args:
model_output: Predicted velocity [B, C, T, H, W]
timestep: Current timestep (unused, step index is tracked internally)
sample: Current noisy sample [B, C, T, H, W]
Returns:
Updated sample
"""
# Use Python floats to avoid creating mx.array scalars that
# could trigger type promotion (per fast-mlx guide)
dt = float(self.sigmas[self._step_index + 1].item()) - float(self.sigmas[self._step_index].item())
x_next = sample + dt * model_output
self._step_index += 1
return x_next
def reset(self):
"""Reset step counter for new generation."""
self._step_index = 0

View File

@@ -0,0 +1,234 @@
"""T5 Text Encoder (UMT5-XXL) for Wan2.2 text conditioning."""
import math
import mlx.core as mx
import mlx.nn as nn
class T5LayerNorm(nn.Module):
"""RMS-based layer normalization (T5 style)."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = mx.ones((dim,))
def __call__(self, x: mx.array) -> mx.array:
return mx.fast.rms_norm(x, self.weight, self.eps)
class T5RelativeEmbedding(nn.Module):
"""T5-style relative position bias with bucketing."""
def __init__(
self,
num_buckets: int,
num_heads: int,
bidirectional: bool = True,
max_dist: int = 128,
):
super().__init__()
self.num_buckets = num_buckets
self.num_heads = num_heads
self.bidirectional = bidirectional
self.max_dist = max_dist
self.embedding = nn.Embedding(num_buckets, num_heads)
def _relative_position_bucket(self, rel_pos: mx.array) -> mx.array:
if self.bidirectional:
num_buckets = self.num_buckets // 2
rel_buckets = (rel_pos > 0).astype(mx.int32) * num_buckets
rel_pos = mx.abs(rel_pos)
else:
num_buckets = self.num_buckets
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
rel_pos = mx.maximum(-rel_pos, mx.zeros_like(rel_pos))
max_exact = num_buckets // 2
is_small = rel_pos < max_exact
rel_pos_f = rel_pos.astype(mx.float32)
rel_pos_large = (
max_exact
+ (
mx.log(rel_pos_f / max_exact)
/ math.log(self.max_dist / max_exact)
* (num_buckets - max_exact)
).astype(mx.int32)
)
rel_pos_large = mx.minimum(
rel_pos_large,
mx.full(rel_pos_large.shape, num_buckets - 1, dtype=mx.int32),
)
rel_buckets = rel_buckets + mx.where(is_small, rel_pos.astype(mx.int32), rel_pos_large)
return rel_buckets
def __call__(self, lq: int, lk: int) -> mx.array:
positions_k = mx.arange(lk)[None, :] # [1, lk]
positions_q = mx.arange(lq)[:, None] # [lq, 1]
rel_pos = positions_k - positions_q # [lq, lk]
buckets = self._relative_position_bucket(rel_pos)
embeds = self.embedding(buckets) # [lq, lk, num_heads]
embeds = embeds.transpose(2, 0, 1)[None, :, :, :] # [1, N, lq, lk]
return embeds
class T5Attention(nn.Module):
"""T5-style multi-head attention (no scaling)."""
def __init__(self, dim: int, dim_attn: int, num_heads: int, dropout: float = 0.0):
super().__init__()
assert dim_attn % num_heads == 0
self.dim = dim
self.dim_attn = dim_attn
self.num_heads = num_heads
self.head_dim = dim_attn // num_heads
self.q = nn.Linear(dim, dim_attn, bias=False)
self.k = nn.Linear(dim, dim_attn, bias=False)
self.v = nn.Linear(dim, dim_attn, bias=False)
self.o = nn.Linear(dim_attn, dim, bias=False)
def __call__(
self,
x: mx.array,
context: mx.array | None = None,
mask: mx.array | None = None,
pos_bias: mx.array | None = None,
) -> mx.array:
context = x if context is None else context
b, n, c = x.shape[0], self.num_heads, self.head_dim
q = self.q(x).reshape(b, -1, n, c) # [B, Lq, N, C]
k = self.k(context).reshape(b, -1, n, c) # [B, Lk, N, C]
v = self.v(context).reshape(b, -1, n, c)
# T5 does not use scaling
# attn = einsum('binc,bjnc->bnij', q, k)
q = q.transpose(0, 2, 1, 3) # [B, N, Lq, C]
k = k.transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
# Combine position bias and attention mask for SDPA
attn_mask = None
if pos_bias is not None:
attn_mask = pos_bias.astype(q.dtype)
if mask is not None:
if mask.ndim == 2:
mask = mask[:, None, None, :] # [B, 1, 1, Lk]
elif mask.ndim == 3:
mask = mask[:, None, :, :] # [B, 1, Lq, Lk]
additive_mask = mx.where(mask == 0, -1e9, 0.0).astype(q.dtype)
attn_mask = (attn_mask + additive_mask) if attn_mask is not None else additive_mask
# T5 uses no scaling (scale=1.0)
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=1.0, mask=attn_mask
)
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * c)
return self.o(out)
class T5FeedForward(nn.Module):
"""Gated feed-forward: gate(x) * fc1(x) -> fc2."""
def __init__(self, dim: int, dim_ffn: int):
super().__init__()
self.dim = dim
self.dim_ffn = dim_ffn
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
self.gate_act = nn.GELU(approx="precise")
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
def __call__(self, x: mx.array) -> mx.array:
return self.fc2(self.fc1(x) * self.gate_act(self.gate_proj(x)))
class T5SelfAttentionBlock(nn.Module):
"""T5 encoder block: self-attention + FFN."""
def __init__(
self,
dim: int,
dim_attn: int,
dim_ffn: int,
num_heads: int,
num_buckets: int,
shared_pos: bool = True,
):
super().__init__()
self.shared_pos = shared_pos
self.norm1 = T5LayerNorm(dim)
self.attn = T5Attention(dim, dim_attn, num_heads)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn)
self.pos_embedding = (
None
if shared_pos
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
)
def __call__(
self,
x: mx.array,
mask: mx.array | None = None,
pos_bias: mx.array | None = None,
) -> mx.array:
e = pos_bias if self.shared_pos else self.pos_embedding(x.shape[1], x.shape[1])
x = x + self.attn(self.norm1(x), mask=mask, pos_bias=e)
x = x + self.ffn(self.norm2(x))
return x
class T5Encoder(nn.Module):
"""T5 Encoder (UMT5-XXL configuration)."""
def __init__(
self,
vocab_size: int = 256384,
dim: int = 4096,
dim_attn: int = 4096,
dim_ffn: int = 10240,
num_heads: int = 64,
num_layers: int = 24,
num_buckets: int = 32,
shared_pos: bool = False,
):
super().__init__()
self.dim = dim
self.token_embedding = nn.Embedding(vocab_size, dim)
self.pos_embedding = (
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
if shared_pos
else None
)
self.blocks = [
T5SelfAttentionBlock(
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos
)
for _ in range(num_layers)
]
self.norm = T5LayerNorm(dim)
def __call__(self, ids: mx.array, mask: mx.array | None = None) -> mx.array:
"""
Args:
ids: Token IDs [B, L]
mask: Attention mask [B, L]
Returns:
Hidden states [B, L, dim]
"""
x = self.token_embedding(ids)
e = self.pos_embedding(x.shape[1], x.shape[1]) if self.pos_embedding else None
for block in self.blocks:
x = block(x, mask=mask, pos_bias=e)
x = self.norm(x)
return x

View File

@@ -0,0 +1,89 @@
import mlx.core as mx
import mlx.nn as nn
from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention
class WanAttentionBlock(nn.Module):
"""Wan transformer block with learned modulation, self-attn, cross-attn, and FFN."""
def __init__(
self,
dim: int,
ffn_dim: int,
num_heads: int,
window_size: tuple = (-1, -1),
qk_norm: bool = True,
cross_attn_norm: bool = False,
eps: float = 1e-6,
):
super().__init__()
# Self-attention
self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
# Cross-attention (with optional norm on context)
self.norm3 = (
WanLayerNorm(dim, eps, elementwise_affine=True)
if cross_attn_norm
else None
)
self.cross_attn = WanCrossAttention(dim, num_heads, qk_norm, eps)
# Feed-forward
self.norm2 = WanLayerNorm(dim, eps)
self.ffn = WanFFN(dim, ffn_dim)
# Learned modulation: 6 vectors for scale/shift/gate
self.modulation = mx.random.normal((1, 6, dim)) * (dim**-0.5)
def __call__(
self,
x: mx.array,
e: mx.array,
seq_lens: list,
grid_sizes: list,
freqs: mx.array,
context: mx.array,
context_lens: list | None = None,
cross_kv_cache: tuple | None = None,
) -> mx.array:
# Compute modulation: e is [B, 1, 6, dim] (broadcasts over tokens)
mod = (self.modulation + e) # [1, 6, dim] + [B, 1, 6, dim] -> [B, 1, 6, dim]
# Split into 6 modulation vectors (each [B, 1, dim], broadcast over L)
e0 = mod[:, :, 0, :] # shift for self-attn
e1 = mod[:, :, 1, :] # scale for self-attn
e2 = mod[:, :, 2, :] # gate for self-attn
e3 = mod[:, :, 3, :] # shift for ffn
e4 = mod[:, :, 4, :] # scale for ffn
e5 = mod[:, :, 5, :] # gate for ffn
# Self-attention with modulation
x_mod = self.norm1(x) * (1 + e1) + e0
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs)
x = x + y * e2
# Cross-attention (no modulation, just norm)
x_cross = self.norm3(x) if self.norm3 is not None else x
x = x + self.cross_attn(x_cross, context, context_lens, kv_cache=cross_kv_cache)
# FFN with modulation
x_mod = self.norm2(x) * (1 + e4) + e3
y = self.ffn(x_mod)
x = x + y * e5
return x
class WanFFN(nn.Module):
"""Gated feed-forward network with GELU(tanh) activation."""
def __init__(self, dim: int, ffn_dim: int):
super().__init__()
self.fc1 = nn.Linear(dim, ffn_dim)
self.act = nn.GELU(approx="precise")
self.fc2 = nn.Linear(ffn_dim, dim)
def __call__(self, x: mx.array) -> mx.array:
return self.fc2(self.act(self.fc1(x)))

315
mlx_video/models/wan/vae.py Normal file
View File

@@ -0,0 +1,315 @@
"""3D VAE Decoder for Wan2.1/2.2 (compression 4×8×8).
Module structure mirrors original PyTorch checkpoint key hierarchy
so weights load directly without key sanitization.
"""
import mlx.core as mx
import mlx.nn as nn
import numpy as np
CACHE_T = 2
# Per-channel normalization statistics for z_dim=16
VAE_MEAN = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921,
]
VAE_STD = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160,
]
class CausalConv3d(nn.Module):
"""3D convolution with causal temporal padding."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple,
stride: int | tuple = 1,
padding: int | tuple = 0,
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(padding, int):
padding = (padding, padding, padding)
self.kernel_size = kernel_size
self.stride = stride
self._causal_pad_t = 2 * padding[0]
self._pad_h = padding[1]
self._pad_w = padding[2]
# MLX Conv3d: weight shape [O, D, H, W, I]
self.weight = mx.zeros((out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels))
self.bias = mx.zeros((out_channels,))
def __call__(self, x: mx.array) -> mx.array:
"""x: [B, C, T, H, W] (channel-first)"""
b, c, t, h, w = x.shape
if self._causal_pad_t > 0:
pad_t = mx.zeros((b, c, self._causal_pad_t, h, w), dtype=x.dtype)
x = mx.concatenate([pad_t, x], axis=2)
if self._pad_h > 0 or self._pad_w > 0:
x = mx.pad(x, [(0, 0), (0, 0), (0, 0),
(self._pad_h, self._pad_h), (self._pad_w, self._pad_w)])
x = x.transpose(0, 2, 3, 4, 1) # [B, T, H, W, C]
out = self._conv3d(x)
return out.transpose(0, 4, 1, 2, 3) # [B, O, T', H', W']
def _conv3d(self, x: mx.array) -> mx.array:
"""3D conv via sliding window + 2D conv per time step.
x: [B, T, H, W, C_in] -> [B, T_out, H_out, W_out, C_out]
"""
b, t, h, w, c_in = x.shape
kt, kh, kw = self.kernel_size
st, sh, sw = self.stride
t_out = (t - kt) // st + 1
# Pre-reshape weight: [O, D, H, W, I] -> [O, H, W, D*I]
w_2d = self.weight.transpose(0, 2, 3, 1, 4).reshape(
self.weight.shape[0], kh, kw, kt * c_in
)
outputs = []
for t_i in range(t_out):
t_start = t_i * st
window = x[:, t_start : t_start + kt]
window = window.transpose(0, 2, 3, 1, 4).reshape(b, h, w, kt * c_in)
out_2d = mx.conv2d(window, w_2d, stride=(sh, sw)) + self.bias
outputs.append(out_2d)
return mx.stack(outputs, axis=1)
class RMS_norm(nn.Module):
"""Channel-first L2 normalization matching original Wan VAE.
Uses F.normalize (L2 norm) with learned scale, equivalent to RMS norm.
images=True: gamma shape (dim, 1, 1) for 4D (per-frame) input.
images=False: gamma shape (dim, 1, 1, 1) for 5D video input.
"""
def __init__(self, dim: int, channel_first: bool = True, images: bool = True):
super().__init__()
self.channel_first = channel_first
self.scale = dim**0.5
if channel_first:
broadcastable = (1, 1) if images else (1, 1, 1)
self.gamma = mx.ones((dim, *broadcastable))
else:
self.gamma = mx.ones((dim,))
def __call__(self, x: mx.array) -> mx.array:
norm_dim = 1 if self.channel_first else -1
# L2 normalize along channel dim (matches F.normalize)
norm = mx.sqrt(mx.clip(mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None))
return (x / norm) * self.scale * self.gamma
class ResidualBlock(nn.Module):
"""Residual block with causal 3D convolutions.
Uses `residual` list with None gaps to match original PyTorch
nn.Sequential indices: [0]=norm, [1]=SiLU, [2]=conv, [3]=norm,
[4]=SiLU, [5]=Dropout, [6]=conv. Only indices 0,2,3,6 have params.
"""
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.residual = [
RMS_norm(in_dim, images=False), # [0]
None, # [1] SiLU
CausalConv3d(in_dim, out_dim, 3, padding=1), # [2]
RMS_norm(out_dim, images=False), # [3]
None, # [4] SiLU
None, # [5] Dropout
CausalConv3d(out_dim, out_dim, 3, padding=1), # [6]
]
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
def __call__(self, x: mx.array) -> mx.array:
h = x if self.shortcut is None else self.shortcut(x)
x = nn.silu(self.residual[0](x))
x = self.residual[2](x)
x = nn.silu(self.residual[3](x))
x = self.residual[6](x)
return x + h
class AttentionBlock(nn.Module):
"""Single-head spatial self-attention."""
def __init__(self, dim: int):
super().__init__()
self.norm = RMS_norm(dim, images=True)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
def __call__(self, x: mx.array) -> mx.array:
"""x: [B, C, T, H, W]"""
identity = x
b, c, t, h, w = x.shape
# [B,C,T,H,W] -> [B,T,C,H,W] -> [BT,C,H,W] -> norm -> [BT,H,W,C]
x = x.transpose(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
x = self.norm(x)
x = x.transpose(0, 2, 3, 1) # [BT, H, W, C]
qkv = self.to_qkv(x) # [BT, H, W, 3C]
qkv = qkv.reshape(b * t, h * w, 3, c).transpose(2, 0, 1, 3)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q[:, None, :, :] # [BT, 1, HW, C]
k = k[:, None, :, :]
v = v[:, None, :, :]
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=c**-0.5)
out = out.squeeze(1).reshape(b * t, h, w, c) # [BT, H, W, C]
out = self.proj(out) # [BT, H, W, C]
out = out.reshape(b, t, h, w, c).transpose(0, 4, 1, 2, 3) # [B, C, T, H, W]
return out + identity
class Resample(nn.Module):
"""Upsample block matching original Wan VAE structure.
Uses `resample` list with [None, Conv2d] to match original
nn.Sequential(Upsample, Conv2d) where index 1 has the conv params.
"""
def __init__(self, dim: int, mode: str):
super().__init__()
assert mode in ("upsample2d", "upsample3d")
self.mode = mode
self.dim = dim
# resample.0 = Upsample (no params), resample.1 = Conv2d
self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)]
if mode == "upsample3d":
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
def __call__(self, x: mx.array) -> mx.array:
"""x: [B, C, T, H, W]"""
b, c, t, h, w = x.shape
if self.mode == "upsample3d":
# Temporal upsample via learned conv
x_t = self.time_conv(x) # [B, 2C, T, H, W]
x_t = x_t.reshape(b, 2, c, t, h, w)
# Interleave along time: [B, C, 2T, H, W]
x = mx.stack([x_t[:, 0], x_t[:, 1]], axis=3).reshape(b, c, t * 2, h, w)
t = t * 2
# Per-frame spatial upsample: nearest 2x + Conv2d
x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C]
x = mx.repeat(x, 2, axis=1)
x = mx.repeat(x, 2, axis=2)
x = self.resample[1](x) # Conv2d [BT, 2H, 2W, C//2]
c_out = x.shape[-1]
return x.reshape(b, t, h * 2, w * 2, c_out).transpose(0, 4, 1, 2, 3)
class Decoder3d(nn.Module):
"""3D VAE Decoder matching Wan2.1 architecture.
Uses flat `middle` and `upsamples` lists to match original
PyTorch nn.Sequential weight key hierarchy.
"""
def __init__(
self,
dim: int = 96,
z_dim: int = 16,
dim_mult: list = None,
num_res_blocks: int = 2,
temporal_upsample: list = None,
):
super().__init__()
if dim_mult is None:
dim_mult = [1, 2, 4, 4]
if temporal_upsample is None:
temporal_upsample = [True, True, False]
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# Middle: [ResBlock, AttentionBlock, ResBlock]
self.middle = [
ResidualBlock(dims[0], dims[0]),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0]),
]
# Flat upsample list matching original nn.Sequential indexing
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
if i in (1, 2, 3):
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim))
in_dim = out_dim
if i != len(dim_mult) - 1:
mode = "upsample3d" if temporal_upsample[i] else "upsample2d"
upsamples.append(Resample(out_dim, mode=mode))
self.upsamples = upsamples
# Output head: [RMS_norm, SiLU (no params), CausalConv3d]
self.head = [
RMS_norm(dims[-1], images=False), # [0]
None, # [1] SiLU
CausalConv3d(dims[-1], 3, 3, padding=1), # [2]
]
def __call__(self, x: mx.array) -> mx.array:
"""x: [B, z_dim, T, H, W] -> [B, 3, T_out, H_out, W_out]"""
x = self.conv1(x)
for layer in self.middle:
x = layer(x)
for layer in self.upsamples:
x = layer(x)
x = nn.silu(self.head[0](x))
x = self.head[2](x)
return x
class WanVAE(nn.Module):
"""Wan2.1 VAE wrapper with per-channel normalization."""
def __init__(self, z_dim: int = 16):
super().__init__()
self.z_dim = z_dim
self.mean = mx.array(VAE_MEAN)
self.std = mx.array(VAE_STD)
self.inv_std = 1.0 / self.std
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim=96, z_dim=z_dim)
def decode(self, z: mx.array) -> mx.array:
"""Decode latent to video.
Args:
z: Normalized latent [B, z_dim, T, H, W]
Returns:
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
"""
mean = self.mean.reshape(1, -1, 1, 1, 1)
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
z = z / inv_std + mean
x = self.conv2(z)
out = self.decoder(x)
return mx.clip(out, -1, 1)

View File

@@ -0,0 +1,584 @@
"""Wan2.2 VAE Decoder (compression 4×16×16, z_dim=48).
Architecture differs from Wan2.1 VAE: uses RMS_norm, DupUp3D shortcuts,
spatial patchify (2×2), and different temporal upsampling pattern.
Weight keys mirror the PyTorch checkpoint hierarchy so only tensor format
conversion (channels-first → channels-last) is needed.
"""
import math
import mlx.core as mx
import mlx.nn as nn
import numpy as np
CACHE_T = 2
# Per-channel normalization for z_dim=48 latent space
VAE22_MEAN = mx.array([
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
])
VAE22_STD = mx.array([
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744,
])
class CausalConv3d(nn.Module):
"""3D causal convolution. Input/output: [B, T, H, W, C] (channels-last).
Decomposes the 3D conv into per-frame 2D convolutions to avoid
excessive memory usage from MLX's conv3d implementation.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(padding, int):
padding = (padding, padding, padding)
self.kernel_size = kernel_size
self.stride = stride
self._causal_pad_t = 2 * padding[0]
self._pad_h = padding[1]
self._pad_w = padding[2]
# Weight: [O, D, H, W, I] for MLX
self.weight = mx.zeros((
out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels
))
self.bias = mx.zeros((out_channels,))
def __call__(self, x):
# x: [B, T, H, W, C]
B, T, H, W, C = x.shape
kd, kh, kw = self.kernel_size
# For 1x1x1 kernel or kernel_d==1, use direct conv
if kd == 1 and kh == 1 and kw == 1:
# Simple pointwise: reshape to [B*T, 1, 1, C] → conv2d
x_flat = x.reshape(B * T, H, W, C)
w2d = self.weight[:, 0, :, :, :] # [O, kH, kW, I]
y = mx.conv_general(x_flat, w2d) + self.bias
return y.reshape(B, T, y.shape[1], y.shape[2], -1)
# Causal temporal padding (left only)
if self._causal_pad_t > 0:
pad_t = mx.zeros((B, self._causal_pad_t, H, W, C))
x = mx.concatenate([pad_t, x], axis=1)
# Spatial padding
if self._pad_h > 0 or self._pad_w > 0:
x = mx.pad(x, [(0, 0), (0, 0), (self._pad_h, self._pad_h),
(self._pad_w, self._pad_w), (0, 0)])
T_padded = x.shape[1]
H_padded, W_padded = x.shape[2], x.shape[3]
T_out = (T_padded - kd) // self.stride[0] + 1
# Decompose 3D conv into sum of 2D convolutions over temporal kernel
# weight shape: [O, kd, kh, kw, I] → split into kd 2D kernels [O, kh, kw, I]
outputs = []
for t in range(T_out):
t_start = t * self.stride[0]
# Sum 2D convs for each temporal kernel position
accum = None
for d in range(kd):
frame = x[:, t_start + d] # [B, H_padded, W_padded, C]
w2d = self.weight[:, d, :, :, :] # [O, kh, kw, I]
conv_out = mx.conv_general(frame, w2d,
stride=(self.stride[1], self.stride[2]))
accum = conv_out if accum is None else accum + conv_out
outputs.append(accum + self.bias)
return mx.stack(outputs, axis=1) # [B, T_out, H_out, W_out, O]
class RMS_norm(nn.Module):
"""RMS normalization along channel dimension."""
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
# Weight stored as (dim,) — PyTorch stores (dim, 1, 1, 1) but we squeeze
self.gamma = mx.ones((dim,))
def __call__(self, x):
# x: [..., C] (channels-last)
# PyTorch uses F.normalize (L2 norm), not RMS: x / max(||x||_2, eps)
l2_sq = mx.sum(x * x, axis=-1, keepdims=True)
return x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma
class ResidualBlock(nn.Module):
"""Residual block: RMS_norm → SiLU → CausalConv3d × 2 + shortcut."""
def __init__(self, in_dim, out_dim):
super().__init__()
# Sequential residual path: [norm, silu, conv3d, norm, silu, dropout, conv3d]
# We store as named layers matching PyTorch's indices
self.residual = ResidualBlockLayers(in_dim, out_dim)
self.shortcut = (
CausalConv3d(in_dim, out_dim, 1)
if in_dim != out_dim
else None
)
def __call__(self, x):
h = self.shortcut(x) if self.shortcut is not None else x
return self.residual(x) + h
class ResidualBlockLayers(nn.Module):
"""The sequential layers inside a ResidualBlock.
PyTorch stores these as nn.Sequential with indices 0-6:
[0] RMS_norm, [1] SiLU, [2] CausalConv3d, [3] RMS_norm, [4] SiLU, [5] Dropout, [6] CausalConv3d
We use matching attribute names for weight compatibility.
"""
def __init__(self, in_dim, out_dim):
super().__init__()
# Indices match PyTorch nn.Sequential indices for weight key compat
# Index 0: RMS_norm
self.layer_0 = RMS_norm(in_dim)
# Index 2: CausalConv3d
self.layer_2 = CausalConv3d(in_dim, out_dim, 3, padding=1)
# Index 3: RMS_norm
self.layer_3 = RMS_norm(out_dim)
# Index 6: CausalConv3d
self.layer_6 = CausalConv3d(out_dim, out_dim, 3, padding=1)
def __call__(self, x):
x = self.layer_0(x)
x = nn.silu(x)
x = self.layer_2(x)
mx.eval(x) # Eval between convolutions to limit graph size
x = self.layer_3(x)
x = nn.silu(x)
x = self.layer_6(x)
return x
class AttentionBlock(nn.Module):
"""2D self-attention applied per frame. Input: [B, T, H, W, C]."""
def __init__(self, dim):
super().__init__()
self.dim = dim
self.norm = RMS_norm(dim)
# Conv2d as linear per spatial position — weight [O, H, W, I] for MLX
# to_qkv: dim -> 3*dim, proj: dim -> dim (1x1 conv2d)
self.to_qkv_weight = mx.zeros((3 * dim, 1, 1, dim))
self.to_qkv_bias = mx.zeros((3 * dim,))
self.proj_weight = mx.zeros((dim, 1, 1, dim))
self.proj_bias = mx.zeros((dim,))
def __call__(self, x):
# x: [B, T, H, W, C]
identity = x
B, T, H, W, C = x.shape
# Apply per frame: merge B and T
x = x.reshape(B * T, H, W, C)
x = self.norm(x)
# QKV via 1x1 conv2d (equivalent to linear on last dim)
qkv = mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias # [BT, H, W, 3C]
qkv = qkv.reshape(B * T, H * W, 3 * C)
q, k, v = mx.split(qkv, 3, axis=-1) # each [BT, HW, C]
# Single-head attention
q = q[:, None, :, :] # [BT, 1, HW, C]
k = k[:, None, :, :]
v = v[:, None, :, :]
scale = C ** -0.5
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) # [BT, 1, HW, C]
out = out.squeeze(1).reshape(B * T, H, W, C)
# Project output
out = mx.conv_general(out, self.proj_weight) + self.proj_bias # [BT, H, W, C]
out = out.reshape(B, T, H, W, C)
return out + identity
class DupUp3D(nn.Module):
"""Upsample by duplicating channels and reshaping. No learnable parameters."""
def __init__(self, in_channels, out_channels, factor_t, factor_s=1):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = factor_t * factor_s * factor_s
self.repeats = out_channels * self.factor // in_channels
def __call__(self, x, first_chunk=False):
# x: [B, T, H, W, C]
B, T, H, W, C = x.shape
# Repeat channels
x = mx.repeat(x, self.repeats, axis=-1) # [B, T, H, W, C*repeats]
# Reshape to [B, T, H, W, out_C, factor_t, factor_s, factor_s]
x = x.reshape(B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s)
# Permute to interleave: [B, T, factor_t, H, factor_s, W, factor_s, out_C]
x = x.transpose(0, 1, 5, 2, 6, 3, 7, 4)
# Reshape to final: [B, T*factor_t, H*factor_s, W*factor_s, out_C]
x = x.reshape(B, T * self.factor_t, H * self.factor_s, W * self.factor_s, self.out_channels)
if first_chunk:
x = x[:, self.factor_t - 1:, :, :, :]
return x
class Resample(nn.Module):
"""Spatial up/downsampling with optional temporal up/downsampling."""
def __init__(self, dim, mode):
super().__init__()
self.dim = dim
self.mode = mode
if mode == "upsample2d":
# resample.0 = Upsample (no params), resample.1 = Conv2d
self.resample_weight = mx.zeros((dim, 3, 3, dim)) # Conv2d [O, H, W, I]
self.resample_bias = mx.zeros((dim,))
elif mode == "upsample3d":
self.resample_weight = mx.zeros((dim, 3, 3, dim))
self.resample_bias = mx.zeros((dim,))
# time_conv: CausalConv3d(dim, dim*2, (3,1,1), padding=(1,0,0))
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
else:
raise ValueError(f"Unsupported mode: {mode}")
def _upsample2x(self, x):
"""Nearest-neighbor 2x spatial upsample. x: [N, H, W, C]."""
N, H, W, C = x.shape
# Repeat along H and W axes separately
x = mx.repeat(x, repeats=2, axis=1) # [N, 2H, W, C]
x = mx.repeat(x, repeats=2, axis=2) # [N, 2H, 2W, C]
return x
def _conv2d(self, x):
"""Apply the Conv2d with padding=1. x: [N, H, W, C]."""
x = mx.pad(x, [(0, 0), (1, 1), (1, 1), (0, 0)])
return mx.conv_general(x, self.resample_weight) + self.resample_bias
def __call__(self, x, first_chunk=False):
# x: [B, T, H, W, C]
B, T, H, W, C = x.shape
if self.mode == "upsample3d":
# Temporal upsample via time_conv
tc_out = self.time_conv(x) # [B, T, H, W, 2C]
# Split into two interleaved temporal streams
tc_out = tc_out.reshape(B, T, H, W, 2, C)
# Interleave: [B, T, 2, H, W, C] → [B, T*2, H, W, C]
stream0 = tc_out[:, :, :, :, 0, :] # [B, T, H, W, C]
stream1 = tc_out[:, :, :, :, 1, :] # [B, T, H, W, C]
x = mx.stack([stream0, stream1], axis=2) # [B, T, 2, H, W, C]
x = x.reshape(B, T * 2, H, W, C)
if first_chunk:
# PyTorch skips time_conv for first chunk entirely. In all-at-once
# mode, we trim the first frame to match (the first interleaved
# frame is from zero-padded causal context and shouldn't be kept).
x = x[:, 1:, :, :, :]
mx.eval(x)
T = x.shape[1]
# Spatial upsample in temporal chunks to limit peak memory
chunk_size = 8
chunks = []
for t_start in range(0, T, chunk_size):
t_end = min(t_start + chunk_size, T)
x_chunk = x[:, t_start:t_end].reshape(-1, H, W, C)
x_chunk = self._upsample2x(x_chunk)
x_chunk = self._conv2d(x_chunk)
mx.eval(x_chunk)
chunks.append(x_chunk)
x = mx.concatenate(chunks, axis=0)
H2, W2 = x.shape[1], x.shape[2]
x = x.reshape(B, T, H2, W2, C)
return x
class Up_ResidualBlock(nn.Module):
"""Upsampling residual block with optional DupUp3D shortcut."""
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False):
super().__init__()
self.up_flag = up_flag
# DupUp3D shortcut (no learnable params)
if up_flag:
self.avg_shortcut = DupUp3D(
in_dim, out_dim,
factor_t=2 if temperal_upsample else 1,
factor_s=2 if up_flag else 1,
)
else:
self.avg_shortcut = None
# Main path: ResidualBlocks + optional Resample
blocks = []
dim_in = in_dim
for _ in range(num_res_blocks):
blocks.append(ResidualBlock(dim_in, out_dim))
dim_in = out_dim
if up_flag:
mode = "upsample3d" if temperal_upsample else "upsample2d"
blocks.append(Resample(out_dim, mode=mode))
self.upsamples = blocks
def __call__(self, x, first_chunk=False):
x_main = x
for module in self.upsamples:
if isinstance(module, Resample):
x_main = module(x_main, first_chunk)
else:
x_main = module(x_main)
mx.eval(x_main) # Limit graph size per sub-block
if self.avg_shortcut is not None:
x_shortcut = self.avg_shortcut(x, first_chunk)
mx.eval(x_shortcut)
return x_main + x_shortcut
return x_main
class Decoder3d(nn.Module):
"""Wan2.2 3D VAE Decoder."""
def __init__(
self,
dim=256,
z_dim=48,
dim_mult=(1, 2, 4, 4),
num_res_blocks=2,
temperal_upsample=(True, True, False),
):
super().__init__()
# Compute layer dimensions
dims = [dim * dim_mult[-1]] + [dim * m for m in reversed(dim_mult)]
# Initial conv
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# Middle blocks
self.middle = [
ResidualBlock(dims[0], dims[0]),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0]),
]
# Upsample blocks
self.upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_up = temperal_upsample[i] if i < len(temperal_upsample) else False
self.upsamples.append(Up_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks + 1,
temperal_upsample=t_up,
up_flag=(i != len(dim_mult) - 1),
))
# Output head: [RMS_norm, SiLU, CausalConv3d]
self.head = Head22(dims[-1])
def __call__(self, x, first_chunk=False):
# x: [B, T, H, W, C=z_dim]
x = self.conv1(x)
for layer in self.middle:
x = layer(x)
mx.eval(x) # Evaluate to limit graph size
for i, layer in enumerate(self.upsamples):
x = layer(x, first_chunk)
mx.eval(x) # Evaluate after each upsample block
x = self.head(x)
return x
class Head22(nn.Module):
"""Decoder output head: RMS_norm → SiLU → CausalConv3d(dim, 12, 3).
PyTorch key mapping: head.0 = RMS_norm, head.2 = CausalConv3d
(index 1 = SiLU has no params)
"""
def __init__(self, dim, out_channels=12):
super().__init__()
# Index 0: RMS_norm
self.layer_0 = RMS_norm(dim)
# Index 2: CausalConv3d
self.layer_2 = CausalConv3d(dim, out_channels, 3, padding=1)
def __call__(self, x):
x = self.layer_0(x)
x = nn.silu(x)
x = self.layer_2(x)
return x
class Wan22VAEDecoder(nn.Module):
"""Full Wan2.2 VAE decoder with normalization and unpatchify."""
def __init__(self, z_dim=48, dim=160, dec_dim=256):
super().__init__()
self.z_dim = z_dim
# conv2: 1x1x1 conv before decoder
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(
dim=dec_dim,
z_dim=z_dim,
dim_mult=(1, 2, 4, 4),
num_res_blocks=2,
temperal_upsample=(True, True, False),
)
def __call__(self, z):
"""Decode latents to video.
Args:
z: [B, T, H, W, C=48] latent tensor (already denormalized)
Returns:
video: [B, T', H', W', 3] decoded RGB in [-1, 1]
"""
x = self.conv2(z)
# All-at-once decode with first_chunk=True to trim extra temporal
# frames from causal padding (matches PyTorch's chunked behavior)
out = self.decoder(x, first_chunk=True)
# Unpatchify: 12 channels → 3 RGB (spatial 2×2)
out = _unpatchify(out, patch_size=2)
return mx.clip(out, -1.0, 1.0)
def denormalize_latents(z, mean=None, std=None):
"""Denormalize latents: z = z / (1/std) + mean."""
if mean is None:
mean = VAE22_MEAN
if std is None:
std = VAE22_STD
inv_scale = std # scale was 1/std, so divide by scale = multiply by std
return z * inv_scale.reshape(1, 1, 1, 1, -1) + mean.reshape(1, 1, 1, 1, -1)
def _unpatchify(x, patch_size=2):
"""Convert from packed channels to spatial: [B, T, H, W, C*p*p] → [B, T, H*p, W*p, C//(p*p)]
Actually: [B, T, H, W, 12] → [B, T, H*2, W*2, 3]
PyTorch: b (c r q) f h w -> b c f (h q) (w r) with q=p, r=p
In channels-last: [B, T, H, W, C*r*q] -> [B, T, H*q, W*r, C]
"""
if patch_size == 1:
return x
B, T, H, W, Cpacked = x.shape
C = Cpacked // (patch_size * patch_size)
# Reshape: [B, T, H, W, r, q, C] then rearrange to [B, T, H*q, W*r, C]
# PyTorch patchify: "b c f (h q) (w r) -> b (c r q) f h w" — so c is packed as (c, r, q)
# Unpatchify reverses: [B, T, H, W, (C, r, q)] -> [B, T, H, q, W, r, C]
x = x.reshape(B, T, H, W, C, patch_size, patch_size)
# Rearrange: put q next to H, r next to W
x = x.transpose(0, 1, 2, 6, 3, 5, 4) # [B, T, H, q, W, r, C]
x = x.reshape(B, T, H * patch_size, W * patch_size, C)
return x
def sanitize_wan22_vae_weights(weights: dict) -> dict:
"""Convert PyTorch Wan2.2 VAE weights to MLX format.
Only keeps decoder + conv2 weights (encoder/conv1 not needed for generation).
Transposes conv weights from channels-first to channels-last.
Squeezes RMS_norm gamma from (dim, 1, 1, 1) or (dim, 1, 1) to (dim,).
Maps PyTorch nn.Sequential indices to our named layers.
"""
sanitized = {}
for key, value in weights.items():
# Skip encoder and conv1 (encoder-only)
if key.startswith("encoder.") or key.startswith("conv1."):
continue
new_key = key
# Map nn.Sequential indexed layers to our named attributes
# ResidualBlockLayers: indices 0, 2, 3, 6 → _layer_0, _layer_2, _layer_3, _layer_6
# Head22: indices 0, 2 → _layer_0, _layer_2
for idx in ["0", "2", "3", "6"]:
# Match patterns like "residual.0.gamma" → "residual.layer_0.gamma"
# or "head.0.gamma" → "head.layer_0.gamma"
old_pattern = f".residual.{idx}."
new_pattern = f".residual.layer_{idx}."
new_key = new_key.replace(old_pattern, new_pattern)
# Head layer mapping: head.0.gamma → head.layer_0.gamma, head.2.weight → head.layer_2.weight
for idx in ["0", "2"]:
old_pattern = f".head.{idx}."
new_pattern = f".head.layer_{idx}."
new_key = new_key.replace(old_pattern, new_pattern)
# Map Resample Conv2d: resample.1.weight → resample_weight, resample.1.bias → resample_bias
if ".resample.1.weight" in new_key:
new_key = new_key.replace(".resample.1.weight", ".resample_weight")
elif ".resample.1.bias" in new_key:
new_key = new_key.replace(".resample.1.bias", ".resample_bias")
# Map AttentionBlock Conv2d weights
if ".to_qkv.weight" in new_key:
new_key = new_key.replace(".to_qkv.weight", ".to_qkv_weight")
elif ".to_qkv.bias" in new_key:
new_key = new_key.replace(".to_qkv.bias", ".to_qkv_bias")
elif ".proj.weight" in new_key and "time_projection" not in new_key:
new_key = new_key.replace(".proj.weight", ".proj_weight")
elif ".proj.bias" in new_key and "time_projection" not in new_key:
new_key = new_key.replace(".proj.bias", ".proj_bias")
# Transpose conv weights to channels-last
is_weight = new_key.endswith(".weight") or new_key.endswith("_weight")
if is_weight:
if value.ndim == 5:
# Conv3d: [O, I, D, H, W] → [O, D, H, W, I]
value = np.transpose(np.array(value), (0, 2, 3, 4, 1))
value = mx.array(value)
elif value.ndim == 4:
# Conv2d: [O, I, H, W] → [O, H, W, I]
value = np.transpose(np.array(value), (0, 2, 3, 1))
value = mx.array(value)
# Squeeze RMS_norm gamma: (dim, 1, 1, 1) or (dim, 1, 1) → (dim,)
if "gamma" in new_key:
value = mx.array(np.array(value).squeeze())
sanitized[new_key] = value
return sanitized

View File

@@ -19,7 +19,9 @@ dependencies = [
"tqdm",
"opencv-python>=4.12.0.88",
"Pillow>=10.3.0",
"mlx-vlm"
"mlx-vlm",
"imageio>=2.37.2",
"imageio-ffmpeg>=0.6.0",
]
license = {text="MIT"}
authors = [
@@ -42,6 +44,7 @@ Issues = "https://github.com/Blaizzy/mlx-video/issues"
[project.scripts]
"mlx_video.generate" = "mlx_video.generate:main"
"mlx_video.generate_wan" = "mlx_video.generate_wan:main"
[tool.setuptools.packages.find]
include = ["mlx_video*"]

1453
tests/test_wan.py Normal file

File diff suppressed because it is too large Load Diff

34
uv.lock generated
View File

@@ -2,7 +2,8 @@ version = 1
revision = 3
requires-python = ">=3.11"
resolution-markers = [
"python_full_version >= '3.12'",
"python_full_version >= '3.13'",
"python_full_version == '3.12.*'",
"python_full_version < '3.12'",
]
@@ -614,6 +615,33 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" },
]
[[package]]
name = "imageio"
version = "2.37.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy" },
{ name = "pillow" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a3/6f/606be632e37bf8d05b253e8626c2291d74c691ddc7bcdf7d6aaf33b32f6a/imageio-2.37.2.tar.gz", hash = "sha256:0212ef2727ac9caa5ca4b2c75ae89454312f440a756fcfc8ef1993e718f50f8a", size = 389600, upload-time = "2025-11-04T14:29:39.898Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/fb/fe/301e0936b79bcab4cacc7548bf2853fc28dced0a578bab1f7ef53c9aa75b/imageio-2.37.2-py3-none-any.whl", hash = "sha256:ad9adfb20335d718c03de457358ed69f141021a333c40a53e57273d8a5bd0b9b", size = 317646, upload-time = "2025-11-04T14:29:37.948Z" },
]
[[package]]
name = "imageio-ffmpeg"
version = "0.6.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/44/bd/c3343c721f2a1b0c9fc71c1aebf1966a3b7f08c2eea8ed5437a2865611d6/imageio_ffmpeg-0.6.0.tar.gz", hash = "sha256:e2556bed8e005564a9f925bb7afa4002d82770d6b08825078b7697ab88ba1755", size = 25210, upload-time = "2025-01-16T21:34:32.747Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/da/58/87ef68ac83f4c7690961bce288fd8e382bc5f1513860fc7f90a9c1c1c6bf/imageio_ffmpeg-0.6.0-py3-none-macosx_10_9_intel.macosx_10_9_x86_64.whl", hash = "sha256:9d2baaf867088508d4a3458e61eeb30e945c4ad8016025545f66c4b5aaef0a61", size = 24932969, upload-time = "2025-01-16T21:34:20.464Z" },
{ url = "https://files.pythonhosted.org/packages/40/5c/f3d8a657d362cc93b81aab8feda487317da5b5d31c0e1fdfd5e986e55d17/imageio_ffmpeg-0.6.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:b1ae3173414b5fc5f538a726c4e48ea97edc0d2cdc11f103afee655c463fa742", size = 21113891, upload-time = "2025-01-16T21:34:00.277Z" },
{ url = "https://files.pythonhosted.org/packages/33/e7/1925bfbc563c39c1d2e82501d8372734a5c725e53ac3b31b4c2d081e895b/imageio_ffmpeg-0.6.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1d47bebd83d2c5fc770720d211855f208af8a596c82d17730aa51e815cdee6dc", size = 25632706, upload-time = "2025-01-16T21:33:53.475Z" },
{ url = "https://files.pythonhosted.org/packages/a0/2d/43c8522a2038e9d0e7dbdf3a61195ecc31ca576fb1527a528c877e87d973/imageio_ffmpeg-0.6.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c7e46fcec401dd990405049d2e2f475e2b397779df2519b544b8aab515195282", size = 29498237, upload-time = "2025-01-16T21:34:13.726Z" },
{ url = "https://files.pythonhosted.org/packages/a0/13/59da54728351883c3c1d9fca1710ab8eee82c7beba585df8f25ca925f08f/imageio_ffmpeg-0.6.0-py3-none-win32.whl", hash = "sha256:196faa79366b4a82f95c0f4053191d2013f4714a715780f0ad2a68ff37483cc2", size = 19652251, upload-time = "2025-01-16T21:34:06.812Z" },
{ url = "https://files.pythonhosted.org/packages/2c/c6/fa760e12a2483469e2bf5058c5faff664acf66cadb4df2ad6205b016a73d/imageio_ffmpeg-0.6.0-py3-none-win_amd64.whl", hash = "sha256:02fa47c83703c37df6bfe4896aab339013f62bf02c5ebf2dce6da56af04ffc0a", size = 31246824, upload-time = "2025-01-16T21:34:28.6Z" },
]
[[package]]
name = "iniconfig"
version = "2.3.0"
@@ -772,6 +800,8 @@ name = "mlx-video"
source = { editable = "." }
dependencies = [
{ name = "huggingface-hub" },
{ name = "imageio" },
{ name = "imageio-ffmpeg" },
{ name = "mlx" },
{ name = "mlx-vlm" },
{ name = "numpy" },
@@ -790,6 +820,8 @@ dev = [
[package.metadata]
requires-dist = [
{ name = "huggingface-hub" },
{ name = "imageio", specifier = ">=2.37.2" },
{ name = "imageio-ffmpeg", specifier = ">=0.6.0" },
{ name = "mlx", specifier = ">=0.22.0" },
{ name = "mlx-vlm" },
{ name = "numpy" },