chore: Cleanup -- reorganize README and docs
This commit is contained in:
158
.github/copilot-instructions.md
vendored
158
.github/copilot-instructions.md
vendored
@@ -1,158 +0,0 @@
|
||||
# MLX-Video Copilot Instructions
|
||||
|
||||
## Overview
|
||||
|
||||
MLX-Video is a video/audio generation package using Apple MLX framework. It implements the LTX-2 model (19B parameter DiT) for text-to-video, image-to-video, and audio-video generation, optimized for Apple Silicon.
|
||||
|
||||
## Build, Test, and Lint
|
||||
|
||||
### Testing
|
||||
```bash
|
||||
# Install test dependencies first (pytest not in main deps)
|
||||
pip install pytest
|
||||
|
||||
# Run all tests
|
||||
python -m pytest tests/
|
||||
|
||||
# Run specific test file
|
||||
python -m pytest tests/test_generate_dev.py
|
||||
|
||||
# Run specific test
|
||||
python -m pytest tests/test_generate_dev.py::TestLTX2Scheduler::test_scheduler_output_shape
|
||||
```
|
||||
|
||||
### Linting
|
||||
Pre-commit hooks configured with:
|
||||
- **black**: Code formatting
|
||||
- **isort**: Import sorting (profile: black)
|
||||
- **autoflake**: Remove unused imports
|
||||
|
||||
```bash
|
||||
# Run pre-commit manually
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
### Running Generation
|
||||
```bash
|
||||
# Quick test - distilled model (two-stage pipeline)
|
||||
python -m mlx_video.generate --prompt "test video" --num-frames 33
|
||||
|
||||
# Dev model with CFG (single-stage, higher quality)
|
||||
python -m mlx_video.generate_dev --prompt "test video" --steps 40 --cfg-scale 4.0
|
||||
|
||||
# Audio-video generation
|
||||
python -m mlx_video.generate_av --prompt "test video" --output-path out.mp4 --output-audio out.wav
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
### Two-Stage Pipeline (Distilled Model)
|
||||
The distilled model (`generate.py`) uses a two-stage approach for efficiency:
|
||||
1. **Stage 1**: Generate at half resolution with 8 denoising steps using STAGE_1_SIGMAS
|
||||
2. **Upsampler**: 2x spatial upsampling via LatentUpsampler
|
||||
3. **Stage 2**: Refine at full resolution with 3 steps using STAGE_2_SIGMAS
|
||||
4. **VAE Decoder**: Convert latents to RGB video (tiled decoding for memory efficiency)
|
||||
|
||||
### Single-Stage Pipeline (Dev Model)
|
||||
The dev model (`generate_dev.py`) uses classifier-free guidance (CFG):
|
||||
- Full resolution generation with configurable steps (typically 40)
|
||||
- CFG guidance scale controls prompt adherence vs. diversity
|
||||
- More flexible but slower than distilled model
|
||||
|
||||
### Core Components
|
||||
|
||||
**DiT Transformer** (`models/ltx/ltx.py`):
|
||||
- 48 layers, 32 attention heads, 128 dim per head
|
||||
- Dual modality support: video (3840-dim) and audio (2048-dim) embeddings
|
||||
- Uses RoPE (Rotary Position Embeddings) in SPLIT mode with double precision
|
||||
- AdaLN-Zero conditioning blocks inject timestep/text embeddings
|
||||
|
||||
**VAE Architecture**:
|
||||
- **Video VAE**: 128 latent channels, 8x temporal + 32x spatial compression
|
||||
- Encoder: `models/ltx/video_vae/encoder.py`
|
||||
- Decoder: `models/ltx/video_vae/decoder.py` (supports tiled decoding)
|
||||
- **Audio VAE**: 8 latent channels, mel-spectrogram intermediate
|
||||
- Decoder: `models/ltx/audio_vae/decoder.py`
|
||||
- HiFi-GAN vocoder: `models/ltx/audio_vae/vocoder.py`
|
||||
|
||||
**Text Encoder** (`models/ltx/text_encoder.py`):
|
||||
- Based on Gemma 3 model
|
||||
- Returns separate embeddings for video (3840-dim) and audio (2048-dim)
|
||||
- Supports prompt enhancement via `enhance_t2v()` method
|
||||
|
||||
**Tiling System** (`models/ltx/video_vae/tiling.py`):
|
||||
- Memory-efficient decoding for large videos
|
||||
- Modes: auto, default (512px/64f), aggressive (256px/32f), conservative (768px/96f)
|
||||
- Supports streaming via `on_frames_ready` callback
|
||||
|
||||
### Key Patterns
|
||||
|
||||
**Position Grids**:
|
||||
- Created in pixel space, then converted to latent space internally
|
||||
- Video: (B, 3, num_patches, 2) with [start, end) bounds for temporal/spatial dims
|
||||
- Audio: (B, 1, num_patches, 2) for temporal dimension only
|
||||
- See `create_position_grid()` in generate modules
|
||||
|
||||
**Latent Conditioning** (`conditioning/latent.py`):
|
||||
- `LatentState` tracks clean latents, noise, and sigma values
|
||||
- `VideoConditionByLatentIndex` enables I2V by conditioning specific frames
|
||||
- `apply_denoise_mask()` protects conditioned regions during denoising
|
||||
|
||||
**Weight Loading**:
|
||||
- `convert.py`: Downloads from HuggingFace, converts PyTorch → MLX format
|
||||
- Sanitization functions (`sanitize_transformer_weights`, `sanitize_vae_encoder_weights`) adapt keys
|
||||
- Uses safetensors for efficient loading
|
||||
|
||||
## Key Conventions
|
||||
|
||||
### Model Configuration
|
||||
- Always use `LTXModelConfig` to instantiate models
|
||||
- `model_type` determines modality: `VideoOnly`, `AudioOnly`, or `AudioVideo`
|
||||
- `rope_type=LTXRopeType.SPLIT` and `double_precision_rope=True` are standard
|
||||
|
||||
### Frame Count Requirements
|
||||
- **Distilled model**: `num_frames = 1 + 8*k` format (e.g., 33, 65, 97)
|
||||
- **Dev model**: No strict requirement, but odd numbers work better
|
||||
- Audio frames auto-computed from video duration via `AUDIO_LATENTS_PER_SECOND`
|
||||
|
||||
### Dimension Constraints
|
||||
- Video height/width must be divisible by 64 (VAE spatial compression)
|
||||
- Latent dimensions are pixel dimensions divided by 32
|
||||
|
||||
### Audio Constants
|
||||
```python
|
||||
AUDIO_SAMPLE_RATE = 24000 # Output sample rate
|
||||
AUDIO_LATENT_SAMPLE_RATE = 16000 # VAE internal rate
|
||||
AUDIO_HOP_LENGTH = 160 # Mel hop length
|
||||
AUDIO_LATENT_CHANNELS = 8 # Audio latent channels
|
||||
AUDIO_MEL_BINS = 16 # Mel frequency bins
|
||||
```
|
||||
|
||||
### Sigma Schedules
|
||||
Distilled model uses predefined schedules (no scheduler class):
|
||||
```python
|
||||
STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
|
||||
STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
|
||||
```
|
||||
|
||||
Dev model computes schedules via `ltx2_scheduler(steps)` function.
|
||||
|
||||
### Code Style
|
||||
- Follow black formatting (configured in pre-commit)
|
||||
- Import sorting: isort with black profile
|
||||
- Remove unused imports (autoflake)
|
||||
- Type hints encouraged but not enforced
|
||||
|
||||
### Modality Enum
|
||||
Use `Modality.VIDEO` and `Modality.AUDIO` from `models/ltx/transformer.py` for multi-modal operations.
|
||||
|
||||
### Video Post-Processing
|
||||
- `postprocess.py`: Contains utilities for frame normalization and video saving
|
||||
- Always denormalize latents from [-1, 1] to [0, 255] before saving
|
||||
- Use opencv-python for video I/O
|
||||
|
||||
## Python Requirements
|
||||
- Python >= 3.11
|
||||
- MLX >= 0.22.0
|
||||
- Primary dependencies: numpy, safetensors, transformers, opencv-python, Pillow, mlx-vlm, scipy, librosa
|
||||
- Package manager: uv recommended for faster installs, pip also supported
|
||||
26
.github/skills/fast-mlx/SKILL.md
vendored
26
.github/skills/fast-mlx/SKILL.md
vendored
@@ -1,26 +0,0 @@
|
||||
---
|
||||
name: fast-mlx
|
||||
description: Optimize MLX code for performance and memory. Use when asked to implement or speed up MLX models or algorithms, reduce latency/throughput bottlenecks, tune lazy evaluation, type promotion, fast ops, compilation, memory use, or profiling.
|
||||
---
|
||||
|
||||
# Fast MLX
|
||||
|
||||
## Workflow
|
||||
|
||||
- Looks for opportunities to compile functions of mostly elementwise operations.
|
||||
- For models with fixed shape inputs or where the shapes don't change much, compile the entire graph
|
||||
- Replace slow implementations with MLX fast ops
|
||||
- Identify evaluation boundaries and unintended sync points (`mx.eval`, `item()`, NumPy conversions).
|
||||
- Check dtype promotion and scalar usage; keep precision consistent with intent.
|
||||
- Review compilation strategy; avoid unnecessary recompiles and closure captures.
|
||||
- Reduce peak memory via lazy loading order and releasing temporaries before `mx.eval`.
|
||||
- Suggest profiling steps if the bottleneck is unclear.
|
||||
|
||||
## References
|
||||
|
||||
- Read `references/fast-mlx-guide.md` for detailed tips and examples. Use it as the source of truth.
|
||||
|
||||
## Output expectations
|
||||
|
||||
- Provide concrete code changes with brief rationale
|
||||
- Call out changes that need user confirmation (e.g., enabling async eval or shapeless compile).
|
||||
350
.github/skills/fast-mlx/references/fast-mlx-guide.md
vendored
350
.github/skills/fast-mlx/references/fast-mlx-guide.md
vendored
@@ -1,350 +0,0 @@
|
||||
# Making MLX Go Fast
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Graph Evaluation](#graph-evaluation)
|
||||
- [Type Promotion](#type-promotion)
|
||||
- [Operations](#operations)
|
||||
- [Compile](#compile)
|
||||
- [Memory Use](#memory-use)
|
||||
- [Profiling](#profiling)
|
||||
|
||||
This guide assumes you have some familiarity with MLX and want to make your MLX
|
||||
model or algorithm as efficient as possible.
|
||||
|
||||
### Graph Evaluation
|
||||
|
||||
Recall, MLX is lazy. When you call an MLX op, no computation actually happens.
|
||||
You are simply building a graph. The computation happens when you explicitly or
|
||||
implicitly evaluate an array. Read more about how this works in the
|
||||
documentation:
|
||||
https://ml-explore.github.io/mlx/build/html/usage/lazy_evaluation.html
|
||||
|
||||
Evaluating the graph incurs some overhead, so don't do it too frequently.
|
||||
Conversely you don't want the graph to get too big before evaluating it as this
|
||||
can also be expensive. Most numerical and machine learning algorithms are
|
||||
iterative. A good place to evaluate the graph is at the end of each iteration.
|
||||
Some examples:
|
||||
|
||||
- After an iteration of gradient descent
|
||||
- After producing one token with a language model
|
||||
- After taking one denoising step in a diffusion model
|
||||
|
||||
Overly frequent evaluations sometimes happen by accident. For example:
|
||||
|
||||
```python
|
||||
# output is an mx.array
|
||||
for x in output:
|
||||
do_something(x.item())
|
||||
```
|
||||
|
||||
The same thing can be written more explicitly with operations and `mx.eval` as:
|
||||
|
||||
```python
|
||||
for i in range(len(output)):
|
||||
x = output[i]
|
||||
mx.eval(x)
|
||||
do_something(x.item())
|
||||
```
|
||||
|
||||
Two better options are:
|
||||
|
||||
1. When possible avoid calling `item()` and do everything in MLX.
|
||||
2. Move the entire output to Python or NumPy first.
|
||||
|
||||
An example of the second approach:
|
||||
|
||||
```python
|
||||
for x in output.tolist():
|
||||
do_something(x)
|
||||
```
|
||||
|
||||
#### Asynchronous Evaluation
|
||||
|
||||
For a latency sensitive computation which is run many times, `mx.async_eval`
|
||||
can be useful. Normally `mx.eval` is synchronous. It returns only when the
|
||||
computation is complete. Instead `mx.async_eval` asynchronously evaluates the
|
||||
graph and returns to the main thread immediately. You can use this to pipeline
|
||||
graph construction with computation like so:
|
||||
|
||||
```python
|
||||
def generator():
|
||||
out = mx.async_eval(my_function())
|
||||
|
||||
while True:
|
||||
out_next = mx.async_eval(my_function())
|
||||
mx.eval(out)
|
||||
yield out
|
||||
out = out_next
|
||||
```
|
||||
|
||||
For this to work `my_function()` cannot do any synchronous evaluations (e.g.
|
||||
calling `mx.eval`, converting to NumPy, etc.). Furthermore, any work done on
|
||||
`out` that is synchronous and on the same stream can stall the pipeline:
|
||||
|
||||
```python
|
||||
for out in generator():
|
||||
out = out * 2
|
||||
# Stalls the pipeline!
|
||||
mx.eval(out)
|
||||
```
|
||||
|
||||
An easy fix for this is to put the pipeline in a separate stream:
|
||||
|
||||
```python
|
||||
def generator():
|
||||
with mx.stream(mx.new_stream(mx.gpu)):
|
||||
out = mx.async_eval(my_function())
|
||||
|
||||
while True:
|
||||
out_next = mx.async_eval(my_function())
|
||||
mx.eval(out)
|
||||
yield out
|
||||
out = out_next
|
||||
```
|
||||
|
||||
### Type Promotion
|
||||
|
||||
One of the most common performance issues comes from accidental up-casting.
|
||||
Make sure you understand how type promotion works in MLX. The inputs to an MLX
|
||||
operation are typically promoted to a common type which doesn't lose precision.
|
||||
For example:
|
||||
|
||||
```python
|
||||
x = mx.array(1.0, mx.float32) * mx.array(2.0, mx.float16)
|
||||
```
|
||||
|
||||
will result in `x` with type `mx.float32`. Similarly:
|
||||
|
||||
```python
|
||||
x = mx.array(1.0, mx.bfloat16) * mx.array(2.0, mx.float16)
|
||||
```
|
||||
|
||||
will result in `x` with type `mx.float32`. A common mistake is to multiply a
|
||||
half-precision array by a default-typed scalar array which up-casts everything
|
||||
to `mx.float32`:
|
||||
|
||||
```python
|
||||
# Warning: x has type mx.float32
|
||||
x = my_fp16_array * mx.array(2.0)
|
||||
```
|
||||
|
||||
To multiply by a scalar while preserving the input type, use Python scalars.
|
||||
Python scalars are weakly typed and have more relaxed promotion rules when
|
||||
used with MLX arrays.
|
||||
|
||||
```python
|
||||
# Ok, x has type mx.float16
|
||||
x = my_fp16_array * 2.0
|
||||
```
|
||||
|
||||
### Operations
|
||||
|
||||
#### Use Fast Ops
|
||||
|
||||
Use `mx.fast` ops when possible:
|
||||
|
||||
- `mx.fast.rms_norm`
|
||||
- `mx.fast.layer_norm`
|
||||
- `mx.fast.rope`
|
||||
- `mx.fast.scaled_dot_product_attention`
|
||||
|
||||
A lot of these operations take a variety of parameters so they can be used for
|
||||
different variations of the function. For example, the weight and bias
|
||||
parameters are optional in `mx.fast.layer_norm` so it can be used with
|
||||
different permutations of inputs.
|
||||
|
||||
#### Precision
|
||||
|
||||
For operations which typically use higher precision there is usually no
|
||||
need to explicitly upcast. For example, `mx.fast.rms_norm` and
|
||||
`mx.fast.layer_norm` accumulate in higher precision so it's
|
||||
wasteful to upcast and downcast into and out of these operations:
|
||||
|
||||
```python
|
||||
# No need for this!
|
||||
mx.fast.rms_norm(x.astype(mx.float32), w, b, eps).astype(x.dtype)
|
||||
|
||||
# This is just as good:
|
||||
mx.fast.rms_norm(x, w, b, eps)
|
||||
```
|
||||
|
||||
Similarly, for `mx.softmax` use `precise=True` if you want to do the softmax in
|
||||
higher precision rather than explicitly casting the input and output.
|
||||
|
||||
#### Misc
|
||||
|
||||
- For vector-matrix multiplication `x @ W.T` is faster than `x @ W`, for
|
||||
matrix-vector multiplication `W @ x` is faster than `W.T @ x`
|
||||
- Use `mx.addmm` for `a @ b + c` (e.g. a linear layer with a bias).
|
||||
- Where it makes sense, use `mx.take_along_axis` and `mx.put_along_axis`
|
||||
instead of fancy indexing
|
||||
- Use broadcasting instead of concatenation. For example, prefer `mx.repeat(a,
|
||||
n)` over `mx.concatenate([a] * n)`
|
||||
|
||||
### Compile
|
||||
|
||||
Compiling graphs with `mx.compile` can make them run a lot faster. But there
|
||||
are some sharp-edges that are good to be aware of.
|
||||
|
||||
First, be aware of when a function will be recompiled. Recompilation is
|
||||
relatively expensive and should only be done if there is sufficient work over
|
||||
which to amortize the cost.
|
||||
|
||||
The default behavior of `mx.compile` is to do a shape-dependent compilation.
|
||||
This means the function will be recompiled if the shape of any input changes.
|
||||
|
||||
MLX supports a shapeless compilation by passing `shapeless=True` to
|
||||
`mx.compile`. It's easy to make hard-to-detect mistakes with shapeless
|
||||
compilation. Make sure to read and understand the documentation and use it
|
||||
with care:
|
||||
https://ml-explore.github.io/mlx/build/html/usage/compile.html#shapeless-compilation
|
||||
|
||||
A function will also be recompiled if any constant inputs change:
|
||||
|
||||
```python
|
||||
@mx.compile
|
||||
def fun(x, scale):
|
||||
return scale * x
|
||||
|
||||
fun(x, 3)
|
||||
|
||||
# Recompiles!
|
||||
fun(x, 4)
|
||||
```
|
||||
|
||||
In this case a simple fix is to make `scale` an `mx.array`.
|
||||
|
||||
#### Compiling Closures
|
||||
|
||||
Be careful when compiling a closure where the function encloses any
|
||||
`mx.array`.
|
||||
|
||||
```python
|
||||
y = some_function()
|
||||
|
||||
@mx.compile
|
||||
def fun(x):
|
||||
return x + y
|
||||
```
|
||||
|
||||
Since `y` is not an input to `fun`, the compiled graph will include the entire
|
||||
computation which produces `y`. Usually you only want to compute `y` one time
|
||||
and re-use it in the compiled function. Either explicitly pass it as an input
|
||||
to `fun` or pass it as an implicit input to `mx.compile` like so:
|
||||
|
||||
```python
|
||||
y = some_function()
|
||||
|
||||
@partial(mx.compile, inputs=[y])
|
||||
def fun(x):
|
||||
return x + y
|
||||
```
|
||||
|
||||
### Memory Use
|
||||
|
||||
#### Lazy Loading
|
||||
|
||||
Loading arrays from a file is lazy in MLX:
|
||||
|
||||
```python
|
||||
weights = mx.load("model.safetensors")
|
||||
```
|
||||
|
||||
The above function returns instantly, regardless of the file size. To actually
|
||||
load the weights into memory, you can do `mx.eval(weights)`.
|
||||
|
||||
Assume the weights are stored on disk in 32-bit precision (i.e. `mx.float32`).
|
||||
But for your model you only need 16-bit precision:
|
||||
|
||||
```python
|
||||
weights = mx.load("model.safetensors")
|
||||
mx.eval(weights)
|
||||
weights = {k: v.astype(mx.float16) for k, v in weights.items()}
|
||||
```
|
||||
|
||||
In the above, the weights will be loaded into memory in full precision and then
|
||||
cast to 16-bit. This requires memory for all the weights in 32-bit plus memory
|
||||
for the weights in 16-bit.
|
||||
|
||||
This is much better:
|
||||
|
||||
```python
|
||||
weights = mx.load("model.safetensors")
|
||||
weights = {k: v.astype(mx.float16) for k, v in weights.items()}
|
||||
mx.eval(weights)
|
||||
```
|
||||
|
||||
Evaluating after the cast to `mx.float16` reduces peak memory by nearly a
|
||||
third. That's because all the weights are never fully materialized in 32-bit.
|
||||
Right after each weight is loaded in 32-bit precision it is cast to 16-bit.
|
||||
The memory for the 32-bit weight can be reused when loading the next weight.
|
||||
|
||||
Note, MLX is only able to lazy load from a file when it is given to `mx.load`
|
||||
as a string path. Due to lifetime management issues, lazy loading from file
|
||||
handles is not supported. So avoid this:
|
||||
|
||||
```python
|
||||
weights = mx.load(open("model.safetensors", 'rb'))
|
||||
```
|
||||
|
||||
#### Release Temporaries
|
||||
|
||||
One way to reduce memory consumption is to avoid holding
|
||||
temporaries you don't need. This is a typical training loop:
|
||||
|
||||
```python
|
||||
for x, y in dataset:
|
||||
loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
mx.eval(model, optimizer.state)
|
||||
```
|
||||
|
||||
It's suboptimal since a reference to `grads` is held during the call to
|
||||
`mx.eval` which keeps the respective memory from being used for any other part
|
||||
of the computation.
|
||||
|
||||
This is better:
|
||||
|
||||
```python
|
||||
def step(x, y):
|
||||
loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
for x, y in dataset:
|
||||
loss = step(x, y)
|
||||
mx.eval(model, optimizer.state)
|
||||
```
|
||||
|
||||
In this case the reference to `grads` is released before `mx.eval` and the
|
||||
memory can be reused. You can achieve the same goal using `del` as long as it's
|
||||
before the call to `mx.eval`:
|
||||
|
||||
```python
|
||||
for x, y in dataset:
|
||||
loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
del grads
|
||||
mx.eval(model, optimizer.state)
|
||||
```
|
||||
|
||||
#### Misc
|
||||
|
||||
- MLX will cache memory buffers of recently released arrays rather than
|
||||
returning them to the system. In some cases, especially for variable shape
|
||||
computations, the cache can get large. To help with this, MLX has some
|
||||
functions for logging and customizing the behavior of memory allocation:
|
||||
https://ml-explore.github.io/mlx/build/html/python/metal.html
|
||||
|
||||
### Profiling
|
||||
|
||||
A good first step is to check GPU utilization using, for example,
|
||||
mactop: https://github.com/context-labs/mactop. If it's not pegged at close
|
||||
to 100% then there is likely a non-MLX bottleneck somewhere in the program. A
|
||||
common culprit is data loading or preprocessing.
|
||||
|
||||
If GPU utilization is good, a good next step is to figure out which operations
|
||||
are taking up so much time. One way to do this is with the Metal debugger. For
|
||||
that, see the documentation on profiling MLX with the Metal debugger:
|
||||
https://ml-explore.github.io/mlx/build/html/dev/metal_debugger.html
|
||||
Reference in New Issue
Block a user