chore: Cleanup -- reorganize README and docs
This commit is contained in:
158
.github/copilot-instructions.md
vendored
158
.github/copilot-instructions.md
vendored
@@ -1,158 +0,0 @@
|
||||
# MLX-Video Copilot Instructions
|
||||
|
||||
## Overview
|
||||
|
||||
MLX-Video is a video/audio generation package using Apple MLX framework. It implements the LTX-2 model (19B parameter DiT) for text-to-video, image-to-video, and audio-video generation, optimized for Apple Silicon.
|
||||
|
||||
## Build, Test, and Lint
|
||||
|
||||
### Testing
|
||||
```bash
|
||||
# Install test dependencies first (pytest not in main deps)
|
||||
pip install pytest
|
||||
|
||||
# Run all tests
|
||||
python -m pytest tests/
|
||||
|
||||
# Run specific test file
|
||||
python -m pytest tests/test_generate_dev.py
|
||||
|
||||
# Run specific test
|
||||
python -m pytest tests/test_generate_dev.py::TestLTX2Scheduler::test_scheduler_output_shape
|
||||
```
|
||||
|
||||
### Linting
|
||||
Pre-commit hooks configured with:
|
||||
- **black**: Code formatting
|
||||
- **isort**: Import sorting (profile: black)
|
||||
- **autoflake**: Remove unused imports
|
||||
|
||||
```bash
|
||||
# Run pre-commit manually
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
### Running Generation
|
||||
```bash
|
||||
# Quick test - distilled model (two-stage pipeline)
|
||||
python -m mlx_video.generate --prompt "test video" --num-frames 33
|
||||
|
||||
# Dev model with CFG (single-stage, higher quality)
|
||||
python -m mlx_video.generate_dev --prompt "test video" --steps 40 --cfg-scale 4.0
|
||||
|
||||
# Audio-video generation
|
||||
python -m mlx_video.generate_av --prompt "test video" --output-path out.mp4 --output-audio out.wav
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
### Two-Stage Pipeline (Distilled Model)
|
||||
The distilled model (`generate.py`) uses a two-stage approach for efficiency:
|
||||
1. **Stage 1**: Generate at half resolution with 8 denoising steps using STAGE_1_SIGMAS
|
||||
2. **Upsampler**: 2x spatial upsampling via LatentUpsampler
|
||||
3. **Stage 2**: Refine at full resolution with 3 steps using STAGE_2_SIGMAS
|
||||
4. **VAE Decoder**: Convert latents to RGB video (tiled decoding for memory efficiency)
|
||||
|
||||
### Single-Stage Pipeline (Dev Model)
|
||||
The dev model (`generate_dev.py`) uses classifier-free guidance (CFG):
|
||||
- Full resolution generation with configurable steps (typically 40)
|
||||
- CFG guidance scale controls prompt adherence vs. diversity
|
||||
- More flexible but slower than distilled model
|
||||
|
||||
### Core Components
|
||||
|
||||
**DiT Transformer** (`models/ltx/ltx.py`):
|
||||
- 48 layers, 32 attention heads, 128 dim per head
|
||||
- Dual modality support: video (3840-dim) and audio (2048-dim) embeddings
|
||||
- Uses RoPE (Rotary Position Embeddings) in SPLIT mode with double precision
|
||||
- AdaLN-Zero conditioning blocks inject timestep/text embeddings
|
||||
|
||||
**VAE Architecture**:
|
||||
- **Video VAE**: 128 latent channels, 8x temporal + 32x spatial compression
|
||||
- Encoder: `models/ltx/video_vae/encoder.py`
|
||||
- Decoder: `models/ltx/video_vae/decoder.py` (supports tiled decoding)
|
||||
- **Audio VAE**: 8 latent channels, mel-spectrogram intermediate
|
||||
- Decoder: `models/ltx/audio_vae/decoder.py`
|
||||
- HiFi-GAN vocoder: `models/ltx/audio_vae/vocoder.py`
|
||||
|
||||
**Text Encoder** (`models/ltx/text_encoder.py`):
|
||||
- Based on Gemma 3 model
|
||||
- Returns separate embeddings for video (3840-dim) and audio (2048-dim)
|
||||
- Supports prompt enhancement via `enhance_t2v()` method
|
||||
|
||||
**Tiling System** (`models/ltx/video_vae/tiling.py`):
|
||||
- Memory-efficient decoding for large videos
|
||||
- Modes: auto, default (512px/64f), aggressive (256px/32f), conservative (768px/96f)
|
||||
- Supports streaming via `on_frames_ready` callback
|
||||
|
||||
### Key Patterns
|
||||
|
||||
**Position Grids**:
|
||||
- Created in pixel space, then converted to latent space internally
|
||||
- Video: (B, 3, num_patches, 2) with [start, end) bounds for temporal/spatial dims
|
||||
- Audio: (B, 1, num_patches, 2) for temporal dimension only
|
||||
- See `create_position_grid()` in generate modules
|
||||
|
||||
**Latent Conditioning** (`conditioning/latent.py`):
|
||||
- `LatentState` tracks clean latents, noise, and sigma values
|
||||
- `VideoConditionByLatentIndex` enables I2V by conditioning specific frames
|
||||
- `apply_denoise_mask()` protects conditioned regions during denoising
|
||||
|
||||
**Weight Loading**:
|
||||
- `convert.py`: Downloads from HuggingFace, converts PyTorch → MLX format
|
||||
- Sanitization functions (`sanitize_transformer_weights`, `sanitize_vae_encoder_weights`) adapt keys
|
||||
- Uses safetensors for efficient loading
|
||||
|
||||
## Key Conventions
|
||||
|
||||
### Model Configuration
|
||||
- Always use `LTXModelConfig` to instantiate models
|
||||
- `model_type` determines modality: `VideoOnly`, `AudioOnly`, or `AudioVideo`
|
||||
- `rope_type=LTXRopeType.SPLIT` and `double_precision_rope=True` are standard
|
||||
|
||||
### Frame Count Requirements
|
||||
- **Distilled model**: `num_frames = 1 + 8*k` format (e.g., 33, 65, 97)
|
||||
- **Dev model**: No strict requirement, but odd numbers work better
|
||||
- Audio frames auto-computed from video duration via `AUDIO_LATENTS_PER_SECOND`
|
||||
|
||||
### Dimension Constraints
|
||||
- Video height/width must be divisible by 64 (VAE spatial compression)
|
||||
- Latent dimensions are pixel dimensions divided by 32
|
||||
|
||||
### Audio Constants
|
||||
```python
|
||||
AUDIO_SAMPLE_RATE = 24000 # Output sample rate
|
||||
AUDIO_LATENT_SAMPLE_RATE = 16000 # VAE internal rate
|
||||
AUDIO_HOP_LENGTH = 160 # Mel hop length
|
||||
AUDIO_LATENT_CHANNELS = 8 # Audio latent channels
|
||||
AUDIO_MEL_BINS = 16 # Mel frequency bins
|
||||
```
|
||||
|
||||
### Sigma Schedules
|
||||
Distilled model uses predefined schedules (no scheduler class):
|
||||
```python
|
||||
STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
|
||||
STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
|
||||
```
|
||||
|
||||
Dev model computes schedules via `ltx2_scheduler(steps)` function.
|
||||
|
||||
### Code Style
|
||||
- Follow black formatting (configured in pre-commit)
|
||||
- Import sorting: isort with black profile
|
||||
- Remove unused imports (autoflake)
|
||||
- Type hints encouraged but not enforced
|
||||
|
||||
### Modality Enum
|
||||
Use `Modality.VIDEO` and `Modality.AUDIO` from `models/ltx/transformer.py` for multi-modal operations.
|
||||
|
||||
### Video Post-Processing
|
||||
- `postprocess.py`: Contains utilities for frame normalization and video saving
|
||||
- Always denormalize latents from [-1, 1] to [0, 255] before saving
|
||||
- Use opencv-python for video I/O
|
||||
|
||||
## Python Requirements
|
||||
- Python >= 3.11
|
||||
- MLX >= 0.22.0
|
||||
- Primary dependencies: numpy, safetensors, transformers, opencv-python, Pillow, mlx-vlm, scipy, librosa
|
||||
- Package manager: uv recommended for faster installs, pip also supported
|
||||
26
.github/skills/fast-mlx/SKILL.md
vendored
26
.github/skills/fast-mlx/SKILL.md
vendored
@@ -1,26 +0,0 @@
|
||||
---
|
||||
name: fast-mlx
|
||||
description: Optimize MLX code for performance and memory. Use when asked to implement or speed up MLX models or algorithms, reduce latency/throughput bottlenecks, tune lazy evaluation, type promotion, fast ops, compilation, memory use, or profiling.
|
||||
---
|
||||
|
||||
# Fast MLX
|
||||
|
||||
## Workflow
|
||||
|
||||
- Looks for opportunities to compile functions of mostly elementwise operations.
|
||||
- For models with fixed shape inputs or where the shapes don't change much, compile the entire graph
|
||||
- Replace slow implementations with MLX fast ops
|
||||
- Identify evaluation boundaries and unintended sync points (`mx.eval`, `item()`, NumPy conversions).
|
||||
- Check dtype promotion and scalar usage; keep precision consistent with intent.
|
||||
- Review compilation strategy; avoid unnecessary recompiles and closure captures.
|
||||
- Reduce peak memory via lazy loading order and releasing temporaries before `mx.eval`.
|
||||
- Suggest profiling steps if the bottleneck is unclear.
|
||||
|
||||
## References
|
||||
|
||||
- Read `references/fast-mlx-guide.md` for detailed tips and examples. Use it as the source of truth.
|
||||
|
||||
## Output expectations
|
||||
|
||||
- Provide concrete code changes with brief rationale
|
||||
- Call out changes that need user confirmation (e.g., enabling async eval or shapeless compile).
|
||||
350
.github/skills/fast-mlx/references/fast-mlx-guide.md
vendored
350
.github/skills/fast-mlx/references/fast-mlx-guide.md
vendored
@@ -1,350 +0,0 @@
|
||||
# Making MLX Go Fast
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Graph Evaluation](#graph-evaluation)
|
||||
- [Type Promotion](#type-promotion)
|
||||
- [Operations](#operations)
|
||||
- [Compile](#compile)
|
||||
- [Memory Use](#memory-use)
|
||||
- [Profiling](#profiling)
|
||||
|
||||
This guide assumes you have some familiarity with MLX and want to make your MLX
|
||||
model or algorithm as efficient as possible.
|
||||
|
||||
### Graph Evaluation
|
||||
|
||||
Recall, MLX is lazy. When you call an MLX op, no computation actually happens.
|
||||
You are simply building a graph. The computation happens when you explicitly or
|
||||
implicitly evaluate an array. Read more about how this works in the
|
||||
documentation:
|
||||
https://ml-explore.github.io/mlx/build/html/usage/lazy_evaluation.html
|
||||
|
||||
Evaluating the graph incurs some overhead, so don't do it too frequently.
|
||||
Conversely you don't want the graph to get too big before evaluating it as this
|
||||
can also be expensive. Most numerical and machine learning algorithms are
|
||||
iterative. A good place to evaluate the graph is at the end of each iteration.
|
||||
Some examples:
|
||||
|
||||
- After an iteration of gradient descent
|
||||
- After producing one token with a language model
|
||||
- After taking one denoising step in a diffusion model
|
||||
|
||||
Overly frequent evaluations sometimes happen by accident. For example:
|
||||
|
||||
```python
|
||||
# output is an mx.array
|
||||
for x in output:
|
||||
do_something(x.item())
|
||||
```
|
||||
|
||||
The same thing can be written more explicitly with operations and `mx.eval` as:
|
||||
|
||||
```python
|
||||
for i in range(len(output)):
|
||||
x = output[i]
|
||||
mx.eval(x)
|
||||
do_something(x.item())
|
||||
```
|
||||
|
||||
Two better options are:
|
||||
|
||||
1. When possible avoid calling `item()` and do everything in MLX.
|
||||
2. Move the entire output to Python or NumPy first.
|
||||
|
||||
An example of the second approach:
|
||||
|
||||
```python
|
||||
for x in output.tolist():
|
||||
do_something(x)
|
||||
```
|
||||
|
||||
#### Asynchronous Evaluation
|
||||
|
||||
For a latency sensitive computation which is run many times, `mx.async_eval`
|
||||
can be useful. Normally `mx.eval` is synchronous. It returns only when the
|
||||
computation is complete. Instead `mx.async_eval` asynchronously evaluates the
|
||||
graph and returns to the main thread immediately. You can use this to pipeline
|
||||
graph construction with computation like so:
|
||||
|
||||
```python
|
||||
def generator():
|
||||
out = mx.async_eval(my_function())
|
||||
|
||||
while True:
|
||||
out_next = mx.async_eval(my_function())
|
||||
mx.eval(out)
|
||||
yield out
|
||||
out = out_next
|
||||
```
|
||||
|
||||
For this to work `my_function()` cannot do any synchronous evaluations (e.g.
|
||||
calling `mx.eval`, converting to NumPy, etc.). Furthermore, any work done on
|
||||
`out` that is synchronous and on the same stream can stall the pipeline:
|
||||
|
||||
```python
|
||||
for out in generator():
|
||||
out = out * 2
|
||||
# Stalls the pipeline!
|
||||
mx.eval(out)
|
||||
```
|
||||
|
||||
An easy fix for this is to put the pipeline in a separate stream:
|
||||
|
||||
```python
|
||||
def generator():
|
||||
with mx.stream(mx.new_stream(mx.gpu)):
|
||||
out = mx.async_eval(my_function())
|
||||
|
||||
while True:
|
||||
out_next = mx.async_eval(my_function())
|
||||
mx.eval(out)
|
||||
yield out
|
||||
out = out_next
|
||||
```
|
||||
|
||||
### Type Promotion
|
||||
|
||||
One of the most common performance issues comes from accidental up-casting.
|
||||
Make sure you understand how type promotion works in MLX. The inputs to an MLX
|
||||
operation are typically promoted to a common type which doesn't lose precision.
|
||||
For example:
|
||||
|
||||
```python
|
||||
x = mx.array(1.0, mx.float32) * mx.array(2.0, mx.float16)
|
||||
```
|
||||
|
||||
will result in `x` with type `mx.float32`. Similarly:
|
||||
|
||||
```python
|
||||
x = mx.array(1.0, mx.bfloat16) * mx.array(2.0, mx.float16)
|
||||
```
|
||||
|
||||
will result in `x` with type `mx.float32`. A common mistake is to multiply a
|
||||
half-precision array by a default-typed scalar array which up-casts everything
|
||||
to `mx.float32`:
|
||||
|
||||
```python
|
||||
# Warning: x has type mx.float32
|
||||
x = my_fp16_array * mx.array(2.0)
|
||||
```
|
||||
|
||||
To multiply by a scalar while preserving the input type, use Python scalars.
|
||||
Python scalars are weakly typed and have more relaxed promotion rules when
|
||||
used with MLX arrays.
|
||||
|
||||
```python
|
||||
# Ok, x has type mx.float16
|
||||
x = my_fp16_array * 2.0
|
||||
```
|
||||
|
||||
### Operations
|
||||
|
||||
#### Use Fast Ops
|
||||
|
||||
Use `mx.fast` ops when possible:
|
||||
|
||||
- `mx.fast.rms_norm`
|
||||
- `mx.fast.layer_norm`
|
||||
- `mx.fast.rope`
|
||||
- `mx.fast.scaled_dot_product_attention`
|
||||
|
||||
A lot of these operations take a variety of parameters so they can be used for
|
||||
different variations of the function. For example, the weight and bias
|
||||
parameters are optional in `mx.fast.layer_norm` so it can be used with
|
||||
different permutations of inputs.
|
||||
|
||||
#### Precision
|
||||
|
||||
For operations which typically use higher precision there is usually no
|
||||
need to explicitly upcast. For example, `mx.fast.rms_norm` and
|
||||
`mx.fast.layer_norm` accumulate in higher precision so it's
|
||||
wasteful to upcast and downcast into and out of these operations:
|
||||
|
||||
```python
|
||||
# No need for this!
|
||||
mx.fast.rms_norm(x.astype(mx.float32), w, b, eps).astype(x.dtype)
|
||||
|
||||
# This is just as good:
|
||||
mx.fast.rms_norm(x, w, b, eps)
|
||||
```
|
||||
|
||||
Similarly, for `mx.softmax` use `precise=True` if you want to do the softmax in
|
||||
higher precision rather than explicitly casting the input and output.
|
||||
|
||||
#### Misc
|
||||
|
||||
- For vector-matrix multiplication `x @ W.T` is faster than `x @ W`, for
|
||||
matrix-vector multiplication `W @ x` is faster than `W.T @ x`
|
||||
- Use `mx.addmm` for `a @ b + c` (e.g. a linear layer with a bias).
|
||||
- Where it makes sense, use `mx.take_along_axis` and `mx.put_along_axis`
|
||||
instead of fancy indexing
|
||||
- Use broadcasting instead of concatenation. For example, prefer `mx.repeat(a,
|
||||
n)` over `mx.concatenate([a] * n)`
|
||||
|
||||
### Compile
|
||||
|
||||
Compiling graphs with `mx.compile` can make them run a lot faster. But there
|
||||
are some sharp-edges that are good to be aware of.
|
||||
|
||||
First, be aware of when a function will be recompiled. Recompilation is
|
||||
relatively expensive and should only be done if there is sufficient work over
|
||||
which to amortize the cost.
|
||||
|
||||
The default behavior of `mx.compile` is to do a shape-dependent compilation.
|
||||
This means the function will be recompiled if the shape of any input changes.
|
||||
|
||||
MLX supports a shapeless compilation by passing `shapeless=True` to
|
||||
`mx.compile`. It's easy to make hard-to-detect mistakes with shapeless
|
||||
compilation. Make sure to read and understand the documentation and use it
|
||||
with care:
|
||||
https://ml-explore.github.io/mlx/build/html/usage/compile.html#shapeless-compilation
|
||||
|
||||
A function will also be recompiled if any constant inputs change:
|
||||
|
||||
```python
|
||||
@mx.compile
|
||||
def fun(x, scale):
|
||||
return scale * x
|
||||
|
||||
fun(x, 3)
|
||||
|
||||
# Recompiles!
|
||||
fun(x, 4)
|
||||
```
|
||||
|
||||
In this case a simple fix is to make `scale` an `mx.array`.
|
||||
|
||||
#### Compiling Closures
|
||||
|
||||
Be careful when compiling a closure where the function encloses any
|
||||
`mx.array`.
|
||||
|
||||
```python
|
||||
y = some_function()
|
||||
|
||||
@mx.compile
|
||||
def fun(x):
|
||||
return x + y
|
||||
```
|
||||
|
||||
Since `y` is not an input to `fun`, the compiled graph will include the entire
|
||||
computation which produces `y`. Usually you only want to compute `y` one time
|
||||
and re-use it in the compiled function. Either explicitly pass it as an input
|
||||
to `fun` or pass it as an implicit input to `mx.compile` like so:
|
||||
|
||||
```python
|
||||
y = some_function()
|
||||
|
||||
@partial(mx.compile, inputs=[y])
|
||||
def fun(x):
|
||||
return x + y
|
||||
```
|
||||
|
||||
### Memory Use
|
||||
|
||||
#### Lazy Loading
|
||||
|
||||
Loading arrays from a file is lazy in MLX:
|
||||
|
||||
```python
|
||||
weights = mx.load("model.safetensors")
|
||||
```
|
||||
|
||||
The above function returns instantly, regardless of the file size. To actually
|
||||
load the weights into memory, you can do `mx.eval(weights)`.
|
||||
|
||||
Assume the weights are stored on disk in 32-bit precision (i.e. `mx.float32`).
|
||||
But for your model you only need 16-bit precision:
|
||||
|
||||
```python
|
||||
weights = mx.load("model.safetensors")
|
||||
mx.eval(weights)
|
||||
weights = {k: v.astype(mx.float16) for k, v in weights.items()}
|
||||
```
|
||||
|
||||
In the above, the weights will be loaded into memory in full precision and then
|
||||
cast to 16-bit. This requires memory for all the weights in 32-bit plus memory
|
||||
for the weights in 16-bit.
|
||||
|
||||
This is much better:
|
||||
|
||||
```python
|
||||
weights = mx.load("model.safetensors")
|
||||
weights = {k: v.astype(mx.float16) for k, v in weights.items()}
|
||||
mx.eval(weights)
|
||||
```
|
||||
|
||||
Evaluating after the cast to `mx.float16` reduces peak memory by nearly a
|
||||
third. That's because all the weights are never fully materialized in 32-bit.
|
||||
Right after each weight is loaded in 32-bit precision it is cast to 16-bit.
|
||||
The memory for the 32-bit weight can be reused when loading the next weight.
|
||||
|
||||
Note, MLX is only able to lazy load from a file when it is given to `mx.load`
|
||||
as a string path. Due to lifetime management issues, lazy loading from file
|
||||
handles is not supported. So avoid this:
|
||||
|
||||
```python
|
||||
weights = mx.load(open("model.safetensors", 'rb'))
|
||||
```
|
||||
|
||||
#### Release Temporaries
|
||||
|
||||
One way to reduce memory consumption is to avoid holding
|
||||
temporaries you don't need. This is a typical training loop:
|
||||
|
||||
```python
|
||||
for x, y in dataset:
|
||||
loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
mx.eval(model, optimizer.state)
|
||||
```
|
||||
|
||||
It's suboptimal since a reference to `grads` is held during the call to
|
||||
`mx.eval` which keeps the respective memory from being used for any other part
|
||||
of the computation.
|
||||
|
||||
This is better:
|
||||
|
||||
```python
|
||||
def step(x, y):
|
||||
loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
for x, y in dataset:
|
||||
loss = step(x, y)
|
||||
mx.eval(model, optimizer.state)
|
||||
```
|
||||
|
||||
In this case the reference to `grads` is released before `mx.eval` and the
|
||||
memory can be reused. You can achieve the same goal using `del` as long as it's
|
||||
before the call to `mx.eval`:
|
||||
|
||||
```python
|
||||
for x, y in dataset:
|
||||
loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
del grads
|
||||
mx.eval(model, optimizer.state)
|
||||
```
|
||||
|
||||
#### Misc
|
||||
|
||||
- MLX will cache memory buffers of recently released arrays rather than
|
||||
returning them to the system. In some cases, especially for variable shape
|
||||
computations, the cache can get large. To help with this, MLX has some
|
||||
functions for logging and customizing the behavior of memory allocation:
|
||||
https://ml-explore.github.io/mlx/build/html/python/metal.html
|
||||
|
||||
### Profiling
|
||||
|
||||
A good first step is to check GPU utilization using, for example,
|
||||
mactop: https://github.com/context-labs/mactop. If it's not pegged at close
|
||||
to 100% then there is likely a non-MLX bottleneck somewhere in the program. A
|
||||
common culprit is data loading or preprocessing.
|
||||
|
||||
If GPU utilization is good, a good next step is to figure out which operations
|
||||
are taking up so much time. One way to do this is with the Metal debugger. For
|
||||
that, see the documentation on profiling MLX with the Metal debugger:
|
||||
https://ml-explore.github.io/mlx/build/html/dev/metal_debugger.html
|
||||
134
README.md
134
README.md
@@ -80,94 +80,13 @@ python -m mlx_video.generate \
|
||||
|
||||
## Wan2.1 / Wan2.2
|
||||
|
||||
Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.com/Wan-Video/Wan2.2) are text-to-video diffusion models built on a DiT (Diffusion Transformer) backbone with a T5 text encoder and 3D VAE. They share the same model architecture — the difference is in the inference pipeline:
|
||||
Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.com/Wan-Video/Wan2.2) are text-to-video diffusion models built on a DiT (Diffusion Transformer) backbone with a T5 text encoder and 3D VAE.
|
||||
|
||||
| | Wan2.1 | Wan2.2 T2V-14B | Wan2.2 I2V-14B | Wan2.2 TI2V-5B |
|
||||
|---|--------|--------|--------|--------|
|
||||
| **Task** | Text-to-Video | Text-to-Video | Image-to-Video | Text+Image-to-Video |
|
||||
| **Pipeline** | Single model | Dual model | Dual model | Single model |
|
||||
| **Sizes** | 1.3B, 14B | 14B | 14B | 5B |
|
||||
| **Steps** | 50 | 40 | 40 | 40 |
|
||||
| **Guidance** | 5.0 (fixed) | 3.0 / 4.0 | 3.5 / 3.5 | 5.0 (fixed) |
|
||||
| **Shift** | 5.0 | 12.0 | 5.0 | 5.0 |
|
||||
| **VAE** | Wan2.1 (z=16) | Wan2.1 (z=16) | Wan2.1 (z=16) + encoder | Wan2.2 (z=48) |
|
||||
### Step 0: Download and Convert Weights
|
||||
|
||||
### Step 1: Download Weights
|
||||
See the dedicated Wan2.1/Wan2.2 [README.md](mlx_video/models/wan/README.md) for details.
|
||||
|
||||
Download the original PyTorch checkpoints:
|
||||
|
||||
**Wan2.1 (14B)**
|
||||
```bash
|
||||
# From https://github.com/Wan-Video/Wan2.1 or HuggingFace
|
||||
# Expected directory structure:
|
||||
# wan21_checkpoint/
|
||||
# ├── models_t5_umt5-xxl-enc-bf16.pth
|
||||
# ├── Wan2.1_VAE.pth
|
||||
# └── diffusion_pytorch_model*.safetensors # single model
|
||||
```
|
||||
|
||||
**Wan2.1 (1.3B)** — same structure, smaller transformer weights.
|
||||
|
||||
**Wan2.2 (14B)**
|
||||
```bash
|
||||
# From https://github.com/Wan-Video/Wan2.2 or HuggingFace
|
||||
# Expected directory structure:
|
||||
# wan22_checkpoint/
|
||||
# ├── models_t5_umt5-xxl-enc-bf16.pth
|
||||
# ├── Wan2.1_VAE.pth
|
||||
# ├── low_noise_model/ # safetensors
|
||||
# └── high_noise_model/ # safetensors
|
||||
```
|
||||
|
||||
**Wan2.2 I2V-14B** — same directory structure as Wan2.2 T2V. The conversion script auto-detects I2V-14B from the model's `config.json` (`model_type: "i2v"`, `in_dim: 36`).
|
||||
|
||||
### Step 2: Convert to MLX Format
|
||||
|
||||
The conversion script auto-detects the model version based on the directory structure (presence of `low_noise_model/` subdirectory) and model type (`model_type` in source config.json for I2V vs T2V).
|
||||
|
||||
```bash
|
||||
# Auto-detect version
|
||||
python -m mlx_video.convert_wan \
|
||||
--checkpoint-dir /path/to/wan_checkpoint \
|
||||
--output-dir wan_mlx
|
||||
|
||||
# Explicit version
|
||||
python -m mlx_video.convert_wan \
|
||||
--checkpoint-dir /path/to/wan21_checkpoint \
|
||||
--output-dir wan21_mlx \
|
||||
--model-version 2.1
|
||||
|
||||
python -m mlx_video.convert_wan \
|
||||
--checkpoint-dir /path/to/wan22_checkpoint \
|
||||
--output-dir wan22_mlx \
|
||||
--model-version 2.2
|
||||
```
|
||||
|
||||
#### Conversion Options
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `--checkpoint-dir` | (required) | Path to original PyTorch checkpoint directory |
|
||||
| `--output-dir` | `wan_mlx_model` | Output path for MLX model |
|
||||
| `--dtype` | `bfloat16` | Target dtype (`float16`, `float32`, `bfloat16`) |
|
||||
| `--model-version` | `auto` | Model version: `2.1`, `2.2`, or `auto` |
|
||||
| `--quantize` | off | Quantize transformer weights for reduced memory |
|
||||
| `--bits` | `4` | Quantization bits: `4` or `8` |
|
||||
| `--group-size` | `64` | Quantization group size: `32`, `64`, or `128` |
|
||||
|
||||
The converter produces:
|
||||
```
|
||||
wan_mlx/
|
||||
├── config.json # Model configuration
|
||||
├── t5_encoder.safetensors # T5 UMT5-XXL text encoder
|
||||
├── vae.safetensors # 3D VAE decoder
|
||||
├── vae_encoder.safetensors # 3D VAE encoder (I2V-14B only)
|
||||
├── model.safetensors # (Wan2.1) Single transformer
|
||||
├── low_noise_model.safetensors # (Wan2.2) Low-noise transformer
|
||||
└── high_noise_model.safetensors # (Wan2.2) High-noise transformer
|
||||
```
|
||||
|
||||
### Step 3: Generate Video
|
||||
### Step 1: Generate Video
|
||||
|
||||
```bash
|
||||
# Wan2.1 — uses defaults from config (50 steps, shift=5.0, guide=5.0)
|
||||
@@ -231,52 +150,7 @@ The I2V-14B model encodes the input image through the Wan2.1 VAE encoder and use
|
||||
| `--seed` | -1 (random) | Random seed for reproducibility |
|
||||
| `--output-path` | `output.mp4` | Output video path |
|
||||
|
||||
### Quantization (Reduced Memory)
|
||||
|
||||
Quantize the transformer weights to reduce memory usage by ~3.4x. This is especially useful for the 14B model or memory-constrained devices:
|
||||
|
||||
```bash
|
||||
# Convert with 4-bit quantization
|
||||
python -m mlx_video.convert_wan \
|
||||
--checkpoint-dir /path/to/Wan2.1-T2V-1.3B \
|
||||
--output-dir wan21_mlx_q4 \
|
||||
--quantize --bits 4 --group-size 64
|
||||
|
||||
# Generate with quantized model (auto-detected from config.json)
|
||||
python -m mlx_video.generate_wan \
|
||||
--model-dir wan21_mlx_q4 \
|
||||
--prompt "A cat playing piano"
|
||||
```
|
||||
|
||||
**What gets quantized**: Self-attention (Q/K/V/O), cross-attention (Q/K/V/O), and FFN (fc1/fc2) — 10 layers × N blocks = ~95% of model weights. Embeddings, norms, and the output head remain in bfloat16 for precision.
|
||||
|
||||
| Model | BF16 Size | 4-bit Size | Notes |
|
||||
|-------|-----------|------------|-------|
|
||||
| 1.3B | 2.7 GB | 799 MB | ~3.4x smaller |
|
||||
| 14B | ~28 GB | ~8 GB | Enables running on 16GB devices |
|
||||
|
||||
> **Note**: On Apple Silicon, the 1.3B model fits comfortably in unified memory at bf16. Quantization reduces memory but may not speed up inference for small models. For the 14B model, quantization is essential to fit in memory and will also improve speed.
|
||||
|
||||
|
||||
### Wan Model Specifications
|
||||
|
||||
**Transformer (14B)**
|
||||
- 40 layers, 40 attention heads, dim 5120, head dim 128
|
||||
- 3-way factorized RoPE (temporal + spatial)
|
||||
- 14.29B parameters
|
||||
|
||||
**Transformer (1.3B, Wan2.1 only)**
|
||||
- 30 layers, 12 attention heads, dim 1536, head dim 128
|
||||
- Same architecture, smaller scale
|
||||
|
||||
**Text Encoder** — UMT5-XXL (5.68B parameters)
|
||||
- 24 layers, 64 heads, dim 4096, vocab 256K
|
||||
|
||||
**VAE** — 3D causal convolution decoder (72.6M parameters)
|
||||
- Latent channels: 16
|
||||
- Compression: 4× temporal, 8× spatial
|
||||
|
||||
---
|
||||
|
||||
## Requirements
|
||||
|
||||
|
||||
157
mlx_video/models/wan/README.md
Normal file
157
mlx_video/models/wan/README.md
Normal file
@@ -0,0 +1,157 @@
|
||||
|
||||
## Wan2.1 / Wan2.2
|
||||
|
||||
Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.com/Wan-Video/Wan2.2) are text-to-video diffusion models built on a DiT (Diffusion Transformer) backbone with a T5 text encoder and 3D VAE.
|
||||
|
||||
They share the same model architecture — the difference is in the inference pipeline:
|
||||
|
||||
| | Wan2.1 | Wan2.2 T2V-14B | Wan2.2 I2V-14B | Wan2.2 TI2V-5B |
|
||||
|---|--------|--------|--------|--------|
|
||||
| **Task** | Text-to-Video | Text-to-Video | Image-to-Video | Text+Image-to-Video |
|
||||
| **Pipeline** | Single model | Dual model | Dual model | Single model |
|
||||
| **Sizes** | 1.3B, 14B | 14B | 14B | 5B |
|
||||
| **Steps** | 50 | 40 | 40 | 40 |
|
||||
| **Guidance** | 5.0 (fixed) | 3.0 / 4.0 | 3.5 / 3.5 | 5.0 (fixed) |
|
||||
| **Shift** | 5.0 | 12.0 | 5.0 | 5.0 |
|
||||
| **VAE** | Wan2.1 (z=16) | Wan2.1 (z=16) | Wan2.1 (z=16) + encoder | Wan2.2 (z=48) |
|
||||
|
||||
### Step 1: Download Weights
|
||||
|
||||
Download the original PyTorch checkpoints:
|
||||
|
||||
**Wan2.1 (14B)**
|
||||
```bash
|
||||
# From https://github.com/Wan-Video/Wan2.1 or HuggingFace
|
||||
# Expected directory structure:
|
||||
# wan21_checkpoint/
|
||||
# ├── models_t5_umt5-xxl-enc-bf16.pth
|
||||
# ├── Wan2.1_VAE.pth
|
||||
# └── diffusion_pytorch_model*.safetensors # single model
|
||||
```
|
||||
|
||||
**Wan2.1 (1.3B)** — same structure, smaller transformer weights.
|
||||
|
||||
**Wan2.2 (14B)**
|
||||
```bash
|
||||
# From https://github.com/Wan-Video/Wan2.2 or HuggingFace
|
||||
# Expected directory structure:
|
||||
# wan22_checkpoint/
|
||||
# ├── models_t5_umt5-xxl-enc-bf16.pth
|
||||
# ├── Wan2.1_VAE.pth
|
||||
# ├── low_noise_model/ # safetensors
|
||||
# └── high_noise_model/ # safetensors
|
||||
```
|
||||
|
||||
**Wan2.2 I2V-14B** — same directory structure as Wan2.2 T2V. The conversion script auto-detects I2V-14B from the model's `config.json` (`model_type: "i2v"`, `in_dim: 36`).
|
||||
|
||||
### Step 2: Convert to MLX Format
|
||||
|
||||
The conversion script auto-detects the model version based on the directory structure (presence of `low_noise_model/` subdirectory) and model type (`model_type` in source config.json for I2V vs T2V).
|
||||
|
||||
```bash
|
||||
# Auto-detect version
|
||||
python -m mlx_video.convert_wan \
|
||||
--checkpoint-dir /path/to/wan_checkpoint \
|
||||
--output-dir wan_mlx
|
||||
|
||||
# Explicit version
|
||||
python -m mlx_video.convert_wan \
|
||||
--checkpoint-dir /path/to/wan21_checkpoint \
|
||||
--output-dir wan21_mlx \
|
||||
--model-version 2.1
|
||||
|
||||
python -m mlx_video.convert_wan \
|
||||
--checkpoint-dir /path/to/wan22_checkpoint \
|
||||
--output-dir wan22_mlx \
|
||||
--model-version 2.2
|
||||
```
|
||||
|
||||
#### Conversion Options
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `--checkpoint-dir` | (required) | Path to original PyTorch checkpoint directory |
|
||||
| `--output-dir` | `wan_mlx_model` | Output path for MLX model |
|
||||
| `--dtype` | `bfloat16` | Target dtype (`float16`, `float32`, `bfloat16`) |
|
||||
| `--model-version` | `auto` | Model version: `2.1`, `2.2`, or `auto` |
|
||||
| `--quantize` | off | Quantize transformer weights for reduced memory |
|
||||
| `--bits` | `4` | Quantization bits: `4` or `8` |
|
||||
| `--group-size` | `64` | Quantization group size: `32`, `64`, or `128` |
|
||||
|
||||
The converter produces:
|
||||
```
|
||||
wan_mlx/
|
||||
├── config.json # Model configuration
|
||||
├── t5_encoder.safetensors # T5 UMT5-XXL text encoder
|
||||
├── vae.safetensors # 3D VAE decoder
|
||||
├── vae_encoder.safetensors # 3D VAE encoder (I2V-14B only)
|
||||
├── model.safetensors # (Wan2.1) Single transformer
|
||||
├── low_noise_model.safetensors # (Wan2.2) Low-noise transformer
|
||||
└── high_noise_model.safetensors # (Wan2.2) High-noise transformer
|
||||
```
|
||||
|
||||
### Quantization (Reduced Memory)
|
||||
|
||||
Quantize the transformer weights to reduce memory usage by ~3.4x. This is especially useful for the 14B model or memory-constrained devices:
|
||||
|
||||
```bash
|
||||
# Convert with 4-bit quantization
|
||||
python -m mlx_video.convert_wan \
|
||||
--checkpoint-dir /path/to/Wan2.1-T2V-1.3B \
|
||||
--output-dir wan21_mlx_q4 \
|
||||
--quantize --bits 4 --group-size 64
|
||||
|
||||
# Generate with quantized model (auto-detected from config.json)
|
||||
python -m mlx_video.generate_wan \
|
||||
--model-dir wan21_mlx_q4 \
|
||||
--prompt "A cat playing piano"
|
||||
```
|
||||
|
||||
**What gets quantized**: Self-attention (Q/K/V/O), cross-attention (Q/K/V/O), and FFN (fc1/fc2) — 10 layers × N blocks = ~95% of model weights. Embeddings, norms, and the output head remain in bfloat16 for precision.
|
||||
|
||||
| Model | BF16 Size | 4-bit Size | Notes |
|
||||
|-------|-----------|------------|-------|
|
||||
| 1.3B | 2.7 GB | 799 MB | ~3.4x smaller |
|
||||
| 14B | ~28 GB | ~8 GB | Enables running on 16GB devices |
|
||||
|
||||
> **Note**: On Apple Silicon, the 1.3B model fits comfortably in unified memory at bf16. Quantization reduces memory but may not speed up inference for small models. For the 14B model, quantization is essential to fit in memory and will also improve speed.
|
||||
|
||||
### Wan Model Specifications
|
||||
|
||||
**Transformer (14B)**
|
||||
- 40 layers, 40 attention heads, dim 5120, head dim 128
|
||||
- 3-way factorized RoPE (temporal + spatial)
|
||||
- 14.29B parameters
|
||||
|
||||
**Transformer (1.3B, Wan2.1 only)**
|
||||
- 30 layers, 12 attention heads, dim 1536, head dim 128
|
||||
- Same architecture, smaller scale
|
||||
|
||||
**Text Encoder** — UMT5-XXL (5.68B parameters)
|
||||
- 24 layers, 64 heads, dim 4096, vocab 256K
|
||||
|
||||
**VAE** — 3D causal convolution decoder (72.6M parameters)
|
||||
- Latent channels: 16
|
||||
- Compression: 4× temporal, 8× spatial
|
||||
|
||||
---
|
||||
|
||||
## LoRA Support
|
||||
|
||||
LoRA's can be used with the `--lora-high` and `--lora-low` command line switches.
|
||||
|
||||
For example, for using the the distilled [Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) LoRA, use the following command. Lightning speeds up generation by using only 4 steps and a CFG scale of 1.
|
||||
|
||||
```bash
|
||||
python -m mlx_video.generate_wan \
|
||||
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-T2V-A14B-MLX \
|
||||
--width 480 \
|
||||
--height 480 \
|
||||
--num-frames 121 \
|
||||
--prompt "Two dogs of the poodle breed sitting on a beach wearing sunglasses, close up, cinematic, sunset" \
|
||||
--steps 4 \
|
||||
--guide-scale 1 \
|
||||
--trim-first-frames 1 \
|
||||
--lora-high /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/high_noise_model.safetensors 1 \
|
||||
--lora-low /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/low_noise_model.safetensors 1
|
||||
```
|
||||
Reference in New Issue
Block a user