feat(wan): Add Wan2.1/2.2 T2V with quantization support
This commit is contained in:
158
.github/copilot-instructions.md
vendored
Normal file
158
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
# MLX-Video Copilot Instructions
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
MLX-Video is a video/audio generation package using Apple MLX framework. It implements the LTX-2 model (19B parameter DiT) for text-to-video, image-to-video, and audio-video generation, optimized for Apple Silicon.
|
||||||
|
|
||||||
|
## Build, Test, and Lint
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
```bash
|
||||||
|
# Install test dependencies first (pytest not in main deps)
|
||||||
|
pip install pytest
|
||||||
|
|
||||||
|
# Run all tests
|
||||||
|
python -m pytest tests/
|
||||||
|
|
||||||
|
# Run specific test file
|
||||||
|
python -m pytest tests/test_generate_dev.py
|
||||||
|
|
||||||
|
# Run specific test
|
||||||
|
python -m pytest tests/test_generate_dev.py::TestLTX2Scheduler::test_scheduler_output_shape
|
||||||
|
```
|
||||||
|
|
||||||
|
### Linting
|
||||||
|
Pre-commit hooks configured with:
|
||||||
|
- **black**: Code formatting
|
||||||
|
- **isort**: Import sorting (profile: black)
|
||||||
|
- **autoflake**: Remove unused imports
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run pre-commit manually
|
||||||
|
pre-commit run --all-files
|
||||||
|
```
|
||||||
|
|
||||||
|
### Running Generation
|
||||||
|
```bash
|
||||||
|
# Quick test - distilled model (two-stage pipeline)
|
||||||
|
python -m mlx_video.generate --prompt "test video" --num-frames 33
|
||||||
|
|
||||||
|
# Dev model with CFG (single-stage, higher quality)
|
||||||
|
python -m mlx_video.generate_dev --prompt "test video" --steps 40 --cfg-scale 4.0
|
||||||
|
|
||||||
|
# Audio-video generation
|
||||||
|
python -m mlx_video.generate_av --prompt "test video" --output-path out.mp4 --output-audio out.wav
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Two-Stage Pipeline (Distilled Model)
|
||||||
|
The distilled model (`generate.py`) uses a two-stage approach for efficiency:
|
||||||
|
1. **Stage 1**: Generate at half resolution with 8 denoising steps using STAGE_1_SIGMAS
|
||||||
|
2. **Upsampler**: 2x spatial upsampling via LatentUpsampler
|
||||||
|
3. **Stage 2**: Refine at full resolution with 3 steps using STAGE_2_SIGMAS
|
||||||
|
4. **VAE Decoder**: Convert latents to RGB video (tiled decoding for memory efficiency)
|
||||||
|
|
||||||
|
### Single-Stage Pipeline (Dev Model)
|
||||||
|
The dev model (`generate_dev.py`) uses classifier-free guidance (CFG):
|
||||||
|
- Full resolution generation with configurable steps (typically 40)
|
||||||
|
- CFG guidance scale controls prompt adherence vs. diversity
|
||||||
|
- More flexible but slower than distilled model
|
||||||
|
|
||||||
|
### Core Components
|
||||||
|
|
||||||
|
**DiT Transformer** (`models/ltx/ltx.py`):
|
||||||
|
- 48 layers, 32 attention heads, 128 dim per head
|
||||||
|
- Dual modality support: video (3840-dim) and audio (2048-dim) embeddings
|
||||||
|
- Uses RoPE (Rotary Position Embeddings) in SPLIT mode with double precision
|
||||||
|
- AdaLN-Zero conditioning blocks inject timestep/text embeddings
|
||||||
|
|
||||||
|
**VAE Architecture**:
|
||||||
|
- **Video VAE**: 128 latent channels, 8x temporal + 32x spatial compression
|
||||||
|
- Encoder: `models/ltx/video_vae/encoder.py`
|
||||||
|
- Decoder: `models/ltx/video_vae/decoder.py` (supports tiled decoding)
|
||||||
|
- **Audio VAE**: 8 latent channels, mel-spectrogram intermediate
|
||||||
|
- Decoder: `models/ltx/audio_vae/decoder.py`
|
||||||
|
- HiFi-GAN vocoder: `models/ltx/audio_vae/vocoder.py`
|
||||||
|
|
||||||
|
**Text Encoder** (`models/ltx/text_encoder.py`):
|
||||||
|
- Based on Gemma 3 model
|
||||||
|
- Returns separate embeddings for video (3840-dim) and audio (2048-dim)
|
||||||
|
- Supports prompt enhancement via `enhance_t2v()` method
|
||||||
|
|
||||||
|
**Tiling System** (`models/ltx/video_vae/tiling.py`):
|
||||||
|
- Memory-efficient decoding for large videos
|
||||||
|
- Modes: auto, default (512px/64f), aggressive (256px/32f), conservative (768px/96f)
|
||||||
|
- Supports streaming via `on_frames_ready` callback
|
||||||
|
|
||||||
|
### Key Patterns
|
||||||
|
|
||||||
|
**Position Grids**:
|
||||||
|
- Created in pixel space, then converted to latent space internally
|
||||||
|
- Video: (B, 3, num_patches, 2) with [start, end) bounds for temporal/spatial dims
|
||||||
|
- Audio: (B, 1, num_patches, 2) for temporal dimension only
|
||||||
|
- See `create_position_grid()` in generate modules
|
||||||
|
|
||||||
|
**Latent Conditioning** (`conditioning/latent.py`):
|
||||||
|
- `LatentState` tracks clean latents, noise, and sigma values
|
||||||
|
- `VideoConditionByLatentIndex` enables I2V by conditioning specific frames
|
||||||
|
- `apply_denoise_mask()` protects conditioned regions during denoising
|
||||||
|
|
||||||
|
**Weight Loading**:
|
||||||
|
- `convert.py`: Downloads from HuggingFace, converts PyTorch → MLX format
|
||||||
|
- Sanitization functions (`sanitize_transformer_weights`, `sanitize_vae_encoder_weights`) adapt keys
|
||||||
|
- Uses safetensors for efficient loading
|
||||||
|
|
||||||
|
## Key Conventions
|
||||||
|
|
||||||
|
### Model Configuration
|
||||||
|
- Always use `LTXModelConfig` to instantiate models
|
||||||
|
- `model_type` determines modality: `VideoOnly`, `AudioOnly`, or `AudioVideo`
|
||||||
|
- `rope_type=LTXRopeType.SPLIT` and `double_precision_rope=True` are standard
|
||||||
|
|
||||||
|
### Frame Count Requirements
|
||||||
|
- **Distilled model**: `num_frames = 1 + 8*k` format (e.g., 33, 65, 97)
|
||||||
|
- **Dev model**: No strict requirement, but odd numbers work better
|
||||||
|
- Audio frames auto-computed from video duration via `AUDIO_LATENTS_PER_SECOND`
|
||||||
|
|
||||||
|
### Dimension Constraints
|
||||||
|
- Video height/width must be divisible by 64 (VAE spatial compression)
|
||||||
|
- Latent dimensions are pixel dimensions divided by 32
|
||||||
|
|
||||||
|
### Audio Constants
|
||||||
|
```python
|
||||||
|
AUDIO_SAMPLE_RATE = 24000 # Output sample rate
|
||||||
|
AUDIO_LATENT_SAMPLE_RATE = 16000 # VAE internal rate
|
||||||
|
AUDIO_HOP_LENGTH = 160 # Mel hop length
|
||||||
|
AUDIO_LATENT_CHANNELS = 8 # Audio latent channels
|
||||||
|
AUDIO_MEL_BINS = 16 # Mel frequency bins
|
||||||
|
```
|
||||||
|
|
||||||
|
### Sigma Schedules
|
||||||
|
Distilled model uses predefined schedules (no scheduler class):
|
||||||
|
```python
|
||||||
|
STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
|
||||||
|
STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
|
||||||
|
```
|
||||||
|
|
||||||
|
Dev model computes schedules via `ltx2_scheduler(steps)` function.
|
||||||
|
|
||||||
|
### Code Style
|
||||||
|
- Follow black formatting (configured in pre-commit)
|
||||||
|
- Import sorting: isort with black profile
|
||||||
|
- Remove unused imports (autoflake)
|
||||||
|
- Type hints encouraged but not enforced
|
||||||
|
|
||||||
|
### Modality Enum
|
||||||
|
Use `Modality.VIDEO` and `Modality.AUDIO` from `models/ltx/transformer.py` for multi-modal operations.
|
||||||
|
|
||||||
|
### Video Post-Processing
|
||||||
|
- `postprocess.py`: Contains utilities for frame normalization and video saving
|
||||||
|
- Always denormalize latents from [-1, 1] to [0, 255] before saving
|
||||||
|
- Use opencv-python for video I/O
|
||||||
|
|
||||||
|
## Python Requirements
|
||||||
|
- Python >= 3.11
|
||||||
|
- MLX >= 0.22.0
|
||||||
|
- Primary dependencies: numpy, safetensors, transformers, opencv-python, Pillow, mlx-vlm, scipy, librosa
|
||||||
|
- Package manager: uv recommended for faster installs, pip also supported
|
||||||
26
.github/skills/fast-mlx/SKILL.md
vendored
Normal file
26
.github/skills/fast-mlx/SKILL.md
vendored
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
---
|
||||||
|
name: fast-mlx
|
||||||
|
description: Optimize MLX code for performance and memory. Use when asked to implement or speed up MLX models or algorithms, reduce latency/throughput bottlenecks, tune lazy evaluation, type promotion, fast ops, compilation, memory use, or profiling.
|
||||||
|
---
|
||||||
|
|
||||||
|
# Fast MLX
|
||||||
|
|
||||||
|
## Workflow
|
||||||
|
|
||||||
|
- Looks for opportunities to compile functions of mostly elementwise operations.
|
||||||
|
- For models with fixed shape inputs or where the shapes don't change much, compile the entire graph
|
||||||
|
- Replace slow implementations with MLX fast ops
|
||||||
|
- Identify evaluation boundaries and unintended sync points (`mx.eval`, `item()`, NumPy conversions).
|
||||||
|
- Check dtype promotion and scalar usage; keep precision consistent with intent.
|
||||||
|
- Review compilation strategy; avoid unnecessary recompiles and closure captures.
|
||||||
|
- Reduce peak memory via lazy loading order and releasing temporaries before `mx.eval`.
|
||||||
|
- Suggest profiling steps if the bottleneck is unclear.
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- Read `references/fast-mlx-guide.md` for detailed tips and examples. Use it as the source of truth.
|
||||||
|
|
||||||
|
## Output expectations
|
||||||
|
|
||||||
|
- Provide concrete code changes with brief rationale
|
||||||
|
- Call out changes that need user confirmation (e.g., enabling async eval or shapeless compile).
|
||||||
350
.github/skills/fast-mlx/references/fast-mlx-guide.md
vendored
Normal file
350
.github/skills/fast-mlx/references/fast-mlx-guide.md
vendored
Normal file
@@ -0,0 +1,350 @@
|
|||||||
|
# Making MLX Go Fast
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Graph Evaluation](#graph-evaluation)
|
||||||
|
- [Type Promotion](#type-promotion)
|
||||||
|
- [Operations](#operations)
|
||||||
|
- [Compile](#compile)
|
||||||
|
- [Memory Use](#memory-use)
|
||||||
|
- [Profiling](#profiling)
|
||||||
|
|
||||||
|
This guide assumes you have some familiarity with MLX and want to make your MLX
|
||||||
|
model or algorithm as efficient as possible.
|
||||||
|
|
||||||
|
### Graph Evaluation
|
||||||
|
|
||||||
|
Recall, MLX is lazy. When you call an MLX op, no computation actually happens.
|
||||||
|
You are simply building a graph. The computation happens when you explicitly or
|
||||||
|
implicitly evaluate an array. Read more about how this works in the
|
||||||
|
documentation:
|
||||||
|
https://ml-explore.github.io/mlx/build/html/usage/lazy_evaluation.html
|
||||||
|
|
||||||
|
Evaluating the graph incurs some overhead, so don't do it too frequently.
|
||||||
|
Conversely you don't want the graph to get too big before evaluating it as this
|
||||||
|
can also be expensive. Most numerical and machine learning algorithms are
|
||||||
|
iterative. A good place to evaluate the graph is at the end of each iteration.
|
||||||
|
Some examples:
|
||||||
|
|
||||||
|
- After an iteration of gradient descent
|
||||||
|
- After producing one token with a language model
|
||||||
|
- After taking one denoising step in a diffusion model
|
||||||
|
|
||||||
|
Overly frequent evaluations sometimes happen by accident. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# output is an mx.array
|
||||||
|
for x in output:
|
||||||
|
do_something(x.item())
|
||||||
|
```
|
||||||
|
|
||||||
|
The same thing can be written more explicitly with operations and `mx.eval` as:
|
||||||
|
|
||||||
|
```python
|
||||||
|
for i in range(len(output)):
|
||||||
|
x = output[i]
|
||||||
|
mx.eval(x)
|
||||||
|
do_something(x.item())
|
||||||
|
```
|
||||||
|
|
||||||
|
Two better options are:
|
||||||
|
|
||||||
|
1. When possible avoid calling `item()` and do everything in MLX.
|
||||||
|
2. Move the entire output to Python or NumPy first.
|
||||||
|
|
||||||
|
An example of the second approach:
|
||||||
|
|
||||||
|
```python
|
||||||
|
for x in output.tolist():
|
||||||
|
do_something(x)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Asynchronous Evaluation
|
||||||
|
|
||||||
|
For a latency sensitive computation which is run many times, `mx.async_eval`
|
||||||
|
can be useful. Normally `mx.eval` is synchronous. It returns only when the
|
||||||
|
computation is complete. Instead `mx.async_eval` asynchronously evaluates the
|
||||||
|
graph and returns to the main thread immediately. You can use this to pipeline
|
||||||
|
graph construction with computation like so:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def generator():
|
||||||
|
out = mx.async_eval(my_function())
|
||||||
|
|
||||||
|
while True:
|
||||||
|
out_next = mx.async_eval(my_function())
|
||||||
|
mx.eval(out)
|
||||||
|
yield out
|
||||||
|
out = out_next
|
||||||
|
```
|
||||||
|
|
||||||
|
For this to work `my_function()` cannot do any synchronous evaluations (e.g.
|
||||||
|
calling `mx.eval`, converting to NumPy, etc.). Furthermore, any work done on
|
||||||
|
`out` that is synchronous and on the same stream can stall the pipeline:
|
||||||
|
|
||||||
|
```python
|
||||||
|
for out in generator():
|
||||||
|
out = out * 2
|
||||||
|
# Stalls the pipeline!
|
||||||
|
mx.eval(out)
|
||||||
|
```
|
||||||
|
|
||||||
|
An easy fix for this is to put the pipeline in a separate stream:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def generator():
|
||||||
|
with mx.stream(mx.new_stream(mx.gpu)):
|
||||||
|
out = mx.async_eval(my_function())
|
||||||
|
|
||||||
|
while True:
|
||||||
|
out_next = mx.async_eval(my_function())
|
||||||
|
mx.eval(out)
|
||||||
|
yield out
|
||||||
|
out = out_next
|
||||||
|
```
|
||||||
|
|
||||||
|
### Type Promotion
|
||||||
|
|
||||||
|
One of the most common performance issues comes from accidental up-casting.
|
||||||
|
Make sure you understand how type promotion works in MLX. The inputs to an MLX
|
||||||
|
operation are typically promoted to a common type which doesn't lose precision.
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
x = mx.array(1.0, mx.float32) * mx.array(2.0, mx.float16)
|
||||||
|
```
|
||||||
|
|
||||||
|
will result in `x` with type `mx.float32`. Similarly:
|
||||||
|
|
||||||
|
```python
|
||||||
|
x = mx.array(1.0, mx.bfloat16) * mx.array(2.0, mx.float16)
|
||||||
|
```
|
||||||
|
|
||||||
|
will result in `x` with type `mx.float32`. A common mistake is to multiply a
|
||||||
|
half-precision array by a default-typed scalar array which up-casts everything
|
||||||
|
to `mx.float32`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Warning: x has type mx.float32
|
||||||
|
x = my_fp16_array * mx.array(2.0)
|
||||||
|
```
|
||||||
|
|
||||||
|
To multiply by a scalar while preserving the input type, use Python scalars.
|
||||||
|
Python scalars are weakly typed and have more relaxed promotion rules when
|
||||||
|
used with MLX arrays.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Ok, x has type mx.float16
|
||||||
|
x = my_fp16_array * 2.0
|
||||||
|
```
|
||||||
|
|
||||||
|
### Operations
|
||||||
|
|
||||||
|
#### Use Fast Ops
|
||||||
|
|
||||||
|
Use `mx.fast` ops when possible:
|
||||||
|
|
||||||
|
- `mx.fast.rms_norm`
|
||||||
|
- `mx.fast.layer_norm`
|
||||||
|
- `mx.fast.rope`
|
||||||
|
- `mx.fast.scaled_dot_product_attention`
|
||||||
|
|
||||||
|
A lot of these operations take a variety of parameters so they can be used for
|
||||||
|
different variations of the function. For example, the weight and bias
|
||||||
|
parameters are optional in `mx.fast.layer_norm` so it can be used with
|
||||||
|
different permutations of inputs.
|
||||||
|
|
||||||
|
#### Precision
|
||||||
|
|
||||||
|
For operations which typically use higher precision there is usually no
|
||||||
|
need to explicitly upcast. For example, `mx.fast.rms_norm` and
|
||||||
|
`mx.fast.layer_norm` accumulate in higher precision so it's
|
||||||
|
wasteful to upcast and downcast into and out of these operations:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# No need for this!
|
||||||
|
mx.fast.rms_norm(x.astype(mx.float32), w, b, eps).astype(x.dtype)
|
||||||
|
|
||||||
|
# This is just as good:
|
||||||
|
mx.fast.rms_norm(x, w, b, eps)
|
||||||
|
```
|
||||||
|
|
||||||
|
Similarly, for `mx.softmax` use `precise=True` if you want to do the softmax in
|
||||||
|
higher precision rather than explicitly casting the input and output.
|
||||||
|
|
||||||
|
#### Misc
|
||||||
|
|
||||||
|
- For vector-matrix multiplication `x @ W.T` is faster than `x @ W`, for
|
||||||
|
matrix-vector multiplication `W @ x` is faster than `W.T @ x`
|
||||||
|
- Use `mx.addmm` for `a @ b + c` (e.g. a linear layer with a bias).
|
||||||
|
- Where it makes sense, use `mx.take_along_axis` and `mx.put_along_axis`
|
||||||
|
instead of fancy indexing
|
||||||
|
- Use broadcasting instead of concatenation. For example, prefer `mx.repeat(a,
|
||||||
|
n)` over `mx.concatenate([a] * n)`
|
||||||
|
|
||||||
|
### Compile
|
||||||
|
|
||||||
|
Compiling graphs with `mx.compile` can make them run a lot faster. But there
|
||||||
|
are some sharp-edges that are good to be aware of.
|
||||||
|
|
||||||
|
First, be aware of when a function will be recompiled. Recompilation is
|
||||||
|
relatively expensive and should only be done if there is sufficient work over
|
||||||
|
which to amortize the cost.
|
||||||
|
|
||||||
|
The default behavior of `mx.compile` is to do a shape-dependent compilation.
|
||||||
|
This means the function will be recompiled if the shape of any input changes.
|
||||||
|
|
||||||
|
MLX supports a shapeless compilation by passing `shapeless=True` to
|
||||||
|
`mx.compile`. It's easy to make hard-to-detect mistakes with shapeless
|
||||||
|
compilation. Make sure to read and understand the documentation and use it
|
||||||
|
with care:
|
||||||
|
https://ml-explore.github.io/mlx/build/html/usage/compile.html#shapeless-compilation
|
||||||
|
|
||||||
|
A function will also be recompiled if any constant inputs change:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@mx.compile
|
||||||
|
def fun(x, scale):
|
||||||
|
return scale * x
|
||||||
|
|
||||||
|
fun(x, 3)
|
||||||
|
|
||||||
|
# Recompiles!
|
||||||
|
fun(x, 4)
|
||||||
|
```
|
||||||
|
|
||||||
|
In this case a simple fix is to make `scale` an `mx.array`.
|
||||||
|
|
||||||
|
#### Compiling Closures
|
||||||
|
|
||||||
|
Be careful when compiling a closure where the function encloses any
|
||||||
|
`mx.array`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
y = some_function()
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def fun(x):
|
||||||
|
return x + y
|
||||||
|
```
|
||||||
|
|
||||||
|
Since `y` is not an input to `fun`, the compiled graph will include the entire
|
||||||
|
computation which produces `y`. Usually you only want to compute `y` one time
|
||||||
|
and re-use it in the compiled function. Either explicitly pass it as an input
|
||||||
|
to `fun` or pass it as an implicit input to `mx.compile` like so:
|
||||||
|
|
||||||
|
```python
|
||||||
|
y = some_function()
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=[y])
|
||||||
|
def fun(x):
|
||||||
|
return x + y
|
||||||
|
```
|
||||||
|
|
||||||
|
### Memory Use
|
||||||
|
|
||||||
|
#### Lazy Loading
|
||||||
|
|
||||||
|
Loading arrays from a file is lazy in MLX:
|
||||||
|
|
||||||
|
```python
|
||||||
|
weights = mx.load("model.safetensors")
|
||||||
|
```
|
||||||
|
|
||||||
|
The above function returns instantly, regardless of the file size. To actually
|
||||||
|
load the weights into memory, you can do `mx.eval(weights)`.
|
||||||
|
|
||||||
|
Assume the weights are stored on disk in 32-bit precision (i.e. `mx.float32`).
|
||||||
|
But for your model you only need 16-bit precision:
|
||||||
|
|
||||||
|
```python
|
||||||
|
weights = mx.load("model.safetensors")
|
||||||
|
mx.eval(weights)
|
||||||
|
weights = {k: v.astype(mx.float16) for k, v in weights.items()}
|
||||||
|
```
|
||||||
|
|
||||||
|
In the above, the weights will be loaded into memory in full precision and then
|
||||||
|
cast to 16-bit. This requires memory for all the weights in 32-bit plus memory
|
||||||
|
for the weights in 16-bit.
|
||||||
|
|
||||||
|
This is much better:
|
||||||
|
|
||||||
|
```python
|
||||||
|
weights = mx.load("model.safetensors")
|
||||||
|
weights = {k: v.astype(mx.float16) for k, v in weights.items()}
|
||||||
|
mx.eval(weights)
|
||||||
|
```
|
||||||
|
|
||||||
|
Evaluating after the cast to `mx.float16` reduces peak memory by nearly a
|
||||||
|
third. That's because all the weights are never fully materialized in 32-bit.
|
||||||
|
Right after each weight is loaded in 32-bit precision it is cast to 16-bit.
|
||||||
|
The memory for the 32-bit weight can be reused when loading the next weight.
|
||||||
|
|
||||||
|
Note, MLX is only able to lazy load from a file when it is given to `mx.load`
|
||||||
|
as a string path. Due to lifetime management issues, lazy loading from file
|
||||||
|
handles is not supported. So avoid this:
|
||||||
|
|
||||||
|
```python
|
||||||
|
weights = mx.load(open("model.safetensors", 'rb'))
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Release Temporaries
|
||||||
|
|
||||||
|
One way to reduce memory consumption is to avoid holding
|
||||||
|
temporaries you don't need. This is a typical training loop:
|
||||||
|
|
||||||
|
```python
|
||||||
|
for x, y in dataset:
|
||||||
|
loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
mx.eval(model, optimizer.state)
|
||||||
|
```
|
||||||
|
|
||||||
|
It's suboptimal since a reference to `grads` is held during the call to
|
||||||
|
`mx.eval` which keeps the respective memory from being used for any other part
|
||||||
|
of the computation.
|
||||||
|
|
||||||
|
This is better:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def step(x, y):
|
||||||
|
loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
for x, y in dataset:
|
||||||
|
loss = step(x, y)
|
||||||
|
mx.eval(model, optimizer.state)
|
||||||
|
```
|
||||||
|
|
||||||
|
In this case the reference to `grads` is released before `mx.eval` and the
|
||||||
|
memory can be reused. You can achieve the same goal using `del` as long as it's
|
||||||
|
before the call to `mx.eval`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
for x, y in dataset:
|
||||||
|
loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
del grads
|
||||||
|
mx.eval(model, optimizer.state)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Misc
|
||||||
|
|
||||||
|
- MLX will cache memory buffers of recently released arrays rather than
|
||||||
|
returning them to the system. In some cases, especially for variable shape
|
||||||
|
computations, the cache can get large. To help with this, MLX has some
|
||||||
|
functions for logging and customizing the behavior of memory allocation:
|
||||||
|
https://ml-explore.github.io/mlx/build/html/python/metal.html
|
||||||
|
|
||||||
|
### Profiling
|
||||||
|
|
||||||
|
A good first step is to check GPU utilization using, for example,
|
||||||
|
mactop: https://github.com/context-labs/mactop. If it's not pegged at close
|
||||||
|
to 100% then there is likely a non-MLX bottleneck somewhere in the program. A
|
||||||
|
common culprit is data loading or preprocessing.
|
||||||
|
|
||||||
|
If GPU utilization is good, a good next step is to figure out which operations
|
||||||
|
are taking up so much time. One way to do this is with the Metal debugger. For
|
||||||
|
that, see the documentation on profiling MLX with the Metal debugger:
|
||||||
|
https://ml-explore.github.io/mlx/build/html/dev/metal_debugger.html
|
||||||
250
README.md
250
README.md
@@ -18,18 +18,20 @@ uv pip install git+https://github.com/Blaizzy/mlx-video.git
|
|||||||
|
|
||||||
Supported models:
|
Supported models:
|
||||||
|
|
||||||
### LTX-2
|
- [**LTX-2**](https://huggingface.co/Lightricks/LTX-Video) — 19B parameter video generation model from Lightricks
|
||||||
[LTX-2](https://huggingface.co/Lightricks/LTX-Video) is 19B parameter video generation model from Lightricks
|
- [**Wan2.1**](https://github.com/Wan-Video/Wan2.1) — 1.3B / 14B parameter T2V models (single-model pipeline)
|
||||||
|
- [**Wan2.2**](https://github.com/Wan-Video/Wan2.2) — 14B parameter T2V model (dual-model pipeline)
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- Text-to-video generation with the LTX-2 19B DiT model
|
- Text-to-video generation with multiple model families
|
||||||
- Two-stage generation pipeline for high-quality output
|
- LTX-2: Two-stage pipeline with 2x spatial upscaling
|
||||||
- 2x spatial upscaling for images and videos
|
- Wan2.1/2.2: Flow-matching diffusion with classifier-free guidance
|
||||||
- Optimized for Apple Silicon using MLX
|
- Optimized for Apple Silicon using MLX
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Usage
|
## LTX-2
|
||||||
|
|
||||||
> **ℹ️ Info:** Currently, only the distilled variant is supported. Full LTX-2 feature support is coming soon.
|
> **ℹ️ Info:** Currently, only the distilled variant is supported. Full LTX-2 feature support is coming soon.
|
||||||
|
|
||||||
@@ -53,7 +55,7 @@ python -m mlx_video.generate \
|
|||||||
--output my_video.mp4
|
--output my_video.mp4
|
||||||
```
|
```
|
||||||
|
|
||||||
### CLI Options
|
### LTX-2 CLI Options
|
||||||
|
|
||||||
| Option | Default | Description |
|
| Option | Default | Description |
|
||||||
|--------|---------|-------------|
|
|--------|---------|-------------|
|
||||||
@@ -67,45 +69,229 @@ python -m mlx_video.generate \
|
|||||||
| `--save-frames` | false | Save individual frames as images |
|
| `--save-frames` | false | Save individual frames as images |
|
||||||
| `--model-repo` | Lightricks/LTX-2 | HuggingFace model repository |
|
| `--model-repo` | Lightricks/LTX-2 | HuggingFace model repository |
|
||||||
|
|
||||||
## How It Works
|
### How It Works (LTX-2)
|
||||||
|
|
||||||
The pipeline uses a two-stage generation process:
|
1. **Stage 1**: Generate at half resolution (e.g., 384×384) with 8 denoising steps
|
||||||
|
2. **Upsample**: 2× spatial upsampling via LatentUpsampler
|
||||||
1. **Stage 1**: Generate at half resolution (e.g., 384x384) with 8 denoising steps
|
3. **Stage 2**: Refine at full resolution (e.g., 768×768) with 3 denoising steps
|
||||||
2. **Upsample**: 2x spatial upsampling via LatentUpsampler
|
|
||||||
3. **Stage 2**: Refine at full resolution (e.g., 768x768) with 3 denoising steps
|
|
||||||
4. **Decode**: VAE decoder converts latents to RGB video
|
4. **Decode**: VAE decoder converts latents to RGB video
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Wan2.1 / Wan2.2
|
||||||
|
|
||||||
|
Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.com/Wan-Video/Wan2.2) are text-to-video diffusion models built on a DiT (Diffusion Transformer) backbone with a T5 text encoder and 3D VAE. They share the same model architecture — the difference is in the inference pipeline:
|
||||||
|
|
||||||
|
| | Wan2.1 | Wan2.2 |
|
||||||
|
|---|--------|--------|
|
||||||
|
| **Pipeline** | Single model | Dual model (high-noise + low-noise) |
|
||||||
|
| **Sizes** | 1.3B, 14B | 14B |
|
||||||
|
| **Steps** | 50 | 40 |
|
||||||
|
| **Guidance** | 5.0 (fixed) | 3.0 / 4.0 (low/high noise) |
|
||||||
|
| **Shift** | 5.0 | 12.0 |
|
||||||
|
|
||||||
|
### Step 1: Download Weights
|
||||||
|
|
||||||
|
Download the original PyTorch checkpoints:
|
||||||
|
|
||||||
|
**Wan2.1 (14B)**
|
||||||
|
```bash
|
||||||
|
# From https://github.com/Wan-Video/Wan2.1 or HuggingFace
|
||||||
|
# Expected directory structure:
|
||||||
|
# wan21_checkpoint/
|
||||||
|
# ├── models_t5_umt5-xxl-enc-bf16.pth
|
||||||
|
# ├── Wan2.1_VAE.pth
|
||||||
|
# └── diffusion_pytorch_model*.safetensors # single model
|
||||||
|
```
|
||||||
|
|
||||||
|
**Wan2.1 (1.3B)** — same structure, smaller transformer weights.
|
||||||
|
|
||||||
|
**Wan2.2 (14B)**
|
||||||
|
```bash
|
||||||
|
# From https://github.com/Wan-Video/Wan2.2 or HuggingFace
|
||||||
|
# Expected directory structure:
|
||||||
|
# wan22_checkpoint/
|
||||||
|
# ├── models_t5_umt5-xxl-enc-bf16.pth
|
||||||
|
# ├── Wan2.1_VAE.pth
|
||||||
|
# ├── low_noise_model/ # safetensors
|
||||||
|
# └── high_noise_model/ # safetensors
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Convert to MLX Format
|
||||||
|
|
||||||
|
The conversion script auto-detects whether the checkpoint is Wan2.1 or Wan2.2 based on the directory structure (presence of `low_noise_model/` subdirectory).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Auto-detect version
|
||||||
|
python -m mlx_video.convert_wan \
|
||||||
|
--checkpoint-dir /path/to/wan_checkpoint \
|
||||||
|
--output-dir wan_mlx
|
||||||
|
|
||||||
|
# Explicit version
|
||||||
|
python -m mlx_video.convert_wan \
|
||||||
|
--checkpoint-dir /path/to/wan21_checkpoint \
|
||||||
|
--output-dir wan21_mlx \
|
||||||
|
--model-version 2.1
|
||||||
|
|
||||||
|
python -m mlx_video.convert_wan \
|
||||||
|
--checkpoint-dir /path/to/wan22_checkpoint \
|
||||||
|
--output-dir wan22_mlx \
|
||||||
|
--model-version 2.2
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Conversion Options
|
||||||
|
|
||||||
|
| Option | Default | Description |
|
||||||
|
|--------|---------|-------------|
|
||||||
|
| `--checkpoint-dir` | (required) | Path to original PyTorch checkpoint directory |
|
||||||
|
| `--output-dir` | `wan_mlx_model` | Output path for MLX model |
|
||||||
|
| `--dtype` | `bfloat16` | Target dtype (`float16`, `float32`, `bfloat16`) |
|
||||||
|
| `--model-version` | `auto` | Model version: `2.1`, `2.2`, or `auto` |
|
||||||
|
| `--quantize` | off | Quantize transformer weights for reduced memory |
|
||||||
|
| `--bits` | `4` | Quantization bits: `4` or `8` |
|
||||||
|
| `--group-size` | `64` | Quantization group size: `32`, `64`, or `128` |
|
||||||
|
|
||||||
|
The converter produces:
|
||||||
|
```
|
||||||
|
wan_mlx/
|
||||||
|
├── config.json # Model configuration
|
||||||
|
├── t5_encoder.safetensors # T5 UMT5-XXL text encoder
|
||||||
|
├── vae.safetensors # 3D VAE decoder
|
||||||
|
├── model.safetensors # (Wan2.1) Single transformer
|
||||||
|
├── low_noise_model.safetensors # (Wan2.2) Low-noise transformer
|
||||||
|
└── high_noise_model.safetensors # (Wan2.2) High-noise transformer
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3: Generate Video
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Wan2.1 — uses defaults from config (50 steps, shift=5.0, guide=5.0)
|
||||||
|
python -m mlx_video.generate_wan \
|
||||||
|
--model-dir wan21_mlx \
|
||||||
|
--prompt "A cat playing piano in a cozy room"
|
||||||
|
|
||||||
|
# Wan2.2 — uses defaults from config (40 steps, shift=12.0, guide=3.0,4.0)
|
||||||
|
python -m mlx_video.generate_wan \
|
||||||
|
--model-dir wan22_mlx \
|
||||||
|
--prompt "A cat playing piano in a cozy room"
|
||||||
|
```
|
||||||
|
|
||||||
|
With custom settings:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m mlx_video.generate_wan \
|
||||||
|
--model-dir wan21_mlx \
|
||||||
|
--prompt "Ocean waves at sunset, cinematic, 4K" \
|
||||||
|
--negative-prompt "blurry, low quality" \
|
||||||
|
--width 1280 \
|
||||||
|
--height 720 \
|
||||||
|
--num-frames 81 \
|
||||||
|
--steps 50 \
|
||||||
|
--guide-scale 5.0 \
|
||||||
|
--shift 5.0 \
|
||||||
|
--seed 42 \
|
||||||
|
--output-path my_video.mp4
|
||||||
|
```
|
||||||
|
|
||||||
|
The pipeline auto-detects the model version from `config.json` and selects the right pipeline mode (single or dual model). You can also override any parameter via CLI flags.
|
||||||
|
|
||||||
|
#### Generation CLI Options
|
||||||
|
|
||||||
|
| Option | Default | Description |
|
||||||
|
|--------|---------|-------------|
|
||||||
|
| `--model-dir` | (required) | Path to converted MLX model directory |
|
||||||
|
| `--prompt` | (required) | Text description of the video |
|
||||||
|
| `--negative-prompt` | `""` | Negative prompt for guidance |
|
||||||
|
| `--width` | 1280 | Video width |
|
||||||
|
| `--height` | 720 | Video height |
|
||||||
|
| `--num-frames` | 81 | Number of frames (must be 4n+1) |
|
||||||
|
| `--steps` | from config | Number of diffusion steps |
|
||||||
|
| `--guide-scale` | from config | Guidance scale: float or `low,high` pair |
|
||||||
|
| `--shift` | from config | Noise schedule shift |
|
||||||
|
| `--seed` | -1 (random) | Random seed for reproducibility |
|
||||||
|
| `--output-path` | `output.mp4` | Output video path |
|
||||||
|
|
||||||
|
### Quantization (Reduced Memory)
|
||||||
|
|
||||||
|
Quantize the transformer weights to reduce memory usage by ~3.4x. This is especially useful for the 14B model or memory-constrained devices:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Convert with 4-bit quantization
|
||||||
|
python -m mlx_video.convert_wan \
|
||||||
|
--checkpoint-dir /path/to/Wan2.1-T2V-1.3B \
|
||||||
|
--output-dir wan21_mlx_q4 \
|
||||||
|
--quantize --bits 4 --group-size 64
|
||||||
|
|
||||||
|
# Generate with quantized model (auto-detected from config.json)
|
||||||
|
python -m mlx_video.generate_wan \
|
||||||
|
--model-dir wan21_mlx_q4 \
|
||||||
|
--prompt "A cat playing piano"
|
||||||
|
```
|
||||||
|
|
||||||
|
**What gets quantized**: Self-attention (Q/K/V/O), cross-attention (Q/K/V/O), and FFN (fc1/fc2) — 10 layers × N blocks = ~95% of model weights. Embeddings, norms, and the output head remain in bfloat16 for precision.
|
||||||
|
|
||||||
|
| Model | BF16 Size | 4-bit Size | Notes |
|
||||||
|
|-------|-----------|------------|-------|
|
||||||
|
| 1.3B | 2.7 GB | 799 MB | ~3.4x smaller |
|
||||||
|
| 14B | ~28 GB | ~8 GB | Enables running on 16GB devices |
|
||||||
|
|
||||||
|
> **Note**: On Apple Silicon, the 1.3B model fits comfortably in unified memory at bf16. Quantization reduces memory but may not speed up inference for small models. For the 14B model, quantization is essential to fit in memory and will also improve speed.
|
||||||
|
|
||||||
|
### Wan Model Specifications
|
||||||
|
|
||||||
|
**Transformer (14B)**
|
||||||
|
- 40 layers, 40 attention heads, dim 5120, head dim 128
|
||||||
|
- 3-way factorized RoPE (temporal + spatial)
|
||||||
|
- 14.29B parameters
|
||||||
|
|
||||||
|
**Transformer (1.3B, Wan2.1 only)**
|
||||||
|
- 30 layers, 12 attention heads, dim 1536, head dim 128
|
||||||
|
- Same architecture, smaller scale
|
||||||
|
|
||||||
|
**Text Encoder** — UMT5-XXL (5.68B parameters)
|
||||||
|
- 24 layers, 64 heads, dim 4096, vocab 256K
|
||||||
|
|
||||||
|
**VAE** — 3D causal convolution decoder (72.6M parameters)
|
||||||
|
- Latent channels: 16
|
||||||
|
- Compression: 4× temporal, 8× spatial
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
- macOS with Apple Silicon
|
- macOS with Apple Silicon
|
||||||
- Python >= 3.11
|
- Python >= 3.11
|
||||||
- MLX >= 0.22.0
|
- MLX >= 0.22.0
|
||||||
|
- For weight conversion: PyTorch (`pip install torch`)
|
||||||
## Model Specifications
|
|
||||||
|
|
||||||
- **Transformer**: 48 layers, 32 attention heads, 128 dim per head
|
|
||||||
- **Latent channels**: 128
|
|
||||||
- **Text encoder**: Gemma 3 with 3840-dim output
|
|
||||||
- **RoPE**: Split mode with double precision
|
|
||||||
|
|
||||||
## Project Structure
|
## Project Structure
|
||||||
|
|
||||||
```
|
```
|
||||||
mlx_video/
|
mlx_video/
|
||||||
├── generate.py # Video generation pipeline
|
├── generate.py # LTX-2 generation pipeline
|
||||||
├── convert.py # Weight conversion (PyTorch -> MLX)
|
├── generate_wan.py # Wan2.1/2.2 generation pipeline
|
||||||
├── postprocess.py # Video post-processing utilities
|
├── convert.py # LTX-2 weight conversion
|
||||||
├── utils.py # Helper functions
|
├── convert_wan.py # Wan weight conversion (PyTorch → MLX)
|
||||||
|
├── postprocess.py # Video post-processing utilities
|
||||||
|
├── utils.py # Helper functions
|
||||||
└── models/
|
└── models/
|
||||||
└── ltx/
|
├── ltx/ # LTX-2 model
|
||||||
├── ltx.py # Main LTXModel (DiT transformer)
|
│ ├── ltx.py # DiT transformer
|
||||||
├── config.py # Model configuration
|
│ ├── config.py # Configuration
|
||||||
├── transformer.py # Transformer blocks
|
│ ├── transformer.py # Transformer blocks
|
||||||
├── attention.py # Multi-head attention with RoPE
|
│ ├── attention.py # Multi-head attention with RoPE
|
||||||
├── text_encoder.py # Text encoder
|
│ ├── text_encoder.py # Gemma 3 text encoder
|
||||||
├── upsampler.py # 2x spatial upsampler
|
│ ├── upsampler.py # 2x spatial upsampler
|
||||||
└── video_vae/ # VAE encoder/decoder
|
│ └── video_vae/ # VAE encoder/decoder
|
||||||
|
└── wan/ # Wan2.1/2.2 model
|
||||||
|
├── config.py # Configuration (2.1 & 2.2 presets)
|
||||||
|
├── model.py # WanModel (DiT transformer)
|
||||||
|
├── transformer.py # Attention blocks with 6-element modulation
|
||||||
|
├── attention.py # Self/cross attention with QK-norm
|
||||||
|
├── rope.py # 3-way factorized RoPE
|
||||||
|
├── text_encoder.py # T5 UMT5-XXL encoder
|
||||||
|
├── vae.py # 3D causal VAE decoder
|
||||||
|
└── scheduler.py # Flow-matching Euler scheduler
|
||||||
```
|
```
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
from mlx_video.models.ltx import LTXModel, LTXModelConfig
|
from mlx_video.models.ltx import LTXModel, LTXModelConfig
|
||||||
|
from mlx_video.models.wan import WanModel, WanModelConfig
|
||||||
from mlx_video.convert import load_transformer_weights, load_vae_weights
|
from mlx_video.convert import load_transformer_weights, load_vae_weights
|
||||||
import os
|
import os
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LTXModel",
|
"LTXModel",
|
||||||
"LTXModelConfig",
|
"LTXModelConfig",
|
||||||
|
"WanModel",
|
||||||
|
"WanModelConfig",
|
||||||
"load_transformer_weights",
|
"load_transformer_weights",
|
||||||
"load_vae_weights",
|
"load_vae_weights",
|
||||||
]
|
]
|
||||||
|
|||||||
556
mlx_video/convert_wan.py
Normal file
556
mlx_video/convert_wan.py
Normal file
@@ -0,0 +1,556 @@
|
|||||||
|
"""Weight conversion for Wan2.2 models (PyTorch -> MLX)."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.utils
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def load_torch_weights(path: str) -> Dict[str, mx.array]:
|
||||||
|
"""Load PyTorch .pth weights and convert to MLX arrays.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to .pth file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of MLX arrays
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("PyTorch is required to load .pth weights: pip install torch")
|
||||||
|
|
||||||
|
logging.info(f"Loading weights from {path}")
|
||||||
|
state_dict = torch.load(path, map_location="cpu", weights_only=True)
|
||||||
|
|
||||||
|
weights = {}
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
np_val = value.detach().float().numpy()
|
||||||
|
weights[key] = mx.array(np_val)
|
||||||
|
|
||||||
|
return weights
|
||||||
|
|
||||||
|
|
||||||
|
def load_safetensors_weights(path: str) -> Dict[str, mx.array]:
|
||||||
|
"""Load safetensors weights as MLX arrays.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to directory with safetensors files or single file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of MLX arrays
|
||||||
|
"""
|
||||||
|
path = Path(path)
|
||||||
|
weights = {}
|
||||||
|
if path.is_file():
|
||||||
|
weights = mx.load(str(path))
|
||||||
|
elif path.is_dir():
|
||||||
|
for sf in sorted(path.glob("*.safetensors")):
|
||||||
|
weights.update(mx.load(str(sf)))
|
||||||
|
return weights
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
|
"""Convert Wan2.2 transformer weight keys to MLX model structure.
|
||||||
|
|
||||||
|
Wan2.2 keys follow the pattern:
|
||||||
|
patch_embedding.weight/bias
|
||||||
|
text_embedding.{0,2}.weight/bias
|
||||||
|
time_embedding.{0,2}.weight/bias
|
||||||
|
time_projection.1.weight/bias
|
||||||
|
blocks.{i}.norm1.weight
|
||||||
|
blocks.{i}.self_attn.{q,k,v,o}.weight/bias
|
||||||
|
blocks.{i}.self_attn.norm_q.weight
|
||||||
|
blocks.{i}.self_attn.norm_k.weight
|
||||||
|
blocks.{i}.norm3.weight/bias (if cross_attn_norm)
|
||||||
|
blocks.{i}.cross_attn.{q,k,v,o}.weight/bias
|
||||||
|
blocks.{i}.cross_attn.norm_q.weight
|
||||||
|
blocks.{i}.cross_attn.norm_k.weight
|
||||||
|
blocks.{i}.norm2.weight
|
||||||
|
blocks.{i}.ffn.{0,2}.weight/bias
|
||||||
|
blocks.{i}.modulation
|
||||||
|
head.norm.weight
|
||||||
|
head.head.weight/bias
|
||||||
|
head.modulation
|
||||||
|
freqs (buffer)
|
||||||
|
|
||||||
|
MLX model uses:
|
||||||
|
patch_embedding_proj.weight/bias (after patchify reshape)
|
||||||
|
text_embedding_0.weight/bias, text_embedding_1.weight/bias
|
||||||
|
time_embedding_0.weight/bias, time_embedding_1.weight/bias
|
||||||
|
time_projection.weight/bias
|
||||||
|
blocks.{i}.norm1.weight
|
||||||
|
blocks.{i}.self_attn.{q,k,v,o}.weight/bias
|
||||||
|
etc.
|
||||||
|
"""
|
||||||
|
sanitized = {}
|
||||||
|
|
||||||
|
for key, value in weights.items():
|
||||||
|
new_key = key
|
||||||
|
|
||||||
|
# Patch embedding: Conv3d(16, 5120, (1,2,2)) weight is [O, I, D, H, W]
|
||||||
|
# MLX Linear expects [O, I*D*H*W] after we flatten in patchify
|
||||||
|
if key == "patch_embedding.weight":
|
||||||
|
# Original: [dim, in_dim, 1, 2, 2] -> reshape to [dim, in_dim*1*2*2]
|
||||||
|
value = value.reshape(value.shape[0], -1)
|
||||||
|
new_key = "patch_embedding_proj.weight"
|
||||||
|
sanitized[new_key] = value
|
||||||
|
continue
|
||||||
|
if key == "patch_embedding.bias":
|
||||||
|
new_key = "patch_embedding_proj.bias"
|
||||||
|
sanitized[new_key] = value
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Text embedding Sequential: 0=Linear, 1=GELU(no params), 2=Linear
|
||||||
|
if key.startswith("text_embedding.0."):
|
||||||
|
new_key = key.replace("text_embedding.0.", "text_embedding_0.")
|
||||||
|
sanitized[new_key] = value
|
||||||
|
continue
|
||||||
|
if key.startswith("text_embedding.2."):
|
||||||
|
new_key = key.replace("text_embedding.2.", "text_embedding_1.")
|
||||||
|
sanitized[new_key] = value
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Time embedding Sequential: 0=Linear, 1=SiLU(no params), 2=Linear
|
||||||
|
if key.startswith("time_embedding.0."):
|
||||||
|
new_key = key.replace("time_embedding.0.", "time_embedding_0.")
|
||||||
|
sanitized[new_key] = value
|
||||||
|
continue
|
||||||
|
if key.startswith("time_embedding.2."):
|
||||||
|
new_key = key.replace("time_embedding.2.", "time_embedding_1.")
|
||||||
|
sanitized[new_key] = value
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Time projection Sequential: 0=SiLU(no params), 1=Linear
|
||||||
|
if key.startswith("time_projection.1."):
|
||||||
|
new_key = key.replace("time_projection.1.", "time_projection.")
|
||||||
|
sanitized[new_key] = value
|
||||||
|
continue
|
||||||
|
|
||||||
|
# FFN: Sequential(Linear, GELU, Linear) -> ffn.{0,2} -> ffn.fc1, ffn.fc2
|
||||||
|
new_key = new_key.replace(".ffn.0.", ".ffn.fc1.")
|
||||||
|
new_key = new_key.replace(".ffn.2.", ".ffn.fc2.")
|
||||||
|
|
||||||
|
# Skip the freqs buffer (we compute it in the model)
|
||||||
|
if key == "freqs":
|
||||||
|
continue
|
||||||
|
|
||||||
|
sanitized[new_key] = value
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_wan_t5_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
|
"""Convert Wan2.2 T5 encoder weight keys to MLX T5Encoder structure.
|
||||||
|
|
||||||
|
Wan2.2 T5 keys:
|
||||||
|
token_embedding.weight
|
||||||
|
pos_embedding.embedding.weight (if shared_pos)
|
||||||
|
blocks.{i}.norm1.weight
|
||||||
|
blocks.{i}.attn.{q,k,v,o}.weight
|
||||||
|
blocks.{i}.norm2.weight
|
||||||
|
blocks.{i}.ffn.gate.0.weight (gate linear)
|
||||||
|
blocks.{i}.ffn.fc1.weight
|
||||||
|
blocks.{i}.ffn.fc2.weight
|
||||||
|
blocks.{i}.pos_embedding.embedding.weight (if not shared_pos)
|
||||||
|
norm.weight
|
||||||
|
|
||||||
|
MLX T5Encoder structure:
|
||||||
|
token_embedding.weight
|
||||||
|
blocks.{i}.norm1.weight
|
||||||
|
blocks.{i}.attn.{q,k,v,o}.weight
|
||||||
|
blocks.{i}.norm2.weight
|
||||||
|
blocks.{i}.ffn.gate_proj.weight (mapped from gate.0)
|
||||||
|
blocks.{i}.ffn.fc1.weight
|
||||||
|
blocks.{i}.ffn.fc2.weight
|
||||||
|
blocks.{i}.pos_embedding.embedding.weight
|
||||||
|
norm.weight
|
||||||
|
"""
|
||||||
|
sanitized = {}
|
||||||
|
|
||||||
|
for key, value in weights.items():
|
||||||
|
new_key = key
|
||||||
|
|
||||||
|
# Map gate.0 -> gate_proj (the GELU is a separate module, not a parameter)
|
||||||
|
new_key = new_key.replace(".ffn.gate.0.", ".ffn.gate_proj.")
|
||||||
|
|
||||||
|
sanitized[new_key] = value
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_wan_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
|
"""Convert Wan2.2 VAE weight keys to MLX WanVAE structure.
|
||||||
|
|
||||||
|
Handles Conv3d and Conv2d weight transpositions for MLX format.
|
||||||
|
"""
|
||||||
|
sanitized = {}
|
||||||
|
|
||||||
|
for key, value in weights.items():
|
||||||
|
new_key = key
|
||||||
|
|
||||||
|
# Handle Conv3d: PyTorch [O, I, D, H, W] -> MLX CausalConv3d weight [O, D, H, W, I]
|
||||||
|
if "weight" in key and value.ndim == 5:
|
||||||
|
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||||
|
|
||||||
|
# Handle Conv2d: PyTorch [O, I, H, W] -> MLX [O, H, W, I]
|
||||||
|
if "weight" in key and value.ndim == 4:
|
||||||
|
value = mx.transpose(value, (0, 2, 3, 1))
|
||||||
|
|
||||||
|
# Map decoder keys to MLX decoder structure
|
||||||
|
# Wan2.2 uses encoder/decoder with downsamples/upsamples
|
||||||
|
# Need to adapt naming for our simplified structure
|
||||||
|
|
||||||
|
sanitized[new_key] = value
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
def convert_wan_checkpoint(
|
||||||
|
checkpoint_dir: str,
|
||||||
|
output_dir: str,
|
||||||
|
dtype: str = "bfloat16",
|
||||||
|
model_version: str = "auto",
|
||||||
|
quantize: bool = False,
|
||||||
|
bits: int = 4,
|
||||||
|
group_size: int = 64,
|
||||||
|
):
|
||||||
|
"""Convert a Wan2.1 or Wan2.2 checkpoint directory to MLX format.
|
||||||
|
|
||||||
|
Wan2.2 expected structure:
|
||||||
|
checkpoint_dir/
|
||||||
|
models_t5_umt5-xxl-enc-bf16.pth
|
||||||
|
Wan2.1_VAE.pth
|
||||||
|
low_noise_model/ (safetensors)
|
||||||
|
high_noise_model/ (safetensors)
|
||||||
|
|
||||||
|
Wan2.1 expected structure:
|
||||||
|
checkpoint_dir/
|
||||||
|
models_t5_umt5-xxl-enc-bf16.pth
|
||||||
|
Wan2.1_VAE.pth
|
||||||
|
diffusion_pytorch_model*.safetensors (single model)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_dir: Path to Wan checkpoint directory
|
||||||
|
output_dir: Path to output MLX model directory
|
||||||
|
dtype: Target dtype
|
||||||
|
model_version: "2.1", "2.2", or "auto" (detect from directory)
|
||||||
|
quantize: Whether to quantize the transformer weights
|
||||||
|
bits: Quantization bits (4 or 8)
|
||||||
|
group_size: Quantization group size (32, 64, or 128)
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
checkpoint_dir = Path(checkpoint_dir)
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
dtype_map = {
|
||||||
|
"float16": mx.float16,
|
||||||
|
"float32": mx.float32,
|
||||||
|
"bfloat16": mx.bfloat16,
|
||||||
|
}
|
||||||
|
target_dtype = dtype_map.get(dtype, mx.bfloat16)
|
||||||
|
|
||||||
|
# Auto-detect version
|
||||||
|
if model_version == "auto":
|
||||||
|
if (checkpoint_dir / "low_noise_model").exists():
|
||||||
|
model_version = "2.2"
|
||||||
|
elif (checkpoint_dir / "Wan2.2_VAE.pth").exists():
|
||||||
|
model_version = "2.2"
|
||||||
|
else:
|
||||||
|
model_version = "2.1"
|
||||||
|
print(f"Auto-detected Wan{model_version} checkpoint")
|
||||||
|
|
||||||
|
is_dual = (checkpoint_dir / "low_noise_model").exists()
|
||||||
|
|
||||||
|
if is_dual:
|
||||||
|
# Wan2.2: Convert dual transformer models
|
||||||
|
low_noise_path = checkpoint_dir / "low_noise_model"
|
||||||
|
if low_noise_path.exists():
|
||||||
|
print("Converting low-noise transformer...")
|
||||||
|
weights = load_safetensors_weights(str(low_noise_path))
|
||||||
|
weights = sanitize_wan_transformer_weights(weights)
|
||||||
|
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
||||||
|
out_path = output_dir / "low_noise_model.safetensors"
|
||||||
|
mx.save_safetensors(str(out_path), weights)
|
||||||
|
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
||||||
|
|
||||||
|
high_noise_path = checkpoint_dir / "high_noise_model"
|
||||||
|
if high_noise_path.exists():
|
||||||
|
print("Converting high-noise transformer...")
|
||||||
|
weights = load_safetensors_weights(str(high_noise_path))
|
||||||
|
weights = sanitize_wan_transformer_weights(weights)
|
||||||
|
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
||||||
|
out_path = output_dir / "high_noise_model.safetensors"
|
||||||
|
mx.save_safetensors(str(out_path), weights)
|
||||||
|
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
||||||
|
else:
|
||||||
|
# Wan2.1: Convert single transformer model
|
||||||
|
# Try safetensors in the checkpoint dir itself
|
||||||
|
print("Converting transformer (single model)...")
|
||||||
|
weights = load_safetensors_weights(str(checkpoint_dir))
|
||||||
|
if not weights:
|
||||||
|
# Fallback: look for .pth files
|
||||||
|
for pth in sorted(checkpoint_dir.glob("*.pth")):
|
||||||
|
if "t5" not in pth.name.lower() and "vae" not in pth.name.lower():
|
||||||
|
print(f" Loading from {pth.name}...")
|
||||||
|
weights = load_torch_weights(str(pth))
|
||||||
|
break
|
||||||
|
if weights:
|
||||||
|
weights = sanitize_wan_transformer_weights(weights)
|
||||||
|
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
||||||
|
out_path = output_dir / "model.safetensors"
|
||||||
|
mx.save_safetensors(str(out_path), weights)
|
||||||
|
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
||||||
|
else:
|
||||||
|
print(" Warning: No transformer weights found!")
|
||||||
|
|
||||||
|
# Save config — detect model size from source config.json or transformer weights
|
||||||
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
|
||||||
|
def _detect_config():
|
||||||
|
"""Detect config from source config.json or transformer weight shapes."""
|
||||||
|
if is_dual:
|
||||||
|
return WanModelConfig.wan22_t2v_14b()
|
||||||
|
|
||||||
|
# Try reading source config.json first (most reliable)
|
||||||
|
src_cfg_path = checkpoint_dir / "config.json"
|
||||||
|
src_config = None
|
||||||
|
if src_cfg_path.exists():
|
||||||
|
with open(src_cfg_path) as f:
|
||||||
|
src_config = json.load(f)
|
||||||
|
|
||||||
|
if src_config and "dim" in src_config:
|
||||||
|
src_dim = src_config.get("dim", 5120)
|
||||||
|
src_in_dim = src_config.get("in_dim", 16)
|
||||||
|
src_out_dim = src_config.get("out_dim", 16)
|
||||||
|
src_ffn_dim = src_config.get("ffn_dim", 13824)
|
||||||
|
src_num_heads = src_config.get("num_heads", 40)
|
||||||
|
src_num_layers = src_config.get("num_layers", 40)
|
||||||
|
src_model_type = src_config.get("model_type", "t2v")
|
||||||
|
src_text_len = src_config.get("text_len", 512)
|
||||||
|
|
||||||
|
print(f" Source config: dim={src_dim}, layers={src_num_layers}, "
|
||||||
|
f"heads={src_num_heads}, type={src_model_type}")
|
||||||
|
|
||||||
|
is_22 = model_version == "2.2"
|
||||||
|
|
||||||
|
# Wan2.2 uses different VAE with z_dim=48 and stride (4,16,16)
|
||||||
|
vae_z = 48 if is_22 else 16
|
||||||
|
vae_s = (4, 16, 16) if is_22 else (4, 8, 8)
|
||||||
|
fps = 24 if is_22 else 16
|
||||||
|
|
||||||
|
return WanModelConfig(
|
||||||
|
model_type=src_model_type,
|
||||||
|
model_version=model_version,
|
||||||
|
dim=src_dim,
|
||||||
|
ffn_dim=src_ffn_dim,
|
||||||
|
in_dim=src_in_dim,
|
||||||
|
out_dim=src_out_dim,
|
||||||
|
num_heads=src_num_heads,
|
||||||
|
num_layers=src_num_layers,
|
||||||
|
text_len=src_text_len,
|
||||||
|
vae_z_dim=vae_z,
|
||||||
|
vae_stride=vae_s,
|
||||||
|
dual_model=False,
|
||||||
|
boundary=0.0,
|
||||||
|
sample_shift=5.0,
|
||||||
|
sample_steps=50,
|
||||||
|
sample_guide_scale=5.0,
|
||||||
|
sample_fps=fps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback: detect from saved transformer weight shapes
|
||||||
|
saved_model = output_dir / "model.safetensors"
|
||||||
|
if saved_model.exists():
|
||||||
|
det_weights = mx.load(str(saved_model))
|
||||||
|
dim = None
|
||||||
|
for k, v in det_weights.items():
|
||||||
|
if "patch_embedding_proj.weight" in k:
|
||||||
|
dim = v.shape[0]
|
||||||
|
break
|
||||||
|
del det_weights
|
||||||
|
if dim is not None and dim <= 2048:
|
||||||
|
print(f" Auto-detected 1.3B model (dim={dim})")
|
||||||
|
return WanModelConfig.wan21_t2v_1_3b()
|
||||||
|
|
||||||
|
return WanModelConfig.wan21_t2v_14b()
|
||||||
|
|
||||||
|
config = _detect_config()
|
||||||
|
config_path = output_dir / "config.json"
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(config.to_dict(), f, indent=2)
|
||||||
|
print(f" Saved config to {config_path}")
|
||||||
|
|
||||||
|
# Convert T5 encoder
|
||||||
|
t5_path = checkpoint_dir / "models_t5_umt5-xxl-enc-bf16.pth"
|
||||||
|
if t5_path.exists():
|
||||||
|
print("Converting T5 encoder...")
|
||||||
|
weights = load_torch_weights(str(t5_path))
|
||||||
|
weights = sanitize_wan_t5_weights(weights)
|
||||||
|
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
||||||
|
out_path = output_dir / "t5_encoder.safetensors"
|
||||||
|
mx.save_safetensors(str(out_path), weights)
|
||||||
|
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
||||||
|
|
||||||
|
# Convert VAE (check both naming conventions)
|
||||||
|
vae_path = checkpoint_dir / "Wan2.1_VAE.pth"
|
||||||
|
is_wan22_vae = False
|
||||||
|
if not vae_path.exists():
|
||||||
|
vae_path = checkpoint_dir / "Wan2.2_VAE.pth"
|
||||||
|
is_wan22_vae = True
|
||||||
|
if vae_path.exists():
|
||||||
|
print(f"Converting VAE ({'Wan2.2' if is_wan22_vae else 'Wan2.1'})...")
|
||||||
|
weights = load_torch_weights(str(vae_path))
|
||||||
|
if is_wan22_vae:
|
||||||
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||||
|
weights = sanitize_wan22_vae_weights(weights)
|
||||||
|
else:
|
||||||
|
weights = sanitize_wan_vae_weights(weights)
|
||||||
|
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
||||||
|
out_path = output_dir / "vae.safetensors"
|
||||||
|
mx.save_safetensors(str(out_path), weights)
|
||||||
|
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
||||||
|
|
||||||
|
# Quantize transformer weights if requested
|
||||||
|
if quantize:
|
||||||
|
print(f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})...")
|
||||||
|
_quantize_saved_model(output_dir, config, is_dual, bits, group_size)
|
||||||
|
|
||||||
|
print(f"\nConversion complete! Output: {output_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
def _quantize_predicate(path: str, module) -> bool:
|
||||||
|
"""Return True for layers that should be quantized.
|
||||||
|
|
||||||
|
Targets heavyweight Linear layers in attention and FFN blocks.
|
||||||
|
Skips embeddings, norms, head, and modulation (small, precision-sensitive).
|
||||||
|
"""
|
||||||
|
if not hasattr(module, "to_quantized"):
|
||||||
|
return False
|
||||||
|
# Quantize attention Q/K/V/O and FFN fc1/fc2
|
||||||
|
quantize_patterns = (
|
||||||
|
".self_attn.q", ".self_attn.k", ".self_attn.v", ".self_attn.o",
|
||||||
|
".cross_attn.q", ".cross_attn.k", ".cross_attn.v", ".cross_attn.o",
|
||||||
|
".ffn.fc1", ".ffn.fc2",
|
||||||
|
)
|
||||||
|
return any(path.endswith(p) for p in quantize_patterns)
|
||||||
|
|
||||||
|
|
||||||
|
def _quantize_saved_model(
|
||||||
|
output_dir: Path,
|
||||||
|
config,
|
||||||
|
is_dual: bool,
|
||||||
|
bits: int,
|
||||||
|
group_size: int,
|
||||||
|
):
|
||||||
|
"""Load saved bf16 model, quantize, and re-save."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
|
model_files = []
|
||||||
|
if is_dual:
|
||||||
|
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
|
||||||
|
p = output_dir / name
|
||||||
|
if p.exists():
|
||||||
|
model_files.append(p)
|
||||||
|
else:
|
||||||
|
p = output_dir / "model.safetensors"
|
||||||
|
if p.exists():
|
||||||
|
model_files.append(p)
|
||||||
|
|
||||||
|
for model_path in model_files:
|
||||||
|
print(f" Quantizing {model_path.name}...")
|
||||||
|
model = WanModel(config)
|
||||||
|
weights = mx.load(str(model_path))
|
||||||
|
model.load_weights(list(weights.items()), strict=False)
|
||||||
|
|
||||||
|
# Apply quantization to targeted layers
|
||||||
|
nn.quantize(
|
||||||
|
model,
|
||||||
|
group_size=group_size,
|
||||||
|
bits=bits,
|
||||||
|
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save quantized weights
|
||||||
|
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
||||||
|
mx.save_safetensors(str(model_path), weights_dict)
|
||||||
|
n_quantized = sum(1 for k in weights_dict if ".scales" in k)
|
||||||
|
print(f" {n_quantized} layers quantized, {len(weights_dict)} tensors saved")
|
||||||
|
|
||||||
|
# Update config.json with quantization metadata
|
||||||
|
config_path = output_dir / "config.json"
|
||||||
|
with open(config_path) as f:
|
||||||
|
cfg = json.load(f)
|
||||||
|
cfg["quantization"] = {
|
||||||
|
"group_size": group_size,
|
||||||
|
"bits": bits,
|
||||||
|
}
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(cfg, f, indent=2)
|
||||||
|
print(f" Updated config.json with quantization metadata")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Convert Wan model to MLX format")
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint-dir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to Wan checkpoint directory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=str,
|
||||||
|
default="wan_mlx_model",
|
||||||
|
help="Output path for MLX model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["float16", "float32", "bfloat16"],
|
||||||
|
default="bfloat16",
|
||||||
|
help="Target dtype",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-version",
|
||||||
|
type=str,
|
||||||
|
choices=["2.1", "2.2", "auto"],
|
||||||
|
default="auto",
|
||||||
|
help="Wan model version (auto-detect by default)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantize",
|
||||||
|
action="store_true",
|
||||||
|
help="Quantize transformer weights for faster inference",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--bits",
|
||||||
|
type=int,
|
||||||
|
choices=[4, 8],
|
||||||
|
default=4,
|
||||||
|
help="Quantization bits (default: 4)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--group-size",
|
||||||
|
type=int,
|
||||||
|
choices=[32, 64, 128],
|
||||||
|
default=64,
|
||||||
|
help="Quantization group size (default: 64)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert_wan_checkpoint(
|
||||||
|
args.checkpoint_dir, args.output_dir, args.dtype, args.model_version,
|
||||||
|
quantize=args.quantize, bits=args.bits, group_size=args.group_size,
|
||||||
|
)
|
||||||
512
mlx_video/generate_wan.py
Normal file
512
mlx_video/generate_wan.py
Normal file
@@ -0,0 +1,512 @@
|
|||||||
|
"""Wan2.2 Text-to-Video generation pipeline for MLX."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class Colors:
|
||||||
|
CYAN = "\033[96m"
|
||||||
|
BLUE = "\033[94m"
|
||||||
|
GREEN = "\033[92m"
|
||||||
|
YELLOW = "\033[93m"
|
||||||
|
RED = "\033[91m"
|
||||||
|
MAGENTA = "\033[95m"
|
||||||
|
BOLD = "\033[1m"
|
||||||
|
DIM = "\033[2m"
|
||||||
|
RESET = "\033[0m"
|
||||||
|
|
||||||
|
|
||||||
|
def load_wan_model(model_path: Path, config, quantization: dict | None = None):
|
||||||
|
"""Load and initialize WanModel, with optional quantization support.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to model safetensors file
|
||||||
|
config: WanModelConfig
|
||||||
|
quantization: Optional dict with 'bits' and 'group_size' keys.
|
||||||
|
If provided, creates QuantizedLinear stubs before loading.
|
||||||
|
"""
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
|
model = WanModel(config)
|
||||||
|
|
||||||
|
if quantization:
|
||||||
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
|
||||||
|
nn.quantize(
|
||||||
|
model,
|
||||||
|
group_size=quantization["group_size"],
|
||||||
|
bits=quantization["bits"],
|
||||||
|
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
||||||
|
)
|
||||||
|
|
||||||
|
weights = mx.load(str(model_path))
|
||||||
|
model.load_weights(list(weights.items()), strict=False)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_t5_encoder(model_path: Path, config):
|
||||||
|
"""Load T5 text encoder."""
|
||||||
|
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||||
|
|
||||||
|
encoder = T5Encoder(
|
||||||
|
vocab_size=config.t5_vocab_size,
|
||||||
|
dim=config.t5_dim,
|
||||||
|
dim_attn=config.t5_dim_attn,
|
||||||
|
dim_ffn=config.t5_dim_ffn,
|
||||||
|
num_heads=config.t5_num_heads,
|
||||||
|
num_layers=config.t5_num_layers,
|
||||||
|
num_buckets=config.t5_num_buckets,
|
||||||
|
shared_pos=False,
|
||||||
|
)
|
||||||
|
weights = mx.load(str(model_path))
|
||||||
|
encoder.load_weights(list(weights.items()))
|
||||||
|
mx.eval(encoder.parameters())
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
|
def load_vae_decoder(model_path: Path, config=None):
|
||||||
|
"""Load VAE decoder (skips encoder weights with strict=False).
|
||||||
|
|
||||||
|
For Wan2.2 (vae_z_dim=48), uses Wan22VAEDecoder.
|
||||||
|
For Wan2.1 (vae_z_dim=16), uses WanVAE.
|
||||||
|
"""
|
||||||
|
is_wan22 = config is not None and config.vae_z_dim == 48
|
||||||
|
|
||||||
|
if is_wan22:
|
||||||
|
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||||
|
vae = Wan22VAEDecoder(z_dim=48)
|
||||||
|
else:
|
||||||
|
from mlx_video.models.wan.vae import WanVAE
|
||||||
|
vae = WanVAE(z_dim=16)
|
||||||
|
|
||||||
|
weights = mx.load(str(model_path))
|
||||||
|
vae.load_weights(list(weights.items()), strict=False)
|
||||||
|
mx.eval(vae.parameters())
|
||||||
|
return vae
|
||||||
|
|
||||||
|
|
||||||
|
def encode_text(
|
||||||
|
encoder,
|
||||||
|
tokenizer,
|
||||||
|
prompt: str,
|
||||||
|
text_len: int = 512,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Encode text prompt using T5 encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder: T5Encoder model
|
||||||
|
tokenizer: HuggingFace tokenizer
|
||||||
|
prompt: Text prompt
|
||||||
|
text_len: Maximum text length
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Text embeddings [L, dim]
|
||||||
|
"""
|
||||||
|
tokens = tokenizer(
|
||||||
|
prompt,
|
||||||
|
max_length=text_len,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
|
ids = mx.array(tokens["input_ids"])
|
||||||
|
mask = mx.array(tokens["attention_mask"])
|
||||||
|
|
||||||
|
embeddings = encoder(ids, mask=mask)
|
||||||
|
|
||||||
|
# Return only non-padding tokens
|
||||||
|
seq_len = int(mask.sum().item())
|
||||||
|
return embeddings[0, :seq_len]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_video(
|
||||||
|
model_dir: str,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
width: int = 1280,
|
||||||
|
height: int = 720,
|
||||||
|
num_frames: int = 81,
|
||||||
|
steps: int = None,
|
||||||
|
guide_scale: str | float | tuple = None,
|
||||||
|
shift: float = None,
|
||||||
|
seed: int = -1,
|
||||||
|
output_path: str = "output.mp4",
|
||||||
|
):
|
||||||
|
"""Generate video using Wan T2V pipeline (supports 2.1 and 2.2).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_dir: Path to converted MLX model directory
|
||||||
|
prompt: Text prompt
|
||||||
|
negative_prompt: Negative prompt
|
||||||
|
width: Video width
|
||||||
|
height: Video height
|
||||||
|
num_frames: Number of frames (must be 4n+1)
|
||||||
|
steps: Number of diffusion steps (None = use config default)
|
||||||
|
guide_scale: Guidance scale: float for single, (low,high) for dual (None = config default)
|
||||||
|
shift: Noise schedule shift (None = use config default)
|
||||||
|
seed: Random seed (-1 for random)
|
||||||
|
output_path: Output video path
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
|
model_dir = Path(model_dir)
|
||||||
|
|
||||||
|
# Load config from model dir if available, otherwise auto-detect
|
||||||
|
config_path = model_dir / "config.json"
|
||||||
|
quantization = None
|
||||||
|
if config_path.exists():
|
||||||
|
with open(config_path) as f:
|
||||||
|
config_dict = json.load(f)
|
||||||
|
# Extract quantization config (not a model config field)
|
||||||
|
quantization = config_dict.pop("quantization", None)
|
||||||
|
# Handle tuple fields stored as lists in JSON
|
||||||
|
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
|
||||||
|
if key in config_dict and isinstance(config_dict[key], list):
|
||||||
|
config_dict[key] = tuple(config_dict[key])
|
||||||
|
config = WanModelConfig(**{
|
||||||
|
k: v for k, v in config_dict.items()
|
||||||
|
if k in WanModelConfig.__dataclass_fields__
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# Auto-detect: dual model files → 2.2, single model → 2.1
|
||||||
|
if (model_dir / "low_noise_model.safetensors").exists():
|
||||||
|
config = WanModelConfig.wan22_t2v_14b()
|
||||||
|
else:
|
||||||
|
# Detect 1.3B vs 14B from weight shapes
|
||||||
|
model_path = model_dir / "model.safetensors"
|
||||||
|
if model_path.exists():
|
||||||
|
probe = mx.load(str(model_path), return_metadata=False)
|
||||||
|
for k, v in probe.items():
|
||||||
|
if "patch_embedding_proj.weight" in k:
|
||||||
|
dim = v.shape[0]
|
||||||
|
if dim <= 2048:
|
||||||
|
config = WanModelConfig.wan21_t2v_1_3b()
|
||||||
|
else:
|
||||||
|
config = WanModelConfig.wan21_t2v_14b()
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
config = WanModelConfig.wan21_t2v_14b()
|
||||||
|
del probe
|
||||||
|
else:
|
||||||
|
config = WanModelConfig.wan21_t2v_14b()
|
||||||
|
|
||||||
|
is_dual = config.dual_model
|
||||||
|
|
||||||
|
# Validate config against actual weights (handles mismatched config.json)
|
||||||
|
if not is_dual:
|
||||||
|
model_path = model_dir / "model.safetensors"
|
||||||
|
if model_path.exists():
|
||||||
|
probe = mx.load(str(model_path), return_metadata=False)
|
||||||
|
for k, v in probe.items():
|
||||||
|
if "patch_embedding_proj.weight" in k:
|
||||||
|
actual_dim = v.shape[0]
|
||||||
|
if actual_dim != config.dim:
|
||||||
|
print(f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}")
|
||||||
|
if actual_dim <= 2048:
|
||||||
|
config = WanModelConfig.wan21_t2v_1_3b()
|
||||||
|
else:
|
||||||
|
config = WanModelConfig.wan21_t2v_14b()
|
||||||
|
break
|
||||||
|
del probe
|
||||||
|
|
||||||
|
# Auto-correct Wan2.2 VAE params from stale configs
|
||||||
|
if config.in_dim == 48 and config.vae_z_dim != 48:
|
||||||
|
print(f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}")
|
||||||
|
config = WanModelConfig(**{
|
||||||
|
**{f.name: getattr(config, f.name) for f in config.__dataclass_fields__.values()},
|
||||||
|
"vae_z_dim": 48,
|
||||||
|
"vae_stride": (4, 16, 16),
|
||||||
|
"sample_fps": 24,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Apply defaults from config if not overridden
|
||||||
|
if steps is None:
|
||||||
|
steps = config.sample_steps
|
||||||
|
if shift is None:
|
||||||
|
shift = config.sample_shift
|
||||||
|
if guide_scale is None:
|
||||||
|
guide_scale = config.sample_guide_scale
|
||||||
|
|
||||||
|
# Normalize guide_scale
|
||||||
|
if isinstance(guide_scale, (int, float)):
|
||||||
|
guide_scale = float(guide_scale)
|
||||||
|
elif isinstance(guide_scale, str):
|
||||||
|
parts = [float(x) for x in guide_scale.split(",")]
|
||||||
|
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
|
||||||
|
|
||||||
|
# Validate frame count
|
||||||
|
assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}"
|
||||||
|
|
||||||
|
version_str = f"Wan{config.model_version}"
|
||||||
|
mode_str = "dual-model" if is_dual else "single-model"
|
||||||
|
print(f"{Colors.CYAN}{'='*60}")
|
||||||
|
print(f" {version_str} Text-to-Video Generation (MLX, {mode_str})")
|
||||||
|
print(f"{'='*60}{Colors.RESET}")
|
||||||
|
print(f"{Colors.DIM} Prompt: {prompt}")
|
||||||
|
print(f" Size: {width}x{height}, Frames: {num_frames}")
|
||||||
|
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}")
|
||||||
|
print(f"{Colors.RESET}")
|
||||||
|
|
||||||
|
# Seed
|
||||||
|
if seed < 0:
|
||||||
|
seed = random.randint(0, 2**32 - 1)
|
||||||
|
mx.random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
print(f"{Colors.DIM} Seed: {seed}{Colors.RESET}")
|
||||||
|
|
||||||
|
# Compute target latent shape
|
||||||
|
vae_stride = config.vae_stride
|
||||||
|
z_dim = config.vae_z_dim
|
||||||
|
t_latent = (num_frames - 1) // vae_stride[0] + 1
|
||||||
|
h_latent = height // vae_stride[1]
|
||||||
|
w_latent = width // vae_stride[2]
|
||||||
|
target_shape = (z_dim, t_latent, h_latent, w_latent)
|
||||||
|
|
||||||
|
# Sequence length for transformer
|
||||||
|
patch_size = config.patch_size
|
||||||
|
seq_len = math.ceil(
|
||||||
|
(h_latent * w_latent) / (patch_size[1] * patch_size[2]) * t_latent
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"{Colors.DIM} Latent shape: {target_shape}")
|
||||||
|
print(f" Sequence length: {seq_len}{Colors.RESET}")
|
||||||
|
|
||||||
|
# Load T5 encoder
|
||||||
|
t1 = time.time()
|
||||||
|
print(f"\n{Colors.BLUE}Loading T5 encoder...{Colors.RESET}")
|
||||||
|
t5_path = model_dir / "t5_encoder.safetensors"
|
||||||
|
t5_encoder = load_t5_encoder(t5_path, config)
|
||||||
|
|
||||||
|
# Load tokenizer
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
|
||||||
|
|
||||||
|
# Encode prompts
|
||||||
|
print(f"{Colors.BLUE}Encoding text...{Colors.RESET}")
|
||||||
|
context = encode_text(t5_encoder, tokenizer, prompt, config.text_len)
|
||||||
|
if negative_prompt:
|
||||||
|
context_null = encode_text(t5_encoder, tokenizer, negative_prompt, config.text_len)
|
||||||
|
else:
|
||||||
|
context_null = encode_text(t5_encoder, tokenizer, "", config.text_len)
|
||||||
|
mx.eval(context, context_null)
|
||||||
|
|
||||||
|
# Free T5 from memory
|
||||||
|
del t5_encoder
|
||||||
|
gc.collect(); mx.clear_cache()
|
||||||
|
print(f"{Colors.DIM} T5 encoding: {time.time() - t1:.1f}s{Colors.RESET}")
|
||||||
|
|
||||||
|
# Load transformer models
|
||||||
|
print(f"\n{Colors.BLUE}Loading transformer model(s)...{Colors.RESET}")
|
||||||
|
if quantization:
|
||||||
|
print(f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}")
|
||||||
|
t2 = time.time()
|
||||||
|
|
||||||
|
if is_dual:
|
||||||
|
low_noise_path = model_dir / "low_noise_model.safetensors"
|
||||||
|
high_noise_path = model_dir / "high_noise_model.safetensors"
|
||||||
|
low_noise_model = load_wan_model(low_noise_path, config, quantization)
|
||||||
|
high_noise_model = load_wan_model(high_noise_path, config, quantization)
|
||||||
|
else:
|
||||||
|
single_model = load_wan_model(model_dir / "model.safetensors", config, quantization)
|
||||||
|
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
|
||||||
|
|
||||||
|
# Precompute text embeddings once (avoids redundant MLP in every step)
|
||||||
|
ref_model = single_model if not is_dual else low_noise_model
|
||||||
|
context_emb = ref_model.embed_text([context, context_null])
|
||||||
|
mx.eval(context_emb)
|
||||||
|
context_cond = context_emb[0:1] # [1, text_len, dim]
|
||||||
|
context_uncond = context_emb[1:2] # [1, text_len, dim]
|
||||||
|
# Stack for batched CFG: [2, text_len, dim]
|
||||||
|
context_cfg = mx.concatenate([context_cond, context_uncond], axis=0)
|
||||||
|
|
||||||
|
# Precompute cross-attention K/V caches (constant across all steps)
|
||||||
|
if is_dual:
|
||||||
|
cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg)
|
||||||
|
cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg)
|
||||||
|
mx.eval(cross_kv_low, cross_kv_high)
|
||||||
|
else:
|
||||||
|
cross_kv = single_model.prepare_cross_kv(context_cfg)
|
||||||
|
mx.eval(cross_kv)
|
||||||
|
|
||||||
|
# Setup scheduler
|
||||||
|
scheduler = FlowMatchEulerScheduler(num_train_timesteps=config.num_train_timesteps)
|
||||||
|
scheduler.set_timesteps(steps, shift=shift)
|
||||||
|
|
||||||
|
# Generate initial noise
|
||||||
|
noise = mx.random.normal(target_shape)
|
||||||
|
|
||||||
|
# Boundary for model switching (dual model only)
|
||||||
|
boundary = (config.boundary * config.num_train_timesteps) if is_dual else None
|
||||||
|
|
||||||
|
# Diffusion loop
|
||||||
|
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
|
||||||
|
latents = noise
|
||||||
|
t3 = time.time()
|
||||||
|
|
||||||
|
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
|
||||||
|
timestep_val = scheduler.timesteps[i].item()
|
||||||
|
|
||||||
|
# Select model, guide scale, and cached K/V
|
||||||
|
if is_dual:
|
||||||
|
if timestep_val >= boundary:
|
||||||
|
model = high_noise_model
|
||||||
|
gs = guide_scale[1]
|
||||||
|
kv = cross_kv_high
|
||||||
|
else:
|
||||||
|
model = low_noise_model
|
||||||
|
gs = guide_scale[0]
|
||||||
|
kv = cross_kv_low
|
||||||
|
else:
|
||||||
|
model = single_model
|
||||||
|
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
||||||
|
kv = cross_kv
|
||||||
|
|
||||||
|
# CFG: batch cond + uncond into single B=2 forward pass
|
||||||
|
preds = model(
|
||||||
|
[latents, latents],
|
||||||
|
t=mx.array([timestep_val, timestep_val]),
|
||||||
|
context=context_cfg,
|
||||||
|
seq_len=seq_len,
|
||||||
|
cross_kv_caches=kv,
|
||||||
|
)
|
||||||
|
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
|
||||||
|
|
||||||
|
# Classifier-free guidance + scheduler step
|
||||||
|
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
||||||
|
latents = scheduler.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||||
|
|
||||||
|
# Release temporaries before eval to free memory for graph execution
|
||||||
|
del noise_pred_cond, noise_pred_uncond, noise_pred, preds
|
||||||
|
mx.eval(latents)
|
||||||
|
|
||||||
|
print(f"{Colors.DIM} Denoising: {time.time() - t3:.1f}s{Colors.RESET}")
|
||||||
|
|
||||||
|
# Free transformer models and text embeddings
|
||||||
|
if is_dual:
|
||||||
|
del low_noise_model, high_noise_model, cross_kv_low, cross_kv_high
|
||||||
|
else:
|
||||||
|
del single_model, cross_kv
|
||||||
|
del model, kv, context, context_null, context_cfg
|
||||||
|
gc.collect(); mx.clear_cache()
|
||||||
|
|
||||||
|
# Load VAE and decode
|
||||||
|
print(f"\n{Colors.BLUE}Decoding with VAE...{Colors.RESET}")
|
||||||
|
t4 = time.time()
|
||||||
|
vae_path = model_dir / "vae.safetensors"
|
||||||
|
vae = load_vae_decoder(vae_path, config)
|
||||||
|
|
||||||
|
is_wan22_vae = config.vae_z_dim == 48
|
||||||
|
|
||||||
|
if is_wan22_vae:
|
||||||
|
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||||
|
|
||||||
|
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
|
||||||
|
z = latents.transpose(1, 2, 3, 0)[None] # [1, T, H, W, C]
|
||||||
|
z = denormalize_latents(z)
|
||||||
|
video = vae(z) # [1, T', H', W', 3]
|
||||||
|
mx.eval(video)
|
||||||
|
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
||||||
|
|
||||||
|
video = np.array(video[0]) # [T', H', W', 3]
|
||||||
|
video = (video + 1.0) / 2.0
|
||||||
|
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
||||||
|
else:
|
||||||
|
video = vae.decode(latents[None]) # [1, 3, T, H, W]
|
||||||
|
mx.eval(video)
|
||||||
|
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
||||||
|
|
||||||
|
video = np.array(video[0]) # [3, T, H, W]
|
||||||
|
video = (video + 1.0) / 2.0
|
||||||
|
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
||||||
|
video = video.transpose(1, 2, 3, 0) # [T, H, W, 3]
|
||||||
|
|
||||||
|
save_video(video, output_path, fps=config.sample_fps)
|
||||||
|
print(f"\n{Colors.GREEN}✓ Video saved to {output_path}{Colors.RESET}")
|
||||||
|
print(f"{Colors.DIM} Total time: {time.time() - t1:.1f}s{Colors.RESET}")
|
||||||
|
|
||||||
|
|
||||||
|
def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
|
||||||
|
"""Save video frames to MP4.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frames: Video frames [T, H, W, 3] uint8
|
||||||
|
output_path: Output file path
|
||||||
|
fps: Frames per second
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import imageio
|
||||||
|
writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8)
|
||||||
|
for frame in frames:
|
||||||
|
writer.append_data(frame)
|
||||||
|
writer.close()
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
h, w = frames.shape[1], frames.shape[2]
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*"avc1")
|
||||||
|
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
|
||||||
|
for frame in frames:
|
||||||
|
writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||||
|
writer.release()
|
||||||
|
except (ImportError, Exception):
|
||||||
|
# Last resort: save as individual PNGs
|
||||||
|
from PIL import Image
|
||||||
|
out_dir = Path(output_path).parent / Path(output_path).stem
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
for i, frame in enumerate(frames):
|
||||||
|
Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png")
|
||||||
|
print(f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
|
||||||
|
parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory")
|
||||||
|
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
|
||||||
|
parser.add_argument("--negative-prompt", type=str, default="", help="Negative prompt")
|
||||||
|
parser.add_argument("--width", type=int, default=1280, help="Video width")
|
||||||
|
parser.add_argument("--height", type=int, default=720, help="Video height")
|
||||||
|
parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)")
|
||||||
|
parser.add_argument("--steps", type=int, default=None, help="Number of diffusion steps (default: from config)")
|
||||||
|
parser.add_argument("--guide-scale", type=str, default=None, help="Guidance scale: single float or low,high pair")
|
||||||
|
parser.add_argument("--shift", type=float, default=None, help="Noise schedule shift (default: from config)")
|
||||||
|
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
|
||||||
|
parser.add_argument("--output-path", type=str, default="output.mp4", help="Output video path")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Parse guide scale
|
||||||
|
guide_scale = None
|
||||||
|
if args.guide_scale is not None:
|
||||||
|
parts = [float(x) for x in args.guide_scale.split(",")]
|
||||||
|
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
|
||||||
|
|
||||||
|
generate_video(
|
||||||
|
model_dir=args.model_dir,
|
||||||
|
prompt=args.prompt,
|
||||||
|
negative_prompt=args.negative_prompt,
|
||||||
|
width=args.width,
|
||||||
|
height=args.height,
|
||||||
|
num_frames=args.num_frames,
|
||||||
|
steps=args.steps,
|
||||||
|
guide_scale=guide_scale,
|
||||||
|
shift=args.shift,
|
||||||
|
seed=args.seed,
|
||||||
|
output_path=args.output_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,2 +1,3 @@
|
|||||||
|
|
||||||
from mlx_video.models.ltx import LTXModel, LTXModelConfig
|
from mlx_video.models.ltx import LTXModel, LTXModelConfig
|
||||||
|
from mlx_video.models.wan import WanModel, WanModelConfig
|
||||||
|
|||||||
2
mlx_video/models/wan/__init__.py
Normal file
2
mlx_video/models/wan/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
201
mlx_video/models/wan/attention.py
Normal file
201
mlx_video/models/wan/attention.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .rope import rope_apply
|
||||||
|
|
||||||
|
|
||||||
|
class WanRMSNorm(nn.Module):
|
||||||
|
"""RMS normalization with learnable scale."""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = mx.ones((dim,))
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
return mx.fast.rms_norm(x, self.weight, self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class WanLayerNorm(nn.Module):
|
||||||
|
"""LayerNorm computed in float32, with optional affine."""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.elementwise_affine = elementwise_affine
|
||||||
|
if elementwise_affine:
|
||||||
|
self.weight = mx.ones((dim,))
|
||||||
|
self.bias = mx.zeros((dim,))
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
if self.elementwise_affine:
|
||||||
|
return mx.fast.layer_norm(x, self.weight, self.bias, self.eps)
|
||||||
|
else:
|
||||||
|
return mx.fast.layer_norm(x, None, None, self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class WanSelfAttention(nn.Module):
|
||||||
|
"""Self-attention with QK normalization and 3-way factorized RoPE."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
window_size: tuple = (-1, -1),
|
||||||
|
qk_norm: bool = True,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert dim % num_heads == 0
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.window_size = window_size
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
|
||||||
|
self.q = nn.Linear(dim, dim)
|
||||||
|
self.k = nn.Linear(dim, dim)
|
||||||
|
self.v = nn.Linear(dim, dim)
|
||||||
|
self.o = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
||||||
|
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
seq_lens: list,
|
||||||
|
grid_sizes: list,
|
||||||
|
freqs: mx.array,
|
||||||
|
) -> mx.array:
|
||||||
|
b, s, _ = x.shape
|
||||||
|
n, d = self.num_heads, self.head_dim
|
||||||
|
|
||||||
|
q = self.q(x)
|
||||||
|
k = self.k(x)
|
||||||
|
if self.norm_q is not None:
|
||||||
|
q = self.norm_q(q)
|
||||||
|
if self.norm_k is not None:
|
||||||
|
k = self.norm_k(k)
|
||||||
|
|
||||||
|
q = q.reshape(b, s, n, d)
|
||||||
|
k = k.reshape(b, s, n, d)
|
||||||
|
v = self.v(x).reshape(b, s, n, d)
|
||||||
|
|
||||||
|
# Apply RoPE
|
||||||
|
q = rope_apply(q, grid_sizes, freqs)
|
||||||
|
k = rope_apply(k, grid_sizes, freqs)
|
||||||
|
|
||||||
|
# Scaled dot-product attention: [B, L, N, D] -> [B, N, L, D]
|
||||||
|
q = q.transpose(0, 2, 1, 3)
|
||||||
|
k = k.transpose(0, 2, 1, 3)
|
||||||
|
v = v.transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
# Build attention mask from seq_lens
|
||||||
|
max_len = s
|
||||||
|
mask = None
|
||||||
|
if any(sl < max_len for sl in seq_lens):
|
||||||
|
mask = mx.zeros((b, 1, 1, max_len), dtype=q.dtype)
|
||||||
|
for i, sl in enumerate(seq_lens):
|
||||||
|
mask[i, :, :, sl:] = -1e9
|
||||||
|
|
||||||
|
# Use memory-efficient scaled dot-product attention
|
||||||
|
# mx.fast.scaled_dot_product_attention expects [B, N, L, D]
|
||||||
|
if mask is not None:
|
||||||
|
out = mx.fast.scaled_dot_product_attention(
|
||||||
|
q, k, v, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = mx.fast.scaled_dot_product_attention(
|
||||||
|
q, k, v, scale=self.scale
|
||||||
|
)
|
||||||
|
|
||||||
|
out = out.transpose(0, 2, 1, 3).reshape(b, s, -1)
|
||||||
|
return self.o(out)
|
||||||
|
|
||||||
|
|
||||||
|
class WanCrossAttention(nn.Module):
|
||||||
|
"""Cross-attention: Q from hidden states, K/V from text context."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
qk_norm: bool = True,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert dim % num_heads == 0
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
|
||||||
|
self.q = nn.Linear(dim, dim)
|
||||||
|
self.k = nn.Linear(dim, dim)
|
||||||
|
self.v = nn.Linear(dim, dim)
|
||||||
|
self.o = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
||||||
|
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
||||||
|
|
||||||
|
def prepare_kv(self, context: mx.array) -> tuple:
|
||||||
|
"""Pre-compute K and V projections for caching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: [B, L_ctx, dim]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(k, v) each [B, N, L_ctx, D] ready for attention
|
||||||
|
"""
|
||||||
|
b = context.shape[0]
|
||||||
|
n, d = self.num_heads, self.head_dim
|
||||||
|
k = self.k(context)
|
||||||
|
if self.norm_k is not None:
|
||||||
|
k = self.norm_k(k)
|
||||||
|
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||||
|
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||||
|
return k, v
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
context: mx.array,
|
||||||
|
context_lens: list | None = None,
|
||||||
|
kv_cache: tuple | None = None,
|
||||||
|
) -> mx.array:
|
||||||
|
b = x.shape[0]
|
||||||
|
n, d = self.num_heads, self.head_dim
|
||||||
|
|
||||||
|
q = self.q(x)
|
||||||
|
if self.norm_q is not None:
|
||||||
|
q = self.norm_q(q)
|
||||||
|
q = q.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if kv_cache is not None:
|
||||||
|
k, v = kv_cache
|
||||||
|
else:
|
||||||
|
k = self.k(context)
|
||||||
|
if self.norm_k is not None:
|
||||||
|
k = self.norm_k(k)
|
||||||
|
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||||
|
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
# Optional context masking
|
||||||
|
mask = None
|
||||||
|
if context_lens is not None:
|
||||||
|
ctx_len = k.shape[2]
|
||||||
|
mask = mx.zeros((b, 1, 1, ctx_len), dtype=q.dtype)
|
||||||
|
for i, cl in enumerate(context_lens):
|
||||||
|
mask[i, :, :, cl:] = -1e9
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
out = mx.fast.scaled_dot_product_attention(
|
||||||
|
q, k, v, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = mx.fast.scaled_dot_product_attention(
|
||||||
|
q, k, v, scale=self.scale
|
||||||
|
)
|
||||||
|
|
||||||
|
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d)
|
||||||
|
return self.o(out)
|
||||||
86
mlx_video/models/wan/config.py
Normal file
86
mlx_video/models/wan/config.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from mlx_video.models.ltx.config import BaseModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WanModelConfig(BaseModelConfig):
|
||||||
|
"""Configuration for Wan T2V models (supports both 2.1 and 2.2)."""
|
||||||
|
|
||||||
|
model_type: str = "t2v"
|
||||||
|
model_version: str = "2.2"
|
||||||
|
patch_size: Tuple[int, int, int] = (1, 2, 2)
|
||||||
|
text_len: int = 512
|
||||||
|
in_dim: int = 16
|
||||||
|
dim: int = 5120
|
||||||
|
ffn_dim: int = 13824
|
||||||
|
freq_dim: int = 256
|
||||||
|
text_dim: int = 4096
|
||||||
|
out_dim: int = 16
|
||||||
|
num_heads: int = 40
|
||||||
|
num_layers: int = 40
|
||||||
|
window_size: Tuple[int, int] = (-1, -1)
|
||||||
|
qk_norm: bool = True
|
||||||
|
cross_attn_norm: bool = True
|
||||||
|
eps: float = 1e-6
|
||||||
|
|
||||||
|
# VAE
|
||||||
|
vae_stride: Tuple[int, int, int] = (4, 8, 8)
|
||||||
|
vae_z_dim: int = 16
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
dual_model: bool = True
|
||||||
|
boundary: float = 0.875
|
||||||
|
sample_shift: float = 12.0
|
||||||
|
sample_steps: int = 40
|
||||||
|
sample_guide_scale: Union[float, Tuple[float, float]] = (3.0, 4.0)
|
||||||
|
num_train_timesteps: int = 1000
|
||||||
|
sample_fps: int = 16
|
||||||
|
frame_num: int = 81
|
||||||
|
|
||||||
|
# T5
|
||||||
|
t5_vocab_size: int = 256384
|
||||||
|
t5_dim: int = 4096
|
||||||
|
t5_dim_attn: int = 4096
|
||||||
|
t5_dim_ffn: int = 10240
|
||||||
|
t5_num_heads: int = 64
|
||||||
|
t5_num_layers: int = 24
|
||||||
|
t5_num_buckets: int = 32
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self) -> int:
|
||||||
|
return self.dim // self.num_heads
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def wan21_t2v_14b(cls) -> "WanModelConfig":
|
||||||
|
"""Wan2.1 T2V 14B: single model, 40 layers, dim=5120."""
|
||||||
|
return cls(
|
||||||
|
model_version="2.1",
|
||||||
|
dual_model=False,
|
||||||
|
boundary=0.0,
|
||||||
|
sample_shift=5.0,
|
||||||
|
sample_steps=50,
|
||||||
|
sample_guide_scale=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def wan21_t2v_1_3b(cls) -> "WanModelConfig":
|
||||||
|
"""Wan2.1 T2V 1.3B: single model, 30 layers, dim=1536."""
|
||||||
|
return cls(
|
||||||
|
model_version="2.1",
|
||||||
|
dim=1536,
|
||||||
|
ffn_dim=8960,
|
||||||
|
num_heads=12,
|
||||||
|
num_layers=30,
|
||||||
|
dual_model=False,
|
||||||
|
boundary=0.0,
|
||||||
|
sample_shift=5.0,
|
||||||
|
sample_steps=50,
|
||||||
|
sample_guide_scale=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def wan22_t2v_14b(cls) -> "WanModelConfig":
|
||||||
|
"""Wan2.2 T2V 14B: dual model, 40 layers, dim=5120 (default)."""
|
||||||
|
return cls()
|
||||||
307
mlx_video/models/wan/model.py
Normal file
307
mlx_video/models/wan/model.py
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .attention import WanLayerNorm
|
||||||
|
from .config import WanModelConfig
|
||||||
|
from .rope import rope_params
|
||||||
|
from .transformer import WanAttentionBlock
|
||||||
|
|
||||||
|
|
||||||
|
def sinusoidal_embedding_1d(dim: int, position: mx.array) -> mx.array:
|
||||||
|
"""Compute sinusoidal positional embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim: Embedding dimension (must be even).
|
||||||
|
position: 1D tensor of positions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embeddings of shape [len(position), dim].
|
||||||
|
"""
|
||||||
|
assert dim % 2 == 0
|
||||||
|
half = dim // 2
|
||||||
|
pos = position.astype(mx.float32)
|
||||||
|
inv_freq = mx.power(10000.0, -mx.arange(half).astype(mx.float32) / half)
|
||||||
|
sinusoid = pos[:, None] * inv_freq[None, :]
|
||||||
|
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
class Head(nn.Module):
|
||||||
|
"""Output projection head with learned modulation."""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, out_dim: int, patch_size: tuple, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.patch_size = patch_size
|
||||||
|
proj_dim = math.prod(patch_size) * out_dim
|
||||||
|
self.norm = WanLayerNorm(dim, eps)
|
||||||
|
self.head = nn.Linear(dim, proj_dim)
|
||||||
|
self.modulation = mx.random.normal((1, 2, dim)) * (dim**-0.5)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: [B, L, dim]
|
||||||
|
e: [B, dim] or [B, 1, dim] (time embedding, broadcast to all tokens)
|
||||||
|
"""
|
||||||
|
if e.ndim == 2:
|
||||||
|
e = e[:, None, :] # [B, 1, dim]
|
||||||
|
e_f32 = e.astype(mx.float32)
|
||||||
|
mod = (self.modulation + e_f32) # broadcasts [1, 2, dim] + [B, 1, dim] -> [B, 2, dim]
|
||||||
|
e0 = mod[:, 0:1, :] # [B, 1, dim] shift
|
||||||
|
e1 = mod[:, 1:2, :] # [B, 1, dim] scale
|
||||||
|
x_norm = self.norm(x).astype(mx.float32)
|
||||||
|
x_mod = x_norm * (1 + e1) + e0 # broadcasts over L
|
||||||
|
return self.head(x_mod.astype(x.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
class WanModel(nn.Module):
|
||||||
|
"""Wan2.2 diffusion backbone for text-to-video generation."""
|
||||||
|
|
||||||
|
def __init__(self, config: WanModelConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
dim = config.dim
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = config.num_heads
|
||||||
|
self.out_dim = config.out_dim
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
self.text_len = config.text_len
|
||||||
|
self.freq_dim = config.freq_dim
|
||||||
|
|
||||||
|
# Patch embedding: Conv3d implemented as a reshaped linear
|
||||||
|
# For kernel (1,2,2) and stride (1,2,2): reshape input then linear
|
||||||
|
patch_dim = config.in_dim * math.prod(config.patch_size)
|
||||||
|
self.patch_embedding_proj = nn.Linear(patch_dim, dim)
|
||||||
|
self._patch_size = config.patch_size
|
||||||
|
|
||||||
|
# Text embedding MLP
|
||||||
|
self.text_embedding_0 = nn.Linear(config.text_dim, dim)
|
||||||
|
self.text_embedding_act = nn.GELU(approx="precise")
|
||||||
|
self.text_embedding_1 = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
# Time embedding MLP
|
||||||
|
self.time_embedding_0 = nn.Linear(config.freq_dim, dim)
|
||||||
|
self.time_embedding_act = nn.SiLU()
|
||||||
|
self.time_embedding_1 = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
# Time projection for modulation (6x dim)
|
||||||
|
self.time_projection_act = nn.SiLU()
|
||||||
|
self.time_projection = nn.Linear(dim, dim * 6)
|
||||||
|
|
||||||
|
# Transformer blocks
|
||||||
|
self.blocks = [
|
||||||
|
WanAttentionBlock(
|
||||||
|
dim=dim,
|
||||||
|
ffn_dim=config.ffn_dim,
|
||||||
|
num_heads=config.num_heads,
|
||||||
|
window_size=config.window_size,
|
||||||
|
qk_norm=config.qk_norm,
|
||||||
|
cross_attn_norm=config.cross_attn_norm,
|
||||||
|
eps=config.eps,
|
||||||
|
)
|
||||||
|
for _ in range(config.num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Output head
|
||||||
|
self.head = Head(dim, config.out_dim, config.patch_size, config.eps)
|
||||||
|
|
||||||
|
# Precompute RoPE frequencies
|
||||||
|
d = dim // config.num_heads
|
||||||
|
d_t = d - 4 * (d // 6)
|
||||||
|
d_h = 2 * (d // 6)
|
||||||
|
d_w = 2 * (d // 6)
|
||||||
|
# Each rope_params returns [1024, d_x//2, 2]
|
||||||
|
freqs_t = rope_params(1024, d_t)
|
||||||
|
freqs_h = rope_params(1024, d_h)
|
||||||
|
freqs_w = rope_params(1024, d_w)
|
||||||
|
# Concatenate along the frequency dimension: [1024, d//2, 2]
|
||||||
|
self.freqs = mx.concatenate([freqs_t, freqs_h, freqs_w], axis=1)
|
||||||
|
|
||||||
|
def _patchify(self, x: mx.array) -> tuple:
|
||||||
|
"""Convert video tensor to patch embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Video latent [C, F, H, W]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(patches, grid_size): patches [1, L, dim], grid_size (F', H', W')
|
||||||
|
"""
|
||||||
|
c, f, h, w = x.shape
|
||||||
|
pt, ph, pw = self._patch_size
|
||||||
|
|
||||||
|
f_out = f // pt
|
||||||
|
h_out = h // ph
|
||||||
|
w_out = w // pw
|
||||||
|
|
||||||
|
# Reshape: [C, F, H, W] -> [F', H', W', C, pt, ph, pw] -> [F'*H'*W', C*pt*ph*pw]
|
||||||
|
# Order must be [C, pt, ph, pw] (C slowest) to match Conv3d weight layout
|
||||||
|
x = x.reshape(c, f_out, pt, h_out, ph, w_out, pw)
|
||||||
|
x = x.transpose(1, 3, 5, 0, 2, 4, 6) # [F', H', W', C, pt, ph, pw]
|
||||||
|
x = x.reshape(f_out * h_out * w_out, -1) # [L, C*pt*ph*pw]
|
||||||
|
|
||||||
|
# Project and cast to model dtype to prevent float32 cascade from input latents
|
||||||
|
patches = self.patch_embedding_proj(x) # [L, dim]
|
||||||
|
patches = patches.astype(self.patch_embedding_proj.weight.dtype)
|
||||||
|
patches = patches[None, :, :] # [1, L, dim]
|
||||||
|
|
||||||
|
return patches, (f_out, h_out, w_out)
|
||||||
|
|
||||||
|
def unpatchify(self, x: mx.array, grid_sizes: list) -> list:
|
||||||
|
"""Reconstruct video from patch embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: [B, L, out_dim * prod(patch_size)]
|
||||||
|
grid_sizes: List of (F', H', W') per batch element
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tensors [C, F, H, W]
|
||||||
|
"""
|
||||||
|
c = self.out_dim
|
||||||
|
pt, ph, pw = self.patch_size
|
||||||
|
out = []
|
||||||
|
for i, (f, h, w) in enumerate(grid_sizes):
|
||||||
|
seq_len = f * h * w
|
||||||
|
u = x[i, :seq_len] # [L, out_dim * pt * ph * pw]
|
||||||
|
u = u.reshape(f, h, w, pt, ph, pw, c)
|
||||||
|
# Rearrange: [F', H', W', pt, ph, pw, C] -> [C, F'*pt, H'*ph, W'*pw]
|
||||||
|
u = u.transpose(6, 0, 3, 1, 4, 2, 5) # [C, F', pt, H', ph, W', pw]
|
||||||
|
u = u.reshape(c, f * pt, h * ph, w * pw)
|
||||||
|
out.append(u)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def embed_text(self, context: list) -> mx.array:
|
||||||
|
"""Precompute text embeddings (call once, reuse across steps).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: List of text embeddings [L_text, text_dim]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embedded context [B, text_len, dim] in model dtype
|
||||||
|
"""
|
||||||
|
model_dtype = self.patch_embedding_proj.weight.dtype
|
||||||
|
context_padded = []
|
||||||
|
for ctx in context:
|
||||||
|
pad_len = self.text_len - ctx.shape[0]
|
||||||
|
if pad_len > 0:
|
||||||
|
ctx = mx.concatenate(
|
||||||
|
[ctx, mx.zeros((pad_len, ctx.shape[1]), dtype=ctx.dtype)],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
context_padded.append(ctx)
|
||||||
|
context_batch = mx.stack(context_padded) # [B, text_len, text_dim]
|
||||||
|
context_batch = self.text_embedding_1(
|
||||||
|
self.text_embedding_act(self.text_embedding_0(context_batch))
|
||||||
|
)
|
||||||
|
return context_batch.astype(model_dtype)
|
||||||
|
|
||||||
|
def prepare_cross_kv(self, context: mx.array) -> list:
|
||||||
|
"""Pre-compute cross-attention K/V for all blocks.
|
||||||
|
|
||||||
|
Call once before the diffusion loop to cache K/V projections,
|
||||||
|
eliminating redundant computation at each denoising step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Pre-embedded text [B, text_len, dim]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (k, v) tuples, one per block
|
||||||
|
"""
|
||||||
|
kv_caches = []
|
||||||
|
for block in self.blocks:
|
||||||
|
kv_caches.append(block.cross_attn.prepare_kv(context))
|
||||||
|
return kv_caches
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x_list: list,
|
||||||
|
t: mx.array,
|
||||||
|
context: list | mx.array,
|
||||||
|
seq_len: int,
|
||||||
|
cross_kv_caches: list | None = None,
|
||||||
|
) -> list:
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_list: List of video latent tensors [C, F, H, W]
|
||||||
|
t: Timestep tensor [B]
|
||||||
|
context: List of raw text embeddings, OR pre-embedded tensor
|
||||||
|
from embed_text() [B, text_len, dim]
|
||||||
|
seq_len: Maximum sequence length for padding
|
||||||
|
cross_kv_caches: Optional list of (k, v) tuples from
|
||||||
|
prepare_cross_kv(), one per block.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of denoised tensors [C, F, H, W]
|
||||||
|
"""
|
||||||
|
# Patchify each video
|
||||||
|
patches = []
|
||||||
|
grid_sizes = []
|
||||||
|
seq_lens_list = []
|
||||||
|
for vid in x_list:
|
||||||
|
p, gs = self._patchify(vid) # [1, L, dim]
|
||||||
|
patches.append(p)
|
||||||
|
grid_sizes.append(gs)
|
||||||
|
seq_lens_list.append(p.shape[1])
|
||||||
|
|
||||||
|
# Pad and batch
|
||||||
|
batch_size = len(patches)
|
||||||
|
x = mx.concatenate(
|
||||||
|
[
|
||||||
|
mx.concatenate(
|
||||||
|
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
if p.shape[1] < seq_len
|
||||||
|
else p
|
||||||
|
for p in patches
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
) # [B, seq_len, dim]
|
||||||
|
|
||||||
|
# Time embedding: compute once per sample, then broadcast to all tokens
|
||||||
|
if t.ndim == 0:
|
||||||
|
t = t[None]
|
||||||
|
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, freq_dim]
|
||||||
|
|
||||||
|
model_dtype = self.patch_embedding_proj.weight.dtype
|
||||||
|
e = self.time_embedding_1(
|
||||||
|
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||||
|
) # [B, dim]
|
||||||
|
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
|
||||||
|
e0 = e0.reshape(batch_size, 1, 6, self.dim).astype(model_dtype)
|
||||||
|
e = e.astype(model_dtype)
|
||||||
|
|
||||||
|
# Text embedding: skip MLP if context is already embedded (mx.array)
|
||||||
|
if isinstance(context, mx.array):
|
||||||
|
# Pre-embedded: expand to batch size if needed
|
||||||
|
context_batch = context
|
||||||
|
if context_batch.shape[0] == 1 and batch_size > 1:
|
||||||
|
context_batch = mx.broadcast_to(
|
||||||
|
context_batch, (batch_size,) + context_batch.shape[1:]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
context_batch = self.embed_text(context)
|
||||||
|
|
||||||
|
# Run transformer blocks
|
||||||
|
kwargs = dict(
|
||||||
|
e=e0,
|
||||||
|
seq_lens=seq_lens_list,
|
||||||
|
grid_sizes=grid_sizes,
|
||||||
|
freqs=self.freqs,
|
||||||
|
context=context_batch,
|
||||||
|
context_lens=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
kv = cross_kv_caches[i] if cross_kv_caches is not None else None
|
||||||
|
x = block(x, cross_kv_cache=kv, **kwargs)
|
||||||
|
|
||||||
|
# Output head
|
||||||
|
x = self.head(x, e)
|
||||||
|
|
||||||
|
# Unpatchify
|
||||||
|
outputs = self.unpatchify(x, grid_sizes)
|
||||||
|
return [u.astype(mx.float32) for u in outputs]
|
||||||
100
mlx_video/models/wan/rope.py
Normal file
100
mlx_video/models/wan/rope.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def rope_params(max_seq_len: int, dim: int, theta: float = 10000.0) -> mx.array:
|
||||||
|
"""Precompute RoPE frequency parameters as complex numbers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Complex frequency tensor of shape [max_seq_len, dim // 2].
|
||||||
|
"""
|
||||||
|
assert dim % 2 == 0
|
||||||
|
freqs = np.arange(max_seq_len, dtype=np.float64)[:, None] * (
|
||||||
|
1.0
|
||||||
|
/ np.power(
|
||||||
|
theta,
|
||||||
|
np.arange(0, dim, 2, dtype=np.float64) / dim,
|
||||||
|
)
|
||||||
|
)[None, :]
|
||||||
|
# Store as (cos, sin) pairs: shape [max_seq_len, dim // 2, 2]
|
||||||
|
cos_freqs = np.cos(freqs).astype(np.float32)
|
||||||
|
sin_freqs = np.sin(freqs).astype(np.float32)
|
||||||
|
return mx.array(np.stack([cos_freqs, sin_freqs], axis=-1))
|
||||||
|
|
||||||
|
|
||||||
|
def rope_apply(
|
||||||
|
x: mx.array,
|
||||||
|
grid_sizes: list,
|
||||||
|
freqs: mx.array,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Apply 3-way factorized RoPE to Q or K tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Shape [B, L, num_heads, head_dim]
|
||||||
|
grid_sizes: List of (F, H, W) tuples per batch element
|
||||||
|
freqs: Precomputed cos/sin, shape [1024, d//2, 2] split into 3 parts
|
||||||
|
"""
|
||||||
|
b, s, n, d = x.shape
|
||||||
|
half_d = d // 2
|
||||||
|
|
||||||
|
# Cast freqs to input dtype to prevent float32 promotion cascade
|
||||||
|
if freqs.dtype != x.dtype:
|
||||||
|
freqs = freqs.astype(x.dtype)
|
||||||
|
|
||||||
|
# Split frequency dimensions: temporal gets more capacity
|
||||||
|
d_t = half_d - 2 * (half_d // 3)
|
||||||
|
d_h = half_d // 3
|
||||||
|
d_w = half_d // 3
|
||||||
|
|
||||||
|
# Split freqs along dim axis
|
||||||
|
freqs_t = freqs[:, :d_t] # [1024, d_t, 2]
|
||||||
|
freqs_h = freqs[:, d_t : d_t + d_h] # [1024, d_h, 2]
|
||||||
|
freqs_w = freqs[:, d_t + d_h : d_t + d_h + d_w] # [1024, d_w, 2]
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for i in range(b):
|
||||||
|
f, h, w = grid_sizes[i]
|
||||||
|
seq_len = f * h * w
|
||||||
|
|
||||||
|
# Reshape x to pairs for rotation: [seq_len, n, half_d, 2]
|
||||||
|
x_i = x[i, :seq_len].reshape(seq_len, n, half_d, 2)
|
||||||
|
|
||||||
|
# Build per-position frequencies by expanding along grid dims
|
||||||
|
# temporal: [f,1,1,d_t,2] -> [f,h,w,d_t,2]
|
||||||
|
ft = mx.broadcast_to(
|
||||||
|
freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2)
|
||||||
|
)
|
||||||
|
# height: [1,h,1,d_h,2] -> [f,h,w,d_h,2]
|
||||||
|
fh = mx.broadcast_to(
|
||||||
|
freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2)
|
||||||
|
)
|
||||||
|
# width: [1,1,w,d_w,2] -> [f,h,w,d_w,2]
|
||||||
|
fw = mx.broadcast_to(
|
||||||
|
freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Concatenate: [f*h*w, half_d, 2]
|
||||||
|
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)
|
||||||
|
|
||||||
|
# Apply rotation: (a + bi) * (cos + sin*i) = (a*cos - b*sin) + (a*sin + b*cos)i
|
||||||
|
cos_f = freqs_i[..., 0] # [seq_len, 1, half_d]
|
||||||
|
sin_f = freqs_i[..., 1] # [seq_len, 1, half_d]
|
||||||
|
|
||||||
|
x_real = x_i[..., 0] # [seq_len, n, half_d]
|
||||||
|
x_imag = x_i[..., 1] # [seq_len, n, half_d]
|
||||||
|
|
||||||
|
out_real = x_real * cos_f - x_imag * sin_f
|
||||||
|
out_imag = x_real * sin_f + x_imag * cos_f
|
||||||
|
|
||||||
|
# Interleave back: [seq_len, n, half_d, 2] -> [seq_len, n, d]
|
||||||
|
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(seq_len, n, d)
|
||||||
|
|
||||||
|
# Handle padding: keep non-rotated tokens after seq_len
|
||||||
|
if seq_len < s:
|
||||||
|
x_rotated = mx.concatenate([x_rotated, x[i, seq_len:]], axis=0)
|
||||||
|
|
||||||
|
outputs.append(x_rotated)
|
||||||
|
|
||||||
|
return mx.stack(outputs)
|
||||||
76
mlx_video/models/wan/scheduler.py
Normal file
76
mlx_video/models/wan/scheduler.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""Flow matching scheduler for Wan2.2 inference."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
class FlowMatchEulerScheduler:
|
||||||
|
"""Simple Euler scheduler for flow matching diffusion.
|
||||||
|
|
||||||
|
Implements the flow matching formulation where the model predicts
|
||||||
|
velocity (flow) and we use Euler steps to denoise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_train_timesteps: int = 1000):
|
||||||
|
self.num_train_timesteps = num_train_timesteps
|
||||||
|
self.timesteps = None
|
||||||
|
self.sigmas = None
|
||||||
|
|
||||||
|
def set_timesteps(self, num_steps: int, shift: float = 1.0):
|
||||||
|
"""Compute sigma schedule with shift.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_steps: Number of inference steps.
|
||||||
|
shift: Noise schedule shift factor.
|
||||||
|
"""
|
||||||
|
# Linear spacing from sigma_max to sigma_min
|
||||||
|
sigmas = np.linspace(1.0, 1.0 / self.num_train_timesteps, self.num_train_timesteps)[::-1]
|
||||||
|
sigmas = 1.0 - sigmas
|
||||||
|
|
||||||
|
# Select evenly spaced subset
|
||||||
|
indices = np.linspace(0, len(sigmas) - 1, num_steps + 1).astype(int)
|
||||||
|
sigmas = sigmas[indices[:-1]]
|
||||||
|
|
||||||
|
# Apply shift: sigma' = shift * sigma / (1 + (shift - 1) * sigma)
|
||||||
|
sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas)
|
||||||
|
|
||||||
|
# Convert to timesteps
|
||||||
|
timesteps = sigmas * self.num_train_timesteps
|
||||||
|
self.timesteps = mx.array(timesteps.astype(np.float32))
|
||||||
|
|
||||||
|
# Append terminal sigma=0
|
||||||
|
sigmas = np.append(sigmas, 0.0)
|
||||||
|
self.sigmas = mx.array(sigmas.astype(np.float32))
|
||||||
|
self._step_index = 0
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self,
|
||||||
|
model_output: mx.array,
|
||||||
|
timestep,
|
||||||
|
sample: mx.array,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Euler step for flow matching.
|
||||||
|
|
||||||
|
In flow matching, model predicts velocity v, and:
|
||||||
|
x_{t-1} = sample + (sigma_{t-1} - sigma_t) * v
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_output: Predicted velocity [B, C, T, H, W]
|
||||||
|
timestep: Current timestep (unused, step index is tracked internally)
|
||||||
|
sample: Current noisy sample [B, C, T, H, W]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated sample
|
||||||
|
"""
|
||||||
|
# Use Python floats to avoid creating mx.array scalars that
|
||||||
|
# could trigger type promotion (per fast-mlx guide)
|
||||||
|
dt = float(self.sigmas[self._step_index + 1].item()) - float(self.sigmas[self._step_index].item())
|
||||||
|
x_next = sample + dt * model_output
|
||||||
|
|
||||||
|
self._step_index += 1
|
||||||
|
return x_next
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset step counter for new generation."""
|
||||||
|
self._step_index = 0
|
||||||
234
mlx_video/models/wan/text_encoder.py
Normal file
234
mlx_video/models/wan/text_encoder.py
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
"""T5 Text Encoder (UMT5-XXL) for Wan2.2 text conditioning."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class T5LayerNorm(nn.Module):
|
||||||
|
"""RMS-based layer normalization (T5 style)."""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = mx.ones((dim,))
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
return mx.fast.rms_norm(x, self.weight, self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class T5RelativeEmbedding(nn.Module):
|
||||||
|
"""T5-style relative position bias with bucketing."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_buckets: int,
|
||||||
|
num_heads: int,
|
||||||
|
bidirectional: bool = True,
|
||||||
|
max_dist: int = 128,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_buckets = num_buckets
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.bidirectional = bidirectional
|
||||||
|
self.max_dist = max_dist
|
||||||
|
self.embedding = nn.Embedding(num_buckets, num_heads)
|
||||||
|
|
||||||
|
def _relative_position_bucket(self, rel_pos: mx.array) -> mx.array:
|
||||||
|
if self.bidirectional:
|
||||||
|
num_buckets = self.num_buckets // 2
|
||||||
|
rel_buckets = (rel_pos > 0).astype(mx.int32) * num_buckets
|
||||||
|
rel_pos = mx.abs(rel_pos)
|
||||||
|
else:
|
||||||
|
num_buckets = self.num_buckets
|
||||||
|
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
|
||||||
|
rel_pos = mx.maximum(-rel_pos, mx.zeros_like(rel_pos))
|
||||||
|
|
||||||
|
max_exact = num_buckets // 2
|
||||||
|
is_small = rel_pos < max_exact
|
||||||
|
|
||||||
|
rel_pos_f = rel_pos.astype(mx.float32)
|
||||||
|
rel_pos_large = (
|
||||||
|
max_exact
|
||||||
|
+ (
|
||||||
|
mx.log(rel_pos_f / max_exact)
|
||||||
|
/ math.log(self.max_dist / max_exact)
|
||||||
|
* (num_buckets - max_exact)
|
||||||
|
).astype(mx.int32)
|
||||||
|
)
|
||||||
|
rel_pos_large = mx.minimum(
|
||||||
|
rel_pos_large,
|
||||||
|
mx.full(rel_pos_large.shape, num_buckets - 1, dtype=mx.int32),
|
||||||
|
)
|
||||||
|
|
||||||
|
rel_buckets = rel_buckets + mx.where(is_small, rel_pos.astype(mx.int32), rel_pos_large)
|
||||||
|
return rel_buckets
|
||||||
|
|
||||||
|
def __call__(self, lq: int, lk: int) -> mx.array:
|
||||||
|
positions_k = mx.arange(lk)[None, :] # [1, lk]
|
||||||
|
positions_q = mx.arange(lq)[:, None] # [lq, 1]
|
||||||
|
rel_pos = positions_k - positions_q # [lq, lk]
|
||||||
|
|
||||||
|
buckets = self._relative_position_bucket(rel_pos)
|
||||||
|
embeds = self.embedding(buckets) # [lq, lk, num_heads]
|
||||||
|
embeds = embeds.transpose(2, 0, 1)[None, :, :, :] # [1, N, lq, lk]
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
|
||||||
|
class T5Attention(nn.Module):
|
||||||
|
"""T5-style multi-head attention (no scaling)."""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, dim_attn: int, num_heads: int, dropout: float = 0.0):
|
||||||
|
super().__init__()
|
||||||
|
assert dim_attn % num_heads == 0
|
||||||
|
self.dim = dim
|
||||||
|
self.dim_attn = dim_attn
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim_attn // num_heads
|
||||||
|
|
||||||
|
self.q = nn.Linear(dim, dim_attn, bias=False)
|
||||||
|
self.k = nn.Linear(dim, dim_attn, bias=False)
|
||||||
|
self.v = nn.Linear(dim, dim_attn, bias=False)
|
||||||
|
self.o = nn.Linear(dim_attn, dim, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
context: mx.array | None = None,
|
||||||
|
mask: mx.array | None = None,
|
||||||
|
pos_bias: mx.array | None = None,
|
||||||
|
) -> mx.array:
|
||||||
|
context = x if context is None else context
|
||||||
|
b, n, c = x.shape[0], self.num_heads, self.head_dim
|
||||||
|
|
||||||
|
q = self.q(x).reshape(b, -1, n, c) # [B, Lq, N, C]
|
||||||
|
k = self.k(context).reshape(b, -1, n, c) # [B, Lk, N, C]
|
||||||
|
v = self.v(context).reshape(b, -1, n, c)
|
||||||
|
|
||||||
|
# T5 does not use scaling
|
||||||
|
# attn = einsum('binc,bjnc->bnij', q, k)
|
||||||
|
q = q.transpose(0, 2, 1, 3) # [B, N, Lq, C]
|
||||||
|
k = k.transpose(0, 2, 1, 3)
|
||||||
|
v = v.transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
# Combine position bias and attention mask for SDPA
|
||||||
|
attn_mask = None
|
||||||
|
if pos_bias is not None:
|
||||||
|
attn_mask = pos_bias.astype(q.dtype)
|
||||||
|
if mask is not None:
|
||||||
|
if mask.ndim == 2:
|
||||||
|
mask = mask[:, None, None, :] # [B, 1, 1, Lk]
|
||||||
|
elif mask.ndim == 3:
|
||||||
|
mask = mask[:, None, :, :] # [B, 1, Lq, Lk]
|
||||||
|
additive_mask = mx.where(mask == 0, -1e9, 0.0).astype(q.dtype)
|
||||||
|
attn_mask = (attn_mask + additive_mask) if attn_mask is not None else additive_mask
|
||||||
|
|
||||||
|
# T5 uses no scaling (scale=1.0)
|
||||||
|
out = mx.fast.scaled_dot_product_attention(
|
||||||
|
q, k, v, scale=1.0, mask=attn_mask
|
||||||
|
)
|
||||||
|
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * c)
|
||||||
|
return self.o(out)
|
||||||
|
|
||||||
|
|
||||||
|
class T5FeedForward(nn.Module):
|
||||||
|
"""Gated feed-forward: gate(x) * fc1(x) -> fc2."""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, dim_ffn: int):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.dim_ffn = dim_ffn
|
||||||
|
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
|
||||||
|
self.gate_act = nn.GELU(approx="precise")
|
||||||
|
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
||||||
|
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
return self.fc2(self.fc1(x) * self.gate_act(self.gate_proj(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class T5SelfAttentionBlock(nn.Module):
|
||||||
|
"""T5 encoder block: self-attention + FFN."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
dim_attn: int,
|
||||||
|
dim_ffn: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_buckets: int,
|
||||||
|
shared_pos: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.shared_pos = shared_pos
|
||||||
|
self.norm1 = T5LayerNorm(dim)
|
||||||
|
self.attn = T5Attention(dim, dim_attn, num_heads)
|
||||||
|
self.norm2 = T5LayerNorm(dim)
|
||||||
|
self.ffn = T5FeedForward(dim, dim_ffn)
|
||||||
|
self.pos_embedding = (
|
||||||
|
None
|
||||||
|
if shared_pos
|
||||||
|
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: mx.array | None = None,
|
||||||
|
pos_bias: mx.array | None = None,
|
||||||
|
) -> mx.array:
|
||||||
|
e = pos_bias if self.shared_pos else self.pos_embedding(x.shape[1], x.shape[1])
|
||||||
|
x = x + self.attn(self.norm1(x), mask=mask, pos_bias=e)
|
||||||
|
x = x + self.ffn(self.norm2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class T5Encoder(nn.Module):
|
||||||
|
"""T5 Encoder (UMT5-XXL configuration)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int = 256384,
|
||||||
|
dim: int = 4096,
|
||||||
|
dim_attn: int = 4096,
|
||||||
|
dim_ffn: int = 10240,
|
||||||
|
num_heads: int = 64,
|
||||||
|
num_layers: int = 24,
|
||||||
|
num_buckets: int = 32,
|
||||||
|
shared_pos: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
self.token_embedding = nn.Embedding(vocab_size, dim)
|
||||||
|
self.pos_embedding = (
|
||||||
|
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
|
||||||
|
if shared_pos
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.blocks = [
|
||||||
|
T5SelfAttentionBlock(
|
||||||
|
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
self.norm = T5LayerNorm(dim)
|
||||||
|
|
||||||
|
def __call__(self, ids: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
ids: Token IDs [B, L]
|
||||||
|
mask: Attention mask [B, L]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hidden states [B, L, dim]
|
||||||
|
"""
|
||||||
|
x = self.token_embedding(ids)
|
||||||
|
|
||||||
|
e = self.pos_embedding(x.shape[1], x.shape[1]) if self.pos_embedding else None
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, mask=mask, pos_bias=e)
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
return x
|
||||||
89
mlx_video/models/wan/transformer.py
Normal file
89
mlx_video/models/wan/transformer.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention
|
||||||
|
|
||||||
|
|
||||||
|
class WanAttentionBlock(nn.Module):
|
||||||
|
"""Wan transformer block with learned modulation, self-attn, cross-attn, and FFN."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
ffn_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
window_size: tuple = (-1, -1),
|
||||||
|
qk_norm: bool = True,
|
||||||
|
cross_attn_norm: bool = False,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Self-attention
|
||||||
|
self.norm1 = WanLayerNorm(dim, eps)
|
||||||
|
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
|
||||||
|
|
||||||
|
# Cross-attention (with optional norm on context)
|
||||||
|
self.norm3 = (
|
||||||
|
WanLayerNorm(dim, eps, elementwise_affine=True)
|
||||||
|
if cross_attn_norm
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.cross_attn = WanCrossAttention(dim, num_heads, qk_norm, eps)
|
||||||
|
|
||||||
|
# Feed-forward
|
||||||
|
self.norm2 = WanLayerNorm(dim, eps)
|
||||||
|
self.ffn = WanFFN(dim, ffn_dim)
|
||||||
|
|
||||||
|
# Learned modulation: 6 vectors for scale/shift/gate
|
||||||
|
self.modulation = mx.random.normal((1, 6, dim)) * (dim**-0.5)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
e: mx.array,
|
||||||
|
seq_lens: list,
|
||||||
|
grid_sizes: list,
|
||||||
|
freqs: mx.array,
|
||||||
|
context: mx.array,
|
||||||
|
context_lens: list | None = None,
|
||||||
|
cross_kv_cache: tuple | None = None,
|
||||||
|
) -> mx.array:
|
||||||
|
# Compute modulation: e is [B, 1, 6, dim] (broadcasts over tokens)
|
||||||
|
mod = (self.modulation + e) # [1, 6, dim] + [B, 1, 6, dim] -> [B, 1, 6, dim]
|
||||||
|
# Split into 6 modulation vectors (each [B, 1, dim], broadcast over L)
|
||||||
|
e0 = mod[:, :, 0, :] # shift for self-attn
|
||||||
|
e1 = mod[:, :, 1, :] # scale for self-attn
|
||||||
|
e2 = mod[:, :, 2, :] # gate for self-attn
|
||||||
|
e3 = mod[:, :, 3, :] # shift for ffn
|
||||||
|
e4 = mod[:, :, 4, :] # scale for ffn
|
||||||
|
e5 = mod[:, :, 5, :] # gate for ffn
|
||||||
|
|
||||||
|
# Self-attention with modulation
|
||||||
|
x_mod = self.norm1(x) * (1 + e1) + e0
|
||||||
|
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs)
|
||||||
|
x = x + y * e2
|
||||||
|
|
||||||
|
# Cross-attention (no modulation, just norm)
|
||||||
|
x_cross = self.norm3(x) if self.norm3 is not None else x
|
||||||
|
x = x + self.cross_attn(x_cross, context, context_lens, kv_cache=cross_kv_cache)
|
||||||
|
|
||||||
|
# FFN with modulation
|
||||||
|
x_mod = self.norm2(x) * (1 + e4) + e3
|
||||||
|
y = self.ffn(x_mod)
|
||||||
|
x = x + y * e5
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class WanFFN(nn.Module):
|
||||||
|
"""Gated feed-forward network with GELU(tanh) activation."""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, ffn_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = nn.Linear(dim, ffn_dim)
|
||||||
|
self.act = nn.GELU(approx="precise")
|
||||||
|
self.fc2 = nn.Linear(ffn_dim, dim)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
return self.fc2(self.act(self.fc1(x)))
|
||||||
315
mlx_video/models/wan/vae.py
Normal file
315
mlx_video/models/wan/vae.py
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
"""3D VAE Decoder for Wan2.1/2.2 (compression 4×8×8).
|
||||||
|
|
||||||
|
Module structure mirrors original PyTorch checkpoint key hierarchy
|
||||||
|
so weights load directly without key sanitization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
CACHE_T = 2
|
||||||
|
|
||||||
|
# Per-channel normalization statistics for z_dim=16
|
||||||
|
VAE_MEAN = [
|
||||||
|
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
||||||
|
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921,
|
||||||
|
]
|
||||||
|
VAE_STD = [
|
||||||
|
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
||||||
|
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv3d(nn.Module):
|
||||||
|
"""3D convolution with causal temporal padding."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int | tuple,
|
||||||
|
stride: int | tuple = 1,
|
||||||
|
padding: int | tuple = 0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||||
|
if isinstance(stride, int):
|
||||||
|
stride = (stride, stride, stride)
|
||||||
|
if isinstance(padding, int):
|
||||||
|
padding = (padding, padding, padding)
|
||||||
|
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
self._causal_pad_t = 2 * padding[0]
|
||||||
|
self._pad_h = padding[1]
|
||||||
|
self._pad_w = padding[2]
|
||||||
|
|
||||||
|
# MLX Conv3d: weight shape [O, D, H, W, I]
|
||||||
|
self.weight = mx.zeros((out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels))
|
||||||
|
self.bias = mx.zeros((out_channels,))
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
"""x: [B, C, T, H, W] (channel-first)"""
|
||||||
|
b, c, t, h, w = x.shape
|
||||||
|
|
||||||
|
if self._causal_pad_t > 0:
|
||||||
|
pad_t = mx.zeros((b, c, self._causal_pad_t, h, w), dtype=x.dtype)
|
||||||
|
x = mx.concatenate([pad_t, x], axis=2)
|
||||||
|
|
||||||
|
if self._pad_h > 0 or self._pad_w > 0:
|
||||||
|
x = mx.pad(x, [(0, 0), (0, 0), (0, 0),
|
||||||
|
(self._pad_h, self._pad_h), (self._pad_w, self._pad_w)])
|
||||||
|
|
||||||
|
x = x.transpose(0, 2, 3, 4, 1) # [B, T, H, W, C]
|
||||||
|
out = self._conv3d(x)
|
||||||
|
return out.transpose(0, 4, 1, 2, 3) # [B, O, T', H', W']
|
||||||
|
|
||||||
|
def _conv3d(self, x: mx.array) -> mx.array:
|
||||||
|
"""3D conv via sliding window + 2D conv per time step.
|
||||||
|
x: [B, T, H, W, C_in] -> [B, T_out, H_out, W_out, C_out]
|
||||||
|
"""
|
||||||
|
b, t, h, w, c_in = x.shape
|
||||||
|
kt, kh, kw = self.kernel_size
|
||||||
|
st, sh, sw = self.stride
|
||||||
|
t_out = (t - kt) // st + 1
|
||||||
|
|
||||||
|
# Pre-reshape weight: [O, D, H, W, I] -> [O, H, W, D*I]
|
||||||
|
w_2d = self.weight.transpose(0, 2, 3, 1, 4).reshape(
|
||||||
|
self.weight.shape[0], kh, kw, kt * c_in
|
||||||
|
)
|
||||||
|
outputs = []
|
||||||
|
for t_i in range(t_out):
|
||||||
|
t_start = t_i * st
|
||||||
|
window = x[:, t_start : t_start + kt]
|
||||||
|
window = window.transpose(0, 2, 3, 1, 4).reshape(b, h, w, kt * c_in)
|
||||||
|
out_2d = mx.conv2d(window, w_2d, stride=(sh, sw)) + self.bias
|
||||||
|
outputs.append(out_2d)
|
||||||
|
return mx.stack(outputs, axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
class RMS_norm(nn.Module):
|
||||||
|
"""Channel-first L2 normalization matching original Wan VAE.
|
||||||
|
|
||||||
|
Uses F.normalize (L2 norm) with learned scale, equivalent to RMS norm.
|
||||||
|
images=True: gamma shape (dim, 1, 1) for 4D (per-frame) input.
|
||||||
|
images=False: gamma shape (dim, 1, 1, 1) for 5D video input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, channel_first: bool = True, images: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
self.channel_first = channel_first
|
||||||
|
self.scale = dim**0.5
|
||||||
|
if channel_first:
|
||||||
|
broadcastable = (1, 1) if images else (1, 1, 1)
|
||||||
|
self.gamma = mx.ones((dim, *broadcastable))
|
||||||
|
else:
|
||||||
|
self.gamma = mx.ones((dim,))
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
norm_dim = 1 if self.channel_first else -1
|
||||||
|
# L2 normalize along channel dim (matches F.normalize)
|
||||||
|
norm = mx.sqrt(mx.clip(mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None))
|
||||||
|
return (x / norm) * self.scale * self.gamma
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
"""Residual block with causal 3D convolutions.
|
||||||
|
|
||||||
|
Uses `residual` list with None gaps to match original PyTorch
|
||||||
|
nn.Sequential indices: [0]=norm, [1]=SiLU, [2]=conv, [3]=norm,
|
||||||
|
[4]=SiLU, [5]=Dropout, [6]=conv. Only indices 0,2,3,6 have params.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_dim: int, out_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.residual = [
|
||||||
|
RMS_norm(in_dim, images=False), # [0]
|
||||||
|
None, # [1] SiLU
|
||||||
|
CausalConv3d(in_dim, out_dim, 3, padding=1), # [2]
|
||||||
|
RMS_norm(out_dim, images=False), # [3]
|
||||||
|
None, # [4] SiLU
|
||||||
|
None, # [5] Dropout
|
||||||
|
CausalConv3d(out_dim, out_dim, 3, padding=1), # [6]
|
||||||
|
]
|
||||||
|
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
h = x if self.shortcut is None else self.shortcut(x)
|
||||||
|
x = nn.silu(self.residual[0](x))
|
||||||
|
x = self.residual[2](x)
|
||||||
|
x = nn.silu(self.residual[3](x))
|
||||||
|
x = self.residual[6](x)
|
||||||
|
return x + h
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBlock(nn.Module):
|
||||||
|
"""Single-head spatial self-attention."""
|
||||||
|
|
||||||
|
def __init__(self, dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = RMS_norm(dim, images=True)
|
||||||
|
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||||
|
self.proj = nn.Conv2d(dim, dim, 1)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
"""x: [B, C, T, H, W]"""
|
||||||
|
identity = x
|
||||||
|
b, c, t, h, w = x.shape
|
||||||
|
|
||||||
|
# [B,C,T,H,W] -> [B,T,C,H,W] -> [BT,C,H,W] -> norm -> [BT,H,W,C]
|
||||||
|
x = x.transpose(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = x.transpose(0, 2, 3, 1) # [BT, H, W, C]
|
||||||
|
|
||||||
|
qkv = self.to_qkv(x) # [BT, H, W, 3C]
|
||||||
|
qkv = qkv.reshape(b * t, h * w, 3, c).transpose(2, 0, 1, 3)
|
||||||
|
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||||
|
|
||||||
|
q = q[:, None, :, :] # [BT, 1, HW, C]
|
||||||
|
k = k[:, None, :, :]
|
||||||
|
v = v[:, None, :, :]
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=c**-0.5)
|
||||||
|
out = out.squeeze(1).reshape(b * t, h, w, c) # [BT, H, W, C]
|
||||||
|
|
||||||
|
out = self.proj(out) # [BT, H, W, C]
|
||||||
|
out = out.reshape(b, t, h, w, c).transpose(0, 4, 1, 2, 3) # [B, C, T, H, W]
|
||||||
|
return out + identity
|
||||||
|
|
||||||
|
|
||||||
|
class Resample(nn.Module):
|
||||||
|
"""Upsample block matching original Wan VAE structure.
|
||||||
|
|
||||||
|
Uses `resample` list with [None, Conv2d] to match original
|
||||||
|
nn.Sequential(Upsample, Conv2d) where index 1 has the conv params.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, mode: str):
|
||||||
|
super().__init__()
|
||||||
|
assert mode in ("upsample2d", "upsample3d")
|
||||||
|
self.mode = mode
|
||||||
|
self.dim = dim
|
||||||
|
# resample.0 = Upsample (no params), resample.1 = Conv2d
|
||||||
|
self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)]
|
||||||
|
if mode == "upsample3d":
|
||||||
|
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
"""x: [B, C, T, H, W]"""
|
||||||
|
b, c, t, h, w = x.shape
|
||||||
|
|
||||||
|
if self.mode == "upsample3d":
|
||||||
|
# Temporal upsample via learned conv
|
||||||
|
x_t = self.time_conv(x) # [B, 2C, T, H, W]
|
||||||
|
x_t = x_t.reshape(b, 2, c, t, h, w)
|
||||||
|
# Interleave along time: [B, C, 2T, H, W]
|
||||||
|
x = mx.stack([x_t[:, 0], x_t[:, 1]], axis=3).reshape(b, c, t * 2, h, w)
|
||||||
|
t = t * 2
|
||||||
|
|
||||||
|
# Per-frame spatial upsample: nearest 2x + Conv2d
|
||||||
|
x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C]
|
||||||
|
x = mx.repeat(x, 2, axis=1)
|
||||||
|
x = mx.repeat(x, 2, axis=2)
|
||||||
|
x = self.resample[1](x) # Conv2d [BT, 2H, 2W, C//2]
|
||||||
|
c_out = x.shape[-1]
|
||||||
|
return x.reshape(b, t, h * 2, w * 2, c_out).transpose(0, 4, 1, 2, 3)
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder3d(nn.Module):
|
||||||
|
"""3D VAE Decoder matching Wan2.1 architecture.
|
||||||
|
|
||||||
|
Uses flat `middle` and `upsamples` lists to match original
|
||||||
|
PyTorch nn.Sequential weight key hierarchy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int = 96,
|
||||||
|
z_dim: int = 16,
|
||||||
|
dim_mult: list = None,
|
||||||
|
num_res_blocks: int = 2,
|
||||||
|
temporal_upsample: list = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if dim_mult is None:
|
||||||
|
dim_mult = [1, 2, 4, 4]
|
||||||
|
if temporal_upsample is None:
|
||||||
|
temporal_upsample = [True, True, False]
|
||||||
|
|
||||||
|
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||||
|
|
||||||
|
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||||
|
|
||||||
|
# Middle: [ResBlock, AttentionBlock, ResBlock]
|
||||||
|
self.middle = [
|
||||||
|
ResidualBlock(dims[0], dims[0]),
|
||||||
|
AttentionBlock(dims[0]),
|
||||||
|
ResidualBlock(dims[0], dims[0]),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Flat upsample list matching original nn.Sequential indexing
|
||||||
|
upsamples = []
|
||||||
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||||
|
if i in (1, 2, 3):
|
||||||
|
in_dim = in_dim // 2
|
||||||
|
for _ in range(num_res_blocks + 1):
|
||||||
|
upsamples.append(ResidualBlock(in_dim, out_dim))
|
||||||
|
in_dim = out_dim
|
||||||
|
if i != len(dim_mult) - 1:
|
||||||
|
mode = "upsample3d" if temporal_upsample[i] else "upsample2d"
|
||||||
|
upsamples.append(Resample(out_dim, mode=mode))
|
||||||
|
self.upsamples = upsamples
|
||||||
|
|
||||||
|
# Output head: [RMS_norm, SiLU (no params), CausalConv3d]
|
||||||
|
self.head = [
|
||||||
|
RMS_norm(dims[-1], images=False), # [0]
|
||||||
|
None, # [1] SiLU
|
||||||
|
CausalConv3d(dims[-1], 3, 3, padding=1), # [2]
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
"""x: [B, z_dim, T, H, W] -> [B, 3, T_out, H_out, W_out]"""
|
||||||
|
x = self.conv1(x)
|
||||||
|
|
||||||
|
for layer in self.middle:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
for layer in self.upsamples:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
x = nn.silu(self.head[0](x))
|
||||||
|
x = self.head[2](x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class WanVAE(nn.Module):
|
||||||
|
"""Wan2.1 VAE wrapper with per-channel normalization."""
|
||||||
|
|
||||||
|
def __init__(self, z_dim: int = 16):
|
||||||
|
super().__init__()
|
||||||
|
self.z_dim = z_dim
|
||||||
|
self.mean = mx.array(VAE_MEAN)
|
||||||
|
self.std = mx.array(VAE_STD)
|
||||||
|
self.inv_std = 1.0 / self.std
|
||||||
|
|
||||||
|
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||||
|
self.decoder = Decoder3d(dim=96, z_dim=z_dim)
|
||||||
|
|
||||||
|
def decode(self, z: mx.array) -> mx.array:
|
||||||
|
"""Decode latent to video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
z: Normalized latent [B, z_dim, T, H, W]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
|
||||||
|
"""
|
||||||
|
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||||
|
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
|
||||||
|
z = z / inv_std + mean
|
||||||
|
|
||||||
|
x = self.conv2(z)
|
||||||
|
out = self.decoder(x)
|
||||||
|
return mx.clip(out, -1, 1)
|
||||||
584
mlx_video/models/wan/vae22.py
Normal file
584
mlx_video/models/wan/vae22.py
Normal file
@@ -0,0 +1,584 @@
|
|||||||
|
"""Wan2.2 VAE Decoder (compression 4×16×16, z_dim=48).
|
||||||
|
|
||||||
|
Architecture differs from Wan2.1 VAE: uses RMS_norm, DupUp3D shortcuts,
|
||||||
|
spatial patchify (2×2), and different temporal upsampling pattern.
|
||||||
|
|
||||||
|
Weight keys mirror the PyTorch checkpoint hierarchy so only tensor format
|
||||||
|
conversion (channels-first → channels-last) is needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
CACHE_T = 2
|
||||||
|
|
||||||
|
# Per-channel normalization for z_dim=48 latent space
|
||||||
|
VAE22_MEAN = mx.array([
|
||||||
|
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
|
||||||
|
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
|
||||||
|
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
|
||||||
|
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
|
||||||
|
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
|
||||||
|
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
|
||||||
|
])
|
||||||
|
|
||||||
|
VAE22_STD = mx.array([
|
||||||
|
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
|
||||||
|
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
|
||||||
|
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
|
||||||
|
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
|
||||||
|
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
|
||||||
|
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744,
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv3d(nn.Module):
|
||||||
|
"""3D causal convolution. Input/output: [B, T, H, W, C] (channels-last).
|
||||||
|
|
||||||
|
Decomposes the 3D conv into per-frame 2D convolutions to avoid
|
||||||
|
excessive memory usage from MLX's conv3d implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||||
|
super().__init__()
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||||
|
if isinstance(stride, int):
|
||||||
|
stride = (stride, stride, stride)
|
||||||
|
if isinstance(padding, int):
|
||||||
|
padding = (padding, padding, padding)
|
||||||
|
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
self._causal_pad_t = 2 * padding[0]
|
||||||
|
self._pad_h = padding[1]
|
||||||
|
self._pad_w = padding[2]
|
||||||
|
|
||||||
|
# Weight: [O, D, H, W, I] for MLX
|
||||||
|
self.weight = mx.zeros((
|
||||||
|
out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels
|
||||||
|
))
|
||||||
|
self.bias = mx.zeros((out_channels,))
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
# x: [B, T, H, W, C]
|
||||||
|
B, T, H, W, C = x.shape
|
||||||
|
kd, kh, kw = self.kernel_size
|
||||||
|
|
||||||
|
# For 1x1x1 kernel or kernel_d==1, use direct conv
|
||||||
|
if kd == 1 and kh == 1 and kw == 1:
|
||||||
|
# Simple pointwise: reshape to [B*T, 1, 1, C] → conv2d
|
||||||
|
x_flat = x.reshape(B * T, H, W, C)
|
||||||
|
w2d = self.weight[:, 0, :, :, :] # [O, kH, kW, I]
|
||||||
|
y = mx.conv_general(x_flat, w2d) + self.bias
|
||||||
|
return y.reshape(B, T, y.shape[1], y.shape[2], -1)
|
||||||
|
|
||||||
|
# Causal temporal padding (left only)
|
||||||
|
if self._causal_pad_t > 0:
|
||||||
|
pad_t = mx.zeros((B, self._causal_pad_t, H, W, C))
|
||||||
|
x = mx.concatenate([pad_t, x], axis=1)
|
||||||
|
|
||||||
|
# Spatial padding
|
||||||
|
if self._pad_h > 0 or self._pad_w > 0:
|
||||||
|
x = mx.pad(x, [(0, 0), (0, 0), (self._pad_h, self._pad_h),
|
||||||
|
(self._pad_w, self._pad_w), (0, 0)])
|
||||||
|
|
||||||
|
T_padded = x.shape[1]
|
||||||
|
H_padded, W_padded = x.shape[2], x.shape[3]
|
||||||
|
T_out = (T_padded - kd) // self.stride[0] + 1
|
||||||
|
|
||||||
|
# Decompose 3D conv into sum of 2D convolutions over temporal kernel
|
||||||
|
# weight shape: [O, kd, kh, kw, I] → split into kd 2D kernels [O, kh, kw, I]
|
||||||
|
outputs = []
|
||||||
|
for t in range(T_out):
|
||||||
|
t_start = t * self.stride[0]
|
||||||
|
# Sum 2D convs for each temporal kernel position
|
||||||
|
accum = None
|
||||||
|
for d in range(kd):
|
||||||
|
frame = x[:, t_start + d] # [B, H_padded, W_padded, C]
|
||||||
|
w2d = self.weight[:, d, :, :, :] # [O, kh, kw, I]
|
||||||
|
conv_out = mx.conv_general(frame, w2d,
|
||||||
|
stride=(self.stride[1], self.stride[2]))
|
||||||
|
accum = conv_out if accum is None else accum + conv_out
|
||||||
|
outputs.append(accum + self.bias)
|
||||||
|
|
||||||
|
return mx.stack(outputs, axis=1) # [B, T_out, H_out, W_out, O]
|
||||||
|
|
||||||
|
|
||||||
|
class RMS_norm(nn.Module):
|
||||||
|
"""RMS normalization along channel dimension."""
|
||||||
|
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim ** 0.5
|
||||||
|
# Weight stored as (dim,) — PyTorch stores (dim, 1, 1, 1) but we squeeze
|
||||||
|
self.gamma = mx.ones((dim,))
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
# x: [..., C] (channels-last)
|
||||||
|
# PyTorch uses F.normalize (L2 norm), not RMS: x / max(||x||_2, eps)
|
||||||
|
l2_sq = mx.sum(x * x, axis=-1, keepdims=True)
|
||||||
|
return x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
"""Residual block: RMS_norm → SiLU → CausalConv3d × 2 + shortcut."""
|
||||||
|
|
||||||
|
def __init__(self, in_dim, out_dim):
|
||||||
|
super().__init__()
|
||||||
|
# Sequential residual path: [norm, silu, conv3d, norm, silu, dropout, conv3d]
|
||||||
|
# We store as named layers matching PyTorch's indices
|
||||||
|
self.residual = ResidualBlockLayers(in_dim, out_dim)
|
||||||
|
self.shortcut = (
|
||||||
|
CausalConv3d(in_dim, out_dim, 1)
|
||||||
|
if in_dim != out_dim
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
h = self.shortcut(x) if self.shortcut is not None else x
|
||||||
|
return self.residual(x) + h
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlockLayers(nn.Module):
|
||||||
|
"""The sequential layers inside a ResidualBlock.
|
||||||
|
|
||||||
|
PyTorch stores these as nn.Sequential with indices 0-6:
|
||||||
|
[0] RMS_norm, [1] SiLU, [2] CausalConv3d, [3] RMS_norm, [4] SiLU, [5] Dropout, [6] CausalConv3d
|
||||||
|
We use matching attribute names for weight compatibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_dim, out_dim):
|
||||||
|
super().__init__()
|
||||||
|
# Indices match PyTorch nn.Sequential indices for weight key compat
|
||||||
|
# Index 0: RMS_norm
|
||||||
|
self.layer_0 = RMS_norm(in_dim)
|
||||||
|
# Index 2: CausalConv3d
|
||||||
|
self.layer_2 = CausalConv3d(in_dim, out_dim, 3, padding=1)
|
||||||
|
# Index 3: RMS_norm
|
||||||
|
self.layer_3 = RMS_norm(out_dim)
|
||||||
|
# Index 6: CausalConv3d
|
||||||
|
self.layer_6 = CausalConv3d(out_dim, out_dim, 3, padding=1)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = self.layer_0(x)
|
||||||
|
x = nn.silu(x)
|
||||||
|
x = self.layer_2(x)
|
||||||
|
mx.eval(x) # Eval between convolutions to limit graph size
|
||||||
|
x = self.layer_3(x)
|
||||||
|
x = nn.silu(x)
|
||||||
|
x = self.layer_6(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBlock(nn.Module):
|
||||||
|
"""2D self-attention applied per frame. Input: [B, T, H, W, C]."""
|
||||||
|
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.norm = RMS_norm(dim)
|
||||||
|
# Conv2d as linear per spatial position — weight [O, H, W, I] for MLX
|
||||||
|
# to_qkv: dim -> 3*dim, proj: dim -> dim (1x1 conv2d)
|
||||||
|
self.to_qkv_weight = mx.zeros((3 * dim, 1, 1, dim))
|
||||||
|
self.to_qkv_bias = mx.zeros((3 * dim,))
|
||||||
|
self.proj_weight = mx.zeros((dim, 1, 1, dim))
|
||||||
|
self.proj_bias = mx.zeros((dim,))
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
# x: [B, T, H, W, C]
|
||||||
|
identity = x
|
||||||
|
B, T, H, W, C = x.shape
|
||||||
|
|
||||||
|
# Apply per frame: merge B and T
|
||||||
|
x = x.reshape(B * T, H, W, C)
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
# QKV via 1x1 conv2d (equivalent to linear on last dim)
|
||||||
|
qkv = mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias # [BT, H, W, 3C]
|
||||||
|
qkv = qkv.reshape(B * T, H * W, 3 * C)
|
||||||
|
q, k, v = mx.split(qkv, 3, axis=-1) # each [BT, HW, C]
|
||||||
|
|
||||||
|
# Single-head attention
|
||||||
|
q = q[:, None, :, :] # [BT, 1, HW, C]
|
||||||
|
k = k[:, None, :, :]
|
||||||
|
v = v[:, None, :, :]
|
||||||
|
|
||||||
|
scale = C ** -0.5
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) # [BT, 1, HW, C]
|
||||||
|
out = out.squeeze(1).reshape(B * T, H, W, C)
|
||||||
|
|
||||||
|
# Project output
|
||||||
|
out = mx.conv_general(out, self.proj_weight) + self.proj_bias # [BT, H, W, C]
|
||||||
|
out = out.reshape(B, T, H, W, C)
|
||||||
|
return out + identity
|
||||||
|
|
||||||
|
|
||||||
|
class DupUp3D(nn.Module):
|
||||||
|
"""Upsample by duplicating channels and reshaping. No learnable parameters."""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, factor_t, factor_s=1):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.factor_t = factor_t
|
||||||
|
self.factor_s = factor_s
|
||||||
|
self.factor = factor_t * factor_s * factor_s
|
||||||
|
self.repeats = out_channels * self.factor // in_channels
|
||||||
|
|
||||||
|
def __call__(self, x, first_chunk=False):
|
||||||
|
# x: [B, T, H, W, C]
|
||||||
|
B, T, H, W, C = x.shape
|
||||||
|
|
||||||
|
# Repeat channels
|
||||||
|
x = mx.repeat(x, self.repeats, axis=-1) # [B, T, H, W, C*repeats]
|
||||||
|
|
||||||
|
# Reshape to [B, T, H, W, out_C, factor_t, factor_s, factor_s]
|
||||||
|
x = x.reshape(B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s)
|
||||||
|
|
||||||
|
# Permute to interleave: [B, T, factor_t, H, factor_s, W, factor_s, out_C]
|
||||||
|
x = x.transpose(0, 1, 5, 2, 6, 3, 7, 4)
|
||||||
|
|
||||||
|
# Reshape to final: [B, T*factor_t, H*factor_s, W*factor_s, out_C]
|
||||||
|
x = x.reshape(B, T * self.factor_t, H * self.factor_s, W * self.factor_s, self.out_channels)
|
||||||
|
|
||||||
|
if first_chunk:
|
||||||
|
x = x[:, self.factor_t - 1:, :, :, :]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Resample(nn.Module):
|
||||||
|
"""Spatial up/downsampling with optional temporal up/downsampling."""
|
||||||
|
|
||||||
|
def __init__(self, dim, mode):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
if mode == "upsample2d":
|
||||||
|
# resample.0 = Upsample (no params), resample.1 = Conv2d
|
||||||
|
self.resample_weight = mx.zeros((dim, 3, 3, dim)) # Conv2d [O, H, W, I]
|
||||||
|
self.resample_bias = mx.zeros((dim,))
|
||||||
|
elif mode == "upsample3d":
|
||||||
|
self.resample_weight = mx.zeros((dim, 3, 3, dim))
|
||||||
|
self.resample_bias = mx.zeros((dim,))
|
||||||
|
# time_conv: CausalConv3d(dim, dim*2, (3,1,1), padding=(1,0,0))
|
||||||
|
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported mode: {mode}")
|
||||||
|
|
||||||
|
def _upsample2x(self, x):
|
||||||
|
"""Nearest-neighbor 2x spatial upsample. x: [N, H, W, C]."""
|
||||||
|
N, H, W, C = x.shape
|
||||||
|
# Repeat along H and W axes separately
|
||||||
|
x = mx.repeat(x, repeats=2, axis=1) # [N, 2H, W, C]
|
||||||
|
x = mx.repeat(x, repeats=2, axis=2) # [N, 2H, 2W, C]
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _conv2d(self, x):
|
||||||
|
"""Apply the Conv2d with padding=1. x: [N, H, W, C]."""
|
||||||
|
x = mx.pad(x, [(0, 0), (1, 1), (1, 1), (0, 0)])
|
||||||
|
return mx.conv_general(x, self.resample_weight) + self.resample_bias
|
||||||
|
|
||||||
|
def __call__(self, x, first_chunk=False):
|
||||||
|
# x: [B, T, H, W, C]
|
||||||
|
B, T, H, W, C = x.shape
|
||||||
|
|
||||||
|
if self.mode == "upsample3d":
|
||||||
|
# Temporal upsample via time_conv
|
||||||
|
tc_out = self.time_conv(x) # [B, T, H, W, 2C]
|
||||||
|
# Split into two interleaved temporal streams
|
||||||
|
tc_out = tc_out.reshape(B, T, H, W, 2, C)
|
||||||
|
# Interleave: [B, T, 2, H, W, C] → [B, T*2, H, W, C]
|
||||||
|
stream0 = tc_out[:, :, :, :, 0, :] # [B, T, H, W, C]
|
||||||
|
stream1 = tc_out[:, :, :, :, 1, :] # [B, T, H, W, C]
|
||||||
|
x = mx.stack([stream0, stream1], axis=2) # [B, T, 2, H, W, C]
|
||||||
|
x = x.reshape(B, T * 2, H, W, C)
|
||||||
|
|
||||||
|
if first_chunk:
|
||||||
|
# PyTorch skips time_conv for first chunk entirely. In all-at-once
|
||||||
|
# mode, we trim the first frame to match (the first interleaved
|
||||||
|
# frame is from zero-padded causal context and shouldn't be kept).
|
||||||
|
x = x[:, 1:, :, :, :]
|
||||||
|
|
||||||
|
mx.eval(x)
|
||||||
|
T = x.shape[1]
|
||||||
|
|
||||||
|
# Spatial upsample in temporal chunks to limit peak memory
|
||||||
|
chunk_size = 8
|
||||||
|
chunks = []
|
||||||
|
for t_start in range(0, T, chunk_size):
|
||||||
|
t_end = min(t_start + chunk_size, T)
|
||||||
|
x_chunk = x[:, t_start:t_end].reshape(-1, H, W, C)
|
||||||
|
x_chunk = self._upsample2x(x_chunk)
|
||||||
|
x_chunk = self._conv2d(x_chunk)
|
||||||
|
mx.eval(x_chunk)
|
||||||
|
chunks.append(x_chunk)
|
||||||
|
|
||||||
|
x = mx.concatenate(chunks, axis=0)
|
||||||
|
H2, W2 = x.shape[1], x.shape[2]
|
||||||
|
x = x.reshape(B, T, H2, W2, C)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Up_ResidualBlock(nn.Module):
|
||||||
|
"""Upsampling residual block with optional DupUp3D shortcut."""
|
||||||
|
|
||||||
|
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False):
|
||||||
|
super().__init__()
|
||||||
|
self.up_flag = up_flag
|
||||||
|
|
||||||
|
# DupUp3D shortcut (no learnable params)
|
||||||
|
if up_flag:
|
||||||
|
self.avg_shortcut = DupUp3D(
|
||||||
|
in_dim, out_dim,
|
||||||
|
factor_t=2 if temperal_upsample else 1,
|
||||||
|
factor_s=2 if up_flag else 1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.avg_shortcut = None
|
||||||
|
|
||||||
|
# Main path: ResidualBlocks + optional Resample
|
||||||
|
blocks = []
|
||||||
|
dim_in = in_dim
|
||||||
|
for _ in range(num_res_blocks):
|
||||||
|
blocks.append(ResidualBlock(dim_in, out_dim))
|
||||||
|
dim_in = out_dim
|
||||||
|
|
||||||
|
if up_flag:
|
||||||
|
mode = "upsample3d" if temperal_upsample else "upsample2d"
|
||||||
|
blocks.append(Resample(out_dim, mode=mode))
|
||||||
|
|
||||||
|
self.upsamples = blocks
|
||||||
|
|
||||||
|
def __call__(self, x, first_chunk=False):
|
||||||
|
x_main = x
|
||||||
|
for module in self.upsamples:
|
||||||
|
if isinstance(module, Resample):
|
||||||
|
x_main = module(x_main, first_chunk)
|
||||||
|
else:
|
||||||
|
x_main = module(x_main)
|
||||||
|
mx.eval(x_main) # Limit graph size per sub-block
|
||||||
|
|
||||||
|
if self.avg_shortcut is not None:
|
||||||
|
x_shortcut = self.avg_shortcut(x, first_chunk)
|
||||||
|
mx.eval(x_shortcut)
|
||||||
|
return x_main + x_shortcut
|
||||||
|
return x_main
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder3d(nn.Module):
|
||||||
|
"""Wan2.2 3D VAE Decoder."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim=256,
|
||||||
|
z_dim=48,
|
||||||
|
dim_mult=(1, 2, 4, 4),
|
||||||
|
num_res_blocks=2,
|
||||||
|
temperal_upsample=(True, True, False),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Compute layer dimensions
|
||||||
|
dims = [dim * dim_mult[-1]] + [dim * m for m in reversed(dim_mult)]
|
||||||
|
|
||||||
|
# Initial conv
|
||||||
|
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||||
|
|
||||||
|
# Middle blocks
|
||||||
|
self.middle = [
|
||||||
|
ResidualBlock(dims[0], dims[0]),
|
||||||
|
AttentionBlock(dims[0]),
|
||||||
|
ResidualBlock(dims[0], dims[0]),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Upsample blocks
|
||||||
|
self.upsamples = []
|
||||||
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||||
|
t_up = temperal_upsample[i] if i < len(temperal_upsample) else False
|
||||||
|
self.upsamples.append(Up_ResidualBlock(
|
||||||
|
in_dim=in_dim,
|
||||||
|
out_dim=out_dim,
|
||||||
|
num_res_blocks=num_res_blocks + 1,
|
||||||
|
temperal_upsample=t_up,
|
||||||
|
up_flag=(i != len(dim_mult) - 1),
|
||||||
|
))
|
||||||
|
|
||||||
|
# Output head: [RMS_norm, SiLU, CausalConv3d]
|
||||||
|
self.head = Head22(dims[-1])
|
||||||
|
|
||||||
|
def __call__(self, x, first_chunk=False):
|
||||||
|
# x: [B, T, H, W, C=z_dim]
|
||||||
|
x = self.conv1(x)
|
||||||
|
|
||||||
|
for layer in self.middle:
|
||||||
|
x = layer(x)
|
||||||
|
mx.eval(x) # Evaluate to limit graph size
|
||||||
|
|
||||||
|
for i, layer in enumerate(self.upsamples):
|
||||||
|
x = layer(x, first_chunk)
|
||||||
|
mx.eval(x) # Evaluate after each upsample block
|
||||||
|
|
||||||
|
x = self.head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Head22(nn.Module):
|
||||||
|
"""Decoder output head: RMS_norm → SiLU → CausalConv3d(dim, 12, 3).
|
||||||
|
|
||||||
|
PyTorch key mapping: head.0 = RMS_norm, head.2 = CausalConv3d
|
||||||
|
(index 1 = SiLU has no params)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim, out_channels=12):
|
||||||
|
super().__init__()
|
||||||
|
# Index 0: RMS_norm
|
||||||
|
self.layer_0 = RMS_norm(dim)
|
||||||
|
# Index 2: CausalConv3d
|
||||||
|
self.layer_2 = CausalConv3d(dim, out_channels, 3, padding=1)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = self.layer_0(x)
|
||||||
|
x = nn.silu(x)
|
||||||
|
x = self.layer_2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Wan22VAEDecoder(nn.Module):
|
||||||
|
"""Full Wan2.2 VAE decoder with normalization and unpatchify."""
|
||||||
|
|
||||||
|
def __init__(self, z_dim=48, dim=160, dec_dim=256):
|
||||||
|
super().__init__()
|
||||||
|
self.z_dim = z_dim
|
||||||
|
# conv2: 1x1x1 conv before decoder
|
||||||
|
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||||
|
self.decoder = Decoder3d(
|
||||||
|
dim=dec_dim,
|
||||||
|
z_dim=z_dim,
|
||||||
|
dim_mult=(1, 2, 4, 4),
|
||||||
|
num_res_blocks=2,
|
||||||
|
temperal_upsample=(True, True, False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, z):
|
||||||
|
"""Decode latents to video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
z: [B, T, H, W, C=48] latent tensor (already denormalized)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
video: [B, T', H', W', 3] decoded RGB in [-1, 1]
|
||||||
|
"""
|
||||||
|
x = self.conv2(z)
|
||||||
|
|
||||||
|
# All-at-once decode with first_chunk=True to trim extra temporal
|
||||||
|
# frames from causal padding (matches PyTorch's chunked behavior)
|
||||||
|
out = self.decoder(x, first_chunk=True)
|
||||||
|
|
||||||
|
# Unpatchify: 12 channels → 3 RGB (spatial 2×2)
|
||||||
|
out = _unpatchify(out, patch_size=2)
|
||||||
|
|
||||||
|
return mx.clip(out, -1.0, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def denormalize_latents(z, mean=None, std=None):
|
||||||
|
"""Denormalize latents: z = z / (1/std) + mean."""
|
||||||
|
if mean is None:
|
||||||
|
mean = VAE22_MEAN
|
||||||
|
if std is None:
|
||||||
|
std = VAE22_STD
|
||||||
|
inv_scale = std # scale was 1/std, so divide by scale = multiply by std
|
||||||
|
return z * inv_scale.reshape(1, 1, 1, 1, -1) + mean.reshape(1, 1, 1, 1, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def _unpatchify(x, patch_size=2):
|
||||||
|
"""Convert from packed channels to spatial: [B, T, H, W, C*p*p] → [B, T, H*p, W*p, C//(p*p)]
|
||||||
|
Actually: [B, T, H, W, 12] → [B, T, H*2, W*2, 3]
|
||||||
|
PyTorch: b (c r q) f h w -> b c f (h q) (w r) with q=p, r=p
|
||||||
|
In channels-last: [B, T, H, W, C*r*q] -> [B, T, H*q, W*r, C]
|
||||||
|
"""
|
||||||
|
if patch_size == 1:
|
||||||
|
return x
|
||||||
|
B, T, H, W, Cpacked = x.shape
|
||||||
|
C = Cpacked // (patch_size * patch_size)
|
||||||
|
# Reshape: [B, T, H, W, r, q, C] then rearrange to [B, T, H*q, W*r, C]
|
||||||
|
# PyTorch patchify: "b c f (h q) (w r) -> b (c r q) f h w" — so c is packed as (c, r, q)
|
||||||
|
# Unpatchify reverses: [B, T, H, W, (C, r, q)] -> [B, T, H, q, W, r, C]
|
||||||
|
x = x.reshape(B, T, H, W, C, patch_size, patch_size)
|
||||||
|
# Rearrange: put q next to H, r next to W
|
||||||
|
x = x.transpose(0, 1, 2, 6, 3, 5, 4) # [B, T, H, q, W, r, C]
|
||||||
|
x = x.reshape(B, T, H * patch_size, W * patch_size, C)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_wan22_vae_weights(weights: dict) -> dict:
|
||||||
|
"""Convert PyTorch Wan2.2 VAE weights to MLX format.
|
||||||
|
|
||||||
|
Only keeps decoder + conv2 weights (encoder/conv1 not needed for generation).
|
||||||
|
Transposes conv weights from channels-first to channels-last.
|
||||||
|
Squeezes RMS_norm gamma from (dim, 1, 1, 1) or (dim, 1, 1) to (dim,).
|
||||||
|
Maps PyTorch nn.Sequential indices to our named layers.
|
||||||
|
"""
|
||||||
|
sanitized = {}
|
||||||
|
|
||||||
|
for key, value in weights.items():
|
||||||
|
# Skip encoder and conv1 (encoder-only)
|
||||||
|
if key.startswith("encoder.") or key.startswith("conv1."):
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_key = key
|
||||||
|
|
||||||
|
# Map nn.Sequential indexed layers to our named attributes
|
||||||
|
# ResidualBlockLayers: indices 0, 2, 3, 6 → _layer_0, _layer_2, _layer_3, _layer_6
|
||||||
|
# Head22: indices 0, 2 → _layer_0, _layer_2
|
||||||
|
for idx in ["0", "2", "3", "6"]:
|
||||||
|
# Match patterns like "residual.0.gamma" → "residual.layer_0.gamma"
|
||||||
|
# or "head.0.gamma" → "head.layer_0.gamma"
|
||||||
|
old_pattern = f".residual.{idx}."
|
||||||
|
new_pattern = f".residual.layer_{idx}."
|
||||||
|
new_key = new_key.replace(old_pattern, new_pattern)
|
||||||
|
|
||||||
|
# Head layer mapping: head.0.gamma → head.layer_0.gamma, head.2.weight → head.layer_2.weight
|
||||||
|
for idx in ["0", "2"]:
|
||||||
|
old_pattern = f".head.{idx}."
|
||||||
|
new_pattern = f".head.layer_{idx}."
|
||||||
|
new_key = new_key.replace(old_pattern, new_pattern)
|
||||||
|
|
||||||
|
# Map Resample Conv2d: resample.1.weight → resample_weight, resample.1.bias → resample_bias
|
||||||
|
if ".resample.1.weight" in new_key:
|
||||||
|
new_key = new_key.replace(".resample.1.weight", ".resample_weight")
|
||||||
|
elif ".resample.1.bias" in new_key:
|
||||||
|
new_key = new_key.replace(".resample.1.bias", ".resample_bias")
|
||||||
|
|
||||||
|
# Map AttentionBlock Conv2d weights
|
||||||
|
if ".to_qkv.weight" in new_key:
|
||||||
|
new_key = new_key.replace(".to_qkv.weight", ".to_qkv_weight")
|
||||||
|
elif ".to_qkv.bias" in new_key:
|
||||||
|
new_key = new_key.replace(".to_qkv.bias", ".to_qkv_bias")
|
||||||
|
elif ".proj.weight" in new_key and "time_projection" not in new_key:
|
||||||
|
new_key = new_key.replace(".proj.weight", ".proj_weight")
|
||||||
|
elif ".proj.bias" in new_key and "time_projection" not in new_key:
|
||||||
|
new_key = new_key.replace(".proj.bias", ".proj_bias")
|
||||||
|
|
||||||
|
# Transpose conv weights to channels-last
|
||||||
|
is_weight = new_key.endswith(".weight") or new_key.endswith("_weight")
|
||||||
|
if is_weight:
|
||||||
|
if value.ndim == 5:
|
||||||
|
# Conv3d: [O, I, D, H, W] → [O, D, H, W, I]
|
||||||
|
value = np.transpose(np.array(value), (0, 2, 3, 4, 1))
|
||||||
|
value = mx.array(value)
|
||||||
|
elif value.ndim == 4:
|
||||||
|
# Conv2d: [O, I, H, W] → [O, H, W, I]
|
||||||
|
value = np.transpose(np.array(value), (0, 2, 3, 1))
|
||||||
|
value = mx.array(value)
|
||||||
|
|
||||||
|
# Squeeze RMS_norm gamma: (dim, 1, 1, 1) or (dim, 1, 1) → (dim,)
|
||||||
|
if "gamma" in new_key:
|
||||||
|
value = mx.array(np.array(value).squeeze())
|
||||||
|
|
||||||
|
sanitized[new_key] = value
|
||||||
|
|
||||||
|
return sanitized
|
||||||
@@ -19,7 +19,9 @@ dependencies = [
|
|||||||
"tqdm",
|
"tqdm",
|
||||||
"opencv-python>=4.12.0.88",
|
"opencv-python>=4.12.0.88",
|
||||||
"Pillow>=10.3.0",
|
"Pillow>=10.3.0",
|
||||||
"mlx-vlm"
|
"mlx-vlm",
|
||||||
|
"imageio>=2.37.2",
|
||||||
|
"imageio-ffmpeg>=0.6.0",
|
||||||
]
|
]
|
||||||
license = {text="MIT"}
|
license = {text="MIT"}
|
||||||
authors = [
|
authors = [
|
||||||
@@ -42,6 +44,7 @@ Issues = "https://github.com/Blaizzy/mlx-video/issues"
|
|||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
"mlx_video.generate" = "mlx_video.generate:main"
|
"mlx_video.generate" = "mlx_video.generate:main"
|
||||||
|
"mlx_video.generate_wan" = "mlx_video.generate_wan:main"
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
include = ["mlx_video*"]
|
include = ["mlx_video*"]
|
||||||
@@ -52,4 +55,4 @@ version = {attr = "mlx_video.version.__version__"}
|
|||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
dev = [
|
dev = [
|
||||||
"pytest",
|
"pytest",
|
||||||
]
|
]
|
||||||
|
|||||||
1453
tests/test_wan.py
Normal file
1453
tests/test_wan.py
Normal file
File diff suppressed because it is too large
Load Diff
34
uv.lock
generated
34
uv.lock
generated
@@ -2,7 +2,8 @@ version = 1
|
|||||||
revision = 3
|
revision = 3
|
||||||
requires-python = ">=3.11"
|
requires-python = ">=3.11"
|
||||||
resolution-markers = [
|
resolution-markers = [
|
||||||
"python_full_version >= '3.12'",
|
"python_full_version >= '3.13'",
|
||||||
|
"python_full_version == '3.12.*'",
|
||||||
"python_full_version < '3.12'",
|
"python_full_version < '3.12'",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -614,6 +615,33 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" },
|
{ url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "imageio"
|
||||||
|
version = "2.37.2"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "numpy" },
|
||||||
|
{ name = "pillow" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/a3/6f/606be632e37bf8d05b253e8626c2291d74c691ddc7bcdf7d6aaf33b32f6a/imageio-2.37.2.tar.gz", hash = "sha256:0212ef2727ac9caa5ca4b2c75ae89454312f440a756fcfc8ef1993e718f50f8a", size = 389600, upload-time = "2025-11-04T14:29:39.898Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/fb/fe/301e0936b79bcab4cacc7548bf2853fc28dced0a578bab1f7ef53c9aa75b/imageio-2.37.2-py3-none-any.whl", hash = "sha256:ad9adfb20335d718c03de457358ed69f141021a333c40a53e57273d8a5bd0b9b", size = 317646, upload-time = "2025-11-04T14:29:37.948Z" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "imageio-ffmpeg"
|
||||||
|
version = "0.6.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/44/bd/c3343c721f2a1b0c9fc71c1aebf1966a3b7f08c2eea8ed5437a2865611d6/imageio_ffmpeg-0.6.0.tar.gz", hash = "sha256:e2556bed8e005564a9f925bb7afa4002d82770d6b08825078b7697ab88ba1755", size = 25210, upload-time = "2025-01-16T21:34:32.747Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/da/58/87ef68ac83f4c7690961bce288fd8e382bc5f1513860fc7f90a9c1c1c6bf/imageio_ffmpeg-0.6.0-py3-none-macosx_10_9_intel.macosx_10_9_x86_64.whl", hash = "sha256:9d2baaf867088508d4a3458e61eeb30e945c4ad8016025545f66c4b5aaef0a61", size = 24932969, upload-time = "2025-01-16T21:34:20.464Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/40/5c/f3d8a657d362cc93b81aab8feda487317da5b5d31c0e1fdfd5e986e55d17/imageio_ffmpeg-0.6.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:b1ae3173414b5fc5f538a726c4e48ea97edc0d2cdc11f103afee655c463fa742", size = 21113891, upload-time = "2025-01-16T21:34:00.277Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/33/e7/1925bfbc563c39c1d2e82501d8372734a5c725e53ac3b31b4c2d081e895b/imageio_ffmpeg-0.6.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1d47bebd83d2c5fc770720d211855f208af8a596c82d17730aa51e815cdee6dc", size = 25632706, upload-time = "2025-01-16T21:33:53.475Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a0/2d/43c8522a2038e9d0e7dbdf3a61195ecc31ca576fb1527a528c877e87d973/imageio_ffmpeg-0.6.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c7e46fcec401dd990405049d2e2f475e2b397779df2519b544b8aab515195282", size = 29498237, upload-time = "2025-01-16T21:34:13.726Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a0/13/59da54728351883c3c1d9fca1710ab8eee82c7beba585df8f25ca925f08f/imageio_ffmpeg-0.6.0-py3-none-win32.whl", hash = "sha256:196faa79366b4a82f95c0f4053191d2013f4714a715780f0ad2a68ff37483cc2", size = 19652251, upload-time = "2025-01-16T21:34:06.812Z" },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/2c/c6/fa760e12a2483469e2bf5058c5faff664acf66cadb4df2ad6205b016a73d/imageio_ffmpeg-0.6.0-py3-none-win_amd64.whl", hash = "sha256:02fa47c83703c37df6bfe4896aab339013f62bf02c5ebf2dce6da56af04ffc0a", size = 31246824, upload-time = "2025-01-16T21:34:28.6Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "iniconfig"
|
name = "iniconfig"
|
||||||
version = "2.3.0"
|
version = "2.3.0"
|
||||||
@@ -772,6 +800,8 @@ name = "mlx-video"
|
|||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "huggingface-hub" },
|
{ name = "huggingface-hub" },
|
||||||
|
{ name = "imageio" },
|
||||||
|
{ name = "imageio-ffmpeg" },
|
||||||
{ name = "mlx" },
|
{ name = "mlx" },
|
||||||
{ name = "mlx-vlm" },
|
{ name = "mlx-vlm" },
|
||||||
{ name = "numpy" },
|
{ name = "numpy" },
|
||||||
@@ -790,6 +820,8 @@ dev = [
|
|||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "huggingface-hub" },
|
{ name = "huggingface-hub" },
|
||||||
|
{ name = "imageio", specifier = ">=2.37.2" },
|
||||||
|
{ name = "imageio-ffmpeg", specifier = ">=0.6.0" },
|
||||||
{ name = "mlx", specifier = ">=0.22.0" },
|
{ name = "mlx", specifier = ">=0.22.0" },
|
||||||
{ name = "mlx-vlm" },
|
{ name = "mlx-vlm" },
|
||||||
{ name = "numpy" },
|
{ name = "numpy" },
|
||||||
|
|||||||
Reference in New Issue
Block a user