Compare commits
31 Commits
f5e311a77c
...
9ab4826d20
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9ab4826d20 | ||
|
|
996a542011 | ||
|
|
b029668cd2 | ||
|
|
6c63163671 | ||
|
|
17397da70c | ||
|
|
78bcfba31b | ||
|
|
3e33172c12 | ||
|
|
95d7c81b20 | ||
|
|
7b9d0a5e44 | ||
|
|
fea0f87df9 | ||
|
|
3618966625 | ||
|
|
33dd3c2edd | ||
|
|
281750f0a9 | ||
|
|
ae410f3121 | ||
|
|
c144c8817c | ||
|
|
1cf878f5e0 | ||
|
|
d207275fea | ||
|
|
afd15018b7 | ||
|
|
061ae4407c | ||
|
|
967218b7c1 | ||
|
|
9bdda9f22e | ||
|
|
9597b7c9c5 | ||
|
|
849cc45d84 | ||
|
|
dbab95ec45 | ||
|
|
f4195f0118 | ||
|
|
2bb95c61ed | ||
|
|
93da550f65 | ||
|
|
e64483a66a | ||
|
|
7a74946c57 | ||
|
|
ffdeec72a6 | ||
|
|
7ad14e18ca |
154
README.md
154
README.md
@@ -16,35 +16,49 @@ uv pip install git+https://github.com/Blaizzy/mlx-video.git
|
||||
|
||||
## Supported Models
|
||||
|
||||
### LTX-2
|
||||
- [**LTX-2**](https://huggingface.co/Lightricks/LTX-Video) — 19B parameter video generation model from Lightricks
|
||||
- [**Wan2.1**](https://github.com/Wan-Video/Wan2.1) — 1.3B / 14B parameter T2V models (single-model pipeline)
|
||||
- [**Wan2.2**](https://github.com/Wan-Video/Wan2.2) — T2V-14B, TI2V-5B, and I2V-14B models (dual-model pipeline)
|
||||
|
||||
[LTX-2](https://huggingface.co/Lightricks/LTX-2) is a 19B parameter video generation model from Lightricks. See the full [LTX-2 model card](mlx_video/models/ltx_2/README.md) for detailed usage, CLI options, pipeline descriptions, and architecture.
|
||||
## Features
|
||||
|
||||
**Features:**
|
||||
- Text-to-Video (T2V), Image-to-Video (I2V), and Audio-to-Video (A2V)
|
||||
- Four pipelines: Distilled (fast), Dev (CFG), Dev Two-Stage (LoRA), Dev Two-Stage HQ (highest quality)
|
||||
- Synchronized audio-video generation (experimental)
|
||||
- LoRA support (local files or HuggingFace repos)
|
||||
- Prompt enhancement via Gemma
|
||||
**LTX-2 / LTX-2.3**
|
||||
- Text-to-Video (T2V), Image-to-Video (I2V), Audio-to-Video (A2V)
|
||||
- Audio-Video joint generation
|
||||
- Multi-pipeline: distilled, dev, dev-two-stage, dev-two-stage-hq
|
||||
- 2x spatial upscaling for images and videos
|
||||
- Prompt enhancement via Gemma
|
||||
|
||||
**Quick start:**
|
||||
**Wan2.1 / Wan2.2**
|
||||
- Text-to-Video (T2V) — 1.3B and 14B models
|
||||
- Image-to-Video (I2V) — 14B model
|
||||
- Flow-matching diffusion with classifier-free guidance
|
||||
- LoRA support (e.g. Wan2.2-Lightning for 4-step generation)
|
||||
|
||||
**General**
|
||||
- Optimized for Apple Silicon using MLX
|
||||
|
||||
---
|
||||
|
||||
## LTX-2
|
||||
|
||||
### Text-to-Video Generation
|
||||
|
||||
```bash
|
||||
# Text-to-Video (distilled, fastest)
|
||||
uv run mlx_video.generate --prompt "Two dogs wearing sunglasses, cinematic, sunset" -n 97 --width 768
|
||||
uv run mlx_video.ltx_2.generate --prompt "Two dogs wearing sunglasses, cinematic, sunset" -n 97 --width 768
|
||||
|
||||
# Image-to-Video
|
||||
uv run mlx_video.generate --prompt "A person dancing" --image photo.jpg
|
||||
uv run mlx_video.ltx_2.generate --prompt "A person dancing" --image photo.jpg
|
||||
|
||||
# Audio-to-Video
|
||||
uv run mlx_video.generate --audio-file music.wav --prompt "A band playing music"
|
||||
uv run mlx_video.ltx_2.generate --audio-file music.wav --prompt "A band playing music"
|
||||
|
||||
# Dev pipeline with CFG (higher quality)
|
||||
uv run mlx_video.generate --pipeline dev --prompt "A cinematic scene" --cfg-scale 3.0
|
||||
uv run mlx_video.ltx_2.generate --pipeline dev --prompt "A cinematic scene" --cfg-scale 3.0
|
||||
|
||||
# Dev two-stage HQ (highest quality)
|
||||
uv run mlx_video.generate --pipeline dev-two-stage-hq \
|
||||
uv run mlx_video.ltx_2.generate --pipeline dev-two-stage-hq \
|
||||
--prompt "A cinematic scene of ocean waves at golden hour" \
|
||||
--model-repo prince-canuma/LTX-2-dev
|
||||
```
|
||||
@@ -55,16 +69,124 @@ uv run mlx_video.generate --pipeline dev-two-stage-hq \
|
||||
|
||||
Pre-converted weights are available on HuggingFace ([LTX-2-distilled](https://huggingface.co/prince-canuma/LTX-2-distilled), [LTX-2-dev](https://huggingface.co/prince-canuma/LTX-2-dev), [LTX-2.3-distilled](https://huggingface.co/prince-canuma/LTX-2.3-distilled), [LTX-2.3-dev](https://huggingface.co/prince-canuma/LTX-2.3-dev)), or convert from the original Lightricks checkpoint:
|
||||
|
||||
|
||||
### LTX-2 CLI Options
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `--prompt`, `-p` | (required) | Text description of the video |
|
||||
| `--height`, `-H` | 512 | Output height (must be divisible by 64) |
|
||||
| `--width`, `-W` | 512 | Output width (must be divisible by 64) |
|
||||
| `--num-frames`, `-n` | 100 | Number of frames |
|
||||
| `--seed`, `-s` | 42 | Random seed for reproducibility |
|
||||
| `--fps` | 24 | Frames per second |
|
||||
| `--output`, `-o` | output.mp4 | Output video path |
|
||||
| `--save-frames` | false | Save individual frames as images |
|
||||
| `--model-repo` | Lightricks/LTX-2 | HuggingFace model repository |
|
||||
|
||||
|
||||
---
|
||||
|
||||
## Wan2.1 / Wan2.2
|
||||
|
||||
Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.com/Wan-Video/Wan2.2) are text-to-video diffusion models built on a DiT (Diffusion Transformer) backbone with a T5 text encoder and 3D VAE.
|
||||
|
||||
### Step 0: Download and Convert Weights
|
||||
|
||||
See the dedicated Wan2.1/Wan2.2 [README.md](mlx_video/models/wan_2/README.md) for details.
|
||||
|
||||
### Step 1: Generate Video
|
||||
|
||||
```bash
|
||||
uv run python -m mlx_video.models.ltx_2.convert \
|
||||
--source Lightricks/LTX-2 --output ./LTX-2-distilled --variant distilled
|
||||
# Wan2.1 — uses defaults from config (50 steps, shift=5.0, guide=5.0)
|
||||
python -m mlx_video.wan_2.generate \
|
||||
--model-dir wan21_mlx \
|
||||
--prompt "A cat playing piano in a cozy room"
|
||||
|
||||
# Wan2.2 — uses defaults from config (40 steps, shift=12.0, guide=3.0,4.0)
|
||||
python -m mlx_video.wan_2.generate \
|
||||
--model-dir wan22_mlx \
|
||||
--prompt "A cat playing piano in a cozy room"
|
||||
```
|
||||
|
||||
With custom settings:
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan_2.generate \
|
||||
--model-dir wan21_mlx \
|
||||
--prompt "Ocean waves at sunset, cinematic, 4K" \
|
||||
--negative-prompt "blurry, low quality" \
|
||||
--width 1280 \
|
||||
--height 720 \
|
||||
--num-frames 81 \
|
||||
--steps 50 \
|
||||
--guide-scale 5.0 \
|
||||
--shift 5.0 \
|
||||
--seed 42 \
|
||||
--output-path my_video.mp4
|
||||
```
|
||||
|
||||
The pipeline auto-detects the model version from `config.json` and selects the right pipeline mode (single or dual model).
|
||||
|
||||
### Image-to-Video (I2V-14B)
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan_2.generate \
|
||||
--model-dir wan22_i2v_mlx \
|
||||
--prompt "The camera slowly zooms in as the subject begins to move" \
|
||||
--image start.png \
|
||||
--num-frames 81 \
|
||||
--output-path my_video.mp4
|
||||
```
|
||||
|
||||
### LoRA Support
|
||||
|
||||
LoRAs can be used with the `--lora-high` and `--lora-low` command line switches.
|
||||
|
||||
For example, using the distilled [Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) LoRA for 4-step generation:
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan_2.generate \
|
||||
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-T2V-A14B-MLX \
|
||||
--width 480 \
|
||||
--height 704 \
|
||||
--num-frames 41 \
|
||||
--prompt "Two dogs of the poodle breed sitting on a beach wearing sunglasses, nodding with their heads, close up, cinematic, sunset" \
|
||||
--steps 4 \
|
||||
--guide-scale 1 \
|
||||
--trim-first-frames 1 \
|
||||
--seed 2391784614 \
|
||||
--lora-high /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/high_noise_model.safetensors 1 \
|
||||
--lora-low /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/low_noise_model.safetensors 1
|
||||
```
|
||||
|
||||

|
||||
|
||||
### Wan CLI Options
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `--model-dir` | (required) | Path to converted MLX model directory |
|
||||
| `--prompt` | (required) | Text description of the video |
|
||||
| `--image` | `None` | Input image path (for I2V models) |
|
||||
| `--negative-prompt` | `""` | Negative prompt for guidance |
|
||||
| `--width` | 1280 | Video width |
|
||||
| `--height` | 720 | Video height |
|
||||
| `--num-frames` | 81 | Number of frames (must be 4n+1) |
|
||||
| `--steps` | from config | Number of diffusion steps |
|
||||
| `--guide-scale` | from config | Guidance scale: float or `low,high` pair |
|
||||
| `--shift` | from config | Noise schedule shift |
|
||||
| `--seed` | -1 (random) | Random seed for reproducibility |
|
||||
| `--output-path` | `output.mp4` | Output video path |
|
||||
|
||||
---
|
||||
|
||||
## Requirements
|
||||
|
||||
- macOS with Apple Silicon
|
||||
- Python >= 3.11
|
||||
- MLX >= 0.22.0
|
||||
- For weight conversion: PyTorch (`pip install torch`)
|
||||
|
||||
## License
|
||||
|
||||
|
||||
911
docs/PORTING-GUIDE.md
Normal file
911
docs/PORTING-GUIDE.md
Normal file
@@ -0,0 +1,911 @@
|
||||
# Porting Diffusion Video Models to MLX: Lessons Learned
|
||||
|
||||
A practical guide distilled from porting Wan2.1/2.2 (1.3B–14B) and Helios 14B DiT
|
||||
video generation models from PyTorch to MLX on Apple Silicon. These lessons apply
|
||||
broadly to any diffusion-based video (or image) model port.
|
||||
|
||||
---
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Debugging Methodology](#1-debugging-methodology)
|
||||
2. [Precision & Dtype Pitfalls](#2-precision--dtype-pitfalls)
|
||||
3. [MLX-Specific Gotchas](#3-mlx-specific-gotchas)
|
||||
4. [Autoregressive Chunk Boundaries](#4-autoregressive-chunk-boundaries)
|
||||
5. [VAE Decoder Artifacts](#5-vae-decoder-artifacts)
|
||||
6. [Scheduler & Timestep Issues](#6-scheduler--timestep-issues)
|
||||
7. [Weight Conversion](#7-weight-conversion)
|
||||
8. [Text Conditioning Failures](#8-text-conditioning-failures)
|
||||
9. [Position Encodings (RoPE)](#9-position-encodings-rope)
|
||||
10. [Multi-Stage / Pyramid Pipelines](#10-multi-stage--pyramid-pipelines)
|
||||
11. [Common Symptoms → Root Causes](#11-common-symptoms--root-causes)
|
||||
12. [Verification Checklist](#12-verification-checklist)
|
||||
13. [Diagnostic Tools](#13-diagnostic-tools)
|
||||
|
||||
---
|
||||
|
||||
## 1. Debugging Methodology
|
||||
|
||||
### Component isolation first
|
||||
|
||||
Never debug the full pipeline. Test each component in isolation:
|
||||
|
||||
1. **Text encoder** — Does it produce embeddings with reasonable statistics? (std > 0.01)
|
||||
2. **Scheduler** — Do sigma/timestep values match the reference exactly?
|
||||
3. **Transformer** — Does a single forward pass match the reference? (cosine similarity > 0.999)
|
||||
4. **VAE decoder** — Feed reference latents into your VAE. Does the output look correct?
|
||||
|
||||
If every component matches individually but the pipeline fails, the bug is in
|
||||
**orchestration** — how components are wired together.
|
||||
|
||||
### Statistical fingerprinting
|
||||
|
||||
Track per-step statistics through the diffusion loop:
|
||||
|
||||
```python
|
||||
# After each denoising step
|
||||
print(f"step {i}: mean={latent.mean():.6f} std={latent.std():.6f} "
|
||||
f"min={latent.min():.4f} max={latent.max():.4f}")
|
||||
```
|
||||
|
||||
**What to look for:**
|
||||
- **Progressive mean drift** (e.g., -0.002 → -0.040 → -0.123) signals accumulating errors
|
||||
- **Collapsing std** (std dropping toward 0) signals broken conditioning or wrong noise schedule
|
||||
- **Exploding values** signal wrong sigma scaling or scheduler formula
|
||||
|
||||
### Cross-framework numerical comparison
|
||||
|
||||
The most powerful debugging tool: save intermediate tensors from your MLX pipeline,
|
||||
feed them to the PyTorch reference, compare outputs.
|
||||
|
||||
```python
|
||||
# In MLX pipeline, save inputs before transformer call
|
||||
mx.save("debug_inputs.npz", {"latent": latent, "timestep": t, "text_emb": text_emb})
|
||||
|
||||
# In PyTorch script, load and compare
|
||||
inputs = np.load("debug_inputs.npz")
|
||||
mlx_out = np.load("debug_output.npz")["flow"]
|
||||
pt_out = reference_model(torch.from_numpy(inputs["latent"]), ...)
|
||||
cos_sim = F.cosine_similarity(pt_out.flatten(), torch.from_numpy(mlx_out).flatten(), dim=0)
|
||||
# cos_sim > 0.999 = model is correct; bug is elsewhere
|
||||
# cos_sim < 0.99 = model has a bug; compare per-layer
|
||||
```
|
||||
|
||||
### Ablation testing
|
||||
|
||||
When a pipeline has multiple "fixes" or features, disable them one at a time:
|
||||
|
||||
- **Frozen history**: Fix history to the same value for all chunks → proves whether
|
||||
history propagation is the source of drift/zoom
|
||||
- **Single chunk**: Generate only 1 chunk → isolates per-chunk quality from
|
||||
multi-chunk interaction bugs
|
||||
- **Disable post-processing**: Remove cross-fade, blending, corrections → reveals
|
||||
what the raw model output looks like
|
||||
|
||||
### Use reference on same hardware
|
||||
|
||||
Run the PyTorch reference on the same device (MPS for Apple Silicon). CUDA and MPS
|
||||
produce different numerical results due to different float handling. Comparing your
|
||||
MLX output against a CUDA reference adds noise to the comparison.
|
||||
|
||||
```python
|
||||
# MPS may not support float64 — patch the reference:
|
||||
original_linspace = torch.linspace
|
||||
def patched_linspace(*args, **kwargs):
|
||||
kwargs.pop("dtype", None)
|
||||
return original_linspace(*args, dtype=torch.float32, **kwargs)
|
||||
torch.linspace = patched_linspace
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. Precision & Dtype Pitfalls
|
||||
|
||||
### The #1 source of subtle bugs
|
||||
|
||||
Precision issues caused the most insidious bugs in our port. They don't cause
|
||||
crashes — they cause progressive quality degradation that's hard to attribute.
|
||||
|
||||
### Residual connections MUST be float32
|
||||
|
||||
**Bug**: Progressive zoom/shrinking across autoregressive chunks.
|
||||
|
||||
**Root cause**: Residual additions (`x = x + attn_out`) in bfloat16. With 7-bit
|
||||
mantissa, high-frequency spatial detail is systematically truncated. Over 144
|
||||
residual ops × 6+ model calls per chunk, detail is progressively smoothed away.
|
||||
|
||||
**Fix**: Promote to float32 for the addition:
|
||||
```python
|
||||
# BAD — bfloat16 accumulation
|
||||
x = x + attn_out
|
||||
|
||||
# GOOD — match reference's .float() pattern
|
||||
x = (x.astype(mx.float32) + attn_out).astype(weight_dtype)
|
||||
```
|
||||
|
||||
**Rule**: If the reference uses `.float()` anywhere, copy that pattern exactly. It's
|
||||
there for a reason, even if a quick test seems to work without it.
|
||||
|
||||
### Scheduler computations need high precision
|
||||
|
||||
Diffusion schedulers involve:
|
||||
- `x0 = xt - sigma * flow` — catastrophic cancellation near sigma ≈ 1
|
||||
- `log(sigma)` and `exp()` — sensitive to small precision differences
|
||||
|
||||
Some references use float64 for these computations. MLX GPU doesn't support float64,
|
||||
so use float32 and accept small numerical differences, but **never** use bfloat16
|
||||
for scheduler math.
|
||||
|
||||
### Dtype propagation is invisible
|
||||
|
||||
Track dtype through your pipeline. A single bfloat16 intermediate can silently
|
||||
downcast everything downstream:
|
||||
|
||||
```python
|
||||
# This looks harmless but if model output is bfloat16:
|
||||
result = noise - sigma * model_output # result is bfloat16!
|
||||
|
||||
# Fix: explicit cast
|
||||
result = (noise.astype(mx.float32) - sigma * model_output.astype(mx.float32))
|
||||
```
|
||||
|
||||
### Type promotion rules differ across frameworks
|
||||
|
||||
- PyTorch: bfloat16 + float32 → float32
|
||||
- MLX: bfloat16 + float32 → float32 (same, but verify)
|
||||
- NumPy: no bfloat16 support
|
||||
|
||||
Always check what your framework does and match the reference's implicit promotations.
|
||||
|
||||
### Float32 for VAE decoding
|
||||
|
||||
**Bug** (Wan2.2): VAE decode in bfloat16 produced visibly worse quality than reference.
|
||||
|
||||
Official Wan2.2 runs VAE decode in `torch.float` (float32), but our converted weights
|
||||
were bfloat16. The VAE has many sequential layers where precision loss compounds.
|
||||
|
||||
**Fix**: Upcast VAE weights to float32 at load time. The VAE runs once per generation,
|
||||
so the performance impact is negligible compared to the transformer.
|
||||
|
||||
### Modulation/gate vectors need float32
|
||||
|
||||
**Bug** (Wan2.2): Quality degradation from bfloat16 modulation across 30 blocks × 50 steps.
|
||||
|
||||
The official Wan2.2 explicitly uses `torch.amp.autocast('cuda', dtype=torch.float32)`
|
||||
for time embeddings, modulation parameters, norm outputs before modulation, and gate ops.
|
||||
|
||||
**Fix**: Keep modulation in float32, cast to working dtype only when applying to the
|
||||
hidden state:
|
||||
```python
|
||||
# Modulation computed in float32
|
||||
e0 = self.modulation(time_emb) # float32
|
||||
scale, shift, gate = e0.split(3, axis=-1)
|
||||
|
||||
# Cast to bfloat16 only for the matmul with hidden state
|
||||
x = (x * (1 + scale.astype(x.dtype)) + shift.astype(x.dtype))
|
||||
```
|
||||
|
||||
### Map PyTorch autocast zones precisely
|
||||
|
||||
PyTorch models use nested `torch.amp.autocast` scopes to switch precision. Map these
|
||||
exactly:
|
||||
- **Outer scope** (`bfloat16`): attention QKV projections, FFN matmuls
|
||||
- **Inner scope** (`float32`): modulation, gates, norms, RoPE
|
||||
- **Residual stream**: float32 (the "backbone" between blocks)
|
||||
|
||||
```python
|
||||
# Wan2.2 dtype flow (matches official):
|
||||
# Modulation/gates: float32 (explicit)
|
||||
# QKV/FFN linear projections: bfloat16 (weight dtype)
|
||||
# RoPE: float32 (official uses float64, MLX lacks float64)
|
||||
# Attention Q/K: cast back to bfloat16 after RoPE
|
||||
# Residual stream: float32
|
||||
```
|
||||
|
||||
### Float32 promotion cascades kill performance
|
||||
|
||||
**Bug** (Wan2.2): ~2x slowdown from accidental float32 promotion.
|
||||
|
||||
A single float32 tensor (e.g., time embedding) flowing into bfloat16 operations
|
||||
promotes the entire computation graph to float32. In Wan2.2:
|
||||
- Time embedding MLP output (float32) fed into transformer → all layers float32
|
||||
- RoPE frequencies (float32) applied to Q/K → all attention float32
|
||||
|
||||
**Fix**: Cast intermediate results to model dtype at promotion boundaries:
|
||||
```python
|
||||
# After time embedding MLP (float32), cast before feeding to transformer
|
||||
time_emb = time_mlp(t).astype(model_dtype)
|
||||
|
||||
# After RoPE (float32), cast Q/K back to attention dtype
|
||||
q = rope_apply(q, freqs).astype(v.dtype)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. MLX-Specific Gotchas
|
||||
|
||||
### Underscore-prefixed attributes are invisible
|
||||
|
||||
**Bug** (Wan2.2): 87 of 110 VAE weights silently dropped during loading.
|
||||
|
||||
MLX's `nn.Module.parameters()` and `nn.Module.load_weights()` **skip** attributes
|
||||
whose names start with underscore. If you name a layer `self._layer_0`, its weights
|
||||
will never be loaded or saved.
|
||||
|
||||
```python
|
||||
# BAD — weights silently ignored
|
||||
self._layer_0 = nn.Linear(...) # nn.Module skips _prefixed attrs
|
||||
|
||||
# GOOD
|
||||
self.layer_0 = nn.Linear(...)
|
||||
```
|
||||
|
||||
This is especially insidious because there's no error — the model loads, runs, and
|
||||
produces output. The output is just garbage because most weights are random.
|
||||
|
||||
### nn.Sequential indexing vs named children
|
||||
|
||||
PyTorch's `nn.Sequential` uses integer indices (`sequential.0.weight`), while MLX's
|
||||
module hierarchy uses named attributes. When mirroring a PyTorch module structure,
|
||||
you need explicit key sanitization:
|
||||
|
||||
```python
|
||||
def sanitize_key(key):
|
||||
# PyTorch: "decoder.middle.0.residual.1.weight"
|
||||
# MLX: "decoder.middle.layer_0.residual.layer_1.weight"
|
||||
key = re.sub(r'\.(\d+)', lambda m: f'.layer_{m.group(1)}', key)
|
||||
return key
|
||||
```
|
||||
|
||||
### Reshape axis ordering differs from PyTorch
|
||||
|
||||
**Bug** (Wan2.2): Green checkerboard pattern from VAE attention.
|
||||
|
||||
`[B,C,T,H,W]` cannot be directly reshaped to `[BT,C,H,W]` because in memory C
|
||||
comes before T. PyTorch's `reshape` works because it handles non-contiguous tensors.
|
||||
MLX requires explicit transpose first:
|
||||
|
||||
```python
|
||||
# BAD — mixes channels with time
|
||||
x = x.reshape(B*T, C, H, W) # Corrupts spatial layout
|
||||
|
||||
# GOOD — make B,T adjacent first
|
||||
x = x.transpose(0, 2, 1, 3, 4) # [B,T,C,H,W]
|
||||
x = x.reshape(B*T, C, H, W) # Now correct
|
||||
```
|
||||
|
||||
### Patchify channel ordering
|
||||
|
||||
**Bug** (Wan2.2): Solid green video output from wrong patchify order.
|
||||
|
||||
When converting a Conv3d patchify to a manual reshape+linear, the dimension ordering
|
||||
in the reshape must match the Conv3d weight layout. Conv3d expects `[C, pt, ph, pw]`
|
||||
(channels slowest), but a naive reshape produces `[pt, ph, pw, C]` (channels fastest):
|
||||
|
||||
```python
|
||||
# BAD — channel scrambling
|
||||
patches = x.reshape(B, F', H', W', pt, ph, pw, C)
|
||||
|
||||
# GOOD — match Conv3d weight layout
|
||||
patches = x.reshape(B, F', pt, H', ph, W', pw, C)
|
||||
patches = patches.transpose(0, 1, 3, 5, 7, 2, 4, 6) # [B, F', H', W', C, pt, ph, pw]
|
||||
```
|
||||
|
||||
Verify numerically: the fixed version should match Conv3d output to ~1e-6.
|
||||
|
||||
### mx.zeros / padding inherits dtype
|
||||
|
||||
Use dtype-aware `mx.zeros` for padding and concatenation to avoid promotion:
|
||||
|
||||
```python
|
||||
# BAD — default float32 padding promotes bfloat16 input
|
||||
pad = mx.zeros((B, pad_len, C)) # float32!
|
||||
x = mx.concatenate([pad, x], axis=1) # x promoted to float32
|
||||
|
||||
# GOOD — match input dtype
|
||||
pad = mx.zeros((B, pad_len, C), dtype=x.dtype)
|
||||
x = mx.concatenate([pad, x], axis=1) # stays bfloat16
|
||||
```
|
||||
|
||||
### Use mx.fast kernels
|
||||
|
||||
Replace manual implementations with fused MLX kernels where possible:
|
||||
|
||||
```python
|
||||
# Manual RMS norm → mx.fast.rms_norm
|
||||
# Manual LayerNorm → mx.fast.layer_norm
|
||||
# Manual attention → mx.fast.scaled_dot_product_attention
|
||||
```
|
||||
|
||||
These are faster and handle precision internally.
|
||||
|
||||
---
|
||||
|
||||
## 4. Autoregressive Chunk Boundaries
|
||||
|
||||
For models that generate long videos by autoregressively extending chunks (Helios,
|
||||
CogVideoX, etc.), chunk boundaries are the primary source of visual artifacts.
|
||||
|
||||
### Don't add post-processing the reference doesn't have
|
||||
|
||||
**Bug**: Added pixel cross-fade to smooth boundaries → caused 40% sharpness drop.
|
||||
|
||||
The reference pipeline used **no cross-fade at all**. The first frame of each new
|
||||
chunk is intentionally a sharp reconstruction conditioned on history. Blending it with
|
||||
the previous chunk's tail (which has different content) creates blur.
|
||||
|
||||
**Rule**: Before adding smoothing/blending, verify the reference doesn't do it.
|
||||
Reference simplicity is usually correct.
|
||||
|
||||
### First-frame artifacts are common
|
||||
|
||||
The first pixel frame of each non-first chunk is typically a distorted reconstruction
|
||||
of the conditioning frame. In many models, this is expected behavior:
|
||||
|
||||
- **Fix**: Drop the first frame from each chunk
|
||||
- **Verify frame math**: If 33 raw frames at 16fps → drop 1 → 32 frames = exactly 2 seconds
|
||||
|
||||
### History conditioning errors compound
|
||||
|
||||
Small errors in how history is prepared, sliced, patchified, or position-encoded
|
||||
will compound across chunks. The error is invisible in chunk 1, small in chunk 2,
|
||||
and catastrophic by chunk 5.
|
||||
|
||||
**Debug strategy**: Generate with frozen history (same history for every chunk).
|
||||
If the artifact disappears, the bug is in history handling.
|
||||
|
||||
---
|
||||
|
||||
## 5. VAE Decoder Artifacts
|
||||
|
||||
### Causal temporal convolutions cause boundary warmup
|
||||
|
||||
Video VAEs (WanVAE, CogVideoX-VAE) use causal temporal convolutions. When decoding
|
||||
each chunk independently, the first few frames lack temporal context (only zero
|
||||
padding), causing:
|
||||
|
||||
- **~7% contrast drop** in first frames of each chunk
|
||||
- **Spatial brightness redistribution** (face darkens, background brightens)
|
||||
|
||||
This is inherent to the architecture. The reference has the same effect but at
|
||||
lower magnitude.
|
||||
|
||||
### Post-processing to fix VAE warmup
|
||||
|
||||
Two-stage correction applied to first N frames of each non-first chunk:
|
||||
|
||||
```python
|
||||
# Stage 1: Spatially-varying brightness correction
|
||||
# Downsample reference (previous chunk's last frame) and current frame
|
||||
ref_small = cv2.resize(ref_frame, (w//16, h//16), interpolation=cv2.INTER_AREA)
|
||||
cur_small = cv2.resize(cur_frame, (w//16, h//16), interpolation=cv2.INTER_AREA)
|
||||
diff_small = ref_small - cur_small
|
||||
diff_full = cv2.resize(diff_small, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||
corrected = cur_frame + ramp * diff_full # ramp: 1.0 → 0.0 over N frames
|
||||
|
||||
# Stage 2: Per-channel contrast matching
|
||||
for c in range(3):
|
||||
ref_std = np.std(ref_frame[:,:,c])
|
||||
cur_std = np.std(corrected[:,:,c])
|
||||
scale = 1.0 + ramp * (ref_std / (cur_std + 1e-6) - 1.0)
|
||||
corrected[:,:,c] = (corrected[:,:,c] - mean) * scale + mean
|
||||
```
|
||||
|
||||
### VAE overlap decode does NOT work
|
||||
|
||||
**Attempted**: Prepend previous chunk's last latent frames to give the decoder
|
||||
temporal context.
|
||||
|
||||
**Result**: Made things **worse** (22% contrast drop vs 7%). The causal convolutions
|
||||
see conflicting content from different chunks and create larger artifacts than
|
||||
zero-padding.
|
||||
|
||||
**Lesson**: Overlap only works when tiles contain the same content from the same
|
||||
denoising process (e.g., spatial tiling). It fails for temporal chunks with
|
||||
different content.
|
||||
|
||||
### Per-chunk VAE decoding is correct
|
||||
|
||||
Decode each chunk's latents independently, not concatenated. Concatenating all chunks
|
||||
and decoding together lets boundary discontinuities propagate through temporal
|
||||
convolutions, creating worse artifacts.
|
||||
|
||||
### First-frame quality: causal padding strategies
|
||||
|
||||
Multiple approaches were tried for the first-frame quality issue in Wan VAE:
|
||||
|
||||
| Approach | Result |
|
||||
|----------|--------|
|
||||
| Zero padding (default) | First ~4 frames degraded, but matches training |
|
||||
| Replicate padding | Fixes artifacts but causes color intensity bias (conv applies all kernel weights to same value) |
|
||||
| Warmup frame prepend | Helps motion but warmup frame itself has artifacts |
|
||||
| Mirror-reflect warmup | Best compromise — varied context without zeros, no intensity bias |
|
||||
|
||||
**Lesson**: Don't assume "replicate padding is better than zero padding." The model
|
||||
was trained with zero padding; changing it shifts the gain. Instead, prepend warmup
|
||||
frames and trim them after decoding.
|
||||
|
||||
### RMS_norm vs L2 normalization
|
||||
|
||||
**Bug** (Wan2.2): Garbled output from incorrect normalization.
|
||||
|
||||
A PyTorch class named `RMS_norm` actually uses `F.normalize` (L2 norm: `x / ||x||_2`),
|
||||
not RMS normalization (`x / sqrt(mean(x²))`). The difference is a factor of `sqrt(C)`,
|
||||
causing values to explode through the decoder.
|
||||
|
||||
**Lesson**: Don't trust class names — read the actual implementation.
|
||||
|
||||
### Temporal frame count: causal boundary effects
|
||||
|
||||
**Bug** (Wan2.2): VAE produced 12 frames instead of 9 for a 9-frame input.
|
||||
|
||||
PyTorch reference processes frames one-by-one with caching, skipping temporal conv for
|
||||
the first chunk. All-at-once decoding produces extra frames from zero-padded causal
|
||||
context.
|
||||
|
||||
**Fix**: Use `first_chunk=True` flag to trim causal boundary frames, matching the
|
||||
reference's chunked behavior.
|
||||
|
||||
### Chunked VAE encoding for I2V
|
||||
|
||||
**Bug** (Wan2.2 I2V-14B): Incorrect latents from non-chunked encoding.
|
||||
|
||||
Non-chunked encoding with causal zero-padding produces incorrect latents because
|
||||
temporal features don't propagate correctly without caching. The reference uses chunked
|
||||
encoding (1+4+4+... frames) with persistent temporal cache.
|
||||
|
||||
**Fix**: Implement chunked encoding with `feat_cache` propagation through CausalConv3d,
|
||||
ResidualBlock, and Resample layers.
|
||||
|
||||
---
|
||||
|
||||
## 6. Scheduler & Timestep Issues
|
||||
|
||||
### Copy formulas exactly
|
||||
|
||||
Even small differences in scheduler formulas compound over many steps:
|
||||
|
||||
```python
|
||||
# Dynamic time shifting — reference uses specific formula
|
||||
mu = 0.5 + shift * 0.5 # NOT shift * 0.6 or any other constant
|
||||
|
||||
# Euler step
|
||||
x_next = x + (sigma_next - sigma) * flow # order matters: next - current
|
||||
```
|
||||
|
||||
### Verify sigma schedules numerically
|
||||
|
||||
Print and compare sigma values at each step:
|
||||
|
||||
```python
|
||||
# Reference
|
||||
sigmas_ref = [1.0, 0.99375, 0.9875, ...]
|
||||
|
||||
# Your implementation
|
||||
sigmas = scheduler.get_sigmas(steps)
|
||||
for i, (r, m) in enumerate(zip(sigmas_ref, sigmas)):
|
||||
assert abs(r - m) < 1e-6, f"Step {i}: ref={r}, mlx={m}"
|
||||
```
|
||||
|
||||
### Timestep embedding precision
|
||||
|
||||
Integer vs float timesteps matter. Some models expect `timestep=999` (int), others
|
||||
expect `timestep=0.999` (float). Wrong type can silently produce wrong embeddings
|
||||
with reasonable-looking but incorrect statistics.
|
||||
|
||||
### Boundary conditions: ±inf at sigma endpoints
|
||||
|
||||
**Bug** (Wan2.2): Greenish/yellow constant output from DPM++/UniPC schedulers.
|
||||
|
||||
The `lambda(sigma)` function must return `-inf` at `sigma=1.0` (pure noise) and `+inf`
|
||||
at `sigma=0.0` (clean signal). Our implementation returned `0.0`, causing massive x0
|
||||
overscaling on the first denoising step.
|
||||
|
||||
PyTorch naturally computes `torch.log(0) = -inf`, and `math.expm1(-inf) = -1.0`
|
||||
handles the formulas correctly. Reproduce this behavior explicitly:
|
||||
|
||||
```python
|
||||
def _lambda(self, sigma):
|
||||
if sigma >= 1.0:
|
||||
return float('-inf')
|
||||
if sigma <= 0.0:
|
||||
return float('inf')
|
||||
return -math.log(sigma / (1 - sigma))
|
||||
```
|
||||
|
||||
### UniPC corrector coefficients
|
||||
|
||||
**Bug** (Wan2.2): Accumulated artifacts across 47+ steps from wrong polynomial weights.
|
||||
|
||||
The UniPC corrector must compute `rhos_c` via `linalg.solve` for order ≥ 2. Hardcoded
|
||||
`0.5` was 7× too large for the history weight (actual: ~0.08), causing massive
|
||||
overweighting of history corrections.
|
||||
|
||||
---
|
||||
|
||||
## 7. Weight Conversion
|
||||
|
||||
### Always verify statistically
|
||||
|
||||
After converting weights from PyTorch to MLX format:
|
||||
|
||||
```python
|
||||
for name in mlx_weights:
|
||||
pt = pytorch_weights[map_name(name)]
|
||||
mx_val = np.array(mlx_weights[name])
|
||||
pt_val = pt.numpy()
|
||||
cos_sim = np.dot(mx_val.flat, pt_val.flat) / (
|
||||
np.linalg.norm(mx_val) * np.linalg.norm(pt_val) + 1e-10
|
||||
)
|
||||
if cos_sim < 0.9999:
|
||||
print(f"MISMATCH: {name} cos_sim={cos_sim:.6f}")
|
||||
```
|
||||
|
||||
### Conv3d → Linear reshaping
|
||||
|
||||
When converting 3D convolutions to linear layers (common for MLX which prefers
|
||||
linear ops), the flattening order must match:
|
||||
|
||||
```python
|
||||
# PyTorch Conv3d weight: (out_ch, in_ch, kT, kH, kW)
|
||||
# Flatten to Linear: (out_ch, in_ch * kT * kH * kW)
|
||||
# The reshape order MUST match how the input is patchified
|
||||
```
|
||||
|
||||
### Sanitization functions
|
||||
|
||||
Write explicit weight sanitization that maps reference key names to your key names.
|
||||
Don't rely on automatic matching — key naming conventions differ between frameworks.
|
||||
|
||||
### Module structure must mirror reference for direct loading
|
||||
|
||||
**Bug** (Wan2.2): Rewrote entire VAE module hierarchy to match PyTorch `nn.Sequential`
|
||||
structure. ResidualBlock needed `None` gaps at specific indices to match the original
|
||||
`nn.Sequential(RMSNorm, SiLU, Conv3d, ...)` indexing.
|
||||
|
||||
When possible, structure your modules to accept reference weights directly without
|
||||
sanitization. This eliminates an entire class of bugs.
|
||||
|
||||
### Save VAE weights in float32
|
||||
|
||||
Even if the model uses bfloat16 for the transformer, save VAE weights in float32.
|
||||
bfloat16 → float32 roundtrip loses precision that cannot be recovered by load-time
|
||||
upcast.
|
||||
|
||||
### Temporal downsample/upsample order
|
||||
|
||||
**Bug** (Wan2.2): `temporal_downsample=[True, True, False]` but reference uses
|
||||
`[False, True, True]`. Stage 0 created a `time_conv` with random weights (no matching
|
||||
file key), and Stage 2 missed its `time_conv` (weights silently dropped).
|
||||
|
||||
Always verify boolean flags for each stage by inspecting the actual weight file keys.
|
||||
|
||||
### Silent weight drops are the worst bugs
|
||||
|
||||
When `load_weights()` with `strict=False` silently skips keys that don't match, you
|
||||
get a model with random weights for those layers. This produces output that looks
|
||||
"almost right" but is subtly wrong. Always log which keys were loaded vs skipped:
|
||||
|
||||
```python
|
||||
loaded_keys = set()
|
||||
for key, value in weights:
|
||||
if key in model_params:
|
||||
loaded_keys.add(key)
|
||||
# Check for missing
|
||||
expected = set(model_params.keys())
|
||||
missing = expected - loaded_keys
|
||||
if missing:
|
||||
print(f"WARNING: {len(missing)} weights not loaded: {list(missing)[:5]}...")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. Text Conditioning Failures
|
||||
|
||||
### Symptom: model predicts noise back to itself
|
||||
|
||||
If the model output correlates > 0.8 with its input noise, text conditioning is
|
||||
likely broken. The model has learned nothing from the prompt and is just returning
|
||||
its input.
|
||||
|
||||
### Check embedding statistics
|
||||
|
||||
```python
|
||||
text_emb = text_encoder(prompt)
|
||||
print(f"text_emb: mean={text_emb.mean():.4f} std={text_emb.std():.4f}")
|
||||
# std < 0.01 → embeddings are collapsed → broken encoder or wrong weights
|
||||
# std > 10.0 → embeddings are exploding → wrong normalization
|
||||
```
|
||||
|
||||
### Verify with ablation
|
||||
|
||||
```python
|
||||
# Generate with real text
|
||||
output_text = denoise(latent, text_emb=real_embeddings)
|
||||
# Generate with zeros
|
||||
output_zero = denoise(latent, text_emb=mx.zeros_like(real_embeddings))
|
||||
# Compare
|
||||
text_influence = np.mean(np.abs(output_text - output_zero))
|
||||
print(f"Text influence: {text_influence:.4f}") # Should be > 0 (typically 30-60% of output)
|
||||
```
|
||||
|
||||
### Text preprocessing must match exactly
|
||||
|
||||
**Bug** (Wan2.2): Patchy-blurry output from wrong negative prompt tokenization.
|
||||
|
||||
The official Wan2.2 tokenizer applies `ftfy.fix_text` + `html.unescape` + whitespace
|
||||
normalization before tokenization. Without this, fullwidth Chinese commas (U+FF0C)
|
||||
tokenize differently from ASCII commas (U+002C), causing **27 different token IDs**
|
||||
in the negative prompt. This made CFG's unconditional prediction wrong.
|
||||
|
||||
**Fix**: Apply the same text cleaning pipeline as the reference:
|
||||
```python
|
||||
import ftfy
|
||||
import html
|
||||
import re
|
||||
|
||||
def clean_text(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(text)
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
return text
|
||||
```
|
||||
|
||||
### T5 encoder precision
|
||||
|
||||
**Bug** (Wan2.2): Quality degradation from bfloat16 T5 attention.
|
||||
|
||||
T5 uses **no scaling** in attention (no `1/sqrt(d)` factor), so attention logits can
|
||||
be very large. bfloat16 softmax loses significant precision across 24 encoder layers.
|
||||
|
||||
**Fix**: Compute T5 QK^T and softmax in float32. The T5 encoder only runs once per
|
||||
generation, so the performance impact is negligible.
|
||||
|
||||
### Dual-model text embeddings
|
||||
|
||||
**Bug** (Wan2.2 I2V-14B): Low/high noise models have different `text_embedding` weights
|
||||
(~42% relative difference). Using one model's embeddings for both caused incorrect
|
||||
text conditioning for the high-noise model that handles critical early denoising steps.
|
||||
|
||||
**Fix**: Compute separate text embeddings for each model in dual-model setups.
|
||||
|
||||
---
|
||||
|
||||
## 9. Position Encodings (RoPE)
|
||||
|
||||
### Multi-scale consistency
|
||||
|
||||
In pyramid/multi-resolution models, RoPE must be computed consistently across scales.
|
||||
If the model operates at 1/4 resolution in an early stage, the position grid must
|
||||
reflect the actual spatial dimensions, not the final target dimensions.
|
||||
|
||||
### History vs current chunk
|
||||
|
||||
When conditioning on history from a previous chunk, the position encoding for
|
||||
history frames must match what the model saw during training. Mismatches between
|
||||
history and current-chunk position encodings can cause subtle spatial distortions
|
||||
that compound across chunks.
|
||||
|
||||
### Factorized RoPE
|
||||
|
||||
3D video models often use factorized RoPE (separate temporal, height, width
|
||||
frequencies). Verify each axis independently:
|
||||
|
||||
```python
|
||||
# Compare temporal frequencies
|
||||
assert np.allclose(mlx_rope_t, ref_rope_t, atol=1e-5)
|
||||
# Compare spatial frequencies
|
||||
assert np.allclose(mlx_rope_h, ref_rope_h, atol=1e-5)
|
||||
assert np.allclose(mlx_rope_w, ref_rope_w, atol=1e-5)
|
||||
```
|
||||
|
||||
### Per-axis frequency construction
|
||||
|
||||
**Bug** (Wan2.2): Grey/artifact-filled output from wrong frequency distribution.
|
||||
|
||||
The reference uses three separate `rope_params()` calls with different dimension
|
||||
normalizations (e.g., 44, 42, 42 for Wan) so each axis gets its own full frequency
|
||||
range. Consolidating into a single `rope_params(head_dim)` call and splitting gave
|
||||
height frequencies starting at 0.042 and width at 0.002 (should be 1.0 for both).
|
||||
|
||||
**Fix** (and subsequent revert): This bug was introduced as a "fix" for a previous
|
||||
RoPE issue, then had to be reverted. The lesson: RoPE changes have far-reaching effects.
|
||||
Always verify with actual generation, not just numerical comparison of frequencies.
|
||||
|
||||
**Lesson**: Read the reference's frequency construction very carefully. Don't
|
||||
"simplify" three separate calls into one unless you verify the frequency distribution
|
||||
matches exactly.
|
||||
|
||||
---
|
||||
|
||||
## 10. Multi-Stage / Pyramid Pipelines
|
||||
|
||||
### Each stage is a potential failure point
|
||||
|
||||
Pyramid pipelines (generate at low res, upsample, refine at high res) multiply the
|
||||
number of things that can go wrong:
|
||||
|
||||
- Downsampling method (bilinear vs area) must match reference
|
||||
- Energy compensation factors (e.g., ×2 after bilinear downsample) must be present
|
||||
- Alpha/beta noise mixing coefficients are stage-dependent
|
||||
- Frame indices and history resolution change per stage
|
||||
|
||||
### Test single-stage first
|
||||
|
||||
If the model works at full resolution for a single stage but fails in the pyramid,
|
||||
the bug is in stage orchestration — typically in how latents are passed between
|
||||
stages or how position encodings adapt to different resolutions.
|
||||
|
||||
### Integration bugs are the hardest
|
||||
|
||||
We verified every Helios component matched the reference individually, but the
|
||||
pyramid still produced uniform color. The bug was in dtype handling during stage
|
||||
transitions. Integration bugs only appear when components interact.
|
||||
|
||||
---
|
||||
|
||||
## 11. Common Symptoms → Root Causes
|
||||
|
||||
| Symptom | Likely Root Causes |
|
||||
|---------|-------------------|
|
||||
| **Pure noise output** | Wrong sigma schedule, broken text conditioning, incorrect weight mapping |
|
||||
| **Uniform color** | Model predicting noise back; text embeddings collapsed; wrong timestep format |
|
||||
| **Progressive zoom/shrink** | bfloat16 residuals truncating high-freq detail; RoPE mismatch across chunks |
|
||||
| **Brightness jumps at boundaries** | VAE causal warmup; cross-fade blending misaligned content |
|
||||
| **Color drift across chunks** | Dtype in scheduler step; history normalization missing |
|
||||
| **Blur at boundaries** | Cross-fade enabled; latent blending; wrong VAE decode order |
|
||||
| **Grid/checker patterns** | Patchify channel ordering bug; latent blend artifacts; reshape axis error |
|
||||
| **Green/magenta tint** | VAE weight key mismatch; wrong denormalization constants; cv2 YUV color matrix |
|
||||
| **Mean drift across steps** | bfloat16 accumulation; wrong scheduler formula; missing energy compensation |
|
||||
| **Garbled/scrambled output** | Silent weight drops (underscore prefix, wrong key mapping); RMS vs L2 norm |
|
||||
| **Greenish-yellow constant** | Scheduler boundary condition (log(0) not returning -inf); x0 overscaling |
|
||||
| **~2x slower than expected** | Float32 promotion cascade from single mistyped intermediate |
|
||||
| **Extra output frames** | Causal padding producing extra temporal frames; missing `first_chunk` trim |
|
||||
| **Grey/artifact output** | RoPE frequency construction wrong (per-axis vs single-call) |
|
||||
| **Patchy-blurry with CFG** | Text preprocessing mismatch (fullwidth vs ASCII chars → wrong tokenization) |
|
||||
| **I2V temporal mismatch** | Non-chunked VAE encoding vs reference's chunked encoding with temporal cache |
|
||||
|
||||
---
|
||||
|
||||
## 12. Verification Checklist
|
||||
|
||||
Use this checklist when porting a new diffusion video model:
|
||||
|
||||
### Model
|
||||
- [ ] Weight conversion: all keys mapped, cosine similarity > 0.9999
|
||||
- [ ] No silent weight drops (log loaded vs expected keys)
|
||||
- [ ] Single forward pass matches reference (cos_sim > 0.999)
|
||||
- [ ] Residual connections use float32 accumulation
|
||||
- [ ] Attention computation matches reference precision
|
||||
- [ ] Modulation/gate vectors in float32 (if reference uses autocast)
|
||||
- [ ] No underscore-prefixed module attributes (MLX ignores them)
|
||||
|
||||
### Scheduler
|
||||
- [ ] Sigma values match reference at every step (diff < 1e-6)
|
||||
- [ ] Timestep format correct (int vs float, scale factor)
|
||||
- [ ] Dynamic shifting formula copied exactly
|
||||
- [ ] Step function returns correct dtype (float32)
|
||||
- [ ] Boundary conditions: lambda(-inf) at sigma=1, lambda(+inf) at sigma=0
|
||||
- [ ] Higher-order coefficients computed (not hardcoded) for UniPC/DPM++
|
||||
|
||||
### Text Encoder
|
||||
- [ ] Embedding statistics reasonable (0.01 < std < 10)
|
||||
- [ ] Text influence > 0 (ablation test)
|
||||
- [ ] Tokenization matches (special tokens, padding, max length)
|
||||
- [ ] Text preprocessing matches (ftfy, html unescape, whitespace normalization)
|
||||
- [ ] T5/CLIP attention precision (float32 softmax if no 1/sqrt(d) scaling)
|
||||
- [ ] Separate embeddings for dual-model setups (if applicable)
|
||||
|
||||
### VAE
|
||||
- [ ] Denormalization constants match training pipeline
|
||||
- [ ] Per-chunk decoding (not concatenated)
|
||||
- [ ] Temporal frame count correct (account for causal padding)
|
||||
- [ ] Weight keys mapped correctly (encoder vs decoder)
|
||||
- [ ] Weights stored/loaded in float32 (not bfloat16)
|
||||
- [ ] Temporal downsample/upsample order matches reference
|
||||
- [ ] RMS_norm vs L2_norm: check actual implementation, not class name
|
||||
- [ ] Chunked encoding for I2V (if applicable)
|
||||
- [ ] Reshape axis ordering correct ([B,C,T,H,W] → transpose before reshape)
|
||||
|
||||
### Pipeline Orchestration
|
||||
- [ ] Position encodings consistent across stages/chunks
|
||||
- [ ] History slicing and conditioning correct
|
||||
- [ ] Noise generation matches (distribution, correlation structure)
|
||||
- [ ] Multi-chunk output visually consistent (no progressive degradation)
|
||||
- [ ] Dimension auto-alignment (divisible by patch_size × vae_stride)
|
||||
- [ ] Dtype-aware padding (mx.zeros with explicit dtype)
|
||||
|
||||
### Output
|
||||
- [ ] Frame count matches expected (account for warmup/trim)
|
||||
- [ ] FPS correct
|
||||
- [ ] Color range [0, 255] uint8 for video
|
||||
- [ ] No first-frame duplication artifacts
|
||||
- [ ] Video codec correct (imageio/libx264 preferred over cv2/mp4v on macOS)
|
||||
|
||||
### Performance
|
||||
- [ ] No float32 promotion cascades (check with profiler)
|
||||
- [ ] Using mx.fast kernels (rms_norm, layer_norm, sdpa)
|
||||
- [ ] Time embedding computed once per sample (not per position)
|
||||
- [ ] Memory cleanup (delete temporaries before mx.eval)
|
||||
|
||||
---
|
||||
|
||||
## 13. Diagnostic Tools
|
||||
|
||||
### General video diagnostics (`scripts/video/`)
|
||||
|
||||
| Script | Purpose |
|
||||
|--------|---------|
|
||||
| `compare_videos.py` | PSNR, SSIM, temporal coherence, color fidelity between two videos |
|
||||
| `video_quality.py` | Sharpness, stability, defect detection, chunk boundary analysis |
|
||||
|
||||
```bash
|
||||
# Quick quality check
|
||||
python scripts/video/video_quality.py output.mp4 --chunk-size 32
|
||||
|
||||
# Compare against reference
|
||||
python scripts/video/compare_videos.py reference.mp4 output.mp4 --diff-video diff.mp4
|
||||
```
|
||||
|
||||
### Model-specific diagnostics (`scripts/helios/`)
|
||||
|
||||
| Script | Purpose |
|
||||
|--------|---------|
|
||||
| `analyze_boundaries.py` | Detailed boundary quality metrics for Helios |
|
||||
| `run_reference.py` | Run PyTorch reference on MPS |
|
||||
| `compare_pipelines.py` | Compare scheduler/pipeline mechanics |
|
||||
| `compare_models.py` | Cross-framework model output comparison |
|
||||
|
||||
### Inline debugging pattern
|
||||
|
||||
Add temporary debug output to the diffusion loop:
|
||||
|
||||
```python
|
||||
for i, sigma in enumerate(sigmas):
|
||||
flow = model(latent, sigma, text_emb)
|
||||
latent = scheduler.step(latent, flow, sigma, sigma_next)
|
||||
|
||||
# Debug: track statistics
|
||||
print(f"[step {i}] sigma={sigma:.4f} "
|
||||
f"latent: mean={latent.mean():.6f} std={latent.std():.6f} "
|
||||
f"flow: mean={flow.mean():.6f} std={flow.std():.6f}")
|
||||
|
||||
# Debug: save for cross-framework comparison
|
||||
if os.environ.get("DEBUG"):
|
||||
mx.save(f"/tmp/debug_step_{i}.npz", {
|
||||
"latent": latent, "flow": flow, "sigma": mx.array(sigma)
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Takeaways
|
||||
|
||||
1. **Precision is the #1 bug source** — bfloat16 residuals, scheduler math, type
|
||||
promotion, modulation vectors. Copy the reference's `.float()` and `autocast` zones.
|
||||
|
||||
2. **Don't add what the reference doesn't have** — cross-fade, overlap decode,
|
||||
temporal blending. If the reference works without it, you probably have a bug
|
||||
elsewhere.
|
||||
|
||||
3. **Silent failures are the hardest bugs** — underscore-prefixed weights, `strict=False`
|
||||
weight loading, wrong normalization class names. Always verify weight load counts
|
||||
and output statistics.
|
||||
|
||||
4. **Component isolation → integration testing** — verify each part matches, then
|
||||
debug their interaction.
|
||||
|
||||
5. **Statistical comparison beats visual inspection** — mean drift, contrast ratios,
|
||||
and cosine similarity catch bugs before they're visible.
|
||||
|
||||
6. **Autoregressive errors compound** — a 1% error per chunk becomes 10% by chunk 10.
|
||||
Fix precision first, add corrections second.
|
||||
|
||||
7. **MLX has unique pitfalls** — underscore attribute names, reshape axis ordering,
|
||||
dtype-unaware padding, and float32 promotion cascades. Know your framework.
|
||||
|
||||
8. **Text preprocessing matters** — Unicode normalization, fullwidth chars, HTML entities.
|
||||
A single mismatched comma can break CFG guidance.
|
||||
|
||||
9. **VAE is deceptively complex** — causal padding, temporal frame counts, chunked vs
|
||||
batch processing, norm implementations. Budget significant debugging time for VAE.
|
||||
BIN
examples/poodles-wan.gif
Normal file
BIN
examples/poodles-wan.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5.4 MiB |
@@ -4,26 +4,25 @@ from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
|
||||
from mlx_video.models.ltx_2.audio_vae import (
|
||||
AudioDecoder,
|
||||
AudioEncoder,
|
||||
AudioLatentShape,
|
||||
AudioPatchifier,
|
||||
PerChannelStatistics,
|
||||
Vocoder,
|
||||
decode_audio,
|
||||
AudioPatchifier,
|
||||
AudioLatentShape,
|
||||
PerChannelStatistics,
|
||||
)
|
||||
|
||||
# Conditioning
|
||||
from mlx_video.models.ltx_2.conditioning import (
|
||||
VideoConditionByLatentIndex,
|
||||
)
|
||||
from mlx_video.models.ltx_2.conditioning import VideoConditionByLatentIndex
|
||||
|
||||
# Utilities
|
||||
from mlx_video.models.ltx_2.utils import (
|
||||
convert_audio_encoder,
|
||||
get_model_path,
|
||||
load_safetensors,
|
||||
load_config,
|
||||
load_safetensors,
|
||||
save_weights,
|
||||
)
|
||||
from mlx_video.models.wan_2 import WanModel, WanModelConfig
|
||||
|
||||
__all__ = [
|
||||
# Models
|
||||
@@ -45,4 +44,7 @@ __all__ = [
|
||||
"load_safetensors",
|
||||
"load_config",
|
||||
"save_weights",
|
||||
# Wan Models
|
||||
"WanModel",
|
||||
"WanModelConfig",
|
||||
]
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
"""Stub — delegates to mlx_video.models.ltx_2.utils."""
|
||||
from mlx_video.models.ltx_2.utils import * # noqa: F401,F403
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Entry point stub — delegates to mlx_video.models.ltx_2.generate."""
|
||||
from mlx_video.models.ltx_2.generate import main, generate_video
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
22
mlx_video/lora/__init__.py
Normal file
22
mlx_video/lora/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""LoRA support for mlx-video."""
|
||||
|
||||
from mlx_video.lora.apply import (
|
||||
LoRALinear,
|
||||
apply_lora_to_linear,
|
||||
apply_loras_to_model,
|
||||
apply_loras_to_weights,
|
||||
)
|
||||
from mlx_video.lora.loader import load_lora_weights, load_multiple_loras
|
||||
from mlx_video.lora.types import AppliedLoRA, LoRAConfig, LoRAWeights
|
||||
|
||||
__all__ = [
|
||||
"LoRAConfig",
|
||||
"LoRAWeights",
|
||||
"AppliedLoRA",
|
||||
"load_lora_weights",
|
||||
"load_multiple_loras",
|
||||
"apply_lora_to_linear",
|
||||
"apply_loras_to_weights",
|
||||
"apply_loras_to_model",
|
||||
"LoRALinear",
|
||||
]
|
||||
421
mlx_video/lora/apply.py
Normal file
421
mlx_video/lora/apply.py
Normal file
@@ -0,0 +1,421 @@
|
||||
"""Apply LoRA weights to model layers."""
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
|
||||
def apply_lora_to_linear(
|
||||
linear_weight: mx.array,
|
||||
lora_weights_and_strengths: List[Tuple[LoRAWeights, float]],
|
||||
) -> mx.array:
|
||||
"""Apply one or more LoRAs to a linear layer weight.
|
||||
|
||||
Args:
|
||||
linear_weight: Original weight matrix [out_features, in_features]
|
||||
lora_weights_and_strengths: List of (LoRAWeights, strength) tuples
|
||||
|
||||
Returns:
|
||||
Modified weight with LoRA deltas applied (preserves original dtype)
|
||||
"""
|
||||
orig_dtype = linear_weight.dtype
|
||||
modified_weight = linear_weight
|
||||
|
||||
for weights, strength in lora_weights_and_strengths:
|
||||
scale = weights.scale
|
||||
# Compute delta in float32 for precision, then cast back to avoid
|
||||
# promoting model weights (e.g. bfloat16 → float32 causes ~1.5x slowdown)
|
||||
delta = (weights.lora_B @ weights.lora_A) * (scale * strength)
|
||||
modified_weight = modified_weight + delta.astype(orig_dtype)
|
||||
|
||||
return modified_weight
|
||||
|
||||
|
||||
def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str:
|
||||
"""Normalize LoRA module name to match Wan2.2 MLX model weight keys.
|
||||
|
||||
Handles:
|
||||
- Stripping common prefixes (diffusion_model., model., etc.)
|
||||
- FFN key mapping: ffn.0 → ffn.fc1, ffn.2 → ffn.fc2
|
||||
- Embedding key mapping: text_embedding.0 → text_embedding_0, etc.
|
||||
- Time projection: time_projection.1 → time_projection
|
||||
- Patch embedding: patch_embedding → patch_embedding_proj
|
||||
|
||||
Args:
|
||||
lora_key: Original LoRA module name
|
||||
model_keys: Set of all model weight keys
|
||||
|
||||
Returns:
|
||||
Normalized key that matches model weights
|
||||
"""
|
||||
# Try the key as-is first
|
||||
if f"{lora_key}.weight" in model_keys or lora_key in model_keys:
|
||||
return lora_key
|
||||
|
||||
# Common prefixes to strip
|
||||
prefixes_to_strip = [
|
||||
"model.diffusion_model.",
|
||||
"diffusion_model.",
|
||||
"base_model.model.",
|
||||
"model.",
|
||||
]
|
||||
|
||||
candidates = [lora_key]
|
||||
for prefix in prefixes_to_strip:
|
||||
if lora_key.startswith(prefix):
|
||||
candidates.append(lora_key[len(prefix) :])
|
||||
|
||||
for candidate in candidates:
|
||||
# Try as-is
|
||||
if f"{candidate}.weight" in model_keys or candidate in model_keys:
|
||||
return candidate
|
||||
|
||||
# Apply Wan2.2 key transformations
|
||||
transformed = candidate
|
||||
|
||||
# FFN: ffn.0 → ffn.fc1, ffn.2 → ffn.fc2
|
||||
transformed = transformed.replace(".ffn.0.", ".ffn.fc1.")
|
||||
transformed = transformed.replace(".ffn.2.", ".ffn.fc2.")
|
||||
if transformed.endswith(".ffn.0"):
|
||||
transformed = transformed[: -len(".ffn.0")] + ".ffn.fc1"
|
||||
if transformed.endswith(".ffn.2"):
|
||||
transformed = transformed[: -len(".ffn.2")] + ".ffn.fc2"
|
||||
|
||||
# Text embedding: text_embedding.0 → text_embedding_0
|
||||
transformed = transformed.replace("text_embedding.0.", "text_embedding_0.")
|
||||
transformed = transformed.replace("text_embedding.2.", "text_embedding_1.")
|
||||
if transformed.endswith("text_embedding.0"):
|
||||
transformed = transformed[: -len("text_embedding.0")] + "text_embedding_0"
|
||||
if transformed.endswith("text_embedding.2"):
|
||||
transformed = transformed[: -len("text_embedding.2")] + "text_embedding_1"
|
||||
|
||||
# Time embedding: time_embedding.0 → time_embedding_0
|
||||
transformed = transformed.replace("time_embedding.0.", "time_embedding_0.")
|
||||
transformed = transformed.replace("time_embedding.2.", "time_embedding_1.")
|
||||
if transformed.endswith("time_embedding.0"):
|
||||
transformed = transformed[: -len("time_embedding.0")] + "time_embedding_0"
|
||||
if transformed.endswith("time_embedding.2"):
|
||||
transformed = transformed[: -len("time_embedding.2")] + "time_embedding_1"
|
||||
|
||||
# Time projection: time_projection.1 → time_projection
|
||||
transformed = transformed.replace("time_projection.1.", "time_projection.")
|
||||
if transformed.endswith("time_projection.1"):
|
||||
transformed = transformed[: -len("time_projection.1")] + "time_projection"
|
||||
|
||||
# Patch embedding: patch_embedding → patch_embedding_proj
|
||||
if (
|
||||
"patch_embedding" in transformed
|
||||
and "patch_embedding_proj" not in transformed
|
||||
):
|
||||
transformed = transformed.replace("patch_embedding", "patch_embedding_proj")
|
||||
|
||||
if f"{transformed}.weight" in model_keys or transformed in model_keys:
|
||||
return transformed
|
||||
|
||||
# Return best attempt with prefix stripped
|
||||
for prefix in prefixes_to_strip:
|
||||
if lora_key.startswith(prefix):
|
||||
return lora_key[len(prefix) :]
|
||||
|
||||
return lora_key
|
||||
|
||||
|
||||
# Also support LTX-style key normalization
|
||||
def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str:
|
||||
"""Normalize LoRA module name to match LTX MLX model weight keys."""
|
||||
if f"{lora_key}.weight" in model_keys or lora_key in model_keys:
|
||||
return lora_key
|
||||
|
||||
prefixes_to_strip = [
|
||||
"model.diffusion_model.",
|
||||
"diffusion_model.",
|
||||
"model.",
|
||||
]
|
||||
|
||||
for prefix in prefixes_to_strip:
|
||||
if lora_key.startswith(prefix):
|
||||
normalized = lora_key[len(prefix) :]
|
||||
|
||||
if f"{normalized}.weight" in model_keys or normalized in model_keys:
|
||||
return normalized
|
||||
|
||||
transformed = normalized
|
||||
if transformed.endswith(".to_out.0"):
|
||||
transformed = transformed[: -len(".to_out.0")] + ".to_out"
|
||||
transformed = transformed.replace(".to_out.0.", ".to_out.")
|
||||
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
|
||||
transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
transformed = transformed.replace(".ff.net.2", ".ff.proj_out")
|
||||
transformed = transformed.replace(
|
||||
".audio_ff.net.0.proj.", ".audio_ff.proj_in."
|
||||
)
|
||||
transformed = transformed.replace(
|
||||
".audio_ff.net.0.proj", ".audio_ff.proj_in"
|
||||
)
|
||||
transformed = transformed.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
||||
transformed = transformed.replace(".audio_ff.net.2", ".audio_ff.proj_out")
|
||||
|
||||
if f"{transformed}.weight" in model_keys or transformed in model_keys:
|
||||
return transformed
|
||||
|
||||
# Try transformations on the original key
|
||||
transformed = lora_key
|
||||
if transformed.endswith(".to_out.0"):
|
||||
transformed = transformed[: -len(".to_out.0")] + ".to_out"
|
||||
transformed = transformed.replace(".to_out.0.", ".to_out.")
|
||||
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
|
||||
transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
transformed = transformed.replace(".ff.net.2", ".ff.proj_out")
|
||||
|
||||
if f"{transformed}.weight" in model_keys or transformed in model_keys:
|
||||
return transformed
|
||||
|
||||
for prefix in prefixes_to_strip:
|
||||
if lora_key.startswith(prefix):
|
||||
return lora_key[len(prefix) :]
|
||||
|
||||
return lora_key
|
||||
|
||||
|
||||
def _normalize_lora_key(lora_key: str, model_keys: set) -> str:
|
||||
"""Normalize LoRA module name to match model weight keys.
|
||||
|
||||
Auto-detects whether to use Wan2.2 or LTX key normalization based
|
||||
on the presence of architecture-specific keys in the model.
|
||||
"""
|
||||
# Detect model architecture from keys
|
||||
is_wan = any("self_attn.q.weight" in k for k in model_keys)
|
||||
|
||||
if is_wan:
|
||||
return _normalize_wan_lora_key(lora_key, model_keys)
|
||||
else:
|
||||
return _normalize_ltx_lora_key(lora_key, model_keys)
|
||||
|
||||
|
||||
def apply_loras_to_weights(
|
||||
model_weights: Dict[str, mx.array],
|
||||
module_to_loras: Dict[str, List[Tuple[LoRAWeights, float]]],
|
||||
verbose: bool = False,
|
||||
quantization_bits: int = 0,
|
||||
) -> Dict[str, mx.array]:
|
||||
"""Apply LoRAs to model weights.
|
||||
|
||||
Args:
|
||||
model_weights: Original model state dictionary
|
||||
module_to_loras: Dictionary mapping module names to lists of
|
||||
(LoRAWeights, strength) tuples
|
||||
verbose: If True, print detailed debug information
|
||||
quantization_bits: If >0, weights are quantized at this bit width.
|
||||
Quantized layers are dequantized before LoRA application
|
||||
and re-quantized after.
|
||||
|
||||
Returns:
|
||||
New state dictionary with LoRA-modified weights
|
||||
"""
|
||||
modified_weights = dict(model_weights)
|
||||
model_keys = set(model_weights.keys())
|
||||
|
||||
applied_count = 0
|
||||
skipped_count = 0
|
||||
skipped_modules = []
|
||||
|
||||
for module_name, loras in module_to_loras.items():
|
||||
normalized_name = _normalize_lora_key(module_name, model_keys)
|
||||
weight_key = f"{normalized_name}.weight"
|
||||
|
||||
if weight_key not in modified_weights:
|
||||
if normalized_name not in modified_weights:
|
||||
skipped_count += 1
|
||||
skipped_modules.append(module_name)
|
||||
if verbose and skipped_count <= 5:
|
||||
print(
|
||||
f" DEBUG: '{module_name}' -> '{normalized_name}' -> NOT FOUND"
|
||||
)
|
||||
similar = [
|
||||
k
|
||||
for k in list(model_keys)[:1000]
|
||||
if normalized_name.split(".")[-1] in k
|
||||
][:3]
|
||||
if similar:
|
||||
print(f" Similar keys: {similar}")
|
||||
continue
|
||||
weight_key = normalized_name
|
||||
|
||||
original_weight = modified_weights[weight_key]
|
||||
|
||||
# Handle quantized weights: dequantize → apply delta → re-quantize
|
||||
scales_key = f"{normalized_name}.scales"
|
||||
biases_key = f"{normalized_name}.biases"
|
||||
is_quantized = (
|
||||
original_weight.dtype == mx.uint32
|
||||
and scales_key in modified_weights
|
||||
and biases_key in modified_weights
|
||||
)
|
||||
|
||||
if is_quantized:
|
||||
scales = modified_weights[scales_key]
|
||||
biases = modified_weights[biases_key]
|
||||
group_size = (original_weight.shape[-1] * 32) // (
|
||||
scales.shape[-1] * quantization_bits
|
||||
)
|
||||
dequantized = mx.dequantize(
|
||||
original_weight,
|
||||
scales,
|
||||
biases,
|
||||
group_size=group_size,
|
||||
bits=quantization_bits,
|
||||
)
|
||||
modified = apply_lora_to_linear(dequantized, loras)
|
||||
# Re-quantize with same parameters
|
||||
new_w, new_scales, new_biases = mx.quantize(
|
||||
modified, group_size=group_size, bits=quantization_bits
|
||||
)
|
||||
modified_weights[weight_key] = new_w
|
||||
modified_weights[scales_key] = new_scales
|
||||
modified_weights[biases_key] = new_biases
|
||||
else:
|
||||
modified_weights[weight_key] = apply_lora_to_linear(original_weight, loras)
|
||||
|
||||
applied_count += 1
|
||||
|
||||
if applied_count > 0:
|
||||
print(f" ✓ Applied to {applied_count} modules")
|
||||
if skipped_count > 0:
|
||||
print(f" ⚠ Skipped {skipped_count} incompatible modules")
|
||||
|
||||
return modified_weights
|
||||
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
"""Linear layer with on-the-fly LoRA application.
|
||||
|
||||
Wraps nn.Linear or nn.QuantizedLinear, computing LoRA delta at runtime:
|
||||
output = base_linear(x) + (x @ lora_A.T @ lora_B.T) * scale * strength
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
linear: nn.Module,
|
||||
lora_weights_and_strengths: List[Tuple[LoRAWeights, float]],
|
||||
):
|
||||
super().__init__()
|
||||
self.linear = linear
|
||||
self.lora_weights_and_strengths = lora_weights_and_strengths
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
output = self.linear(x)
|
||||
for weights, strength in self.lora_weights_and_strengths:
|
||||
scale = weights.scale
|
||||
lora_out = x @ weights.lora_A.T @ weights.lora_B.T
|
||||
output = output + (scale * strength * lora_out)
|
||||
return output
|
||||
|
||||
|
||||
def apply_loras_to_model(
|
||||
model: nn.Module,
|
||||
module_to_loras: Dict[str, List[Tuple[LoRAWeights, float]]],
|
||||
verbose: bool = False,
|
||||
) -> int:
|
||||
"""Apply LoRAs to a model by merging into weights.
|
||||
|
||||
For QuantizedLinear layers: dequantizes to bf16, merges LoRA delta, and
|
||||
replaces with a regular nn.Linear (no per-step overhead, no re-quantization
|
||||
precision loss). Non-LoRA layers stay quantized.
|
||||
|
||||
For nn.Linear layers: merges LoRA delta directly into the weight.
|
||||
|
||||
Args:
|
||||
model: The model to apply LoRAs to
|
||||
module_to_loras: Dictionary mapping module names to (LoRAWeights, strength) lists
|
||||
verbose: Print debug info
|
||||
|
||||
Returns:
|
||||
Number of modules modified
|
||||
"""
|
||||
# Build a set of model module paths for key normalization
|
||||
module_paths = set()
|
||||
for name, _ in model.named_modules():
|
||||
module_paths.add(name)
|
||||
module_paths.add(f"{name}.weight")
|
||||
|
||||
# Map LoRA keys → model module paths
|
||||
lora_to_module = {}
|
||||
for lora_key in module_to_loras:
|
||||
normalized = _normalize_lora_key(lora_key, module_paths)
|
||||
if normalized.endswith(".weight"):
|
||||
normalized = normalized[: -len(".weight")]
|
||||
lora_to_module[lora_key] = normalized
|
||||
|
||||
applied_count = 0
|
||||
dequant_count = 0
|
||||
skipped = []
|
||||
|
||||
for lora_key, loras in module_to_loras.items():
|
||||
module_path = lora_to_module[lora_key]
|
||||
parts = module_path.split(".")
|
||||
|
||||
# Traverse to the parent module
|
||||
parent = model
|
||||
try:
|
||||
for part in parts[:-1]:
|
||||
parent = (
|
||||
getattr(parent, part) if not part.isdigit() else parent[int(part)]
|
||||
)
|
||||
leaf_name = parts[-1]
|
||||
target = (
|
||||
getattr(parent, leaf_name)
|
||||
if not leaf_name.isdigit()
|
||||
else parent[int(leaf_name)]
|
||||
)
|
||||
except (AttributeError, IndexError, TypeError):
|
||||
skipped.append(lora_key)
|
||||
if verbose:
|
||||
print(f" DEBUG: '{lora_key}' -> '{module_path}' -> module not found")
|
||||
continue
|
||||
|
||||
if isinstance(target, nn.QuantizedLinear):
|
||||
# Dequantize → merge LoRA → replace with bf16 Linear
|
||||
weight = mx.dequantize(
|
||||
target.weight,
|
||||
target.scales,
|
||||
target.biases,
|
||||
group_size=target.group_size,
|
||||
bits=target.bits,
|
||||
)
|
||||
merged = apply_lora_to_linear(weight, loras)
|
||||
new_linear = nn.Linear(merged.shape[1], merged.shape[0])
|
||||
new_linear.weight = merged
|
||||
if "bias" in target:
|
||||
new_linear.bias = target.bias
|
||||
if leaf_name.isdigit():
|
||||
parent[int(leaf_name)] = new_linear
|
||||
else:
|
||||
setattr(parent, leaf_name, new_linear)
|
||||
dequant_count += 1
|
||||
applied_count += 1
|
||||
elif isinstance(target, nn.Linear):
|
||||
# Merge directly into weight
|
||||
target.weight = apply_lora_to_linear(target.weight, loras)
|
||||
applied_count += 1
|
||||
else:
|
||||
skipped.append(lora_key)
|
||||
if verbose:
|
||||
print(
|
||||
f" DEBUG: '{module_path}' is {type(target).__name__}, not Linear"
|
||||
)
|
||||
continue
|
||||
|
||||
if applied_count > 0:
|
||||
msg = f" ✓ Applied to {applied_count} modules"
|
||||
if dequant_count > 0:
|
||||
msg += f" ({dequant_count} dequantized to bf16)"
|
||||
print(msg)
|
||||
if skipped:
|
||||
print(f" ⚠ Skipped {len(skipped)} incompatible modules")
|
||||
|
||||
return applied_count
|
||||
122
mlx_video/lora/loader.py
Normal file
122
mlx_video/lora/loader.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""LoRA weight loading utilities."""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from mlx_video.lora.types import LoRAConfig, LoRAWeights
|
||||
|
||||
|
||||
def load_lora_weights(lora_path: Path) -> Dict[str, LoRAWeights]:
|
||||
"""Load LoRA weights from a safetensors file.
|
||||
|
||||
Supports both key conventions:
|
||||
- {module_name}.lora_A.weight / {module_name}.lora_B.weight
|
||||
- {module_name}.lora_down.weight / {module_name}.lora_up.weight
|
||||
|
||||
Args:
|
||||
lora_path: Path to the LoRA safetensors file
|
||||
|
||||
Returns:
|
||||
Dictionary mapping module names to LoRAWeights objects
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the LoRA file doesn't exist
|
||||
ValueError: If the LoRA file format is invalid
|
||||
"""
|
||||
if not lora_path.exists():
|
||||
raise FileNotFoundError(f"LoRA file not found: {lora_path}")
|
||||
|
||||
all_weights = mx.load(str(lora_path))
|
||||
|
||||
# Group weights by module name, handling both naming conventions
|
||||
lora_weights = {}
|
||||
module_names = set()
|
||||
|
||||
for key in all_weights.keys():
|
||||
# Format 1: {module}.lora_A.weight / {module}.lora_B.weight
|
||||
match = re.match(r"(.+)\.lora_([AB])\.weight$", key)
|
||||
if match:
|
||||
module_names.add(match.group(1))
|
||||
continue
|
||||
# Format 2: {module}.lora_down.weight / {module}.lora_up.weight
|
||||
match = re.match(r"(.+)\.lora_(down|up)\.weight$", key)
|
||||
if match:
|
||||
module_names.add(match.group(1))
|
||||
|
||||
for module_name in module_names:
|
||||
# Try both key conventions
|
||||
key_a = f"{module_name}.lora_A.weight"
|
||||
key_b = f"{module_name}.lora_B.weight"
|
||||
if key_a not in all_weights or key_b not in all_weights:
|
||||
key_a = f"{module_name}.lora_down.weight"
|
||||
key_b = f"{module_name}.lora_up.weight"
|
||||
if key_a not in all_weights or key_b not in all_weights:
|
||||
continue
|
||||
|
||||
lora_a = all_weights[key_a]
|
||||
lora_b = all_weights[key_b]
|
||||
|
||||
if lora_a.ndim != 2 or lora_b.ndim != 2:
|
||||
raise ValueError(
|
||||
f"Invalid LoRA shape for {module_name}: "
|
||||
f"lora_A={lora_a.shape}, lora_B={lora_b.shape}"
|
||||
)
|
||||
|
||||
rank = lora_a.shape[0]
|
||||
if lora_b.shape[1] != rank:
|
||||
raise ValueError(
|
||||
f"LoRA rank mismatch for {module_name}: "
|
||||
f"lora_A rank={rank}, lora_B rank={lora_b.shape[1]}"
|
||||
)
|
||||
|
||||
# Check for per-module alpha stored as a scalar tensor
|
||||
alpha_key = f"{module_name}.alpha"
|
||||
if alpha_key in all_weights:
|
||||
alpha = float(all_weights[alpha_key].item())
|
||||
else:
|
||||
alpha = float(rank)
|
||||
|
||||
lora_weights[module_name] = LoRAWeights(
|
||||
lora_A=lora_a,
|
||||
lora_B=lora_b,
|
||||
rank=rank,
|
||||
alpha=alpha,
|
||||
module_name=module_name,
|
||||
)
|
||||
|
||||
if not lora_weights:
|
||||
raise ValueError(f"No valid LoRA weights found in {lora_path}")
|
||||
|
||||
return lora_weights
|
||||
|
||||
|
||||
def load_multiple_loras(
|
||||
configs: List[LoRAConfig],
|
||||
) -> Dict[str, List[tuple]]:
|
||||
"""Load multiple LoRA configurations.
|
||||
|
||||
Args:
|
||||
configs: List of LoRAConfig objects
|
||||
|
||||
Returns:
|
||||
Dictionary mapping module names to lists of (LoRAWeights, strength) tuples.
|
||||
"""
|
||||
module_to_loras: Dict[str, list] = {}
|
||||
|
||||
for config in configs:
|
||||
lora_weights = load_lora_weights(config.path)
|
||||
|
||||
for module_name, weights in lora_weights.items():
|
||||
if config.target_modules is not None:
|
||||
if module_name not in config.target_modules:
|
||||
continue
|
||||
|
||||
if module_name not in module_to_loras:
|
||||
module_to_loras[module_name] = []
|
||||
|
||||
module_to_loras[module_name].append((weights, config.strength))
|
||||
|
||||
return module_to_loras
|
||||
74
mlx_video/lora/types.py
Normal file
74
mlx_video/lora/types.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Data structures for LoRA support."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAWeights:
|
||||
"""Container for LoRA weight matrices.
|
||||
|
||||
Attributes:
|
||||
lora_A: Low-rank matrix A of shape [rank, in_features]
|
||||
lora_B: Low-rank matrix B of shape [out_features, rank]
|
||||
rank: Rank of the LoRA decomposition
|
||||
alpha: LoRA scaling parameter (default: rank)
|
||||
module_name: Target module name in the model
|
||||
"""
|
||||
|
||||
lora_A: mx.array
|
||||
lora_B: mx.array
|
||||
rank: int
|
||||
alpha: float
|
||||
module_name: str
|
||||
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
"""Compute the scale factor: alpha / rank."""
|
||||
return self.alpha / self.rank
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAConfig:
|
||||
"""Configuration for a single LoRA.
|
||||
|
||||
Attributes:
|
||||
path: Path to the LoRA safetensors file
|
||||
strength: Strength/weight to apply this LoRA (typically 0.0-2.0)
|
||||
target_modules: Optional list of module names to apply LoRA to.
|
||||
If None, applies to all available modules in the LoRA.
|
||||
"""
|
||||
|
||||
path: Path
|
||||
strength: float = 1.0
|
||||
target_modules: Optional[list[str]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate and normalize the configuration."""
|
||||
self.path = Path(self.path)
|
||||
if not self.path.exists():
|
||||
raise FileNotFoundError(f"LoRA file not found: {self.path}")
|
||||
if self.strength < 0:
|
||||
raise ValueError(f"LoRA strength must be non-negative, got {self.strength}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppliedLoRA:
|
||||
"""Represents a LoRA applied to a specific module.
|
||||
|
||||
Attributes:
|
||||
weights: The LoRA weight matrices
|
||||
strength: Application strength for this LoRA
|
||||
"""
|
||||
|
||||
weights: LoRAWeights
|
||||
strength: float
|
||||
|
||||
def compute_delta(self) -> mx.array:
|
||||
"""Compute the weight delta: strength * scale * (lora_B @ lora_A)."""
|
||||
scale = self.weights.scale
|
||||
delta = self.weights.lora_B @ self.weights.lora_A
|
||||
return scale * self.strength * delta
|
||||
@@ -1,2 +1,2 @@
|
||||
|
||||
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
|
||||
from mlx_video.models.wan_2 import WanModel, WanModelConfig
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
|
||||
from mlx_video.models.ltx_2.audio_vae import AudioDecoder, Vocoder, decode_audio
|
||||
from mlx_video.models.ltx_2.config import (
|
||||
LTXModelConfig,
|
||||
TransformerConfig,
|
||||
LTXModelType,
|
||||
TransformerConfig,
|
||||
)
|
||||
from mlx_video.models.ltx_2.ltx import LTXModel, X0Model
|
||||
from mlx_video.models.ltx_2.audio_vae import AudioDecoder, Vocoder, decode_audio
|
||||
from mlx_video.models.ltx_2.ltx_2 import LTXModel, X0Model
|
||||
|
||||
@@ -8,7 +8,6 @@ from mlx_video.utils import get_timestep_embedding
|
||||
|
||||
class AdaLayerNormSingle(nn.Module):
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
@@ -24,7 +23,9 @@ class AdaLayerNormSingle(nn.Module):
|
||||
)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True)
|
||||
self.linear = nn.Linear(
|
||||
embedding_dim, embedding_coefficient * embedding_dim, bias=True
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -63,8 +64,12 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
self.size_emb_dim = size_emb_dim
|
||||
self.use_additional_conditions = use_additional_conditions
|
||||
|
||||
self.time_proj = Timesteps(timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(timestep_proj_dim, embedding_dim, out_dim=embedding_dim)
|
||||
self.time_proj = Timesteps(
|
||||
timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0
|
||||
)
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
timestep_proj_dim, embedding_dim, out_dim=embedding_dim
|
||||
)
|
||||
|
||||
if use_additional_conditions and size_emb_dim > 0:
|
||||
self.additional_embedder = ConditionEmbedding(size_emb_dim, embedding_dim)
|
||||
@@ -87,7 +92,9 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
# Add additional conditions if enabled
|
||||
if self.use_additional_conditions and self.size_emb_dim > 0:
|
||||
if resolution is not None and aspect_ratio is not None:
|
||||
additional_embeds = self.additional_embedder(resolution, aspect_ratio, hidden_dtype)
|
||||
additional_embeds = self.additional_embedder(
|
||||
resolution, aspect_ratio, hidden_dtype
|
||||
)
|
||||
timesteps_emb = timesteps_emb + additional_embeds
|
||||
|
||||
return timesteps_emb
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""Audio VAE module for LTX-2 audio generation."""
|
||||
|
||||
from .attention import AttentionType, AttnBlock, make_attn
|
||||
from .audio_vae import AudioDecoder, AudioEncoder, decode_audio
|
||||
from .audio_processor import load_audio, ensure_stereo, waveform_to_mel
|
||||
from .causal_conv_2d import CausalConv2d, make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .attention import AttentionType, AttnBlock, make_attn
|
||||
from .audio_processor import ensure_stereo, load_audio, waveform_to_mel
|
||||
from .audio_vae import AudioDecoder, AudioEncoder, decode_audio
|
||||
from .causal_conv_2d import CausalConv2d, make_conv2d
|
||||
from .downsample import Downsample, build_downsampling_path
|
||||
from .normalization import NormType, PixelNorm, build_normalization_layer
|
||||
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||
|
||||
@@ -32,7 +32,9 @@ class AttnBlock(nn.Module):
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""
|
||||
@@ -103,6 +105,8 @@ def make_attn(
|
||||
elif attn_type == AttentionType.NONE:
|
||||
return Identity()
|
||||
elif attn_type == AttentionType.LINEAR:
|
||||
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
||||
raise NotImplementedError(
|
||||
f"Attention type {attn_type.value} is not supported yet."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown attention type: {attn_type}")
|
||||
|
||||
@@ -4,10 +4,9 @@ Matches the PyTorch AudioProcessor from LTX-2 (torchaudio.transforms.MelSpectrog
|
||||
using librosa for macOS/MLX compatibility.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_audio(
|
||||
@@ -99,14 +98,16 @@ def waveform_to_mel(
|
||||
|
||||
for ch in range(channels):
|
||||
# Magnitude spectrogram (power=1.0)
|
||||
S = np.abs(librosa.stft(
|
||||
waveform[ch],
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
))
|
||||
S = np.abs(
|
||||
librosa.stft(
|
||||
waveform[ch],
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
)
|
||||
)
|
||||
|
||||
# Mel filterbank with slaney normalization
|
||||
mel_basis = librosa.filters.mel(
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
"""Audio VAE encoder and decoder for LTX-2."""
|
||||
|
||||
from typing import Dict
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_vlm.models.base import check_array_shape
|
||||
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig
|
||||
|
||||
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig, CausalityAxis
|
||||
from .attention import AttentionType, make_attn
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .downsample import build_downsampling_path
|
||||
from .normalization import NormType, build_normalization_layer
|
||||
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||
@@ -39,7 +39,9 @@ def build_mid_block(
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
mid["attn_1"] = (
|
||||
make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else None
|
||||
make_attn(channels, attn_type=attn_type, norm_type=norm_type)
|
||||
if add_attention
|
||||
else None
|
||||
)
|
||||
mid["block_2"] = ResnetBlock(
|
||||
in_channels=channels,
|
||||
@@ -93,7 +95,10 @@ class AudioEncoder(nn.Module):
|
||||
self.attn_type = config.attn_type
|
||||
|
||||
self.conv_in = make_conv2d(
|
||||
config.in_channels, self.ch, kernel_size=3, stride=1,
|
||||
config.in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
@@ -125,7 +130,10 @@ class AudioEncoder(nn.Module):
|
||||
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
|
||||
out_channels = 2 * config.z_channels if config.double_z else config.z_channels
|
||||
self.conv_out = make_conv2d(
|
||||
block_in, out_channels, kernel_size=3, stride=1,
|
||||
block_in,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
@@ -160,7 +168,11 @@ class AudioEncoder(nn.Module):
|
||||
continue
|
||||
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
|
||||
value = (
|
||||
value
|
||||
if check_array_shape(value)
|
||||
else mx.transpose(value, (0, 2, 3, 1))
|
||||
)
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
@@ -168,11 +180,14 @@ class AudioEncoder(nn.Module):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
|
||||
"""Load audio encoder from pretrained weights."""
|
||||
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
||||
import json
|
||||
|
||||
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
||||
|
||||
model_path = Path(model_path)
|
||||
config = AudioEncoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
||||
config = AudioEncoderModelConfig.from_dict(
|
||||
json.load(open(model_path / "config.json"))
|
||||
)
|
||||
encoder = cls(config)
|
||||
weights = mx.load(str(model_path / "model.safetensors"))
|
||||
encoder.load_weights(list(weights.items()), strict=True)
|
||||
@@ -265,7 +280,6 @@ class AudioDecoder(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
||||
# Per-channel statistics for denormalizing latents
|
||||
# Uses ch (base channel count) to match the patchified latent dimension
|
||||
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
|
||||
@@ -305,7 +319,11 @@ class AudioDecoder(nn.Module):
|
||||
self.z_shape = (1, config.z_channels, base_resolution, base_resolution)
|
||||
|
||||
self.conv_in = make_conv2d(
|
||||
config.z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
config.z_channels,
|
||||
base_block_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
self.mid = build_mid_block(
|
||||
@@ -334,9 +352,15 @@ class AudioDecoder(nn.Module):
|
||||
initial_block_channels=base_block_channels,
|
||||
)
|
||||
|
||||
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
|
||||
self.norm_out = build_normalization_layer(
|
||||
final_block_channels, normtype=self.norm_type
|
||||
)
|
||||
self.conv_out = make_conv2d(
|
||||
final_block_channels, config.out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
final_block_channels,
|
||||
config.out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
@@ -371,7 +395,11 @@ class AudioDecoder(nn.Module):
|
||||
# PyTorch: (out_channels, in_channels, H, W)
|
||||
# MLX: (out_channels, H, W, in_channels)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
|
||||
value = (
|
||||
value
|
||||
if check_array_shape(value)
|
||||
else mx.transpose(value, (0, 2, 3, 1))
|
||||
)
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
@@ -380,17 +408,19 @@ class AudioDecoder(nn.Module):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
|
||||
"""Load audio VAE decoder from pretrained model."""
|
||||
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
|
||||
import json
|
||||
|
||||
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
||||
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
|
||||
|
||||
config = AudioDecoderModelConfig.from_dict(
|
||||
json.load(open(model_path / "config.json"))
|
||||
)
|
||||
decoder = cls(config)
|
||||
weights = mx.load(str(model_path / "model.safetensors"))
|
||||
# weights = decoder.sanitize(weights)
|
||||
decoder.load_weights(list(weights.items()), strict=True)
|
||||
return decoder
|
||||
|
||||
|
||||
def __call__(self, sample: mx.array) -> mx.array:
|
||||
"""
|
||||
Decode latent features back to audio spectrograms.
|
||||
@@ -414,7 +444,9 @@ class AudioDecoder(nn.Module):
|
||||
|
||||
return self._adjust_output_shape(h, target_shape)
|
||||
|
||||
def _denormalize_latents(self, sample: mx.array) -> tuple[mx.array, AudioLatentShape]:
|
||||
def _denormalize_latents(
|
||||
self, sample: mx.array
|
||||
) -> tuple[mx.array, AudioLatentShape]:
|
||||
"""Denormalize latents using per-channel statistics."""
|
||||
# sample shape: (B, H, W, C) in MLX format
|
||||
latent_shape = AudioLatentShape(
|
||||
@@ -436,7 +468,9 @@ class AudioDecoder(nn.Module):
|
||||
batch=latent_shape.batch,
|
||||
channels=self.out_ch,
|
||||
frames=target_frames,
|
||||
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
|
||||
mel_bins=(
|
||||
self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins
|
||||
),
|
||||
)
|
||||
|
||||
return sample, target_shape
|
||||
@@ -462,7 +496,10 @@ class AudioDecoder(nn.Module):
|
||||
|
||||
# Step 1: Crop first to avoid exceeding target dimensions
|
||||
decoded_output = decoded_output[
|
||||
:, : min(current_time, target_time), : min(current_freq, target_freq), :target_channels
|
||||
:,
|
||||
: min(current_time, target_time),
|
||||
: min(current_freq, target_freq),
|
||||
:target_channels,
|
||||
]
|
||||
|
||||
# Step 2: Calculate padding needed for time and frequency dimensions
|
||||
@@ -514,7 +551,9 @@ class AudioDecoder(nn.Module):
|
||||
return mx.tanh(h) if self.tanh_out else h
|
||||
|
||||
|
||||
def decode_audio(latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder") -> mx.array:
|
||||
def decode_audio(
|
||||
latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder"
|
||||
) -> mx.array:
|
||||
"""
|
||||
Decode an audio latent representation using the provided audio decoder and vocoder.
|
||||
Args:
|
||||
|
||||
@@ -53,8 +53,16 @@ class CausalConv2d(nn.Module):
|
||||
# For (N, H, W, C) format: axis 1 is H (height), axis 2 is W (width)
|
||||
if self.causality_axis == CausalityAxis.NONE:
|
||||
# Non-causal: symmetric padding
|
||||
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w // 2, pad_w - pad_w // 2)
|
||||
elif self.causality_axis in (CausalityAxis.WIDTH, CausalityAxis.WIDTH_COMPATIBILITY):
|
||||
self.padding = (
|
||||
pad_h // 2,
|
||||
pad_h - pad_h // 2,
|
||||
pad_w // 2,
|
||||
pad_w - pad_w // 2,
|
||||
)
|
||||
elif self.causality_axis in (
|
||||
CausalityAxis.WIDTH,
|
||||
CausalityAxis.WIDTH_COMPATIBILITY,
|
||||
):
|
||||
# Causal on width: pad left (before width axis)
|
||||
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w, 0)
|
||||
elif self.causality_axis == CausalityAxis.HEIGHT:
|
||||
@@ -90,7 +98,10 @@ class CausalConv2d(nn.Module):
|
||||
if any(p > 0 for p in self.padding):
|
||||
# MLX pad expects: [(before_0, after_0), (before_1, after_1), ...]
|
||||
# For (N, H, W, C): axis 0=N, axis 1=H, axis 2=W, axis 3=C
|
||||
x = mx.pad(x, [(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)])
|
||||
x = mx.pad(
|
||||
x,
|
||||
[(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)],
|
||||
)
|
||||
|
||||
return self.conv(x)
|
||||
|
||||
@@ -124,7 +135,14 @@ def make_conv2d(
|
||||
if causality_axis is not None:
|
||||
# For causal convolution, padding is handled internally by CausalConv2d
|
||||
return CausalConv2d(
|
||||
in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
causality_axis,
|
||||
)
|
||||
else:
|
||||
# For non-causal convolution, use symmetric padding if not specified
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Set, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .attention import AttentionType, make_attn
|
||||
from ..config import CausalityAxis
|
||||
from .attention import AttentionType, make_attn
|
||||
from .normalization import NormType
|
||||
from .resnet import ResnetBlock
|
||||
|
||||
@@ -34,7 +34,9 @@ class Downsample(nn.Module):
|
||||
if self.with_conv:
|
||||
# Do time downsampling here
|
||||
# no asymmetric padding in MLX conv, must do it ourselves
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""
|
||||
@@ -116,10 +118,14 @@ def build_downsampling_path(
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
|
||||
stage["attn"][i_block] = make_attn(
|
||||
block_in, attn_type=attn_type, norm_type=norm_type
|
||||
)
|
||||
|
||||
if i_level != num_resolutions - 1:
|
||||
stage["downsample"] = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||
stage["downsample"] = Downsample(
|
||||
block_in, resamp_with_conv, causality_axis=causality_axis
|
||||
)
|
||||
curr_res = curr_res // 2
|
||||
|
||||
down_modules[i_level] = stage
|
||||
|
||||
@@ -51,7 +51,9 @@ def build_normalization_layer(
|
||||
A normalization layer
|
||||
"""
|
||||
if normtype == NormType.GROUP:
|
||||
return nn.GroupNorm(num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True)
|
||||
return nn.GroupNorm(
|
||||
num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
if normtype == NormType.PIXEL:
|
||||
# For MLX channels-last format (B, H, W, C), normalize along channels (dim=-1)
|
||||
# PyTorch uses dim=1 for channels-first format (B, C, H, W)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""ResNet blocks for audio VAE and vocoder."""
|
||||
|
||||
from typing import List, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from .normalization import NormType, build_normalization_layer
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
@@ -125,7 +125,11 @@ class ResnetBlock(nn.Module):
|
||||
|
||||
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
|
||||
self.conv1 = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
if temb_channels > 0:
|
||||
@@ -134,17 +138,29 @@ class ResnetBlock(nn.Module):
|
||||
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
|
||||
self.dropout_rate = dropout
|
||||
self.conv2 = make_conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
@@ -168,7 +184,9 @@ class ResnetBlock(nn.Module):
|
||||
if temb is not None and self.temb_channels > 0:
|
||||
# temb: (B, temb_channels) -> (B, out_channels)
|
||||
# Need to add spatial dims: (B, 1, 1, out_channels) for broadcasting
|
||||
h = h + mx.expand_dims(mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1)
|
||||
h = h + mx.expand_dims(
|
||||
mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1
|
||||
)
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nn.silu(h)
|
||||
|
||||
@@ -5,9 +5,9 @@ from typing import Set, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from ..config import CausalityAxis
|
||||
from .attention import AttentionType, make_attn
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .normalization import NormType
|
||||
from .resnet import ResnetBlock
|
||||
|
||||
@@ -42,7 +42,11 @@ class Upsample(nn.Module):
|
||||
self.causality_axis = causality_axis
|
||||
if self.with_conv:
|
||||
self.conv = make_conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
@@ -124,10 +128,14 @@ def build_upsampling_path(
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
|
||||
stage["attn"][i_block] = make_attn(
|
||||
block_in, attn_type=attn_type, norm_type=norm_type
|
||||
)
|
||||
|
||||
if level != 0:
|
||||
stage["upsample"] = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||
stage["upsample"] = Upsample(
|
||||
block_in, resamp_with_conv, causality_axis=causality_axis
|
||||
)
|
||||
curr_res *= 2
|
||||
|
||||
up_modules[level] = stage
|
||||
|
||||
@@ -7,8 +7,8 @@ Supports:
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -32,7 +32,9 @@ class Snake(nn.Module):
|
||||
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
self.alpha = (
|
||||
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# x: (N, L, C) in MLX format
|
||||
@@ -48,8 +50,12 @@ class SnakeBeta(nn.Module):
|
||||
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
self.beta = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
self.alpha = (
|
||||
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
)
|
||||
self.beta = (
|
||||
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
alpha = self.alpha
|
||||
@@ -73,7 +79,9 @@ def _sinc(x: mx.array) -> mx.array:
|
||||
)
|
||||
|
||||
|
||||
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> mx.array:
|
||||
def kaiser_sinc_filter1d(
|
||||
cutoff: float, half_width: float, kernel_size: int
|
||||
) -> mx.array:
|
||||
"""Compute a Kaiser-windowed sinc filter."""
|
||||
even = kernel_size % 2 == 0
|
||||
half_size = kernel_size // 2
|
||||
@@ -88,6 +96,7 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) ->
|
||||
|
||||
# Kaiser window - compute using scipy-compatible formula
|
||||
import numpy as np
|
||||
|
||||
window = mx.array(np.kaiser(kernel_size, beta).astype(np.float32))
|
||||
|
||||
if even:
|
||||
@@ -107,6 +116,7 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) ->
|
||||
def hann_sinc_filter1d(ratio: int) -> Tuple[mx.array, int, int, int]:
|
||||
"""Compute a Hann-windowed sinc filter for upsampling (used by BWE resampler)."""
|
||||
import numpy as np
|
||||
|
||||
rolloff = 0.99
|
||||
lowpass_filter_width = 6
|
||||
width = math.ceil(lowpass_filter_width / rolloff)
|
||||
@@ -187,10 +197,16 @@ class UpSample1d(nn.Module):
|
||||
self.kernel_size = filt.shape[2]
|
||||
self.filter = filt
|
||||
else:
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.kernel_size = (
|
||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
)
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
self.pad_left = (
|
||||
self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
)
|
||||
self.pad_right = (
|
||||
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
)
|
||||
self.filter = kaiser_sinc_filter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
@@ -215,10 +231,12 @@ class UpSample1d(nn.Module):
|
||||
filt = self.filter.astype(x.dtype) # (1, 1, K)
|
||||
filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1)
|
||||
|
||||
x = self.ratio * mx.conv_transpose1d(x, filt, stride=self.stride) # (N*C, L', 1)
|
||||
x = self.ratio * mx.conv_transpose1d(
|
||||
x, filt, stride=self.stride
|
||||
) # (N*C, L', 1)
|
||||
|
||||
# Trim padding
|
||||
x = x[:, self.pad_left:-self.pad_right, :]
|
||||
x = x[:, self.pad_left : -self.pad_right, :]
|
||||
|
||||
x = x.reshape(n, c, -1) # (N, C, L')
|
||||
x = mx.transpose(x, (0, 2, 1)) # (N, L', C)
|
||||
@@ -285,16 +303,24 @@ class AMPBlock1(nn.Module):
|
||||
|
||||
self.convs1 = {
|
||||
i: nn.Conv1d(
|
||||
channels, channels, kernel_size, stride=1,
|
||||
dilation=d, padding=get_padding(kernel_size, d),
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=d,
|
||||
padding=get_padding(kernel_size, d),
|
||||
)
|
||||
for i, d in enumerate(dilation)
|
||||
}
|
||||
|
||||
self.convs2 = {
|
||||
i: nn.Conv1d(
|
||||
channels, channels, kernel_size, stride=1,
|
||||
dilation=1, padding=get_padding(kernel_size, 1),
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
for i in range(len(dilation))
|
||||
}
|
||||
@@ -348,7 +374,9 @@ class STFTFn(nn.Module):
|
||||
y = mx.concatenate([first, y], axis=1)
|
||||
|
||||
# forward_basis: (514, 1, 512) PyTorch format -> (514, 512, 1) MLX
|
||||
basis = mx.transpose(self.forward_basis.astype(y.dtype), (0, 2, 1)) # (514, K, 1)
|
||||
basis = mx.transpose(
|
||||
self.forward_basis.astype(y.dtype), (0, 2, 1)
|
||||
) # (514, K, 1)
|
||||
|
||||
# Conv1d: (B, T, 1) * (514, K, 1) -> (B, T_frames, 514)
|
||||
spec = mx.conv1d(y, basis, stride=self.hop_length)
|
||||
@@ -358,8 +386,10 @@ class STFTFn(nn.Module):
|
||||
real = spec[..., :n_freqs]
|
||||
imag = spec[..., n_freqs:]
|
||||
|
||||
magnitude = mx.sqrt(real ** 2 + imag ** 2)
|
||||
phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype(real.dtype)
|
||||
magnitude = mx.sqrt(real**2 + imag**2)
|
||||
phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype(
|
||||
real.dtype
|
||||
)
|
||||
|
||||
# Output: (B, T_frames, n_freqs) in MLX channels-last
|
||||
return magnitude, phase
|
||||
@@ -368,7 +398,9 @@ class STFTFn(nn.Module):
|
||||
class MelSTFT(nn.Module):
|
||||
"""Causal log-mel spectrogram from precomputed STFT bases."""
|
||||
|
||||
def __init__(self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int) -> None:
|
||||
def __init__(
|
||||
self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.stft_fn = STFTFn(filter_length, hop_length, win_length)
|
||||
n_freqs = filter_length // 2 + 1
|
||||
@@ -385,7 +417,9 @@ class MelSTFT(nn.Module):
|
||||
"""
|
||||
magnitude, phase = self.stft_fn(y)
|
||||
# magnitude: (B, T_frames, n_freqs)
|
||||
mel = magnitude @ self.mel_basis.astype(magnitude.dtype).T # (B, T_frames, n_mels)
|
||||
mel = (
|
||||
magnitude @ self.mel_basis.astype(magnitude.dtype).T
|
||||
) # (B, T_frames, n_mels)
|
||||
log_mel = mx.log(mx.clip(mel, 1e-5, None))
|
||||
# Transpose to (B, n_mels, T_frames) for compatibility with vocoder input format
|
||||
return mx.transpose(log_mel, (0, 2, 1))
|
||||
@@ -415,8 +449,11 @@ class Vocoder(nn.Module):
|
||||
|
||||
in_channels = 128 if config.stereo else 64
|
||||
self.conv_pre = nn.Conv1d(
|
||||
in_channels, config.upsample_initial_channel,
|
||||
kernel_size=7, stride=1, padding=3,
|
||||
in_channels,
|
||||
config.upsample_initial_channel,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=3,
|
||||
)
|
||||
|
||||
# Upsampling layers
|
||||
@@ -424,11 +461,13 @@ class Vocoder(nn.Module):
|
||||
for i, (stride, kernel_size) in enumerate(
|
||||
zip(config.upsample_rates, config.upsample_kernel_sizes)
|
||||
):
|
||||
in_ch = config.upsample_initial_channel // (2 ** i)
|
||||
in_ch = config.upsample_initial_channel // (2**i)
|
||||
out_ch = config.upsample_initial_channel // (2 ** (i + 1))
|
||||
self.ups[i] = nn.ConvTranspose1d(
|
||||
in_ch, out_ch,
|
||||
kernel_size=kernel_size, stride=stride,
|
||||
in_ch,
|
||||
out_ch,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - stride) // 2,
|
||||
)
|
||||
|
||||
@@ -442,7 +481,9 @@ class Vocoder(nn.Module):
|
||||
config.resblock_kernel_sizes, config.resblock_dilation_sizes
|
||||
):
|
||||
self.resblocks[block_idx] = AMPBlock1(
|
||||
ch, kernel_size, tuple(dilations),
|
||||
ch,
|
||||
kernel_size,
|
||||
tuple(dilations),
|
||||
activation=config.activation,
|
||||
)
|
||||
block_idx += 1
|
||||
@@ -455,10 +496,14 @@ class Vocoder(nn.Module):
|
||||
for kernel_size, dilations in zip(
|
||||
config.resblock_kernel_sizes, config.resblock_dilation_sizes
|
||||
):
|
||||
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations))
|
||||
self.resblocks[block_idx] = resblock_class(
|
||||
ch, kernel_size, tuple(dilations)
|
||||
)
|
||||
block_idx += 1
|
||||
|
||||
final_channels = config.upsample_initial_channel // (2 ** len(config.upsample_rates))
|
||||
final_channels = config.upsample_initial_channel // (
|
||||
2 ** len(config.upsample_rates)
|
||||
)
|
||||
|
||||
# Post-activation
|
||||
if self.is_amp:
|
||||
@@ -468,8 +513,11 @@ class Vocoder(nn.Module):
|
||||
# Final conv
|
||||
out_channels = 2 if config.stereo else 1
|
||||
self.conv_post = nn.Conv1d(
|
||||
final_channels, out_channels,
|
||||
kernel_size=7, stride=1, padding=3,
|
||||
final_channels,
|
||||
out_channels,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=3,
|
||||
bias=config.use_bias_at_final,
|
||||
)
|
||||
|
||||
@@ -588,7 +636,9 @@ class VocoderWithBWE(nn.Module):
|
||||
"""
|
||||
x = self.vocoder(mel_spec) # (B, C, T) at input_sampling_rate
|
||||
_, _, length_low_rate = x.shape
|
||||
output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
|
||||
output_length = (
|
||||
length_low_rate * self.output_sampling_rate // self.input_sampling_rate
|
||||
)
|
||||
|
||||
# Pad to hop_length multiple
|
||||
remainder = length_low_rate % self.hop_length
|
||||
@@ -685,5 +735,3 @@ def _load_vocoder_with_bwe(config_dict: dict, weights: dict) -> VocoderWithBWE:
|
||||
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
"""Conditioning modules for LTX-2 video generation."""
|
||||
|
||||
from mlx_video.models.ltx_2.conditioning.latent import VideoConditionByLatentIndex, apply_conditioning
|
||||
from mlx_video.models.ltx_2.conditioning.latent import (
|
||||
VideoConditionByLatentIndex,
|
||||
apply_conditioning,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ the video generation process at specific frame positions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
@@ -22,6 +22,7 @@ class VideoConditionByLatentIndex:
|
||||
frame_idx: Frame index to condition (0 = first frame)
|
||||
strength: Denoising strength (1.0 = full denoise, 0.0 = keep original)
|
||||
"""
|
||||
|
||||
latent: mx.array
|
||||
frame_idx: int = 0
|
||||
strength: float = 1.0
|
||||
@@ -41,6 +42,7 @@ class LatentState:
|
||||
denoise_mask: Per-frame denoising mask (B, 1, F, 1, 1) where
|
||||
1.0 = full denoise, 0.0 = keep clean
|
||||
"""
|
||||
|
||||
latent: mx.array
|
||||
clean_latent: mx.array
|
||||
denoise_mask: mx.array
|
||||
@@ -130,15 +132,15 @@ def apply_conditioning(
|
||||
if frame_idx <= i < end_idx:
|
||||
# Use conditioning latent
|
||||
cond_idx = i - frame_idx
|
||||
latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
||||
clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
||||
latent_list.append(cond_latent[:, :, cond_idx : cond_idx + 1])
|
||||
clean_list.append(cond_latent[:, :, cond_idx : cond_idx + 1])
|
||||
# Set mask: 1.0 - strength means less denoising for conditioned frames
|
||||
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype))
|
||||
else:
|
||||
# Keep original
|
||||
latent_list.append(state.latent[:, :, i:i+1])
|
||||
clean_list.append(state.clean_latent[:, :, i:i+1])
|
||||
mask_list.append(state.denoise_mask[:, :, i:i+1])
|
||||
latent_list.append(state.latent[:, :, i : i + 1])
|
||||
clean_list.append(state.clean_latent[:, :, i : i + 1])
|
||||
mask_list.append(state.denoise_mask[:, :, i : i + 1])
|
||||
|
||||
state.latent = mx.concatenate(latent_list, axis=2)
|
||||
state.clean_latent = mx.concatenate(clean_list, axis=2)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
@@ -22,9 +21,11 @@ class LTXRopeType(Enum):
|
||||
SPLIT = "split"
|
||||
TWO_D = "2d"
|
||||
|
||||
|
||||
class AttentionType(Enum):
|
||||
DEFAULT = "default"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelConfig:
|
||||
|
||||
@@ -46,7 +47,7 @@ class BaseModelConfig:
|
||||
if v is not None:
|
||||
if isinstance(v, Enum):
|
||||
result[k] = v.value
|
||||
elif hasattr(v, 'to_dict'):
|
||||
elif hasattr(v, "to_dict"):
|
||||
result[k] = v.to_dict()
|
||||
else:
|
||||
result[k] = v
|
||||
@@ -68,26 +69,30 @@ class VideoVAEConfig(BaseModelConfig):
|
||||
out_channels: int = 128
|
||||
latent_channels: int = 128
|
||||
patch_size: int = 4
|
||||
encoder_blocks: List[tuple] = field(default_factory=lambda: [
|
||||
("res_x", {"num_layers": 4}),
|
||||
("compress_space_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_time_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
])
|
||||
decoder_blocks: List[tuple] = field(default_factory=lambda: [
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
])
|
||||
encoder_blocks: List[tuple] = field(
|
||||
default_factory=lambda: [
|
||||
("res_x", {"num_layers": 4}),
|
||||
("compress_space_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_time_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
]
|
||||
)
|
||||
decoder_blocks: List[tuple] = field(
|
||||
default_factory=lambda: [
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -111,7 +116,9 @@ class LTXModelConfig(BaseModelConfig):
|
||||
audio_in_channels: int = 128
|
||||
audio_out_channels: int = 128
|
||||
audio_cross_attention_dim: int = 2048
|
||||
audio_caption_channels: int = 3840 # Input dim for audio text embeddings (same as video)
|
||||
audio_caption_channels: int = (
|
||||
3840 # Input dim for audio text embeddings (same as video)
|
||||
)
|
||||
|
||||
# Positional embedding config
|
||||
positional_embedding_theta: float = 10000.0
|
||||
@@ -196,7 +203,6 @@ class LTXModelConfig(BaseModelConfig):
|
||||
)
|
||||
|
||||
|
||||
|
||||
class CausalityAxis(Enum):
|
||||
"""Enum for specifying the causality axis in causal convolutions."""
|
||||
|
||||
@@ -237,8 +243,8 @@ class AudioDecoderModelConfig(BaseModelConfig):
|
||||
def __post_init__(self):
|
||||
"""Convert string enum values to proper enum types."""
|
||||
# Import here to avoid circular imports
|
||||
from .audio_vae.normalization import NormType
|
||||
from .audio_vae.attention import AttentionType
|
||||
from .audio_vae.normalization import NormType
|
||||
|
||||
# Convert causality_axis string to enum
|
||||
if isinstance(self.causality_axis, str):
|
||||
@@ -252,6 +258,7 @@ class AudioDecoderModelConfig(BaseModelConfig):
|
||||
if isinstance(self.attn_type, str):
|
||||
self.attn_type = AttentionType(self.attn_type)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioEncoderModelConfig(BaseModelConfig):
|
||||
ch: int = 128
|
||||
@@ -282,8 +289,8 @@ class AudioEncoderModelConfig(BaseModelConfig):
|
||||
|
||||
def __post_init__(self):
|
||||
"""Convert string enum values to proper enum types."""
|
||||
from .audio_vae.normalization import NormType
|
||||
from .audio_vae.attention import AttentionType
|
||||
from .audio_vae.normalization import NormType
|
||||
|
||||
if isinstance(self.causality_axis, str):
|
||||
self.causality_axis = CausalityAxis(self.causality_axis)
|
||||
@@ -334,6 +341,7 @@ class VideoDecoderModelConfig(BaseModelConfig):
|
||||
dropout: float = 0.0
|
||||
timestep_conditioning: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoEncoderModelConfig(BaseModelConfig):
|
||||
convolution_dimensions: int = 3
|
||||
@@ -343,21 +351,24 @@ class VideoEncoderModelConfig(BaseModelConfig):
|
||||
norm_layer: Enum = None
|
||||
latent_log_var: Enum = None
|
||||
encoder_spatial_padding_mode: Enum = None
|
||||
encoder_blocks: List[tuple] = field(default_factory=lambda: [("res_x", {"num_layers": 4}),
|
||||
("compress_space_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_time_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2})
|
||||
])
|
||||
encoder_blocks: List[tuple] = field(
|
||||
default_factory=lambda: [
|
||||
("res_x", {"num_layers": 4}),
|
||||
("compress_space_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_time_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
]
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.resnet import NormLayerType
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import LogVarianceType
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType
|
||||
|
||||
if self.norm_layer is None:
|
||||
self.norm_layer = NormLayerType.PIXEL_NORM
|
||||
@@ -371,7 +382,9 @@ class VideoEncoderModelConfig(BaseModelConfig):
|
||||
if isinstance(self.latent_log_var, str):
|
||||
self.latent_log_var = LogVarianceType(self.latent_log_var)
|
||||
if isinstance(self.encoder_spatial_padding_mode, str):
|
||||
self.encoder_spatial_padding_mode = PaddingModeType(self.encoder_spatial_padding_mode)
|
||||
self.encoder_spatial_padding_mode = PaddingModeType(
|
||||
self.encoder_spatial_padding_mode
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
|
||||
@@ -49,7 +49,6 @@ from typing import Dict
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
# ─── Key prefix routing ──────────────────────────────────────────────────────
|
||||
|
||||
TRANSFORMER_PREFIX = "model.diffusion_model."
|
||||
@@ -78,7 +77,7 @@ def sanitize_transformer(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
|
||||
continue
|
||||
|
||||
new_key = key[len(TRANSFORMER_PREFIX):]
|
||||
new_key = key[len(TRANSFORMER_PREFIX) :]
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
@@ -109,7 +108,7 @@ def sanitize_vae_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
else:
|
||||
continue
|
||||
elif key.startswith(VAE_DECODER_PREFIX):
|
||||
new_key = key[len(VAE_DECODER_PREFIX):]
|
||||
new_key = key[len(VAE_DECODER_PREFIX) :]
|
||||
else:
|
||||
continue
|
||||
|
||||
@@ -147,7 +146,7 @@ def sanitize_vae_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
if value.dtype != mx.float32:
|
||||
value = value.astype(mx.float32)
|
||||
elif key.startswith(VAE_ENCODER_PREFIX):
|
||||
new_key = key[len(VAE_ENCODER_PREFIX):]
|
||||
new_key = key[len(VAE_ENCODER_PREFIX) :]
|
||||
else:
|
||||
continue
|
||||
|
||||
@@ -170,7 +169,7 @@ def sanitize_audio_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
new_key = None
|
||||
|
||||
if key.startswith(AUDIO_DECODER_PREFIX):
|
||||
new_key = key[len(AUDIO_DECODER_PREFIX):]
|
||||
new_key = key[len(AUDIO_DECODER_PREFIX) :]
|
||||
elif key.startswith(AUDIO_STATS_PREFIX):
|
||||
if "mean-of-means" in key:
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
@@ -196,7 +195,7 @@ def sanitize_audio_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
new_key = None
|
||||
|
||||
if key.startswith(AUDIO_ENCODER_PREFIX):
|
||||
new_key = key[len(AUDIO_ENCODER_PREFIX):]
|
||||
new_key = key[len(AUDIO_ENCODER_PREFIX) :]
|
||||
elif key.startswith(AUDIO_STATS_PREFIX):
|
||||
if "mean-of-means" in key:
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
@@ -226,7 +225,7 @@ def sanitize_vocoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
if not key.startswith(VOCODER_PREFIX):
|
||||
continue
|
||||
|
||||
new_key = key[len(VOCODER_PREFIX):]
|
||||
new_key = key[len(VOCODER_PREFIX) :]
|
||||
|
||||
# Handle Conv1d/ConvTranspose1d weight shape conversion
|
||||
if "weight" in new_key and value.ndim == 3:
|
||||
@@ -260,20 +259,20 @@ def extract_text_projections(weights: Dict[str, mx.array]) -> Dict[str, mx.array
|
||||
# aggregate_embed weights (text_embedding_projection.*)
|
||||
for key, value in weights.items():
|
||||
if key.startswith(TEXT_PROJ_PREFIX):
|
||||
new_key = key[len(TEXT_PROJ_PREFIX):]
|
||||
new_key = key[len(TEXT_PROJ_PREFIX) :]
|
||||
extracted[new_key] = value
|
||||
|
||||
# video_embeddings_connector
|
||||
for key, value in weights.items():
|
||||
if key.startswith(VIDEO_CONNECTOR_PREFIX):
|
||||
suffix = key[len(VIDEO_CONNECTOR_PREFIX):]
|
||||
suffix = key[len(VIDEO_CONNECTOR_PREFIX) :]
|
||||
new_key = "video_embeddings_connector." + sanitize_connector_key(suffix)
|
||||
extracted[new_key] = value
|
||||
|
||||
# audio_embeddings_connector
|
||||
for key, value in weights.items():
|
||||
if key.startswith(AUDIO_CONNECTOR_PREFIX):
|
||||
suffix = key[len(AUDIO_CONNECTOR_PREFIX):]
|
||||
suffix = key[len(AUDIO_CONNECTOR_PREFIX) :]
|
||||
new_key = "audio_embeddings_connector." + sanitize_connector_key(suffix)
|
||||
extracted[new_key] = value
|
||||
|
||||
@@ -369,11 +368,15 @@ def save_config(config: dict, output_dir: Path):
|
||||
# ─── Source resolution ─────────────────────────────────────────────────────────
|
||||
|
||||
# Matches monolithic model files: ltx-2-19b-distilled.safetensors, ltx-2.3-22b-dev.safetensors, etc.
|
||||
MONOLITHIC_PATTERN = re.compile(r"^ltx-[\d.]+-\d+b-(?P<variant>distilled|dev)\.safetensors$")
|
||||
MONOLITHIC_PATTERN = re.compile(
|
||||
r"^ltx-[\d.]+-\d+b-(?P<variant>distilled|dev)\.safetensors$"
|
||||
)
|
||||
|
||||
# Matches upscaler files like ltx-2-spatial-upscaler-x2-1.0.safetensors,
|
||||
# ltx-2.3-spatial-upscaler-x2-1.0.safetensors, etc.
|
||||
UPSCALER_PATTERN = re.compile(r"^ltx-[\d.]+-(?:spatial|temporal)-upscaler-.+\.safetensors$")
|
||||
UPSCALER_PATTERN = re.compile(
|
||||
r"^ltx-[\d.]+-(?:spatial|temporal)-upscaler-.+\.safetensors$"
|
||||
)
|
||||
|
||||
|
||||
def resolve_source(source: str, variant: str) -> Path:
|
||||
@@ -506,7 +509,9 @@ def infer_transformer_config(weights: Dict[str, mx.array]) -> dict:
|
||||
def infer_vae_decoder_config(weights: Dict[str, mx.array], variant: str) -> dict:
|
||||
"""Infer VAE decoder config from weights."""
|
||||
# Check for timestep conditioning keys
|
||||
has_timestep = any("last_time_embedder" in k or "last_scale_shift_table" in k for k in weights)
|
||||
has_timestep = any(
|
||||
"last_time_embedder" in k or "last_scale_shift_table" in k for k in weights
|
||||
)
|
||||
|
||||
# Count channel multipliers from up_blocks
|
||||
max_block = -1
|
||||
@@ -658,7 +663,9 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
config = infer_transformer_config(transformer_weights)
|
||||
save_config(config, output_path / "transformer")
|
||||
t_params = sum(v.size for v in transformer_weights.values())
|
||||
print(f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards")
|
||||
print(
|
||||
f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards"
|
||||
)
|
||||
|
||||
# 2. VAE Decoder
|
||||
print(" [2/7] VAE Decoder...")
|
||||
@@ -728,7 +735,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
]
|
||||
else:
|
||||
upscaler_files = [
|
||||
f.name for f in source_dir.iterdir()
|
||||
f.name
|
||||
for f in source_dir.iterdir()
|
||||
if f.is_file() and UPSCALER_PATTERN.match(f.name)
|
||||
]
|
||||
|
||||
@@ -800,12 +808,21 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
print(f"\nDone! Converted {all_converted}/{total_keys} keys")
|
||||
if all_converted < total_keys:
|
||||
known_prefixes = (
|
||||
TRANSFORMER_PREFIX, VAE_DECODER_PREFIX, VAE_ENCODER_PREFIX,
|
||||
VAE_STATS_PREFIX, AUDIO_DECODER_PREFIX, AUDIO_ENCODER_PREFIX,
|
||||
AUDIO_STATS_PREFIX, VOCODER_PREFIX, TEXT_PROJ_PREFIX,
|
||||
VIDEO_CONNECTOR_PREFIX, AUDIO_CONNECTOR_PREFIX,
|
||||
TRANSFORMER_PREFIX,
|
||||
VAE_DECODER_PREFIX,
|
||||
VAE_ENCODER_PREFIX,
|
||||
VAE_STATS_PREFIX,
|
||||
AUDIO_DECODER_PREFIX,
|
||||
AUDIO_ENCODER_PREFIX,
|
||||
AUDIO_STATS_PREFIX,
|
||||
VOCODER_PREFIX,
|
||||
TEXT_PROJ_PREFIX,
|
||||
VIDEO_CONNECTOR_PREFIX,
|
||||
AUDIO_CONNECTOR_PREFIX,
|
||||
)
|
||||
skipped = [k for k in all_weights if not any(k.startswith(p) for p in known_prefixes)]
|
||||
skipped = [
|
||||
k for k in all_weights if not any(k.startswith(p) for p in known_prefixes)
|
||||
]
|
||||
if skipped:
|
||||
print(f" Skipped {len(skipped)} keys:")
|
||||
for k in sorted(skipped)[:20]:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,15 +1,14 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from pathlib import Path
|
||||
|
||||
from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle
|
||||
from mlx_video.models.ltx_2.config import (
|
||||
LTXModelConfig,
|
||||
LTXModelType,
|
||||
LTXRopeType,
|
||||
TransformerConfig,
|
||||
)
|
||||
from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle
|
||||
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
|
||||
from mlx_video.models.ltx_2.text_projection import PixArtAlphaTextProjection
|
||||
from mlx_video.models.ltx_2.transformer import (
|
||||
@@ -58,11 +57,17 @@ class TransformerArgsPreprocessor:
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
||||
timestep_emb, embedded_timestep = self.adaln(
|
||||
timestep.reshape(-1), hidden_dtype=hidden_dtype
|
||||
)
|
||||
|
||||
# Reshape to (batch, tokens, dim)
|
||||
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
|
||||
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
|
||||
timestep_emb = mx.reshape(
|
||||
timestep_emb, (batch_size, -1, timestep_emb.shape[-1])
|
||||
)
|
||||
embedded_timestep = mx.reshape(
|
||||
embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])
|
||||
)
|
||||
|
||||
return timestep_emb, embedded_timestep
|
||||
|
||||
@@ -74,9 +79,15 @@ class TransformerArgsPreprocessor:
|
||||
hidden_dtype: mx.Dtype = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep_emb, embedded_timestep = adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
||||
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
|
||||
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
|
||||
timestep_emb, embedded_timestep = adaln(
|
||||
timestep.reshape(-1), hidden_dtype=hidden_dtype
|
||||
)
|
||||
timestep_emb = mx.reshape(
|
||||
timestep_emb, (batch_size, -1, timestep_emb.shape[-1])
|
||||
)
|
||||
embedded_timestep = mx.reshape(
|
||||
embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])
|
||||
)
|
||||
return timestep_emb, embedded_timestep
|
||||
|
||||
def _prepare_context(
|
||||
@@ -107,7 +118,9 @@ class TransformerArgsPreprocessor:
|
||||
# Convert boolean/int mask to float mask
|
||||
# 0 -> -inf (masked), 1 -> 0 (not masked)
|
||||
mask = (attention_mask.astype(x_dtype) - 1) * 1e9
|
||||
mask = mx.reshape(mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
|
||||
mask = mx.reshape(
|
||||
mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||
)
|
||||
return mask
|
||||
|
||||
def _prepare_positional_embeddings(
|
||||
@@ -132,9 +145,15 @@ class TransformerArgsPreprocessor:
|
||||
|
||||
def prepare(self, modality: Modality) -> TransformerArgs:
|
||||
x = self.patchify_proj(modality.latent)
|
||||
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype)
|
||||
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
|
||||
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
|
||||
timestep, embedded_timestep = self._prepare_timestep(
|
||||
modality.timesteps, x.shape[0], hidden_dtype=x.dtype
|
||||
)
|
||||
context, attention_mask = self._prepare_context(
|
||||
modality.context, x, modality.context_mask
|
||||
)
|
||||
attention_mask = self._prepare_attention_mask(
|
||||
attention_mask, modality.latent.dtype
|
||||
)
|
||||
|
||||
# Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation)
|
||||
if modality.positional_embeddings is not None:
|
||||
@@ -152,8 +171,13 @@ class TransformerArgsPreprocessor:
|
||||
prompt_timestep = None
|
||||
prompt_embedded_timestep = None
|
||||
if self.prompt_adaln is not None and modality.sigma is not None:
|
||||
prompt_timestep, prompt_embedded_timestep = self._prepare_timestep_with_adaln(
|
||||
self.prompt_adaln, modality.sigma, x.shape[0], hidden_dtype=x.dtype,
|
||||
prompt_timestep, prompt_embedded_timestep = (
|
||||
self._prepare_timestep_with_adaln(
|
||||
self.prompt_adaln,
|
||||
modality.sigma,
|
||||
x.shape[0],
|
||||
hidden_dtype=x.dtype,
|
||||
)
|
||||
)
|
||||
|
||||
return TransformerArgs(
|
||||
@@ -229,11 +253,13 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
)
|
||||
|
||||
# Prepare cross-attention timestep embeddings
|
||||
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
|
||||
timestep=modality.timesteps,
|
||||
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
||||
batch_size=transformer_args.x.shape[0],
|
||||
hidden_dtype=transformer_args.x.dtype,
|
||||
cross_scale_shift_timestep, cross_gate_timestep = (
|
||||
self._prepare_cross_attention_timestep(
|
||||
timestep=modality.timesteps,
|
||||
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
||||
batch_size=transformer_args.x.shape[0],
|
||||
hidden_dtype=transformer_args.x.dtype,
|
||||
)
|
||||
)
|
||||
|
||||
return replace(
|
||||
@@ -254,11 +280,19 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
|
||||
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
|
||||
|
||||
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
||||
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
|
||||
scale_shift_timestep, _ = self.cross_scale_shift_adaln(
|
||||
timestep.reshape(-1), hidden_dtype=hidden_dtype
|
||||
)
|
||||
scale_shift_timestep = mx.reshape(
|
||||
scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1])
|
||||
)
|
||||
|
||||
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype)
|
||||
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
|
||||
gate_timestep, _ = self.cross_gate_adaln(
|
||||
timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype
|
||||
)
|
||||
gate_timestep = mx.reshape(
|
||||
gate_timestep, (batch_size, -1, gate_timestep.shape[-1])
|
||||
)
|
||||
|
||||
return scale_shift_timestep, gate_timestep
|
||||
|
||||
@@ -285,18 +319,25 @@ class LTXModel(nn.Module):
|
||||
self._init_video(config)
|
||||
|
||||
if config.model_type.is_audio_enabled():
|
||||
self.audio_positional_embedding_max_pos = config.audio_positional_embedding_max_pos
|
||||
self.audio_positional_embedding_max_pos = (
|
||||
config.audio_positional_embedding_max_pos
|
||||
)
|
||||
self.audio_num_attention_heads = config.audio_num_attention_heads
|
||||
self.audio_inner_dim = config.audio_inner_dim
|
||||
self._init_audio(config)
|
||||
|
||||
# Initialize cross-modal components
|
||||
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
|
||||
if (
|
||||
config.model_type.is_video_enabled()
|
||||
and config.model_type.is_audio_enabled()
|
||||
):
|
||||
cross_pe_max_pos = max(
|
||||
config.positional_embedding_max_pos[0],
|
||||
config.audio_positional_embedding_max_pos[0],
|
||||
)
|
||||
self.av_ca_timestep_scale_multiplier = config.av_ca_timestep_scale_multiplier
|
||||
self.av_ca_timestep_scale_multiplier = (
|
||||
config.av_ca_timestep_scale_multiplier
|
||||
)
|
||||
self.audio_cross_attention_dim = config.audio_cross_attention_dim
|
||||
self._init_audio_video(config)
|
||||
|
||||
@@ -308,10 +349,14 @@ class LTXModel(nn.Module):
|
||||
self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True)
|
||||
|
||||
adaln_coefficient = 9 if config.has_prompt_adaln else 6
|
||||
self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=adaln_coefficient)
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, embedding_coefficient=adaln_coefficient
|
||||
)
|
||||
|
||||
if config.has_prompt_adaln:
|
||||
self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2)
|
||||
self.prompt_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, embedding_coefficient=2
|
||||
)
|
||||
else:
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=config.caption_channels,
|
||||
@@ -323,13 +368,19 @@ class LTXModel(nn.Module):
|
||||
self.proj_out = nn.Linear(self.inner_dim, config.out_channels)
|
||||
|
||||
def _init_audio(self, config: LTXModelConfig) -> None:
|
||||
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
|
||||
self.audio_patchify_proj = nn.Linear(
|
||||
config.audio_in_channels, self.audio_inner_dim, bias=True
|
||||
)
|
||||
|
||||
audio_adaln_coefficient = 9 if config.has_prompt_adaln else 6
|
||||
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient)
|
||||
self.audio_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient
|
||||
)
|
||||
|
||||
if config.has_prompt_adaln:
|
||||
self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2)
|
||||
self.audio_prompt_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim, embedding_coefficient=2
|
||||
)
|
||||
else:
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=config.audio_caption_channels,
|
||||
@@ -338,7 +389,9 @@ class LTXModel(nn.Module):
|
||||
|
||||
# Output components
|
||||
self.audio_scale_shift_table = mx.zeros((2, self.audio_inner_dim))
|
||||
self.audio_norm_out = nn.LayerNorm(self.audio_inner_dim, eps=config.norm_eps, affine=False)
|
||||
self.audio_norm_out = nn.LayerNorm(
|
||||
self.audio_inner_dim, eps=config.norm_eps, affine=False
|
||||
)
|
||||
self.audio_proj_out = nn.Linear(self.audio_inner_dim, config.audio_out_channels)
|
||||
|
||||
def _init_audio_video(self, config: LTXModelConfig) -> None:
|
||||
@@ -361,8 +414,13 @@ class LTXModel(nn.Module):
|
||||
embedding_coefficient=1,
|
||||
)
|
||||
|
||||
def _init_preprocessors(self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]) -> None:
|
||||
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
|
||||
def _init_preprocessors(
|
||||
self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]
|
||||
) -> None:
|
||||
if (
|
||||
config.model_type.is_video_enabled()
|
||||
and config.model_type.is_audio_enabled()
|
||||
):
|
||||
# Multi-modal preprocessors
|
||||
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
||||
patchify_proj=self.patchify_proj,
|
||||
@@ -468,7 +526,8 @@ class LTXModel(nn.Module):
|
||||
stg_a_set = set(stg_audio_blocks) if stg_audio_blocks else set()
|
||||
for idx, block in self.transformer_blocks.items():
|
||||
video, audio = block(
|
||||
video=video, audio=audio,
|
||||
video=video,
|
||||
audio=audio,
|
||||
skip_video_self_attn=(idx in stg_v_set),
|
||||
skip_audio_self_attn=(idx in stg_a_set),
|
||||
skip_cross_modal=skip_cross_modal,
|
||||
@@ -526,8 +585,12 @@ class LTXModel(nn.Module):
|
||||
raise ValueError("Audio is not enabled for this model")
|
||||
|
||||
# Preprocess arguments
|
||||
video_args = self.video_args_preprocessor.prepare(video) if video is not None else None
|
||||
audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None
|
||||
video_args = (
|
||||
self.video_args_preprocessor.prepare(video) if video is not None else None
|
||||
)
|
||||
audio_args = (
|
||||
self.audio_args_preprocessor.prepare(audio) if audio is not None else None
|
||||
)
|
||||
|
||||
# Process transformer blocks
|
||||
video_out, audio_out = self._process_transformer_blocks(
|
||||
@@ -577,7 +640,10 @@ class LTXModel(nn.Module):
|
||||
|
||||
if not key.startswith("model.diffusion_model."):
|
||||
continue
|
||||
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
|
||||
if (
|
||||
"audio_embeddings_connector" in key
|
||||
or "video_embeddings_connector" in key
|
||||
):
|
||||
continue
|
||||
|
||||
# Remove 'model.diffusion_model.' prefix
|
||||
@@ -612,9 +678,11 @@ class LTXModel(nn.Module):
|
||||
for weight_file in model_path.glob("*.safetensors"):
|
||||
weights.update(mx.load(str(weight_file)))
|
||||
|
||||
|
||||
sanitized = model.sanitize(weights)
|
||||
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
|
||||
sanitized = {
|
||||
k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v
|
||||
for k, v in sanitized.items()
|
||||
}
|
||||
|
||||
model.load_weights(list(sanitized.items()), strict=strict)
|
||||
mx.eval(model.parameters())
|
||||
@@ -639,13 +707,18 @@ class X0Model(nn.Module):
|
||||
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
|
||||
|
||||
vx, ax = self.velocity_model(
|
||||
video, audio,
|
||||
video,
|
||||
audio,
|
||||
stg_video_blocks=stg_video_blocks,
|
||||
stg_audio_blocks=stg_audio_blocks,
|
||||
skip_cross_modal=skip_cross_modal,
|
||||
)
|
||||
|
||||
denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
|
||||
denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
|
||||
denoised_video = (
|
||||
to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
|
||||
)
|
||||
denoised_audio = (
|
||||
to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
|
||||
)
|
||||
|
||||
return denoised_video, denoised_audio
|
||||
@@ -1,9 +1,10 @@
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75) -> np.ndarray:
|
||||
def bilateral_filter(
|
||||
image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75
|
||||
) -> np.ndarray:
|
||||
"""Apply bilateral filter to reduce grid artifacts while preserving edges.
|
||||
|
||||
Args:
|
||||
@@ -17,6 +18,7 @@ def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sig
|
||||
"""
|
||||
try:
|
||||
import cv2
|
||||
|
||||
return cv2.bilateralFilter(image, d, sigma_color, sigma_space)
|
||||
except ImportError:
|
||||
# Fallback to simple Gaussian blur if cv2 not available
|
||||
@@ -35,14 +37,20 @@ def gaussian_blur(image: np.ndarray, kernel_size: int = 3) -> np.ndarray:
|
||||
"""
|
||||
try:
|
||||
import cv2
|
||||
|
||||
return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
|
||||
except ImportError:
|
||||
# Simple box blur fallback
|
||||
from scipy.ndimage import uniform_filter
|
||||
return uniform_filter(image, size=(kernel_size, kernel_size, 1)).astype(np.uint8)
|
||||
|
||||
return uniform_filter(image, size=(kernel_size, kernel_size, 1)).astype(
|
||||
np.uint8
|
||||
)
|
||||
|
||||
|
||||
def unsharp_mask(image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, amount: float = 1.0) -> np.ndarray:
|
||||
def unsharp_mask(
|
||||
image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, amount: float = 1.0
|
||||
) -> np.ndarray:
|
||||
"""Apply unsharp masking to enhance edges after blur.
|
||||
|
||||
Args:
|
||||
@@ -56,6 +64,7 @@ def unsharp_mask(image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, am
|
||||
"""
|
||||
try:
|
||||
import cv2
|
||||
|
||||
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
|
||||
sharpened = cv2.addWeighted(image, 1 + amount, blurred, -amount, 0)
|
||||
return np.clip(sharpened, 0, 255).astype(np.uint8)
|
||||
@@ -81,23 +90,23 @@ def reduce_grid_artifacts(
|
||||
if method == "bilateral":
|
||||
d = max(3, int(5 * strength))
|
||||
sigma = 50 + 50 * strength
|
||||
processed = np.stack([
|
||||
bilateral_filter(frame, d=d, sigma_color=sigma, sigma_space=sigma)
|
||||
for frame in video
|
||||
])
|
||||
processed = np.stack(
|
||||
[
|
||||
bilateral_filter(frame, d=d, sigma_color=sigma, sigma_space=sigma)
|
||||
for frame in video
|
||||
]
|
||||
)
|
||||
elif method == "gaussian":
|
||||
kernel_size = max(3, int(3 + 4 * strength))
|
||||
if kernel_size % 2 == 0:
|
||||
kernel_size += 1
|
||||
processed = np.stack([
|
||||
gaussian_blur(frame, kernel_size=kernel_size)
|
||||
for frame in video
|
||||
])
|
||||
processed = np.stack(
|
||||
[gaussian_blur(frame, kernel_size=kernel_size) for frame in video]
|
||||
)
|
||||
elif method == "frequency":
|
||||
processed = np.stack([
|
||||
remove_grid_frequency(frame, grid_size=8)
|
||||
for frame in video
|
||||
])
|
||||
processed = np.stack(
|
||||
[remove_grid_frequency(frame, grid_size=8) for frame in video]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {method}")
|
||||
|
||||
@@ -160,6 +169,3 @@ def remove_grid_frequency(frame: np.ndarray, grid_size: int = 8) -> np.ndarray:
|
||||
result[:, :, c] = np.clip(channel_filtered, 0, 255).astype(np.uint8)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
@@ -86,11 +85,12 @@ def rotate_half_interleaved(x: mx.array) -> mx.array:
|
||||
"""
|
||||
# x: (..., dim) where dim is even
|
||||
x_even = x[..., 0::2] # [x0, x2, x4, ...]
|
||||
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
|
||||
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
|
||||
# Stack: [[-x1, x0], [-x3, x2], ...] then flatten to [-x1, x0, -x3, x2, ...]
|
||||
rotated = mx.stack([-x_odd, x_even], axis=-1)
|
||||
return mx.reshape(rotated, x.shape)
|
||||
|
||||
|
||||
def apply_rotary_emb_1d(
|
||||
q: mx.array,
|
||||
k: mx.array,
|
||||
@@ -228,9 +228,9 @@ def get_fractional_positions(
|
||||
Fractional positions in range [-1, 1] after scaling
|
||||
"""
|
||||
n_pos_dims = indices_grid.shape[1]
|
||||
assert n_pos_dims == len(max_pos), (
|
||||
f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
|
||||
)
|
||||
assert n_pos_dims == len(
|
||||
max_pos
|
||||
), f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
|
||||
|
||||
# Divide each dimension by its max position
|
||||
fractional_positions = []
|
||||
@@ -392,11 +392,15 @@ def precompute_freqs_cis(
|
||||
if max_pos is None:
|
||||
max_pos = [20, 2048, 2048]
|
||||
|
||||
|
||||
if double_precision:
|
||||
return _precompute_freqs_cis_double_precision(
|
||||
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
|
||||
num_attention_heads, rope_type
|
||||
indices_grid,
|
||||
dim,
|
||||
theta,
|
||||
max_pos,
|
||||
use_middle_indices_grid,
|
||||
num_attention_heads,
|
||||
rope_type,
|
||||
)
|
||||
|
||||
# Keep positions in float32 for RoPE computation.
|
||||
@@ -495,7 +499,9 @@ def _precompute_freqs_cis_double_precision(
|
||||
# Compute frequencies: outer product
|
||||
# scaled_positions: (B, T, n_dims) -> (B, T, n_dims, 1)
|
||||
# freq_indices: (num_indices,) -> (1, 1, 1, num_indices)
|
||||
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(freq_indices, (1, 1, 1, -1))
|
||||
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(
|
||||
freq_indices, (1, 1, 1, -1)
|
||||
)
|
||||
# freqs: (B, T, n_dims, num_indices)
|
||||
|
||||
# Transpose and flatten: (B, T, n_dims, num_indices) -> (B, T, num_indices, n_dims) -> (B, T, num_indices * n_dims)
|
||||
|
||||
@@ -5,15 +5,14 @@ noise injection, ported from the LTX-2 PyTorch implementation.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phi functions and RK coefficients (pure Python math, no MLX needed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def phi(j: int, neg_h: float) -> float:
|
||||
"""Compute phi_j(z) where z = -h (negative step size in log-space).
|
||||
|
||||
@@ -43,6 +42,7 @@ def get_res2s_coefficients(
|
||||
Returns:
|
||||
(a21, b1, b2): RK coefficients.
|
||||
"""
|
||||
|
||||
def get_phi(j: int, neg_h: float) -> float:
|
||||
cache_key = (j, neg_h)
|
||||
if cache_key in phi_cache:
|
||||
@@ -69,6 +69,7 @@ def get_res2s_coefficients(
|
||||
# SDE noise injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_sde_coeff(
|
||||
sigma_next: float,
|
||||
) -> tuple[float, float, float]:
|
||||
@@ -139,7 +140,9 @@ def sde_noise_step(
|
||||
denoised_next = sample_f32 - sigma * eps_next
|
||||
|
||||
# Mix deterministic and stochastic components
|
||||
x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
|
||||
x_noised = (
|
||||
alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
|
||||
)
|
||||
|
||||
return x_noised
|
||||
|
||||
@@ -148,6 +151,7 @@ def sde_noise_step(
|
||||
# Noise generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def channelwise_normalize(x: mx.array) -> mx.array:
|
||||
"""Normalize each channel to zero mean and unit variance over spatial dims.
|
||||
|
||||
|
||||
@@ -1,25 +1,25 @@
|
||||
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
|
||||
|
||||
from mlx_video.utils import rms_norm, apply_quantization
|
||||
from mlx_video.models.ltx_2.rope import apply_interleaved_rotary_emb
|
||||
|
||||
from mlx_vlm.models.gemma3.language import Gemma3Model
|
||||
from mlx_vlm.models.gemma3.config import TextConfig
|
||||
from mlx_vlm.models.gemma3.language import Gemma3Model
|
||||
from rich.console import Console
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TaskProgressColumn,
|
||||
TextColumn,
|
||||
TimeRemainingColumn,
|
||||
)
|
||||
|
||||
from mlx_video.utils import apply_quantization, rms_norm
|
||||
|
||||
# Path to system prompts
|
||||
PROMPTS_DIR = Path(__file__).parent / "prompts"
|
||||
@@ -36,7 +36,6 @@ def _load_system_prompt(prompt_name: str) -> str:
|
||||
|
||||
class LanguageModel(nn.Module):
|
||||
|
||||
|
||||
def __init__(self, config: TextConfig):
|
||||
super().__init__()
|
||||
# Create config matching LTX-2 text encoder requirements
|
||||
@@ -59,15 +58,25 @@ class LanguageModel(nn.Module):
|
||||
|
||||
padding_mask = attention_mask.astype(mx.bool_) # (batch, seq_len)
|
||||
combined = causal_mask[None, :, :] & padding_mask[:, None, :]
|
||||
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
||||
mask = mx.where(combined, mx.zeros(combined.shape, dtype=dtype),
|
||||
mx.full(combined.shape, min_val, dtype=dtype))
|
||||
min_val = (
|
||||
mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
||||
)
|
||||
mask = mx.where(
|
||||
combined,
|
||||
mx.zeros(combined.shape, dtype=dtype),
|
||||
mx.full(combined.shape, min_val, dtype=dtype),
|
||||
)
|
||||
return mask[:, None, :, :]
|
||||
else:
|
||||
# No padding mask, just causal
|
||||
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
||||
mask = mx.where(causal_mask, mx.zeros((seq_len, seq_len), dtype=dtype),
|
||||
mx.full((seq_len, seq_len), min_val, dtype=dtype))
|
||||
min_val = (
|
||||
mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
||||
)
|
||||
mask = mx.where(
|
||||
causal_mask,
|
||||
mx.zeros((seq_len, seq_len), dtype=dtype),
|
||||
mx.full((seq_len, seq_len), min_val, dtype=dtype),
|
||||
)
|
||||
return mask[None, None, :, :] # (1, 1, seq, seq)
|
||||
|
||||
def __call__(
|
||||
@@ -91,7 +100,11 @@ class LanguageModel(nn.Module):
|
||||
batch_size, seq_len = inputs.shape
|
||||
|
||||
# Get embeddings
|
||||
h = input_embeddings if input_embeddings is not None else self.model.embed_tokens(inputs)
|
||||
h = (
|
||||
input_embeddings
|
||||
if input_embeddings is not None
|
||||
else self.model.embed_tokens(inputs)
|
||||
)
|
||||
|
||||
# Apply Gemma scaling
|
||||
h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype)
|
||||
@@ -103,11 +116,12 @@ class LanguageModel(nn.Module):
|
||||
if cache is None:
|
||||
cache = [None] * len(self.model.layers)
|
||||
|
||||
full_causal_mask = self._create_causal_mask_with_padding(seq_len, attention_mask, h.dtype)
|
||||
full_causal_mask = self._create_causal_mask_with_padding(
|
||||
seq_len, attention_mask, h.dtype
|
||||
)
|
||||
|
||||
sliding_mask = full_causal_mask
|
||||
|
||||
|
||||
num_layers = len(self.model.layers)
|
||||
for i, layer in enumerate(self.model.layers):
|
||||
is_global = (
|
||||
@@ -147,9 +161,9 @@ class LanguageModel(nn.Module):
|
||||
for key, value in weights.items():
|
||||
if key.startswith(prefix):
|
||||
if hasattr(value, "dtype") and value.dtype == mx.float32:
|
||||
sanitized[key[len(prefix):]] = value.astype(mx.bfloat16)
|
||||
sanitized[key[len(prefix) :]] = value.astype(mx.bfloat16)
|
||||
else:
|
||||
sanitized[key[len(prefix):]] = value
|
||||
sanitized[key[len(prefix) :]] = value
|
||||
return sanitized
|
||||
|
||||
@property
|
||||
@@ -158,6 +172,7 @@ class LanguageModel(nn.Module):
|
||||
|
||||
def make_cache(self):
|
||||
from mlx_vlm.models.cache import KVCache, RotatingKVCache
|
||||
|
||||
caches = []
|
||||
for i in range(len(self.layers)):
|
||||
if (
|
||||
@@ -172,6 +187,7 @@ class LanguageModel(nn.Module):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str):
|
||||
import json
|
||||
|
||||
weight_files = sorted(Path(model_path).glob("*.safetensors"))
|
||||
config_file = Path(model_path) / "config.json"
|
||||
config_dict = {}
|
||||
@@ -179,7 +195,9 @@ class LanguageModel(nn.Module):
|
||||
with open(config_file, "r") as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
language_model = cls(config=TextConfig.from_dict(config_dict["text_config"]))
|
||||
language_model = cls(
|
||||
config=TextConfig.from_dict(config_dict["text_config"])
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Config file not found at {model_path}")
|
||||
|
||||
@@ -188,19 +206,18 @@ class LanguageModel(nn.Module):
|
||||
for i, wf in enumerate(weight_files):
|
||||
weights.update(mx.load(str(wf)))
|
||||
|
||||
|
||||
if hasattr(language_model, "sanitize"):
|
||||
weights = language_model.sanitize(weights=weights)
|
||||
|
||||
|
||||
apply_quantization(model=language_model, weights=weights, quantization=quantization)
|
||||
apply_quantization(
|
||||
model=language_model, weights=weights, quantization=quantization
|
||||
)
|
||||
|
||||
language_model.load_weights(list(weights.items()), strict=False)
|
||||
|
||||
return language_model
|
||||
|
||||
|
||||
|
||||
class ConnectorAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -250,9 +267,15 @@ class ConnectorAttention(nn.Module):
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Reshape to (B, H, T, D) for SPLIT RoPE
|
||||
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
q = mx.reshape(
|
||||
q, (batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
).transpose(0, 2, 1, 3)
|
||||
k = mx.reshape(
|
||||
k, (batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
).transpose(0, 2, 1, 3)
|
||||
v = mx.reshape(
|
||||
v, (batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
).transpose(0, 2, 1, 3)
|
||||
|
||||
if pe is not None:
|
||||
q = self._apply_split_rope(q, pe[0], pe[1])
|
||||
@@ -336,9 +359,17 @@ class ConnectorFeedForward(nn.Module):
|
||||
|
||||
class ConnectorTransformerBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128, has_gate_logits: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 3840,
|
||||
num_heads: int = 30,
|
||||
head_dim: int = 128,
|
||||
has_gate_logits: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.attn1 = ConnectorAttention(dim, num_heads, head_dim, has_gate_logits=has_gate_logits)
|
||||
self.attn1 = ConnectorAttention(
|
||||
dim, num_heads, head_dim, has_gate_logits=has_gate_logits
|
||||
)
|
||||
self.ff = ConnectorFeedForward(dim)
|
||||
|
||||
def __call__(
|
||||
@@ -388,14 +419,18 @@ class Embeddings1DConnector(nn.Module):
|
||||
self.positional_embedding_max_pos = positional_embedding_max_pos or [1]
|
||||
|
||||
self.transformer_1d_blocks = {
|
||||
i: ConnectorTransformerBlock(dim, num_heads, head_dim, has_gate_logits=has_gate_logits)
|
||||
i: ConnectorTransformerBlock(
|
||||
dim, num_heads, head_dim, has_gate_logits=has_gate_logits
|
||||
)
|
||||
for i in range(num_layers)
|
||||
}
|
||||
|
||||
if num_learnable_registers > 0:
|
||||
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
|
||||
|
||||
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
|
||||
def _precompute_freqs_cis(
|
||||
self, seq_len: int, dtype: mx.Dtype
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute RoPE frequencies for connector (SPLIT type matching PyTorch).
|
||||
|
||||
Returns tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2).
|
||||
@@ -464,11 +499,15 @@ class Embeddings1DConnector(nn.Module):
|
||||
|
||||
# Binary mask: 1 for valid tokens, 0 for padded
|
||||
# attention_mask is additive: 0 for valid, large negative for padded
|
||||
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(mx.int32) # (batch, seq)
|
||||
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(
|
||||
mx.int32
|
||||
) # (batch, seq)
|
||||
|
||||
# Tile registers to match sequence length, cast to hidden_states dtype
|
||||
num_tiles = seq_len // self.num_learnable_registers
|
||||
registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype(dtype) # (seq_len, dim)
|
||||
registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype(
|
||||
dtype
|
||||
) # (seq_len, dim)
|
||||
|
||||
# Process each batch item (PyTorch uses advanced indexing)
|
||||
result_list = []
|
||||
@@ -481,25 +520,33 @@ class Embeddings1DConnector(nn.Module):
|
||||
|
||||
# Extract valid tokens (where mask is 1)
|
||||
# Since we have left-padded input, valid tokens are at the end
|
||||
valid_tokens = hs_b[seq_len - num_valid:] # (num_valid, dim)
|
||||
valid_tokens = hs_b[seq_len - num_valid :] # (num_valid, dim)
|
||||
|
||||
# Pad with zeros on the right to get back to seq_len
|
||||
pad_length = seq_len - num_valid
|
||||
if pad_length > 0:
|
||||
padding = mx.zeros((pad_length, dim), dtype=dtype)
|
||||
adjusted = mx.concatenate([valid_tokens, padding], axis=0) # (seq_len, dim)
|
||||
adjusted = mx.concatenate(
|
||||
[valid_tokens, padding], axis=0
|
||||
) # (seq_len, dim)
|
||||
else:
|
||||
adjusted = valid_tokens
|
||||
|
||||
# Create flipped mask: 1s at front (where valid tokens now are), 0s at back
|
||||
flipped_mask = mx.concatenate([
|
||||
mx.ones((num_valid,), dtype=mx.int32),
|
||||
mx.zeros((pad_length,), dtype=mx.int32)
|
||||
], axis=0) # (seq,)
|
||||
flipped_mask = mx.concatenate(
|
||||
[
|
||||
mx.ones((num_valid,), dtype=mx.int32),
|
||||
mx.zeros((pad_length,), dtype=mx.int32),
|
||||
],
|
||||
axis=0,
|
||||
) # (seq,)
|
||||
|
||||
# Combine: valid tokens at front, registers at back
|
||||
flipped_mask_expanded = flipped_mask[:, None].astype(dtype) # (seq, 1)
|
||||
combined = flipped_mask_expanded * adjusted + (1 - flipped_mask_expanded) * registers
|
||||
combined = (
|
||||
flipped_mask_expanded * adjusted
|
||||
+ (1 - flipped_mask_expanded) * registers
|
||||
)
|
||||
result_list.append(combined)
|
||||
|
||||
hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim)
|
||||
@@ -526,7 +573,9 @@ class Embeddings1DConnector(nn.Module):
|
||||
|
||||
# Process through transformer blocks
|
||||
for i in range(len(self.transformer_1d_blocks)):
|
||||
hidden_states = self.transformer_1d_blocks[i](hidden_states, attention_mask, freqs_cis)
|
||||
hidden_states = self.transformer_1d_blocks[i](
|
||||
hidden_states, attention_mask, freqs_cis
|
||||
)
|
||||
|
||||
# Final RMS norm
|
||||
hidden_states = rms_norm(hidden_states)
|
||||
@@ -534,7 +583,6 @@ class Embeddings1DConnector(nn.Module):
|
||||
return hidden_states, attention_mask
|
||||
|
||||
|
||||
|
||||
def norm_and_concat_hidden_states(
|
||||
hidden_states: List[mx.array],
|
||||
attention_mask: mx.array,
|
||||
@@ -567,8 +615,12 @@ def norm_and_concat_hidden_states(
|
||||
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
|
||||
|
||||
# Compute masked min/max per layer
|
||||
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=dtype))
|
||||
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=dtype))
|
||||
x_for_min = mx.where(
|
||||
mask, stacked, mx.full(stacked.shape, float("inf"), dtype=dtype)
|
||||
)
|
||||
x_for_max = mx.where(
|
||||
mask, stacked, mx.full(stacked.shape, float("-inf"), dtype=dtype)
|
||||
)
|
||||
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
|
||||
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
|
||||
range_val = x_max - x_min
|
||||
@@ -603,7 +655,9 @@ def norm_and_concat_per_token_rms(
|
||||
dtype = encoded_text.dtype
|
||||
|
||||
# Per-token RMSNorm across hidden dimension: variance = mean(x^2) over dim D
|
||||
variance = mx.mean(encoded_text.astype(mx.float32) ** 2, axis=2, keepdims=True) # (B, T, 1, L)
|
||||
variance = mx.mean(
|
||||
encoded_text.astype(mx.float32) ** 2, axis=2, keepdims=True
|
||||
) # (B, T, 1, L)
|
||||
normed = encoded_text.astype(mx.float32) * mx.rsqrt(variance + 1e-6)
|
||||
normed = normed.astype(dtype)
|
||||
|
||||
@@ -625,7 +679,9 @@ def _rescale_norm(x: mx.array, target_dim: int, source_dim: int) -> mx.array:
|
||||
class GemmaFeaturesExtractor(nn.Module):
|
||||
"""V1 feature extractor (LTX-2): 8 * (x - mean) / range normalization."""
|
||||
|
||||
def __init__(self, input_dim: int = 188160, output_dim: int = 3840, bias: bool = False):
|
||||
def __init__(
|
||||
self, input_dim: int = 188160, output_dim: int = 3840, bias: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=bias)
|
||||
|
||||
@@ -674,13 +730,14 @@ class GemmaFeaturesExtractorV2(nn.Module):
|
||||
|
||||
if mode == "video":
|
||||
target_dim = self.video_aggregate_embed.weight.shape[0]
|
||||
return self.video_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim))
|
||||
return self.video_aggregate_embed(
|
||||
_rescale_norm(normed, target_dim, self.embedding_dim)
|
||||
)
|
||||
else:
|
||||
target_dim = self.audio_aggregate_embed.weight.shape[0]
|
||||
return self.audio_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim))
|
||||
|
||||
|
||||
|
||||
return self.audio_aggregate_embed(
|
||||
_rescale_norm(normed, target_dim, self.embedding_dim)
|
||||
)
|
||||
|
||||
|
||||
class AudioEmbeddingsConnector(nn.Module):
|
||||
@@ -717,8 +774,8 @@ class LTX2TextEncoder(nn.Module):
|
||||
video_output_dim = 4096
|
||||
audio_output_dim = 2048
|
||||
self.feature_extractor_v2 = GemmaFeaturesExtractorV2(
|
||||
flat_dim=feature_input_dim, # 3840 * 49 = 188160 (concatenated)
|
||||
embedding_dim=hidden_dim, # 3840 (Gemma hidden_dim, for rescale)
|
||||
flat_dim=feature_input_dim, # 3840 * 49 = 188160 (concatenated)
|
||||
embedding_dim=hidden_dim, # 3840 (Gemma hidden_dim, for rescale)
|
||||
video_output_dim=video_output_dim,
|
||||
audio_output_dim=audio_output_dim,
|
||||
bias=True,
|
||||
@@ -728,33 +785,53 @@ class LTX2TextEncoder(nn.Module):
|
||||
# connector_positional_embedding_max_pos=[4096] from LTX-2.3 safetensors
|
||||
# config (nested under config.transformer.connector_positional_embedding_max_pos)
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
dim=video_output_dim, num_heads=32, head_dim=128,
|
||||
num_layers=8, num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[4096], has_gate_logits=True,
|
||||
dim=video_output_dim,
|
||||
num_heads=32,
|
||||
head_dim=128,
|
||||
num_layers=8,
|
||||
num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[4096],
|
||||
has_gate_logits=True,
|
||||
)
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
dim=audio_output_dim, num_heads=32, head_dim=64,
|
||||
num_layers=8, num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[4096], has_gate_logits=True,
|
||||
dim=audio_output_dim,
|
||||
num_heads=32,
|
||||
head_dim=64,
|
||||
num_layers=8,
|
||||
num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[4096],
|
||||
has_gate_logits=True,
|
||||
)
|
||||
else:
|
||||
# LTX-2: shared feature extractor, 3840-dim connectors
|
||||
self.feature_extractor = GemmaFeaturesExtractor(feature_input_dim, hidden_dim)
|
||||
self.feature_extractor = GemmaFeaturesExtractor(
|
||||
feature_input_dim, hidden_dim
|
||||
)
|
||||
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
dim=hidden_dim, num_heads=30, head_dim=128,
|
||||
num_layers=2, num_learnable_registers=128,
|
||||
dim=hidden_dim,
|
||||
num_heads=30,
|
||||
head_dim=128,
|
||||
num_layers=2,
|
||||
num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[1],
|
||||
)
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
dim=hidden_dim, num_heads=30, head_dim=128,
|
||||
num_layers=2, num_learnable_registers=128,
|
||||
dim=hidden_dim,
|
||||
num_heads=30,
|
||||
head_dim=128,
|
||||
num_layers=2,
|
||||
num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[1],
|
||||
)
|
||||
|
||||
self.processor = None
|
||||
|
||||
def load(self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it"):
|
||||
def load(
|
||||
self,
|
||||
model_path: Optional[str] = None,
|
||||
text_encoder_path: Optional[str] = "google/gemma-3-12b-it",
|
||||
):
|
||||
|
||||
if Path(str(text_encoder_path)).joinpath("text_encoder").is_dir():
|
||||
text_encoder_path = str(Path(text_encoder_path) / "text_encoder")
|
||||
@@ -785,22 +862,35 @@ class LTX2TextEncoder(nn.Module):
|
||||
|
||||
if transformer_weights:
|
||||
self._load_feature_extractors(transformer_weights, is_reformatted)
|
||||
self._load_connector("video_embeddings_connector", transformer_weights, is_reformatted)
|
||||
self._load_connector("audio_embeddings_connector", transformer_weights, is_reformatted)
|
||||
self._load_connector(
|
||||
"video_embeddings_connector", transformer_weights, is_reformatted
|
||||
)
|
||||
self._load_connector(
|
||||
"audio_embeddings_connector", transformer_weights, is_reformatted
|
||||
)
|
||||
else:
|
||||
print("WARNING: No transformer weights found for text projection connectors. "
|
||||
"Text conditioning will use uninitialized weights!")
|
||||
print(
|
||||
"WARNING: No transformer weights found for text projection connectors. "
|
||||
"Text conditioning will use uninitialized weights!"
|
||||
)
|
||||
|
||||
# Load tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer_path = model_path / "tokenizer"
|
||||
if tokenizer_path.exists():
|
||||
self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True)
|
||||
self.processor = AutoTokenizer.from_pretrained(
|
||||
str(tokenizer_path), trust_remote_code=True
|
||||
)
|
||||
else:
|
||||
try:
|
||||
self.processor = AutoTokenizer.from_pretrained(text_encoder_path, trust_remote_code=True)
|
||||
self.processor = AutoTokenizer.from_pretrained(
|
||||
text_encoder_path, trust_remote_code=True
|
||||
)
|
||||
except Exception:
|
||||
self.processor = AutoTokenizer.from_pretrained("google/gemma-3-12b-it", trust_remote_code=True)
|
||||
self.processor = AutoTokenizer.from_pretrained(
|
||||
"google/gemma-3-12b-it", trust_remote_code=True
|
||||
)
|
||||
# Set left padding to match official LTX-2 text encoder
|
||||
self.processor.padding_side = "left"
|
||||
|
||||
@@ -823,7 +913,11 @@ class LTX2TextEncoder(nn.Module):
|
||||
submodule.bias = weights[b_key]
|
||||
else:
|
||||
# LTX-2: single aggregate_embed
|
||||
agg_key = "aggregate_embed.weight" if is_reformatted else "text_embedding_projection.aggregate_embed.weight"
|
||||
agg_key = (
|
||||
"aggregate_embed.weight"
|
||||
if is_reformatted
|
||||
else "text_embedding_projection.aggregate_embed.weight"
|
||||
)
|
||||
if agg_key in weights:
|
||||
self.feature_extractor.aggregate_embed.weight = weights[agg_key]
|
||||
|
||||
@@ -837,12 +931,12 @@ class LTX2TextEncoder(nn.Module):
|
||||
prefix = f"{name}."
|
||||
for key, value in weights.items():
|
||||
if key.startswith(prefix):
|
||||
connector_weights[key[len(prefix):]] = value
|
||||
connector_weights[key[len(prefix) :]] = value
|
||||
else:
|
||||
mono_prefix = f"model.diffusion_model.{name}."
|
||||
for key, value in weights.items():
|
||||
if key.startswith(mono_prefix):
|
||||
connector_weights[key[len(mono_prefix):]] = value
|
||||
connector_weights[key[len(mono_prefix) :]] = value
|
||||
|
||||
if not connector_weights:
|
||||
return
|
||||
@@ -894,21 +988,36 @@ class LTX2TextEncoder(nn.Module):
|
||||
input_ids = mx.array(inputs["input_ids"])
|
||||
attention_mask = mx.array(inputs["attention_mask"])
|
||||
|
||||
_, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True)
|
||||
_, all_hidden_states = self.language_model(
|
||||
inputs=input_ids,
|
||||
input_embeddings=None,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
if self.has_prompt_adaln:
|
||||
# LTX-2.3: V2 feature extraction (per-token RMSNorm + rescale)
|
||||
video_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="video")
|
||||
video_features = self.feature_extractor_v2(
|
||||
all_hidden_states, attention_mask, mode="video"
|
||||
)
|
||||
additive_mask = (attention_mask - 1).astype(video_features.dtype)
|
||||
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
additive_mask = (
|
||||
additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
)
|
||||
|
||||
video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask)
|
||||
video_embeddings, _ = self.video_embeddings_connector(
|
||||
video_features, additive_mask
|
||||
)
|
||||
|
||||
if return_audio_embeddings:
|
||||
audio_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="audio")
|
||||
audio_features = self.feature_extractor_v2(
|
||||
all_hidden_states, attention_mask, mode="audio"
|
||||
)
|
||||
audio_mask = (attention_mask - 1).astype(audio_features.dtype)
|
||||
audio_mask = audio_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
audio_embeddings, _ = self.audio_embeddings_connector(audio_features, audio_mask)
|
||||
audio_embeddings, _ = self.audio_embeddings_connector(
|
||||
audio_features, audio_mask
|
||||
)
|
||||
return video_embeddings, audio_embeddings
|
||||
else:
|
||||
return video_embeddings, attention_mask
|
||||
@@ -920,12 +1029,18 @@ class LTX2TextEncoder(nn.Module):
|
||||
|
||||
video_features = self.feature_extractor(concat_hidden)
|
||||
additive_mask = (attention_mask - 1).astype(video_features.dtype)
|
||||
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
additive_mask = (
|
||||
additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
)
|
||||
|
||||
video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask)
|
||||
video_embeddings, _ = self.video_embeddings_connector(
|
||||
video_features, additive_mask
|
||||
)
|
||||
|
||||
if return_audio_embeddings:
|
||||
audio_embeddings, _ = self.audio_embeddings_connector(video_features, additive_mask)
|
||||
audio_embeddings, _ = self.audio_embeddings_connector(
|
||||
video_features, additive_mask
|
||||
)
|
||||
return video_embeddings, audio_embeddings
|
||||
else:
|
||||
return video_embeddings, attention_mask
|
||||
@@ -964,7 +1079,7 @@ class LTX2TextEncoder(nn.Module):
|
||||
# Remove leading/trailing whitespace
|
||||
response = response.strip()
|
||||
# Remove any leading punctuation
|
||||
response = re.sub(r'^[^\w\s]+', '', response)
|
||||
response = re.sub(r"^[^\w\s]+", "", response)
|
||||
return response
|
||||
|
||||
def _apply_chat_template(
|
||||
@@ -985,7 +1100,9 @@ class LTX2TextEncoder(nn.Module):
|
||||
elif isinstance(content, list):
|
||||
# Handle multimodal content (image + text)
|
||||
text_parts = [c["text"] for c in content if c.get("type") == "text"]
|
||||
formatted += f"<start_of_turn>user\n{' '.join(text_parts)}<end_of_turn>\n"
|
||||
formatted += (
|
||||
f"<start_of_turn>user\n{' '.join(text_parts)}<end_of_turn>\n"
|
||||
)
|
||||
elif role == "assistant":
|
||||
formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n"
|
||||
# Add generation prompt
|
||||
@@ -1016,7 +1133,9 @@ class LTX2TextEncoder(nn.Module):
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||
except ImportError:
|
||||
logging.warning("mlx-lm not available for prompt enhancement. Using original prompt.")
|
||||
logging.warning(
|
||||
"mlx-lm not available for prompt enhancement. Using original prompt."
|
||||
)
|
||||
return prompt
|
||||
|
||||
if self.processor is None:
|
||||
@@ -1043,7 +1162,11 @@ class LTX2TextEncoder(nn.Module):
|
||||
)
|
||||
input_ids = mx.array(inputs["input_ids"])
|
||||
|
||||
sampler = make_sampler(kwargs.get("temperature", 0.7), kwargs.get("top_p", 1.0), top_k=kwargs.get("top_k", -1))
|
||||
sampler = make_sampler(
|
||||
kwargs.get("temperature", 0.7),
|
||||
kwargs.get("top_p", 1.0),
|
||||
top_k=kwargs.get("top_k", -1),
|
||||
)
|
||||
logits_processors = make_logits_processors(
|
||||
kwargs.get("logit_bias", None),
|
||||
kwargs.get("repetition_penalty", 1.3),
|
||||
@@ -1079,7 +1202,7 @@ class LTX2TextEncoder(nn.Module):
|
||||
for i, response in enumerate(generator):
|
||||
next_token = mx.array([response.token])
|
||||
input_ids = mx.concatenate([input_ids, next_token[None, :]], axis=1)
|
||||
generated_tokens.append(next_token.squeeze())
|
||||
generated_tokens.append(response.token)
|
||||
generated_token_count += 1
|
||||
progress.update(task, advance=1)
|
||||
|
||||
@@ -1094,14 +1217,15 @@ class LTX2TextEncoder(nn.Module):
|
||||
mx.clear_cache()
|
||||
|
||||
# Decode only the new tokens
|
||||
enhanced_prompt = self.processor.decode(generated_tokens, skip_special_tokens=True)
|
||||
enhanced_prompt = self.processor.decode(
|
||||
generated_tokens, skip_special_tokens=True
|
||||
)
|
||||
|
||||
enhanced_prompt = self._clean_response(enhanced_prompt)
|
||||
logging.info(f"Enhanced prompt: {enhanced_prompt}")
|
||||
|
||||
return enhanced_prompt
|
||||
|
||||
|
||||
def enhance_i2v(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -1135,4 +1259,3 @@ def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
|
||||
encoder = LTX2TextEncoder()
|
||||
encoder.load(model_path=model_path)
|
||||
return encoder
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ from typing import Optional, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig
|
||||
from mlx_video.models.ltx_2.attention import Attention
|
||||
from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig
|
||||
from mlx_video.models.ltx_2.feed_forward import FeedForward
|
||||
from mlx_video.utils import rms_norm
|
||||
|
||||
@@ -171,8 +171,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
|
||||
# timestep: (B, seq, num_params * dim) -> reshape to (B, seq, num_params, dim)
|
||||
timestep_reshaped = mx.reshape(
|
||||
timestep,
|
||||
(batch_size, timestep.shape[1], num_ada_params, -1)
|
||||
timestep, (batch_size, timestep.shape[1], num_ada_params, -1)
|
||||
)
|
||||
|
||||
# Extract the relevant indices
|
||||
@@ -225,8 +224,12 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
)
|
||||
|
||||
# Squeeze the sequence dimension if it's 1
|
||||
scale_shift_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada)
|
||||
gate_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada)
|
||||
scale_shift_squeezed = tuple(
|
||||
mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada
|
||||
)
|
||||
gate_squeezed = tuple(
|
||||
mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada
|
||||
)
|
||||
|
||||
return (*scale_shift_squeezed, *gate_squeezed)
|
||||
|
||||
@@ -258,8 +261,16 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
# Check which modalities to run
|
||||
run_vx = video is not None and video.enabled and vx.size > 0
|
||||
run_ax = audio is not None and audio.enabled and ax.size > 0
|
||||
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0) and not skip_cross_modal
|
||||
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0) and not skip_cross_modal
|
||||
run_a2v = (
|
||||
run_vx
|
||||
and (audio is not None and audio.enabled and ax.size > 0)
|
||||
and not skip_cross_modal
|
||||
)
|
||||
run_v2a = (
|
||||
run_ax
|
||||
and (video is not None and video.enabled and vx.size > 0)
|
||||
and not skip_cross_modal
|
||||
)
|
||||
|
||||
# Process video self-attention and cross-attention with text
|
||||
if run_vx:
|
||||
@@ -269,7 +280,15 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
|
||||
# Self-attention with RoPE (skip_attention=True for STG perturbation)
|
||||
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
|
||||
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings, skip_attention=skip_video_self_attn) * vgate_msa
|
||||
vx = (
|
||||
vx
|
||||
+ self.attn1(
|
||||
norm_vx,
|
||||
pe=video.positional_embeddings,
|
||||
skip_attention=skip_video_self_attn,
|
||||
)
|
||||
* vgate_msa
|
||||
)
|
||||
|
||||
# Cross-attention with text context
|
||||
if self.has_prompt_adaln:
|
||||
@@ -278,11 +297,24 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(6, 9)
|
||||
)
|
||||
vprompt_shift_kv, vprompt_scale_kv = self.get_ada_values(
|
||||
self.prompt_scale_shift_table, vx.shape[0], video.prompt_timesteps, slice(0, 2)
|
||||
self.prompt_scale_shift_table,
|
||||
vx.shape[0],
|
||||
video.prompt_timesteps,
|
||||
slice(0, 2),
|
||||
)
|
||||
attn_input = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_q) + vshift_q
|
||||
encoder_hidden_states = video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv
|
||||
vx = vx + self.attn2(attn_input, context=encoder_hidden_states, mask=video.context_mask) * vgate_q
|
||||
encoder_hidden_states = (
|
||||
video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv
|
||||
)
|
||||
vx = (
|
||||
vx
|
||||
+ self.attn2(
|
||||
attn_input,
|
||||
context=encoder_hidden_states,
|
||||
mask=video.context_mask,
|
||||
)
|
||||
* vgate_q
|
||||
)
|
||||
else:
|
||||
vx = vx + self.attn2(
|
||||
rms_norm(vx, eps=self.norm_eps),
|
||||
@@ -298,20 +330,46 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
|
||||
# Self-attention with RoPE (skip_attention=True for STG perturbation)
|
||||
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
|
||||
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings, skip_attention=skip_audio_self_attn) * agate_msa
|
||||
ax = (
|
||||
ax
|
||||
+ self.audio_attn1(
|
||||
norm_ax,
|
||||
pe=audio.positional_embeddings,
|
||||
skip_attention=skip_audio_self_attn,
|
||||
)
|
||||
* agate_msa
|
||||
)
|
||||
|
||||
# Cross-attention with text context
|
||||
if self.has_prompt_adaln:
|
||||
# LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln
|
||||
ashift_q, ascale_q, agate_q = self.get_ada_values(
|
||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(6, 9)
|
||||
self.audio_scale_shift_table,
|
||||
ax.shape[0],
|
||||
audio.timesteps,
|
||||
slice(6, 9),
|
||||
)
|
||||
aprompt_shift_kv, aprompt_scale_kv = self.get_ada_values(
|
||||
self.audio_prompt_scale_shift_table, ax.shape[0], audio.prompt_timesteps, slice(0, 2)
|
||||
self.audio_prompt_scale_shift_table,
|
||||
ax.shape[0],
|
||||
audio.prompt_timesteps,
|
||||
slice(0, 2),
|
||||
)
|
||||
attn_input_a = (
|
||||
rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q
|
||||
)
|
||||
encoder_hidden_states_a = (
|
||||
audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv
|
||||
)
|
||||
ax = (
|
||||
ax
|
||||
+ self.audio_attn2(
|
||||
attn_input_a,
|
||||
context=encoder_hidden_states_a,
|
||||
mask=audio.context_mask,
|
||||
)
|
||||
* agate_q
|
||||
)
|
||||
attn_input_a = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q
|
||||
encoder_hidden_states_a = audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv
|
||||
ax = ax + self.audio_attn2(attn_input_a, context=encoder_hidden_states_a, mask=audio.context_mask) * agate_q
|
||||
else:
|
||||
ax = ax + self.audio_attn2(
|
||||
rms_norm(ax, eps=self.norm_eps),
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
@@ -36,11 +37,20 @@ class Conv3d(nn.Module):
|
||||
self.groups = groups
|
||||
|
||||
# Weight shape: (C_out, KD, KH, KW, C_in)
|
||||
scale = 1.0 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5
|
||||
scale = (
|
||||
1.0
|
||||
/ (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5
|
||||
)
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels),
|
||||
shape=(
|
||||
out_channels,
|
||||
kernel_size[0],
|
||||
kernel_size[1],
|
||||
kernel_size[2],
|
||||
in_channels,
|
||||
),
|
||||
)
|
||||
|
||||
if bias:
|
||||
@@ -87,7 +97,6 @@ class GroupNorm3d(nn.Module):
|
||||
n, d, h, w, c = x.shape
|
||||
input_dtype = x.dtype
|
||||
|
||||
|
||||
x = x.astype(mx.float32)
|
||||
|
||||
# Reshape to (N, D*H*W, num_groups, C//num_groups)
|
||||
@@ -219,7 +228,9 @@ class SpatialRationalResampler(nn.Module):
|
||||
self.den = den
|
||||
|
||||
# Conv2d: mid_channels -> num^2 * mid_channels for PixelShuffle(num)
|
||||
self.conv = nn.Conv2d(mid_channels, num * num * mid_channels, kernel_size=3, padding=1)
|
||||
self.conv = nn.Conv2d(
|
||||
mid_channels, num * num * mid_channels, kernel_size=3, padding=1
|
||||
)
|
||||
self.pixel_shuffle = PixelShuffle2D(num, num)
|
||||
self.blur_down = BlurDownsample(stride=den)
|
||||
|
||||
@@ -230,7 +241,7 @@ class SpatialRationalResampler(nn.Module):
|
||||
|
||||
x = self.conv(x)
|
||||
x = self.pixel_shuffle(x) # H*num, W*num
|
||||
x = self.blur_down(x) # H*num/den, W*num/den
|
||||
x = self.blur_down(x) # H*num/den, W*num/den
|
||||
|
||||
_, h_out, w_out, _ = x.shape
|
||||
x = mx.reshape(x, (n, d, h_out, w_out, c))
|
||||
@@ -240,6 +251,7 @@ class SpatialRationalResampler(nn.Module):
|
||||
def _rational_for_scale(scale: float) -> Tuple[int, int]:
|
||||
"""Convert a float scale to a rational fraction (numerator, denominator)."""
|
||||
from fractions import Fraction
|
||||
|
||||
frac = Fraction(scale).limit_denominator(10)
|
||||
return frac.numerator, frac.denominator
|
||||
|
||||
@@ -290,16 +302,22 @@ class LatentUpsampler(nn.Module):
|
||||
self.initial_norm = GroupNorm3d(32, mid_channels)
|
||||
|
||||
# Pre-upsample ResBlocks - use dict with int keys for MLX parameter tracking
|
||||
self.res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
|
||||
self.res_blocks = {
|
||||
i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)
|
||||
}
|
||||
|
||||
# Upsampler: 2D spatial upsampling (frame-by-frame)
|
||||
if rational_resampler:
|
||||
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=spatial_scale)
|
||||
self.upsampler = SpatialRationalResampler(
|
||||
mid_channels=mid_channels, scale=spatial_scale
|
||||
)
|
||||
else:
|
||||
self.upsampler = SpatialUpsampler2x(mid_channels=mid_channels)
|
||||
|
||||
# Post-upsample ResBlocks - use dict with int keys for MLX parameter tracking
|
||||
self.post_upsample_res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
|
||||
self.post_upsample_res_blocks = {
|
||||
i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)
|
||||
}
|
||||
|
||||
# Final projection
|
||||
self.final_conv = Conv3d(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||
@@ -314,10 +332,13 @@ class LatentUpsampler(nn.Module):
|
||||
Returns:
|
||||
Upsampled tensor of shape (B, C, F, H*scale, W*scale) - channels first
|
||||
"""
|
||||
|
||||
def debug_stats(name, t):
|
||||
if debug:
|
||||
mx.eval(t)
|
||||
print(f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
|
||||
print(
|
||||
f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}"
|
||||
)
|
||||
|
||||
if debug:
|
||||
print(" [DEBUG] LatentUpsampler forward pass:")
|
||||
@@ -404,7 +425,11 @@ def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]:
|
||||
# x2: conv out = 4 * mid (2^2 * mid for PixelShuffle(2))
|
||||
# x1.5: conv out = 9 * mid (3^2 * mid for PixelShuffle(3)) + blur downsample
|
||||
# Both formats may have upsampler.blur_down.kernel, so use channel count
|
||||
conv_key = "upsampler.conv.weight" if "upsampler.conv.weight" in raw_weights else "upsampler.0.weight"
|
||||
conv_key = (
|
||||
"upsampler.conv.weight"
|
||||
if "upsampler.conv.weight" in raw_weights
|
||||
else "upsampler.0.weight"
|
||||
)
|
||||
if conv_key in raw_weights:
|
||||
out_channels = raw_weights[conv_key].shape[0]
|
||||
ratio = out_channels // mid_channels
|
||||
@@ -414,7 +439,9 @@ def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]:
|
||||
rational_resampler = False
|
||||
spatial_scale = 2.0
|
||||
|
||||
print(f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}")
|
||||
print(
|
||||
f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}"
|
||||
)
|
||||
|
||||
# Create model
|
||||
upsampler = LatentUpsampler(
|
||||
|
||||
@@ -109,6 +109,7 @@ def convert_audio_encoder(
|
||||
return encoder_dir
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
vae_path = hf_hub_download(
|
||||
source_repo,
|
||||
"audio_vae/diffusion_pytorch_model.safetensors",
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
|
||||
from mlx_video.models.ltx_2.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
|
||||
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import (
|
||||
TilingConfig,
|
||||
SpatialTilingConfig,
|
||||
TemporalTilingConfig,
|
||||
TilingConfig,
|
||||
)
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -27,14 +27,18 @@ def reflect_pad_2d(x: mx.array, pad_h: int, pad_w: int) -> mx.array:
|
||||
# Height padding (axis 2)
|
||||
if pad_h > 0:
|
||||
# Get reflection indices - exclude boundary
|
||||
top_pad = x[:, :, 1:pad_h+1, :, :][:, :, ::-1, :, :] # Flip top portion
|
||||
bottom_pad = x[:, :, -pad_h-1:-1, :, :][:, :, ::-1, :, :] # Flip bottom portion
|
||||
top_pad = x[:, :, 1 : pad_h + 1, :, :][:, :, ::-1, :, :] # Flip top portion
|
||||
bottom_pad = x[:, :, -pad_h - 1 : -1, :, :][
|
||||
:, :, ::-1, :, :
|
||||
] # Flip bottom portion
|
||||
x = mx.concatenate([top_pad, x, bottom_pad], axis=2)
|
||||
|
||||
# Width padding (axis 3)
|
||||
if pad_w > 0:
|
||||
left_pad = x[:, :, :, 1:pad_w+1, :][:, :, :, ::-1, :] # Flip left portion
|
||||
right_pad = x[:, :, :, -pad_w-1:-1, :][:, :, :, ::-1, :] # Flip right portion
|
||||
left_pad = x[:, :, :, 1 : pad_w + 1, :][:, :, :, ::-1, :] # Flip left portion
|
||||
right_pad = x[:, :, :, -pad_w - 1 : -1, :][
|
||||
:, :, :, ::-1, :
|
||||
] # Flip right portion
|
||||
x = mx.concatenate([left_pad, x, right_pad], axis=3)
|
||||
|
||||
return x
|
||||
@@ -126,7 +130,9 @@ class CausalConv3d(nn.Module):
|
||||
if self.time_kernel_size > 1:
|
||||
if use_causal:
|
||||
# Causal: replicate first frame kernel_size-1 times at the beginning
|
||||
first_frame_pad = mx.repeat(x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2)
|
||||
first_frame_pad = mx.repeat(
|
||||
x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2
|
||||
)
|
||||
x = mx.concatenate([first_frame_pad, x], axis=2)
|
||||
else:
|
||||
# Non-causal: replicate first frame at start, last frame at end
|
||||
@@ -176,7 +182,6 @@ class CausalConv3d(nn.Module):
|
||||
"""
|
||||
b, d, h, w, c = x.shape
|
||||
|
||||
|
||||
total_elements = d * h * w * c
|
||||
max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk
|
||||
|
||||
@@ -191,7 +196,6 @@ class CausalConv3d(nn.Module):
|
||||
|
||||
overlap = kernel_t - 1
|
||||
|
||||
|
||||
expected_output_frames = d - overlap
|
||||
|
||||
outputs = []
|
||||
|
||||
@@ -15,14 +15,14 @@ Architecture (from PyTorch weights):
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional, Dict
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.ops import unpatchify, PerChannelStatistics
|
||||
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, unpatchify
|
||||
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
@@ -77,16 +77,14 @@ class PixArtAlphaTimestepEmbedder(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=256,
|
||||
time_embed_dim=embedding_dim
|
||||
in_channels=256, time_embed_dim=embedding_dim
|
||||
)
|
||||
|
||||
def __call__(self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32) -> mx.array:
|
||||
def __call__(
|
||||
self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32
|
||||
) -> mx.array:
|
||||
timesteps_proj = get_timestep_embedding(
|
||||
timestep,
|
||||
embedding_dim=256,
|
||||
flip_sin_to_cos=True,
|
||||
downscale_freq_shift=0
|
||||
timestep, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0
|
||||
)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
|
||||
return timesteps_emb
|
||||
@@ -119,6 +117,7 @@ class ResnetBlock3DSimple(nn.Module):
|
||||
|
||||
def _make_conv_wrapper(self, in_ch, out_ch, padding_mode):
|
||||
"""Create a wrapper object with a 'conv' attribute to match PyTorch naming."""
|
||||
|
||||
class ConvWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
@@ -130,13 +129,15 @@ class ResnetBlock3DSimple(nn.Module):
|
||||
padding=1,
|
||||
spatial_padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
|
||||
return ConvWrapper()
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
"""Apply pixel normalization."""
|
||||
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
|
||||
return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -153,7 +154,9 @@ class ResnetBlock3DSimple(nn.Module):
|
||||
if self.timestep_conditioning and timestep_embed is not None:
|
||||
# scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1)
|
||||
# Combine table with timestep embedding
|
||||
ada_values = self.scale_shift_table[None, :, :, None, None, None] # (1, 4, C, 1, 1, 1)
|
||||
ada_values = self.scale_shift_table[
|
||||
None, :, :, None, None, None
|
||||
] # (1, 4, C, 1, 1, 1)
|
||||
# Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1)
|
||||
channels = self.scale_shift_table.shape[1]
|
||||
ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1)
|
||||
@@ -199,16 +202,14 @@ class ResBlockGroup(nn.Module):
|
||||
|
||||
# Time embedder for this block group: embed_dim = 4 * channels
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaTimestepEmbedder(
|
||||
embedding_dim=channels * 4
|
||||
)
|
||||
self.time_embedder = PixArtAlphaTimestepEmbedder(embedding_dim=channels * 4)
|
||||
|
||||
# Use dict with int keys for MLX to track parameters properly
|
||||
self.res_blocks = {
|
||||
i: ResnetBlock3DSimple(
|
||||
channels,
|
||||
spatial_padding_mode,
|
||||
timestep_conditioning=timestep_conditioning
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
}
|
||||
@@ -224,8 +225,7 @@ class ResBlockGroup(nn.Module):
|
||||
if self.timestep_conditioning and timestep is not None:
|
||||
batch_size = x.shape[0]
|
||||
timestep_embed = self.time_embedder(
|
||||
timestep.flatten(),
|
||||
hidden_dtype=x.dtype
|
||||
timestep.flatten(), hidden_dtype=x.dtype
|
||||
)
|
||||
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
|
||||
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
|
||||
@@ -301,8 +301,10 @@ class LTX2VideoDecoder(nn.Module):
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
|
||||
self.conv_in = ConvInWrapper()
|
||||
|
||||
# Build up blocks from config
|
||||
@@ -311,8 +313,12 @@ class LTX2VideoDecoder(nn.Module):
|
||||
block_type = block_def[0]
|
||||
ch = block_def[1]
|
||||
if block_type == "res":
|
||||
num_layers = block_def[2] if len(block_def) > 2 else num_layers_per_block
|
||||
self.up_blocks[idx] = ResBlockGroup(ch, num_layers, spatial_padding_mode, timestep_conditioning)
|
||||
num_layers = (
|
||||
block_def[2] if len(block_def) > 2 else num_layers_per_block
|
||||
)
|
||||
self.up_blocks[idx] = ResBlockGroup(
|
||||
ch, num_layers, spatial_padding_mode, timestep_conditioning
|
||||
)
|
||||
elif block_type == "d2s":
|
||||
reduction = block_def[2] if len(block_def) > 2 else 2
|
||||
stride = block_def[3] if len(block_def) > 3 else (2, 2, 2)
|
||||
@@ -327,6 +333,7 @@ class LTX2VideoDecoder(nn.Module):
|
||||
)
|
||||
|
||||
final_out_channels = out_channels * patch_size * patch_size
|
||||
|
||||
class ConvOutWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
@@ -338,8 +345,10 @@ class LTX2VideoDecoder(nn.Module):
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
|
||||
self.conv_out = ConvOutWrapper()
|
||||
|
||||
self.act = nn.SiLU()
|
||||
@@ -374,7 +383,6 @@ class LTX2VideoDecoder(nn.Module):
|
||||
if key.startswith("vae.decoder."):
|
||||
new_key = key.replace("vae.decoder.", "")
|
||||
|
||||
|
||||
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||
if ".conv.weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
@@ -384,7 +392,10 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
if ".conv.weight" in new_key or ".conv.bias" in new_key:
|
||||
|
||||
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
|
||||
if (
|
||||
".conv.conv.weight" not in new_key
|
||||
and ".conv.conv.bias" not in new_key
|
||||
):
|
||||
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
|
||||
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
|
||||
|
||||
@@ -392,7 +403,9 @@ class LTX2VideoDecoder(nn.Module):
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTX2VideoDecoder":
|
||||
def from_pretrained(
|
||||
cls, model_path: Path, strict: bool = True
|
||||
) -> "LTX2VideoDecoder":
|
||||
"""Load a pretrained decoder from a directory with config.json and weights.
|
||||
|
||||
Args:
|
||||
@@ -422,7 +435,6 @@ class LTX2VideoDecoder(nn.Module):
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(str(wf)))
|
||||
|
||||
|
||||
# Infer block structure from weights
|
||||
decoder_blocks = cls._infer_blocks(weights)
|
||||
|
||||
@@ -537,11 +549,9 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
return final_blocks
|
||||
|
||||
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
"""Apply pixel normalization."""
|
||||
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
|
||||
return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -552,20 +562,15 @@ class LTX2VideoDecoder(nn.Module):
|
||||
chunked_conv: bool = False,
|
||||
) -> mx.array:
|
||||
|
||||
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
|
||||
|
||||
# Add noise if timestep conditioning is enabled
|
||||
if self.timestep_conditioning:
|
||||
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
||||
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
||||
|
||||
|
||||
sample = self.per_channel_statistics.un_normalize(sample)
|
||||
|
||||
|
||||
if timestep is None and self.timestep_conditioning:
|
||||
timestep = mx.full((batch_size,), self.decode_timestep)
|
||||
|
||||
@@ -575,7 +580,6 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
x = self.conv_in(sample, causal=causal)
|
||||
|
||||
|
||||
for i, block in self.up_blocks.items():
|
||||
if isinstance(block, ResBlockGroup):
|
||||
x = block(x, causal=causal, timestep=scaled_timestep)
|
||||
@@ -584,18 +588,17 @@ class LTX2VideoDecoder(nn.Module):
|
||||
else:
|
||||
x = block(x, causal=causal)
|
||||
|
||||
|
||||
x = self.pixel_norm(x)
|
||||
|
||||
|
||||
if self.timestep_conditioning and scaled_timestep is not None:
|
||||
embedded_timestep = self.last_time_embedder(
|
||||
scaled_timestep.flatten(),
|
||||
hidden_dtype=x.dtype
|
||||
scaled_timestep.flatten(), hidden_dtype=x.dtype
|
||||
)
|
||||
embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1)
|
||||
|
||||
ada_values = self.last_scale_shift_table[None, :, :, None, None, None] # (1, 2, 128, 1, 1, 1)
|
||||
ada_values = self.last_scale_shift_table[
|
||||
None, :, :, None, None, None
|
||||
] # (1, 2, 128, 1, 1, 1)
|
||||
ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1)
|
||||
ada_values = ada_values + ts_reshaped
|
||||
|
||||
@@ -604,16 +607,13 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
x = x * (1 + scale) + shift
|
||||
|
||||
|
||||
x = self.act(x)
|
||||
|
||||
|
||||
x = self.conv_out(x, causal=causal)
|
||||
|
||||
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
|
||||
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
|
||||
return x
|
||||
|
||||
def decode_tiled(
|
||||
@@ -669,11 +669,23 @@ class LTX2VideoDecoder(nn.Module):
|
||||
|
||||
# Auto-enable chunked conv for modes where it helps (larger tiles)
|
||||
# Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks
|
||||
use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial")
|
||||
use_chunked_conv = tiling_mode in (
|
||||
"conservative",
|
||||
"none",
|
||||
"auto",
|
||||
"default",
|
||||
"spatial",
|
||||
)
|
||||
|
||||
if not needs_spatial_tiling and not needs_temporal_tiling:
|
||||
# No tiling needed, use regular decode
|
||||
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
|
||||
return self(
|
||||
sample,
|
||||
causal=causal,
|
||||
timestep=timestep,
|
||||
debug=debug,
|
||||
chunked_conv=use_chunked_conv,
|
||||
)
|
||||
|
||||
return decode_with_tiling(
|
||||
decoder_fn=self,
|
||||
|
||||
@@ -6,8 +6,8 @@ to latent space, which can then be used to condition video generation.
|
||||
"""
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||
|
||||
|
||||
def encode_image(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Operations for Video VAE."""
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -32,7 +31,9 @@ def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.a
|
||||
new_c = c * patch_size_hw * patch_size_hw * patch_size_t
|
||||
|
||||
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
|
||||
x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw))
|
||||
x = mx.reshape(
|
||||
x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw)
|
||||
)
|
||||
|
||||
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, pw, ph, F', H', W')
|
||||
# PyTorch einops uses (c, p, r, q) = (c, temporal, width, height), so we need pw before ph
|
||||
|
||||
@@ -156,7 +156,9 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array:
|
||||
def __call__(
|
||||
self, x: mx.array, causal: bool = True, chunked_conv: bool = False
|
||||
) -> mx.array:
|
||||
|
||||
b, c, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
@@ -196,7 +198,9 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array:
|
||||
def _chunked_conv_depth_to_space(
|
||||
self, x: mx.array, causal: bool = True
|
||||
) -> mx.array:
|
||||
"""Chunked conv + depth_to_space that processes in temporal chunks.
|
||||
|
||||
This reduces peak memory by avoiding the full high-channel intermediate tensor.
|
||||
|
||||
@@ -55,7 +55,9 @@ def compute_trapezoidal_mask_1d(
|
||||
# Apply right ramp (fade out)
|
||||
if ramp_right > 0:
|
||||
# Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1]
|
||||
fade_out = [(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)]
|
||||
fade_out = [
|
||||
(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)
|
||||
]
|
||||
for i in range(ramp_right):
|
||||
mask[length - ramp_right + i] *= fade_out[i]
|
||||
|
||||
@@ -71,11 +73,17 @@ class SpatialTilingConfig:
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.tile_size_in_pixels < 64:
|
||||
raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}")
|
||||
raise ValueError(
|
||||
f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}"
|
||||
)
|
||||
if self.tile_size_in_pixels % 32 != 0:
|
||||
raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}")
|
||||
raise ValueError(
|
||||
f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}"
|
||||
)
|
||||
if self.tile_overlap_in_pixels % 32 != 0:
|
||||
raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}")
|
||||
raise ValueError(
|
||||
f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}"
|
||||
)
|
||||
if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:
|
||||
raise ValueError(
|
||||
f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}"
|
||||
@@ -91,11 +99,17 @@ class TemporalTilingConfig:
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.tile_size_in_frames < 16:
|
||||
raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}")
|
||||
raise ValueError(
|
||||
f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}"
|
||||
)
|
||||
if self.tile_size_in_frames % 8 != 0:
|
||||
raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}")
|
||||
raise ValueError(
|
||||
f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}"
|
||||
)
|
||||
if self.tile_overlap_in_frames % 8 != 0:
|
||||
raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}")
|
||||
raise ValueError(
|
||||
f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}"
|
||||
)
|
||||
if self.tile_overlap_in_frames >= self.tile_size_in_frames:
|
||||
raise ValueError(
|
||||
f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}"
|
||||
@@ -113,15 +127,21 @@ class TilingConfig:
|
||||
def default(cls) -> "TilingConfig":
|
||||
"""Default tiling: 512px spatial, 64 frame temporal."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64),
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24),
|
||||
spatial_config=SpatialTilingConfig(
|
||||
tile_size_in_pixels=512, tile_overlap_in_pixels=64
|
||||
),
|
||||
temporal_config=TemporalTilingConfig(
|
||||
tile_size_in_frames=64, tile_overlap_in_frames=24
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig":
|
||||
"""Spatial tiling only (for short videos with large resolution)."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap),
|
||||
spatial_config=SpatialTilingConfig(
|
||||
tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap
|
||||
),
|
||||
temporal_config=None,
|
||||
)
|
||||
|
||||
@@ -130,23 +150,33 @@ class TilingConfig:
|
||||
"""Temporal tiling only (for long videos with small resolution)."""
|
||||
return cls(
|
||||
spatial_config=None,
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap),
|
||||
temporal_config=TemporalTilingConfig(
|
||||
tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def aggressive(cls) -> "TilingConfig":
|
||||
"""Aggressive tiling for very large videos (smaller tiles, much lower memory)."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=256, tile_overlap_in_pixels=64),
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=32, tile_overlap_in_frames=8),
|
||||
spatial_config=SpatialTilingConfig(
|
||||
tile_size_in_pixels=256, tile_overlap_in_pixels=64
|
||||
),
|
||||
temporal_config=TemporalTilingConfig(
|
||||
tile_size_in_frames=32, tile_overlap_in_frames=8
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def conservative(cls) -> "TilingConfig":
|
||||
"""Conservative tiling (larger tiles, less memory savings but faster)."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=768, tile_overlap_in_pixels=64),
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=96, tile_overlap_in_frames=24),
|
||||
spatial_config=SpatialTilingConfig(
|
||||
tile_size_in_pixels=768, tile_overlap_in_pixels=64
|
||||
),
|
||||
temporal_config=TemporalTilingConfig(
|
||||
tile_size_in_frames=96, tile_overlap_in_frames=24
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -186,10 +216,14 @@ class TilingConfig:
|
||||
temporal_config = None
|
||||
|
||||
if needs_spatial:
|
||||
spatial_config = SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64)
|
||||
spatial_config = SpatialTilingConfig(
|
||||
tile_size_in_pixels=512, tile_overlap_in_pixels=64
|
||||
)
|
||||
|
||||
if needs_temporal:
|
||||
temporal_config = TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24)
|
||||
temporal_config = TemporalTilingConfig(
|
||||
tile_size_in_frames=64, tile_overlap_in_frames=24
|
||||
)
|
||||
|
||||
return cls(spatial_config=spatial_config, temporal_config=temporal_config)
|
||||
|
||||
@@ -197,16 +231,21 @@ class TilingConfig:
|
||||
@dataclass
|
||||
class DimensionIntervals:
|
||||
"""Intervals for splitting a single dimension."""
|
||||
|
||||
starts: List[int]
|
||||
ends: List[int]
|
||||
left_ramps: List[int]
|
||||
right_ramps: List[int]
|
||||
|
||||
|
||||
def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
|
||||
def split_in_spatial(
|
||||
size: int, overlap: int, dimension_size: int
|
||||
) -> DimensionIntervals:
|
||||
"""Split a spatial dimension into intervals."""
|
||||
if dimension_size <= size:
|
||||
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
|
||||
return DimensionIntervals(
|
||||
starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]
|
||||
)
|
||||
|
||||
amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)
|
||||
starts = [i * (size - overlap) for i in range(amount)]
|
||||
@@ -215,13 +254,19 @@ def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionI
|
||||
left_ramps = [0] + [overlap] * (amount - 1)
|
||||
right_ramps = [overlap] * (amount - 1) + [0]
|
||||
|
||||
return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps)
|
||||
return DimensionIntervals(
|
||||
starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps
|
||||
)
|
||||
|
||||
|
||||
def split_in_temporal(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
|
||||
def split_in_temporal(
|
||||
size: int, overlap: int, dimension_size: int
|
||||
) -> DimensionIntervals:
|
||||
"""Split a temporal dimension into intervals with causal adjustment."""
|
||||
if dimension_size <= size:
|
||||
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
|
||||
return DimensionIntervals(
|
||||
starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]
|
||||
)
|
||||
|
||||
# Start with spatial split
|
||||
intervals = split_in_spatial(size, overlap, dimension_size)
|
||||
@@ -234,28 +279,41 @@ def split_in_temporal(size: int, overlap: int, dimension_size: int) -> Dimension
|
||||
starts[i] = starts[i] - 1
|
||||
left_ramps[i] = left_ramps[i] + 1
|
||||
|
||||
return DimensionIntervals(starts=starts, ends=intervals.ends, left_ramps=left_ramps, right_ramps=intervals.right_ramps)
|
||||
return DimensionIntervals(
|
||||
starts=starts,
|
||||
ends=intervals.ends,
|
||||
left_ramps=left_ramps,
|
||||
right_ramps=intervals.right_ramps,
|
||||
)
|
||||
|
||||
|
||||
def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
|
||||
def map_temporal_slice(
|
||||
begin: int, end: int, left_ramp: int, right_ramp: int, scale: int
|
||||
) -> Tuple[slice, mx.array]:
|
||||
"""Map temporal latent interval to output coordinates and mask."""
|
||||
start = begin * scale
|
||||
stop = 1 + (end - 1) * scale
|
||||
left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0
|
||||
right_ramp_scaled = right_ramp * scale
|
||||
|
||||
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, True)
|
||||
mask = compute_trapezoidal_mask_1d(
|
||||
stop - start, left_ramp_scaled, right_ramp_scaled, True
|
||||
)
|
||||
return slice(start, stop), mask
|
||||
|
||||
|
||||
def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
|
||||
def map_spatial_slice(
|
||||
begin: int, end: int, left_ramp: int, right_ramp: int, scale: int
|
||||
) -> Tuple[slice, mx.array]:
|
||||
"""Map spatial latent interval to output coordinates and mask."""
|
||||
start = begin * scale
|
||||
stop = end * scale
|
||||
left_ramp_scaled = left_ramp * scale
|
||||
right_ramp_scaled = right_ramp * scale
|
||||
|
||||
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, False)
|
||||
mask = compute_trapezoidal_mask_1d(
|
||||
stop - start, left_ramp_scaled, right_ramp_scaled, False
|
||||
)
|
||||
return slice(start, stop), mask
|
||||
|
||||
|
||||
@@ -315,7 +373,9 @@ def decode_with_tiling(
|
||||
temporal_overlap = 0
|
||||
|
||||
# Compute intervals for each dimension
|
||||
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
|
||||
temporal_intervals = split_in_temporal(
|
||||
temporal_tile_size, temporal_overlap, f_latent
|
||||
)
|
||||
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
|
||||
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
|
||||
|
||||
@@ -338,7 +398,9 @@ def decode_with_tiling(
|
||||
t_right = temporal_intervals.right_ramps[t_idx]
|
||||
|
||||
# Map temporal coordinates
|
||||
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
||||
out_t_slice, t_mask = map_temporal_slice(
|
||||
t_start, t_end, t_left, t_right, temporal_scale
|
||||
)
|
||||
|
||||
for h_idx in range(num_h_tiles):
|
||||
h_start = height_intervals.starts[h_idx]
|
||||
@@ -347,7 +409,9 @@ def decode_with_tiling(
|
||||
h_right = height_intervals.right_ramps[h_idx]
|
||||
|
||||
# Map height coordinates
|
||||
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
|
||||
out_h_slice, h_mask = map_spatial_slice(
|
||||
h_start, h_end, h_left, h_right, spatial_scale
|
||||
)
|
||||
|
||||
for w_idx in range(num_w_tiles):
|
||||
w_start = width_intervals.starts[w_idx]
|
||||
@@ -356,13 +420,23 @@ def decode_with_tiling(
|
||||
w_right = width_intervals.right_ramps[w_idx]
|
||||
|
||||
# Map width coordinates
|
||||
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
|
||||
out_w_slice, w_mask = map_spatial_slice(
|
||||
w_start, w_end, w_left, w_right, spatial_scale
|
||||
)
|
||||
|
||||
# Extract tile latents (small slice)
|
||||
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
|
||||
tile_latents = latents[
|
||||
:, :, t_start:t_end, h_start:h_end, w_start:w_end
|
||||
]
|
||||
|
||||
# Decode tile
|
||||
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
|
||||
tile_output = decoder_fn(
|
||||
tile_latents,
|
||||
causal=causal,
|
||||
timestep=timestep,
|
||||
debug=False,
|
||||
chunked_conv=chunked_conv,
|
||||
)
|
||||
mx.eval(tile_output)
|
||||
|
||||
# Clear tile_latents reference
|
||||
@@ -385,13 +459,15 @@ def decode_with_tiling(
|
||||
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
|
||||
|
||||
blend_mask = (
|
||||
t_mask_slice.reshape(1, 1, -1, 1, 1) *
|
||||
h_mask_slice.reshape(1, 1, 1, -1, 1) *
|
||||
w_mask_slice.reshape(1, 1, 1, 1, -1)
|
||||
t_mask_slice.reshape(1, 1, -1, 1, 1)
|
||||
* h_mask_slice.reshape(1, 1, 1, -1, 1)
|
||||
* w_mask_slice.reshape(1, 1, 1, 1, -1)
|
||||
)
|
||||
|
||||
# Slice tile output to match
|
||||
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
|
||||
tile_output_slice = tile_output[
|
||||
:, :, :actual_t, :actual_h, :actual_w
|
||||
].astype(mx.float32)
|
||||
|
||||
# Clear full tile_output
|
||||
del tile_output
|
||||
@@ -409,11 +485,37 @@ def decode_with_tiling(
|
||||
weighted_tile = tile_output_slice * blend_mask
|
||||
|
||||
# Update output using slice assignment
|
||||
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
||||
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
|
||||
output[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
] = (
|
||||
output[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
]
|
||||
+ weighted_tile
|
||||
)
|
||||
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
||||
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
|
||||
weights[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
] = (
|
||||
weights[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
]
|
||||
+ blend_mask
|
||||
)
|
||||
|
||||
# Force evaluation to free memory
|
||||
@@ -445,10 +547,12 @@ def decode_with_tiling(
|
||||
if next_tile_start_latent == 0:
|
||||
next_tile_start_out = 0
|
||||
else:
|
||||
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
|
||||
next_tile_start_out = (
|
||||
1 + (next_tile_start_latent - 1) * temporal_scale
|
||||
)
|
||||
|
||||
# We need to track how many frames we've already emitted
|
||||
if not hasattr(decode_with_tiling, '_emitted_frames'):
|
||||
if not hasattr(decode_with_tiling, "_emitted_frames"):
|
||||
decode_with_tiling._emitted_frames = 0
|
||||
emitted = decode_with_tiling._emitted_frames
|
||||
|
||||
@@ -456,7 +560,10 @@ def decode_with_tiling(
|
||||
# Normalize and emit frames [emitted, next_tile_start_out)
|
||||
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
|
||||
finalized_weights = mx.maximum(finalized_weights, 1e-8)
|
||||
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
|
||||
finalized_output = (
|
||||
output[:, :, emitted:next_tile_start_out, :, :]
|
||||
/ finalized_weights
|
||||
)
|
||||
finalized_output = finalized_output.astype(latents.dtype)
|
||||
mx.eval(finalized_output)
|
||||
|
||||
@@ -473,7 +580,7 @@ def decode_with_tiling(
|
||||
|
||||
# Emit remaining frames if callback provided
|
||||
if on_frames_ready is not None:
|
||||
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
|
||||
emitted = getattr(decode_with_tiling, "_emitted_frames", 0)
|
||||
if emitted < out_f:
|
||||
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
||||
mx.eval(remaining_output)
|
||||
@@ -481,7 +588,7 @@ def decode_with_tiling(
|
||||
del remaining_output
|
||||
|
||||
# Reset emitted frames counter for next call
|
||||
if hasattr(decode_with_tiling, '_emitted_frames'):
|
||||
if hasattr(decode_with_tiling, "_emitted_frames"):
|
||||
del decode_with_tiling._emitted_frames
|
||||
|
||||
# Clean up weights
|
||||
|
||||
@@ -8,12 +8,15 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, patchify, unpatchify
|
||||
from mlx_video.models.ltx_2.video_vae.ops import (
|
||||
PerChannelStatistics,
|
||||
patchify,
|
||||
unpatchify,
|
||||
)
|
||||
from mlx_video.models.ltx_2.video_vae.resnet import (
|
||||
NormLayerType,
|
||||
ResnetBlock3D,
|
||||
UNetMidBlock3D,
|
||||
get_norm_layer,
|
||||
)
|
||||
from mlx_video.models.ltx_2.video_vae.sampling import (
|
||||
DepthToSpaceUpsample,
|
||||
@@ -24,6 +27,7 @@ from mlx_video.utils import PixelNorm
|
||||
|
||||
class LogVarianceType(Enum):
|
||||
"""Log variance mode for VAE."""
|
||||
|
||||
PER_CHANNEL = "per_channel"
|
||||
UNIFORM = "uniform"
|
||||
CONSTANT = "constant"
|
||||
@@ -229,7 +233,6 @@ class VideoEncoder(nn.Module):
|
||||
config: VideoEncoderModelConfig with encoder parameters
|
||||
"""
|
||||
super().__init__()
|
||||
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
|
||||
|
||||
self.patch_size = config.patch_size
|
||||
self.norm_layer = config.norm_layer
|
||||
@@ -241,10 +244,12 @@ class VideoEncoder(nn.Module):
|
||||
encoder_spatial_padding_mode = config.encoder_spatial_padding_mode
|
||||
|
||||
# Per-channel statistics for normalizing latents
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.out_channels)
|
||||
self.per_channel_statistics = PerChannelStatistics(
|
||||
latent_channels=config.out_channels
|
||||
)
|
||||
|
||||
# After patchify, channels increase by patch_size^2
|
||||
in_channels = config.in_channels * config.patch_size ** 2
|
||||
in_channels = config.in_channels * config.patch_size**2
|
||||
feature_channels = config.out_channels
|
||||
|
||||
# Initial convolution
|
||||
@@ -262,7 +267,11 @@ class VideoEncoder(nn.Module):
|
||||
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||
self.down_blocks = {}
|
||||
for idx, (block_name, block_params) in enumerate(encoder_blocks):
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
block_config = (
|
||||
{"num_layers": block_params}
|
||||
if isinstance(block_params, int)
|
||||
else block_params
|
||||
)
|
||||
|
||||
block, feature_channels = _make_encoder_block(
|
||||
block_name=block_name,
|
||||
@@ -291,7 +300,10 @@ class VideoEncoder(nn.Module):
|
||||
conv_out_channels = config.out_channels
|
||||
if config.latent_log_var == LogVarianceType.PER_CHANNEL:
|
||||
conv_out_channels *= 2
|
||||
elif config.latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
|
||||
elif config.latent_log_var in {
|
||||
LogVarianceType.UNIFORM,
|
||||
LogVarianceType.CONSTANT,
|
||||
}:
|
||||
conv_out_channels += 1
|
||||
|
||||
self.conv_out = CausalConv3d(
|
||||
@@ -349,13 +361,16 @@ class VideoEncoder(nn.Module):
|
||||
elif self.latent_log_var == LogVarianceType.CONSTANT:
|
||||
sample = sample[:, :-1, ...]
|
||||
approx_ln_0 = -30
|
||||
sample = mx.concatenate([
|
||||
sample,
|
||||
mx.full_like(sample, approx_ln_0),
|
||||
], axis=1)
|
||||
sample = mx.concatenate(
|
||||
[
|
||||
sample,
|
||||
mx.full_like(sample, approx_ln_0),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
# Split into means and logvar, normalize means
|
||||
means = sample[:, :self.latent_channels, ...]
|
||||
means = sample[:, : self.latent_channels, ...]
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
@@ -409,6 +424,7 @@ class VideoEncoder(nn.Module):
|
||||
Loaded VideoEncoder instance
|
||||
"""
|
||||
import json
|
||||
|
||||
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
|
||||
|
||||
# Load config
|
||||
@@ -474,7 +490,7 @@ class VideoDecoder(nn.Module):
|
||||
decoder_blocks = []
|
||||
|
||||
self.patch_size = patch_size
|
||||
out_channels = out_channels * patch_size ** 2
|
||||
out_channels = out_channels * patch_size**2
|
||||
self.causal = causal
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
|
||||
@@ -510,7 +526,11 @@ class VideoDecoder(nn.Module):
|
||||
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||
self.up_blocks = {}
|
||||
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
block_config = (
|
||||
{"num_layers": block_params}
|
||||
if isinstance(block_params, int)
|
||||
else block_params
|
||||
)
|
||||
|
||||
block, feature_channels = _make_decoder_block(
|
||||
block_name=block_name,
|
||||
|
||||
349
mlx_video/models/wan_2/README.md
Normal file
349
mlx_video/models/wan_2/README.md
Normal file
@@ -0,0 +1,349 @@
|
||||
|
||||
## Wan2.1 / Wan2.2
|
||||
|
||||
Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.com/Wan-Video/Wan2.2) are text-to-video diffusion models built on a DiT (Diffusion Transformer) backbone with a T5 text encoder and 3D VAE.
|
||||
|
||||
They share the same model architecture — the difference is in the inference pipeline:
|
||||
|
||||
| | Wan2.1 | Wan2.2 T2V-14B | Wan2.2 I2V-14B | Wan2.2 TI2V-5B |
|
||||
|---|--------|--------|--------|--------|
|
||||
| **Task** | Text-to-Video | Text-to-Video | Image-to-Video | Text+Image-to-Video |
|
||||
| **Pipeline** | Single model | Dual model | Dual model | Single model |
|
||||
| **Sizes** | 1.3B, 14B | 14B | 14B | 5B |
|
||||
| **Resolution** | 480P (1.3B), 720P (14B) | 720P | 720P | 720P |
|
||||
| **Steps** | 50 | 40 | 40 | 40 |
|
||||
| **Guidance** | 5.0 (fixed) | 3.0 / 4.0 | 3.5 / 3.5 | 5.0 (fixed) |
|
||||
| **Shift** | 5.0 | 12.0 | 5.0 | 5.0 |
|
||||
| **VAE** | Wan2.1 (z=16) | Wan2.1 (z=16) | Wan2.1 (z=16) + encoder | Wan2.2 (z=48) |
|
||||
|
||||
### Step 1: Download Weights
|
||||
|
||||
Download the original PyTorch checkpoints from HuggingFace using the `huggingface-cli` tool (install with `pip install huggingface_hub`):
|
||||
|
||||
**Wan2.1**
|
||||
```bash
|
||||
# Text-to-Video 1.3B (fast, fits in ~4 GB)
|
||||
huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir ./Wan2.1-T2V-1.3B
|
||||
|
||||
# Text-to-Video 14B
|
||||
huggingface-cli download Wan-AI/Wan2.1-T2V-14B --local-dir ./Wan2.1-T2V-14B
|
||||
```
|
||||
|
||||
**Wan2.2**
|
||||
```bash
|
||||
# Text-to-Video 14B
|
||||
huggingface-cli download Wan-AI/Wan2.2-T2V-A14B --local-dir ./Wan2.2-T2V-A14B
|
||||
|
||||
# Image-to-Video 14B
|
||||
huggingface-cli download Wan-AI/Wan2.2-I2V-A14B --local-dir ./Wan2.2-I2V-A14B
|
||||
|
||||
# Text+Image-to-Video 5B (uses a different VAE — z_dim=48)
|
||||
huggingface-cli download Wan-AI/Wan2.2-TI2V-5B --local-dir ./Wan2.2-TI2V-5B
|
||||
```
|
||||
|
||||
Each downloaded directory will have this structure:
|
||||
|
||||
```
|
||||
Wan2.1-T2V-*/
|
||||
├── models_t5_umt5-xxl-enc-bf16.pth # T5 text encoder
|
||||
├── Wan2.1_VAE.pth # 3D VAE
|
||||
└── diffusion_pytorch_model*.safetensors # transformer (single)
|
||||
|
||||
Wan2.2-T2V-A14B/ or Wan2.2-I2V-A14B/
|
||||
├── models_t5_umt5-xxl-enc-bf16.pth
|
||||
├── Wan2.1_VAE.pth
|
||||
├── low_noise_model/ # dual-model low-noise transformer
|
||||
└── high_noise_model/ # dual-model high-noise transformer
|
||||
|
||||
Wan2.2-TI2V-5B/
|
||||
├── models_t5_umt5-xxl-enc-bf16.pth
|
||||
├── Wan2.2_VAE.pth # different VAE (z_dim=48)
|
||||
└── diffusion_pytorch_model*.safetensors # transformer (single)
|
||||
```
|
||||
|
||||
> **Wan2.2 I2V-14B** shares the same directory structure as Wan2.2 T2V. The conversion script auto-detects I2V from the model's `config.json` (`model_type: "i2v"`, `in_dim: 36`).
|
||||
|
||||
### Step 2: Convert to MLX Format
|
||||
|
||||
The conversion script auto-detects the model version from the directory structure (presence of `low_noise_model/` → Wan2.2 dual model) and the model type from `config.json` (I2V vs T2V).
|
||||
|
||||
#### Wan2.1 T2V 1.3B
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.1-T2V-1.3B \
|
||||
--output-dir ./Wan2.1-T2V-1.3B-MLX
|
||||
```
|
||||
|
||||
#### Wan2.1 T2V 14B
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.1-T2V-14B \
|
||||
--output-dir ./Wan2.1-T2V-14B-MLX
|
||||
```
|
||||
|
||||
#### Wan2.2 T2V 14B
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.2-T2V-A14B \
|
||||
--output-dir ./Wan2.2-T2V-A14B-MLX
|
||||
```
|
||||
|
||||
#### Wan2.2 I2V 14B
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.2-I2V-A14B \
|
||||
--output-dir ./Wan2.2-I2V-A14B-MLX
|
||||
```
|
||||
|
||||
The I2V model is auto-detected from `config.json`; the output will include a `vae_encoder.safetensors` used to encode the conditioning image.
|
||||
|
||||
#### Wan2.2 TI2V 5B
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.2-TI2V-5B \
|
||||
--output-dir ./Wan2.2-TI2V-5B-MLX
|
||||
```
|
||||
|
||||
The TI2V model uses a different VAE (`z_dim=48`, `vae_stride=(4,16,16)`) and is auto-detected during conversion.
|
||||
|
||||
---
|
||||
|
||||
You can also pass `--model-version 2.1` or `--model-version 2.2` to force the version instead of relying on auto-detection.
|
||||
|
||||
#### Conversion Options
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `--checkpoint-dir` | (required) | Path to original PyTorch checkpoint directory |
|
||||
| `--output-dir` | `wan_mlx_model` | Output path for MLX model |
|
||||
| `--dtype` | `bfloat16` | Target dtype (`float16`, `float32`, `bfloat16`) |
|
||||
| `--model-version` | `auto` | Model version: `2.1`, `2.2`, or `auto` |
|
||||
| `--quantize` | off | Quantize transformer weights for reduced memory |
|
||||
| `--bits` | `4` | Quantization bits: `4` or `8` |
|
||||
| `--group-size` | `64` | Quantization group size: `32`, `64`, or `128` |
|
||||
|
||||
The converter produces:
|
||||
```
|
||||
wan_mlx/
|
||||
├── config.json # Model configuration
|
||||
├── t5_encoder.safetensors # T5 UMT5-XXL text encoder
|
||||
├── vae.safetensors # 3D VAE decoder
|
||||
├── vae_encoder.safetensors # 3D VAE encoder (I2V-14B only)
|
||||
├── model.safetensors # (Wan2.1) Single transformer
|
||||
├── low_noise_model.safetensors # (Wan2.2) Low-noise transformer
|
||||
└── high_noise_model.safetensors # (Wan2.2) High-noise transformer
|
||||
```
|
||||
|
||||
### Step 3: Generate Video
|
||||
|
||||
#### Wan2.1 T2V 1.3B
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.gemer \
|
||||
--model-dir ./Wan2.1-T2V-1.3B-MLX \
|
||||
--prompt "A cat playing piano in a cozy living room, cinematic lighting" \
|
||||
--width 832 --height 480 --num-frames 81 \
|
||||
--steps 50 --guide-scale 5.0 \
|
||||
--seed 42 \
|
||||
--output-path wan21_1b.mp4
|
||||
```
|
||||
|
||||
#### Wan2.1 T2V 14B
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.gemer \
|
||||
--model-dir ./Wan2.1-T2V-14B-MLX \
|
||||
--prompt "A woman walks through a misty forest at dawn, slow motion, cinematic" \
|
||||
--width 1280 --height 704 --num-frames 81 \
|
||||
--steps 50 --guide-scale 5.0 \
|
||||
--seed 42 \
|
||||
--output-path wan21_14b.mp4
|
||||
```
|
||||
|
||||
> **Tip**: If the first few frames look washed out or have color artifacts, add `--trim-first-frames 1` to generate 4 extra frames at the start and discard them. With the `unipc` scheduler (default), **10 steps** often gives satisfying results — useful for quick iteration.
|
||||
|
||||
#### Wan2.2 T2V 14B
|
||||
|
||||
Wan2.2 uses a dual-model pipeline (separate high-noise and low-noise transformers) and takes guidance as a `high,low` pair:
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.generate \
|
||||
--model-dir ./Wan2.2-T2V-A14B-MLX \
|
||||
--prompt "Two astronauts playing chess on the surface of the moon, dramatic lighting, 8K" \
|
||||
--negative-prompt "low quality, blurry, distorted" \
|
||||
--width 1280 --height 704 --num-frames 81 \
|
||||
--steps 40 --guide-scale "3.0,4.0" \
|
||||
--seed 42 \
|
||||
--output-path wan22_t2v.mp4
|
||||
```
|
||||
|
||||
> **Tip**: With the `unipc` scheduler (default), **10 steps** often produces satisfying results for 14B models — a significant speed-up with minimal quality loss. Try `--steps 10` for quick iterations.
|
||||
|
||||
#### Wan2.2 I2V 14B
|
||||
|
||||
Image-to-video: animates a starting image guided by a text prompt. Pass the image with `--image`:
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.generate \
|
||||
--model-dir ./Wan2.2-I2V-A14B-MLX \
|
||||
--image ./my_photo.png \
|
||||
--prompt "The person slowly turns their head and smiles, cinematic, natural lighting" \
|
||||
--negative-prompt "low quality, blurry, distorted" \
|
||||
--width 1280 --height 704 --num-frames 81 \
|
||||
--steps 40 --guide-scale "3.5,3.5" \
|
||||
--seed 42 \
|
||||
--output-path wan22_i2v.mp4
|
||||
```
|
||||
|
||||
> **Tip**: As with T2V, `--steps 10` with the `unipc` scheduler is often sufficient for fast prototyping.
|
||||
|
||||
#### Wan2.2 TI2V 5B
|
||||
|
||||
Text+image-to-video: a single-model variant with a larger VAE (`z_dim=48`). Resolution must be divisible by **32** (not 16 as with other models):
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.generate \
|
||||
--model-dir ./Wan2.2-TI2V-5B-MLX \
|
||||
--image ./my_photo.png \
|
||||
--prompt "The subject waves hello, warm sunlight, film grain" \
|
||||
--width 1280 --height 704 --num-frames 41 \
|
||||
--steps 40 --guide-scale 5.0 \
|
||||
--seed 42 \
|
||||
--output-path wan22_ti2v.mp4
|
||||
```
|
||||
|
||||
> **Note**: The 5B model is fast — 40 steps run quickly and are recommended for best quality.
|
||||
|
||||
> **Frame count**: `--num-frames` must satisfy `4n+1` for all models (e.g. 5, 9, 13, 21, 41, 81, 101 …).
|
||||
|
||||
> **Resolution**: Always use the model's native resolution. While generation will succeed at other sizes, mismatched resolutions or aspect ratios are likely to produce visual artifacts. Preferred resolutions are:
|
||||
> - **480P** — 832×480 (landscape) or 480×832 (portrait) — for Wan2.1 1.3B
|
||||
> - **720P** — 1280×704 (landscape) or 704×1280 (portrait) — for Wan2.1 14B, Wan2.2 T2V/I2V/TI2V
|
||||
|
||||
#### Generation Options
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `--model-dir` | (required) | Path to converted MLX model directory |
|
||||
| `--prompt` | (required) | Text prompt |
|
||||
| `--image` | — | Input image path (I2V and TI2V modes) |
|
||||
| `--negative-prompt` | config default | Negative guidance prompt |
|
||||
| `--width` | `1280` | Output width in pixels |
|
||||
| `--height` | `704` | Output height in pixels |
|
||||
| `--num-frames` | `81` | Number of frames (must be `4n+1`) |
|
||||
| `--steps` | config default | Diffusion steps |
|
||||
| `--guide-scale` | config default | Guidance scale; use `"high,low"` pair for Wan2.2 dual models |
|
||||
| `--shift` | config default | Noise schedule shift |
|
||||
| `--seed` | `-1` (random) | Random seed for reproducibility |
|
||||
| `--output-path` | `output.mp4` | Output video file path |
|
||||
| `--scheduler` | `unipc` | Solver: `euler`, `dpm++`, or `unipc` |
|
||||
| `--trim-first-frames` | `0` | Drop N leading frames (fixes first-frame artifacts on 14B models) |
|
||||
| `--tiling` | `auto` | VAE tiling: `auto`, `none`, `spatial`, `temporal` |
|
||||
|
||||
### Quantization (Reduced Memory)
|
||||
|
||||
Quantize the transformer weights to reduce memory usage by ~3.4×. Quantization is supported for all model variants and is especially important for running 14B models on devices with limited unified memory:
|
||||
|
||||
```bash
|
||||
# Convert with 4-bit quantization (works for any variant)
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.1-T2V-1.3B \
|
||||
--output-dir ./Wan2.1-T2V-1.3B-MLX-Q4 \
|
||||
--quantize --bits 4 --group-size 64
|
||||
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.1-T2V-14B \
|
||||
--output-dir ./Wan2.1-T2V-14B-MLX-Q4 \
|
||||
--quantize --bits 4 --group-size 64
|
||||
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.2-T2V-A14B \
|
||||
--output-dir ./Wan2.2-T2V-A14B-MLX-Q4 \
|
||||
--quantize --bits 4 --group-size 64
|
||||
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.2-I2V-A14B \
|
||||
--output-dir ./Wan2.2-I2V-A14B-MLX-Q4 \
|
||||
--quantize --bits 4 --group-size 64
|
||||
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.2-TI2V-5B \
|
||||
--output-dir ./Wan2.2-TI2V-5B-MLX-Q4 \
|
||||
--quantize --bits 4 --group-size 64
|
||||
```
|
||||
|
||||
You can also quantize an already-converted MLX model without re-converting from PyTorch:
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.convert \
|
||||
--checkpoint-dir ./Wan2.2-T2V-A14B-MLX \
|
||||
--output-dir ./Wan2.2-T2V-A14B-MLX-Q4 \
|
||||
--quantize-only --bits 4
|
||||
```
|
||||
|
||||
Quantized models are used exactly the same way — the quantization is auto-detected from `config.json`:
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.generate \
|
||||
--model-dir ./Wan2.2-T2V-A14B-MLX-Q4 \
|
||||
--prompt "A cat playing piano"
|
||||
```
|
||||
|
||||
**What gets quantized**: Self-attention (Q/K/V/O), cross-attention (Q/K/V/O), and FFN (fc1/fc2) — 10 layers × N blocks = ~95% of model weights. Embeddings, norms, and the output head remain in bfloat16 for precision.
|
||||
|
||||
| Model | BF16 Size | 4-bit Size | Notes |
|
||||
|-------|-----------|------------|-------|
|
||||
| 1.3B | 2.7 GB | 799 MB | ~3.4x smaller |
|
||||
| 14B | ~28 GB | ~8 GB | Enables running on 16GB devices |
|
||||
|
||||
> **Note**: On Apple Silicon, the 1.3B model fits comfortably in unified memory at bf16. Quantization reduces memory but may not speed up inference for small models. For the 14B model, quantization is essential to fit in memory and will also improve speed.
|
||||
|
||||
### Wan Model Specifications
|
||||
|
||||
**Transformer (14B)**
|
||||
- 40 layers, 40 attention heads, dim 5120, head dim 128
|
||||
- 3-way factorized RoPE (temporal + spatial)
|
||||
- 14.29B parameters
|
||||
|
||||
**Transformer (1.3B, Wan2.1 only)**
|
||||
- 30 layers, 12 attention heads, dim 1536, head dim 128
|
||||
- Same architecture, smaller scale
|
||||
|
||||
**Text Encoder** — UMT5-XXL (5.68B parameters)
|
||||
- 24 layers, 64 heads, dim 4096, vocab 256K
|
||||
|
||||
**VAE** — 3D causal convolution decoder (72.6M parameters)
|
||||
- Latent channels: 16
|
||||
- Compression: 4× temporal, 8× spatial
|
||||
|
||||
---
|
||||
|
||||
## LoRA Support
|
||||
|
||||
LoRA's can be used with the `--lora-high` and `--lora-low` command line switches.
|
||||
|
||||
For example, for using the the distilled [Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) LoRA, use the following command. Lightning speeds up generation by using only 4 steps and a CFG scale of 1.
|
||||
|
||||
```bash
|
||||
python -m mlx_video.wan2.generate \
|
||||
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-T2V-A14B-MLX \
|
||||
--width 480 \
|
||||
--height 704 \
|
||||
--num-frames 41 \
|
||||
--prompt "Two dogs of the poodle breed sitting on a beach wearing sunglasses, nodding with their heads, close up, cinematic, sunset" \
|
||||
--steps 4 \
|
||||
--guide-scale 1 \
|
||||
--trim-first-frames 1 \
|
||||
--seed 2391784614 \
|
||||
--lora-high /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/high_noise_model.safetensors 1 \
|
||||
--lora-low /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/low_noise_model.safetensors 1
|
||||
```
|
||||
|
||||
## Enjoy
|
||||
|
||||

|
||||
2
mlx_video/models/wan_2/__init__.py
Normal file
2
mlx_video/models/wan_2/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
221
mlx_video/models/wan_2/attention.py
Normal file
221
mlx_video/models/wan_2/attention.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .rope import rope_apply
|
||||
|
||||
|
||||
def _linear_dtype(layer) -> mx.Dtype:
|
||||
"""Get the compute dtype of a linear layer, handling QuantizedLinear and LoRA wrappers."""
|
||||
# Unwrap LoRA wrapper to get the underlying linear layer
|
||||
inner = getattr(layer, "linear", layer)
|
||||
if isinstance(inner, nn.QuantizedLinear):
|
||||
return inner.scales.dtype
|
||||
return inner.weight.dtype
|
||||
|
||||
|
||||
class WanRMSNorm(nn.Module):
|
||||
"""RMS normalization with learnable scale."""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = mx.ones((dim,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return mx.fast.rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
class WanLayerNorm(nn.Module):
|
||||
"""LayerNorm computed in float32, with optional affine."""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if elementwise_affine:
|
||||
self.weight = mx.ones((dim,))
|
||||
self.bias = mx.zeros((dim,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.elementwise_affine:
|
||||
return mx.fast.layer_norm(x, self.weight, self.bias, self.eps)
|
||||
else:
|
||||
return mx.fast.layer_norm(x, None, None, self.eps)
|
||||
|
||||
|
||||
class WanSelfAttention(nn.Module):
|
||||
"""Self-attention with QK normalization and 3-way factorized RoPE."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
window_size: tuple = (-1, -1),
|
||||
qk_norm: bool = True,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.window_size = window_size
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
|
||||
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
||||
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
seq_lens: list,
|
||||
grid_sizes: list,
|
||||
freqs: mx.array,
|
||||
rope_cos_sin: tuple | None = None,
|
||||
attn_mask: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
b, s, _ = x.shape
|
||||
n, d = self.num_heads, self.head_dim
|
||||
|
||||
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
w_dtype = _linear_dtype(self.q)
|
||||
x_w = x.astype(w_dtype)
|
||||
|
||||
q = self.q(x_w)
|
||||
k = self.k(x_w)
|
||||
if self.norm_q is not None:
|
||||
q = self.norm_q(q)
|
||||
if self.norm_k is not None:
|
||||
k = self.norm_k(k)
|
||||
|
||||
q = q.reshape(b, s, n, d)
|
||||
k = k.reshape(b, s, n, d)
|
||||
v = self.v(x_w).reshape(b, s, n, d)
|
||||
|
||||
# RoPE in float32 for precision (official uses float64)
|
||||
q = rope_apply(
|
||||
q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
|
||||
)
|
||||
k = rope_apply(
|
||||
k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
|
||||
)
|
||||
|
||||
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
|
||||
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
|
||||
k = k.astype(w_dtype).transpose(0, 2, 1, 3)
|
||||
v = v.transpose(0, 2, 1, 3)
|
||||
|
||||
# Use precomputed mask or build from seq_lens
|
||||
mask = attn_mask
|
||||
if mask is None and any(sl < s for sl in seq_lens):
|
||||
mask = mx.zeros((b, 1, 1, s), dtype=q.dtype)
|
||||
for i, sl in enumerate(seq_lens):
|
||||
mask[i, :, :, sl:] = -1e9
|
||||
|
||||
# Use memory-efficient scaled dot-product attention
|
||||
# mx.fast.scaled_dot_product_attention expects [B, N, L, D]
|
||||
if mask is not None:
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=self.scale, mask=mask
|
||||
)
|
||||
else:
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||
|
||||
out = out.transpose(0, 2, 1, 3).reshape(b, s, -1)
|
||||
return self.o(out)
|
||||
|
||||
|
||||
class WanCrossAttention(nn.Module):
|
||||
"""Cross-attention: Q from hidden states, K/V from text context."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
qk_norm: bool = True,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
|
||||
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
||||
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
||||
|
||||
def prepare_kv(self, context: mx.array) -> tuple:
|
||||
"""Pre-compute K and V projections for caching.
|
||||
|
||||
Args:
|
||||
context: [B, L_ctx, dim]
|
||||
|
||||
Returns:
|
||||
(k, v) each [B, N, L_ctx, D] ready for attention
|
||||
"""
|
||||
b = context.shape[0]
|
||||
n, d = self.num_heads, self.head_dim
|
||||
# Cast to compute dtype for efficient matmul
|
||||
w_dtype = _linear_dtype(self.k)
|
||||
ctx = context.astype(w_dtype)
|
||||
k = self.k(ctx)
|
||||
if self.norm_k is not None:
|
||||
k = self.norm_k(k)
|
||||
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
return k, v
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
context: mx.array,
|
||||
context_lens: list | None = None,
|
||||
kv_cache: tuple | None = None,
|
||||
) -> mx.array:
|
||||
b = x.shape[0]
|
||||
n, d = self.num_heads, self.head_dim
|
||||
|
||||
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
w_dtype = _linear_dtype(self.q)
|
||||
q = self.q(x.astype(w_dtype))
|
||||
if self.norm_q is not None:
|
||||
q = self.norm_q(q)
|
||||
q = q.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
|
||||
if kv_cache is not None:
|
||||
k, v = kv_cache
|
||||
else:
|
||||
ctx = context.astype(w_dtype)
|
||||
k = self.k(ctx)
|
||||
if self.norm_k is not None:
|
||||
k = self.norm_k(k)
|
||||
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
|
||||
# Optional context masking
|
||||
mask = None
|
||||
if context_lens is not None:
|
||||
ctx_len = k.shape[2]
|
||||
mask = mx.zeros((b, 1, 1, ctx_len), dtype=q.dtype)
|
||||
for i, cl in enumerate(context_lens):
|
||||
mask[i, :, :, cl:] = -1e9
|
||||
|
||||
if mask is not None:
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=self.scale, mask=mask
|
||||
)
|
||||
else:
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||
|
||||
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d)
|
||||
return self.o(out)
|
||||
129
mlx_video/models/wan_2/config.py
Normal file
129
mlx_video/models/wan_2/config.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Union
|
||||
|
||||
from mlx_video.models.ltx_2.config import BaseModelConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class WanModelConfig(BaseModelConfig):
|
||||
"""Configuration for Wan T2V models (supports both 2.1 and 2.2)."""
|
||||
|
||||
model_type: str = "t2v"
|
||||
model_version: str = "2.2"
|
||||
patch_size: Tuple[int, int, int] = (1, 2, 2)
|
||||
text_len: int = 512
|
||||
in_dim: int = 16
|
||||
dim: int = 5120
|
||||
ffn_dim: int = 13824
|
||||
freq_dim: int = 256
|
||||
text_dim: int = 4096
|
||||
out_dim: int = 16
|
||||
num_heads: int = 40
|
||||
num_layers: int = 40
|
||||
window_size: Tuple[int, int] = (-1, -1)
|
||||
qk_norm: bool = True
|
||||
cross_attn_norm: bool = True
|
||||
eps: float = 1e-6
|
||||
|
||||
# VAE
|
||||
vae_stride: Tuple[int, int, int] = (4, 8, 8)
|
||||
vae_z_dim: int = 16
|
||||
|
||||
# Inference
|
||||
dual_model: bool = True
|
||||
boundary: float = 0.875
|
||||
sample_shift: float = 12.0
|
||||
sample_steps: int = 40
|
||||
sample_guide_scale: Union[float, Tuple[float, float]] = (3.0, 4.0)
|
||||
num_train_timesteps: int = 1000
|
||||
sample_fps: int = 16
|
||||
frame_num: int = 81
|
||||
sample_neg_prompt: str = (
|
||||
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,"
|
||||
"最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,"
|
||||
"画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,"
|
||||
"杂乱的背景,三条腿,背景人很多,倒着走"
|
||||
)
|
||||
|
||||
# Resolution constraints
|
||||
max_area: int = 0 # 0 = no limit; e.g. 704*1280 for TI2V-5B
|
||||
t5_vocab_size: int = 256384
|
||||
t5_dim: int = 4096
|
||||
t5_dim_attn: int = 4096
|
||||
t5_dim_ffn: int = 10240
|
||||
t5_num_heads: int = 64
|
||||
t5_num_layers: int = 24
|
||||
t5_num_buckets: int = 32
|
||||
|
||||
@property
|
||||
def head_dim(self) -> int:
|
||||
return self.dim // self.num_heads
|
||||
|
||||
@classmethod
|
||||
def wan21_t2v_14b(cls) -> "WanModelConfig":
|
||||
"""Wan2.1 T2V 14B: single model, 40 layers, dim=5120."""
|
||||
return cls(
|
||||
model_version="2.1",
|
||||
dual_model=False,
|
||||
boundary=0.0,
|
||||
sample_shift=5.0,
|
||||
sample_steps=50,
|
||||
sample_guide_scale=5.0,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def wan21_t2v_1_3b(cls) -> "WanModelConfig":
|
||||
"""Wan2.1 T2V 1.3B: single model, 30 layers, dim=1536."""
|
||||
return cls(
|
||||
model_version="2.1",
|
||||
dim=1536,
|
||||
ffn_dim=8960,
|
||||
num_heads=12,
|
||||
num_layers=30,
|
||||
dual_model=False,
|
||||
boundary=0.0,
|
||||
sample_shift=5.0,
|
||||
sample_steps=50,
|
||||
sample_guide_scale=5.0,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def wan22_t2v_14b(cls) -> "WanModelConfig":
|
||||
"""Wan2.2 T2V 14B: dual model, 40 layers, dim=5120 (default)."""
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def wan22_i2v_14b(cls) -> "WanModelConfig":
|
||||
"""Wan2.2 I2V 14B: dual model, image-to-video, 40 layers, dim=5120."""
|
||||
return cls(
|
||||
model_type="i2v",
|
||||
in_dim=36,
|
||||
out_dim=16,
|
||||
dual_model=True,
|
||||
boundary=0.900,
|
||||
sample_shift=5.0,
|
||||
sample_guide_scale=(3.5, 3.5),
|
||||
max_area=704 * 1280,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def wan22_ti2v_5b(cls) -> "WanModelConfig":
|
||||
"""Wan2.2 TI2V 5B: text+image to video, 30 layers, dim=3072."""
|
||||
return cls(
|
||||
model_type="ti2v",
|
||||
dim=3072,
|
||||
ffn_dim=14336,
|
||||
in_dim=48,
|
||||
out_dim=48,
|
||||
num_heads=24,
|
||||
num_layers=30,
|
||||
vae_z_dim=48,
|
||||
vae_stride=(4, 16, 16),
|
||||
dual_model=False,
|
||||
boundary=0.0,
|
||||
sample_shift=5.0,
|
||||
sample_steps=40,
|
||||
sample_guide_scale=5.0,
|
||||
sample_fps=24,
|
||||
max_area=704 * 1280,
|
||||
)
|
||||
808
mlx_video/models/wan_2/convert.py
Normal file
808
mlx_video/models/wan_2/convert.py
Normal file
@@ -0,0 +1,808 @@
|
||||
"""Weight conversion for Wan2.2 models (PyTorch -> MLX)."""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_torch_weights(path: str) -> Dict[str, mx.array]:
|
||||
"""Load PyTorch .pth weights and convert to MLX arrays.
|
||||
|
||||
Args:
|
||||
path: Path to .pth file
|
||||
|
||||
Returns:
|
||||
Dictionary of MLX arrays
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError("PyTorch is required to load .pth weights: pip install torch")
|
||||
|
||||
logging.info(f"Loading weights from {path}")
|
||||
state_dict = torch.load(path, map_location="cpu", weights_only=True)
|
||||
|
||||
weights = {}
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
np_val = value.detach().float().numpy()
|
||||
weights[key] = mx.array(np_val)
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def load_safetensors_weights(path: str) -> Dict[str, mx.array]:
|
||||
"""Load safetensors weights as MLX arrays.
|
||||
|
||||
Args:
|
||||
path: Path to directory with safetensors files or single file
|
||||
|
||||
Returns:
|
||||
Dictionary of MLX arrays
|
||||
"""
|
||||
path = Path(path)
|
||||
weights = {}
|
||||
if path.is_file():
|
||||
weights = mx.load(str(path))
|
||||
elif path.is_dir():
|
||||
for sf in sorted(path.glob("*.safetensors")):
|
||||
weights.update(mx.load(str(sf)))
|
||||
return weights
|
||||
|
||||
|
||||
def sanitize_wan_transformer_weights(
|
||||
weights: Dict[str, mx.array]
|
||||
) -> Dict[str, mx.array]:
|
||||
"""Convert Wan2.2 transformer weight keys to MLX model structure.
|
||||
|
||||
Wan2.2 keys follow the pattern:
|
||||
patch_embedding.weight/bias
|
||||
text_embedding.{0,2}.weight/bias
|
||||
time_embedding.{0,2}.weight/bias
|
||||
time_projection.1.weight/bias
|
||||
blocks.{i}.norm1.weight
|
||||
blocks.{i}.self_attn.{q,k,v,o}.weight/bias
|
||||
blocks.{i}.self_attn.norm_q.weight
|
||||
blocks.{i}.self_attn.norm_k.weight
|
||||
blocks.{i}.norm3.weight/bias (if cross_attn_norm)
|
||||
blocks.{i}.cross_attn.{q,k,v,o}.weight/bias
|
||||
blocks.{i}.cross_attn.norm_q.weight
|
||||
blocks.{i}.cross_attn.norm_k.weight
|
||||
blocks.{i}.norm2.weight
|
||||
blocks.{i}.ffn.{0,2}.weight/bias
|
||||
blocks.{i}.modulation
|
||||
head.norm.weight
|
||||
head.head.weight/bias
|
||||
head.modulation
|
||||
freqs (buffer)
|
||||
|
||||
MLX model uses:
|
||||
patch_embedding_proj.weight/bias (after patchify reshape)
|
||||
text_embedding_0.weight/bias, text_embedding_1.weight/bias
|
||||
time_embedding_0.weight/bias, time_embedding_1.weight/bias
|
||||
time_projection.weight/bias
|
||||
blocks.{i}.norm1.weight
|
||||
blocks.{i}.self_attn.{q,k,v,o}.weight/bias
|
||||
etc.
|
||||
"""
|
||||
sanitized = {}
|
||||
consumed = set()
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Patch embedding: Conv3d(16, 5120, (1,2,2)) weight is [O, I, D, H, W]
|
||||
# MLX Linear expects [O, I*D*H*W] after we flatten in patchify
|
||||
if key == "patch_embedding.weight":
|
||||
# Original: [dim, in_dim, 1, 2, 2] -> reshape to [dim, in_dim*1*2*2]
|
||||
value = value.reshape(value.shape[0], -1)
|
||||
new_key = "patch_embedding_proj.weight"
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
if key == "patch_embedding.bias":
|
||||
new_key = "patch_embedding_proj.bias"
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
# Text embedding Sequential: 0=Linear, 1=GELU(no params), 2=Linear
|
||||
if key.startswith("text_embedding.0."):
|
||||
new_key = key.replace("text_embedding.0.", "text_embedding_0.")
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
if key.startswith("text_embedding.2."):
|
||||
new_key = key.replace("text_embedding.2.", "text_embedding_1.")
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
# Time embedding Sequential: 0=Linear, 1=SiLU(no params), 2=Linear
|
||||
if key.startswith("time_embedding.0."):
|
||||
new_key = key.replace("time_embedding.0.", "time_embedding_0.")
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
if key.startswith("time_embedding.2."):
|
||||
new_key = key.replace("time_embedding.2.", "time_embedding_1.")
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
# Time projection Sequential: 0=SiLU(no params), 1=Linear
|
||||
if key.startswith("time_projection.1."):
|
||||
new_key = key.replace("time_projection.1.", "time_projection.")
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
# FFN: Sequential(Linear, GELU, Linear) -> ffn.{0,2} -> ffn.fc1, ffn.fc2
|
||||
new_key = new_key.replace(".ffn.0.", ".ffn.fc1.")
|
||||
new_key = new_key.replace(".ffn.2.", ".ffn.fc2.")
|
||||
|
||||
# Skip the freqs buffer (we compute it in the model)
|
||||
if key == "freqs":
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
|
||||
unconsumed = set(weights.keys()) - consumed
|
||||
if unconsumed:
|
||||
logger.warning("Unconsumed transformer weight keys: %s", sorted(unconsumed))
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_wan_t5_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Convert Wan2.2 T5 encoder weight keys to MLX T5Encoder structure.
|
||||
|
||||
Wan2.2 T5 keys:
|
||||
token_embedding.weight
|
||||
pos_embedding.embedding.weight (if shared_pos)
|
||||
blocks.{i}.norm1.weight
|
||||
blocks.{i}.attn.{q,k,v,o}.weight
|
||||
blocks.{i}.norm2.weight
|
||||
blocks.{i}.ffn.gate.0.weight (gate linear)
|
||||
blocks.{i}.ffn.fc1.weight
|
||||
blocks.{i}.ffn.fc2.weight
|
||||
blocks.{i}.pos_embedding.embedding.weight (if not shared_pos)
|
||||
norm.weight
|
||||
|
||||
MLX T5Encoder structure:
|
||||
token_embedding.weight
|
||||
blocks.{i}.norm1.weight
|
||||
blocks.{i}.attn.{q,k,v,o}.weight
|
||||
blocks.{i}.norm2.weight
|
||||
blocks.{i}.ffn.gate_proj.weight (mapped from gate.0)
|
||||
blocks.{i}.ffn.fc1.weight
|
||||
blocks.{i}.ffn.fc2.weight
|
||||
blocks.{i}.pos_embedding.embedding.weight
|
||||
norm.weight
|
||||
"""
|
||||
sanitized = {}
|
||||
consumed = set()
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Map gate.0 -> gate_proj (the GELU is a separate module, not a parameter)
|
||||
new_key = new_key.replace(".ffn.gate.0.", ".ffn.gate_proj.")
|
||||
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
|
||||
unconsumed = set(weights.keys()) - consumed
|
||||
if unconsumed:
|
||||
logger.warning("Unconsumed T5 weight keys: %s", sorted(unconsumed))
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_wan_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Convert Wan2.2 VAE weight keys to MLX WanVAE structure.
|
||||
|
||||
Handles Conv3d and Conv2d weight transpositions for MLX format.
|
||||
"""
|
||||
sanitized = {}
|
||||
consumed = set()
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Handle Conv3d: PyTorch [O, I, D, H, W] -> MLX CausalConv3d weight [O, D, H, W, I]
|
||||
if "weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
# Handle Conv2d: PyTorch [O, I, H, W] -> MLX [O, H, W, I]
|
||||
if "weight" in key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
# Map decoder keys to MLX decoder structure
|
||||
# Wan2.2 uses encoder/decoder with downsamples/upsamples
|
||||
# Need to adapt naming for our simplified structure
|
||||
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
|
||||
unconsumed = set(weights.keys()) - consumed
|
||||
if unconsumed:
|
||||
logger.warning("Unconsumed VAE weight keys: %s", sorted(unconsumed))
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def _load_lora_configs(
|
||||
lora_configs: List[Tuple[str, float]],
|
||||
) -> Dict[str, list]:
|
||||
"""Load LoRA weights from config tuples, returning module_to_loras dict.
|
||||
|
||||
Shared between weight-merging and runtime-wrapping paths.
|
||||
"""
|
||||
from mlx_video.models.wan_2.generate import Colors
|
||||
from mlx_video.lora import LoRAConfig, load_multiple_loras
|
||||
|
||||
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
|
||||
|
||||
configs = []
|
||||
for lora_path, strength in lora_configs:
|
||||
try:
|
||||
config = LoRAConfig(path=lora_path, strength=strength)
|
||||
configs.append(config)
|
||||
print(f" - {Path(lora_path).name} (strength: {strength})")
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}Error loading LoRA {lora_path}: {e}{Colors.RESET}")
|
||||
raise
|
||||
|
||||
module_to_loras = load_multiple_loras(configs)
|
||||
|
||||
if not module_to_loras:
|
||||
print(
|
||||
f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}"
|
||||
)
|
||||
|
||||
return module_to_loras
|
||||
|
||||
|
||||
def load_and_apply_loras(
|
||||
model_weights: Dict[str, mx.array],
|
||||
lora_configs: Optional[List[Tuple[str, float]]] = None,
|
||||
verbose: bool = False,
|
||||
quantization_bits: int = 0,
|
||||
) -> Dict[str, mx.array]:
|
||||
"""Load and apply LoRA weights to model weights by merging into weight dict.
|
||||
|
||||
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
|
||||
"""
|
||||
from mlx_video.models.wan_2.generate import Colors
|
||||
from mlx_video.lora import apply_loras_to_weights
|
||||
|
||||
if not lora_configs:
|
||||
return model_weights
|
||||
|
||||
module_to_loras = _load_lora_configs(lora_configs)
|
||||
if not module_to_loras:
|
||||
return model_weights
|
||||
|
||||
print(
|
||||
f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}"
|
||||
)
|
||||
if verbose:
|
||||
print(f" Model has {len(model_weights)} weight keys")
|
||||
|
||||
modified_weights = apply_loras_to_weights(
|
||||
model_weights,
|
||||
module_to_loras,
|
||||
verbose=verbose,
|
||||
quantization_bits=quantization_bits,
|
||||
)
|
||||
|
||||
print(f"{Colors.GREEN}✓ LoRAs applied successfully{Colors.RESET}")
|
||||
|
||||
return modified_weights
|
||||
|
||||
|
||||
def convert_wan_checkpoint(
|
||||
checkpoint_dir: str,
|
||||
output_dir: str,
|
||||
dtype: str = "bfloat16",
|
||||
model_version: str = "auto",
|
||||
quantize: bool = False,
|
||||
bits: int = 4,
|
||||
group_size: int = 64,
|
||||
):
|
||||
"""Convert a Wan2.1 or Wan2.2 checkpoint directory to MLX format.
|
||||
|
||||
Wan2.2 expected structure:
|
||||
checkpoint_dir/
|
||||
models_t5_umt5-xxl-enc-bf16.pth
|
||||
Wan2.1_VAE.pth
|
||||
low_noise_model/ (safetensors)
|
||||
high_noise_model/ (safetensors)
|
||||
|
||||
Wan2.1 expected structure:
|
||||
checkpoint_dir/
|
||||
models_t5_umt5-xxl-enc-bf16.pth
|
||||
Wan2.1_VAE.pth
|
||||
diffusion_pytorch_model*.safetensors (single model)
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Path to Wan checkpoint directory
|
||||
output_dir: Path to output MLX model directory
|
||||
dtype: Target dtype
|
||||
model_version: "2.1", "2.2", or "auto" (detect from directory)
|
||||
quantize: Whether to quantize the transformer weights
|
||||
bits: Quantization bits (4 or 8)
|
||||
group_size: Quantization group size (32, 64, or 128)
|
||||
"""
|
||||
import json
|
||||
|
||||
checkpoint_dir = Path(checkpoint_dir)
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
dtype_map = {
|
||||
"float16": mx.float16,
|
||||
"float32": mx.float32,
|
||||
"bfloat16": mx.bfloat16,
|
||||
}
|
||||
target_dtype = dtype_map.get(dtype, mx.bfloat16)
|
||||
|
||||
# Auto-detect version
|
||||
if model_version == "auto":
|
||||
if (checkpoint_dir / "low_noise_model").exists():
|
||||
model_version = "2.2"
|
||||
elif (checkpoint_dir / "Wan2.2_VAE.pth").exists():
|
||||
model_version = "2.2"
|
||||
else:
|
||||
model_version = "2.1"
|
||||
print(f"Auto-detected Wan{model_version} checkpoint")
|
||||
|
||||
is_dual = (checkpoint_dir / "low_noise_model").exists()
|
||||
|
||||
if is_dual:
|
||||
# Wan2.2: Convert dual transformer models
|
||||
low_noise_path = checkpoint_dir / "low_noise_model"
|
||||
if low_noise_path.exists():
|
||||
print("Converting low-noise transformer...")
|
||||
weights = load_safetensors_weights(str(low_noise_path))
|
||||
weights = sanitize_wan_transformer_weights(weights)
|
||||
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
||||
out_path = output_dir / "low_noise_model.safetensors"
|
||||
mx.save_safetensors(str(out_path), weights)
|
||||
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
||||
|
||||
high_noise_path = checkpoint_dir / "high_noise_model"
|
||||
if high_noise_path.exists():
|
||||
print("Converting high-noise transformer...")
|
||||
weights = load_safetensors_weights(str(high_noise_path))
|
||||
weights = sanitize_wan_transformer_weights(weights)
|
||||
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
||||
out_path = output_dir / "high_noise_model.safetensors"
|
||||
mx.save_safetensors(str(out_path), weights)
|
||||
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
||||
else:
|
||||
# Wan2.1: Convert single transformer model
|
||||
# Try safetensors in the checkpoint dir itself
|
||||
print("Converting transformer (single model)...")
|
||||
weights = load_safetensors_weights(str(checkpoint_dir))
|
||||
if not weights:
|
||||
# Fallback: look for .pth files
|
||||
for pth in sorted(checkpoint_dir.glob("*.pth")):
|
||||
if "t5" not in pth.name.lower() and "vae" not in pth.name.lower():
|
||||
print(f" Loading from {pth.name}...")
|
||||
weights = load_torch_weights(str(pth))
|
||||
break
|
||||
if weights:
|
||||
weights = sanitize_wan_transformer_weights(weights)
|
||||
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
||||
out_path = output_dir / "model.safetensors"
|
||||
mx.save_safetensors(str(out_path), weights)
|
||||
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
||||
else:
|
||||
print(" Warning: No transformer weights found!")
|
||||
|
||||
# Save config — detect model size from source config.json or transformer weights
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
def _detect_config():
|
||||
"""Detect config from source config.json or transformer weight shapes."""
|
||||
if is_dual:
|
||||
# Check source config.json for model_type (I2V vs T2V)
|
||||
src_cfg_path = checkpoint_dir / "high_noise_model" / "config.json"
|
||||
if src_cfg_path.exists():
|
||||
with open(src_cfg_path) as f:
|
||||
src_config = json.load(f)
|
||||
src_model_type = src_config.get("model_type", "t2v")
|
||||
if src_model_type == "i2v" or src_config.get("in_dim") == 36:
|
||||
return WanModelConfig.wan22_i2v_14b()
|
||||
return WanModelConfig.wan22_t2v_14b()
|
||||
|
||||
# Try reading source config.json first (most reliable)
|
||||
src_cfg_path = checkpoint_dir / "config.json"
|
||||
src_config = None
|
||||
if src_cfg_path.exists():
|
||||
with open(src_cfg_path) as f:
|
||||
src_config = json.load(f)
|
||||
|
||||
if src_config and "dim" in src_config:
|
||||
src_dim = src_config.get("dim", 5120)
|
||||
src_in_dim = src_config.get("in_dim", 16)
|
||||
src_out_dim = src_config.get("out_dim", 16)
|
||||
src_ffn_dim = src_config.get("ffn_dim", 13824)
|
||||
src_num_heads = src_config.get("num_heads", 40)
|
||||
src_num_layers = src_config.get("num_layers", 40)
|
||||
src_model_type = src_config.get("model_type", "t2v")
|
||||
src_text_len = src_config.get("text_len", 512)
|
||||
|
||||
print(
|
||||
f" Source config: dim={src_dim}, layers={src_num_layers}, "
|
||||
f"heads={src_num_heads}, type={src_model_type}"
|
||||
)
|
||||
|
||||
# Use preset for known TI2V 5B configuration
|
||||
if src_model_type == "ti2v" and src_dim == 3072:
|
||||
return WanModelConfig.wan22_ti2v_5b()
|
||||
|
||||
is_22 = model_version == "2.2"
|
||||
|
||||
# Wan2.2 uses different VAE with z_dim=48 and stride (4,16,16)
|
||||
vae_z = 48 if is_22 else 16
|
||||
vae_s = (4, 16, 16) if is_22 else (4, 8, 8)
|
||||
fps = 24 if is_22 else 16
|
||||
|
||||
return WanModelConfig(
|
||||
model_type=src_model_type,
|
||||
model_version=model_version,
|
||||
dim=src_dim,
|
||||
ffn_dim=src_ffn_dim,
|
||||
in_dim=src_in_dim,
|
||||
out_dim=src_out_dim,
|
||||
num_heads=src_num_heads,
|
||||
num_layers=src_num_layers,
|
||||
text_len=src_text_len,
|
||||
vae_z_dim=vae_z,
|
||||
vae_stride=vae_s,
|
||||
dual_model=False,
|
||||
boundary=0.0,
|
||||
sample_shift=5.0,
|
||||
sample_steps=50,
|
||||
sample_guide_scale=5.0,
|
||||
sample_fps=fps,
|
||||
)
|
||||
|
||||
# Fallback: detect from saved transformer weight shapes
|
||||
saved_model = output_dir / "model.safetensors"
|
||||
if saved_model.exists():
|
||||
det_weights = mx.load(str(saved_model))
|
||||
dim = None
|
||||
for k, v in det_weights.items():
|
||||
if "patch_embedding_proj.weight" in k:
|
||||
dim = v.shape[0]
|
||||
break
|
||||
del det_weights
|
||||
if dim is not None and dim <= 2048:
|
||||
print(f" Auto-detected 1.3B model (dim={dim})")
|
||||
return WanModelConfig.wan21_t2v_1_3b()
|
||||
|
||||
return WanModelConfig.wan21_t2v_14b()
|
||||
|
||||
config = _detect_config()
|
||||
config_path = output_dir / "config.json"
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config.to_dict(), f, indent=2)
|
||||
print(f" Saved config to {config_path}")
|
||||
|
||||
# Convert T5 encoder
|
||||
t5_path = checkpoint_dir / "models_t5_umt5-xxl-enc-bf16.pth"
|
||||
if t5_path.exists():
|
||||
print("Converting T5 encoder...")
|
||||
weights = load_torch_weights(str(t5_path))
|
||||
weights = sanitize_wan_t5_weights(weights)
|
||||
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
||||
out_path = output_dir / "t5_encoder.safetensors"
|
||||
mx.save_safetensors(str(out_path), weights)
|
||||
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
||||
|
||||
# Convert VAE (check both naming conventions)
|
||||
vae_path = checkpoint_dir / "Wan2.1_VAE.pth"
|
||||
is_wan22_vae = False
|
||||
if not vae_path.exists():
|
||||
vae_path = checkpoint_dir / "Wan2.2_VAE.pth"
|
||||
is_wan22_vae = True
|
||||
if vae_path.exists():
|
||||
print(f"Converting VAE ({'Wan2.2' if is_wan22_vae else 'Wan2.1'})...")
|
||||
weights = load_torch_weights(str(vae_path))
|
||||
if is_wan22_vae:
|
||||
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
include_encoder = config.model_type in ("ti2v", "i2v")
|
||||
weights = sanitize_wan22_vae_weights(
|
||||
weights, include_encoder=include_encoder
|
||||
)
|
||||
else:
|
||||
weights = sanitize_wan_vae_weights(weights)
|
||||
# Always save VAE in float32 — official Wan2.2 runs VAE decode in
|
||||
# float32 (dtype=torch.float). Saving in bfloat16 loses precision
|
||||
# that cannot be recovered by upcasting at load time.
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
out_path = output_dir / "vae.safetensors"
|
||||
mx.save_safetensors(str(out_path), weights)
|
||||
print(f" Saved {len(weights)} weight tensors to {out_path} (float32)")
|
||||
|
||||
# Quantize transformer weights if requested
|
||||
if quantize:
|
||||
print(
|
||||
f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})..."
|
||||
)
|
||||
_quantize_saved_model(output_dir, config, is_dual, bits, group_size)
|
||||
|
||||
print(f"\nConversion complete! Output: {output_dir}")
|
||||
|
||||
|
||||
def _quantize_predicate(path: str, module) -> bool:
|
||||
"""Return True for layers that should be quantized.
|
||||
|
||||
Targets heavyweight Linear layers in attention and FFN blocks.
|
||||
Skips embeddings, norms, head, and modulation (small, precision-sensitive).
|
||||
"""
|
||||
if not hasattr(module, "to_quantized"):
|
||||
return False
|
||||
# Quantize attention Q/K/V/O and FFN fc1/fc2
|
||||
quantize_patterns = (
|
||||
".self_attn.q",
|
||||
".self_attn.k",
|
||||
".self_attn.v",
|
||||
".self_attn.o",
|
||||
".cross_attn.q",
|
||||
".cross_attn.k",
|
||||
".cross_attn.v",
|
||||
".cross_attn.o",
|
||||
".ffn.fc1",
|
||||
".ffn.fc2",
|
||||
)
|
||||
return any(path.endswith(p) for p in quantize_patterns)
|
||||
|
||||
|
||||
def _quantize_saved_model(
|
||||
output_dir: Path,
|
||||
config,
|
||||
is_dual: bool,
|
||||
bits: int,
|
||||
group_size: int,
|
||||
source_dir: Path = None,
|
||||
):
|
||||
"""Load saved bf16 model, quantize, and re-save.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to write quantized weights to.
|
||||
config: WanModelConfig for creating the model.
|
||||
is_dual: Whether this is a dual-expert model.
|
||||
bits: Quantization bits.
|
||||
group_size: Quantization group size.
|
||||
source_dir: Directory to read bf16 weights from. Defaults to output_dir.
|
||||
"""
|
||||
import json
|
||||
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
if source_dir is None:
|
||||
source_dir = output_dir
|
||||
|
||||
model_names = []
|
||||
if is_dual:
|
||||
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
|
||||
if (source_dir / name).exists():
|
||||
model_names.append(name)
|
||||
else:
|
||||
if (source_dir / "model.safetensors").exists():
|
||||
model_names.append("model.safetensors")
|
||||
|
||||
for name in model_names:
|
||||
print(f" Quantizing {name}...")
|
||||
model = WanModel(config)
|
||||
weights = mx.load(str(source_dir / name))
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(model.parameters())
|
||||
del weights
|
||||
gc.collect()
|
||||
mx.clear_cache()
|
||||
|
||||
# Apply quantization to targeted layers
|
||||
nn.quantize(
|
||||
model,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
||||
)
|
||||
|
||||
# Save quantized weights
|
||||
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
||||
|
||||
# Validate: check for NaN/Inf in bias tensors (corruption canary)
|
||||
bad_keys = []
|
||||
for k, v in weights_dict.items():
|
||||
if k.endswith(".bias") and not k.endswith(".biases"):
|
||||
mx.eval(v)
|
||||
if mx.any(mx.isnan(v)).item() or mx.any(mx.isinf(v)).item():
|
||||
bad_keys.append(k)
|
||||
if bad_keys:
|
||||
raise RuntimeError(
|
||||
f"Quantization produced corrupted weights in {model_path.name}: "
|
||||
f"{len(bad_keys)} bias tensors contain NaN/Inf "
|
||||
f"(e.g. {bad_keys[0]}). Try re-running with more available memory."
|
||||
)
|
||||
|
||||
mx.save_safetensors(str(output_dir / name), weights_dict)
|
||||
n_quantized = sum(1 for k in weights_dict if ".scales" in k)
|
||||
print(f" {n_quantized} layers quantized, {len(weights_dict)} tensors saved")
|
||||
|
||||
# Free model before processing next file
|
||||
del model, weights_dict
|
||||
gc.collect()
|
||||
mx.clear_cache()
|
||||
|
||||
# Update config.json with quantization metadata
|
||||
config_path = output_dir / "config.json"
|
||||
with open(config_path) as f:
|
||||
cfg = json.load(f)
|
||||
cfg["quantization"] = {
|
||||
"group_size": group_size,
|
||||
"bits": bits,
|
||||
}
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(cfg, f, indent=2)
|
||||
print(f" Updated config.json with quantization metadata")
|
||||
|
||||
|
||||
def quantize_mlx_model(
|
||||
mlx_model_dir: str,
|
||||
output_dir: str,
|
||||
bits: int = 4,
|
||||
group_size: int = 64,
|
||||
):
|
||||
"""Quantize an already-converted MLX model (skips PyTorch conversion).
|
||||
|
||||
Args:
|
||||
mlx_model_dir: Path to existing MLX model directory (bf16/fp16).
|
||||
output_dir: Path to output quantized model directory.
|
||||
bits: Quantization bits (4 or 8).
|
||||
group_size: Quantization group size (32, 64, or 128).
|
||||
"""
|
||||
import json
|
||||
import shutil
|
||||
|
||||
src = Path(mlx_model_dir)
|
||||
dst = Path(output_dir)
|
||||
|
||||
config_path = src / "config.json"
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"No config.json found in {src}")
|
||||
|
||||
with open(config_path) as f:
|
||||
cfg = json.load(f)
|
||||
|
||||
if cfg.get("quantization"):
|
||||
raise ValueError(
|
||||
f"Model at {src} is already quantized "
|
||||
f"({cfg['quantization']['bits']}-bit). Use a bf16/fp16 source."
|
||||
)
|
||||
|
||||
# Detect dual vs single expert
|
||||
is_dual = (src / "low_noise_model.safetensors").exists() and (
|
||||
src / "high_noise_model.safetensors"
|
||||
).exists()
|
||||
|
||||
# Build model config
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config_dict = {
|
||||
k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__
|
||||
}
|
||||
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
|
||||
if key in config_dict and isinstance(config_dict[key], list):
|
||||
config_dict[key] = tuple(config_dict[key])
|
||||
config = WanModelConfig(**config_dict)
|
||||
|
||||
# Copy non-transformer files to output dir (skip large model weights)
|
||||
transformer_files = {
|
||||
"low_noise_model.safetensors",
|
||||
"high_noise_model.safetensors",
|
||||
"model.safetensors",
|
||||
}
|
||||
if dst.resolve() != src.resolve():
|
||||
dst.mkdir(parents=True, exist_ok=True)
|
||||
for f in src.iterdir():
|
||||
if f.is_file() and f.name not in transformer_files:
|
||||
shutil.copy2(f, dst / f.name)
|
||||
print(f"Copied non-transformer files from {src} to {dst}")
|
||||
|
||||
print(f"Quantizing transformer weights ({bits}-bit, group_size={group_size})...")
|
||||
_quantize_saved_model(dst, config, is_dual, bits, group_size, source_dir=src)
|
||||
|
||||
print(f"\nQuantization complete! Output: {dst}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert Wan model to MLX format")
|
||||
parser.add_argument(
|
||||
"--checkpoint-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to Wan checkpoint directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="wan_mlx_model",
|
||||
help="Output path for MLX model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=["float16", "float32", "bfloat16"],
|
||||
default="bfloat16",
|
||||
help="Target dtype",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-version",
|
||||
type=str,
|
||||
choices=["2.1", "2.2", "auto"],
|
||||
default="auto",
|
||||
help="Wan model version (auto-detect by default)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize",
|
||||
action="store_true",
|
||||
help="Quantize transformer weights for faster inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize-only",
|
||||
action="store_true",
|
||||
help="Quantize an already-converted MLX model (skips PyTorch conversion)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bits",
|
||||
type=int,
|
||||
choices=[4, 8],
|
||||
default=4,
|
||||
help="Quantization bits (default: 4)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-size",
|
||||
type=int,
|
||||
choices=[32, 64, 128],
|
||||
default=64,
|
||||
help="Quantization group size (default: 64)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.quantize_only:
|
||||
quantize_mlx_model(
|
||||
args.checkpoint_dir,
|
||||
args.output_dir,
|
||||
bits=args.bits,
|
||||
group_size=args.group_size,
|
||||
)
|
||||
else:
|
||||
convert_wan_checkpoint(
|
||||
args.checkpoint_dir,
|
||||
args.output_dir,
|
||||
args.dtype,
|
||||
args.model_version,
|
||||
quantize=args.quantize,
|
||||
bits=args.bits,
|
||||
group_size=args.group_size,
|
||||
)
|
||||
977
mlx_video/models/wan_2/generate.py
Normal file
977
mlx_video/models/wan_2/generate.py
Normal file
@@ -0,0 +1,977 @@
|
||||
"""Wan2.2 Text-to-Video generation pipeline for MLX."""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from mlx_video.models.wan_2.i2v_utils import build_i2v_mask, preprocess_image
|
||||
from mlx_video.models.wan_2.utils import (
|
||||
encode_text,
|
||||
load_t5_encoder,
|
||||
load_vae_decoder,
|
||||
load_vae_encoder,
|
||||
load_wan_model,
|
||||
)
|
||||
from mlx_video.models.wan_2.postprocess import save_video
|
||||
|
||||
|
||||
class Colors:
|
||||
"""ANSI color codes for terminal output."""
|
||||
|
||||
CYAN = "\033[96m"
|
||||
BLUE = "\033[94m"
|
||||
GREEN = "\033[92m"
|
||||
YELLOW = "\033[93m"
|
||||
RED = "\033[91m"
|
||||
MAGENTA = "\033[95m"
|
||||
BOLD = "\033[1m"
|
||||
DIM = "\033[2m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
|
||||
# Backward-compat alias (tests and external code may use the old name)
|
||||
_build_i2v_mask = build_i2v_mask
|
||||
|
||||
|
||||
def _best_output_size(w, h, dw, dh, max_area):
|
||||
"""Compute the best output resolution that fits within max_area while
|
||||
preserving the input aspect ratio and satisfying alignment constraints.
|
||||
Matches the reference implementation's best_output_size().
|
||||
"""
|
||||
ratio = w / h
|
||||
ow = (max_area * ratio) ** 0.5
|
||||
oh = max_area / ow
|
||||
|
||||
# Option 1: process width first
|
||||
ow1 = int(ow // dw * dw)
|
||||
oh1 = int(max_area / ow1 // dh * dh)
|
||||
ratio1 = ow1 / oh1
|
||||
|
||||
# Option 2: process height first
|
||||
oh2 = int(oh // dh * dh)
|
||||
ow2 = int(max_area / oh2 // dw * dw)
|
||||
ratio2 = ow2 / oh2
|
||||
|
||||
if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2, ratio2 / ratio):
|
||||
return ow1, oh1
|
||||
return ow2, oh2
|
||||
|
||||
|
||||
def generate_video(
|
||||
model_dir: str,
|
||||
prompt: str,
|
||||
negative_prompt: str | None = None,
|
||||
image: str | None = None,
|
||||
width: int = 1280,
|
||||
height: int = 704,
|
||||
num_frames: int = 81,
|
||||
steps: int = None,
|
||||
guide_scale: str | float | tuple = None,
|
||||
shift: float = None,
|
||||
seed: int = -1,
|
||||
output_path: str = "output.mp4",
|
||||
scheduler: str = "unipc",
|
||||
loras: list | None = None,
|
||||
loras_high: list | None = None,
|
||||
loras_low: list | None = None,
|
||||
tiling: str = "auto",
|
||||
no_compile: bool = False,
|
||||
trim_first_frames: int = 0,
|
||||
debug_latents: bool = False,
|
||||
):
|
||||
"""Generate video using Wan pipeline (supports T2V and I2V).
|
||||
|
||||
Args:
|
||||
model_dir: Path to converted MLX model directory
|
||||
prompt: Text prompt
|
||||
negative_prompt: Negative prompt (None = use config default, "" = no negative prompt)
|
||||
image: Path to input image for I2V (None = T2V mode)
|
||||
width: Video width
|
||||
height: Video height
|
||||
num_frames: Number of frames (must be 4n+1)
|
||||
steps: Number of diffusion steps (None = use config default)
|
||||
guide_scale: Guidance scale: float for single, (low,high) for dual (None = config default)
|
||||
shift: Noise schedule shift (None = use config default)
|
||||
seed: Random seed (-1 for random)
|
||||
output_path: Output video path
|
||||
scheduler: Solver type: 'euler', 'dpm++', or 'unipc' (default)
|
||||
loras: Optional list of (path, strength) tuples applied to all models
|
||||
loras_high: Optional list of (path, strength) tuples for high-noise model only
|
||||
loras_low: Optional list of (path, strength) tuples for low-noise model only
|
||||
tiling: Tiling mode for VAE decoding. Options:
|
||||
- "auto": Automatically determine tiling based on video size (default)
|
||||
- "none": Disable tiling
|
||||
- "default", "aggressive", "conservative": Preset tiling configs
|
||||
- "spatial": Spatial tiling only
|
||||
- "temporal": Temporal tiling only
|
||||
no_compile: If True, skip mx.compile on models (useful for debugging)
|
||||
trim_first_frames: Number of temporal latent positions to generate extra
|
||||
and discard from the start. Each position = 4 pixel frames. Use 1
|
||||
to fix first-frame artifacts on 14B models (generates 4 extra frames,
|
||||
discards first 4). Use 2 for more aggressive trimming. Default: 0.
|
||||
debug_latents: If True, print per-temporal-position latent statistics
|
||||
after denoising for diagnosing first-frame artifacts.
|
||||
"""
|
||||
import json
|
||||
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowMatchEulerScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
|
||||
model_dir = Path(model_dir)
|
||||
|
||||
# Load config from model dir if available, otherwise auto-detect
|
||||
config_path = model_dir / "config.json"
|
||||
quantization = None
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
config_dict = json.load(f)
|
||||
# Extract quantization config (not a model config field)
|
||||
quantization = config_dict.pop("quantization", None)
|
||||
# Handle tuple fields stored as lists in JSON
|
||||
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
|
||||
if key in config_dict and isinstance(config_dict[key], list):
|
||||
config_dict[key] = tuple(config_dict[key])
|
||||
config = WanModelConfig(
|
||||
**{
|
||||
k: v
|
||||
for k, v in config_dict.items()
|
||||
if k in WanModelConfig.__dataclass_fields__
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Auto-detect: dual model files → 2.2, single model → 2.1
|
||||
if (model_dir / "low_noise_model.safetensors").exists():
|
||||
config = WanModelConfig.wan22_t2v_14b()
|
||||
else:
|
||||
# Detect 1.3B vs 14B from weight shapes
|
||||
model_path = model_dir / "model.safetensors"
|
||||
if model_path.exists():
|
||||
probe = mx.load(str(model_path), return_metadata=False)
|
||||
for k, v in probe.items():
|
||||
if "patch_embedding_proj.weight" in k:
|
||||
dim = v.shape[0]
|
||||
if dim <= 2048:
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
else:
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
break
|
||||
else:
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
del probe
|
||||
else:
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
|
||||
is_dual = config.dual_model
|
||||
is_i2v = image is not None
|
||||
|
||||
# Validate config against actual weights (handles mismatched config.json)
|
||||
if not is_dual:
|
||||
model_path = model_dir / "model.safetensors"
|
||||
if model_path.exists():
|
||||
probe = mx.load(str(model_path), return_metadata=False)
|
||||
for k, v in probe.items():
|
||||
if "patch_embedding_proj.weight" in k:
|
||||
actual_dim = v.shape[0]
|
||||
if actual_dim != config.dim:
|
||||
print(
|
||||
f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}"
|
||||
)
|
||||
if actual_dim <= 2048:
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
else:
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
break
|
||||
del probe
|
||||
|
||||
# Auto-correct Wan2.2 VAE params from stale configs
|
||||
if config.in_dim == 48 and config.vae_z_dim != 48:
|
||||
print(
|
||||
f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}"
|
||||
)
|
||||
config = WanModelConfig(
|
||||
**{
|
||||
**{
|
||||
f.name: getattr(config, f.name)
|
||||
for f in config.__dataclass_fields__.values()
|
||||
},
|
||||
"vae_z_dim": 48,
|
||||
"vae_stride": (4, 16, 16),
|
||||
"sample_fps": 24,
|
||||
}
|
||||
)
|
||||
|
||||
# Apply defaults from config if not overridden
|
||||
if steps is None:
|
||||
steps = config.sample_steps
|
||||
if shift is None:
|
||||
shift = config.sample_shift
|
||||
if guide_scale is None:
|
||||
guide_scale = config.sample_guide_scale
|
||||
|
||||
# Normalize guide_scale
|
||||
if isinstance(guide_scale, (int, float)):
|
||||
guide_scale = float(guide_scale)
|
||||
elif isinstance(guide_scale, str):
|
||||
parts = [float(x) for x in guide_scale.split(",")]
|
||||
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
# Detect CFG-disabled mode (guide_scale=1.0 for all models → skip uncond pass for 2x speedup)
|
||||
if isinstance(guide_scale, tuple):
|
||||
cfg_disabled = all(gs <= 1.0 for gs in guide_scale)
|
||||
else:
|
||||
cfg_disabled = guide_scale <= 1.0
|
||||
|
||||
# Validate frame count
|
||||
assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}"
|
||||
|
||||
gen_frames = num_frames
|
||||
if trim_first_frames > 0:
|
||||
gen_frames = num_frames + trim_first_frames * 4
|
||||
print(
|
||||
f"{Colors.DIM} Trim: generating {gen_frames} frames, will discard first {trim_first_frames * 4}{Colors.RESET}"
|
||||
)
|
||||
|
||||
version_str = f"Wan{config.model_version}"
|
||||
mode_str = "dual-model" if is_dual else "single-model"
|
||||
pipeline_str = "Image-to-Video" if is_i2v else "Text-to-Video"
|
||||
# Resolve negative prompt: explicit user value > config default
|
||||
# The official Wan2.2 uses a Chinese negative prompt (config.sample_neg_prompt)
|
||||
# that prevents oversaturation, artifacts, and comic look. We use it by default.
|
||||
# Text cleaning (_clean_text) normalizes fullwidth chars to match official tokenization.
|
||||
if negative_prompt is None:
|
||||
neg_prompt_resolved = config.sample_neg_prompt
|
||||
else:
|
||||
neg_prompt_resolved = negative_prompt
|
||||
print(f"{Colors.CYAN}{'='*60}")
|
||||
print(f" {version_str} {pipeline_str} Generation (MLX, {mode_str})")
|
||||
print(f"{'='*60}{Colors.RESET}")
|
||||
print(f"{Colors.DIM} Prompt: {prompt}")
|
||||
if is_i2v:
|
||||
print(f" Image: {image}")
|
||||
if neg_prompt_resolved and neg_prompt_resolved.strip():
|
||||
neg_display = (
|
||||
neg_prompt_resolved[:60] + "..."
|
||||
if len(neg_prompt_resolved) > 60
|
||||
else neg_prompt_resolved
|
||||
)
|
||||
print(f" Neg prompt: {neg_display}")
|
||||
print(f" Size: {width}x{height}, Frames: {num_frames}")
|
||||
print(
|
||||
f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}"
|
||||
)
|
||||
if cfg_disabled:
|
||||
print(f" CFG: disabled (guide_scale≤1 → B=1 fast path, 2x denoising speedup)")
|
||||
print(f"{Colors.RESET}")
|
||||
|
||||
# Seed
|
||||
if seed < 0:
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
mx.random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
print(f"{Colors.DIM} Seed: {seed}{Colors.RESET}")
|
||||
|
||||
# Align dimensions to patch_size * vae_stride (required for patchify)
|
||||
vae_stride = config.vae_stride
|
||||
patch_size = config.patch_size
|
||||
align_h = patch_size[1] * vae_stride[1] # e.g. 2*16=32
|
||||
align_w = patch_size[2] * vae_stride[2]
|
||||
if height % align_h != 0 or width % align_w != 0:
|
||||
old_h, old_w = height, width
|
||||
height = (height // align_h) * align_h
|
||||
width = (width // align_w) * align_w
|
||||
if height == 0:
|
||||
height = align_h
|
||||
if width == 0:
|
||||
width = align_w
|
||||
print(
|
||||
f"{Colors.DIM} Aligned {old_w}x{old_h} → {width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}"
|
||||
)
|
||||
|
||||
# Enforce max_area constraint (model-specific resolution limit)
|
||||
if config.max_area > 0 and height * width > config.max_area:
|
||||
old_h, old_w = height, width
|
||||
width, height = _best_output_size(
|
||||
width, height, align_w, align_h, config.max_area
|
||||
)
|
||||
print(
|
||||
f"{Colors.YELLOW} ⚠ Resolution {old_w}x{old_h} exceeds model's max area "
|
||||
f"({config.max_area:,}px). Adjusted → {width}x{height}{Colors.RESET}"
|
||||
)
|
||||
|
||||
# Compute target latent shape
|
||||
z_dim = config.vae_z_dim
|
||||
t_latent = (gen_frames - 1) // vae_stride[0] + 1
|
||||
h_latent = height // vae_stride[1]
|
||||
w_latent = width // vae_stride[2]
|
||||
target_shape = (z_dim, t_latent, h_latent, w_latent)
|
||||
|
||||
# Sequence length for transformer
|
||||
seq_len = math.ceil(
|
||||
(h_latent * w_latent) / (patch_size[1] * patch_size[2]) * t_latent
|
||||
)
|
||||
|
||||
print(f"{Colors.DIM} Latent shape: {target_shape}")
|
||||
print(f" Sequence length: {seq_len}{Colors.RESET}")
|
||||
|
||||
# Load T5 encoder
|
||||
t1 = time.time()
|
||||
print(f"\n{Colors.BLUE}Loading T5 encoder...{Colors.RESET}")
|
||||
t5_path = model_dir / "t5_encoder.safetensors"
|
||||
t5_encoder = load_t5_encoder(t5_path, config)
|
||||
|
||||
# Load tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
|
||||
|
||||
# Encode prompts
|
||||
print(f"{Colors.BLUE}Encoding text...{Colors.RESET}")
|
||||
context = encode_text(t5_encoder, tokenizer, prompt, config.text_len)
|
||||
if cfg_disabled:
|
||||
context_null = None
|
||||
mx.eval(context)
|
||||
else:
|
||||
context_null = encode_text(
|
||||
t5_encoder, tokenizer, neg_prompt_resolved, config.text_len
|
||||
)
|
||||
mx.eval(context, context_null)
|
||||
|
||||
# Free T5 from memory
|
||||
del t5_encoder
|
||||
gc.collect()
|
||||
mx.clear_cache()
|
||||
print(f"{Colors.DIM} T5 encoding: {time.time() - t1:.1f}s{Colors.RESET}")
|
||||
|
||||
# I2V: encode image to latent space
|
||||
z_img = None
|
||||
i2v_mask = None
|
||||
i2v_mask_tokens = None
|
||||
y_i2v = None
|
||||
is_i2v_channel_concat = is_i2v and config.model_type == "i2v"
|
||||
is_i2v_mask_blend = is_i2v and config.model_type != "i2v"
|
||||
if is_i2v:
|
||||
print(f"\n{Colors.BLUE}Encoding input image...{Colors.RESET}")
|
||||
t_img = time.time()
|
||||
|
||||
vae_path = model_dir / "vae.safetensors"
|
||||
|
||||
if is_i2v_channel_concat:
|
||||
# I2V-14B: encode full video (first frame = image, rest = zeros)
|
||||
# and construct y tensor with mask + encoded latents
|
||||
from PIL import Image
|
||||
|
||||
img = Image.open(image).convert("RGB")
|
||||
scale = max(width / img.width, height / img.height)
|
||||
img = img.resize(
|
||||
(round(img.width * scale), round(img.height * scale)), Image.LANCZOS
|
||||
)
|
||||
x1, y1 = (img.width - width) // 2, (img.height - height) // 2
|
||||
img = img.crop((x1, y1, x1 + width, y1 + height))
|
||||
img_arr = mx.array(
|
||||
np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0
|
||||
) # [H, W, 3]
|
||||
img_chw = img_arr.transpose(2, 0, 1) # [3, H, W]
|
||||
|
||||
# Build video: first frame = image, rest = zeros -> [3, F, H, W]
|
||||
# Chunked encoding processes 1-frame + 4-frame chunks with temporal caching
|
||||
video = mx.concatenate(
|
||||
[
|
||||
img_chw[:, None, :, :],
|
||||
mx.zeros((3, num_frames - 1, height, width)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
# Encode through Wan2.1 VAE -> [1, z_dim, T_lat, H_lat, W_lat]
|
||||
vae_enc = load_vae_encoder(vae_path, config)
|
||||
z_video = vae_enc.encode(video[None]) # [1, 16, T_lat, H_lat, W_lat]
|
||||
mx.eval(z_video)
|
||||
z_video = z_video[0] # [16, T_lat, H_lat, W_lat]
|
||||
|
||||
# Build mask: 1 for first frame, 0 for rest -> rearrange to [4, T_lat, H, W]
|
||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||
msk = mx.concatenate(
|
||||
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
|
||||
)
|
||||
# Repeat first frame 4x, concat rest: [1, 4 + (F-1), H_lat, W_lat]
|
||||
msk = mx.concatenate(
|
||||
[
|
||||
mx.repeat(msk[:, :1], 4, axis=1),
|
||||
msk[:, 1:],
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
# Reshape to [1, T_lat, 4, H_lat, W_lat] then transpose -> [4, T_lat, H_lat, W_lat]
|
||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
||||
|
||||
# y = concat([mask, encoded_video]) -> [20, T_lat, H_lat, W_lat]
|
||||
y_i2v = mx.concatenate([msk, z_video], axis=0)
|
||||
mx.eval(y_i2v)
|
||||
|
||||
del vae_enc, img_arr, img_chw, video, z_video, msk
|
||||
else:
|
||||
# TI2V-5B: encode single image, blend with noise via mask
|
||||
img_tensor = preprocess_image(image, width, height)
|
||||
mx.eval(img_tensor)
|
||||
|
||||
vae_enc = load_vae_encoder(vae_path, config)
|
||||
z_img = vae_enc.encode(img_tensor) # [1, 1, H_lat, W_lat, z_dim]
|
||||
mx.eval(z_img)
|
||||
z_img = z_img[0].transpose(3, 0, 1, 2) # [z_dim, 1, H_lat, W_lat]
|
||||
i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size)
|
||||
|
||||
del vae_enc, img_tensor
|
||||
|
||||
gc.collect()
|
||||
mx.clear_cache()
|
||||
print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}")
|
||||
|
||||
# Load transformer models
|
||||
print(f"\n{Colors.BLUE}Loading transformer model(s)...{Colors.RESET}")
|
||||
if quantization:
|
||||
print(
|
||||
f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}"
|
||||
)
|
||||
t2 = time.time()
|
||||
|
||||
# Merge per-model LoRAs with shared LoRAs
|
||||
_loras_low = (loras or []) + (loras_low or []) or None
|
||||
_loras_high = (loras or []) + (loras_high or []) or None
|
||||
_loras_single = loras
|
||||
|
||||
if is_dual:
|
||||
low_noise_path = model_dir / "low_noise_model.safetensors"
|
||||
high_noise_path = model_dir / "high_noise_model.safetensors"
|
||||
low_noise_model = load_wan_model(
|
||||
low_noise_path, config, quantization, loras=_loras_low
|
||||
)
|
||||
high_noise_model = load_wan_model(
|
||||
high_noise_path, config, quantization, loras=_loras_high
|
||||
)
|
||||
else:
|
||||
single_model = load_wan_model(
|
||||
model_dir / "model.safetensors", config, quantization, loras=_loras_single
|
||||
)
|
||||
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
|
||||
|
||||
# Precompute text embeddings once (avoids redundant MLP in every step)
|
||||
# Each model has its own text_embedding weights, so dual models need separate embeddings
|
||||
if cfg_disabled:
|
||||
# No CFG: only compute cond embeddings (B=1 forward pass, 2x faster)
|
||||
if is_dual:
|
||||
context_emb_low = low_noise_model.embed_text([context])
|
||||
context_emb_high = high_noise_model.embed_text([context])
|
||||
mx.eval(context_emb_low, context_emb_high)
|
||||
context_cond_low = context_emb_low[0:1]
|
||||
context_cond_high = context_emb_high[0:1]
|
||||
else:
|
||||
context_emb = single_model.embed_text([context])
|
||||
mx.eval(context_emb)
|
||||
context_cond = context_emb[0:1]
|
||||
else:
|
||||
if is_dual:
|
||||
context_emb_low = low_noise_model.embed_text([context, context_null])
|
||||
context_emb_high = high_noise_model.embed_text([context, context_null])
|
||||
mx.eval(context_emb_low, context_emb_high)
|
||||
context_cfg_low = mx.concatenate(
|
||||
[context_emb_low[0:1], context_emb_low[1:2]], axis=0
|
||||
)
|
||||
context_cfg_high = mx.concatenate(
|
||||
[context_emb_high[0:1], context_emb_high[1:2]], axis=0
|
||||
)
|
||||
else:
|
||||
context_emb = single_model.embed_text([context, context_null])
|
||||
mx.eval(context_emb)
|
||||
context_cfg = mx.concatenate([context_emb[0:1], context_emb[1:2]], axis=0)
|
||||
|
||||
# Precompute cross-attention K/V caches (constant across all steps)
|
||||
if cfg_disabled:
|
||||
if is_dual:
|
||||
cross_kv_low = low_noise_model.prepare_cross_kv(context_cond_low)
|
||||
cross_kv_high = high_noise_model.prepare_cross_kv(context_cond_high)
|
||||
mx.eval(cross_kv_low, cross_kv_high)
|
||||
else:
|
||||
cross_kv = single_model.prepare_cross_kv(context_cond)
|
||||
mx.eval(cross_kv)
|
||||
else:
|
||||
if is_dual:
|
||||
cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg_low)
|
||||
cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg_high)
|
||||
mx.eval(cross_kv_low, cross_kv_high)
|
||||
else:
|
||||
cross_kv = single_model.prepare_cross_kv(context_cfg)
|
||||
mx.eval(cross_kv)
|
||||
|
||||
# Precompute RoPE frequencies (grid sizes are constant across all steps)
|
||||
f_grid = t_latent // patch_size[0]
|
||||
h_grid = h_latent // patch_size[1]
|
||||
w_grid = w_latent // patch_size[2]
|
||||
if cfg_disabled:
|
||||
rope_grid_sizes = [(f_grid, h_grid, w_grid)]
|
||||
else:
|
||||
rope_grid_sizes = [(f_grid, h_grid, w_grid), (f_grid, h_grid, w_grid)]
|
||||
if is_dual:
|
||||
rope_cos_sin_low = low_noise_model.prepare_rope(rope_grid_sizes)
|
||||
rope_cos_sin_high = high_noise_model.prepare_rope(rope_grid_sizes)
|
||||
mx.eval(rope_cos_sin_low, rope_cos_sin_high)
|
||||
else:
|
||||
rope_cos_sin = single_model.prepare_rope(rope_grid_sizes)
|
||||
mx.eval(rope_cos_sin)
|
||||
|
||||
# Setup scheduler
|
||||
_schedulers = {
|
||||
"euler": FlowMatchEulerScheduler,
|
||||
"dpm++": FlowDPMPP2MScheduler,
|
||||
"unipc": FlowUniPCScheduler,
|
||||
}
|
||||
sched_cls = _schedulers.get(scheduler, FlowUniPCScheduler)
|
||||
sched = sched_cls(num_train_timesteps=config.num_train_timesteps)
|
||||
sched.set_timesteps(steps, shift=shift)
|
||||
|
||||
# Generate initial noise
|
||||
noise = mx.random.normal(target_shape)
|
||||
|
||||
# I2V initialization: TI2V-5B blends image with noise, I2V-14B uses pure noise
|
||||
if is_i2v_mask_blend:
|
||||
latents = (1.0 - i2v_mask) * z_img + i2v_mask * noise
|
||||
else:
|
||||
latents = noise
|
||||
|
||||
# Boundary for model switching (dual model only)
|
||||
boundary = (config.boundary * config.num_train_timesteps) if is_dual else None
|
||||
|
||||
# Diffusion loop
|
||||
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
|
||||
t3 = time.time()
|
||||
|
||||
# Compile model forward for faster denoising
|
||||
if not no_compile:
|
||||
models_to_compile = (
|
||||
[high_noise_model, low_noise_model] if is_dual else [single_model]
|
||||
)
|
||||
for m in models_to_compile:
|
||||
m._compiled = mx.compile(m)
|
||||
|
||||
# Pre-convert timesteps to Python list to avoid .item() sync each step
|
||||
timestep_list = sched.timesteps.tolist()
|
||||
|
||||
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
|
||||
timestep_val = timestep_list[i]
|
||||
|
||||
# Select model, cached K/V, and precomputed RoPE
|
||||
if is_dual:
|
||||
if timestep_val >= boundary:
|
||||
model = high_noise_model
|
||||
kv = cross_kv_high
|
||||
rcs = rope_cos_sin_high
|
||||
else:
|
||||
model = low_noise_model
|
||||
kv = cross_kv_low
|
||||
rcs = rope_cos_sin_low
|
||||
else:
|
||||
model = single_model
|
||||
kv = cross_kv
|
||||
rcs = rope_cos_sin
|
||||
|
||||
# Use compiled forward when available (faster after first trace)
|
||||
_call = getattr(model, "_compiled", model)
|
||||
|
||||
if cfg_disabled:
|
||||
# No CFG: B=1 forward pass (2x faster than B=2 CFG batch)
|
||||
if is_i2v_mask_blend:
|
||||
t_tokens = i2v_mask_tokens * timestep_val
|
||||
pad_len = seq_len - t_tokens.shape[1]
|
||||
if pad_len > 0:
|
||||
t_tokens = mx.concatenate(
|
||||
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
|
||||
)
|
||||
t_batch = t_tokens # [1, L]
|
||||
else:
|
||||
t_batch = mx.array([timestep_val])
|
||||
|
||||
y_arg = [y_i2v] if is_i2v_channel_concat else None
|
||||
|
||||
if is_dual:
|
||||
ctx = (
|
||||
context_cond_high if timestep_val >= boundary else context_cond_low
|
||||
)
|
||||
else:
|
||||
ctx = context_cond
|
||||
preds = _call(
|
||||
[latents],
|
||||
t=t_batch,
|
||||
context=ctx,
|
||||
seq_len=seq_len,
|
||||
cross_kv_caches=kv,
|
||||
y=y_arg,
|
||||
rope_cos_sin=rcs,
|
||||
)
|
||||
noise_pred = preds[0]
|
||||
del preds
|
||||
else:
|
||||
# CFG: batch cond + uncond into single B=2 forward pass
|
||||
if is_dual:
|
||||
gs = guide_scale[1] if timestep_val >= boundary else guide_scale[0]
|
||||
else:
|
||||
gs = (
|
||||
guide_scale
|
||||
if isinstance(guide_scale, (int, float))
|
||||
else guide_scale[0]
|
||||
)
|
||||
|
||||
if is_i2v_mask_blend:
|
||||
t_tokens = i2v_mask_tokens * timestep_val
|
||||
pad_len = seq_len - t_tokens.shape[1]
|
||||
if pad_len > 0:
|
||||
t_tokens = mx.concatenate(
|
||||
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
|
||||
)
|
||||
t_batch = mx.concatenate([t_tokens, t_tokens], axis=0)
|
||||
else:
|
||||
t_batch = mx.array([timestep_val, timestep_val])
|
||||
|
||||
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
|
||||
|
||||
ctx = (
|
||||
context_cfg
|
||||
if not is_dual
|
||||
else (context_cfg_high if timestep_val >= boundary else context_cfg_low)
|
||||
)
|
||||
preds = _call(
|
||||
[latents, latents],
|
||||
t=t_batch,
|
||||
context=ctx,
|
||||
seq_len=seq_len,
|
||||
cross_kv_caches=kv,
|
||||
y=y_arg,
|
||||
rope_cos_sin=rcs,
|
||||
)
|
||||
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
|
||||
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
||||
del noise_pred_cond, noise_pred_uncond, preds
|
||||
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||
|
||||
# TI2V-5B: re-apply mask to keep first frame frozen
|
||||
if is_i2v_mask_blend:
|
||||
latents = (1.0 - i2v_mask) * z_img + i2v_mask * latents
|
||||
|
||||
# Release temporaries before eval to free memory for graph execution
|
||||
del noise_pred
|
||||
mx.eval(latents)
|
||||
|
||||
print(f"{Colors.DIM} Denoising: {time.time() - t3:.1f}s{Colors.RESET}")
|
||||
|
||||
# Diagnostic: per-temporal-position latent statistics
|
||||
if debug_latents:
|
||||
lat_np = np.array(latents) # [C, T, H, W]
|
||||
n_t = lat_np.shape[1]
|
||||
print(
|
||||
f"\n{Colors.CYAN} Latent diagnostics (shape {lat_np.shape}):{Colors.RESET}"
|
||||
)
|
||||
print(
|
||||
f" {'Pos':>4s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s} {'AbsMean':>8s}"
|
||||
)
|
||||
for t_pos in range(min(n_t, 8)):
|
||||
frame = lat_np[:, t_pos, :, :]
|
||||
print(
|
||||
f" {t_pos:4d} {frame.mean():8.4f} {frame.std():8.4f} "
|
||||
f"{frame.min():8.4f} {frame.max():8.4f} {np.abs(frame).mean():8.4f}"
|
||||
)
|
||||
if n_t > 8:
|
||||
interior = lat_np[:, 4:, :, :]
|
||||
print(
|
||||
f" {'4+':>4s} {interior.mean():8.4f} {interior.std():8.4f} "
|
||||
f"{interior.min():8.4f} {interior.max():8.4f} {np.abs(interior).mean():8.4f}"
|
||||
)
|
||||
print()
|
||||
|
||||
# Free transformer models and text embeddings
|
||||
if is_dual:
|
||||
del low_noise_model, high_noise_model, cross_kv_low, cross_kv_high
|
||||
if cfg_disabled:
|
||||
del context_cond_low, context_cond_high
|
||||
else:
|
||||
del context_cfg_low, context_cfg_high
|
||||
else:
|
||||
del single_model, cross_kv
|
||||
if cfg_disabled:
|
||||
del context_cond
|
||||
else:
|
||||
del context_cfg
|
||||
del model, kv, context
|
||||
if context_null is not None:
|
||||
del context_null
|
||||
gc.collect()
|
||||
mx.clear_cache()
|
||||
|
||||
# Load VAE and decode
|
||||
print(f"\n{Colors.BLUE}Decoding with VAE...{Colors.RESET}")
|
||||
t4 = time.time()
|
||||
vae_path = model_dir / "vae.safetensors"
|
||||
vae = load_vae_decoder(vae_path, config)
|
||||
|
||||
is_wan22_vae = config.vae_z_dim == 48
|
||||
|
||||
# Temporal extend: prepend reflected latent frames to the VAE input so that
|
||||
# the CausalConv3d zero-padding artifacts fall on the prefix (which we crop).
|
||||
# This gives the first real frame a full temporal receptive field of real data.
|
||||
# Select tiling configuration
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig
|
||||
|
||||
if tiling == "none":
|
||||
tiling_config = None
|
||||
elif tiling == "auto":
|
||||
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||
elif tiling == "default":
|
||||
tiling_config = TilingConfig.default()
|
||||
elif tiling == "aggressive":
|
||||
tiling_config = TilingConfig.aggressive()
|
||||
elif tiling == "conservative":
|
||||
tiling_config = TilingConfig.conservative()
|
||||
elif tiling == "spatial":
|
||||
tiling_config = TilingConfig.spatial_only()
|
||||
elif tiling == "temporal":
|
||||
tiling_config = TilingConfig.temporal_only()
|
||||
else:
|
||||
print(
|
||||
f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}"
|
||||
)
|
||||
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||
|
||||
if tiling_config is not None:
|
||||
spatial_info = (
|
||||
f"{tiling_config.spatial_config.tile_size_in_pixels}px"
|
||||
if tiling_config.spatial_config
|
||||
else "none"
|
||||
)
|
||||
temporal_info = (
|
||||
f"{tiling_config.temporal_config.tile_size_in_frames}f"
|
||||
if tiling_config.temporal_config
|
||||
else "none"
|
||||
)
|
||||
print(
|
||||
f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}"
|
||||
)
|
||||
|
||||
if is_wan22_vae:
|
||||
from mlx_video.models.wan_2.vae22 import denormalize_latents
|
||||
|
||||
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
|
||||
z = latents.transpose(1, 2, 3, 0)[None]
|
||||
z = denormalize_latents(z)
|
||||
if tiling_config is not None:
|
||||
video = vae.decode_tiled(z, tiling_config)
|
||||
else:
|
||||
video = vae(z)
|
||||
mx.eval(video)
|
||||
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
||||
|
||||
video = np.array(video[0]) # [T', H', W', 3]
|
||||
video = (video + 1.0) / 2.0
|
||||
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
||||
else:
|
||||
if tiling_config is not None:
|
||||
video = vae.decode_tiled(latents[None], tiling_config)
|
||||
else:
|
||||
video = vae.decode(latents[None])
|
||||
mx.eval(video)
|
||||
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
||||
|
||||
video = np.array(video[0]) # [3, T', H, W]
|
||||
video = (video + 1.0) / 2.0
|
||||
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
||||
video = video.transpose(1, 2, 3, 0) # [T, H, W, 3]
|
||||
|
||||
# Trim first N temporal chunks if requested (avoids first-frame artifacts)
|
||||
if trim_first_frames > 0:
|
||||
trim_pixels = trim_first_frames * 4
|
||||
video = video[trim_pixels:]
|
||||
print(
|
||||
f"{Colors.DIM} Trimmed first {trim_pixels} frames ({video.shape[0]} remaining){Colors.RESET}"
|
||||
)
|
||||
|
||||
save_video(video, output_path, fps=config.sample_fps)
|
||||
print(f"\n{Colors.GREEN}✓ Video saved to {output_path}{Colors.RESET}")
|
||||
print(f"{Colors.DIM} Total time: {time.time() - t1:.1f}s{Colors.RESET}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
|
||||
parser.add_argument(
|
||||
"--model-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to converted MLX model directory",
|
||||
)
|
||||
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
|
||||
parser.add_argument(
|
||||
"--image",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to input image for I2V (omit for T2V mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--negative-prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Negative prompt for CFG (default: official Chinese prompt from config)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-negative-prompt",
|
||||
action="store_true",
|
||||
help="Disable negative prompt (use empty string instead of config default)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width", type=int, default=1280, help="Video width (default: 1280)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=704,
|
||||
help="Video height (default: 704; 720p models use 704)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of diffusion steps (default: from config)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guide-scale",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Guidance scale: single float or low,high pair",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shift",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Noise schedule shift (default: from config)",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--output-path", type=str, default="output.mp4", help="Output video path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler",
|
||||
type=str,
|
||||
default="unipc",
|
||||
choices=["euler", "dpm++", "unipc"],
|
||||
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora",
|
||||
nargs=2,
|
||||
action="append",
|
||||
metavar=("PATH", "STRENGTH"),
|
||||
help="Apply a LoRA to all models (repeatable). Format: --lora path.safetensors 0.8",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-high",
|
||||
nargs=2,
|
||||
action="append",
|
||||
metavar=("PATH", "STRENGTH"),
|
||||
help="Apply a LoRA to high-noise model only (dual-model, repeatable)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-low",
|
||||
nargs=2,
|
||||
action="append",
|
||||
metavar=("PATH", "STRENGTH"),
|
||||
help="Apply a LoRA to low-noise model only (dual-model, repeatable)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tiling",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=[
|
||||
"auto",
|
||||
"none",
|
||||
"default",
|
||||
"aggressive",
|
||||
"conservative",
|
||||
"spatial",
|
||||
"temporal",
|
||||
],
|
||||
help="VAE tiling mode to reduce memory during decoding (default: auto)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-compile",
|
||||
action="store_true",
|
||||
help="Disable mx.compile on models (for debugging)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trim-first-frames",
|
||||
type=int,
|
||||
default=0,
|
||||
metavar="N",
|
||||
help="Generate N extra temporal chunks (N×4 frames) and discard them from the start. "
|
||||
"Fixes first-frame color/lighting artifacts on 14B models. Try 1 first (4 frames). "
|
||||
"Default: 0 (disabled)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug-latents",
|
||||
action="store_true",
|
||||
help="Print per-temporal-position latent statistics after denoising (diagnostic)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse guide scale
|
||||
guide_scale = None
|
||||
if args.guide_scale is not None:
|
||||
parts = [float(x) for x in args.guide_scale.split(",")]
|
||||
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
# Handle negative prompt: --no-negative-prompt forces empty, otherwise pass through
|
||||
neg_prompt = args.negative_prompt
|
||||
if args.no_negative_prompt:
|
||||
neg_prompt = ""
|
||||
|
||||
# Parse LoRA configs: convert [path, strength_str] → (path, float)
|
||||
def _parse_lora_args(lora_list):
|
||||
if not lora_list:
|
||||
return None
|
||||
return [(path, float(strength)) for path, strength in lora_list]
|
||||
|
||||
generate_video(
|
||||
model_dir=args.model_dir,
|
||||
prompt=args.prompt,
|
||||
negative_prompt=neg_prompt,
|
||||
image=args.image,
|
||||
width=args.width,
|
||||
height=args.height,
|
||||
num_frames=args.num_frames,
|
||||
steps=args.steps,
|
||||
guide_scale=guide_scale,
|
||||
shift=args.shift,
|
||||
seed=args.seed,
|
||||
output_path=args.output_path,
|
||||
scheduler=args.scheduler,
|
||||
loras=_parse_lora_args(args.lora),
|
||||
loras_high=_parse_lora_args(args.lora_high),
|
||||
loras_low=_parse_lora_args(args.lora_low),
|
||||
tiling=args.tiling,
|
||||
no_compile=args.no_compile,
|
||||
trim_first_frames=args.trim_first_frames,
|
||||
debug_latents=args.debug_latents,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
60
mlx_video/models/wan_2/i2v_utils.py
Normal file
60
mlx_video/models/wan_2/i2v_utils.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Image-to-Video utility functions for Wan2.2."""
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def preprocess_image(image_path: str, width: int, height: int) -> mx.array:
|
||||
"""Load, resize, center-crop, and normalize an image for I2V.
|
||||
|
||||
Args:
|
||||
image_path: Path to input image
|
||||
width: Target width
|
||||
height: Target height
|
||||
|
||||
Returns:
|
||||
Image tensor [1, 1, H, W, 3] in [-1, 1] (channels-last, batch + temporal dims)
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
img = Image.open(image_path).convert("RGB")
|
||||
|
||||
# Resize so that the image covers the target size (LANCZOS)
|
||||
scale = max(width / img.width, height / img.height)
|
||||
img = img.resize(
|
||||
(round(img.width * scale), round(img.height * scale)), Image.LANCZOS
|
||||
)
|
||||
|
||||
# Center crop
|
||||
x1 = (img.width - width) // 2
|
||||
y1 = (img.height - height) // 2
|
||||
img = img.crop((x1, y1, x1 + width, y1 + height))
|
||||
|
||||
# To tensor: [H, W, 3] float32 in [-1, 1]
|
||||
arr = np.array(img, dtype=np.float32) / 255.0
|
||||
arr = arr * 2.0 - 1.0 # [0,1] → [-1,1]
|
||||
return mx.array(arr[None, None]) # [1, 1, H, W, 3]
|
||||
|
||||
|
||||
def build_i2v_mask(z_shape, patch_size):
|
||||
"""Build temporal mask for I2V: first frame = 0, rest = 1.
|
||||
|
||||
Args:
|
||||
z_shape: Latent shape (C, T, H, W) in channels-first
|
||||
patch_size: (pt, ph, pw) patch size
|
||||
|
||||
Returns:
|
||||
mask: (C, T, H, W) float32 — 0 for first frame, 1 for rest
|
||||
mask_tokens: (1, L) float32 — 0 for first-frame tokens, 1 for rest
|
||||
"""
|
||||
C, T, H, W = z_shape
|
||||
mask = mx.ones(z_shape)
|
||||
# Zero out the first temporal position
|
||||
mask = mx.concatenate([mx.zeros((C, 1, H, W)), mask[:, 1:]], axis=1)
|
||||
|
||||
# Token-level mask for per-token timesteps: subsample to patch grid
|
||||
# mask shape [C, T, H, W] → take first channel, subsample by patch_size
|
||||
pt, ph, pw = patch_size
|
||||
mask_tokens = mask[0, ::pt, ::ph, ::pw] # [T', H', W']
|
||||
mask_tokens = mask_tokens.reshape(1, -1) # [1, L]
|
||||
return mask, mask_tokens
|
||||
41
mlx_video/models/wan_2/postprocess.py
Normal file
41
mlx_video/models/wan_2/postprocess.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
|
||||
"""Save video frames to MP4.
|
||||
|
||||
Args:
|
||||
frames: Video frames [T, H, W, 3] uint8
|
||||
output_path: Output file path
|
||||
fps: Frames per second
|
||||
"""
|
||||
try:
|
||||
import imageio
|
||||
|
||||
writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8)
|
||||
for frame in frames:
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
except ImportError:
|
||||
try:
|
||||
import cv2
|
||||
|
||||
h, w = frames.shape[1], frames.shape[2]
|
||||
fourcc = cv2.VideoWriter_fourcc(*"avc1")
|
||||
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
|
||||
for frame in frames:
|
||||
writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||
writer.release()
|
||||
except (ImportError, Exception):
|
||||
# Last resort: save as individual PNGs
|
||||
from PIL import Image
|
||||
|
||||
out_dir = Path(output_path).parent / Path(output_path).stem
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i, frame in enumerate(frames):
|
||||
Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png")
|
||||
print(
|
||||
f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)"
|
||||
)
|
||||
176
mlx_video/models/wan_2/rope.py
Normal file
176
mlx_video/models/wan_2/rope.py
Normal file
@@ -0,0 +1,176 @@
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def rope_params(max_seq_len: int, dim: int, theta: float = 10000.0) -> mx.array:
|
||||
"""Precompute RoPE frequency parameters as complex numbers.
|
||||
|
||||
Returns:
|
||||
Complex frequency tensor of shape [max_seq_len, dim // 2].
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
freqs = (
|
||||
np.arange(max_seq_len, dtype=np.float64)[:, None]
|
||||
* (
|
||||
1.0
|
||||
/ np.power(
|
||||
theta,
|
||||
np.arange(0, dim, 2, dtype=np.float64) / dim,
|
||||
)
|
||||
)[None, :]
|
||||
)
|
||||
# Store as (cos, sin) pairs: shape [max_seq_len, dim // 2, 2]
|
||||
cos_freqs = np.cos(freqs).astype(np.float32)
|
||||
sin_freqs = np.sin(freqs).astype(np.float32)
|
||||
return mx.array(np.stack([cos_freqs, sin_freqs], axis=-1))
|
||||
|
||||
|
||||
def rope_apply(
|
||||
x: mx.array,
|
||||
grid_sizes: list,
|
||||
freqs: mx.array,
|
||||
precomputed_cos_sin: tuple | None = None,
|
||||
) -> mx.array:
|
||||
"""Apply 3-way factorized RoPE to Q or K tensor.
|
||||
|
||||
Args:
|
||||
x: Shape [B, L, num_heads, head_dim]
|
||||
grid_sizes: List of (F, H, W) tuples per batch element
|
||||
freqs: Precomputed cos/sin, shape [1024, d//2, 2] split into 3 parts
|
||||
precomputed_cos_sin: Optional (cos, sin) from rope_precompute_cos_sin()
|
||||
"""
|
||||
b, s, n, d = x.shape
|
||||
half_d = d // 2
|
||||
|
||||
if precomputed_cos_sin is not None:
|
||||
cos_f, sin_f = precomputed_cos_sin
|
||||
# Check if all batch elements have the same grid (common for CFG B=2)
|
||||
f0, h0, w0 = grid_sizes[0]
|
||||
seq_len = f0 * h0 * w0
|
||||
all_same_grid = (
|
||||
all(grid_sizes[i] == grid_sizes[0] for i in range(1, b)) if b > 1 else True
|
||||
)
|
||||
|
||||
if all_same_grid:
|
||||
# Vectorized path: apply RoPE to all batch elements at once
|
||||
x_seq = x[:, :seq_len].reshape(b, seq_len, n, half_d, 2)
|
||||
x_real = x_seq[..., 0]
|
||||
x_imag = x_seq[..., 1]
|
||||
out_real = x_real * cos_f - x_imag * sin_f
|
||||
out_imag = x_real * sin_f + x_imag * cos_f
|
||||
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(
|
||||
b, seq_len, n, d
|
||||
)
|
||||
if seq_len < s:
|
||||
x_rotated = mx.concatenate([x_rotated, x[:, seq_len:]], axis=1)
|
||||
return x_rotated
|
||||
else:
|
||||
# Per-element path for mixed grid sizes
|
||||
outputs = []
|
||||
for i in range(b):
|
||||
f, h, w = grid_sizes[i]
|
||||
sl = f * h * w
|
||||
x_i = x[i, :sl].reshape(sl, n, half_d, 2)
|
||||
x_real = x_i[..., 0]
|
||||
x_imag = x_i[..., 1]
|
||||
out_real = x_real * cos_f - x_imag * sin_f
|
||||
out_imag = x_real * sin_f + x_imag * cos_f
|
||||
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(sl, n, d)
|
||||
if sl < s:
|
||||
x_rotated = mx.concatenate([x_rotated, x[i, sl:]], axis=0)
|
||||
outputs.append(x_rotated)
|
||||
return mx.stack(outputs)
|
||||
|
||||
# Cast freqs to input dtype to prevent float32 promotion cascade
|
||||
if freqs.dtype != x.dtype:
|
||||
freqs = freqs.astype(x.dtype)
|
||||
|
||||
# Split frequency dimensions: temporal gets more capacity
|
||||
d_t = half_d - 2 * (half_d // 3)
|
||||
d_h = half_d // 3
|
||||
d_w = half_d // 3
|
||||
|
||||
# Split freqs along dim axis
|
||||
freqs_t = freqs[:, :d_t] # [1024, d_t, 2]
|
||||
freqs_h = freqs[:, d_t : d_t + d_h] # [1024, d_h, 2]
|
||||
freqs_w = freqs[:, d_t + d_h : d_t + d_h + d_w] # [1024, d_w, 2]
|
||||
|
||||
outputs = []
|
||||
for i in range(b):
|
||||
f, h, w = grid_sizes[i]
|
||||
seq_len = f * h * w
|
||||
|
||||
# Reshape x to pairs for rotation: [seq_len, n, half_d, 2]
|
||||
x_i = x[i, :seq_len].reshape(seq_len, n, half_d, 2)
|
||||
|
||||
# Build per-position frequencies by expanding along grid dims
|
||||
# temporal: [f,1,1,d_t,2] -> [f,h,w,d_t,2]
|
||||
ft = mx.broadcast_to(freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2))
|
||||
# height: [1,h,1,d_h,2] -> [f,h,w,d_h,2]
|
||||
fh = mx.broadcast_to(freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2))
|
||||
# width: [1,1,w,d_w,2] -> [f,h,w,d_w,2]
|
||||
fw = mx.broadcast_to(freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2))
|
||||
|
||||
# Concatenate: [f*h*w, half_d, 2]
|
||||
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)
|
||||
|
||||
# Apply rotation: (a + bi) * (cos + sin*i) = (a*cos - b*sin) + (a*sin + b*cos)i
|
||||
cos_f = freqs_i[..., 0] # [seq_len, 1, half_d]
|
||||
sin_f = freqs_i[..., 1] # [seq_len, 1, half_d]
|
||||
|
||||
x_real = x_i[..., 0] # [seq_len, n, half_d]
|
||||
x_imag = x_i[..., 1] # [seq_len, n, half_d]
|
||||
|
||||
out_real = x_real * cos_f - x_imag * sin_f
|
||||
out_imag = x_real * sin_f + x_imag * cos_f
|
||||
|
||||
# Interleave back: [seq_len, n, half_d, 2] -> [seq_len, n, d]
|
||||
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(seq_len, n, d)
|
||||
|
||||
# Handle padding: keep non-rotated tokens after seq_len
|
||||
if seq_len < s:
|
||||
x_rotated = mx.concatenate([x_rotated, x[i, seq_len:]], axis=0)
|
||||
|
||||
outputs.append(x_rotated)
|
||||
|
||||
return mx.stack(outputs)
|
||||
|
||||
|
||||
def rope_precompute_cos_sin(
|
||||
grid_sizes: list, freqs: mx.array, dtype: type = mx.float32
|
||||
) -> tuple:
|
||||
"""Precompute cos/sin frequency tensors for constant grid sizes.
|
||||
|
||||
Call once before the diffusion loop. Pass result as precomputed_cos_sin
|
||||
to rope_apply to skip per-step broadcast/concat.
|
||||
|
||||
Args:
|
||||
grid_sizes: List of (F, H, W) tuples (must be same for all batch elements)
|
||||
freqs: Precomputed frequencies [1024, d//2, 2]
|
||||
dtype: Target dtype for the output tensors
|
||||
|
||||
Returns:
|
||||
(cos_f, sin_f) each [seq_len, 1, half_d]
|
||||
"""
|
||||
if freqs.dtype != dtype:
|
||||
freqs = freqs.astype(dtype)
|
||||
|
||||
f, h, w = grid_sizes[0]
|
||||
seq_len = f * h * w
|
||||
half_d = freqs.shape[1]
|
||||
|
||||
d_t = half_d - 2 * (half_d // 3)
|
||||
d_h = half_d // 3
|
||||
d_w = half_d // 3
|
||||
|
||||
freqs_t = freqs[:, :d_t]
|
||||
freqs_h = freqs[:, d_t : d_t + d_h]
|
||||
freqs_w = freqs[:, d_t + d_h : d_t + d_h + d_w]
|
||||
|
||||
ft = mx.broadcast_to(freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2))
|
||||
fh = mx.broadcast_to(freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2))
|
||||
fw = mx.broadcast_to(freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2))
|
||||
|
||||
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)
|
||||
return freqs_i[..., 0], freqs_i[..., 1]
|
||||
447
mlx_video/models/wan_2/scheduler.py
Normal file
447
mlx_video/models/wan_2/scheduler.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""Flow matching schedulers for Wan2.2 inference.
|
||||
|
||||
Provides Euler, DPM++2M, and UniPC solvers for flow matching diffusion.
|
||||
Higher-order solvers (DPM++, UniPC) converge faster, needing fewer steps
|
||||
for the same quality as Euler.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _compute_sigmas(
|
||||
num_steps: int, shift: float = 1.0, num_train_timesteps: int = 1000
|
||||
) -> np.ndarray:
|
||||
"""Compute shifted sigma schedule matching official Wan2.2 scheduler.
|
||||
|
||||
The reference creates FlowUniPCMultistepScheduler with shift=1 (identity)
|
||||
in the constructor, deriving sigma_max/sigma_min from the unshifted
|
||||
training schedule. Then set_timesteps() builds a linspace between those
|
||||
unshifted bounds and applies the actual shift once.
|
||||
|
||||
Returns num_steps+1 values (the last being 0.0 for the terminal state).
|
||||
"""
|
||||
# sigma bounds from unshifted training schedule (constructor uses shift=1)
|
||||
alphas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps)[::-1]
|
||||
sigmas_unshifted = 1.0 - alphas
|
||||
sigma_max = float(sigmas_unshifted[0]) # (N-1)/N
|
||||
sigma_min = float(sigmas_unshifted[-1]) # 0.0
|
||||
|
||||
# Interpolate, then apply shift once (matching set_timesteps)
|
||||
sigmas = np.linspace(sigma_max, sigma_min, num_steps + 1)[:-1]
|
||||
sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas)
|
||||
|
||||
return np.append(sigmas, 0.0).astype(np.float32)
|
||||
|
||||
|
||||
class FlowMatchEulerScheduler:
|
||||
"""1st-order Euler scheduler for flow matching diffusion."""
|
||||
|
||||
def __init__(self, num_train_timesteps: int = 1000):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.timesteps = None
|
||||
self.sigmas = None
|
||||
|
||||
def set_timesteps(self, num_steps: int, shift: float = 1.0):
|
||||
sigmas = _compute_sigmas(num_steps, shift, self.num_train_timesteps)
|
||||
self.sigmas = mx.array(sigmas)
|
||||
# Integer timesteps to match reference (model trained with int timesteps)
|
||||
self.timesteps = mx.array(
|
||||
(sigmas[:-1] * self.num_train_timesteps).astype(np.int64).astype(np.float32)
|
||||
)
|
||||
# Store as Python floats to avoid .item() sync in step()
|
||||
self._sigmas_float = sigmas.tolist()
|
||||
self._step_index = 0
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
timestep,
|
||||
sample: mx.array,
|
||||
) -> mx.array:
|
||||
"""Euler step: x_next = x + (sigma_next - sigma_cur) * v."""
|
||||
dt = (
|
||||
self._sigmas_float[self._step_index + 1]
|
||||
- self._sigmas_float[self._step_index]
|
||||
)
|
||||
x_next = sample + dt * model_output
|
||||
self._step_index += 1
|
||||
return x_next
|
||||
|
||||
def reset(self):
|
||||
self._step_index = 0
|
||||
|
||||
|
||||
class FlowDPMPP2MScheduler:
|
||||
"""DPM-Solver++(2M) for flow matching diffusion.
|
||||
|
||||
2nd-order multistep solver that reuses the previous step's model output
|
||||
for a correction term. Falls back to 1st order on the first and
|
||||
(optionally) last step. Reference: Wan2.2 fm_solvers.py.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
lower_order_final: bool = True,
|
||||
):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.lower_order_final = lower_order_final
|
||||
self.timesteps = None
|
||||
self.sigmas = None
|
||||
|
||||
def set_timesteps(self, num_steps: int, shift: float = 1.0):
|
||||
sigmas = _compute_sigmas(num_steps, shift, self.num_train_timesteps)
|
||||
self.sigmas = mx.array(sigmas)
|
||||
self.timesteps = mx.array(
|
||||
(sigmas[:-1] * self.num_train_timesteps).astype(np.int64).astype(np.float32)
|
||||
)
|
||||
# Store sigmas as Python floats for scalar math
|
||||
self._sigmas_float = sigmas.tolist()
|
||||
self._step_index = 0
|
||||
self._num_steps = num_steps
|
||||
self._prev_x0 = None # previous x0 prediction for 2nd-order correction
|
||||
|
||||
@staticmethod
|
||||
def _lambda(sigma: float) -> float:
|
||||
"""log-SNR: lambda(sigma) = log((1-sigma)/sigma).
|
||||
|
||||
Returns -inf at sigma=1.0 (pure noise) and +inf at sigma=0.0 (clean),
|
||||
matching torch.log behavior in the official code.
|
||||
"""
|
||||
if sigma >= 1.0:
|
||||
return -math.inf
|
||||
if sigma <= 0.0:
|
||||
return math.inf
|
||||
return math.log((1.0 - sigma) / sigma)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
timestep,
|
||||
sample: mx.array,
|
||||
) -> mx.array:
|
||||
"""DPM++(2M) step for flow matching.
|
||||
|
||||
Converts velocity prediction to x0, then applies 1st or 2nd order
|
||||
update depending on available history.
|
||||
"""
|
||||
i = self._step_index
|
||||
s = self._sigmas_float
|
||||
|
||||
sigma_cur = s[i]
|
||||
sigma_next = s[i + 1]
|
||||
|
||||
# Convert velocity -> x0 prediction: x0 = sample - sigma * v
|
||||
x0 = sample - sigma_cur * model_output
|
||||
|
||||
# Decide order: 1st for first step, last step (if lower_order_final
|
||||
# and few steps), otherwise 2nd
|
||||
use_first_order = self._prev_x0 is None or (
|
||||
self.lower_order_final and i == self._num_steps - 1 and self._num_steps < 15
|
||||
)
|
||||
|
||||
if use_first_order or sigma_next == 0.0:
|
||||
# 1st order DPM++ (equivalent to DDIM):
|
||||
# x_next = (σ_next/σ_cur)*x - (α_next*(exp(-h)-1))*x0
|
||||
if sigma_next == 0.0:
|
||||
x_next = x0
|
||||
else:
|
||||
lambda_cur = self._lambda(sigma_cur)
|
||||
lambda_next = self._lambda(sigma_next)
|
||||
h = lambda_next - lambda_cur
|
||||
alpha_next = 1.0 - sigma_next
|
||||
coeff_x = sigma_next / sigma_cur
|
||||
coeff_x0 = alpha_next * math.expm1(-h)
|
||||
x_next = coeff_x * sample - coeff_x0 * x0
|
||||
else:
|
||||
# 2nd order DPM++(2M) with midpoint correction
|
||||
sigma_prev = s[i - 1]
|
||||
lambda_prev = self._lambda(sigma_prev)
|
||||
lambda_cur = self._lambda(sigma_cur)
|
||||
lambda_next = self._lambda(sigma_next)
|
||||
|
||||
h = lambda_next - lambda_cur
|
||||
h_0 = lambda_cur - lambda_prev
|
||||
r0 = h_0 / h
|
||||
|
||||
# D0 = current x0, D1 = correction from previous x0
|
||||
D0 = x0
|
||||
D1 = (1.0 / r0) * (x0 - self._prev_x0)
|
||||
|
||||
alpha_next = 1.0 - sigma_next
|
||||
exp_neg_h_m1 = math.expm1(-h) # exp(-h) - 1
|
||||
|
||||
x_next = (
|
||||
(sigma_next / sigma_cur) * sample
|
||||
- (alpha_next * exp_neg_h_m1) * D0
|
||||
- 0.5 * (alpha_next * exp_neg_h_m1) * D1
|
||||
)
|
||||
|
||||
self._prev_x0 = x0
|
||||
self._step_index += 1
|
||||
return x_next
|
||||
|
||||
def reset(self):
|
||||
self._step_index = 0
|
||||
self._prev_x0 = None
|
||||
|
||||
|
||||
class FlowUniPCScheduler:
|
||||
"""UniPC (Unified Predictor-Corrector) for flow matching diffusion.
|
||||
|
||||
Multi-step predictor-corrector solver with configurable order.
|
||||
The corrector refines each step using the model output that was already
|
||||
computed, costing no extra model evaluations. Official Wan2.2 default.
|
||||
Reference: Wan2.2 fm_solvers_unipc.py.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
solver_order: int = 2,
|
||||
lower_order_final: bool = True,
|
||||
disable_corrector: list | None = None,
|
||||
use_corrector: bool = True,
|
||||
):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.solver_order = solver_order
|
||||
self.lower_order_final = lower_order_final
|
||||
self._use_corrector = use_corrector
|
||||
self.disable_corrector = set(disable_corrector or [])
|
||||
self.timesteps = None
|
||||
self.sigmas = None
|
||||
|
||||
def set_timesteps(self, num_steps: int, shift: float = 1.0):
|
||||
sigmas = _compute_sigmas(num_steps, shift, self.num_train_timesteps)
|
||||
self.sigmas = mx.array(sigmas)
|
||||
self.timesteps = mx.array(
|
||||
(sigmas[:-1] * self.num_train_timesteps).astype(np.int64).astype(np.float32)
|
||||
)
|
||||
self._sigmas_float = sigmas.tolist()
|
||||
self._step_index = 0
|
||||
self._num_steps = num_steps
|
||||
self._lower_order_nums = 0
|
||||
# Model output (x0) history for multi-step, stored newest-last
|
||||
self._model_outputs = [None] * self.solver_order
|
||||
self._last_sample = None # sample before prediction (for corrector)
|
||||
self._this_order = 1
|
||||
|
||||
@staticmethod
|
||||
def _lambda(sigma: float) -> float:
|
||||
"""log-SNR: lambda(sigma) = log((1-sigma)/sigma).
|
||||
|
||||
Returns -inf at sigma=1.0 (pure noise) and +inf at sigma=0.0 (clean),
|
||||
matching torch.log behavior in the official code.
|
||||
"""
|
||||
if sigma >= 1.0:
|
||||
return -math.inf
|
||||
if sigma <= 0.0:
|
||||
return math.inf
|
||||
return math.log((1.0 - sigma) / sigma)
|
||||
|
||||
def _convert_output(self, velocity: mx.array, sample: mx.array) -> mx.array:
|
||||
"""Convert velocity prediction to x0: x0 = sample - sigma * v."""
|
||||
sigma = self._sigmas_float[self._step_index]
|
||||
return sample - sigma * velocity
|
||||
|
||||
def _uni_p_bh2(self, x0: mx.array, sample: mx.array, order: int) -> mx.array:
|
||||
"""UniP predictor with B(h)=expm1(-h) basis (bh2 variant).
|
||||
|
||||
Matches official multistep_uni_p_bh_update: computes rhos_p via
|
||||
linalg.solve for order >= 3; order <= 2 uses analytic rhos_p=[0.5].
|
||||
"""
|
||||
i = self._step_index
|
||||
s = self._sigmas_float
|
||||
|
||||
sigma_s0 = s[i]
|
||||
sigma_t = s[i + 1]
|
||||
|
||||
if sigma_t == 0.0:
|
||||
return x0
|
||||
|
||||
lambda_s0 = self._lambda(sigma_s0)
|
||||
lambda_t = self._lambda(sigma_t)
|
||||
h = lambda_t - lambda_s0
|
||||
hh = -h # negated for predict_x0
|
||||
|
||||
alpha_t = 1.0 - sigma_t
|
||||
h_phi_1 = math.expm1(hh)
|
||||
B_h = h_phi_1
|
||||
|
||||
m0 = self._model_outputs[-1]
|
||||
# Base prediction
|
||||
x_t = (sigma_t / sigma_s0) * sample - (alpha_t * h_phi_1) * m0
|
||||
|
||||
if order >= 2 and m0 is not None:
|
||||
rks = []
|
||||
D1s = []
|
||||
for k in range(1, order):
|
||||
si_idx = i - k
|
||||
if si_idx < 0 or self._model_outputs[-(k + 1)] is None:
|
||||
break
|
||||
mk = self._model_outputs[-(k + 1)]
|
||||
sigma_sk = s[si_idx]
|
||||
lambda_sk = self._lambda(sigma_sk)
|
||||
rk = (lambda_sk - lambda_s0) / h
|
||||
if math.isinf(rk):
|
||||
break
|
||||
rks.append(rk)
|
||||
D1s.append((mk - m0) / rk)
|
||||
|
||||
if D1s:
|
||||
effective_order = len(D1s) + 1
|
||||
if effective_order <= 2:
|
||||
# Analytic solution for order 2
|
||||
rhos_p = [0.5]
|
||||
else:
|
||||
rks_arr = np.array(rks, dtype=np.float64)
|
||||
h_phi_k = h_phi_1 / hh - 1.0
|
||||
factorial_i = 1
|
||||
R_rows = []
|
||||
b_vals = []
|
||||
for j in range(1, effective_order):
|
||||
R_rows.append(rks_arr ** (j - 1))
|
||||
b_vals.append(float(h_phi_k * factorial_i / B_h))
|
||||
factorial_i *= j + 1
|
||||
h_phi_k = h_phi_k / hh - 1.0 / factorial_i
|
||||
R = np.stack(R_rows)
|
||||
b = np.array(b_vals)
|
||||
rhos_p = np.linalg.solve(R, b).tolist()
|
||||
|
||||
pred_res = sum(r * d for r, d in zip(rhos_p, D1s))
|
||||
x_t = x_t - (alpha_t * B_h) * pred_res
|
||||
|
||||
return x_t
|
||||
|
||||
def _uni_c_bh2(
|
||||
self,
|
||||
model_x0: mx.array,
|
||||
last_sample: mx.array,
|
||||
this_sample: mx.array,
|
||||
order: int,
|
||||
) -> mx.array:
|
||||
"""UniC corrector with B(h)=expm1(-h) basis (bh2 variant).
|
||||
|
||||
Matches official multistep_uni_c_bh_update: computes rhos_c via
|
||||
linalg.solve for order >= 2 (not hardcoded 0.5).
|
||||
"""
|
||||
i = self._step_index
|
||||
s = self._sigmas_float
|
||||
|
||||
sigma_s0 = s[i - 1]
|
||||
sigma_t = s[i]
|
||||
|
||||
if sigma_t == 0.0:
|
||||
return this_sample
|
||||
|
||||
lambda_s0 = self._lambda(sigma_s0)
|
||||
lambda_t = self._lambda(sigma_t)
|
||||
h = lambda_t - lambda_s0
|
||||
hh = -h # negated for predict_x0
|
||||
|
||||
alpha_t = 1.0 - sigma_t
|
||||
h_phi_1 = math.expm1(hh)
|
||||
B_h = h_phi_1
|
||||
|
||||
m0 = self._model_outputs[-1]
|
||||
# Re-derive base from last_sample
|
||||
x_t_ = (sigma_t / sigma_s0) * last_sample - (alpha_t * h_phi_1) * m0
|
||||
|
||||
D1_t = model_x0 - m0
|
||||
|
||||
# Gather rks and D1s from history
|
||||
rks = []
|
||||
D1s = []
|
||||
for k in range(1, order):
|
||||
si_idx = i - (k + 1)
|
||||
if si_idx < 0 or self._model_outputs[-(k + 1)] is None:
|
||||
break
|
||||
mk = self._model_outputs[-(k + 1)]
|
||||
sigma_sk = s[si_idx]
|
||||
lambda_sk = self._lambda(sigma_sk)
|
||||
rk = (lambda_sk - lambda_s0) / h
|
||||
if math.isinf(rk):
|
||||
break # History references sigma=1.0 boundary; reduce order
|
||||
rks.append(rk)
|
||||
D1s.append((mk - m0) / rk)
|
||||
rks.append(1.0)
|
||||
effective_order = len(rks) # = len(D1s) + 1
|
||||
|
||||
# Compute rhos_c coefficients
|
||||
if effective_order == 1:
|
||||
rhos_c = [0.5]
|
||||
else:
|
||||
rks_arr = np.array(rks, dtype=np.float64)
|
||||
h_phi_k = h_phi_1 / hh - 1.0
|
||||
factorial_i = 1
|
||||
R_rows = []
|
||||
b_vals = []
|
||||
for j in range(1, effective_order + 1):
|
||||
R_rows.append(rks_arr ** (j - 1))
|
||||
b_vals.append(float(h_phi_k * factorial_i / B_h))
|
||||
factorial_i *= j + 1
|
||||
h_phi_k = h_phi_k / hh - 1.0 / factorial_i
|
||||
R = np.stack(R_rows)
|
||||
b = np.array(b_vals)
|
||||
rhos_c = np.linalg.solve(R, b).tolist()
|
||||
|
||||
# Apply correction
|
||||
corr_res = mx.zeros_like(D1_t)
|
||||
for k_idx, d1 in enumerate(D1s):
|
||||
corr_res = corr_res + rhos_c[k_idx] * d1
|
||||
x_t = x_t_ - (alpha_t * B_h) * (corr_res + rhos_c[-1] * D1_t)
|
||||
return x_t
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
timestep,
|
||||
sample: mx.array,
|
||||
) -> mx.array:
|
||||
"""UniPC step: correct current, then predict next."""
|
||||
i = self._step_index
|
||||
|
||||
# Convert velocity -> x0
|
||||
x0 = self._convert_output(model_output, sample)
|
||||
|
||||
# 1. Corrector: refine current sample if we have history
|
||||
use_corrector = (
|
||||
self._use_corrector
|
||||
and i > 0
|
||||
and (i - 1) not in self.disable_corrector
|
||||
and self._last_sample is not None
|
||||
)
|
||||
if use_corrector:
|
||||
sample = self._uni_c_bh2(x0, self._last_sample, sample, self._this_order)
|
||||
|
||||
# 2. Shift model output history
|
||||
for k in range(self.solver_order - 1):
|
||||
self._model_outputs[k] = self._model_outputs[k + 1]
|
||||
self._model_outputs[-1] = x0
|
||||
|
||||
# 3. Determine prediction order
|
||||
if self.lower_order_final:
|
||||
this_order = min(self.solver_order, self._num_steps - i)
|
||||
else:
|
||||
this_order = self.solver_order
|
||||
self._this_order = min(this_order, self._lower_order_nums + 1)
|
||||
|
||||
# 4. Predict next sample
|
||||
self._last_sample = sample
|
||||
x_next = self._uni_p_bh2(x0, sample, self._this_order)
|
||||
|
||||
if self._lower_order_nums < self.solver_order:
|
||||
self._lower_order_nums += 1
|
||||
|
||||
self._step_index += 1
|
||||
return x_next
|
||||
|
||||
def reset(self):
|
||||
self._step_index = 0
|
||||
self._lower_order_nums = 0
|
||||
self._model_outputs = [None] * self.solver_order
|
||||
self._last_sample = None
|
||||
self._this_order = 1
|
||||
239
mlx_video/models/wan_2/text_encoder.py
Normal file
239
mlx_video/models/wan_2/text_encoder.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""T5 Text Encoder (UMT5-XXL) for Wan2.2 text conditioning."""
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
"""RMS-based layer normalization (T5 style)."""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = mx.ones((dim,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return mx.fast.rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
class T5RelativeEmbedding(nn.Module):
|
||||
"""T5-style relative position bias with bucketing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_buckets: int,
|
||||
num_heads: int,
|
||||
bidirectional: bool = True,
|
||||
max_dist: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_buckets = num_buckets
|
||||
self.num_heads = num_heads
|
||||
self.bidirectional = bidirectional
|
||||
self.max_dist = max_dist
|
||||
self.embedding = nn.Embedding(num_buckets, num_heads)
|
||||
|
||||
def _relative_position_bucket(self, rel_pos: mx.array) -> mx.array:
|
||||
if self.bidirectional:
|
||||
num_buckets = self.num_buckets // 2
|
||||
rel_buckets = (rel_pos > 0).astype(mx.int32) * num_buckets
|
||||
rel_pos = mx.abs(rel_pos)
|
||||
else:
|
||||
num_buckets = self.num_buckets
|
||||
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
|
||||
rel_pos = mx.maximum(-rel_pos, mx.zeros_like(rel_pos))
|
||||
|
||||
max_exact = num_buckets // 2
|
||||
is_small = rel_pos < max_exact
|
||||
|
||||
rel_pos_f = rel_pos.astype(mx.float32)
|
||||
rel_pos_large = max_exact + (
|
||||
mx.log(rel_pos_f / max_exact)
|
||||
/ math.log(self.max_dist / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).astype(mx.int32)
|
||||
rel_pos_large = mx.minimum(
|
||||
rel_pos_large,
|
||||
mx.full(rel_pos_large.shape, num_buckets - 1, dtype=mx.int32),
|
||||
)
|
||||
|
||||
rel_buckets = rel_buckets + mx.where(
|
||||
is_small, rel_pos.astype(mx.int32), rel_pos_large
|
||||
)
|
||||
return rel_buckets
|
||||
|
||||
def __call__(self, lq: int, lk: int) -> mx.array:
|
||||
positions_k = mx.arange(lk)[None, :] # [1, lk]
|
||||
positions_q = mx.arange(lq)[:, None] # [lq, 1]
|
||||
rel_pos = positions_k - positions_q # [lq, lk]
|
||||
|
||||
buckets = self._relative_position_bucket(rel_pos)
|
||||
embeds = self.embedding(buckets) # [lq, lk, num_heads]
|
||||
embeds = embeds.transpose(2, 0, 1)[None, :, :, :] # [1, N, lq, lk]
|
||||
return embeds
|
||||
|
||||
|
||||
class T5Attention(nn.Module):
|
||||
"""T5-style multi-head attention (no scaling)."""
|
||||
|
||||
def __init__(self, dim: int, dim_attn: int, num_heads: int, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
assert dim_attn % num_heads == 0
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim_attn // num_heads
|
||||
|
||||
self.q = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.k = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.v = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.o = nn.Linear(dim_attn, dim, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
context: mx.array | None = None,
|
||||
mask: mx.array | None = None,
|
||||
pos_bias: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
context = x if context is None else context
|
||||
b, n, c = x.shape[0], self.num_heads, self.head_dim
|
||||
|
||||
q = self.q(x).reshape(b, -1, n, c) # [B, Lq, N, C]
|
||||
k = self.k(context).reshape(b, -1, n, c) # [B, Lk, N, C]
|
||||
v = self.v(context).reshape(b, -1, n, c)
|
||||
|
||||
# T5 uses no scaling — compute attention manually with float32 softmax
|
||||
# to match official: F.softmax(attn.float(), dim=-1).type_as(attn)
|
||||
# Using SDPA with bfloat16 inputs causes precision loss in softmax
|
||||
# since unscaled logits can be very large (no 1/sqrt(d) division).
|
||||
q = q.transpose(0, 2, 1, 3) # [B, N, Lq, C]
|
||||
k = k.transpose(0, 2, 1, 3)
|
||||
v = v.transpose(0, 2, 1, 3)
|
||||
|
||||
# QK^T (no scaling) — compute in float32 for precision
|
||||
attn = q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2)
|
||||
|
||||
# Add position bias
|
||||
if pos_bias is not None:
|
||||
attn = attn + pos_bias.astype(mx.float32)
|
||||
|
||||
# Apply attention mask (use dtype min like official, not -1e9)
|
||||
if mask is not None:
|
||||
if mask.ndim == 2:
|
||||
mask = mask[:, None, None, :] # [B, 1, 1, Lk]
|
||||
elif mask.ndim == 3:
|
||||
mask = mask[:, None, :, :] # [B, 1, Lq, Lk]
|
||||
additive_mask = mx.where(mask == 0, -3.389e38, 0.0).astype(mx.float32)
|
||||
attn = attn + additive_mask
|
||||
|
||||
# Softmax in float32 (matches official), then cast back
|
||||
attn = mx.softmax(attn, axis=-1).astype(q.dtype)
|
||||
|
||||
# Attention @ V
|
||||
out = (attn @ v).transpose(0, 2, 1, 3).reshape(b, -1, n * c)
|
||||
return self.o(out)
|
||||
|
||||
|
||||
class T5FeedForward(nn.Module):
|
||||
"""Gated feed-forward: gate(x) * fc1(x) -> fc2."""
|
||||
|
||||
def __init__(self, dim: int, dim_ffn: int):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_ffn = dim_ffn
|
||||
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.gate_act = nn.GELU(approx="tanh")
|
||||
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.fc2(self.fc1(x) * self.gate_act(self.gate_proj(x)))
|
||||
|
||||
|
||||
class T5SelfAttentionBlock(nn.Module):
|
||||
"""T5 encoder block: self-attention + FFN."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_attn: int,
|
||||
dim_ffn: int,
|
||||
num_heads: int,
|
||||
num_buckets: int,
|
||||
shared_pos: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.shared_pos = shared_pos
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.attn = T5Attention(dim, dim_attn, num_heads)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn)
|
||||
self.pos_embedding = (
|
||||
None
|
||||
if shared_pos
|
||||
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array | None = None,
|
||||
pos_bias: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(x.shape[1], x.shape[1])
|
||||
x = x + self.attn(self.norm1(x), mask=mask, pos_bias=e)
|
||||
x = x + self.ffn(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class T5Encoder(nn.Module):
|
||||
"""T5 Encoder (UMT5-XXL configuration)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int = 256384,
|
||||
dim: int = 4096,
|
||||
dim_attn: int = 4096,
|
||||
dim_ffn: int = 10240,
|
||||
num_heads: int = 64,
|
||||
num_layers: int = 24,
|
||||
num_buckets: int = 32,
|
||||
shared_pos: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
self.token_embedding = nn.Embedding(vocab_size, dim)
|
||||
self.pos_embedding = (
|
||||
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
|
||||
if shared_pos
|
||||
else None
|
||||
)
|
||||
self.blocks = [
|
||||
T5SelfAttentionBlock(
|
||||
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
def __call__(self, ids: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||
"""
|
||||
Args:
|
||||
ids: Token IDs [B, L]
|
||||
mask: Attention mask [B, L]
|
||||
|
||||
Returns:
|
||||
Hidden states [B, L, dim]
|
||||
"""
|
||||
x = self.token_embedding(ids)
|
||||
|
||||
e = self.pos_embedding(x.shape[1], x.shape[1]) if self.pos_embedding else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask=mask, pos_bias=e)
|
||||
|
||||
x = self.norm(x)
|
||||
return x
|
||||
338
mlx_video/models/wan_2/tiling.py
Normal file
338
mlx_video/models/wan_2/tiling.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""Wan-specific tiled VAE decoding.
|
||||
|
||||
Re-exports all tiling utilities from the LTX VAE tiling module and provides
|
||||
a Wan-specific ``decode_with_tiling`` that adds ``causal_temporal`` support
|
||||
for non-causal temporal decoders (e.g. Wan2.1 where T latent frames → T*scale
|
||||
output frames rather than LTX's 1+(T-1)*scale mapping).
|
||||
|
||||
# TODO: This function can be refactored to consolidate with
|
||||
# mlx_video.models.ltx_2.video_vae.tiling.decode_with_tiling once the
|
||||
# causal_temporal generalisation is accepted upstream.
|
||||
"""
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import (
|
||||
SpatialTilingConfig,
|
||||
TemporalTilingConfig,
|
||||
TilingConfig,
|
||||
map_spatial_slice,
|
||||
map_temporal_slice,
|
||||
split_in_spatial,
|
||||
split_in_temporal,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SpatialTilingConfig",
|
||||
"TemporalTilingConfig",
|
||||
"TilingConfig",
|
||||
"decode_with_tiling",
|
||||
"map_spatial_slice",
|
||||
"map_temporal_slice",
|
||||
"split_in_spatial",
|
||||
"split_in_temporal",
|
||||
]
|
||||
|
||||
|
||||
def decode_with_tiling(
|
||||
decoder_fn,
|
||||
latents: mx.array,
|
||||
tiling_config: TilingConfig,
|
||||
spatial_scale: int = 32,
|
||||
temporal_scale: int = 8,
|
||||
causal: bool = False,
|
||||
causal_temporal: bool = True,
|
||||
timestep: Optional[mx.array] = None,
|
||||
chunked_conv: bool = False,
|
||||
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
|
||||
) -> mx.array:
|
||||
"""Decode latents using tiling to reduce memory usage.
|
||||
|
||||
Args:
|
||||
decoder_fn: Decoder function to call for each tile.
|
||||
latents: Input latents of shape (B, C, F, H, W).
|
||||
tiling_config: Tiling configuration.
|
||||
spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify).
|
||||
temporal_scale: Temporal scale factor (8 for LTX VAE).
|
||||
causal: Whether to use causal convolutions.
|
||||
causal_temporal: Whether the decoder uses causal temporal mapping where
|
||||
T input frames produce 1+(T-1)*scale output frames. When False, uses
|
||||
simple scaling where T frames produce T*scale output frames.
|
||||
Default True (LTX behavior). Set False for non-causal decoders (e.g. Wan2.1).
|
||||
timestep: Optional timestep for conditioning.
|
||||
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
|
||||
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
|
||||
frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames.
|
||||
start_idx: Starting frame index in the full video.
|
||||
|
||||
Returns:
|
||||
Decoded video.
|
||||
"""
|
||||
import gc
|
||||
|
||||
b, c, f_latent, h_latent, w_latent = latents.shape
|
||||
|
||||
# Compute output shape
|
||||
out_f = (
|
||||
(1 + (f_latent - 1) * temporal_scale)
|
||||
if causal_temporal
|
||||
else (f_latent * temporal_scale)
|
||||
)
|
||||
out_h = h_latent * spatial_scale
|
||||
out_w = w_latent * spatial_scale
|
||||
|
||||
# Get tile size and overlap in latent space
|
||||
if tiling_config.spatial_config is not None:
|
||||
s_cfg = tiling_config.spatial_config
|
||||
spatial_tile_size = s_cfg.tile_size_in_pixels // spatial_scale
|
||||
spatial_overlap = s_cfg.tile_overlap_in_pixels // spatial_scale
|
||||
else:
|
||||
spatial_tile_size = max(h_latent, w_latent)
|
||||
spatial_overlap = 0
|
||||
|
||||
if tiling_config.temporal_config is not None:
|
||||
t_cfg = tiling_config.temporal_config
|
||||
temporal_tile_size = t_cfg.tile_size_in_frames // temporal_scale
|
||||
temporal_overlap = t_cfg.tile_overlap_in_frames // temporal_scale
|
||||
else:
|
||||
temporal_tile_size = f_latent
|
||||
temporal_overlap = 0
|
||||
|
||||
# Compute intervals for each dimension
|
||||
if causal_temporal:
|
||||
temporal_intervals = split_in_temporal(
|
||||
temporal_tile_size, temporal_overlap, f_latent
|
||||
)
|
||||
else:
|
||||
temporal_intervals = split_in_spatial(
|
||||
temporal_tile_size, temporal_overlap, f_latent
|
||||
)
|
||||
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
|
||||
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
|
||||
|
||||
num_t_tiles = len(temporal_intervals.starts)
|
||||
num_h_tiles = len(height_intervals.starts)
|
||||
num_w_tiles = len(width_intervals.starts)
|
||||
total_tiles = num_t_tiles * num_h_tiles * num_w_tiles # noqa: F841
|
||||
|
||||
# Initialize output and weight accumulator
|
||||
# Use float32 for accumulation to avoid precision issues
|
||||
output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32)
|
||||
weights = mx.zeros((b, 1, out_f, out_h, out_w), dtype=mx.float32)
|
||||
mx.eval(output, weights)
|
||||
|
||||
tile_idx = 0
|
||||
for t_idx in range(num_t_tiles):
|
||||
t_start = temporal_intervals.starts[t_idx]
|
||||
t_end = temporal_intervals.ends[t_idx]
|
||||
t_left = temporal_intervals.left_ramps[t_idx]
|
||||
t_right = temporal_intervals.right_ramps[t_idx]
|
||||
|
||||
# Map temporal coordinates
|
||||
if causal_temporal:
|
||||
out_t_slice, t_mask = map_temporal_slice(
|
||||
t_start, t_end, t_left, t_right, temporal_scale
|
||||
)
|
||||
else:
|
||||
out_t_slice, t_mask = map_spatial_slice(
|
||||
t_start, t_end, t_left, t_right, temporal_scale
|
||||
)
|
||||
|
||||
for h_idx in range(num_h_tiles):
|
||||
h_start = height_intervals.starts[h_idx]
|
||||
h_end = height_intervals.ends[h_idx]
|
||||
h_left = height_intervals.left_ramps[h_idx]
|
||||
h_right = height_intervals.right_ramps[h_idx]
|
||||
|
||||
# Map height coordinates
|
||||
out_h_slice, h_mask = map_spatial_slice(
|
||||
h_start, h_end, h_left, h_right, spatial_scale
|
||||
)
|
||||
|
||||
for w_idx in range(num_w_tiles):
|
||||
w_start = width_intervals.starts[w_idx]
|
||||
w_end = width_intervals.ends[w_idx]
|
||||
w_left = width_intervals.left_ramps[w_idx]
|
||||
w_right = width_intervals.right_ramps[w_idx]
|
||||
|
||||
# Map width coordinates
|
||||
out_w_slice, w_mask = map_spatial_slice(
|
||||
w_start, w_end, w_left, w_right, spatial_scale
|
||||
)
|
||||
|
||||
# Extract tile latents (small slice)
|
||||
tile_latents = latents[
|
||||
:, :, t_start:t_end, h_start:h_end, w_start:w_end
|
||||
]
|
||||
|
||||
# Decode tile
|
||||
tile_output = decoder_fn(
|
||||
tile_latents,
|
||||
causal=causal,
|
||||
timestep=timestep,
|
||||
debug=False,
|
||||
chunked_conv=chunked_conv,
|
||||
)
|
||||
mx.eval(tile_output)
|
||||
|
||||
# Clear tile_latents reference
|
||||
del tile_latents
|
||||
|
||||
# Get actual decoded dimensions
|
||||
_, _, decoded_t, decoded_h, decoded_w = tile_output.shape
|
||||
expected_t = out_t_slice.stop - out_t_slice.start
|
||||
expected_h = out_h_slice.stop - out_h_slice.start
|
||||
expected_w = out_w_slice.stop - out_w_slice.start
|
||||
|
||||
# Handle potential size mismatches (use minimum)
|
||||
actual_t = min(decoded_t, expected_t)
|
||||
actual_h = min(decoded_h, expected_h)
|
||||
actual_w = min(decoded_w, expected_w)
|
||||
|
||||
# Build blend mask
|
||||
t_mask_slice = t_mask[:actual_t] if len(t_mask) > actual_t else t_mask
|
||||
h_mask_slice = h_mask[:actual_h] if len(h_mask) > actual_h else h_mask
|
||||
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
|
||||
|
||||
blend_mask = (
|
||||
t_mask_slice.reshape(1, 1, -1, 1, 1)
|
||||
* h_mask_slice.reshape(1, 1, 1, -1, 1)
|
||||
* w_mask_slice.reshape(1, 1, 1, 1, -1)
|
||||
)
|
||||
|
||||
# Slice tile output to match
|
||||
tile_output_slice = tile_output[
|
||||
:, :, :actual_t, :actual_h, :actual_w
|
||||
].astype(mx.float32)
|
||||
|
||||
# Clear full tile_output
|
||||
del tile_output
|
||||
|
||||
# Compute output coordinates
|
||||
t_out_start = out_t_slice.start
|
||||
t_out_end = t_out_start + actual_t
|
||||
h_out_start = out_h_slice.start
|
||||
h_out_end = h_out_start + actual_h
|
||||
w_out_start = out_w_slice.start
|
||||
w_out_end = w_out_start + actual_w
|
||||
|
||||
# Weighted accumulation
|
||||
weighted_tile = tile_output_slice * blend_mask
|
||||
|
||||
# Update output using slice assignment
|
||||
output[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
] = (
|
||||
output[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
]
|
||||
+ weighted_tile
|
||||
)
|
||||
weights[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
] = (
|
||||
weights[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
]
|
||||
+ blend_mask
|
||||
)
|
||||
|
||||
# Force evaluation to free memory
|
||||
mx.eval(output, weights)
|
||||
|
||||
# Clean up tile-specific arrays
|
||||
del tile_output_slice, weighted_tile, blend_mask
|
||||
del t_mask_slice, h_mask_slice, w_mask_slice
|
||||
|
||||
tile_idx += 1
|
||||
|
||||
# Periodic garbage collection and cache clearing
|
||||
if tile_idx % 4 == 0:
|
||||
gc.collect()
|
||||
try:
|
||||
mx.clear_cache()
|
||||
except Exception:
|
||||
pass # May not be available on all platforms
|
||||
|
||||
# After completing all spatial tiles for this temporal tile,
|
||||
# check if any frames are now finalized (no future tiles will contribute)
|
||||
if on_frames_ready is not None and num_t_tiles > 1:
|
||||
# Determine the finalized frame boundary
|
||||
# Frames before the start of the next tile's output region are finalized
|
||||
if t_idx < num_t_tiles - 1:
|
||||
# Next tile starts at temporal_intervals.starts[t_idx + 1]
|
||||
next_tile_start_latent = temporal_intervals.starts[t_idx + 1]
|
||||
# Map to output frame index (first frame of next tile's contribution)
|
||||
if next_tile_start_latent == 0:
|
||||
next_tile_start_out = 0
|
||||
elif causal_temporal:
|
||||
next_tile_start_out = (
|
||||
1 + (next_tile_start_latent - 1) * temporal_scale
|
||||
)
|
||||
else:
|
||||
next_tile_start_out = next_tile_start_latent * temporal_scale
|
||||
|
||||
# We need to track how many frames we've already emitted
|
||||
if not hasattr(decode_with_tiling, "_emitted_frames"):
|
||||
decode_with_tiling._emitted_frames = 0
|
||||
emitted = decode_with_tiling._emitted_frames
|
||||
|
||||
if next_tile_start_out > emitted:
|
||||
# Normalize and emit frames [emitted, next_tile_start_out)
|
||||
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
|
||||
finalized_weights = mx.maximum(finalized_weights, 1e-8)
|
||||
finalized_output = (
|
||||
output[:, :, emitted:next_tile_start_out, :, :]
|
||||
/ finalized_weights
|
||||
)
|
||||
finalized_output = finalized_output.astype(latents.dtype)
|
||||
mx.eval(finalized_output)
|
||||
|
||||
on_frames_ready(finalized_output, emitted)
|
||||
decode_with_tiling._emitted_frames = next_tile_start_out
|
||||
|
||||
del finalized_output, finalized_weights
|
||||
gc.collect()
|
||||
|
||||
# Normalize by weights
|
||||
weights = mx.maximum(weights, 1e-8)
|
||||
output = output / weights
|
||||
mx.eval(output)
|
||||
|
||||
# Emit remaining frames if callback provided
|
||||
if on_frames_ready is not None:
|
||||
emitted = getattr(decode_with_tiling, "_emitted_frames", 0)
|
||||
if emitted < out_f:
|
||||
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
||||
mx.eval(remaining_output)
|
||||
on_frames_ready(remaining_output, emitted)
|
||||
del remaining_output
|
||||
|
||||
# Reset emitted frames counter for next call
|
||||
if hasattr(decode_with_tiling, "_emitted_frames"):
|
||||
del decode_with_tiling._emitted_frames
|
||||
|
||||
# Clean up weights
|
||||
del weights
|
||||
gc.collect()
|
||||
|
||||
# Convert back to original dtype if needed
|
||||
return output.astype(latents.dtype)
|
||||
104
mlx_video/models/wan_2/transformer.py
Normal file
104
mlx_video/models/wan_2/transformer.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention, _linear_dtype
|
||||
|
||||
|
||||
class WanAttentionBlock(nn.Module):
|
||||
"""Wan transformer block with learned modulation, self-attn, cross-attn, and FFN."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
ffn_dim: int,
|
||||
num_heads: int,
|
||||
window_size: tuple = (-1, -1),
|
||||
qk_norm: bool = True,
|
||||
cross_attn_norm: bool = False,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Self-attention
|
||||
self.norm1 = WanLayerNorm(dim, eps)
|
||||
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
|
||||
|
||||
# Cross-attention (with optional norm on context)
|
||||
self.norm3 = (
|
||||
WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else None
|
||||
)
|
||||
self.cross_attn = WanCrossAttention(dim, num_heads, qk_norm, eps)
|
||||
|
||||
# Feed-forward
|
||||
self.norm2 = WanLayerNorm(dim, eps)
|
||||
self.ffn = WanFFN(dim, ffn_dim)
|
||||
|
||||
# Learned modulation: 6 vectors for scale/shift/gate (kept in float32 for precision)
|
||||
self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(
|
||||
mx.float32
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
e: mx.array,
|
||||
seq_lens: list,
|
||||
grid_sizes: list,
|
||||
freqs: mx.array,
|
||||
context: mx.array,
|
||||
context_lens: list | None = None,
|
||||
cross_kv_cache: tuple | None = None,
|
||||
rope_cos_sin: tuple | None = None,
|
||||
attn_mask: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
# Modulation: compute in float32 for precision, matching the reference
|
||||
# which keeps residual x in float32 via torch.amp.autocast(dtype=float32).
|
||||
# By keeping modulation in float32, type promotion ensures the residual
|
||||
# stream stays float32 throughout all 30 layers (gate * output + x → float32).
|
||||
mod = self.modulation + e # float32
|
||||
e0, e1, e2, e3, e4, e5 = (
|
||||
mod[:, :, 0, :], # shift for self-attn
|
||||
mod[:, :, 1, :], # scale for self-attn
|
||||
mod[:, :, 2, :], # gate for self-attn
|
||||
mod[:, :, 3, :], # shift for ffn
|
||||
mod[:, :, 4, :], # scale for ffn
|
||||
mod[:, :, 5, :], # gate for ffn
|
||||
)
|
||||
|
||||
# Self-attention with modulation (hidden state stays in w_dtype)
|
||||
x_mod = self.norm1(x) * (1 + e1) + e0
|
||||
y = self.self_attn(
|
||||
x_mod,
|
||||
seq_lens,
|
||||
grid_sizes,
|
||||
freqs,
|
||||
rope_cos_sin=rope_cos_sin,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
x = x + y * e2
|
||||
|
||||
# Cross-attention (no modulation, just norm)
|
||||
x_cross = self.norm3(x) if self.norm3 is not None else x
|
||||
x = x + self.cross_attn(x_cross, context, context_lens, kv_cache=cross_kv_cache)
|
||||
|
||||
# FFN with modulation
|
||||
x_mod = self.norm2(x) * (1 + e4) + e3
|
||||
y = self.ffn(x_mod)
|
||||
x = x + y * e5
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class WanFFN(nn.Module):
|
||||
"""Gated feed-forward network with GELU(tanh) activation."""
|
||||
|
||||
def __init__(self, dim: int, ffn_dim: int):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(dim, ffn_dim)
|
||||
self.act = nn.GELU(approx="tanh")
|
||||
self.fc2 = nn.Linear(ffn_dim, dim)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
x_w = x.astype(_linear_dtype(self.fc1))
|
||||
return self.fc2(self.act(self.fc1(x_w)))
|
||||
191
mlx_video/models/wan_2/utils.py
Normal file
191
mlx_video/models/wan_2/utils.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Wan model loading utilities."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
def load_wan_model(
|
||||
model_path: Path,
|
||||
config,
|
||||
quantization: dict | None = None,
|
||||
loras: list | None = None,
|
||||
):
|
||||
"""Load and initialize WanModel, with optional quantization and LoRA support.
|
||||
|
||||
Args:
|
||||
model_path: Path to model safetensors file
|
||||
config: WanModelConfig
|
||||
quantization: Optional dict with 'bits' and 'group_size' keys.
|
||||
If provided, creates QuantizedLinear stubs before loading.
|
||||
loras: Optional list of (lora_path, strength) tuples to apply.
|
||||
"""
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
model = WanModel(config)
|
||||
|
||||
if quantization:
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
|
||||
nn.quantize(
|
||||
model,
|
||||
group_size=quantization["group_size"],
|
||||
bits=quantization["bits"],
|
||||
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
||||
)
|
||||
|
||||
weights = mx.load(str(model_path))
|
||||
|
||||
# Apply LoRAs: dequantize+merge for quantized models, weight merge for bf16
|
||||
if loras:
|
||||
if quantization:
|
||||
# Dequantize LoRA-targeted layers, merge delta, replace with bf16 Linear.
|
||||
# Non-LoRA layers stay 4-bit. Zero per-step overhead.
|
||||
from mlx_video.models.wan_2.convert import _load_lora_configs
|
||||
from mlx_video.lora import apply_loras_to_model
|
||||
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(model.parameters())
|
||||
module_to_loras = _load_lora_configs(loras)
|
||||
apply_loras_to_model(model, module_to_loras)
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
else:
|
||||
# Weight merging: fold LoRA into bf16 weights before loading
|
||||
from mlx_video.models.wan_2.convert import load_and_apply_loras
|
||||
|
||||
weights = load_and_apply_loras(dict(weights), loras)
|
||||
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
|
||||
|
||||
def load_t5_encoder(model_path: Path, config):
|
||||
"""Load T5 text encoder.
|
||||
|
||||
Weights are upcast to float32 for maximum precision — the T5 encoder
|
||||
only runs once per generation, so performance impact is negligible.
|
||||
This matches the official which computes softmax in float32 explicitly.
|
||||
"""
|
||||
from mlx_video.models.wan_2.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
vocab_size=config.t5_vocab_size,
|
||||
dim=config.t5_dim,
|
||||
dim_attn=config.t5_dim_attn,
|
||||
dim_ffn=config.t5_dim_ffn,
|
||||
num_heads=config.t5_num_heads,
|
||||
num_layers=config.t5_num_layers,
|
||||
num_buckets=config.t5_num_buckets,
|
||||
shared_pos=False,
|
||||
)
|
||||
weights = mx.load(str(model_path))
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
encoder.load_weights(list(weights.items()))
|
||||
mx.eval(encoder.parameters())
|
||||
return encoder
|
||||
|
||||
|
||||
def load_vae_decoder(model_path: Path, config=None):
|
||||
"""Load VAE decoder (skips encoder weights with strict=False).
|
||||
|
||||
For Wan2.2 (vae_z_dim=48), uses Wan22VAEDecoder.
|
||||
For Wan2.1 (vae_z_dim=16), uses WanVAE.
|
||||
"""
|
||||
is_wan22 = config is not None and config.vae_z_dim == 48
|
||||
|
||||
if is_wan22:
|
||||
from mlx_video.models.wan_2.vae22 import Wan22VAEDecoder
|
||||
|
||||
vae = Wan22VAEDecoder(z_dim=48)
|
||||
else:
|
||||
from mlx_video.models.wan_2.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16)
|
||||
|
||||
weights = mx.load(str(model_path))
|
||||
# Upcast VAE weights to float32 for quality — official Wan2.2 runs VAE in float32
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
vae.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(vae.parameters())
|
||||
return vae
|
||||
|
||||
|
||||
def load_vae_encoder(model_path: Path, config=None):
|
||||
"""Load VAE encoder for I2V image encoding.
|
||||
|
||||
For Wan2.2 TI2V (vae_z_dim=48), uses Wan22VAEEncoder.
|
||||
For Wan2.1/I2V-14B (vae_z_dim=16), uses WanVAE with encoder=True.
|
||||
"""
|
||||
if config is not None and config.vae_z_dim == 16:
|
||||
from mlx_video.models.wan_2.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16, encoder=True)
|
||||
else:
|
||||
from mlx_video.models.wan_2.vae22 import Wan22VAEEncoder
|
||||
|
||||
vae = Wan22VAEEncoder(z_dim=config.vae_z_dim if config else 48)
|
||||
|
||||
weights = mx.load(str(model_path))
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
vae.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(vae.parameters())
|
||||
return vae
|
||||
|
||||
|
||||
def _clean_text(text: str) -> str:
|
||||
"""Clean text matching official Wan2.2 tokenizer preprocessing.
|
||||
|
||||
Applies ftfy.fix_text (fixes mojibake, normalizes fullwidth chars),
|
||||
double HTML unescape, and whitespace normalization. Critical for
|
||||
correct tokenization of the Chinese negative prompt.
|
||||
"""
|
||||
import html
|
||||
import re
|
||||
|
||||
try:
|
||||
import ftfy
|
||||
|
||||
text = ftfy.fix_text(text)
|
||||
except ImportError:
|
||||
pass
|
||||
text = html.unescape(html.unescape(text))
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
return text
|
||||
|
||||
|
||||
def encode_text(
|
||||
encoder,
|
||||
tokenizer,
|
||||
prompt: str,
|
||||
text_len: int = 512,
|
||||
) -> mx.array:
|
||||
"""Encode text prompt using T5 encoder.
|
||||
|
||||
Args:
|
||||
encoder: T5Encoder model
|
||||
tokenizer: HuggingFace tokenizer
|
||||
prompt: Text prompt
|
||||
text_len: Maximum text length
|
||||
|
||||
Returns:
|
||||
Text embeddings [L, dim]
|
||||
"""
|
||||
prompt = _clean_text(prompt)
|
||||
tokens = tokenizer(
|
||||
prompt,
|
||||
max_length=text_len,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
ids = mx.array(tokens["input_ids"])
|
||||
mask = mx.array(tokens["attention_mask"])
|
||||
|
||||
embeddings = encoder(ids, mask=mask)
|
||||
|
||||
# Return only non-padding tokens
|
||||
seq_len = int(mask.sum().item())
|
||||
return embeddings[0, :seq_len]
|
||||
629
mlx_video/models/wan_2/vae.py
Normal file
629
mlx_video/models/wan_2/vae.py
Normal file
@@ -0,0 +1,629 @@
|
||||
"""3D VAE Decoder for Wan2.1/2.2 (compression 4×8×8).
|
||||
|
||||
Module structure mirrors original PyTorch checkpoint key hierarchy
|
||||
so weights load directly without key sanitization.
|
||||
"""
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
# Per-channel normalization statistics for z_dim=16
|
||||
VAE_MEAN = [
|
||||
-0.7571,
|
||||
-0.7089,
|
||||
-0.9113,
|
||||
0.1075,
|
||||
-0.1745,
|
||||
0.9653,
|
||||
-0.1517,
|
||||
1.5508,
|
||||
0.4134,
|
||||
-0.0715,
|
||||
0.5517,
|
||||
-0.3632,
|
||||
-0.1922,
|
||||
-0.9497,
|
||||
0.2503,
|
||||
-0.2921,
|
||||
]
|
||||
VAE_STD = [
|
||||
2.8184,
|
||||
1.4541,
|
||||
2.3275,
|
||||
2.6558,
|
||||
1.2196,
|
||||
1.7708,
|
||||
2.6052,
|
||||
2.0743,
|
||||
3.2687,
|
||||
2.1526,
|
||||
2.8652,
|
||||
1.5579,
|
||||
1.6382,
|
||||
1.1253,
|
||||
2.8251,
|
||||
1.9160,
|
||||
]
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
"""3D convolution with causal temporal padding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int | tuple,
|
||||
stride: int | tuple = 1,
|
||||
padding: int | tuple = 0,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
if isinstance(padding, int):
|
||||
padding = (padding, padding, padding)
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
# Causal padding: match reference formula dilation*(k-1) + (1-stride)
|
||||
# With dilation=1: k-stride (pads left only, no future context)
|
||||
self._causal_pad_t = kernel_size[0] - stride[0]
|
||||
self._pad_h = padding[1]
|
||||
self._pad_w = padding[2]
|
||||
|
||||
# MLX Conv3d: weight shape [O, D, H, W, I]
|
||||
self.weight = mx.zeros(
|
||||
(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels)
|
||||
)
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
|
||||
def __call__(self, x: mx.array, cache_x: mx.array = None) -> mx.array:
|
||||
"""x: [B, C, T, H, W] (channel-first)"""
|
||||
b, c, t, h, w = x.shape
|
||||
|
||||
causal_pad = self._causal_pad_t
|
||||
if cache_x is not None and causal_pad > 0:
|
||||
x = mx.concatenate([cache_x, x], axis=2)
|
||||
causal_pad = max(0, causal_pad - cache_x.shape[2])
|
||||
|
||||
if causal_pad > 0:
|
||||
pad_t = mx.zeros((b, c, causal_pad, h, w), dtype=x.dtype)
|
||||
x = mx.concatenate([pad_t, x], axis=2)
|
||||
|
||||
if self._pad_h > 0 or self._pad_w > 0:
|
||||
x = mx.pad(
|
||||
x,
|
||||
[
|
||||
(0, 0),
|
||||
(0, 0),
|
||||
(0, 0),
|
||||
(self._pad_h, self._pad_h),
|
||||
(self._pad_w, self._pad_w),
|
||||
],
|
||||
)
|
||||
|
||||
x = x.transpose(0, 2, 3, 4, 1) # [B, T, H, W, C]
|
||||
out = self._conv3d(x)
|
||||
return out.transpose(0, 4, 1, 2, 3) # [B, O, T', H', W']
|
||||
|
||||
def _conv3d(self, x: mx.array) -> mx.array:
|
||||
"""3D conv via sliding window + 2D conv per time step.
|
||||
x: [B, T, H, W, C_in] -> [B, T_out, H_out, W_out, C_out]
|
||||
"""
|
||||
b, t, h, w, c_in = x.shape
|
||||
kt, kh, kw = self.kernel_size
|
||||
st, sh, sw = self.stride
|
||||
t_out = (t - kt) // st + 1
|
||||
|
||||
# Pre-reshape weight: [O, D, H, W, I] -> [O, H, W, D*I]
|
||||
w_2d = self.weight.transpose(0, 2, 3, 1, 4).reshape(
|
||||
self.weight.shape[0], kh, kw, kt * c_in
|
||||
)
|
||||
outputs = []
|
||||
for t_i in range(t_out):
|
||||
t_start = t_i * st
|
||||
window = x[:, t_start : t_start + kt]
|
||||
window = window.transpose(0, 2, 3, 1, 4).reshape(b, h, w, kt * c_in)
|
||||
out_2d = mx.conv2d(window, w_2d, stride=(sh, sw)) + self.bias
|
||||
outputs.append(out_2d)
|
||||
return mx.stack(outputs, axis=1)
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
"""Channel-first L2 normalization matching original Wan VAE.
|
||||
|
||||
Uses F.normalize (L2 norm) with learned scale, equivalent to RMS norm.
|
||||
images=True: gamma shape (dim, 1, 1) for 4D (per-frame) input.
|
||||
images=False: gamma shape (dim, 1, 1, 1) for 5D video input.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, channel_first: bool = True, images: bool = True):
|
||||
super().__init__()
|
||||
self.channel_first = channel_first
|
||||
self.scale = dim**0.5
|
||||
if channel_first:
|
||||
broadcastable = (1, 1) if images else (1, 1, 1)
|
||||
self.gamma = mx.ones((dim, *broadcastable))
|
||||
else:
|
||||
self.gamma = mx.ones((dim,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
norm_dim = 1 if self.channel_first else -1
|
||||
# L2 normalize along channel dim (matches F.normalize)
|
||||
norm = mx.sqrt(
|
||||
mx.clip(
|
||||
mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None
|
||||
)
|
||||
)
|
||||
return (x / norm) * self.scale * self.gamma
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
"""Residual block with causal 3D convolutions.
|
||||
|
||||
Uses `residual` list with None gaps to match original PyTorch
|
||||
nn.Sequential indices: [0]=norm, [1]=SiLU, [2]=conv, [3]=norm,
|
||||
[4]=SiLU, [5]=Dropout, [6]=conv. Only indices 0,2,3,6 have params.
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim: int, out_dim: int):
|
||||
super().__init__()
|
||||
self.residual = [
|
||||
RMS_norm(in_dim, images=False), # [0]
|
||||
None, # [1] SiLU
|
||||
CausalConv3d(in_dim, out_dim, 3, padding=1), # [2]
|
||||
RMS_norm(out_dim, images=False), # [3]
|
||||
None, # [4] SiLU
|
||||
None, # [5] Dropout
|
||||
CausalConv3d(out_dim, out_dim, 3, padding=1), # [6]
|
||||
]
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
|
||||
|
||||
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
|
||||
h = x if self.shortcut is None else self.shortcut(x)
|
||||
|
||||
if feat_cache is not None:
|
||||
# First conv: norm -> silu -> [cache] -> conv
|
||||
x = nn.silu(self.residual[0](x))
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:]
|
||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
x = self.residual[2](x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
# Second conv: norm -> silu -> [cache] -> conv
|
||||
x = nn.silu(self.residual[3](x))
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:]
|
||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
x = self.residual[6](x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = nn.silu(self.residual[0](x))
|
||||
x = self.residual[2](x)
|
||||
x = nn.silu(self.residual[3](x))
|
||||
x = self.residual[6](x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""Single-head spatial self-attention."""
|
||||
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.norm = RMS_norm(dim, images=True)
|
||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""x: [B, C, T, H, W]"""
|
||||
identity = x
|
||||
b, c, t, h, w = x.shape
|
||||
|
||||
# [B,C,T,H,W] -> [B,T,C,H,W] -> [BT,C,H,W] -> norm -> [BT,H,W,C]
|
||||
x = x.transpose(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
x = self.norm(x)
|
||||
x = x.transpose(0, 2, 3, 1) # [BT, H, W, C]
|
||||
|
||||
qkv = self.to_qkv(x) # [BT, H, W, 3C]
|
||||
qkv = qkv.reshape(b * t, h * w, 3, c).transpose(2, 0, 1, 3)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
q = q[:, None, :, :] # [BT, 1, HW, C]
|
||||
k = k[:, None, :, :]
|
||||
v = v[:, None, :, :]
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=c**-0.5)
|
||||
out = out.squeeze(1).reshape(b * t, h, w, c) # [BT, H, W, C]
|
||||
|
||||
out = self.proj(out) # [BT, H, W, C]
|
||||
out = out.reshape(b, t, h, w, c).transpose(0, 4, 1, 2, 3) # [B, C, T, H, W]
|
||||
return out + identity
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
"""Resample block matching original Wan VAE structure.
|
||||
|
||||
Supports both upsampling (decoder) and downsampling (encoder).
|
||||
Uses list-based param storage to match original nn.Sequential key hierarchy.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, mode: str):
|
||||
super().__init__()
|
||||
assert mode in ("upsample2d", "upsample3d", "downsample2d", "downsample3d")
|
||||
self.mode = mode
|
||||
self.dim = dim
|
||||
|
||||
if mode.startswith("upsample"):
|
||||
# resample.0 = Upsample (no params), resample.1 = Conv2d
|
||||
self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)]
|
||||
if mode == "upsample3d":
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)
|
||||
)
|
||||
else:
|
||||
# resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2)
|
||||
self.resample = [None, nn.Conv2d(dim, dim, 3, stride=2)]
|
||||
if mode == "downsample3d":
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
|
||||
"""x: [B, C, T, H, W]"""
|
||||
b, c, t, h, w = x.shape
|
||||
|
||||
if self.mode == "upsample3d":
|
||||
# Temporal upsample via learned conv
|
||||
x_t = self.time_conv(x) # [B, 2C, T, H, W]
|
||||
x_t = x_t.reshape(b, 2, c, t, h, w)
|
||||
x = mx.stack([x_t[:, 0], x_t[:, 1]], axis=3).reshape(b, c, t * 2, h, w)
|
||||
t = t * 2
|
||||
|
||||
if self.mode.startswith("upsample"):
|
||||
# Per-frame spatial upsample: nearest 2x + Conv2d
|
||||
x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C]
|
||||
x = mx.repeat(x, 2, axis=1)
|
||||
x = mx.repeat(x, 2, axis=2)
|
||||
x = self.resample[1](x) # Conv2d [BT, 2H, 2W, C//2]
|
||||
c_out = x.shape[-1]
|
||||
return x.reshape(b, t, h * 2, w * 2, c_out).transpose(0, 4, 1, 2, 3)
|
||||
else:
|
||||
# Per-frame spatial downsample: ZeroPad(0,1,0,1) + Conv2d(stride=2)
|
||||
x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C]
|
||||
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) # ZeroPad2d(0,1,0,1)
|
||||
x = self.resample[1](x) # Conv2d stride=2
|
||||
c_out = x.shape[-1]
|
||||
h_out, w_out = x.shape[1], x.shape[2]
|
||||
x = x.reshape(b, t, h_out, w_out, c_out).transpose(0, 4, 1, 2, 3)
|
||||
|
||||
if self.mode == "downsample3d":
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
# First chunk: save x, skip time_conv
|
||||
feat_cache[idx] = x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
# Subsequent chunks: use cached frame as temporal context
|
||||
cache_x = x[:, :, -1:]
|
||||
x = self.time_conv(x, cache_x=feat_cache[idx][:, :, -1:])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.time_conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
"""3D VAE Decoder matching Wan2.1 architecture.
|
||||
|
||||
Uses flat `middle` and `upsamples` lists to match original
|
||||
PyTorch nn.Sequential weight key hierarchy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 96,
|
||||
z_dim: int = 16,
|
||||
dim_mult: list = None,
|
||||
num_res_blocks: int = 2,
|
||||
temporal_upsample: list = None,
|
||||
):
|
||||
super().__init__()
|
||||
if dim_mult is None:
|
||||
dim_mult = [1, 2, 4, 4]
|
||||
if temporal_upsample is None:
|
||||
temporal_upsample = [True, True, False]
|
||||
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
|
||||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# Middle: [ResBlock, AttentionBlock, ResBlock]
|
||||
self.middle = [
|
||||
ResidualBlock(dims[0], dims[0]),
|
||||
AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0]),
|
||||
]
|
||||
|
||||
# Flat upsample list matching original nn.Sequential indexing
|
||||
upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
if i in (1, 2, 3):
|
||||
in_dim = in_dim // 2
|
||||
for _ in range(num_res_blocks + 1):
|
||||
upsamples.append(ResidualBlock(in_dim, out_dim))
|
||||
in_dim = out_dim
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = "upsample3d" if temporal_upsample[i] else "upsample2d"
|
||||
upsamples.append(Resample(out_dim, mode=mode))
|
||||
self.upsamples = upsamples
|
||||
|
||||
# Output head: [RMS_norm, SiLU (no params), CausalConv3d]
|
||||
self.head = [
|
||||
RMS_norm(dims[-1], images=False), # [0]
|
||||
None, # [1] SiLU
|
||||
CausalConv3d(dims[-1], 3, 3, padding=1), # [2]
|
||||
]
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""x: [B, z_dim, T, H, W] -> [B, 3, T_out, H_out, W_out]"""
|
||||
x = self.conv1(x)
|
||||
|
||||
for layer in self.middle:
|
||||
x = layer(x)
|
||||
|
||||
for layer in self.upsamples:
|
||||
x = layer(x)
|
||||
|
||||
x = nn.silu(self.head[0](x))
|
||||
x = self.head[2](x)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
"""3D VAE Encoder matching Wan2.1 architecture.
|
||||
|
||||
Mirror of Decoder3d with downsampling instead of upsampling.
|
||||
Uses flat lists to match original PyTorch nn.Sequential weight key hierarchy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 96,
|
||||
z_dim: int = 16,
|
||||
dim_mult: list = None,
|
||||
num_res_blocks: int = 2,
|
||||
temporal_downsample: list = None,
|
||||
):
|
||||
super().__init__()
|
||||
if dim_mult is None:
|
||||
dim_mult = [1, 2, 4, 4]
|
||||
if temporal_downsample is None:
|
||||
temporal_downsample = [False, True, True]
|
||||
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
|
||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# Flat downsample list matching original nn.Sequential indexing
|
||||
downsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
for _ in range(num_res_blocks):
|
||||
downsamples.append(ResidualBlock(in_dim, out_dim))
|
||||
in_dim = out_dim
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = "downsample3d" if temporal_downsample[i] else "downsample2d"
|
||||
downsamples.append(Resample(out_dim, mode=mode))
|
||||
self.downsamples = downsamples
|
||||
|
||||
# Middle: [ResBlock, AttentionBlock, ResBlock]
|
||||
self.middle = [
|
||||
ResidualBlock(dims[-1], dims[-1]),
|
||||
AttentionBlock(dims[-1]),
|
||||
ResidualBlock(dims[-1], dims[-1]),
|
||||
]
|
||||
|
||||
# Output head: [RMS_norm, SiLU (no params), CausalConv3d]
|
||||
self.head = [
|
||||
RMS_norm(dims[-1], images=False),
|
||||
None, # SiLU
|
||||
CausalConv3d(dims[-1], z_dim, 3, padding=1),
|
||||
]
|
||||
|
||||
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
|
||||
"""x: [B, 3, T, H, W] -> [B, z_dim, T_lat, H_lat, W_lat]"""
|
||||
if feat_cache is not None:
|
||||
# conv1 with caching
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:]
|
||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
x = self.conv1(x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None and isinstance(layer, (ResidualBlock, Resample)):
|
||||
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
for layer in self.middle:
|
||||
if feat_cache is not None and isinstance(layer, ResidualBlock):
|
||||
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
if feat_cache is not None:
|
||||
# Head: norm -> silu -> [cache] -> conv
|
||||
x = nn.silu(self.head[0](x))
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:]
|
||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
x = self.head[2](x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = nn.silu(self.head[0](x))
|
||||
x = self.head[2](x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class WanVAE(nn.Module):
|
||||
"""Wan2.1 VAE wrapper with per-channel normalization.
|
||||
|
||||
Supports both encode (for I2V) and decode (for all models).
|
||||
"""
|
||||
|
||||
def __init__(self, z_dim: int = 16, encoder: bool = False):
|
||||
super().__init__()
|
||||
self.z_dim = z_dim
|
||||
self.mean = mx.array(VAE_MEAN)
|
||||
self.std = mx.array(VAE_STD)
|
||||
self.inv_std = 1.0 / self.std
|
||||
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim=96, z_dim=z_dim)
|
||||
|
||||
if encoder:
|
||||
self.encoder = Encoder3d(dim=96, z_dim=z_dim * 2)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
|
||||
def encode(self, x: mx.array) -> mx.array:
|
||||
"""Encode video to normalized latent using chunked encoding.
|
||||
|
||||
Uses chunked encoding with temporal caching to match reference behavior.
|
||||
First frame encoded alone, then 4-frame chunks with cached context.
|
||||
|
||||
Args:
|
||||
x: Video [B, 3, T, H, W] in [-1, 1]
|
||||
|
||||
Returns:
|
||||
Normalized latent [B, z_dim, T_lat, H_lat, W_lat]
|
||||
"""
|
||||
# Count cacheable CausalConv3d slots in encoder
|
||||
num_slots = self._count_encoder_cache_slots()
|
||||
feat_cache = [None] * num_slots
|
||||
|
||||
t = x.shape[2]
|
||||
num_chunks = 1 + (t - 1) // 4
|
||||
|
||||
out = None
|
||||
for i in range(num_chunks):
|
||||
feat_idx = [0]
|
||||
if i == 0:
|
||||
chunk = x[:, :, :1]
|
||||
else:
|
||||
chunk = x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i]
|
||||
|
||||
chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
|
||||
if out is None:
|
||||
out = chunk_out
|
||||
else:
|
||||
out = mx.concatenate([out, chunk_out], axis=2)
|
||||
|
||||
mu, _ = mx.split(self.conv1(out), 2, axis=1)
|
||||
|
||||
# Normalize: (mu - mean) * inv_std
|
||||
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
|
||||
return (mu - mean) * inv_std
|
||||
|
||||
def _count_encoder_cache_slots(self) -> int:
|
||||
"""Count CausalConv3d that participate in chunked encoding cache."""
|
||||
count = 1 # encoder.conv1
|
||||
for layer in self.encoder.downsamples:
|
||||
if isinstance(layer, ResidualBlock):
|
||||
count += 2 # two convs in residual path
|
||||
elif isinstance(layer, Resample) and layer.mode == "downsample3d":
|
||||
count += 1 # time_conv
|
||||
for layer in self.encoder.middle:
|
||||
if isinstance(layer, ResidualBlock):
|
||||
count += 2
|
||||
count += 1 # encoder.head CausalConv3d
|
||||
return count
|
||||
|
||||
def decode(self, z: mx.array) -> mx.array:
|
||||
"""Decode latent to video.
|
||||
|
||||
Args:
|
||||
z: Normalized latent [B, z_dim, T, H, W]
|
||||
|
||||
Returns:
|
||||
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
|
||||
"""
|
||||
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
|
||||
z = z / inv_std + mean
|
||||
|
||||
x = self.conv2(z)
|
||||
out = self.decoder(x)
|
||||
return mx.clip(out, -1, 1)
|
||||
|
||||
def decode_tiled(self, z: mx.array, tiling_config=None) -> mx.array:
|
||||
"""Decode latent to video using tiling to reduce memory usage.
|
||||
|
||||
Splits the latent tensor into overlapping spatial/temporal tiles,
|
||||
decodes each tile independently, and blends them with trapezoidal
|
||||
masks. Reuses the LTX-2 tiling infrastructure.
|
||||
|
||||
Args:
|
||||
z: Normalized latent [B, z_dim, T, H, W]
|
||||
tiling_config: Optional TilingConfig. If None, uses default.
|
||||
|
||||
Returns:
|
||||
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
|
||||
"""
|
||||
from mlx_video.models.wan_2.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
if tiling_config is None:
|
||||
tiling_config = TilingConfig.default()
|
||||
|
||||
# Check if tiling is actually needed
|
||||
_, _, f, h, w = z.shape
|
||||
needs_tiling = False
|
||||
if tiling_config.spatial_config is not None:
|
||||
s_tile = tiling_config.spatial_config.tile_size_in_pixels // 8
|
||||
if h > s_tile or w > s_tile:
|
||||
needs_tiling = True
|
||||
if tiling_config.temporal_config is not None:
|
||||
t_tile = tiling_config.temporal_config.tile_size_in_frames // 4
|
||||
if f > t_tile:
|
||||
needs_tiling = True
|
||||
|
||||
if not needs_tiling:
|
||||
return self.decode(z)
|
||||
|
||||
# Denormalize once (small tensor), then tile the denormalized latents
|
||||
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
|
||||
z_denorm = z / inv_std + mean
|
||||
|
||||
def tile_decode(tile_latents, **kwargs):
|
||||
x = self.conv2(tile_latents)
|
||||
out = self.decoder(x)
|
||||
return mx.clip(out, -1, 1)
|
||||
|
||||
return decode_with_tiling(
|
||||
decoder_fn=tile_decode,
|
||||
latents=z_denorm,
|
||||
tiling_config=tiling_config,
|
||||
spatial_scale=8, # 3× spatial 2× upsamples = 8×
|
||||
temporal_scale=4, # 2× temporal upsamples × 2 = 4×
|
||||
causal_temporal=False, # Wan2.1 uses non-causal temporal (T → 4T)
|
||||
)
|
||||
1150
mlx_video/models/wan_2/vae22.py
Normal file
1150
mlx_video/models/wan_2/vae22.py
Normal file
File diff suppressed because it is too large
Load Diff
388
mlx_video/models/wan_2/wan_2.py
Normal file
388
mlx_video/models/wan_2/wan_2.py
Normal file
@@ -0,0 +1,388 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .attention import WanLayerNorm, _linear_dtype
|
||||
from .config import WanModelConfig
|
||||
from .rope import rope_params, rope_precompute_cos_sin
|
||||
from .transformer import WanAttentionBlock
|
||||
|
||||
|
||||
def sinusoidal_embedding_1d(dim: int, position: mx.array) -> mx.array:
|
||||
"""Compute sinusoidal positional embeddings.
|
||||
|
||||
Args:
|
||||
dim: Embedding dimension (must be even).
|
||||
position: Tensor of positions — 1D [L] or 2D [B, L].
|
||||
|
||||
Returns:
|
||||
Embeddings of shape [L, dim] or [B, L, dim].
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
half = dim // 2
|
||||
pos = position.astype(mx.float32)
|
||||
inv_freq = mx.power(10000.0, -mx.arange(half).astype(mx.float32) / half)
|
||||
sinusoid = pos[..., None] * inv_freq # [..., half]
|
||||
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
"""Output projection head with learned modulation."""
|
||||
|
||||
def __init__(self, dim: int, out_dim: int, patch_size: tuple, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.out_dim = out_dim
|
||||
self.patch_size = patch_size
|
||||
proj_dim = math.prod(patch_size) * out_dim
|
||||
self.norm = WanLayerNorm(dim, eps)
|
||||
self.head = nn.Linear(dim, proj_dim)
|
||||
self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(
|
||||
mx.float32
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
|
||||
"""
|
||||
Args:
|
||||
x: [B, L, dim]
|
||||
e: [B, dim] or [B, 1, dim] (broadcast) or [B, L, dim] (per-token)
|
||||
"""
|
||||
if e.ndim == 2:
|
||||
e = e[:, None, :] # [B, 1, dim]
|
||||
# Compute modulation in float32 (matching reference's autocast(float32))
|
||||
mod = self.modulation[:, None, :, :] + e[:, :, None, :] # float32
|
||||
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
|
||||
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
|
||||
x_norm = self.norm(x)
|
||||
x_mod = x_norm * (1 + e1) + e0
|
||||
return self.head(x_mod)
|
||||
|
||||
|
||||
class WanModel(nn.Module):
|
||||
"""Wan2.2 diffusion backbone for text-to-video generation."""
|
||||
|
||||
def __init__(self, config: WanModelConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
dim = config.dim
|
||||
self.dim = dim
|
||||
self.num_heads = config.num_heads
|
||||
self.out_dim = config.out_dim
|
||||
self.patch_size = config.patch_size
|
||||
self.text_len = config.text_len
|
||||
self.freq_dim = config.freq_dim
|
||||
|
||||
# Patch embedding: Conv3d implemented as a reshaped linear
|
||||
# For kernel (1,2,2) and stride (1,2,2): reshape input then linear
|
||||
patch_dim = config.in_dim * math.prod(config.patch_size)
|
||||
self.patch_embedding_proj = nn.Linear(patch_dim, dim)
|
||||
self._patch_size = config.patch_size
|
||||
|
||||
# Text embedding MLP
|
||||
self.text_embedding_0 = nn.Linear(config.text_dim, dim)
|
||||
self.text_embedding_act = nn.GELU(approx="tanh")
|
||||
self.text_embedding_1 = nn.Linear(dim, dim)
|
||||
|
||||
# Time embedding MLP
|
||||
self.time_embedding_0 = nn.Linear(config.freq_dim, dim)
|
||||
self.time_embedding_act = nn.SiLU()
|
||||
self.time_embedding_1 = nn.Linear(dim, dim)
|
||||
|
||||
# Time projection for modulation (6x dim)
|
||||
self.time_projection_act = nn.SiLU()
|
||||
self.time_projection = nn.Linear(dim, dim * 6)
|
||||
|
||||
# Transformer blocks
|
||||
self.blocks = [
|
||||
WanAttentionBlock(
|
||||
dim=dim,
|
||||
ffn_dim=config.ffn_dim,
|
||||
num_heads=config.num_heads,
|
||||
window_size=config.window_size,
|
||||
qk_norm=config.qk_norm,
|
||||
cross_attn_norm=config.cross_attn_norm,
|
||||
eps=config.eps,
|
||||
)
|
||||
for _ in range(config.num_layers)
|
||||
]
|
||||
|
||||
# Output head
|
||||
self.head = Head(dim, config.out_dim, config.patch_size, config.eps)
|
||||
|
||||
# Precompute RoPE frequencies — three separate tables concatenated.
|
||||
# Reference computes three rope_params with different dim normalizations
|
||||
# so each axis (temporal/height/width) gets its own full frequency range.
|
||||
d = dim // config.num_heads
|
||||
self.freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
# Precompute sinusoidal inv_freq for time embedding.
|
||||
half = config.freq_dim // 2
|
||||
self._inv_freq = mx.array(
|
||||
np.power(10000.0, -np.arange(half, dtype=np.float64) / half).astype(
|
||||
np.float32
|
||||
)
|
||||
)
|
||||
|
||||
def _patchify(self, x: mx.array) -> tuple:
|
||||
"""Convert video tensor to patch embeddings.
|
||||
|
||||
Args:
|
||||
x: Video latent [C, F, H, W]
|
||||
|
||||
Returns:
|
||||
(patches, grid_size): patches [1, L, dim], grid_size (F', H', W')
|
||||
"""
|
||||
c, f, h, w = x.shape
|
||||
pt, ph, pw = self._patch_size
|
||||
|
||||
f_out = f // pt
|
||||
h_out = h // ph
|
||||
w_out = w // pw
|
||||
|
||||
# Reshape: [C, F, H, W] -> [F', H', W', C, pt, ph, pw] -> [F'*H'*W', C*pt*ph*pw]
|
||||
# Order must be [C, pt, ph, pw] (C slowest) to match Conv3d weight layout
|
||||
x = x.reshape(c, f_out, pt, h_out, ph, w_out, pw)
|
||||
x = x.transpose(1, 3, 5, 0, 2, 4, 6) # [F', H', W', C, pt, ph, pw]
|
||||
x = x.reshape(f_out * h_out * w_out, -1) # [L, C*pt*ph*pw]
|
||||
|
||||
# Project and cast to model dtype to prevent float32 cascade from input latents
|
||||
patches = self.patch_embedding_proj(x) # [L, dim]
|
||||
patches = patches.astype(_linear_dtype(self.patch_embedding_proj))
|
||||
patches = patches[None, :, :] # [1, L, dim]
|
||||
|
||||
return patches, (f_out, h_out, w_out)
|
||||
|
||||
def unpatchify(self, x: mx.array, grid_sizes: list) -> list:
|
||||
"""Reconstruct video from patch embeddings.
|
||||
|
||||
Args:
|
||||
x: [B, L, out_dim * prod(patch_size)]
|
||||
grid_sizes: List of (F', H', W') per batch element
|
||||
|
||||
Returns:
|
||||
List of tensors [C, F, H, W]
|
||||
"""
|
||||
c = self.out_dim
|
||||
pt, ph, pw = self.patch_size
|
||||
out = []
|
||||
for i, (f, h, w) in enumerate(grid_sizes):
|
||||
seq_len = f * h * w
|
||||
u = x[i, :seq_len] # [L, out_dim * pt * ph * pw]
|
||||
u = u.reshape(f, h, w, pt, ph, pw, c)
|
||||
# Rearrange: [F', H', W', pt, ph, pw, C] -> [C, F'*pt, H'*ph, W'*pw]
|
||||
u = u.transpose(6, 0, 3, 1, 4, 2, 5) # [C, F', pt, H', ph, W', pw]
|
||||
u = u.reshape(c, f * pt, h * ph, w * pw)
|
||||
out.append(u)
|
||||
return out
|
||||
|
||||
def embed_text(self, context: list) -> mx.array:
|
||||
"""Precompute text embeddings (call once, reuse across steps).
|
||||
|
||||
Args:
|
||||
context: List of text embeddings [L_text, text_dim]
|
||||
|
||||
Returns:
|
||||
Embedded context [B, text_len, dim] in model dtype
|
||||
"""
|
||||
model_dtype = _linear_dtype(self.patch_embedding_proj)
|
||||
context_padded = []
|
||||
for ctx in context:
|
||||
pad_len = self.text_len - ctx.shape[0]
|
||||
if pad_len > 0:
|
||||
ctx = mx.concatenate(
|
||||
[ctx, mx.zeros((pad_len, ctx.shape[1]), dtype=ctx.dtype)],
|
||||
axis=0,
|
||||
)
|
||||
context_padded.append(ctx)
|
||||
context_batch = mx.stack(context_padded) # [B, text_len, text_dim]
|
||||
context_batch = self.text_embedding_1(
|
||||
self.text_embedding_act(self.text_embedding_0(context_batch))
|
||||
)
|
||||
return context_batch.astype(model_dtype)
|
||||
|
||||
def prepare_cross_kv(self, context: mx.array) -> list:
|
||||
"""Pre-compute cross-attention K/V for all blocks.
|
||||
|
||||
Call once before the diffusion loop to cache K/V projections,
|
||||
eliminating redundant computation at each denoising step.
|
||||
|
||||
Args:
|
||||
context: Pre-embedded text [B, text_len, dim]
|
||||
|
||||
Returns:
|
||||
List of (k, v) tuples, one per block
|
||||
"""
|
||||
kv_caches = []
|
||||
for block in self.blocks:
|
||||
kv_caches.append(block.cross_attn.prepare_kv(context))
|
||||
return kv_caches
|
||||
|
||||
def prepare_rope(self, grid_sizes: list) -> tuple:
|
||||
"""Pre-compute RoPE cos/sin for constant grid sizes.
|
||||
|
||||
Call once before the diffusion loop when grid sizes don't change
|
||||
across steps. Eliminates per-step broadcast/concat overhead.
|
||||
|
||||
Args:
|
||||
grid_sizes: List of (F, H, W) tuples per batch element
|
||||
|
||||
Returns:
|
||||
(cos_f, sin_f) precomputed frequency tensors
|
||||
"""
|
||||
w_dtype = _linear_dtype(self.patch_embedding_proj)
|
||||
return rope_precompute_cos_sin(grid_sizes, self.freqs, dtype=w_dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x_list: list,
|
||||
t: mx.array,
|
||||
context: list | mx.array,
|
||||
seq_len: int,
|
||||
cross_kv_caches: list | None = None,
|
||||
y: list | None = None,
|
||||
rope_cos_sin: tuple | None = None,
|
||||
) -> list:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x_list: List of video latent tensors [C, F, H, W]
|
||||
t: Timestep tensor [B]
|
||||
context: List of raw text embeddings, OR pre-embedded tensor
|
||||
from embed_text() [B, text_len, dim]
|
||||
seq_len: Maximum sequence length for padding
|
||||
cross_kv_caches: Optional list of (k, v) tuples from
|
||||
prepare_cross_kv(), one per block.
|
||||
y: Optional list of conditioning tensors for I2V [C_y, F, H, W].
|
||||
Channel-concatenated with x before patchify.
|
||||
rope_cos_sin: Optional precomputed (cos, sin) from prepare_rope().
|
||||
|
||||
Returns:
|
||||
List of denoised tensors [C, F, H, W]
|
||||
"""
|
||||
# Detect identical inputs (CFG B=2) to avoid duplicate patchify work.
|
||||
# Check BEFORE I2V concat since concat creates new array objects.
|
||||
batch_size = len(x_list)
|
||||
all_same = batch_size > 1 and all(
|
||||
x_list[i] is x_list[0] for i in range(1, batch_size)
|
||||
)
|
||||
if all_same and y is not None:
|
||||
all_same = all(y[i] is y[0] for i in range(1, len(y)))
|
||||
|
||||
# I2V: channel-concatenate conditioning y with noise x
|
||||
if y is not None:
|
||||
x_list = [mx.concatenate([u, v], axis=0) for u, v in zip(x_list, y)]
|
||||
|
||||
if all_same:
|
||||
# Patchify once and broadcast — saves a Linear projection per step
|
||||
p, gs = self._patchify(x_list[0]) # [1, L, dim]
|
||||
grid_sizes = [gs] * batch_size
|
||||
seq_lens_list = [p.shape[1]] * batch_size
|
||||
# Pad and broadcast
|
||||
if p.shape[1] < seq_len:
|
||||
p = mx.concatenate(
|
||||
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
|
||||
axis=1,
|
||||
)
|
||||
x = mx.broadcast_to(p, (batch_size,) + p.shape[1:])
|
||||
else:
|
||||
patches = []
|
||||
grid_sizes = []
|
||||
seq_lens_list = []
|
||||
for vid in x_list:
|
||||
p, gs = self._patchify(vid) # [1, L, dim]
|
||||
patches.append(p)
|
||||
grid_sizes.append(gs)
|
||||
seq_lens_list.append(p.shape[1])
|
||||
x = mx.concatenate(
|
||||
[
|
||||
(
|
||||
mx.concatenate(
|
||||
[
|
||||
p,
|
||||
mx.zeros(
|
||||
(1, seq_len - p.shape[1], self.dim), dtype=p.dtype
|
||||
),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
if p.shape[1] < seq_len
|
||||
else p
|
||||
)
|
||||
for p in patches
|
||||
],
|
||||
axis=0,
|
||||
) # [B, seq_len, dim]
|
||||
|
||||
# Time embedding: sinusoidal from precomputed inv_freq.
|
||||
# inv_freq was computed in float64 for precision, stored as float32.
|
||||
# With integer timesteps (matching reference), float32 sin/cos is fine.
|
||||
if t.ndim == 0:
|
||||
t = t[None]
|
||||
|
||||
sinusoid = t[..., None].astype(mx.float32) * self._inv_freq
|
||||
sin_emb = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
|
||||
|
||||
if t.ndim == 1:
|
||||
# Standard T2V: scalar timestep per batch element [B]
|
||||
e = self.time_embedding_1(
|
||||
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||
) # [B, dim]
|
||||
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
|
||||
e0 = e0.reshape(batch_size, 1, 6, self.dim)
|
||||
else:
|
||||
# I2V: per-token timesteps [B, L]
|
||||
e = self.time_embedding_1(
|
||||
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||
) # [B, L, dim]
|
||||
e0 = self.time_projection(self.time_projection_act(e)) # [B, L, dim*6]
|
||||
e0 = e0.reshape(batch_size, -1, 6, self.dim)
|
||||
|
||||
# Text embedding: skip MLP if context is already embedded (mx.array)
|
||||
if isinstance(context, mx.array):
|
||||
# Pre-embedded: expand to batch size if needed
|
||||
context_batch = context
|
||||
if context_batch.shape[0] == 1 and batch_size > 1:
|
||||
context_batch = mx.broadcast_to(
|
||||
context_batch, (batch_size,) + context_batch.shape[1:]
|
||||
)
|
||||
else:
|
||||
context_batch = self.embed_text(context)
|
||||
|
||||
# Pre-compute attention mask from seq_lens (constant across all blocks)
|
||||
attn_mask = None
|
||||
w_dtype = _linear_dtype(self.patch_embedding_proj)
|
||||
if any(sl < seq_len for sl in seq_lens_list):
|
||||
attn_mask = mx.zeros((batch_size, 1, 1, seq_len), dtype=w_dtype)
|
||||
for i, sl in enumerate(seq_lens_list):
|
||||
attn_mask[i, :, :, sl:] = -1e9
|
||||
|
||||
kwargs = dict(
|
||||
e=e0,
|
||||
seq_lens=seq_lens_list,
|
||||
grid_sizes=grid_sizes,
|
||||
freqs=self.freqs,
|
||||
context=context_batch,
|
||||
context_lens=None,
|
||||
rope_cos_sin=rope_cos_sin,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
|
||||
# Run transformer blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
kv = cross_kv_caches[i] if cross_kv_caches is not None else None
|
||||
x = block(x, cross_kv_cache=kv, **kwargs)
|
||||
|
||||
# Output head
|
||||
x = self.head(x, e)
|
||||
|
||||
# Unpatchify
|
||||
outputs = self.unpatchify(x, grid_sizes)
|
||||
return [u.astype(mx.float32) for u in outputs]
|
||||
@@ -1,14 +1,15 @@
|
||||
import math
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_model_path(model_repo: str):
|
||||
"""Get or download LTX-2 model path."""
|
||||
try:
|
||||
@@ -17,15 +18,19 @@ def get_model_path(model_repo: str):
|
||||
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
|
||||
except Exception:
|
||||
print("Downloading LTX-2 model weights...")
|
||||
return Path(snapshot_download(
|
||||
repo_id=model_repo,
|
||||
local_files_only=False,
|
||||
resume_download=True,
|
||||
allow_patterns=["*.safetensors", "*.json"],
|
||||
))
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo_id=model_repo,
|
||||
local_files_only=False,
|
||||
resume_download=True,
|
||||
allow_patterns=["*.safetensors", "*.json"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
|
||||
if quantization is not None:
|
||||
|
||||
def get_class_predicate(p, m):
|
||||
# Handle custom per layer quantizations
|
||||
if p in quantization:
|
||||
@@ -46,17 +51,15 @@ def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
|
||||
class_predicate=get_class_predicate,
|
||||
)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
|
||||
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],), dtype=x.dtype), eps)
|
||||
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def to_denoised(
|
||||
noisy: mx.array,
|
||||
velocity: mx.array,
|
||||
sigma: mx.array | float
|
||||
noisy: mx.array, velocity: mx.array, sigma: mx.array | float
|
||||
) -> mx.array:
|
||||
"""Convert velocity prediction to denoised output.
|
||||
|
||||
@@ -284,7 +287,9 @@ def prepare_image_for_encoding(
|
||||
if image_np.max() <= 1.0:
|
||||
image_np = (image_np * 255).astype(np.uint8)
|
||||
pil_image = Image.fromarray(image_np)
|
||||
pil_image = pil_image.resize((target_width, target_height), Image.Resampling.LANCZOS)
|
||||
pil_image = pil_image.resize(
|
||||
(target_width, target_height), Image.Resampling.LANCZOS
|
||||
)
|
||||
image = mx.array(np.array(pil_image).astype(np.float32) / 255.0)
|
||||
|
||||
# Normalize to [-1, 1]
|
||||
|
||||
@@ -22,6 +22,9 @@ dependencies = [
|
||||
"mlx-vlm",
|
||||
"rich>=14.2.0",
|
||||
"librosa>=0.10.0",
|
||||
"imageio>=2.37.2",
|
||||
"imageio-ffmpeg>=0.6.0",
|
||||
"ftfy",
|
||||
]
|
||||
license = {text="MIT"}
|
||||
authors = [
|
||||
@@ -43,7 +46,8 @@ Repository = "https://github.com/Blaizzy/mlx-video"
|
||||
Issues = "https://github.com/Blaizzy/mlx-video/issues"
|
||||
|
||||
[project.scripts]
|
||||
"mlx_video.generate" = "mlx_video.generate:main"
|
||||
"mlx_video.ltx_2.generate" = "mlx_video.models.ltx_2.generate:main"
|
||||
"mlx_video.wan_2.generate" = "mlx_video.models.wan_2.generate:main"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["mlx_video*"]
|
||||
@@ -55,3 +59,4 @@ version = {attr = "mlx_video.version.__version__"}
|
||||
dev = [
|
||||
"pytest",
|
||||
]
|
||||
|
||||
|
||||
331
scripts/video/compare_videos.py
Normal file
331
scripts/video/compare_videos.py
Normal file
@@ -0,0 +1,331 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Compare two videos frame-by-frame with quality metrics.
|
||||
|
||||
Useful for validating MLX ports against reference PyTorch implementations.
|
||||
Reports PSNR, SSIM, per-frame differences, temporal coherence, and color
|
||||
fidelity. Optionally saves a side-by-side diff video.
|
||||
|
||||
Usage:
|
||||
# Basic comparison
|
||||
python scripts/video/compare_videos.py reference.mp4 test.mp4
|
||||
|
||||
# Save side-by-side diff visualization
|
||||
python scripts/video/compare_videos.py ref.mp4 test.mp4 --diff-video diff.mp4
|
||||
|
||||
# Compare only first 64 frames
|
||||
python scripts/video/compare_videos.py ref.mp4 test.mp4 --max-frames 64
|
||||
|
||||
# Adjust SSIM window size (default: 7)
|
||||
python scripts/video/compare_videos.py ref.mp4 test.mp4 --ssim-win 11
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_video(path, max_frames=None):
|
||||
"""Load video frames as float32 numpy arrays (0-255 range)."""
|
||||
cap = cv2.VideoCapture(path)
|
||||
if not cap.isOpened():
|
||||
print(f"Error: cannot open {path}")
|
||||
sys.exit(1)
|
||||
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frames = []
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
frames.append(frame.astype(np.float32))
|
||||
if max_frames and len(frames) >= max_frames:
|
||||
break
|
||||
cap.release()
|
||||
return frames, fps
|
||||
|
||||
|
||||
def compute_psnr(a, b):
|
||||
"""Peak Signal-to-Noise Ratio between two frames."""
|
||||
mse = np.mean((a - b) ** 2)
|
||||
if mse == 0:
|
||||
return float("inf")
|
||||
return 10 * np.log10(255.0**2 / mse)
|
||||
|
||||
|
||||
def compute_ssim(a, b, win_size=7):
|
||||
"""Structural Similarity Index (per-channel, averaged).
|
||||
|
||||
Uses the standard SSIM formula with default constants.
|
||||
"""
|
||||
C1 = (0.01 * 255) ** 2
|
||||
C2 = (0.03 * 255) ** 2
|
||||
|
||||
kernel = cv2.getGaussianKernel(win_size, 1.5)
|
||||
window = kernel @ kernel.T
|
||||
|
||||
ssim_channels = []
|
||||
for c in range(a.shape[2]):
|
||||
ac, bc = a[:, :, c], b[:, :, c]
|
||||
mu_a = cv2.filter2D(ac, -1, window)
|
||||
mu_b = cv2.filter2D(bc, -1, window)
|
||||
|
||||
mu_a_sq = mu_a**2
|
||||
mu_b_sq = mu_b**2
|
||||
mu_ab = mu_a * mu_b
|
||||
|
||||
sigma_a_sq = cv2.filter2D(ac**2, -1, window) - mu_a_sq
|
||||
sigma_b_sq = cv2.filter2D(bc**2, -1, window) - mu_b_sq
|
||||
sigma_ab = cv2.filter2D(ac * bc, -1, window) - mu_ab
|
||||
|
||||
num = (2 * mu_ab + C1) * (2 * sigma_ab + C2)
|
||||
den = (mu_a_sq + mu_b_sq + C1) * (sigma_a_sq + sigma_b_sq + C2)
|
||||
ssim_map = num / den
|
||||
ssim_channels.append(np.mean(ssim_map))
|
||||
|
||||
return np.mean(ssim_channels)
|
||||
|
||||
|
||||
def temporal_coherence(frames):
|
||||
"""Mean frame-to-frame difference (lower = smoother)."""
|
||||
if len(frames) < 2:
|
||||
return 0.0
|
||||
diffs = []
|
||||
for i in range(1, len(frames)):
|
||||
diffs.append(np.mean(np.abs(frames[i] - frames[i - 1])))
|
||||
return np.mean(diffs)
|
||||
|
||||
|
||||
def color_histogram_distance(a, b, bins=64):
|
||||
"""Chi-squared distance between color histograms."""
|
||||
dist = 0.0
|
||||
for c in range(3):
|
||||
ha, _ = np.histogram(a[:, :, c], bins=bins, range=(0, 256))
|
||||
hb, _ = np.histogram(b[:, :, c], bins=bins, range=(0, 256))
|
||||
ha = ha.astype(np.float64) / (ha.sum() + 1e-10)
|
||||
hb = hb.astype(np.float64) / (hb.sum() + 1e-10)
|
||||
dist += np.sum((ha - hb) ** 2 / (ha + hb + 1e-10)) / 2
|
||||
return dist / 3
|
||||
|
||||
|
||||
def make_diff_frame(a, b, scale=5.0):
|
||||
"""Create a heatmap visualization of the absolute difference."""
|
||||
diff = np.mean(np.abs(a - b), axis=2)
|
||||
diff_scaled = np.clip(diff * scale, 0, 255).astype(np.uint8)
|
||||
heatmap = cv2.applyColorMap(diff_scaled, cv2.COLORMAP_JET)
|
||||
return heatmap
|
||||
|
||||
|
||||
def analyze(ref_frames, test_frames, ssim_win=7):
|
||||
"""Compute per-frame and aggregate metrics."""
|
||||
n = min(len(ref_frames), len(test_frames))
|
||||
|
||||
psnrs = []
|
||||
ssims = []
|
||||
mean_diffs = []
|
||||
max_diffs = []
|
||||
color_dists = []
|
||||
|
||||
for i in range(n):
|
||||
r, t = ref_frames[i], test_frames[i]
|
||||
psnrs.append(compute_psnr(r, t))
|
||||
ssims.append(compute_ssim(r, t, ssim_win))
|
||||
absdiff = np.abs(r - t)
|
||||
mean_diffs.append(np.mean(absdiff))
|
||||
max_diffs.append(np.max(absdiff))
|
||||
color_dists.append(color_histogram_distance(r, t))
|
||||
|
||||
ref_tc = temporal_coherence(ref_frames[:n])
|
||||
test_tc = temporal_coherence(test_frames[:n])
|
||||
|
||||
return {
|
||||
"num_frames": n,
|
||||
"psnr": np.array(psnrs),
|
||||
"ssim": np.array(ssims),
|
||||
"mean_diff": np.array(mean_diffs),
|
||||
"max_diff": np.array(max_diffs),
|
||||
"color_dist": np.array(color_dists),
|
||||
"ref_temporal_coherence": ref_tc,
|
||||
"test_temporal_coherence": test_tc,
|
||||
}
|
||||
|
||||
|
||||
def print_report(results, ref_path, test_path):
|
||||
"""Print a formatted comparison report."""
|
||||
n = results["num_frames"]
|
||||
psnr = results["psnr"]
|
||||
ssim = results["ssim"]
|
||||
md = results["mean_diff"]
|
||||
mx = results["max_diff"]
|
||||
cd = results["color_dist"]
|
||||
|
||||
print("=" * 72)
|
||||
print("VIDEO COMPARISON REPORT")
|
||||
print("=" * 72)
|
||||
print(f" Reference: {ref_path}")
|
||||
print(f" Test: {test_path}")
|
||||
print(f" Frames compared: {n}")
|
||||
print()
|
||||
|
||||
print("AGGREGATE METRICS")
|
||||
print("-" * 40)
|
||||
print(
|
||||
f" PSNR (dB): mean={np.mean(psnr):6.2f} min={np.min(psnr):6.2f} max={np.max(psnr):6.2f}"
|
||||
)
|
||||
print(
|
||||
f" SSIM: mean={np.mean(ssim):.4f} min={np.min(ssim):.4f} max={np.max(ssim):.4f}"
|
||||
)
|
||||
print(
|
||||
f" Mean diff: mean={np.mean(md):6.2f} min={np.min(md):6.2f} max={np.max(md):6.2f}"
|
||||
)
|
||||
print(
|
||||
f" Max diff: mean={np.mean(mx):6.1f} min={np.min(mx):6.1f} max={np.max(mx):6.1f}"
|
||||
)
|
||||
print(
|
||||
f" Color dist: mean={np.mean(cd):.4f} min={np.min(cd):.4f} max={np.max(cd):.4f}"
|
||||
)
|
||||
print()
|
||||
|
||||
print("TEMPORAL COHERENCE (mean frame-to-frame diff, lower = smoother)")
|
||||
print("-" * 40)
|
||||
print(f" Reference: {results['ref_temporal_coherence']:.2f}")
|
||||
print(f" Test: {results['test_temporal_coherence']:.2f}")
|
||||
ratio = results["test_temporal_coherence"] / (
|
||||
results["ref_temporal_coherence"] + 1e-10
|
||||
)
|
||||
print(
|
||||
f" Ratio: {ratio:.2f}x {'(test is smoother)' if ratio < 1 else '(test is jerkier)' if ratio > 1.05 else '(similar)'}"
|
||||
)
|
||||
print()
|
||||
|
||||
# Identify worst frames
|
||||
print("WORST FRAMES (by PSNR)")
|
||||
print("-" * 40)
|
||||
worst_idx = np.argsort(psnr)[:5]
|
||||
for i in worst_idx:
|
||||
print(
|
||||
f" Frame {i:4d}: PSNR={psnr[i]:6.2f} dB SSIM={ssim[i]:.4f} mean_diff={md[i]:.2f}"
|
||||
)
|
||||
print()
|
||||
|
||||
# Quality assessment
|
||||
mean_psnr = np.mean(psnr)
|
||||
mean_ssim = np.mean(ssim)
|
||||
print("QUALITY ASSESSMENT")
|
||||
print("-" * 40)
|
||||
if mean_psnr > 40:
|
||||
grade = "Excellent"
|
||||
elif mean_psnr > 35:
|
||||
grade = "Good"
|
||||
elif mean_psnr > 30:
|
||||
grade = "Fair"
|
||||
elif mean_psnr > 25:
|
||||
grade = "Poor"
|
||||
else:
|
||||
grade = "Very different"
|
||||
print(f" Overall: {grade} (PSNR={mean_psnr:.1f} dB, SSIM={mean_ssim:.4f})")
|
||||
if mean_psnr < 30:
|
||||
print(
|
||||
" ⚠ Videos differ significantly — likely a bug or different generation seed"
|
||||
)
|
||||
print("=" * 72)
|
||||
|
||||
|
||||
def save_diff_video(ref_frames, test_frames, output_path, fps, scale=5.0):
|
||||
"""Save a side-by-side video: reference | test | diff heatmap."""
|
||||
n = min(len(ref_frames), len(test_frames))
|
||||
h, w = ref_frames[0].shape[:2]
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||
out = cv2.VideoWriter(output_path, fourcc, fps, (w * 3, h))
|
||||
|
||||
for i in range(n):
|
||||
r = ref_frames[i].astype(np.uint8)
|
||||
t = test_frames[i].astype(np.uint8)
|
||||
d = make_diff_frame(ref_frames[i], test_frames[i], scale)
|
||||
combined = np.hstack([r, t, d])
|
||||
out.write(combined)
|
||||
|
||||
out.release()
|
||||
print(f"Diff video saved to {output_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compare two videos frame-by-frame with quality metrics"
|
||||
)
|
||||
parser.add_argument("reference", help="Path to reference video")
|
||||
parser.add_argument("test", help="Path to test video")
|
||||
parser.add_argument(
|
||||
"--diff-video", help="Save side-by-side diff visualization to this path"
|
||||
)
|
||||
parser.add_argument("--max-frames", type=int, help="Compare only first N frames")
|
||||
parser.add_argument(
|
||||
"--ssim-win", type=int, default=7, help="SSIM window size (default: 7)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--diff-scale",
|
||||
type=float,
|
||||
default=5.0,
|
||||
help="Diff heatmap amplification (default: 5.0)",
|
||||
)
|
||||
parser.add_argument("--csv", help="Export per-frame metrics to CSV file")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Loading reference: {args.reference}")
|
||||
ref_frames, ref_fps = load_video(args.reference, args.max_frames)
|
||||
print(
|
||||
f" → {len(ref_frames)} frames, {ref_fps:.1f} fps, {ref_frames[0].shape[1]}x{ref_frames[0].shape[0]}"
|
||||
)
|
||||
|
||||
print(f"Loading test: {args.test}")
|
||||
test_frames, test_fps = load_video(args.test, args.max_frames)
|
||||
print(
|
||||
f" → {len(test_frames)} frames, {test_fps:.1f} fps, {test_frames[0].shape[1]}x{test_frames[0].shape[0]}"
|
||||
)
|
||||
|
||||
if ref_frames[0].shape != test_frames[0].shape:
|
||||
print(
|
||||
f"Warning: resolution mismatch {ref_frames[0].shape} vs {test_frames[0].shape}"
|
||||
)
|
||||
print("Resizing test frames to match reference...")
|
||||
h, w = ref_frames[0].shape[:2]
|
||||
test_frames = [
|
||||
cv2.resize(f, (w, h), interpolation=cv2.INTER_LANCZOS4) for f in test_frames
|
||||
]
|
||||
|
||||
print("Computing metrics...")
|
||||
results = analyze(ref_frames, test_frames, args.ssim_win)
|
||||
print()
|
||||
print_report(results, args.reference, args.test)
|
||||
|
||||
if args.diff_video:
|
||||
save_diff_video(
|
||||
ref_frames, test_frames, args.diff_video, ref_fps, args.diff_scale
|
||||
)
|
||||
|
||||
if args.csv:
|
||||
import csv
|
||||
|
||||
with open(args.csv, "w", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(
|
||||
["frame", "psnr", "ssim", "mean_diff", "max_diff", "color_dist"]
|
||||
)
|
||||
for i in range(results["num_frames"]):
|
||||
writer.writerow(
|
||||
[
|
||||
i,
|
||||
f"{results['psnr'][i]:.4f}",
|
||||
f"{results['ssim'][i]:.6f}",
|
||||
f"{results['mean_diff'][i]:.4f}",
|
||||
f"{results['max_diff'][i]:.1f}",
|
||||
f"{results['color_dist'][i]:.6f}",
|
||||
]
|
||||
)
|
||||
print(f"Per-frame metrics saved to {args.csv}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
384
scripts/video/video_quality.py
Normal file
384
scripts/video/video_quality.py
Normal file
@@ -0,0 +1,384 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Analyze quality of a single generated video.
|
||||
|
||||
Reports sharpness, temporal stability, color distribution, motion smoothness,
|
||||
chunk boundary artifacts, and common generation defects. Useful for quick
|
||||
quality checks during model porting and debugging.
|
||||
|
||||
Usage:
|
||||
# Basic analysis
|
||||
python scripts/video/video_quality.py output.mp4
|
||||
|
||||
# With chunk boundary analysis (e.g., 32 frames/chunk)
|
||||
python scripts/video/video_quality.py output.mp4 --chunk-size 32
|
||||
|
||||
# Detailed per-frame CSV export
|
||||
python scripts/video/video_quality.py output.mp4 --csv metrics.csv
|
||||
|
||||
# Analyze specific frame range
|
||||
python scripts/video/video_quality.py output.mp4 --start 0 --end 64
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_video(path, start=0, end=None):
|
||||
"""Load video frames as float32 numpy arrays (0-255 range)."""
|
||||
cap = cv2.VideoCapture(path)
|
||||
if not cap.isOpened():
|
||||
print(f"Error: cannot open {path}")
|
||||
sys.exit(1)
|
||||
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
if start > 0:
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, start)
|
||||
|
||||
frames = []
|
||||
idx = start
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
frames.append(frame.astype(np.float32))
|
||||
idx += 1
|
||||
if end and idx >= end:
|
||||
break
|
||||
cap.release()
|
||||
return frames, fps, total
|
||||
|
||||
|
||||
def sharpness_laplacian(frame):
|
||||
"""Laplacian variance — higher means sharper."""
|
||||
gray = cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_BGR2GRAY)
|
||||
return cv2.Laplacian(gray, cv2.CV_64F).var()
|
||||
|
||||
|
||||
def sharpness_gradient(frame):
|
||||
"""Mean gradient magnitude — higher means more edges/detail."""
|
||||
gray = cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_BGR2GRAY).astype(np.float32)
|
||||
gx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3)
|
||||
gy = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3)
|
||||
return np.mean(np.sqrt(gx**2 + gy**2))
|
||||
|
||||
|
||||
def color_stats(frame):
|
||||
"""Per-channel mean and std in BGR order."""
|
||||
means = [np.mean(frame[:, :, c]) for c in range(3)]
|
||||
stds = [np.std(frame[:, :, c]) for c in range(3)]
|
||||
return means, stds
|
||||
|
||||
|
||||
def detect_uniform_color(frame, std_threshold=15.0):
|
||||
"""Detect if frame is near-uniform (common failure mode)."""
|
||||
return np.std(frame) < std_threshold
|
||||
|
||||
|
||||
def detect_noise(frame, threshold=200.0):
|
||||
"""High Laplacian variance with low gradient can indicate noise."""
|
||||
lap = sharpness_laplacian(frame)
|
||||
grad = sharpness_gradient(frame)
|
||||
# Noise has high variance but less coherent edges
|
||||
return lap > threshold and grad < 5.0
|
||||
|
||||
|
||||
def frame_difference(a, b):
|
||||
"""Mean absolute pixel difference between frames."""
|
||||
return np.mean(np.abs(a - b))
|
||||
|
||||
|
||||
def optical_flow_magnitude(prev, curr):
|
||||
"""Mean optical flow magnitude (Farneback method)."""
|
||||
prev_gray = cv2.cvtColor(prev.astype(np.uint8), cv2.COLOR_BGR2GRAY)
|
||||
curr_gray = cv2.cvtColor(curr.astype(np.uint8), cv2.COLOR_BGR2GRAY)
|
||||
flow = cv2.calcOpticalFlowFarneback(
|
||||
prev_gray, curr_gray, None, 0.5, 3, 15, 3, 5, 1.2, 0
|
||||
)
|
||||
mag = np.sqrt(flow[..., 0] ** 2 + flow[..., 1] ** 2)
|
||||
return np.mean(mag), np.max(mag)
|
||||
|
||||
|
||||
def analyze_video(frames, chunk_size=None, compute_flow=False):
|
||||
"""Compute per-frame and aggregate quality metrics."""
|
||||
n = len(frames)
|
||||
|
||||
metrics = {
|
||||
"sharpness_lap": [],
|
||||
"sharpness_grad": [],
|
||||
"brightness": [],
|
||||
"contrast": [],
|
||||
"color_mean_b": [],
|
||||
"color_mean_g": [],
|
||||
"color_mean_r": [],
|
||||
"frame_diff": [],
|
||||
"is_uniform": [],
|
||||
"is_noisy": [],
|
||||
}
|
||||
if compute_flow:
|
||||
metrics["flow_mean"] = []
|
||||
metrics["flow_max"] = []
|
||||
|
||||
for i in range(n):
|
||||
f = frames[i]
|
||||
metrics["sharpness_lap"].append(sharpness_laplacian(f))
|
||||
metrics["sharpness_grad"].append(sharpness_gradient(f))
|
||||
metrics["brightness"].append(np.mean(f))
|
||||
metrics["contrast"].append(np.std(f))
|
||||
means, _ = color_stats(f)
|
||||
metrics["color_mean_b"].append(means[0])
|
||||
metrics["color_mean_g"].append(means[1])
|
||||
metrics["color_mean_r"].append(means[2])
|
||||
metrics["is_uniform"].append(detect_uniform_color(f))
|
||||
metrics["is_noisy"].append(detect_noise(f))
|
||||
|
||||
if i > 0:
|
||||
metrics["frame_diff"].append(frame_difference(frames[i - 1], f))
|
||||
if compute_flow:
|
||||
fm, fmx = optical_flow_magnitude(frames[i - 1], f)
|
||||
metrics["flow_mean"].append(fm)
|
||||
metrics["flow_max"].append(fmx)
|
||||
else:
|
||||
metrics["frame_diff"].append(0.0)
|
||||
if compute_flow:
|
||||
metrics["flow_mean"].append(0.0)
|
||||
metrics["flow_max"].append(0.0)
|
||||
|
||||
# Convert to arrays
|
||||
for k in metrics:
|
||||
metrics[k] = np.array(metrics[k])
|
||||
|
||||
# Chunk boundary analysis
|
||||
if chunk_size and n > chunk_size:
|
||||
boundaries = list(range(chunk_size, n, chunk_size))
|
||||
boundary_metrics = []
|
||||
for b in boundaries:
|
||||
if b < n and b > 0:
|
||||
pre = (
|
||||
metrics["frame_diff"][b - 1] if b > 1 else metrics["frame_diff"][1]
|
||||
)
|
||||
at = metrics["frame_diff"][b]
|
||||
ratio = at / (pre + 1e-10)
|
||||
brightness_jump = (
|
||||
metrics["brightness"][b] - metrics["brightness"][b - 1]
|
||||
)
|
||||
contrast_jump = (
|
||||
(metrics["contrast"][b] - metrics["contrast"][b - 1])
|
||||
/ (metrics["contrast"][b - 1] + 1e-10)
|
||||
* 100
|
||||
)
|
||||
sharpness_jump = (
|
||||
(metrics["sharpness_lap"][b] - metrics["sharpness_lap"][b - 1])
|
||||
/ (metrics["sharpness_lap"][b - 1] + 1e-10)
|
||||
* 100
|
||||
)
|
||||
boundary_metrics.append(
|
||||
{
|
||||
"frame": b,
|
||||
"diff_ratio": ratio,
|
||||
"brightness_jump": brightness_jump,
|
||||
"contrast_jump_pct": contrast_jump,
|
||||
"sharpness_jump_pct": sharpness_jump,
|
||||
}
|
||||
)
|
||||
metrics["boundaries"] = boundary_metrics
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def print_report(metrics, path, fps, total_frames, frames_analyzed):
|
||||
"""Print a formatted quality report."""
|
||||
sl = metrics["sharpness_lap"]
|
||||
sg = metrics["sharpness_grad"]
|
||||
br = metrics["brightness"]
|
||||
ct = metrics["contrast"]
|
||||
fd = metrics["frame_diff"]
|
||||
|
||||
print("=" * 72)
|
||||
print("VIDEO QUALITY REPORT")
|
||||
print("=" * 72)
|
||||
print(f" File: {path}")
|
||||
print(
|
||||
f" Total frames: {total_frames} Analyzed: {frames_analyzed} FPS: {fps:.1f}"
|
||||
)
|
||||
duration = total_frames / fps if fps > 0 else 0
|
||||
print(f" Duration: {duration:.1f}s")
|
||||
print()
|
||||
|
||||
# Defect detection
|
||||
n_uniform = int(np.sum(metrics["is_uniform"]))
|
||||
n_noisy = int(np.sum(metrics["is_noisy"]))
|
||||
if n_uniform > 0 or n_noisy > 0:
|
||||
print("⚠ DEFECTS DETECTED")
|
||||
print("-" * 40)
|
||||
if n_uniform:
|
||||
frames_list = np.where(metrics["is_uniform"])[0][:10]
|
||||
print(
|
||||
f" Uniform/blank frames: {n_uniform} — frames {list(frames_list)}{'...' if n_uniform > 10 else ''}"
|
||||
)
|
||||
if n_noisy:
|
||||
frames_list = np.where(metrics["is_noisy"])[0][:10]
|
||||
print(
|
||||
f" Noisy frames: {n_noisy} — frames {list(frames_list)}{'...' if n_noisy > 10 else ''}"
|
||||
)
|
||||
print()
|
||||
|
||||
print("SHARPNESS")
|
||||
print("-" * 40)
|
||||
print(
|
||||
f" Laplacian var: mean={np.mean(sl):8.1f} min={np.min(sl):8.1f} max={np.max(sl):8.1f} std={np.std(sl):.1f}"
|
||||
)
|
||||
print(
|
||||
f" Gradient mag: mean={np.mean(sg):8.2f} min={np.min(sg):8.2f} max={np.max(sg):8.2f} std={np.std(sg):.2f}"
|
||||
)
|
||||
if np.std(sl) / (np.mean(sl) + 1e-10) > 0.3:
|
||||
print(" ⚠ High sharpness variation — possible blur artifacts")
|
||||
print()
|
||||
|
||||
print("BRIGHTNESS & CONTRAST")
|
||||
print("-" * 40)
|
||||
print(
|
||||
f" Brightness: mean={np.mean(br):6.1f} min={np.min(br):6.1f} max={np.max(br):6.1f} std={np.std(br):.2f}"
|
||||
)
|
||||
print(
|
||||
f" Contrast (std): mean={np.mean(ct):6.1f} min={np.min(ct):6.1f} max={np.max(ct):6.1f} std={np.std(ct):.2f}"
|
||||
)
|
||||
if np.std(br) > 3.0:
|
||||
print(" ⚠ Brightness instability — may indicate chunk boundary artifacts")
|
||||
print()
|
||||
|
||||
print("COLOR DISTRIBUTION (BGR)")
|
||||
print("-" * 40)
|
||||
print(
|
||||
f" Blue: mean={np.mean(metrics['color_mean_b']):6.1f} std={np.std(metrics['color_mean_b']):.2f}"
|
||||
)
|
||||
print(
|
||||
f" Green: mean={np.mean(metrics['color_mean_g']):6.1f} std={np.std(metrics['color_mean_g']):.2f}"
|
||||
)
|
||||
print(
|
||||
f" Red: mean={np.mean(metrics['color_mean_r']):6.1f} std={np.std(metrics['color_mean_r']):.2f}"
|
||||
)
|
||||
print()
|
||||
|
||||
print("TEMPORAL STABILITY")
|
||||
print("-" * 40)
|
||||
fd_nz = fd[1:] # skip first frame (always 0)
|
||||
if len(fd_nz) > 0:
|
||||
print(
|
||||
f" Frame diff: mean={np.mean(fd_nz):6.2f} min={np.min(fd_nz):6.2f} max={np.max(fd_nz):6.2f} std={np.std(fd_nz):.2f}"
|
||||
)
|
||||
if np.std(fd_nz) / (np.mean(fd_nz) + 1e-10) > 0.5:
|
||||
print(" ⚠ High diff variance — jitter or discontinuities")
|
||||
if "flow_mean" in metrics:
|
||||
fm = metrics["flow_mean"][1:]
|
||||
print(
|
||||
f" Optical flow: mean={np.mean(fm):6.2f} max_frame={np.max(metrics['flow_max'][1:]):.1f}"
|
||||
)
|
||||
print()
|
||||
|
||||
# Chunk boundaries
|
||||
if "boundaries" in metrics and metrics["boundaries"]:
|
||||
print("CHUNK BOUNDARIES")
|
||||
print("-" * 40)
|
||||
print(
|
||||
f" {'Frame':>6} {'Diff ratio':>10} {'Brightness':>10} {'Contrast %':>10} {'Sharpness %':>11}"
|
||||
)
|
||||
for bm in metrics["boundaries"]:
|
||||
print(
|
||||
f" {bm['frame']:6d}"
|
||||
f" {bm['diff_ratio']:10.2f}x"
|
||||
f" {bm['brightness_jump']:+10.1f}"
|
||||
f" {bm['contrast_jump_pct']:+10.1f}%"
|
||||
f" {bm['sharpness_jump_pct']:+11.1f}%"
|
||||
)
|
||||
avg_ratio = np.mean([b["diff_ratio"] for b in metrics["boundaries"]])
|
||||
if avg_ratio > 2.0:
|
||||
print(
|
||||
f" ⚠ Boundary diff ratio {avg_ratio:.1f}x — visible chunk transitions"
|
||||
)
|
||||
print()
|
||||
|
||||
# Overall grade
|
||||
print("OVERALL ASSESSMENT")
|
||||
print("-" * 40)
|
||||
issues = []
|
||||
if n_uniform > 0:
|
||||
issues.append("uniform/blank frames")
|
||||
if n_noisy > 0:
|
||||
issues.append("noisy frames")
|
||||
if np.std(br) > 3.0:
|
||||
issues.append("brightness flicker")
|
||||
if np.std(sl) / (np.mean(sl) + 1e-10) > 0.3:
|
||||
issues.append("sharpness variation")
|
||||
if "boundaries" in metrics and metrics["boundaries"]:
|
||||
avg_ratio = np.mean([b["diff_ratio"] for b in metrics["boundaries"]])
|
||||
if avg_ratio > 2.0:
|
||||
issues.append("chunk boundary artifacts")
|
||||
if issues:
|
||||
print(f" Issues found: {', '.join(issues)}")
|
||||
else:
|
||||
print(" ✓ No significant quality issues detected")
|
||||
print("=" * 72)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Analyze quality of a single generated video"
|
||||
)
|
||||
parser.add_argument("video", help="Path to video file")
|
||||
parser.add_argument(
|
||||
"--chunk-size",
|
||||
type=int,
|
||||
help="Frames per chunk for boundary analysis (e.g., 32)",
|
||||
)
|
||||
parser.add_argument("--start", type=int, default=0, help="Start frame (default: 0)")
|
||||
parser.add_argument("--end", type=int, help="End frame (default: all)")
|
||||
parser.add_argument(
|
||||
"--flow",
|
||||
action="store_true",
|
||||
help="Compute optical flow (slower but more detailed)",
|
||||
)
|
||||
parser.add_argument("--csv", help="Export per-frame metrics to CSV")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Loading: {args.video}")
|
||||
frames, fps, total = load_video(args.video, args.start, args.end)
|
||||
h, w = frames[0].shape[:2]
|
||||
print(f" → {len(frames)} frames, {fps:.1f} fps, {w}x{h}")
|
||||
|
||||
print("Analyzing...")
|
||||
metrics = analyze_video(frames, args.chunk_size, args.flow)
|
||||
print()
|
||||
print_report(metrics, args.video, fps, total, len(frames))
|
||||
|
||||
if args.csv:
|
||||
import csv
|
||||
|
||||
keys = [
|
||||
"sharpness_lap",
|
||||
"sharpness_grad",
|
||||
"brightness",
|
||||
"contrast",
|
||||
"color_mean_b",
|
||||
"color_mean_g",
|
||||
"color_mean_r",
|
||||
"frame_diff",
|
||||
]
|
||||
if args.flow:
|
||||
keys += ["flow_mean", "flow_max"]
|
||||
|
||||
with open(args.csv, "w", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(["frame"] + keys)
|
||||
for i in range(len(frames)):
|
||||
row = [i] + [f"{metrics[k][i]:.4f}" for k in keys]
|
||||
writer.writerow(row)
|
||||
print(f"Per-frame metrics saved to {args.csv}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
4
tests/conftest.py
Normal file
4
tests/conftest.py
Normal file
@@ -0,0 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
@@ -1,17 +1,17 @@
|
||||
"""Tests for LTX-2 dev model generation pipeline."""
|
||||
|
||||
import pytest
|
||||
import mlx.core as mx
|
||||
import pytest
|
||||
|
||||
from mlx_video.generate_dev import (
|
||||
ltx2_scheduler,
|
||||
create_position_grid,
|
||||
create_audio_position_grid,
|
||||
compute_audio_frames,
|
||||
cfg_delta,
|
||||
DEFAULT_NEGATIVE_PROMPT,
|
||||
AUDIO_SAMPLE_RATE,
|
||||
AUDIO_LATENTS_PER_SECOND,
|
||||
AUDIO_SAMPLE_RATE,
|
||||
DEFAULT_NEGATIVE_PROMPT,
|
||||
cfg_delta,
|
||||
compute_audio_frames,
|
||||
create_audio_position_grid,
|
||||
create_position_grid,
|
||||
ltx2_scheduler,
|
||||
)
|
||||
|
||||
|
||||
@@ -22,12 +22,16 @@ class TestLTX2Scheduler:
|
||||
"""Scheduler should return steps+1 sigma values."""
|
||||
steps = 20
|
||||
sigmas = ltx2_scheduler(steps=steps)
|
||||
assert sigmas.shape == (steps + 1,), f"Expected ({steps + 1},), got {sigmas.shape}"
|
||||
assert sigmas.shape == (
|
||||
steps + 1,
|
||||
), f"Expected ({steps + 1},), got {sigmas.shape}"
|
||||
|
||||
def test_scheduler_starts_at_one(self):
|
||||
"""Sigma schedule should start at 1.0."""
|
||||
sigmas = ltx2_scheduler(steps=20)
|
||||
assert abs(sigmas[0].item() - 1.0) < 1e-5, f"Expected 1.0, got {sigmas[0].item()}"
|
||||
assert (
|
||||
abs(sigmas[0].item() - 1.0) < 1e-5
|
||||
), f"Expected 1.0, got {sigmas[0].item()}"
|
||||
|
||||
def test_scheduler_ends_at_zero(self):
|
||||
"""Sigma schedule should end at 0.0."""
|
||||
@@ -39,8 +43,9 @@ class TestLTX2Scheduler:
|
||||
sigmas = ltx2_scheduler(steps=20)
|
||||
sigmas_list = sigmas.tolist()
|
||||
for i in range(len(sigmas_list) - 1):
|
||||
assert sigmas_list[i] >= sigmas_list[i + 1], \
|
||||
f"Sigma not decreasing at index {i}: {sigmas_list[i]} < {sigmas_list[i + 1]}"
|
||||
assert (
|
||||
sigmas_list[i] >= sigmas_list[i + 1]
|
||||
), f"Sigma not decreasing at index {i}: {sigmas_list[i]} < {sigmas_list[i + 1]}"
|
||||
|
||||
def test_scheduler_dtype(self):
|
||||
"""Scheduler should return float32 array."""
|
||||
@@ -84,14 +89,16 @@ class TestCreatePositionGrid:
|
||||
num_patches = num_frames * height * width
|
||||
|
||||
expected_shape = (batch_size, 3, num_patches, 2)
|
||||
assert positions.shape == expected_shape, \
|
||||
f"Expected {expected_shape}, got {positions.shape}"
|
||||
assert (
|
||||
positions.shape == expected_shape
|
||||
), f"Expected {expected_shape}, got {positions.shape}"
|
||||
|
||||
def test_position_grid_dtype(self):
|
||||
"""Position grid should be float32 for RoPE precision."""
|
||||
positions = create_position_grid(1, 5, 16, 24)
|
||||
assert positions.dtype == mx.float32, \
|
||||
f"Expected float32 for RoPE precision, got {positions.dtype}"
|
||||
assert (
|
||||
positions.dtype == mx.float32
|
||||
), f"Expected float32 for RoPE precision, got {positions.dtype}"
|
||||
|
||||
def test_position_grid_batch_size(self):
|
||||
"""Position grid should respect batch size."""
|
||||
@@ -165,7 +172,9 @@ class TestCFGDelta:
|
||||
mx.eval(delta)
|
||||
|
||||
# Scale=1.0 means (1.0 - 1.0) * (cond - uncond) = 0
|
||||
assert mx.max(mx.abs(delta)).item() < 1e-6, "CFG delta with scale=1.0 should be zero"
|
||||
assert (
|
||||
mx.max(mx.abs(delta)).item() < 1e-6
|
||||
), "CFG delta with scale=1.0 should be zero"
|
||||
|
||||
def test_cfg_delta_formula(self):
|
||||
"""CFG delta should follow the formula: (scale-1) * (cond - uncond)."""
|
||||
@@ -204,8 +213,9 @@ class TestDefaultNegativePrompt:
|
||||
|
||||
# Check for common negative quality terms
|
||||
assert "blurry" in prompt_lower, "Should contain 'blurry'"
|
||||
assert "low quality" in prompt_lower or "low contrast" in prompt_lower, \
|
||||
"Should contain quality-related terms"
|
||||
assert (
|
||||
"low quality" in prompt_lower or "low contrast" in prompt_lower
|
||||
), "Should contain quality-related terms"
|
||||
|
||||
|
||||
class TestInputValidation:
|
||||
@@ -248,15 +258,16 @@ class TestInputValidation:
|
||||
(30, 33), # 30 -> nearest valid is 33
|
||||
(35, 33), # 35 -> nearest valid is 33
|
||||
(40, 41), # 40 -> nearest valid is 41
|
||||
(1, 1), # 1 is already valid
|
||||
(1, 1), # 1 is already valid
|
||||
(33, 33), # 33 is already valid
|
||||
]
|
||||
|
||||
for input_frames, expected in test_cases:
|
||||
if input_frames % 8 != 1:
|
||||
adjusted = round((input_frames - 1) / 8) * 8 + 1
|
||||
assert adjusted == expected, \
|
||||
f"Expected {expected} for input {input_frames}, got {adjusted}"
|
||||
assert (
|
||||
adjusted == expected
|
||||
), f"Expected {expected} for input {input_frames}, got {adjusted}"
|
||||
|
||||
|
||||
class TestDenoiseWithCFGMocked:
|
||||
@@ -277,14 +288,16 @@ class TestTilingDefault:
|
||||
def test_tiling_default_is_none(self):
|
||||
"""Default tiling should be 'none' for performance."""
|
||||
import inspect
|
||||
|
||||
from mlx_video.generate_dev import generate_video_dev
|
||||
|
||||
sig = inspect.signature(generate_video_dev)
|
||||
|
||||
tiling_param = sig.parameters.get('tiling')
|
||||
tiling_param = sig.parameters.get("tiling")
|
||||
assert tiling_param is not None
|
||||
assert tiling_param.default == "none", \
|
||||
f"Expected default tiling='none', got '{tiling_param.default}'"
|
||||
assert (
|
||||
tiling_param.default == "none"
|
||||
), f"Expected default tiling='none', got '{tiling_param.default}'"
|
||||
|
||||
|
||||
class TestLatentDimensions:
|
||||
@@ -296,8 +309,9 @@ class TestLatentDimensions:
|
||||
|
||||
for height, expected_latent_h in test_cases:
|
||||
latent_h = height // 32
|
||||
assert latent_h == expected_latent_h, \
|
||||
f"Expected latent_h={expected_latent_h} for height={height}, got {latent_h}"
|
||||
assert (
|
||||
latent_h == expected_latent_h
|
||||
), f"Expected latent_h={expected_latent_h} for height={height}, got {latent_h}"
|
||||
|
||||
def test_latent_width_calculation(self):
|
||||
"""Latent width should be width // 32."""
|
||||
@@ -305,8 +319,9 @@ class TestLatentDimensions:
|
||||
|
||||
for width, expected_latent_w in test_cases:
|
||||
latent_w = width // 32
|
||||
assert latent_w == expected_latent_w, \
|
||||
f"Expected latent_w={expected_latent_w} for width={width}, got {latent_w}"
|
||||
assert (
|
||||
latent_w == expected_latent_w
|
||||
), f"Expected latent_w={expected_latent_w} for width={width}, got {latent_w}"
|
||||
|
||||
def test_latent_frames_calculation(self):
|
||||
"""Latent frames should be 1 + (num_frames - 1) // 8."""
|
||||
@@ -314,8 +329,9 @@ class TestLatentDimensions:
|
||||
|
||||
for num_frames, expected_latent_f in test_cases:
|
||||
latent_f = 1 + (num_frames - 1) // 8
|
||||
assert latent_f == expected_latent_f, \
|
||||
f"Expected latent_f={expected_latent_f} for num_frames={num_frames}, got {latent_f}"
|
||||
assert (
|
||||
latent_f == expected_latent_f
|
||||
), f"Expected latent_f={expected_latent_f} for num_frames={num_frames}, got {latent_f}"
|
||||
|
||||
def test_num_tokens_calculation(self):
|
||||
"""Number of tokens should be latent_f * latent_h * latent_w."""
|
||||
@@ -343,14 +359,14 @@ class TestAudioPositionGrid:
|
||||
positions = create_audio_position_grid(batch_size, audio_frames)
|
||||
expected_shape = (batch_size, 1, audio_frames, 2)
|
||||
|
||||
assert positions.shape == expected_shape, \
|
||||
f"Expected {expected_shape}, got {positions.shape}"
|
||||
assert (
|
||||
positions.shape == expected_shape
|
||||
), f"Expected {expected_shape}, got {positions.shape}"
|
||||
|
||||
def test_audio_position_grid_dtype(self):
|
||||
"""Audio position grid should be float32."""
|
||||
positions = create_audio_position_grid(1, 34)
|
||||
assert positions.dtype == mx.float32, \
|
||||
f"Expected float32, got {positions.dtype}"
|
||||
assert positions.dtype == mx.float32, f"Expected float32, got {positions.dtype}"
|
||||
|
||||
def test_audio_position_grid_batch_size(self):
|
||||
"""Audio position grid should respect batch size."""
|
||||
@@ -371,8 +387,12 @@ class TestAudioPositionGrid:
|
||||
"""Audio position grid should not contain NaN or Inf."""
|
||||
positions = create_audio_position_grid(1, 34)
|
||||
|
||||
assert not mx.any(mx.isnan(positions)).item(), "Audio position grid contains NaN"
|
||||
assert not mx.any(mx.isinf(positions)).item(), "Audio position grid contains Inf"
|
||||
assert not mx.any(
|
||||
mx.isnan(positions)
|
||||
).item(), "Audio position grid contains NaN"
|
||||
assert not mx.any(
|
||||
mx.isinf(positions)
|
||||
).item(), "Audio position grid contains Inf"
|
||||
|
||||
|
||||
class TestComputeAudioFrames:
|
||||
@@ -391,8 +411,9 @@ class TestComputeAudioFrames:
|
||||
audio_33 = compute_audio_frames(33, 24.0)
|
||||
audio_65 = compute_audio_frames(65, 24.0)
|
||||
|
||||
assert audio_65 > audio_33, \
|
||||
f"Expected more audio frames for longer video: {audio_65} <= {audio_33}"
|
||||
assert (
|
||||
audio_65 > audio_33
|
||||
), f"Expected more audio frames for longer video: {audio_65} <= {audio_33}"
|
||||
|
||||
def test_audio_frames_formula(self):
|
||||
"""Audio frames should match expected formula."""
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import pytest
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mlx_video.models.ltx_2.rope import (
|
||||
precompute_freqs_cis,
|
||||
)
|
||||
from mlx_video.models.ltx_2.config import LTXModelConfig, LTXRopeType
|
||||
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
|
||||
|
||||
|
||||
def create_video_position_grid(
|
||||
@@ -20,7 +18,7 @@ def create_video_position_grid(
|
||||
h_coords = np.arange(0, height)
|
||||
w_coords = np.arange(0, width)
|
||||
|
||||
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
|
||||
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing="ij")
|
||||
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
|
||||
patch_ends = patch_starts + 1
|
||||
|
||||
@@ -71,10 +69,14 @@ def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads):
|
||||
scaled = fractional * 2 - 1 # [-1, 1]
|
||||
|
||||
# Outer product: (B, T, n_dims, 1) * (1, 1, 1, num_indices)
|
||||
freqs = scaled[..., np.newaxis] * freq_indices[np.newaxis, np.newaxis, np.newaxis, :]
|
||||
freqs = (
|
||||
scaled[..., np.newaxis] * freq_indices[np.newaxis, np.newaxis, np.newaxis, :]
|
||||
)
|
||||
# (B, T, n_dims, num_indices) -> swap last two -> (B, T, num_indices, n_dims) -> flatten
|
||||
freqs = np.swapaxes(freqs, -1, -2)
|
||||
freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # (B, T, num_indices * n_dims)
|
||||
freqs = freqs.reshape(
|
||||
freqs.shape[0], freqs.shape[1], -1
|
||||
) # (B, T, num_indices * n_dims)
|
||||
|
||||
cos_ref = np.cos(freqs)
|
||||
sin_ref = np.sin(freqs)
|
||||
@@ -84,8 +86,12 @@ def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads):
|
||||
pad_size = expected - cos_ref.shape[-1]
|
||||
if pad_size > 0:
|
||||
# Padding is prepended (ones for cos, zeros for sin) — matches split_freqs_cis()
|
||||
cos_ref = np.concatenate([np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1)
|
||||
sin_ref = np.concatenate([np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1)
|
||||
cos_ref = np.concatenate(
|
||||
[np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1
|
||||
)
|
||||
sin_ref = np.concatenate(
|
||||
[np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1
|
||||
)
|
||||
|
||||
B, T, _ = cos_ref.shape
|
||||
dim_per_head = dim // num_heads
|
||||
@@ -124,10 +130,12 @@ class TestRoPEPositionPrecision:
|
||||
assert not mx.any(mx.isinf(sin_freq)).item(), "sin_freq contains Inf"
|
||||
|
||||
# Verify cos/sin are in valid range [-1, 1]
|
||||
assert mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item(), \
|
||||
"cos_freq values out of [-1, 1] range"
|
||||
assert mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item(), \
|
||||
"sin_freq values out of [-1, 1] range"
|
||||
assert (
|
||||
mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item()
|
||||
), "cos_freq values out of [-1, 1] range"
|
||||
assert (
|
||||
mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item()
|
||||
), "sin_freq values out of [-1, 1] range"
|
||||
|
||||
def test_bfloat16_positions_cause_precision_loss(self):
|
||||
"""bfloat16 positions should produce different (less precise) results than float32.
|
||||
@@ -175,7 +183,9 @@ class TestRoPEPositionPrecision:
|
||||
# The threshold here is intentionally low to catch the issue
|
||||
precision_threshold = 1e-6
|
||||
|
||||
has_precision_loss = max_cos_diff > precision_threshold or max_sin_diff > precision_threshold
|
||||
has_precision_loss = (
|
||||
max_cos_diff > precision_threshold or max_sin_diff > precision_threshold
|
||||
)
|
||||
|
||||
# Document the precision loss (this is expected behavior)
|
||||
if has_precision_loss:
|
||||
@@ -184,8 +194,9 @@ class TestRoPEPositionPrecision:
|
||||
print(f" Max sin difference: {max_sin_diff:.6e}")
|
||||
|
||||
# This assertion documents the issue - bfloat16 positions cause precision loss
|
||||
assert has_precision_loss, \
|
||||
"Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed"
|
||||
assert (
|
||||
has_precision_loss
|
||||
), "Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed"
|
||||
|
||||
def test_double_precision_converts_to_float32_internally(self):
|
||||
"""Verify that double_precision mode converts bfloat16 to float32 first."""
|
||||
@@ -215,20 +226,26 @@ class TestRoPEPositionPrecision:
|
||||
# Recommended: create positions in float32
|
||||
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||
|
||||
assert positions.dtype == mx.float32, \
|
||||
"Position grids should be created in float32 for RoPE precision"
|
||||
assert (
|
||||
positions.dtype == mx.float32
|
||||
), "Position grids should be created in float32 for RoPE precision"
|
||||
|
||||
# Verify the position values are reasonable
|
||||
# Temporal positions should be small (seconds)
|
||||
temporal_positions = positions[:, 0, :, :]
|
||||
assert mx.max(temporal_positions).item() < 100, \
|
||||
"Temporal positions should be in seconds (small values)"
|
||||
assert (
|
||||
mx.max(temporal_positions).item() < 100
|
||||
), "Temporal positions should be in seconds (small values)"
|
||||
|
||||
# Spatial positions should be larger (pixels)
|
||||
spatial_h = positions[:, 1, :, :]
|
||||
spatial_w = positions[:, 2, :, :]
|
||||
assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive"
|
||||
assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive"
|
||||
assert (
|
||||
mx.max(spatial_h).item() > 0
|
||||
), "Spatial height positions should be positive"
|
||||
assert (
|
||||
mx.max(spatial_w).item() > 0
|
||||
), "Spatial width positions should be positive"
|
||||
|
||||
def test_float32_positions_match_numpy_float64_reference(self):
|
||||
"""Regression test: float32 RoPE must closely match a NumPy float64 reference.
|
||||
@@ -259,7 +276,9 @@ class TestRoPEPositionPrecision:
|
||||
)
|
||||
|
||||
# NumPy float64 reference
|
||||
cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads)
|
||||
cos_ref, sin_ref = _numpy_reference_rope(
|
||||
positions_np, dim, theta, max_pos, num_heads
|
||||
)
|
||||
|
||||
cos_mlx_np = np.array(cos_mlx)
|
||||
sin_mlx_np = np.array(sin_mlx)
|
||||
@@ -270,16 +289,21 @@ class TestRoPEPositionPrecision:
|
||||
# Cosine similarity (flatten for single scalar)
|
||||
cos_flat = cos_mlx_np.flatten()
|
||||
ref_flat = cos_ref.flatten()
|
||||
cosine_sim = np.dot(cos_flat, ref_flat) / (np.linalg.norm(cos_flat) * np.linalg.norm(ref_flat))
|
||||
cosine_sim = np.dot(cos_flat, ref_flat) / (
|
||||
np.linalg.norm(cos_flat) * np.linalg.norm(ref_flat)
|
||||
)
|
||||
|
||||
# float32 vs float64: expect small diffs from 23-bit vs 52-bit mantissa.
|
||||
# Threshold 0.01 is well below the bfloat16 failure mode (~2.0 max diff).
|
||||
assert max_cos_diff < 0.01, \
|
||||
f"cos max diff {max_cos_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
|
||||
assert max_sin_diff < 0.01, \
|
||||
f"sin max diff {max_sin_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
|
||||
assert cosine_sim > 0.9999, \
|
||||
f"cos cosine similarity {cosine_sim:.6f} too low — expected >0.9999"
|
||||
assert (
|
||||
max_cos_diff < 0.01
|
||||
), f"cos max diff {max_cos_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
|
||||
assert (
|
||||
max_sin_diff < 0.01
|
||||
), f"sin max diff {max_sin_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
|
||||
assert (
|
||||
cosine_sim > 0.9999
|
||||
), f"cos cosine similarity {cosine_sim:.6f} too low — expected >0.9999"
|
||||
|
||||
def test_high_frequency_amplification_regression(self):
|
||||
"""Regression test for the specific failure mode: high-frequency index amplification.
|
||||
@@ -309,16 +333,20 @@ class TestRoPEPositionPrecision:
|
||||
double_precision=False,
|
||||
)
|
||||
|
||||
cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads)
|
||||
cos_ref, sin_ref = _numpy_reference_rope(
|
||||
positions_np, dim, theta, max_pos, num_heads
|
||||
)
|
||||
|
||||
max_cos_diff = np.max(np.abs(np.array(cos_mlx) - cos_ref))
|
||||
max_sin_diff = np.max(np.abs(np.array(sin_mlx) - sin_ref))
|
||||
|
||||
# Float32 should keep errors well below the bfloat16 failure threshold of ~2.0
|
||||
assert max_cos_diff < 0.01, \
|
||||
f"Production grid cos max diff {max_cos_diff:.4f} — high-freq amplification detected"
|
||||
assert max_sin_diff < 0.01, \
|
||||
f"Production grid sin max diff {max_sin_diff:.4f} — high-freq amplification detected"
|
||||
assert (
|
||||
max_cos_diff < 0.01
|
||||
), f"Production grid cos max diff {max_cos_diff:.4f} — high-freq amplification detected"
|
||||
assert (
|
||||
max_sin_diff < 0.01
|
||||
), f"Production grid sin max diff {max_sin_diff:.4f} — high-freq amplification detected"
|
||||
|
||||
|
||||
class TestRoPEInterleaved:
|
||||
@@ -359,9 +387,13 @@ class TestRoPEInputCasting:
|
||||
positions_bf16 = positions_f32.astype(mx.bfloat16)
|
||||
|
||||
kwargs = dict(
|
||||
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True, num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT, double_precision=False,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=False,
|
||||
)
|
||||
|
||||
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
|
||||
@@ -383,9 +415,13 @@ class TestRoPEInputCasting:
|
||||
positions_bf16 = positions_f32.astype(mx.bfloat16)
|
||||
|
||||
kwargs = dict(
|
||||
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True, num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT, double_precision=True,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=True,
|
||||
)
|
||||
|
||||
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
|
||||
@@ -405,9 +441,13 @@ class TestRoPEInputCasting:
|
||||
|
||||
cos_freq, sin_freq = precompute_freqs_cis(
|
||||
indices_grid=positions_f16,
|
||||
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True, num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT, double_precision=False,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=False,
|
||||
)
|
||||
|
||||
assert cos_freq.dtype == mx.float32
|
||||
@@ -421,20 +461,23 @@ class TestDoublePrecisionRopeConfig:
|
||||
def test_ltx2_forces_double_precision_rope_false(self):
|
||||
"""LTX-2 (no prompt adaln) must have double_precision_rope=False."""
|
||||
config = LTXModelConfig(has_prompt_adaln=False, double_precision_rope=True)
|
||||
assert config.double_precision_rope is False, \
|
||||
"LTX-2 should force double_precision_rope=False regardless of input"
|
||||
assert (
|
||||
config.double_precision_rope is False
|
||||
), "LTX-2 should force double_precision_rope=False regardless of input"
|
||||
|
||||
def test_ltx23_preserves_double_precision_rope_true(self):
|
||||
"""LTX-2.3 (has_prompt_adaln=True) should keep double_precision_rope=True."""
|
||||
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=True)
|
||||
assert config.double_precision_rope is True, \
|
||||
"LTX-2.3 should preserve double_precision_rope=True"
|
||||
assert (
|
||||
config.double_precision_rope is True
|
||||
), "LTX-2.3 should preserve double_precision_rope=True"
|
||||
|
||||
def test_ltx23_preserves_double_precision_rope_false(self):
|
||||
"""LTX-2.3 with double_precision_rope=False should stay False."""
|
||||
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=False)
|
||||
assert config.double_precision_rope is False, \
|
||||
"LTX-2.3 should respect double_precision_rope=False when explicitly set"
|
||||
assert (
|
||||
config.double_precision_rope is False
|
||||
), "LTX-2.3 should respect double_precision_rope=False when explicitly set"
|
||||
|
||||
def test_ltx2_default_double_precision_rope(self):
|
||||
"""LTX-2 default (double_precision_rope not set) should be False."""
|
||||
@@ -449,20 +492,24 @@ class TestDoublePrecisionRopeConfig:
|
||||
|
||||
def test_config_from_dict_ltx2(self):
|
||||
"""Config created from dict for LTX-2 should force double_precision_rope=False."""
|
||||
config = LTXModelConfig.from_dict({
|
||||
"has_prompt_adaln": False,
|
||||
"double_precision_rope": True,
|
||||
"rope_type": "split",
|
||||
})
|
||||
config = LTXModelConfig.from_dict(
|
||||
{
|
||||
"has_prompt_adaln": False,
|
||||
"double_precision_rope": True,
|
||||
"rope_type": "split",
|
||||
}
|
||||
)
|
||||
assert config.double_precision_rope is False
|
||||
|
||||
def test_config_from_dict_ltx23(self):
|
||||
"""Config created from dict for LTX-2.3 should preserve double_precision_rope."""
|
||||
config = LTXModelConfig.from_dict({
|
||||
"has_prompt_adaln": True,
|
||||
"double_precision_rope": True,
|
||||
"rope_type": "split",
|
||||
})
|
||||
config = LTXModelConfig.from_dict(
|
||||
{
|
||||
"has_prompt_adaln": True,
|
||||
"double_precision_rope": True,
|
||||
"rope_type": "split",
|
||||
}
|
||||
)
|
||||
assert config.double_precision_rope is True
|
||||
|
||||
|
||||
@@ -496,10 +543,12 @@ class TestRoPESplit:
|
||||
# dim=128, num_heads=32, so dim_per_head=4, and split uses half=2
|
||||
dim_per_head = dim // num_heads
|
||||
expected_shape = (batch_size, num_heads, num_tokens, dim_per_head // 2)
|
||||
assert cos_freq.shape == expected_shape, \
|
||||
f"Expected shape {expected_shape}, got {cos_freq.shape}"
|
||||
assert sin_freq.shape == expected_shape, \
|
||||
f"Expected shape {expected_shape}, got {sin_freq.shape}"
|
||||
assert (
|
||||
cos_freq.shape == expected_shape
|
||||
), f"Expected shape {expected_shape}, got {cos_freq.shape}"
|
||||
assert (
|
||||
sin_freq.shape == expected_shape
|
||||
), f"Expected shape {expected_shape}, got {sin_freq.shape}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Tests for VAE streaming and chunked conv features."""
|
||||
|
||||
import pytest
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import (
|
||||
@@ -50,7 +50,7 @@ class TestChunkedConv:
|
||||
np.array(out_chunked),
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
err_msg="Chunked conv output differs from regular output"
|
||||
err_msg="Chunked conv output differs from regular output",
|
||||
)
|
||||
|
||||
def test_chunked_conv_small_input_passthrough(self):
|
||||
@@ -117,13 +117,17 @@ class TestProgressiveFrameSaving:
|
||||
frames_received = []
|
||||
|
||||
def on_frames_ready(frames: mx.array, start_idx: int):
|
||||
frames_received.append({
|
||||
'shape': frames.shape,
|
||||
'start_idx': start_idx,
|
||||
})
|
||||
frames_received.append(
|
||||
{
|
||||
"shape": frames.shape,
|
||||
"start_idx": start_idx,
|
||||
}
|
||||
)
|
||||
|
||||
# Create a mock decoder that just returns scaled input
|
||||
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False):
|
||||
def mock_decoder(
|
||||
x, causal=False, timestep=None, debug=False, chunked_conv=False
|
||||
):
|
||||
# Simulate VAE output: upsample 8x temporal, 32x spatial
|
||||
b, c, f, h, w = x.shape
|
||||
out_f = 1 + (f - 1) * 8
|
||||
@@ -154,7 +158,9 @@ class TestProgressiveFrameSaving:
|
||||
|
||||
# All received frames should have correct channel count
|
||||
for received in frames_received:
|
||||
assert received['shape'][1] == 3, f"Expected 3 channels, got {received['shape'][1]}"
|
||||
assert (
|
||||
received["shape"][1] == 3
|
||||
), f"Expected 3 channels, got {received['shape'][1]}"
|
||||
|
||||
def test_on_frames_ready_covers_all_frames(self):
|
||||
"""Verify all frames are emitted via callbacks."""
|
||||
@@ -165,7 +171,9 @@ class TestProgressiveFrameSaving:
|
||||
for i in range(num_frames):
|
||||
all_frame_indices.add(start_idx + i)
|
||||
|
||||
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False):
|
||||
def mock_decoder(
|
||||
x, causal=False, timestep=None, debug=False, chunked_conv=False
|
||||
):
|
||||
b, c, f, h, w = x.shape
|
||||
out_f = 1 + (f - 1) * 8
|
||||
out_h = h * 32
|
||||
@@ -191,24 +199,29 @@ class TestProgressiveFrameSaving:
|
||||
expected_frames = 1 + (12 - 1) * 8 # 89 frames
|
||||
|
||||
# All frames should have been emitted
|
||||
assert len(all_frame_indices) == expected_frames, \
|
||||
f"Expected {expected_frames} frames, got {len(all_frame_indices)}"
|
||||
assert all_frame_indices == set(range(expected_frames)), \
|
||||
"Not all frame indices were covered"
|
||||
assert (
|
||||
len(all_frame_indices) == expected_frames
|
||||
), f"Expected {expected_frames} frames, got {len(all_frame_indices)}"
|
||||
assert all_frame_indices == set(
|
||||
range(expected_frames)
|
||||
), "Not all frame indices were covered"
|
||||
|
||||
|
||||
class TestAutoChunkedConv:
|
||||
"""Tests for auto-enabling chunked_conv based on tiling mode."""
|
||||
|
||||
@pytest.mark.parametrize("tiling_mode,should_enable", [
|
||||
("conservative", True),
|
||||
("none", True),
|
||||
("auto", True),
|
||||
("default", True),
|
||||
("spatial", True),
|
||||
("aggressive", False),
|
||||
("temporal", False),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"tiling_mode,should_enable",
|
||||
[
|
||||
("conservative", True),
|
||||
("none", True),
|
||||
("auto", True),
|
||||
("default", True),
|
||||
("spatial", True),
|
||||
("aggressive", False),
|
||||
("temporal", False),
|
||||
],
|
||||
)
|
||||
def test_chunked_conv_auto_enable(self, tiling_mode: str, should_enable: bool):
|
||||
"""Verify chunked_conv is auto-enabled for correct tiling modes."""
|
||||
# The logic is: tiling_mode in ("conservative", "none", "auto", "default", "spatial")
|
||||
@@ -216,8 +229,9 @@ class TestAutoChunkedConv:
|
||||
|
||||
use_chunked_conv = tiling_mode in expected_modes
|
||||
|
||||
assert use_chunked_conv == should_enable, \
|
||||
f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}"
|
||||
assert (
|
||||
use_chunked_conv == should_enable
|
||||
), f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}"
|
||||
|
||||
|
||||
class TestTrapezoidalMask:
|
||||
@@ -250,7 +264,9 @@ class TestTrapezoidalMask:
|
||||
|
||||
# Right ramp should be decreasing
|
||||
right_ramp = mask_np[-8:]
|
||||
assert np.all(np.diff(right_ramp) <= 0), "Right ramp not monotonically decreasing"
|
||||
assert np.all(
|
||||
np.diff(right_ramp) <= 0
|
||||
), "Right ramp not monotonically decreasing"
|
||||
|
||||
def test_temporal_mask_starts_from_zero(self):
|
||||
"""Verify temporal mask (left_starts_from_0=True) starts from 0."""
|
||||
|
||||
399
tests/test_wan_attention.py
Normal file
399
tests/test_wan_attention.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""Tests for Wan attention components and RoPE."""
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RoPE Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRoPE:
|
||||
"""Tests for 3-way factorized RoPE."""
|
||||
|
||||
def test_rope_params_shape(self):
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
freqs = rope_params(1024, 64)
|
||||
mx.eval(freqs)
|
||||
assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2]
|
||||
|
||||
def test_rope_params_different_dims(self):
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
for dim in [32, 64, 128]:
|
||||
freqs = rope_params(512, dim)
|
||||
mx.eval(freqs)
|
||||
assert freqs.shape == (512, dim // 2, 2)
|
||||
|
||||
def test_rope_params_cos_sin_range(self):
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
freqs = rope_params(256, 64)
|
||||
mx.eval(freqs)
|
||||
cos_vals = np.array(freqs[:, :, 0])
|
||||
sin_vals = np.array(freqs[:, :, 1])
|
||||
assert np.all(cos_vals >= -1.0) and np.all(cos_vals <= 1.0)
|
||||
assert np.all(sin_vals >= -1.0) and np.all(sin_vals <= 1.0)
|
||||
|
||||
def test_rope_params_position_zero(self):
|
||||
"""At position 0, cos should be 1 and sin should be 0."""
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
freqs = rope_params(10, 64)
|
||||
mx.eval(freqs)
|
||||
np.testing.assert_allclose(np.array(freqs[0, :, 0]), 1.0, atol=1e-6)
|
||||
np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6)
|
||||
|
||||
def test_rope_apply_output_shape(self):
|
||||
from mlx_video.models.wan_2.rope import rope_apply, rope_params
|
||||
|
||||
B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim
|
||||
x = mx.random.normal((B, L, N, D))
|
||||
freqs = rope_params(1024, D)
|
||||
grid_sizes = [(2, 3, 4)] # F*H*W = 24 = L
|
||||
out = rope_apply(x, grid_sizes, freqs)
|
||||
mx.eval(out)
|
||||
assert out.shape == (B, L, N, D)
|
||||
|
||||
def test_rope_apply_preserves_norm(self):
|
||||
"""RoPE rotation should preserve vector norms."""
|
||||
from mlx_video.models.wan_2.rope import rope_apply, rope_params
|
||||
|
||||
B, N, D = 1, 2, 16
|
||||
F, H, W = 2, 3, 4
|
||||
L = F * H * W
|
||||
x = mx.random.normal((B, L, N, D))
|
||||
freqs = rope_params(1024, D)
|
||||
|
||||
out = rope_apply(x, [(F, H, W)], freqs)
|
||||
mx.eval(x, out)
|
||||
|
||||
x_np = np.array(x[0])
|
||||
out_np = np.array(out[0])
|
||||
for i in range(L):
|
||||
for h in range(N):
|
||||
norm_in = np.linalg.norm(x_np[i, h])
|
||||
norm_out = np.linalg.norm(out_np[i, h])
|
||||
np.testing.assert_allclose(norm_in, norm_out, rtol=1e-4)
|
||||
|
||||
def test_rope_apply_with_padding(self):
|
||||
"""When seq_len < L, extra tokens should be preserved unchanged."""
|
||||
from mlx_video.models.wan_2.rope import rope_apply, rope_params
|
||||
|
||||
B, N, D = 1, 2, 16
|
||||
F, H, W = 2, 2, 2
|
||||
seq_len = F * H * W # 8
|
||||
pad = 4
|
||||
L = seq_len + pad
|
||||
x = mx.random.normal((B, L, N, D))
|
||||
freqs = rope_params(1024, D)
|
||||
|
||||
out = rope_apply(x, [(F, H, W)], freqs)
|
||||
mx.eval(x, out)
|
||||
# Padded tokens should be unchanged
|
||||
np.testing.assert_allclose(
|
||||
np.array(out[0, seq_len:]),
|
||||
np.array(x[0, seq_len:]),
|
||||
atol=1e-6,
|
||||
)
|
||||
|
||||
def test_rope_apply_batch(self):
|
||||
"""Test with batch_size > 1 and different grid sizes."""
|
||||
from mlx_video.models.wan_2.rope import rope_apply, rope_params
|
||||
|
||||
B, N, D = 2, 2, 16
|
||||
grids = [(2, 3, 4), (2, 3, 4)]
|
||||
L = 2 * 3 * 4
|
||||
x = mx.random.normal((B, L, N, D))
|
||||
freqs = rope_params(1024, D)
|
||||
|
||||
out = rope_apply(x, grids, freqs)
|
||||
mx.eval(out)
|
||||
assert out.shape == (B, L, N, D)
|
||||
|
||||
def test_rope_frequency_split(self):
|
||||
"""Verify the 3-way frequency dimension split matches Wan2.2 convention."""
|
||||
D = 128 # head_dim for 14B model
|
||||
half_d = D // 2
|
||||
d_t = half_d - 2 * (half_d // 3)
|
||||
d_h = half_d // 3
|
||||
d_w = half_d // 3
|
||||
assert d_t + d_h + d_w == half_d
|
||||
# Temporal gets more capacity
|
||||
assert d_t >= d_h
|
||||
assert d_t >= d_w
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attention Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWanRMSNorm:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan_2.attention import WanRMSNorm
|
||||
|
||||
norm = WanRMSNorm(64)
|
||||
x = mx.random.normal((2, 10, 64))
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (2, 10, 64)
|
||||
|
||||
def test_zero_mean_variance(self):
|
||||
"""RMS norm should make RMS ≈ 1 before scaling."""
|
||||
from mlx_video.models.wan_2.attention import WanRMSNorm
|
||||
|
||||
norm = WanRMSNorm(64)
|
||||
x = mx.random.normal((1, 5, 64)) * 10.0
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
out_np = np.array(out[0])
|
||||
for i in range(5):
|
||||
rms = np.sqrt(np.mean(out_np[i] ** 2))
|
||||
# After RMS norm with weight=1, RMS should be ~1
|
||||
np.testing.assert_allclose(rms, 1.0, rtol=0.1)
|
||||
|
||||
def test_dtype_preservation(self):
|
||||
"""RMSNorm weight is float32, so output is promoted to float32."""
|
||||
from mlx_video.models.wan_2.attention import WanRMSNorm
|
||||
|
||||
norm = WanRMSNorm(32)
|
||||
x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16)
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
# Weight is float32, so multiplication promotes result to float32
|
||||
assert out.dtype == mx.float32
|
||||
|
||||
|
||||
class TestWanLayerNorm:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan_2.attention import WanLayerNorm
|
||||
|
||||
norm = WanLayerNorm(64)
|
||||
x = mx.random.normal((2, 10, 64))
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (2, 10, 64)
|
||||
|
||||
def test_without_affine(self):
|
||||
from mlx_video.models.wan_2.attention import WanLayerNorm
|
||||
|
||||
norm = WanLayerNorm(64, elementwise_affine=False)
|
||||
x = mx.random.normal((1, 4, 64))
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
# Mean should be ~0, variance should be ~1
|
||||
out_np = np.array(out[0])
|
||||
for i in range(4):
|
||||
np.testing.assert_allclose(np.mean(out_np[i]), 0.0, atol=0.05)
|
||||
np.testing.assert_allclose(np.std(out_np[i]), 1.0, rtol=0.1)
|
||||
|
||||
def test_with_affine(self):
|
||||
from mlx_video.models.wan_2.attention import WanLayerNorm
|
||||
|
||||
norm = WanLayerNorm(32, elementwise_affine=True)
|
||||
assert hasattr(norm, "weight")
|
||||
assert hasattr(norm, "bias")
|
||||
x = mx.random.normal((1, 4, 32))
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 4, 32)
|
||||
|
||||
|
||||
class TestWanSelfAttention:
|
||||
def setup_method(self):
|
||||
mx.random.seed(42)
|
||||
self.dim = 64
|
||||
self.num_heads = 4
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan_2.attention import WanSelfAttention
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads)
|
||||
B, L = 1, 24
|
||||
F, H, W = 2, 3, 4
|
||||
x = mx.random.normal((B, L, self.dim))
|
||||
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||
out = attn(x, seq_lens=[L], grid_sizes=[(F, H, W)], freqs=freqs)
|
||||
mx.eval(out)
|
||||
assert out.shape == (B, L, self.dim)
|
||||
|
||||
def test_with_qk_norm(self):
|
||||
from mlx_video.models.wan_2.attention import WanSelfAttention
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True)
|
||||
assert attn.norm_q is not None
|
||||
assert attn.norm_k is not None
|
||||
|
||||
def test_without_qk_norm(self):
|
||||
from mlx_video.models.wan_2.attention import WanSelfAttention
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
|
||||
assert attn.norm_q is None
|
||||
assert attn.norm_k is None
|
||||
|
||||
def test_masking(self):
|
||||
"""Test that masking works: shorter seq_lens should mask later tokens."""
|
||||
from mlx_video.models.wan_2.attention import WanSelfAttention
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
|
||||
B, L = 1, 24
|
||||
F, H, W = 2, 3, 4
|
||||
x = mx.random.normal((B, L, self.dim))
|
||||
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||
|
||||
# Full sequence
|
||||
out_full = attn(x, seq_lens=[L], grid_sizes=[(F, H, W)], freqs=freqs)
|
||||
# Shorter sequence (mask last 4 tokens)
|
||||
out_masked = attn(x, seq_lens=[L - 4], grid_sizes=[(F, H, W)], freqs=freqs)
|
||||
mx.eval(out_full, out_masked)
|
||||
|
||||
# Outputs should differ when masking is applied
|
||||
assert not np.allclose(np.array(out_full), np.array(out_masked), atol=1e-5)
|
||||
|
||||
|
||||
class TestWanCrossAttention:
|
||||
def setup_method(self):
|
||||
mx.random.seed(42)
|
||||
self.dim = 64
|
||||
self.num_heads = 4
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan_2.attention import WanCrossAttention
|
||||
|
||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||
B, L_q, L_kv = 1, 24, 16
|
||||
x = mx.random.normal((B, L_q, self.dim))
|
||||
context = mx.random.normal((B, L_kv, self.dim))
|
||||
out = attn(x, context)
|
||||
mx.eval(out)
|
||||
assert out.shape == (B, L_q, self.dim)
|
||||
|
||||
def test_with_context_mask(self):
|
||||
from mlx_video.models.wan_2.attention import WanCrossAttention
|
||||
|
||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||
B, L_q, L_kv = 1, 12, 16
|
||||
x = mx.random.normal((B, L_q, self.dim))
|
||||
context = mx.random.normal((B, L_kv, self.dim))
|
||||
out = attn(x, context, context_lens=[10])
|
||||
mx.eval(out)
|
||||
assert out.shape == (B, L_q, self.dim)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# bfloat16 Autocast Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBFloat16Autocast:
|
||||
"""Tests that attention and FFN cast inputs to weight dtype (bfloat16)
|
||||
for efficient matmul, matching official PyTorch autocast behavior."""
|
||||
|
||||
def setup_method(self):
|
||||
mx.random.seed(42)
|
||||
self.dim = 64
|
||||
self.num_heads = 4
|
||||
|
||||
@staticmethod
|
||||
def _to_bf16(params):
|
||||
"""Recursively cast all arrays in params to bfloat16."""
|
||||
if isinstance(params, dict):
|
||||
return {k: TestBFloat16Autocast._to_bf16(v) for k, v in params.items()}
|
||||
elif isinstance(params, list):
|
||||
return [TestBFloat16Autocast._to_bf16(v) for v in params]
|
||||
elif isinstance(params, mx.array):
|
||||
return params.astype(mx.bfloat16)
|
||||
return params
|
||||
|
||||
def test_self_attn_casts_to_weight_dtype(self):
|
||||
"""Self-attention should cast input to weight dtype for QKV projections."""
|
||||
from mlx_video.models.wan_2.attention import WanSelfAttention
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads)
|
||||
attn.update(self._to_bf16(attn.parameters()))
|
||||
|
||||
x = mx.random.normal((1, 8, self.dim))
|
||||
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||
out = attn(x, seq_lens=[8], grid_sizes=[(2, 2, 2)], freqs=freqs)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 8, self.dim)
|
||||
assert np.isfinite(np.array(out.astype(mx.float32))).all()
|
||||
|
||||
def test_cross_attn_casts_to_weight_dtype(self):
|
||||
"""Cross-attention should cast input to weight dtype."""
|
||||
from mlx_video.models.wan_2.attention import WanCrossAttention
|
||||
|
||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||
attn.update(self._to_bf16(attn.parameters()))
|
||||
|
||||
x = mx.random.normal((1, 8, self.dim))
|
||||
ctx = mx.random.normal((1, 4, self.dim))
|
||||
out = attn(x, ctx)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 8, self.dim)
|
||||
assert np.isfinite(np.array(out.astype(mx.float32))).all()
|
||||
|
||||
def test_cross_attn_kv_cache_uses_weight_dtype(self):
|
||||
"""prepare_kv should cast context to weight dtype."""
|
||||
from mlx_video.models.wan_2.attention import WanCrossAttention
|
||||
|
||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||
attn.update(self._to_bf16(attn.parameters()))
|
||||
|
||||
ctx = mx.random.normal((1, 4, self.dim))
|
||||
k, v = attn.prepare_kv(ctx)
|
||||
mx.eval(k, v)
|
||||
assert k.dtype == mx.bfloat16
|
||||
assert v.dtype == mx.bfloat16
|
||||
|
||||
def test_ffn_casts_to_weight_dtype(self):
|
||||
"""FFN should cast input to weight dtype for linear layers."""
|
||||
from mlx_video.models.wan_2.transformer import WanFFN
|
||||
|
||||
ffn = WanFFN(self.dim, 128)
|
||||
ffn.update(self._to_bf16(ffn.parameters()))
|
||||
|
||||
x = mx.random.normal((1, 8, self.dim))
|
||||
out = ffn(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 8, self.dim)
|
||||
assert np.isfinite(np.array(out.astype(mx.float32))).all()
|
||||
|
||||
def test_self_attn_rope_in_float32(self):
|
||||
"""RoPE should be applied in float32 for precision, even with bf16 weights."""
|
||||
from mlx_video.models.wan_2.attention import WanSelfAttention
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads)
|
||||
attn.update(self._to_bf16(attn.parameters()))
|
||||
|
||||
x = mx.random.normal((1, 8, self.dim))
|
||||
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||
assert freqs.dtype == mx.float32
|
||||
out = attn(x, seq_lens=[8], grid_sizes=[(2, 2, 2)], freqs=freqs)
|
||||
mx.eval(out)
|
||||
assert np.isfinite(np.array(out.astype(mx.float32))).all()
|
||||
|
||||
def test_block_float32_residual_with_bf16_weights(self):
|
||||
"""Full block: residual stream stays float32, matmuls use bf16 weights."""
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
from mlx_video.models.wan_2.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(self.dim, 128, self.num_heads, cross_attn_norm=True)
|
||||
block.update(self._to_bf16(block.parameters()))
|
||||
|
||||
B, L = 1, 8
|
||||
x = mx.random.normal((B, L, self.dim))
|
||||
e = mx.random.normal((B, L, 6, self.dim))
|
||||
ctx = mx.random.normal((B, 4, self.dim))
|
||||
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||
|
||||
out = block(x, e, [L], [(2, 2, 2)], freqs, ctx)
|
||||
mx.eval(out)
|
||||
assert out.dtype == mx.float32
|
||||
assert np.isfinite(np.array(out)).all()
|
||||
135
tests/test_wan_config.py
Normal file
135
tests/test_wan_config.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Tests for Wan model configuration."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWanModelConfig:
|
||||
"""Tests for WanModelConfig dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
assert config.dim == 5120
|
||||
assert config.ffn_dim == 13824
|
||||
assert config.num_heads == 40
|
||||
assert config.num_layers == 40
|
||||
assert config.in_dim == 16
|
||||
assert config.out_dim == 16
|
||||
assert config.patch_size == (1, 2, 2)
|
||||
assert config.vae_stride == (4, 8, 8)
|
||||
assert config.vae_z_dim == 16
|
||||
assert config.boundary == 0.875
|
||||
assert config.sample_shift == 12.0
|
||||
assert config.sample_steps == 40
|
||||
assert config.sample_guide_scale == (3.0, 4.0)
|
||||
assert config.num_train_timesteps == 1000
|
||||
assert config.qk_norm is True
|
||||
assert config.cross_attn_norm is True
|
||||
assert config.text_len == 512
|
||||
|
||||
def test_head_dim_property(self):
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
assert config.head_dim == 128 # 5120 // 40
|
||||
|
||||
def test_to_dict_roundtrip(self):
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
d = config.to_dict()
|
||||
assert isinstance(d, dict)
|
||||
assert d["dim"] == 5120
|
||||
assert d["patch_size"] == (1, 2, 2)
|
||||
assert d["boundary"] == 0.875
|
||||
|
||||
def test_t5_config_values(self):
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
assert config.t5_vocab_size == 256384
|
||||
assert config.t5_dim == 4096
|
||||
assert config.t5_dim_attn == 4096
|
||||
assert config.t5_dim_ffn == 10240
|
||||
assert config.t5_num_heads == 64
|
||||
assert config.t5_num_layers == 24
|
||||
assert config.t5_num_buckets == 32
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wan2.1 Config Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWan21Config:
|
||||
"""Tests for Wan2.1 config presets."""
|
||||
|
||||
def test_wan21_14b_factory(self):
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
assert config.model_version == "2.1"
|
||||
assert config.dual_model is False
|
||||
assert config.dim == 5120
|
||||
assert config.ffn_dim == 13824
|
||||
assert config.num_heads == 40
|
||||
assert config.num_layers == 40
|
||||
assert config.head_dim == 128
|
||||
assert config.sample_guide_scale == 5.0
|
||||
assert config.sample_shift == 5.0
|
||||
assert config.sample_steps == 50
|
||||
assert config.boundary == 0.0
|
||||
|
||||
def test_wan21_1_3b_factory(self):
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
assert config.model_version == "2.1"
|
||||
assert config.dual_model is False
|
||||
assert config.dim == 1536
|
||||
assert config.ffn_dim == 8960
|
||||
assert config.num_heads == 12
|
||||
assert config.num_layers == 30
|
||||
assert config.head_dim == 128 # 1536 // 12
|
||||
assert config.sample_guide_scale == 5.0
|
||||
|
||||
def test_wan22_14b_factory(self):
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan22_t2v_14b()
|
||||
assert config.model_version == "2.2"
|
||||
assert config.dual_model is True
|
||||
assert config.dim == 5120
|
||||
assert config.sample_guide_scale == (3.0, 4.0)
|
||||
assert config.sample_shift == 12.0
|
||||
assert config.sample_steps == 40
|
||||
assert config.boundary == 0.875
|
||||
|
||||
def test_wan21_config_to_dict(self):
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
d = config.to_dict()
|
||||
assert d["model_version"] == "2.1"
|
||||
assert d["dual_model"] is False
|
||||
assert d["sample_guide_scale"] == 5.0
|
||||
|
||||
def test_wan21_1_3b_config_to_dict(self):
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
d = config.to_dict()
|
||||
assert d["dim"] == 1536
|
||||
assert d["num_layers"] == 30
|
||||
|
||||
def test_default_config_is_wan22(self):
|
||||
"""Default WanModelConfig() should be Wan2.2 14B."""
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
assert config.model_version == "2.2"
|
||||
assert config.dual_model is True
|
||||
324
tests/test_wan_convert.py
Normal file
324
tests/test_wan_convert.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""Tests for Wan weight conversion utilities."""
|
||||
|
||||
import logging
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transformer Weight Conversion Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSanitizeTransformerWeights:
|
||||
def test_patch_embedding_reshape(self):
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
|
||||
"patch_embedding.bias": mx.random.normal((5120,)),
|
||||
}
|
||||
out = sanitize_wan_transformer_weights(weights)
|
||||
assert "patch_embedding_proj.weight" in out
|
||||
assert "patch_embedding_proj.bias" in out
|
||||
assert out["patch_embedding_proj.weight"].shape == (5120, 16 * 1 * 2 * 2)
|
||||
|
||||
def test_text_embedding_rename(self):
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"text_embedding.0.weight": mx.zeros((64, 32)),
|
||||
"text_embedding.0.bias": mx.zeros((64,)),
|
||||
"text_embedding.2.weight": mx.zeros((64, 64)),
|
||||
"text_embedding.2.bias": mx.zeros((64,)),
|
||||
}
|
||||
out = sanitize_wan_transformer_weights(weights)
|
||||
assert "text_embedding_0.weight" in out
|
||||
assert "text_embedding_0.bias" in out
|
||||
assert "text_embedding_1.weight" in out
|
||||
assert "text_embedding_1.bias" in out
|
||||
|
||||
def test_time_embedding_rename(self):
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"time_embedding.0.weight": mx.zeros((64, 32)),
|
||||
"time_embedding.2.weight": mx.zeros((64, 64)),
|
||||
}
|
||||
out = sanitize_wan_transformer_weights(weights)
|
||||
assert "time_embedding_0.weight" in out
|
||||
assert "time_embedding_1.weight" in out
|
||||
|
||||
def test_time_projection_rename(self):
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"time_projection.1.weight": mx.zeros((384, 64)),
|
||||
"time_projection.1.bias": mx.zeros((384,)),
|
||||
}
|
||||
out = sanitize_wan_transformer_weights(weights)
|
||||
assert "time_projection.weight" in out
|
||||
assert "time_projection.bias" in out
|
||||
|
||||
def test_ffn_rename(self):
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
|
||||
"blocks.0.ffn.0.bias": mx.zeros((128,)),
|
||||
"blocks.0.ffn.2.weight": mx.zeros((64, 128)),
|
||||
"blocks.0.ffn.2.bias": mx.zeros((64,)),
|
||||
}
|
||||
out = sanitize_wan_transformer_weights(weights)
|
||||
assert "blocks.0.ffn.fc1.weight" in out
|
||||
assert "blocks.0.ffn.fc1.bias" in out
|
||||
assert "blocks.0.ffn.fc2.weight" in out
|
||||
assert "blocks.0.ffn.fc2.bias" in out
|
||||
|
||||
def test_freqs_skipped(self):
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"freqs": mx.zeros((1024, 64, 2)),
|
||||
"blocks.0.norm1.weight": mx.zeros((64,)),
|
||||
}
|
||||
out = sanitize_wan_transformer_weights(weights)
|
||||
assert "freqs" not in out
|
||||
assert "blocks.0.norm1.weight" in out
|
||||
|
||||
def test_passthrough_keys(self):
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
|
||||
"blocks.0.self_attn.k.weight": mx.zeros((64, 64)),
|
||||
"blocks.0.self_attn.v.weight": mx.zeros((64, 64)),
|
||||
"blocks.0.self_attn.o.weight": mx.zeros((64, 64)),
|
||||
"blocks.0.modulation": mx.zeros((1, 6, 64)),
|
||||
"head.head.weight": mx.zeros((64, 64)),
|
||||
"head.modulation": mx.zeros((1, 2, 64)),
|
||||
}
|
||||
out = sanitize_wan_transformer_weights(weights)
|
||||
for key in weights:
|
||||
assert key in out
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
|
||||
"patch_embedding.bias": mx.random.normal((5120,)),
|
||||
"text_embedding.0.weight": mx.zeros((64, 32)),
|
||||
"text_embedding.2.weight": mx.zeros((64, 64)),
|
||||
"time_embedding.0.weight": mx.zeros((64, 32)),
|
||||
"time_embedding.2.weight": mx.zeros((64, 64)),
|
||||
"time_projection.1.weight": mx.zeros((384, 64)),
|
||||
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
|
||||
"blocks.0.ffn.2.weight": mx.zeros((64, 128)),
|
||||
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
|
||||
"blocks.0.modulation": mx.zeros((1, 6, 64)),
|
||||
"head.head.weight": mx.zeros((64, 64)),
|
||||
"freqs": mx.zeros((1024, 64, 2)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.convert"):
|
||||
sanitize_wan_transformer_weights(weights)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
|
||||
class TestSanitizeT5Weights:
|
||||
def test_gate_rename(self):
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_t5_weights
|
||||
|
||||
weights = {
|
||||
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
|
||||
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
|
||||
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
|
||||
}
|
||||
out = sanitize_wan_t5_weights(weights)
|
||||
assert "blocks.0.ffn.gate_proj.weight" in out
|
||||
assert "blocks.0.ffn.fc1.weight" in out
|
||||
assert "blocks.0.ffn.fc2.weight" in out
|
||||
|
||||
def test_passthrough(self):
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_t5_weights
|
||||
|
||||
weights = {
|
||||
"token_embedding.weight": mx.zeros((100, 64)),
|
||||
"blocks.0.attn.q.weight": mx.zeros((64, 64)),
|
||||
"norm.weight": mx.zeros((64,)),
|
||||
}
|
||||
out = sanitize_wan_t5_weights(weights)
|
||||
for key in weights:
|
||||
assert key in out
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_t5_weights
|
||||
|
||||
weights = {
|
||||
"token_embedding.weight": mx.zeros((100, 64)),
|
||||
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
|
||||
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
|
||||
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
|
||||
"norm.weight": mx.zeros((64,)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.convert"):
|
||||
sanitize_wan_t5_weights(weights)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
|
||||
class TestSanitizeVAEWeights:
|
||||
def test_conv3d_transpose(self):
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W]
|
||||
}
|
||||
out = sanitize_wan_vae_weights(weights)
|
||||
assert out["decoder.conv1.weight"].shape == (8, 3, 3, 3, 4) # [O, D, H, W, I]
|
||||
|
||||
def test_conv2d_transpose(self):
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W]
|
||||
}
|
||||
out = sanitize_wan_vae_weights(weights)
|
||||
assert out["decoder.proj.weight"].shape == (16, 3, 3, 8) # [O, H, W, I]
|
||||
|
||||
def test_non_conv_passthrough(self):
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose
|
||||
"decoder.bias": mx.zeros((16,)),
|
||||
}
|
||||
out = sanitize_wan_vae_weights(weights)
|
||||
assert out["decoder.norm.weight"].shape == (64,)
|
||||
assert out["decoder.bias"].shape == (16,)
|
||||
|
||||
def test_mixed_weights(self):
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
|
||||
|
||||
weights = {
|
||||
"conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D
|
||||
"conv2d.weight": mx.zeros((8, 4, 3, 3)), # 4D
|
||||
"linear.weight": mx.zeros((8, 4)), # 2D
|
||||
"norm.weight": mx.zeros((8,)), # 1D
|
||||
}
|
||||
out = sanitize_wan_vae_weights(weights)
|
||||
assert out["conv3d.weight"].shape == (8, 3, 3, 3, 4)
|
||||
assert out["conv2d.weight"].shape == (8, 3, 3, 4)
|
||||
assert out["linear.weight"].shape == (8, 4)
|
||||
assert out["norm.weight"].shape == (8,)
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)),
|
||||
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)),
|
||||
"decoder.norm.weight": mx.zeros((64,)),
|
||||
"decoder.bias": mx.zeros((16,)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.convert"):
|
||||
sanitize_wan_vae_weights(weights)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wan2.1 Conversion Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWan21Convert:
|
||||
"""Tests for Wan2.1 conversion support."""
|
||||
|
||||
def test_auto_detect_wan21(self, tmp_path):
|
||||
"""Auto-detect single-model directory as Wan2.1."""
|
||||
# Create a Wan2.1-style directory (no low_noise_model subdir)
|
||||
(tmp_path / "dummy.safetensors").touch()
|
||||
# The auto-detect logic: no low_noise_model dir → 2.1
|
||||
|
||||
low = tmp_path / "low_noise_model"
|
||||
assert not low.exists()
|
||||
# Simulates auto detection
|
||||
version = "2.2" if low.exists() else "2.1"
|
||||
assert version == "2.1"
|
||||
|
||||
def test_auto_detect_wan22(self, tmp_path):
|
||||
"""Auto-detect dual-model directory as Wan2.2."""
|
||||
(tmp_path / "low_noise_model").mkdir()
|
||||
(tmp_path / "high_noise_model").mkdir()
|
||||
|
||||
low = tmp_path / "low_noise_model"
|
||||
assert low.exists()
|
||||
version = "2.2" if low.exists() else "2.1"
|
||||
assert version == "2.2"
|
||||
|
||||
def test_wan21_config_saved_correctly(self):
|
||||
"""Verify config dict has correct fields for Wan2.1."""
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
d = config.to_dict()
|
||||
assert d["model_version"] == "2.1"
|
||||
assert d["dual_model"] is False
|
||||
assert d["sample_steps"] == 50
|
||||
assert d["sample_shift"] == 5.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Encoder Weight Sanitization Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSanitizeEncoderWeights:
|
||||
"""Tests for sanitize_wan22_vae_weights with include_encoder."""
|
||||
|
||||
def test_exclude_encoder_by_default(self):
|
||||
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
}
|
||||
out = sanitize_wan22_vae_weights(weights, include_encoder=False)
|
||||
assert "conv2.weight" in out
|
||||
assert not any("encoder" in k or k.startswith("conv1") for k in out)
|
||||
|
||||
def test_include_encoder(self):
|
||||
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
}
|
||||
out = sanitize_wan22_vae_weights(weights, include_encoder=True)
|
||||
assert "encoder.conv1.weight" in out
|
||||
assert "conv1.weight" in out
|
||||
assert "conv2.weight" in out
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.vae22"):
|
||||
sanitize_wan22_vae_weights(weights, include_encoder=True)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
def test_no_unconsumed_keys_exclude_encoder(self, caplog):
|
||||
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.vae22"):
|
||||
sanitize_wan22_vae_weights(weights, include_encoder=False)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
248
tests/test_wan_generate.py
Normal file
248
tests/test_wan_generate.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Tests for end-to-end generation and I2V mask construction."""
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from wan_test_helpers import _make_tiny_config
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: end-to-end tiny model forward pass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEndToEnd:
|
||||
"""End-to-end test with tiny model (no real weights needed)."""
|
||||
|
||||
def test_tiny_model_denoise_step(self):
|
||||
"""Simulate one denoising step with tiny model."""
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(42)
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(5, shift=3.0)
|
||||
|
||||
latents = mx.random.normal((C, F, H, W))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
|
||||
# One step
|
||||
t = sched.timesteps[0]
|
||||
pred = model([latents], mx.array([t.item()]), [context], seq_len)[0]
|
||||
latents_next = sched.step(pred[None], t, latents[None]).squeeze(0)
|
||||
mx.eval(latents_next)
|
||||
|
||||
assert latents_next.shape == (C, F, H, W)
|
||||
# Should differ from original noise
|
||||
assert not np.allclose(np.array(latents_next), np.array(latents), atol=1e-5)
|
||||
|
||||
def test_tiny_model_full_loop(self):
|
||||
"""Run a complete (tiny) diffusion loop."""
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(123)
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
num_steps = 3
|
||||
sched.set_timesteps(num_steps, shift=3.0)
|
||||
|
||||
latents = mx.random.normal((C, F, H, W))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
|
||||
for i in range(num_steps):
|
||||
t = sched.timesteps[i]
|
||||
pred = model([latents], mx.array([t.item()]), [context], seq_len)[0]
|
||||
latents = sched.step(pred[None], t, latents[None]).squeeze(0)
|
||||
mx.eval(latents)
|
||||
|
||||
assert latents.shape == (C, F, H, W)
|
||||
assert not mx.any(mx.isnan(latents)).item(), "NaN in output"
|
||||
assert not mx.any(mx.isinf(latents)).item(), "Inf in output"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# I2V Mask Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestI2VMask:
|
||||
"""Tests for _build_i2v_mask."""
|
||||
|
||||
def test_mask_shapes(self):
|
||||
from mlx_video.models.wan_2.generate import _build_i2v_mask
|
||||
|
||||
z_shape = (48, 5, 4, 4) # C, T, H, W
|
||||
patch_size = (1, 2, 2)
|
||||
mask, mask_tokens = _build_i2v_mask(z_shape, patch_size)
|
||||
assert mask.shape == z_shape
|
||||
# Tokens: T=5, H/2=2, W/2=2 → 5*2*2 = 20
|
||||
assert mask_tokens.shape == (1, 20)
|
||||
|
||||
def test_first_frame_zero(self):
|
||||
from mlx_video.models.wan_2.generate import _build_i2v_mask
|
||||
|
||||
z_shape = (48, 5, 4, 4)
|
||||
mask, mask_tokens = _build_i2v_mask(z_shape, (1, 2, 2))
|
||||
mx.eval(mask, mask_tokens)
|
||||
# First temporal position should be 0
|
||||
assert float(mask[:, 0, :, :].max()) == 0.0
|
||||
# Rest should be 1
|
||||
assert float(mask[:, 1:, :, :].min()) == 1.0
|
||||
# First-frame tokens (T=0) should be 0 in mask_tokens
|
||||
# With T=5, H'=2, W'=2: first 4 tokens are frame 0
|
||||
assert float(mask_tokens[0, :4].max()) == 0.0
|
||||
assert float(mask_tokens[0, 4:].min()) == 1.0
|
||||
|
||||
|
||||
class TestI2VMaskAlignment:
|
||||
"""Tests that I2V mask works correctly with various aligned dimensions."""
|
||||
|
||||
def test_mask_with_ti2v_dimensions(self):
|
||||
"""Mask should work with TI2V-5B typical dimensions."""
|
||||
from mlx_video.models.wan_2.generate import _build_i2v_mask
|
||||
|
||||
# TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2)
|
||||
# 704x1280 → latent 44x80, t_latent=21 for 81 frames
|
||||
z_shape = (48, 21, 44, 80)
|
||||
patch_size = (1, 2, 2)
|
||||
mask, mask_tokens = _build_i2v_mask(z_shape, patch_size)
|
||||
mx.eval(mask, mask_tokens)
|
||||
|
||||
assert mask.shape == z_shape
|
||||
assert float(mask[:, 0].max()) == 0.0
|
||||
assert float(mask[:, 1:].min()) == 1.0
|
||||
|
||||
expected_tokens = 21 * 22 * 40 # T * (H/ph) * (W/pw)
|
||||
assert mask_tokens.shape == (1, expected_tokens)
|
||||
first_frame_tokens = 1 * 22 * 40 # pt=1
|
||||
assert float(mask_tokens[0, :first_frame_tokens].max()) == 0.0
|
||||
assert float(mask_tokens[0, first_frame_tokens:].min()) == 1.0
|
||||
|
||||
def test_mask_per_token_timestep(self):
|
||||
"""Per-token timesteps: first-frame tokens get t=0, rest get t=sigma."""
|
||||
from mlx_video.models.wan_2.generate import _build_i2v_mask
|
||||
|
||||
z_shape = (4, 3, 4, 4)
|
||||
patch_size = (1, 2, 2)
|
||||
_, mask_tokens = _build_i2v_mask(z_shape, patch_size)
|
||||
mx.eval(mask_tokens)
|
||||
|
||||
timestep_val = 0.8
|
||||
t_tokens = mask_tokens * timestep_val
|
||||
mx.eval(t_tokens)
|
||||
|
||||
first_tokens = 1 * 2 * 2 # pt * (H/ph) * (W/pw)
|
||||
np.testing.assert_allclose(np.array(t_tokens[0, :first_tokens]), 0.0, atol=1e-7)
|
||||
np.testing.assert_allclose(
|
||||
np.array(t_tokens[0, first_tokens:]), timestep_val, atol=1e-7
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dimension Alignment Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDimensionAlignment:
|
||||
"""Tests for automatic dimension alignment in generate_wan."""
|
||||
|
||||
def test_already_aligned(self):
|
||||
"""Dimensions already divisible by alignment factor should be unchanged."""
|
||||
# patch_size=(1,2,2), vae_stride=(4,16,16) → align = 32
|
||||
align_h = 2 * 16 # 32
|
||||
align_w = 2 * 16 # 32
|
||||
h, w = 704, 1280
|
||||
assert h % align_h == 0
|
||||
assert w % align_w == 0
|
||||
h_aligned = (h // align_h) * align_h
|
||||
w_aligned = (w // align_w) * align_w
|
||||
assert h_aligned == h
|
||||
assert w_aligned == w
|
||||
|
||||
def test_720p_rounds_down(self):
|
||||
"""720p (1280x720) should round height to 704."""
|
||||
align_h = 32
|
||||
align_w = 32
|
||||
h, w = 720, 1280
|
||||
assert h % align_h != 0 # 720 not divisible by 32
|
||||
h_aligned = (h // align_h) * align_h
|
||||
w_aligned = (w // align_w) * align_w
|
||||
assert h_aligned == 704
|
||||
assert w_aligned == 1280
|
||||
|
||||
def test_1080p_rounds_down(self):
|
||||
"""1080p (1920x1080) should round height to 1056."""
|
||||
align = 32
|
||||
h, w = 1080, 1920
|
||||
assert h % align != 0
|
||||
assert (h // align) * align == 1056
|
||||
assert (w // align) * align == 1920
|
||||
|
||||
def test_odd_sizes(self):
|
||||
"""Odd sizes should be safely rounded down."""
|
||||
align = 32
|
||||
for size in [100, 255, 513, 1023]:
|
||||
aligned = (size // align) * align
|
||||
assert aligned % align == 0
|
||||
assert aligned <= size
|
||||
assert aligned + align > size # closest lower multiple
|
||||
|
||||
def test_patchify_valid_after_alignment(self):
|
||||
"""After alignment, patchify should succeed without reshape errors."""
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
|
||||
# Simulate 720p-like scenario with tiny config
|
||||
vae_stride = config.vae_stride # (4, 8, 8)
|
||||
patch_size = config.patch_size # (1, 2, 2)
|
||||
align_h = patch_size[1] * vae_stride[1]
|
||||
align_w = patch_size[2] * vae_stride[2]
|
||||
|
||||
# Pick a height not divisible by alignment
|
||||
raw_h = align_h * 3 + 5 # e.g. 53 for align=16
|
||||
raw_w = align_w * 4
|
||||
h = (raw_h // align_h) * align_h # rounds down
|
||||
w = (raw_w // align_w) * align_w
|
||||
|
||||
C = config.in_dim
|
||||
t_latent = 1
|
||||
h_latent = h // vae_stride[1]
|
||||
w_latent = w // vae_stride[2]
|
||||
|
||||
vid = mx.random.normal((C, t_latent, h_latent, w_latent))
|
||||
patches, grid_size = model._patchify(vid)
|
||||
mx.eval(patches)
|
||||
assert patches.ndim == 3 # [1, L, dim]
|
||||
assert grid_size == (
|
||||
t_latent,
|
||||
h_latent // patch_size[1],
|
||||
w_latent // patch_size[2],
|
||||
)
|
||||
|
||||
def test_alignment_with_ti2v_config(self):
|
||||
"""TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32."""
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan22_ti2v_5b()
|
||||
align_h = config.patch_size[1] * config.vae_stride[1]
|
||||
align_w = config.patch_size[2] * config.vae_stride[2]
|
||||
assert align_h == 32
|
||||
assert align_w == 32
|
||||
# 720 not divisible
|
||||
assert 720 % align_h != 0
|
||||
# 704 is
|
||||
assert 704 % align_h == 0
|
||||
587
tests/test_wan_i2v.py
Normal file
587
tests/test_wan_i2v.py
Normal file
@@ -0,0 +1,587 @@
|
||||
"""Tests for Wan2.2 I2V-14B support."""
|
||||
|
||||
import mlx.core as mx
|
||||
from wan_test_helpers import _make_tiny_config
|
||||
|
||||
|
||||
def _make_tiny_i2v_config():
|
||||
"""Create a tiny I2V-14B config for testing."""
|
||||
config = _make_tiny_config()
|
||||
config.model_type = "i2v"
|
||||
config.in_dim = 9 # 4 noise + 4 image + 1 mask (scaled down from 16+16+4=36)
|
||||
config.out_dim = 4
|
||||
config.vae_z_dim = 4
|
||||
config.vae_stride = (4, 8, 8)
|
||||
config.dual_model = True
|
||||
config.boundary = 0.900
|
||||
config.sample_shift = 5.0
|
||||
config.sample_guide_scale = (3.5, 3.5)
|
||||
return config
|
||||
|
||||
|
||||
class TestI2VConfig:
|
||||
"""Test I2V-14B config preset."""
|
||||
|
||||
def test_wan22_i2v_14b_preset(self):
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan22_i2v_14b()
|
||||
assert config.model_type == "i2v"
|
||||
assert config.in_dim == 36
|
||||
assert config.out_dim == 16
|
||||
assert config.dim == 5120
|
||||
assert config.num_layers == 40
|
||||
assert config.dual_model is True
|
||||
assert config.boundary == 0.900
|
||||
assert config.sample_shift == 5.0
|
||||
assert config.sample_guide_scale == (3.5, 3.5)
|
||||
assert config.vae_stride == (4, 8, 8)
|
||||
assert config.vae_z_dim == 16
|
||||
|
||||
def test_i2v_vs_t2v_differences(self):
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
i2v = WanModelConfig.wan22_i2v_14b()
|
||||
t2v = WanModelConfig.wan22_t2v_14b()
|
||||
|
||||
assert i2v.model_type == "i2v"
|
||||
assert t2v.model_type == "t2v"
|
||||
assert i2v.in_dim == 36 and t2v.in_dim == 16
|
||||
assert i2v.boundary == 0.900 and t2v.boundary == 0.875
|
||||
assert i2v.sample_shift == 5.0 and t2v.sample_shift == 12.0
|
||||
|
||||
def test_i2v_serialization_roundtrip(self):
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan22_i2v_14b()
|
||||
d = config.to_dict()
|
||||
restored = WanModelConfig.from_dict(d)
|
||||
assert restored.model_type == "i2v"
|
||||
assert restored.in_dim == 36
|
||||
assert restored.boundary == 0.900
|
||||
|
||||
|
||||
class TestModelYParameter:
|
||||
"""Test y parameter channel concatenation in WanModel."""
|
||||
|
||||
def test_forward_without_y(self):
|
||||
"""Standard T2V forward pass (no y) still works."""
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
x_list = [mx.random.normal((C, F, H, W))]
|
||||
t = mx.array([500.0])
|
||||
context = [mx.random.normal((6, config.text_dim))]
|
||||
|
||||
out = model(x_list, t, context, seq_len)
|
||||
mx.eval(out[0])
|
||||
assert out[0].shape == (C, F, H, W)
|
||||
|
||||
def test_forward_with_y(self):
|
||||
"""I2V forward pass with y channel concatenation."""
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_i2v_config()
|
||||
model = WanModel(config)
|
||||
|
||||
C_noise = 4 # noise channels
|
||||
C_y = 5 # mask (1) + image (4)
|
||||
F, H, W = 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
x_list = [mx.random.normal((C_noise, F, H, W))]
|
||||
y_list = [mx.random.normal((C_y, F, H, W))]
|
||||
t = mx.array([500.0])
|
||||
context = [mx.random.normal((6, config.text_dim))]
|
||||
|
||||
out = model(x_list, t, context, seq_len, y=y_list)
|
||||
mx.eval(out[0])
|
||||
# Output should match noise channels (out_dim), not concatenated in_dim
|
||||
assert out[0].shape == (config.out_dim, F, H, W)
|
||||
|
||||
def test_y_none_is_noop(self):
|
||||
"""Passing y=None should be identical to not passing y."""
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
mx.random.seed(42)
|
||||
x = mx.random.normal((C, F, H, W))
|
||||
t = mx.array([500.0])
|
||||
ctx = [mx.random.normal((6, config.text_dim))]
|
||||
|
||||
out1 = model([x], t, ctx, seq_len)[0]
|
||||
out2 = model([x], t, ctx, seq_len, y=None)[0]
|
||||
mx.eval(out1, out2)
|
||||
assert mx.allclose(out1, out2, atol=1e-5).item()
|
||||
|
||||
def test_batched_cfg_with_y(self):
|
||||
"""Batched CFG (B=2) with y should work."""
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_i2v_config()
|
||||
model = WanModel(config)
|
||||
|
||||
C_noise, C_y = 4, 5
|
||||
F, H, W = 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
latents = mx.random.normal((C_noise, F, H, W))
|
||||
y = mx.random.normal((C_y, F, H, W))
|
||||
t = mx.array([500.0, 500.0])
|
||||
ctx = [
|
||||
mx.random.normal((6, config.text_dim)),
|
||||
mx.random.normal((6, config.text_dim)),
|
||||
]
|
||||
|
||||
out = model([latents, latents], t, ctx, seq_len, y=[y, y])
|
||||
mx.eval(out[0], out[1])
|
||||
assert len(out) == 2
|
||||
assert out[0].shape == (config.out_dim, F, H, W)
|
||||
assert out[1].shape == (config.out_dim, F, H, W)
|
||||
|
||||
|
||||
class TestVAEEncoder:
|
||||
"""Test Wan2.1 VAE encoder."""
|
||||
|
||||
def test_encoder3d_instantiation(self):
|
||||
from mlx_video.models.wan_2.vae import Encoder3d
|
||||
|
||||
enc = Encoder3d(
|
||||
dim=32, z_dim=8
|
||||
) # z_dim=8 (will output 8ch, but WanVAE wraps with z*2)
|
||||
assert enc.conv1 is not None
|
||||
assert len(enc.downsamples) > 0
|
||||
assert len(enc.middle) == 3
|
||||
|
||||
def test_encoder3d_output_shape(self):
|
||||
"""Encoder should downsample spatially by 8x and temporally by 4x."""
|
||||
from mlx_video.models.wan_2.vae import Encoder3d
|
||||
|
||||
enc = Encoder3d(dim=32, z_dim=8)
|
||||
# Random input: [B=1, 3, T=5, H=32, W=32]
|
||||
x = mx.random.normal((1, 3, 5, 32, 32))
|
||||
out = enc(x)
|
||||
mx.eval(out)
|
||||
# With default dim_mult=[1,2,4,4] and temporal_downsample=[True,True,False]:
|
||||
# Spatial: 32 -> 16 -> 8 -> 4 (3 spatial downsamples)
|
||||
# Temporal: 5 -> 3 -> 2 (2 temporal downsamples: downsample3d stride 2)
|
||||
assert out.shape[0] == 1
|
||||
assert out.shape[1] == 8 # z_dim
|
||||
assert out.shape[3] == 32 // 8 # spatial /8
|
||||
assert out.shape[4] == 32 // 8
|
||||
|
||||
def test_wan_vae_encode(self):
|
||||
"""WanVAE with encoder=True should produce normalized latents."""
|
||||
from mlx_video.models.wan_2.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16, encoder=True)
|
||||
# Input: [B=1, 3, T=5, H=32, W=32]
|
||||
x = mx.random.normal((1, 3, 5, 32, 32))
|
||||
z = vae.encode(x)
|
||||
mx.eval(z)
|
||||
assert z.shape[0] == 1
|
||||
assert z.shape[1] == 16 # z_dim
|
||||
|
||||
def test_wan_vae_encoder_flag(self):
|
||||
"""WanVAE without encoder flag should not have encoder attribute."""
|
||||
from mlx_video.models.wan_2.vae import WanVAE
|
||||
|
||||
vae_no_enc = WanVAE(z_dim=4, encoder=False)
|
||||
assert not hasattr(vae_no_enc, "encoder")
|
||||
|
||||
vae_enc = WanVAE(z_dim=4, encoder=True)
|
||||
assert hasattr(vae_enc, "encoder")
|
||||
|
||||
|
||||
class TestResampleDownsample:
|
||||
"""Test downsample modes in Resample."""
|
||||
|
||||
def test_downsample2d(self):
|
||||
from mlx_video.models.wan_2.vae import Resample
|
||||
|
||||
r = Resample(dim=16, mode="downsample2d")
|
||||
x = mx.random.normal((1, 16, 2, 8, 8))
|
||||
out = r(x)
|
||||
mx.eval(out)
|
||||
# Spatial /2, temporal unchanged, channels same
|
||||
assert out.shape == (1, 16, 2, 4, 4)
|
||||
|
||||
def test_downsample3d(self):
|
||||
from mlx_video.models.wan_2.vae import Resample
|
||||
|
||||
r = Resample(dim=16, mode="downsample3d")
|
||||
x = mx.random.normal((1, 16, 4, 8, 8))
|
||||
out = r(x)
|
||||
mx.eval(out)
|
||||
# Spatial /2, temporal /2, channels same
|
||||
assert out.shape == (1, 16, 2, 4, 4)
|
||||
|
||||
def test_upsample2d_still_works(self):
|
||||
from mlx_video.models.wan_2.vae import Resample
|
||||
|
||||
r = Resample(dim=16, mode="upsample2d")
|
||||
x = mx.random.normal((1, 16, 2, 4, 4))
|
||||
out = r(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 8, 2, 8, 8)
|
||||
|
||||
def test_upsample3d_still_works(self):
|
||||
from mlx_video.models.wan_2.vae import Resample
|
||||
|
||||
r = Resample(dim=16, mode="upsample3d")
|
||||
x = mx.random.normal((1, 16, 2, 4, 4))
|
||||
out = r(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 8, 4, 8, 8)
|
||||
|
||||
|
||||
class TestI2VMaskConstruction:
|
||||
"""Test mask construction for I2V-14B."""
|
||||
|
||||
def test_mask_shape(self):
|
||||
"""I2V-14B mask should have 4 channels with correct temporal structure."""
|
||||
num_frames = 81
|
||||
h_latent, w_latent = 10, 18 # example latent dims
|
||||
t_latent = (num_frames - 1) // 4 + 1 # = 21
|
||||
|
||||
# Build mask following reference logic
|
||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||
msk = mx.concatenate(
|
||||
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
|
||||
)
|
||||
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
||||
|
||||
assert msk.shape == (4, t_latent, h_latent, w_latent)
|
||||
|
||||
def test_mask_values(self):
|
||||
"""First temporal position should be 1, rest 0."""
|
||||
num_frames = 9
|
||||
h_latent, w_latent = 4, 4
|
||||
t_latent = (num_frames - 1) // 4 + 1 # = 3
|
||||
|
||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||
msk = mx.concatenate(
|
||||
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
|
||||
)
|
||||
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||
msk = msk.transpose(0, 2, 1, 3, 4)[0]
|
||||
|
||||
mx.eval(msk)
|
||||
# First temporal position: all 4 channels should be 1
|
||||
assert mx.all(msk[:, 0] == 1.0).item()
|
||||
# Rest: all should be 0
|
||||
assert mx.all(msk[:, 1:] == 0.0).item()
|
||||
|
||||
def test_y_tensor_shape(self):
|
||||
"""y = concat([mask_4ch, encoded_video_16ch]) should be 20 channels."""
|
||||
mask = mx.zeros((4, 5, 10, 18))
|
||||
encoded = mx.zeros((16, 5, 10, 18))
|
||||
y = mx.concatenate([mask, encoded], axis=0)
|
||||
assert y.shape == (20, 5, 10, 18)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: I2V end-to-end pipeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestI2VEndToEndPipeline:
|
||||
"""Full I2V pipeline: image → preprocess → VAE encode → y tensor → denoise → VAE decode."""
|
||||
|
||||
def test_full_i2v_pipeline(self):
|
||||
"""End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode."""
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan_2.vae import WanVAE
|
||||
|
||||
mx.random.seed(0)
|
||||
|
||||
# --- Tiny I2V model config (z_dim=16 to match VAE normalization stats) ---
|
||||
config = _make_tiny_i2v_config()
|
||||
config.vae_z_dim = 16
|
||||
config.out_dim = 16 # must match VAE z_dim for decode
|
||||
config.in_dim = (
|
||||
16 + 4 + 16
|
||||
) # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36
|
||||
model = WanModel(config)
|
||||
|
||||
# --- Tiny VAE (with encoder) ---
|
||||
vae = WanVAE(z_dim=config.vae_z_dim, encoder=True)
|
||||
|
||||
# --- Synthetic image: [B=1, 3, T=1, H=32, W=32] in [-1, 1] ---
|
||||
height, width = 32, 32
|
||||
num_frames = 5 # small temporal extent
|
||||
img = mx.random.uniform(-1, 1, (1, 3, 1, height, width))
|
||||
|
||||
# Build video: first frame = image, rest = zeros -> [1, 3, F, H, W]
|
||||
video = mx.concatenate(
|
||||
[
|
||||
img,
|
||||
mx.zeros((1, 3, num_frames - 1, height, width)),
|
||||
],
|
||||
axis=2,
|
||||
)
|
||||
|
||||
# --- VAE encode ---
|
||||
z_video = vae.encode(video) # [1, z_dim, T_lat, H_lat, W_lat]
|
||||
mx.eval(z_video)
|
||||
assert z_video.ndim == 5
|
||||
assert z_video.shape[1] == config.vae_z_dim
|
||||
|
||||
z_video = z_video[0] # [z_dim, T_lat, H_lat, W_lat]
|
||||
t_latent = z_video.shape[1]
|
||||
h_latent = z_video.shape[2]
|
||||
w_latent = z_video.shape[3]
|
||||
|
||||
# --- Build I2V mask (4 channels) ---
|
||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||
msk = mx.concatenate(
|
||||
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
|
||||
)
|
||||
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
||||
|
||||
# --- Build y tensor: [mask(4ch) + encoded(z_dim ch)] ---
|
||||
y_i2v = mx.concatenate([msk, z_video], axis=0)
|
||||
mx.eval(y_i2v)
|
||||
assert y_i2v.shape[0] == 4 + config.vae_z_dim
|
||||
|
||||
# --- Denoising loop (2 steps) ---
|
||||
C_noise = config.out_dim # noise channels
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (t_latent // pt) * (h_latent // ph) * (w_latent // pw)
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
num_steps = 2
|
||||
sched.set_timesteps(num_steps, shift=config.sample_shift)
|
||||
|
||||
latents = mx.random.normal((C_noise, t_latent, h_latent, w_latent))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
|
||||
for i in range(num_steps):
|
||||
t_val = sched.timesteps[i].item()
|
||||
pred = model(
|
||||
[latents],
|
||||
mx.array([t_val]),
|
||||
[context],
|
||||
seq_len,
|
||||
y=[y_i2v],
|
||||
)[0]
|
||||
latents = sched.step(pred[None], t_val, latents[None]).squeeze(0)
|
||||
mx.eval(latents)
|
||||
|
||||
assert latents.shape == (C_noise, t_latent, h_latent, w_latent)
|
||||
assert not mx.any(mx.isnan(latents)).item(), "NaN in denoised latents"
|
||||
assert not mx.any(mx.isinf(latents)).item(), "Inf in denoised latents"
|
||||
|
||||
# --- VAE decode ---
|
||||
decoded = vae.decode(latents[None]) # [1, 3, T_out, H_out, W_out]
|
||||
mx.eval(decoded)
|
||||
assert decoded.ndim == 5
|
||||
assert decoded.shape[0] == 1
|
||||
assert decoded.shape[1] == 3 # RGB output
|
||||
assert not mx.any(mx.isnan(decoded)).item(), "NaN in decoded video"
|
||||
assert not mx.any(mx.isinf(decoded)).item(), "Inf in decoded video"
|
||||
# VAE decode clips to [-1, 1]
|
||||
assert float(decoded.max()) <= 1.0
|
||||
assert float(decoded.min()) >= -1.0
|
||||
|
||||
|
||||
class TestDualModelSwitching:
|
||||
"""Test dual-model selection logic: high_noise vs low_noise based on boundary."""
|
||||
|
||||
def test_model_selection_by_timestep(self):
|
||||
"""Verify high_noise model used for timesteps >= boundary, low_noise otherwise."""
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(1)
|
||||
config = _make_tiny_i2v_config()
|
||||
assert config.dual_model is True
|
||||
|
||||
high_noise_model = WanModel(config)
|
||||
low_noise_model = WanModel(config)
|
||||
|
||||
boundary = config.boundary * config.num_train_timesteps # 0.9 * 1000 = 900
|
||||
|
||||
C_noise = config.out_dim # 4
|
||||
C_y = config.in_dim - config.out_dim # 9 - 4 = 5
|
||||
F, H, W = 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
num_steps = 5
|
||||
sched.set_timesteps(num_steps, shift=config.sample_shift)
|
||||
|
||||
guide_scale = config.sample_guide_scale # (3.5, 3.5)
|
||||
assert isinstance(guide_scale, tuple) and len(guide_scale) == 2
|
||||
|
||||
latents = mx.random.normal((C_noise, F, H, W))
|
||||
y_i2v = mx.random.normal((C_y, F, H, W))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
|
||||
high_used_steps = []
|
||||
low_used_steps = []
|
||||
|
||||
timestep_list = sched.timesteps.tolist()
|
||||
for i in range(num_steps):
|
||||
timestep_val = timestep_list[i]
|
||||
|
||||
if timestep_val >= boundary:
|
||||
model = high_noise_model
|
||||
gs = guide_scale[1]
|
||||
high_used_steps.append(i)
|
||||
else:
|
||||
model = low_noise_model
|
||||
gs = guide_scale[0]
|
||||
low_used_steps.append(i)
|
||||
|
||||
# CFG pass: cond + uncond
|
||||
preds = model(
|
||||
[latents, latents],
|
||||
mx.array([timestep_val, timestep_val]),
|
||||
[context, context],
|
||||
seq_len,
|
||||
y=[y_i2v, y_i2v],
|
||||
)
|
||||
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
|
||||
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(
|
||||
0
|
||||
)
|
||||
mx.eval(latents)
|
||||
|
||||
# With shift=5.0, early timesteps should be high (>=900), later ones low
|
||||
assert len(high_used_steps) > 0, "High-noise model was never selected"
|
||||
assert len(low_used_steps) > 0, "Low-noise model was never selected"
|
||||
# High-noise steps should come before low-noise steps (timesteps decrease)
|
||||
if high_used_steps and low_used_steps:
|
||||
assert max(high_used_steps) < min(low_used_steps) or min(
|
||||
high_used_steps
|
||||
) < max(low_used_steps), "Model switching should happen during the loop"
|
||||
|
||||
assert latents.shape == (C_noise, F, H, W)
|
||||
assert not mx.any(mx.isnan(latents)).item()
|
||||
|
||||
def test_guide_scale_tuple_applied_per_model(self):
|
||||
"""Verify (low_gs, high_gs) tuple applies different scales per model."""
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(2)
|
||||
config = _make_tiny_i2v_config()
|
||||
config.sample_guide_scale = (2.0, 5.0) # distinct values
|
||||
|
||||
model = WanModel(config)
|
||||
boundary = config.boundary * config.num_train_timesteps
|
||||
|
||||
C_noise = config.out_dim
|
||||
F, H, W = 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(5, shift=config.sample_shift)
|
||||
|
||||
latents = mx.random.normal((C_noise, F, H, W))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
guide_scale = config.sample_guide_scale
|
||||
C_y = config.in_dim - config.out_dim # y channels
|
||||
y_i2v = mx.random.normal((C_y, F, H, W))
|
||||
|
||||
# Track which guide scale was used at each step
|
||||
gs_per_step = []
|
||||
|
||||
timestep_list = sched.timesteps.tolist()
|
||||
for i in range(5):
|
||||
timestep_val = timestep_list[i]
|
||||
|
||||
if timestep_val >= boundary:
|
||||
gs = guide_scale[1] # high_gs = 5.0
|
||||
else:
|
||||
gs = guide_scale[0] # low_gs = 2.0
|
||||
gs_per_step.append(gs)
|
||||
|
||||
pred = model(
|
||||
[latents, latents],
|
||||
mx.array([timestep_val, timestep_val]),
|
||||
[context, context],
|
||||
seq_len,
|
||||
y=[y_i2v, y_i2v],
|
||||
)
|
||||
noise_pred = pred[1] + gs * (pred[0] - pred[1])
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(
|
||||
0
|
||||
)
|
||||
mx.eval(latents)
|
||||
|
||||
# Verify both guide scales were used
|
||||
assert 5.0 in gs_per_step, "High guide scale (5.0) was never used"
|
||||
assert 2.0 in gs_per_step, "Low guide scale (2.0) was never used"
|
||||
# High gs should appear first (high timesteps come first)
|
||||
first_high = gs_per_step.index(5.0)
|
||||
last_low = len(gs_per_step) - 1 - gs_per_step[::-1].index(2.0)
|
||||
assert first_high < last_low, "High gs steps should precede low gs steps"
|
||||
|
||||
def test_single_model_fallback_with_tuple_guide_scale(self):
|
||||
"""When dual_model=False, guide_scale tuple should use first element."""
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(3)
|
||||
config = _make_tiny_config()
|
||||
config.dual_model = False
|
||||
config.sample_guide_scale = (3.0, 5.0)
|
||||
|
||||
model = WanModel(config)
|
||||
guide_scale = config.sample_guide_scale
|
||||
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(3, shift=3.0)
|
||||
|
||||
latents = mx.random.normal((C, F, H, W))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
|
||||
# Mimic generate_wan.py single-model logic:
|
||||
# gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
||||
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
||||
assert gs == 3.0, "Single model should use first element of guide_scale tuple"
|
||||
|
||||
for i in range(3):
|
||||
t_val = sched.timesteps[i].item()
|
||||
pred = model(
|
||||
[latents, latents],
|
||||
mx.array([t_val, t_val]),
|
||||
[context, context],
|
||||
seq_len,
|
||||
)
|
||||
noise_pred = pred[1] + gs * (pred[0] - pred[1])
|
||||
latents = sched.step(noise_pred[None], t_val, latents[None]).squeeze(0)
|
||||
mx.eval(latents)
|
||||
|
||||
assert latents.shape == (C, F, H, W)
|
||||
assert not mx.any(mx.isnan(latents)).item()
|
||||
362
tests/test_wan_lora.py
Normal file
362
tests/test_wan_lora.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""Tests for LoRA loading and application."""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import pytest
|
||||
|
||||
|
||||
class TestLoRATypes:
|
||||
"""Test LoRA data structures."""
|
||||
|
||||
def test_lora_weights_scale(self):
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
w = LoRAWeights(
|
||||
lora_A=mx.zeros((16, 64)),
|
||||
lora_B=mx.zeros((128, 16)),
|
||||
rank=16,
|
||||
alpha=32.0,
|
||||
module_name="test",
|
||||
)
|
||||
assert w.scale == 2.0
|
||||
|
||||
def test_lora_weights_scale_default(self):
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
w = LoRAWeights(
|
||||
lora_A=mx.zeros((16, 64)),
|
||||
lora_B=mx.zeros((128, 16)),
|
||||
rank=16,
|
||||
alpha=16.0,
|
||||
module_name="test",
|
||||
)
|
||||
assert w.scale == 1.0
|
||||
|
||||
def test_applied_lora_delta(self):
|
||||
from mlx_video.lora.types import AppliedLoRA, LoRAWeights
|
||||
|
||||
lora_a = mx.ones((2, 4))
|
||||
lora_b = mx.ones((8, 2))
|
||||
w = LoRAWeights(
|
||||
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
|
||||
)
|
||||
applied = AppliedLoRA(weights=w, strength=0.5)
|
||||
delta = applied.compute_delta()
|
||||
# scale=1.0, strength=0.5, B@A = [[2,2,2,2]]*8 (each row sum of 2 ones)
|
||||
expected = 0.5 * mx.ones((8, 4)) * 2.0
|
||||
assert mx.allclose(delta, expected).item()
|
||||
|
||||
|
||||
class TestLoRALoader:
|
||||
"""Test LoRA weight loading from safetensors."""
|
||||
|
||||
def _make_lora_file(
|
||||
self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"
|
||||
):
|
||||
"""Helper to create a mock LoRA safetensors file."""
|
||||
weights = {}
|
||||
for name in module_names:
|
||||
if key_format == "AB":
|
||||
weights[f"{name}.lora_A.weight"] = mx.random.normal((rank, in_dim))
|
||||
weights[f"{name}.lora_B.weight"] = mx.random.normal((out_dim, rank))
|
||||
else:
|
||||
weights[f"{name}.lora_down.weight"] = mx.random.normal((rank, in_dim))
|
||||
weights[f"{name}.lora_up.weight"] = mx.random.normal((out_dim, rank))
|
||||
path = Path(tmp_dir) / "test_lora.safetensors"
|
||||
mx.save_safetensors(str(path), weights)
|
||||
return path
|
||||
|
||||
def test_load_lora_a_b_format(self):
|
||||
from mlx_video.lora.loader import load_lora_weights
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = self._make_lora_file(tmp, ["blocks.0.self_attn.q"], key_format="AB")
|
||||
lora_weights = load_lora_weights(path)
|
||||
assert "blocks.0.self_attn.q" in lora_weights
|
||||
w = lora_weights["blocks.0.self_attn.q"]
|
||||
assert w.rank == 4
|
||||
assert w.alpha == 4.0 # default: alpha == rank
|
||||
assert w.lora_A.shape == (4, 64)
|
||||
assert w.lora_B.shape == (128, 4)
|
||||
|
||||
def test_load_lora_down_up_format(self):
|
||||
from mlx_video.lora.loader import load_lora_weights
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = self._make_lora_file(
|
||||
tmp, ["blocks.0.self_attn.q"], key_format="down_up"
|
||||
)
|
||||
lora_weights = load_lora_weights(path)
|
||||
assert "blocks.0.self_attn.q" in lora_weights
|
||||
|
||||
def test_load_multiple_modules(self):
|
||||
from mlx_video.lora.loader import load_lora_weights
|
||||
|
||||
modules = [
|
||||
"blocks.0.self_attn.q",
|
||||
"blocks.0.self_attn.k",
|
||||
"blocks.0.ffn.fc1",
|
||||
]
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = self._make_lora_file(tmp, modules)
|
||||
lora_weights = load_lora_weights(path)
|
||||
assert len(lora_weights) == 3
|
||||
for name in modules:
|
||||
assert name in lora_weights
|
||||
|
||||
def test_load_with_alpha(self):
|
||||
from mlx_video.lora.loader import load_lora_weights
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
weights = {
|
||||
"test.lora_A.weight": mx.random.normal((8, 64)),
|
||||
"test.lora_B.weight": mx.random.normal((128, 8)),
|
||||
"test.alpha": mx.array(16.0),
|
||||
}
|
||||
path = Path(tmp) / "lora.safetensors"
|
||||
mx.save_safetensors(str(path), weights)
|
||||
lora_weights = load_lora_weights(path)
|
||||
assert lora_weights["test"].alpha == 16.0
|
||||
assert lora_weights["test"].rank == 8
|
||||
assert lora_weights["test"].scale == 2.0
|
||||
|
||||
def test_file_not_found(self):
|
||||
from mlx_video.lora.loader import load_lora_weights
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_lora_weights(Path("/nonexistent/lora.safetensors"))
|
||||
|
||||
|
||||
class TestWanKeyNormalization:
|
||||
"""Test Wan2.2 LoRA key normalization."""
|
||||
|
||||
def _wan_model_keys(self):
|
||||
"""Simulate typical Wan2.2 MLX model weight keys."""
|
||||
keys = set()
|
||||
for i in range(2):
|
||||
for layer in [
|
||||
"self_attn.q",
|
||||
"self_attn.k",
|
||||
"self_attn.v",
|
||||
"self_attn.o",
|
||||
"cross_attn.q",
|
||||
"cross_attn.k",
|
||||
"cross_attn.v",
|
||||
"cross_attn.o",
|
||||
]:
|
||||
keys.add(f"blocks.{i}.{layer}.weight")
|
||||
keys.add(f"blocks.{i}.ffn.fc1.weight")
|
||||
keys.add(f"blocks.{i}.ffn.fc2.weight")
|
||||
keys.add("text_embedding_0.weight")
|
||||
keys.add("text_embedding_1.weight")
|
||||
keys.add("time_embedding_0.weight")
|
||||
keys.add("time_embedding_1.weight")
|
||||
keys.add("time_projection.weight")
|
||||
keys.add("patch_embedding_proj.weight")
|
||||
return keys
|
||||
|
||||
def test_direct_match(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert (
|
||||
_normalize_wan_lora_key("blocks.0.self_attn.q", keys)
|
||||
== "blocks.0.self_attn.q"
|
||||
)
|
||||
|
||||
def test_strip_diffusion_model_prefix(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
result = _normalize_wan_lora_key("diffusion_model.blocks.0.self_attn.q", keys)
|
||||
assert result == "blocks.0.self_attn.q"
|
||||
|
||||
def test_strip_model_prefix(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
result = _normalize_wan_lora_key(
|
||||
"model.diffusion_model.blocks.0.self_attn.k", keys
|
||||
)
|
||||
assert result == "blocks.0.self_attn.k"
|
||||
|
||||
def test_ffn_key_mapping(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert _normalize_wan_lora_key("blocks.0.ffn.0", keys) == "blocks.0.ffn.fc1"
|
||||
assert _normalize_wan_lora_key("blocks.0.ffn.2", keys) == "blocks.0.ffn.fc2"
|
||||
|
||||
def test_text_embedding_mapping(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert _normalize_wan_lora_key("text_embedding.0", keys) == "text_embedding_0"
|
||||
assert _normalize_wan_lora_key("text_embedding.2", keys) == "text_embedding_1"
|
||||
|
||||
def test_time_embedding_mapping(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert _normalize_wan_lora_key("time_embedding.0", keys) == "time_embedding_0"
|
||||
assert _normalize_wan_lora_key("time_embedding.2", keys) == "time_embedding_1"
|
||||
|
||||
def test_time_projection_mapping(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert _normalize_wan_lora_key("time_projection.1", keys) == "time_projection"
|
||||
|
||||
def test_patch_embedding_mapping(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert (
|
||||
_normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj"
|
||||
)
|
||||
|
||||
def test_combined_prefix_and_ffn(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
result = _normalize_wan_lora_key("diffusion_model.blocks.1.ffn.0", keys)
|
||||
assert result == "blocks.1.ffn.fc1"
|
||||
|
||||
|
||||
class TestApplyLoRA:
|
||||
"""Test LoRA delta application to weights."""
|
||||
|
||||
def test_preserves_bfloat16_dtype(self):
|
||||
"""LoRA delta must not promote bfloat16 weights to float32."""
|
||||
from mlx_video.lora.apply import apply_lora_to_linear
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
original = mx.ones((8, 4), dtype=mx.bfloat16)
|
||||
# LoRA weights in float32 (typical when loaded from safetensors)
|
||||
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
|
||||
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
|
||||
w = LoRAWeights(
|
||||
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
|
||||
)
|
||||
result = apply_lora_to_linear(original, [(w, 1.0)])
|
||||
assert result.dtype == mx.bfloat16, f"Expected bfloat16, got {result.dtype}"
|
||||
|
||||
def test_preserves_float16_dtype(self):
|
||||
from mlx_video.lora.apply import apply_lora_to_linear
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
original = mx.ones((8, 4), dtype=mx.float16)
|
||||
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
|
||||
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
|
||||
w = LoRAWeights(
|
||||
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
|
||||
)
|
||||
result = apply_lora_to_linear(original, [(w, 1.0)])
|
||||
assert result.dtype == mx.float16, f"Expected float16, got {result.dtype}"
|
||||
|
||||
def test_apply_single_lora(self):
|
||||
from mlx_video.lora.apply import apply_lora_to_linear
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
original = mx.ones((8, 4))
|
||||
lora_a = mx.ones((2, 4)) * 0.1
|
||||
lora_b = mx.ones((8, 2)) * 0.1
|
||||
w = LoRAWeights(
|
||||
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
|
||||
)
|
||||
result = apply_lora_to_linear(original, [(w, 1.0)])
|
||||
# delta = 1.0 * (B @ A) = ones(8,2)*0.1 @ ones(2,4)*0.1 = 0.02 * ones(8,4)
|
||||
expected = original + 0.02 * mx.ones((8, 4))
|
||||
assert mx.allclose(result, expected, atol=1e-6).item()
|
||||
|
||||
def test_apply_multiple_loras(self):
|
||||
from mlx_video.lora.apply import apply_lora_to_linear
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
original = mx.zeros((8, 4))
|
||||
w1 = LoRAWeights(
|
||||
lora_A=mx.ones((2, 4)),
|
||||
lora_B=mx.ones((8, 2)),
|
||||
rank=2,
|
||||
alpha=2.0,
|
||||
module_name="a",
|
||||
)
|
||||
w2 = LoRAWeights(
|
||||
lora_A=mx.ones((2, 4)) * 2,
|
||||
lora_B=mx.ones((8, 2)) * 2,
|
||||
rank=2,
|
||||
alpha=4.0,
|
||||
module_name="b",
|
||||
)
|
||||
result = apply_lora_to_linear(original, [(w1, 1.0), (w2, 0.5)])
|
||||
# w1 delta: 1.0 * 1.0 * (ones(8,2) @ ones(2,4)) = 2 * ones(8,4)
|
||||
# w2 delta: 2.0 * 0.5 * (2*ones(8,2) @ 2*ones(2,4)) = 1.0 * 8*ones(8,4) = 8
|
||||
delta1 = mx.ones((8, 4)) * 2.0
|
||||
delta2 = mx.ones((8, 4)) * 8.0
|
||||
expected = delta1 + delta2
|
||||
assert mx.allclose(result, expected, atol=1e-5).item()
|
||||
|
||||
def test_apply_loras_to_weights_dict(self):
|
||||
from mlx_video.lora.apply import apply_loras_to_weights
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
model_weights = {
|
||||
"blocks.0.self_attn.q.weight": mx.ones((128, 64)),
|
||||
"blocks.0.self_attn.k.weight": mx.ones((128, 64)),
|
||||
"blocks.0.ffn.fc1.weight": mx.ones((256, 64)),
|
||||
}
|
||||
w = LoRAWeights(
|
||||
lora_A=mx.ones((4, 64)) * 0.01,
|
||||
lora_B=mx.ones((128, 4)) * 0.01,
|
||||
rank=4,
|
||||
alpha=4.0,
|
||||
module_name="blocks.0.self_attn.q",
|
||||
)
|
||||
module_to_loras = {"blocks.0.self_attn.q": [(w, 1.0)]}
|
||||
result = apply_loras_to_weights(model_weights, module_to_loras)
|
||||
# Only q should be modified
|
||||
assert not mx.array_equal(
|
||||
result["blocks.0.self_attn.q.weight"],
|
||||
model_weights["blocks.0.self_attn.q.weight"],
|
||||
).item()
|
||||
assert mx.array_equal(
|
||||
result["blocks.0.self_attn.k.weight"],
|
||||
model_weights["blocks.0.self_attn.k.weight"],
|
||||
).item()
|
||||
|
||||
|
||||
class TestEndToEnd:
|
||||
"""End-to-end LoRA loading and application."""
|
||||
|
||||
def test_load_and_apply_loras(self):
|
||||
from mlx_video.models.wan_2.convert import load_and_apply_loras
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
# Create mock LoRA safetensors
|
||||
rank = 4
|
||||
weights = {
|
||||
"blocks.0.self_attn.q.lora_A.weight": mx.random.normal((rank, 64)),
|
||||
"blocks.0.self_attn.q.lora_B.weight": mx.random.normal((128, rank)),
|
||||
}
|
||||
lora_path = Path(tmp) / "test.safetensors"
|
||||
mx.save_safetensors(str(lora_path), weights)
|
||||
|
||||
# Create mock model weights
|
||||
model_weights = {
|
||||
"blocks.0.self_attn.q.weight": mx.ones((128, 64)),
|
||||
"blocks.0.self_attn.k.weight": mx.ones((128, 64)),
|
||||
}
|
||||
|
||||
result = load_and_apply_loras(model_weights, [(str(lora_path), 1.0)])
|
||||
|
||||
# q weight should be modified, k unchanged
|
||||
assert not mx.array_equal(
|
||||
result["blocks.0.self_attn.q.weight"],
|
||||
model_weights["blocks.0.self_attn.q.weight"],
|
||||
).item()
|
||||
assert mx.array_equal(
|
||||
result["blocks.0.self_attn.k.weight"],
|
||||
model_weights["blocks.0.self_attn.k.weight"],
|
||||
).item()
|
||||
357
tests/test_wan_model.py
Normal file
357
tests/test_wan_model.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""Tests for Wan model components."""
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from wan_test_helpers import _make_tiny_config
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sinusoidal Embedding Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSinusoidalEmbedding:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
|
||||
|
||||
pos = mx.arange(10).astype(mx.float32)
|
||||
emb = sinusoidal_embedding_1d(256, pos)
|
||||
mx.eval(emb)
|
||||
assert emb.shape == (10, 256)
|
||||
|
||||
def test_position_zero(self):
|
||||
"""Position 0 should have cos=1 for all dims and sin=0."""
|
||||
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
|
||||
|
||||
pos = mx.array([0.0])
|
||||
emb = sinusoidal_embedding_1d(64, pos)
|
||||
mx.eval(emb)
|
||||
emb_np = np.array(emb[0])
|
||||
# First half is cos, should be 1 at position 0
|
||||
np.testing.assert_allclose(emb_np[:32], 1.0, atol=1e-5)
|
||||
# Second half is sin, should be 0 at position 0
|
||||
np.testing.assert_allclose(emb_np[32:], 0.0, atol=1e-5)
|
||||
|
||||
def test_different_positions_differ(self):
|
||||
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
|
||||
|
||||
pos = mx.array([0.0, 100.0, 999.0])
|
||||
emb = sinusoidal_embedding_1d(128, pos)
|
||||
mx.eval(emb)
|
||||
emb_np = np.array(emb)
|
||||
assert not np.allclose(emb_np[0], emb_np[1])
|
||||
assert not np.allclose(emb_np[1], emb_np[2])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Head Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHead:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan_2.wan_2 import Head
|
||||
|
||||
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
||||
B, L = 1, 24
|
||||
x = mx.random.normal((B, L, 64))
|
||||
e = mx.random.normal((B, 64)) # time embedding: [B, dim]
|
||||
out = head(x, e)
|
||||
mx.eval(out)
|
||||
expected_proj_dim = 16 * 1 * 2 * 2 # 64
|
||||
assert out.shape == (B, L, expected_proj_dim)
|
||||
|
||||
def test_modulation_shape(self):
|
||||
from mlx_video.models.wan_2.wan_2 import Head
|
||||
|
||||
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
||||
assert head.modulation.shape == (1, 2, 64)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WanModel (Tiny) Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWanModel:
|
||||
def setup_method(self):
|
||||
mx.random.seed(42)
|
||||
|
||||
def test_instantiation(self):
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
num_params = sum(p.size for _, p in nn.utils.tree_flatten(model.parameters()))
|
||||
assert num_params > 0
|
||||
|
||||
def test_patchify_shape(self):
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
# Input: [C=4, F=1, H=4, W=4]
|
||||
x = mx.random.normal((4, 1, 4, 4))
|
||||
patches, grid_size = model._patchify(x)
|
||||
mx.eval(patches)
|
||||
# Patch size (1,2,2): F'=1, H'=2, W'=2
|
||||
assert grid_size == (1, 2, 2)
|
||||
assert patches.shape == (1, 1 * 2 * 2, config.dim)
|
||||
|
||||
def test_patchify_various_sizes(self):
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
for f, h, w in [(1, 4, 4), (2, 6, 8), (3, 4, 6)]:
|
||||
x = mx.random.normal((config.in_dim, f, h, w))
|
||||
patches, (gf, gh, gw) = model._patchify(x)
|
||||
mx.eval(patches)
|
||||
pt, ph, pw = config.patch_size
|
||||
assert gf == f // pt
|
||||
assert gh == h // ph
|
||||
assert gw == w // pw
|
||||
assert patches.shape[1] == gf * gh * gw
|
||||
|
||||
def test_unpatchify_inverse(self):
|
||||
"""Patchify then unpatchify should reconstruct original spatial dims."""
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
C, F, H, W = config.in_dim, 2, 4, 6
|
||||
pt, ph, pw = config.patch_size
|
||||
F_out, H_out, W_out = F // pt, H // ph, W // pw
|
||||
L = F_out * H_out * W_out
|
||||
proj_dim = config.out_dim * pt * ph * pw
|
||||
# Simulated head output
|
||||
x = mx.random.normal((1, L, proj_dim))
|
||||
out = model.unpatchify(x, [(F_out, H_out, W_out)])
|
||||
mx.eval(out[0])
|
||||
assert out[0].shape == (config.out_dim, F, H, W)
|
||||
|
||||
def test_forward_pass(self):
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
x_list = [mx.random.normal((C, F, H, W))]
|
||||
t = mx.array([500.0])
|
||||
context = [mx.random.normal((6, config.text_dim))]
|
||||
|
||||
out = model(x_list, t, context, seq_len)
|
||||
mx.eval(out[0])
|
||||
assert len(out) == 1
|
||||
assert out[0].shape == (C, F, H, W)
|
||||
|
||||
def test_forward_batch(self):
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
x_list = [mx.random.normal((C, F, H, W)), mx.random.normal((C, F, H, W))]
|
||||
t = mx.array([500.0, 200.0])
|
||||
context = [
|
||||
mx.random.normal((6, config.text_dim)),
|
||||
mx.random.normal((4, config.text_dim)),
|
||||
]
|
||||
|
||||
out = model(x_list, t, context, seq_len)
|
||||
mx.eval(out[0], out[1])
|
||||
assert len(out) == 2
|
||||
for o in out:
|
||||
assert o.shape == (C, F, H, W)
|
||||
|
||||
def test_output_is_float32(self):
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
seq_len = (F // 1) * (H // 2) * (W // 2)
|
||||
out = model(
|
||||
[mx.random.normal((C, F, H, W))],
|
||||
mx.array([100.0]),
|
||||
[mx.random.normal((4, config.text_dim))],
|
||||
seq_len,
|
||||
)
|
||||
mx.eval(out[0])
|
||||
assert out[0].dtype == mx.float32
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wan2.1 Model Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWan21Model:
|
||||
"""Test tiny Wan2.1-style model (single model mode)."""
|
||||
|
||||
def setup_method(self):
|
||||
mx.random.seed(42)
|
||||
|
||||
def _make_tiny_wan21_config(self):
|
||||
"""Create a tiny config mimicking Wan2.1 (single model)."""
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
# Override to tiny values
|
||||
config.dim = 64
|
||||
config.ffn_dim = 128
|
||||
config.num_heads = 4
|
||||
config.num_layers = 2
|
||||
config.in_dim = 4
|
||||
config.out_dim = 4
|
||||
config.freq_dim = 32
|
||||
config.text_dim = 32
|
||||
config.text_len = 8
|
||||
return config
|
||||
|
||||
def _make_tiny_wan21_1_3b_config(self):
|
||||
"""Create a tiny config mimicking Wan2.1 1.3B."""
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
# Override to tiny values (preserve 1.3B head structure: 12 heads)
|
||||
config.dim = 48
|
||||
config.ffn_dim = 96
|
||||
config.num_heads = 4
|
||||
config.num_layers = 2
|
||||
config.in_dim = 4
|
||||
config.out_dim = 4
|
||||
config.freq_dim = 24
|
||||
config.text_dim = 24
|
||||
config.text_len = 8
|
||||
return config
|
||||
|
||||
def test_wan21_tiny_model_forward(self):
|
||||
"""Forward pass with Wan2.1 tiny config."""
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = self._make_tiny_wan21_config()
|
||||
model = WanModel(config)
|
||||
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
seq_len = (F // 1) * (H // 2) * (W // 2)
|
||||
|
||||
latents = mx.random.normal((C, F, H, W))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
t = mx.array([500.0])
|
||||
|
||||
out = model([latents], t, [context], seq_len)
|
||||
mx.eval(out)
|
||||
assert out[0].shape == (C, F, H, W)
|
||||
|
||||
def test_wan21_1_3b_tiny_model_forward(self):
|
||||
"""Forward pass with Wan2.1 1.3B tiny config."""
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = self._make_tiny_wan21_1_3b_config()
|
||||
model = WanModel(config)
|
||||
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
seq_len = (F // 1) * (H // 2) * (W // 2)
|
||||
|
||||
latents = mx.random.normal((C, F, H, W))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
t = mx.array([500.0])
|
||||
|
||||
out = model([latents], t, [context], seq_len)
|
||||
mx.eval(out)
|
||||
assert out[0].shape == (C, F, H, W)
|
||||
|
||||
def test_wan21_single_model_loop(self):
|
||||
"""Full diffusion loop with single model (Wan2.1 style)."""
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
config = self._make_tiny_wan21_config()
|
||||
model = WanModel(config)
|
||||
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
seq_len = (F // 1) * (H // 2) * (W // 2)
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(config.sample_steps, shift=config.sample_shift)
|
||||
|
||||
# Use only 3 steps for speed
|
||||
latents = mx.random.normal((C, F, H, W))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
context_null = mx.zeros((4, config.text_dim))
|
||||
gs = config.sample_guide_scale # Should be float for Wan2.1
|
||||
|
||||
assert isinstance(gs, float), "Wan2.1 guide_scale should be float"
|
||||
|
||||
for i in range(3):
|
||||
t = sched.timesteps[i]
|
||||
pred_cond = model([latents], mx.array([t.item()]), [context], seq_len)[0]
|
||||
pred_uncond = model(
|
||||
[latents], mx.array([t.item()]), [context_null], seq_len
|
||||
)[0]
|
||||
pred = pred_uncond + gs * (pred_cond - pred_uncond)
|
||||
latents = sched.step(pred[None], t, latents[None]).squeeze(0)
|
||||
mx.eval(latents)
|
||||
|
||||
assert latents.shape == (C, F, H, W)
|
||||
assert not mx.any(mx.isnan(latents)).item()
|
||||
|
||||
def test_wan21_vs_wan22_config_differences(self):
|
||||
"""Verify key differences between Wan2.1 and Wan2.2 configs."""
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
c21 = WanModelConfig.wan21_t2v_14b()
|
||||
c22 = WanModelConfig.wan22_t2v_14b()
|
||||
|
||||
# Same architecture
|
||||
assert c21.dim == c22.dim
|
||||
assert c21.num_heads == c22.num_heads
|
||||
assert c21.num_layers == c22.num_layers
|
||||
|
||||
# Different pipeline settings
|
||||
assert c21.dual_model is False
|
||||
assert c22.dual_model is True
|
||||
assert isinstance(c21.sample_guide_scale, float)
|
||||
assert isinstance(c22.sample_guide_scale, tuple)
|
||||
assert c21.sample_shift != c22.sample_shift
|
||||
assert c21.sample_steps != c22.sample_steps
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-Token Timestep Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPerTokenTimestep:
|
||||
"""Tests for per-token sinusoidal embedding."""
|
||||
|
||||
def test_1d_unchanged(self):
|
||||
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
|
||||
|
||||
pos = mx.array([0.0, 100.0, 500.0])
|
||||
emb = sinusoidal_embedding_1d(256, pos)
|
||||
assert emb.shape == (3, 256)
|
||||
|
||||
def test_2d_per_token(self):
|
||||
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
|
||||
|
||||
pos = mx.array([[0.0, 100.0, 100.0], [50.0, 50.0, 50.0]])
|
||||
emb = sinusoidal_embedding_1d(256, pos)
|
||||
assert emb.shape == (2, 3, 256)
|
||||
|
||||
def test_consistency(self):
|
||||
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
|
||||
|
||||
pos_1d = mx.array([0.0, 100.0])
|
||||
emb_1d = sinusoidal_embedding_1d(256, pos_1d)
|
||||
pos_2d = mx.array([[0.0, 100.0]])
|
||||
emb_2d = sinusoidal_embedding_1d(256, pos_2d)
|
||||
assert mx.array_equal(emb_1d[0], emb_2d[0, 0])
|
||||
assert mx.array_equal(emb_1d[1], emb_2d[0, 1])
|
||||
338
tests/test_wan_quantization.py
Normal file
338
tests/test_wan_quantization.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""Tests for Wan model quantization pipeline."""
|
||||
|
||||
import json
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.utils
|
||||
import numpy as np
|
||||
from wan_test_helpers import _make_tiny_config
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quantize Predicate Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQuantizePredicate:
|
||||
def test_matches_self_attention_layers(self):
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
for suffix in ["q", "k", "v", "o"]:
|
||||
path = f"blocks.0.self_attn.{suffix}"
|
||||
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
|
||||
|
||||
def test_matches_cross_attention_layers(self):
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
for suffix in ["q", "k", "v", "o"]:
|
||||
path = f"blocks.0.cross_attn.{suffix}"
|
||||
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
|
||||
|
||||
def test_matches_ffn_layers(self):
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear)
|
||||
assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear)
|
||||
|
||||
def test_rejects_embeddings(self):
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
for path in [
|
||||
"patch_embedding_proj",
|
||||
"text_embedding_fc1",
|
||||
"time_embedding.fc1",
|
||||
]:
|
||||
assert not _quantize_predicate(path, mock_linear), f"Should reject {path}"
|
||||
|
||||
def test_rejects_norms(self):
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
|
||||
mock_norm = nn.RMSNorm(64)
|
||||
assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm)
|
||||
|
||||
def test_rejects_non_quantizable_modules(self):
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
|
||||
mock_norm = nn.RMSNorm(64)
|
||||
# Even if path matches, module must have to_quantized
|
||||
assert not _quantize_predicate("blocks.0.self_attn.q", mock_norm)
|
||||
|
||||
def test_all_10_patterns_covered(self):
|
||||
"""Verify exactly 10 layer patterns are targeted."""
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
patterns = [
|
||||
"blocks.0.self_attn.q",
|
||||
"blocks.0.self_attn.k",
|
||||
"blocks.0.self_attn.v",
|
||||
"blocks.0.self_attn.o",
|
||||
"blocks.0.cross_attn.q",
|
||||
"blocks.0.cross_attn.k",
|
||||
"blocks.0.cross_attn.v",
|
||||
"blocks.0.cross_attn.o",
|
||||
"blocks.0.ffn.fc1",
|
||||
"blocks.0.ffn.fc2",
|
||||
]
|
||||
matched = [p for p in patterns if _quantize_predicate(p, mock_linear)]
|
||||
assert len(matched) == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quantize Round-Trip Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQuantizeRoundTrip:
|
||||
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
|
||||
"""Helper: create model, quantize, save to tmp_path."""
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
model = WanModel(config)
|
||||
nn.quantize(
|
||||
model,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
||||
)
|
||||
|
||||
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
||||
model_path = tmp_path / "model.safetensors"
|
||||
mx.save_safetensors(str(model_path), weights_dict)
|
||||
|
||||
# Write config.json
|
||||
cfg = {"quantization": {"bits": bits, "group_size": group_size}}
|
||||
with open(tmp_path / "config.json", "w") as f:
|
||||
json.dump(cfg, f)
|
||||
|
||||
return model_path, weights_dict
|
||||
|
||||
def test_4bit_roundtrip(self, tmp_path):
|
||||
config = _make_tiny_config()
|
||||
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4)
|
||||
|
||||
from mlx_video.models.wan_2.utils import load_wan_model
|
||||
|
||||
loaded = load_wan_model(
|
||||
model_path,
|
||||
config,
|
||||
quantization={"bits": 4, "group_size": 64},
|
||||
)
|
||||
|
||||
# Verify quantized layers have scales
|
||||
has_scales = any("scales" in k for k in saved_weights)
|
||||
assert has_scales, "Quantized model should have .scales tensors"
|
||||
|
||||
# Verify a self-attention layer is QuantizedLinear
|
||||
assert isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear)
|
||||
assert isinstance(loaded.blocks[0].ffn.fc1, nn.QuantizedLinear)
|
||||
|
||||
def test_8bit_roundtrip(self, tmp_path):
|
||||
config = _make_tiny_config()
|
||||
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8)
|
||||
|
||||
from mlx_video.models.wan_2.utils import load_wan_model
|
||||
|
||||
loaded = load_wan_model(
|
||||
model_path,
|
||||
config,
|
||||
quantization={"bits": 8, "group_size": 64},
|
||||
)
|
||||
|
||||
assert isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear)
|
||||
assert isinstance(loaded.blocks[0].cross_attn.k, nn.QuantizedLinear)
|
||||
|
||||
def test_non_quantized_layers_remain_linear(self, tmp_path):
|
||||
config = _make_tiny_config()
|
||||
model_path, _ = self._quantize_and_save(config, tmp_path, bits=4)
|
||||
|
||||
from mlx_video.models.wan_2.utils import load_wan_model
|
||||
|
||||
loaded = load_wan_model(
|
||||
model_path,
|
||||
config,
|
||||
quantization={"bits": 4, "group_size": 64},
|
||||
)
|
||||
|
||||
# Head should NOT be quantized (it's not in the predicate patterns)
|
||||
assert not isinstance(loaded.head, nn.QuantizedLinear)
|
||||
|
||||
def test_loading_without_quantization_flag(self, tmp_path):
|
||||
"""Loading a non-quantized model should have standard Linear layers."""
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
||||
model_path = tmp_path / "model.safetensors"
|
||||
mx.save_safetensors(str(model_path), weights_dict)
|
||||
|
||||
from mlx_video.models.wan_2.utils import load_wan_model
|
||||
|
||||
loaded = load_wan_model(model_path, config, quantization=None)
|
||||
|
||||
assert isinstance(loaded.blocks[0].self_attn.q, nn.Linear)
|
||||
assert not isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quantized Inference Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQuantizedInference:
|
||||
def _make_quantized_model(self, config, bits=4):
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
model = WanModel(config)
|
||||
nn.quantize(
|
||||
model,
|
||||
group_size=64,
|
||||
bits=bits,
|
||||
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
||||
)
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
|
||||
def test_forward_pass_4bit(self):
|
||||
config = _make_tiny_config()
|
||||
model = self._make_quantized_model(config, bits=4)
|
||||
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
x = [mx.random.normal((C, F, H, W))]
|
||||
t = mx.array([500.0])
|
||||
context = [mx.random.normal((4, config.text_dim))]
|
||||
|
||||
out = model(x, t, context, seq_len)
|
||||
mx.eval(out[0])
|
||||
|
||||
assert len(out) == 1
|
||||
assert out[0].shape == (C, F, H, W)
|
||||
|
||||
def test_forward_pass_8bit(self):
|
||||
config = _make_tiny_config()
|
||||
model = self._make_quantized_model(config, bits=8)
|
||||
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
x = [mx.random.normal((C, F, H, W))]
|
||||
t = mx.array([500.0])
|
||||
context = [mx.random.normal((4, config.text_dim))]
|
||||
|
||||
out = model(x, t, context, seq_len)
|
||||
mx.eval(out[0])
|
||||
|
||||
assert len(out) == 1
|
||||
assert out[0].shape == (C, F, H, W)
|
||||
|
||||
def test_quantized_output_differs_from_unquantized(self):
|
||||
"""Sanity check: quantization should change the weights."""
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
mx.random.seed(42)
|
||||
|
||||
# Get unquantized weights
|
||||
model = WanModel(config)
|
||||
mx.eval(model.parameters())
|
||||
orig_weight = np.array(model.blocks[0].self_attn.q.weight)
|
||||
|
||||
# Quantize
|
||||
nn.quantize(
|
||||
model,
|
||||
group_size=64,
|
||||
bits=4,
|
||||
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
||||
)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
# QuantizedLinear stores weight differently (uint32 packed)
|
||||
assert isinstance(model.blocks[0].self_attn.q, nn.QuantizedLinear)
|
||||
assert hasattr(model.blocks[0].self_attn.q, "scales")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config Metadata Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQuantizationConfig:
|
||||
def test_config_metadata_written(self, tmp_path):
|
||||
"""Verify _quantize_saved_model writes quantization metadata to config.json."""
|
||||
from mlx_video.models.wan_2.convert import _quantize_saved_model
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
||||
|
||||
# Save unquantized model + config
|
||||
model_path = tmp_path / "model.safetensors"
|
||||
mx.save_safetensors(str(model_path), weights_dict)
|
||||
with open(tmp_path / "config.json", "w") as f:
|
||||
json.dump({"dim": config.dim}, f)
|
||||
|
||||
# Run quantization
|
||||
_quantize_saved_model(tmp_path, config, is_dual=False, bits=4, group_size=64)
|
||||
|
||||
# Verify metadata
|
||||
with open(tmp_path / "config.json") as f:
|
||||
cfg = json.load(f)
|
||||
assert "quantization" in cfg
|
||||
assert cfg["quantization"]["bits"] == 4
|
||||
assert cfg["quantization"]["group_size"] == 64
|
||||
|
||||
def test_config_metadata_8bit(self, tmp_path):
|
||||
from mlx_video.models.wan_2.convert import _quantize_saved_model
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
||||
|
||||
model_path = tmp_path / "model.safetensors"
|
||||
mx.save_safetensors(str(model_path), weights_dict)
|
||||
with open(tmp_path / "config.json", "w") as f:
|
||||
json.dump({}, f)
|
||||
|
||||
_quantize_saved_model(tmp_path, config, is_dual=False, bits=8, group_size=32)
|
||||
|
||||
with open(tmp_path / "config.json") as f:
|
||||
cfg = json.load(f)
|
||||
assert cfg["quantization"]["bits"] == 8
|
||||
assert cfg["quantization"]["group_size"] == 32
|
||||
|
||||
def test_dual_model_quantization(self, tmp_path):
|
||||
"""Verify dual-model quantization writes both model files."""
|
||||
from mlx_video.models.wan_2.convert import _quantize_saved_model
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
|
||||
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
|
||||
model = WanModel(config)
|
||||
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
||||
mx.save_safetensors(str(tmp_path / name), weights_dict)
|
||||
|
||||
with open(tmp_path / "config.json", "w") as f:
|
||||
json.dump({}, f)
|
||||
|
||||
_quantize_saved_model(tmp_path, config, is_dual=True, bits=4, group_size=64)
|
||||
|
||||
# Both files should now contain quantized weights (have .scales keys)
|
||||
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
|
||||
weights = mx.load(str(tmp_path / name))
|
||||
has_scales = any("scales" in k for k in weights)
|
||||
assert has_scales, f"{name} should have quantized layers"
|
||||
384
tests/test_wan_rope_freqs.py
Normal file
384
tests/test_wan_rope_freqs.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""Tests for Wan RoPE frequency construction (Bug 6 regression tests).
|
||||
|
||||
These tests verify that the RoPE frequency table is built correctly by
|
||||
concatenating three separate rope_params calls with different dimension
|
||||
normalizations, matching the reference implementation.
|
||||
|
||||
Background: The reference Wan model constructs RoPE frequencies as:
|
||||
d = dim // num_heads (128 for all Wan models)
|
||||
freqs = cat([
|
||||
rope_params(1024, d - 4*(d//6)), # temporal (dim=44, 22 freqs)
|
||||
rope_params(1024, 2*(d//6)), # height (dim=42, 21 freqs)
|
||||
rope_params(1024, 2*(d//6)), # width (dim=42, 21 freqs)
|
||||
])
|
||||
|
||||
A previous incorrect fix used a single rope_params(1024, 128) call, which
|
||||
gave height/width axes only medium/high frequencies instead of full-range.
|
||||
This destroyed spatial position encoding and caused grey/artifact output.
|
||||
"""
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
class TestRoPEFrequencyConstruction:
|
||||
"""Verify WanModel builds RoPE frequencies matching the reference."""
|
||||
|
||||
def _get_model_freqs(self, dim=64, num_heads=4):
|
||||
"""Instantiate a tiny WanModel and return its .freqs tensor."""
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = WanModelConfig()
|
||||
config.dim = dim
|
||||
config.ffn_dim = dim * 2
|
||||
config.num_heads = num_heads
|
||||
config.num_layers = 1
|
||||
config.in_dim = 4
|
||||
config.out_dim = 4
|
||||
config.freq_dim = 32
|
||||
config.text_dim = 32
|
||||
config.text_len = 8
|
||||
model = WanModel(config)
|
||||
mx.eval(model.freqs)
|
||||
return model.freqs, dim // num_heads
|
||||
|
||||
def test_freqs_shape(self):
|
||||
"""Freqs should be [1024, head_dim//2, 2] regardless of construction."""
|
||||
freqs, head_dim = self._get_model_freqs(dim=64, num_heads=4)
|
||||
assert freqs.shape == (1024, head_dim // 2, 2)
|
||||
|
||||
def test_three_call_vs_single_call_differ(self):
|
||||
"""Three separate rope_params calls must differ from single call."""
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
d = 128 # head_dim for all Wan models
|
||||
# Reference: three separate calls
|
||||
correct = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
# Wrong: single call
|
||||
wrong = rope_params(1024, d)
|
||||
mx.eval(correct, wrong)
|
||||
|
||||
assert correct.shape == wrong.shape
|
||||
diff = np.abs(np.array(correct) - np.array(wrong)).max()
|
||||
assert (
|
||||
diff > 0.1
|
||||
), f"Three-call and single-call should differ significantly, got max diff {diff}"
|
||||
|
||||
def test_each_axis_starts_at_frequency_one(self):
|
||||
"""Each axis (temporal/height/width) should have cos=1, sin=0 at position 0.
|
||||
|
||||
This verifies each axis gets its own independent frequency range
|
||||
starting from theta^0 = 1.0 (i.e., exponent 0/dim).
|
||||
"""
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
mx.eval(freqs)
|
||||
f = np.array(freqs)
|
||||
|
||||
half_d = d // 2 # 64
|
||||
d_t = half_d - 2 * (half_d // 3) # 22
|
||||
d_h = half_d // 3 # 21
|
||||
|
||||
# At position 0, cos=1 and sin=0 for ALL frequency components
|
||||
np.testing.assert_allclose(f[0, :, 0], 1.0, atol=1e-6, err_msg="cos at pos 0")
|
||||
np.testing.assert_allclose(f[0, :, 1], 0.0, atol=1e-6, err_msg="sin at pos 0")
|
||||
|
||||
# At position 1, each axis should have its FIRST frequency near cos(1/theta^0)=cos(1)
|
||||
# Temporal axis first freq
|
||||
np.testing.assert_allclose(
|
||||
f[1, 0, 0], np.cos(1.0), atol=1e-5, err_msg="temporal[0] cos at pos 1"
|
||||
)
|
||||
# Height axis first freq (starts at index d_t)
|
||||
np.testing.assert_allclose(
|
||||
f[1, d_t, 0], np.cos(1.0), atol=1e-5, err_msg="height[0] cos at pos 1"
|
||||
)
|
||||
# Width axis first freq (starts at index d_t + d_h)
|
||||
np.testing.assert_allclose(
|
||||
f[1, d_t + d_h, 0], np.cos(1.0), atol=1e-5, err_msg="width[0] cos at pos 1"
|
||||
)
|
||||
|
||||
def test_height_width_frequencies_identical(self):
|
||||
"""Height and width axes should have identical frequency tables.
|
||||
|
||||
Both use rope_params(1024, 2*(d//6)) = rope_params(1024, 42).
|
||||
"""
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
d = 128
|
||||
d_h_dim = 2 * (d // 6) # 42
|
||||
freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, d_h_dim),
|
||||
rope_params(1024, d_h_dim),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
mx.eval(freqs)
|
||||
f = np.array(freqs)
|
||||
|
||||
half_d = d // 2
|
||||
d_t = half_d - 2 * (half_d // 3)
|
||||
d_h = half_d // 3
|
||||
|
||||
height_freqs = f[:, d_t : d_t + d_h]
|
||||
width_freqs = f[:, d_t + d_h :]
|
||||
np.testing.assert_array_equal(height_freqs, width_freqs)
|
||||
|
||||
def test_frequency_range_per_axis(self):
|
||||
"""Each axis should span a full frequency range, not a slice of one range.
|
||||
|
||||
With three-call construction, the inverse frequency at index 0 of each
|
||||
axis should be 1.0 (theta^0). A single-call approach would give height
|
||||
starting at ~0.04 and width at ~0.002 instead of 1.0.
|
||||
"""
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
mx.eval(freqs)
|
||||
f = np.array(freqs)
|
||||
|
||||
half_d = d // 2
|
||||
d_t = half_d - 2 * (half_d // 3)
|
||||
d_h = half_d // 3
|
||||
|
||||
# At position 1, the first frequency component of each axis should
|
||||
# have significant magnitude (cos ≈ 0.54), not near-zero
|
||||
pos1_t = f[1, 0, 0] # temporal first freq
|
||||
pos1_h = f[1, d_t, 0] # height first freq
|
||||
pos1_w = f[1, d_t + d_h, 0] # width first freq
|
||||
|
||||
assert (
|
||||
pos1_t > 0.5
|
||||
), f"Temporal first freq at pos 1 should be >0.5, got {pos1_t}"
|
||||
assert pos1_h > 0.5, f"Height first freq at pos 1 should be >0.5, got {pos1_h}"
|
||||
assert pos1_w > 0.5, f"Width first freq at pos 1 should be >0.5, got {pos1_w}"
|
||||
|
||||
def test_model_freqs_match_manual_construction(self):
|
||||
"""WanModel.freqs should match manually constructed three-call freqs."""
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4)
|
||||
d = head_dim # 16
|
||||
freqs_manual = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
mx.eval(freqs_model, freqs_manual)
|
||||
np.testing.assert_array_equal(
|
||||
np.array(freqs_model),
|
||||
np.array(freqs_manual),
|
||||
err_msg="WanModel.freqs should use three-call construction",
|
||||
)
|
||||
|
||||
def test_model_freqs_14b_dimensions(self):
|
||||
"""Verify freq dimensions for 14B-scale head_dim=128."""
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)), # dim=44 → 22 freq pairs
|
||||
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
|
||||
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
mx.eval(freqs)
|
||||
|
||||
assert freqs.shape == (1024, 64, 2)
|
||||
# Verify the split dimensions used by rope_apply
|
||||
half_d = 64
|
||||
d_t = half_d - 2 * (half_d // 3)
|
||||
d_h = half_d // 3
|
||||
d_w = half_d // 3
|
||||
assert (d_t, d_h, d_w) == (22, 21, 21)
|
||||
assert d_t + d_h + d_w == half_d
|
||||
|
||||
|
||||
class TestRoPEFrequencyMatchesReference:
|
||||
"""Cross-validate MLX RoPE against PyTorch reference implementation."""
|
||||
|
||||
@pytest.fixture
|
||||
def has_torch(self):
|
||||
try:
|
||||
pass
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
pytest.skip("PyTorch not installed")
|
||||
|
||||
def test_freqs_match_pytorch_reference(self, has_torch):
|
||||
"""Numerically compare MLX and PyTorch frequency tables."""
|
||||
import torch
|
||||
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
d = 128
|
||||
|
||||
# PyTorch reference (from wan/modules/model.py)
|
||||
def pt_rope_params(max_seq_len, dim, theta=10000):
|
||||
freqs = torch.outer(
|
||||
torch.arange(max_seq_len),
|
||||
1.0
|
||||
/ torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)),
|
||||
)
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
ref = torch.cat(
|
||||
[
|
||||
pt_rope_params(1024, d - 4 * (d // 6)),
|
||||
pt_rope_params(1024, 2 * (d // 6)),
|
||||
pt_rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# MLX
|
||||
ours = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
mx.eval(ours)
|
||||
|
||||
our_cos = np.array(ours[:, :, 0])
|
||||
our_sin = np.array(ours[:, :, 1])
|
||||
ref_cos = ref.real.float().numpy()
|
||||
ref_sin = ref.imag.float().numpy()
|
||||
|
||||
np.testing.assert_allclose(
|
||||
our_cos, ref_cos, atol=1e-6, err_msg="cos mismatch vs PyTorch reference"
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
our_sin, ref_sin, atol=1e-6, err_msg="sin mismatch vs PyTorch reference"
|
||||
)
|
||||
|
||||
|
||||
class TestRoPEApplyWithCorrectFreqs:
|
||||
"""Test that rope_apply produces correct rotations with three-call freqs."""
|
||||
|
||||
def test_different_spatial_positions_get_different_rotations(self):
|
||||
"""Adjacent height/width positions must produce different RoPE rotations.
|
||||
|
||||
This is the key property that was broken by the single-call bug:
|
||||
height/width frequencies were too low to distinguish nearby positions.
|
||||
"""
|
||||
from mlx_video.models.wan_2.rope import rope_apply, rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
B, N = 1, 4
|
||||
F, H, W = 1, 4, 4
|
||||
L = F * H * W
|
||||
# Use a constant input so differences come purely from RoPE
|
||||
x = mx.ones((B, L, N, d))
|
||||
out = rope_apply(x, [(F, H, W)], freqs)
|
||||
mx.eval(out)
|
||||
out_np = np.array(out[0])
|
||||
|
||||
# Position (0,0,0) vs (0,1,0) — different height
|
||||
pos_00 = out_np[0 * H * W + 0 * W + 0] # (f=0, h=0, w=0)
|
||||
pos_10 = out_np[0 * H * W + 1 * W + 0] # (f=0, h=1, w=0)
|
||||
height_diff = np.abs(pos_00 - pos_10).max()
|
||||
|
||||
# Position (0,0,0) vs (0,0,1) — different width
|
||||
pos_01 = out_np[0 * H * W + 0 * W + 1] # (f=0, h=0, w=1)
|
||||
width_diff = np.abs(pos_00 - pos_01).max()
|
||||
|
||||
# Max diff should be >0.5 for both axes. With the bug, height was ~0.04
|
||||
# and width was ~0.002. With correct freqs, both are ~1.3.
|
||||
assert (
|
||||
height_diff > 0.5
|
||||
), f"Adjacent height positions should differ significantly, got {height_diff:.4f}"
|
||||
assert (
|
||||
width_diff > 0.5
|
||||
), f"Adjacent width positions should differ significantly, got {width_diff:.4f}"
|
||||
# Height and width should have identical frequency tables → same diffs
|
||||
np.testing.assert_allclose(
|
||||
height_diff,
|
||||
width_diff,
|
||||
rtol=1e-5,
|
||||
err_msg="Height and width should use identical frequency tables",
|
||||
)
|
||||
|
||||
def test_precomputed_matches_online(self):
|
||||
"""rope_precompute_cos_sin + rope_apply should match non-precomputed path."""
|
||||
from mlx_video.models.wan_2.rope import (
|
||||
rope_apply,
|
||||
rope_params,
|
||||
rope_precompute_cos_sin,
|
||||
)
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
B, N = 2, 4
|
||||
F, H, W = 2, 3, 4
|
||||
L = F * H * W
|
||||
grids = [(F, H, W), (F, H, W)]
|
||||
|
||||
x = mx.random.normal((B, L, N, d))
|
||||
|
||||
# Online (no precomputed)
|
||||
out_online = rope_apply(x, grids, freqs)
|
||||
# Precomputed
|
||||
cos_sin = rope_precompute_cos_sin(grids, freqs)
|
||||
out_precomp = rope_apply(x, grids, freqs, precomputed_cos_sin=cos_sin)
|
||||
mx.eval(out_online, out_precomp)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
np.array(out_online),
|
||||
np.array(out_precomp),
|
||||
atol=1e-5,
|
||||
err_msg="Precomputed and online RoPE should match",
|
||||
)
|
||||
996
tests/test_wan_scheduler.py
Normal file
996
tests/test_wan_scheduler.py
Normal file
@@ -0,0 +1,996 @@
|
||||
"""Tests for Wan scheduler components."""
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Euler Scheduler Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlowMatchEulerScheduler:
|
||||
def test_initialization(self):
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
assert sched.num_train_timesteps == 1000
|
||||
assert sched.timesteps is None
|
||||
assert sched.sigmas is None
|
||||
|
||||
def test_set_timesteps(self):
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(40, shift=12.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
assert sched.timesteps.shape == (40,)
|
||||
assert sched.sigmas.shape == (41,) # 40 steps + terminal
|
||||
|
||||
def test_timesteps_decreasing(self):
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(40, shift=12.0)
|
||||
mx.eval(sched.timesteps)
|
||||
ts = np.array(sched.timesteps)
|
||||
# Timesteps should be monotonically decreasing
|
||||
assert np.all(np.diff(ts) < 0), f"Timesteps not decreasing: {ts[:5]}..."
|
||||
|
||||
def test_sigmas_decreasing(self):
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(20, shift=1.0)
|
||||
mx.eval(sched.sigmas)
|
||||
sigmas = np.array(sched.sigmas)
|
||||
assert np.all(np.diff(sigmas) <= 0), "Sigmas not decreasing"
|
||||
|
||||
def test_terminal_sigma_is_zero(self):
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(20, shift=5.0)
|
||||
mx.eval(sched.sigmas)
|
||||
np.testing.assert_allclose(np.array(sched.sigmas[-1]), 0.0, atol=1e-6)
|
||||
|
||||
def test_shift_effect(self):
|
||||
"""Larger shift should push sigmas toward higher values."""
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched1 = FlowMatchEulerScheduler()
|
||||
sched2 = FlowMatchEulerScheduler()
|
||||
sched1.set_timesteps(20, shift=1.0)
|
||||
sched2.set_timesteps(20, shift=12.0)
|
||||
mx.eval(sched1.sigmas, sched2.sigmas)
|
||||
mean1 = np.mean(np.array(sched1.sigmas[:-1]))
|
||||
mean2 = np.mean(np.array(sched2.sigmas[:-1]))
|
||||
assert mean2 > mean1, "Higher shift should push sigmas higher"
|
||||
|
||||
def test_step_euler(self):
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(10, shift=1.0)
|
||||
mx.eval(sched.sigmas)
|
||||
|
||||
sample = mx.ones((1, 4, 2, 2, 2))
|
||||
velocity = mx.ones((1, 4, 2, 2, 2)) * 0.5
|
||||
timestep = sched.timesteps[0]
|
||||
|
||||
sigma = float(np.array(sched.sigmas[0]))
|
||||
sigma_next = float(np.array(sched.sigmas[1]))
|
||||
|
||||
result = sched.step(velocity, timestep, sample)
|
||||
mx.eval(result)
|
||||
|
||||
# Euler: x_next = x + (sigma_next - sigma) * v
|
||||
expected = 1.0 + (sigma_next - sigma) * 0.5
|
||||
np.testing.assert_allclose(
|
||||
np.array(result).flatten()[0],
|
||||
expected,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
def test_step_index_increments(self):
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
assert sched._step_index == 0
|
||||
sample = mx.ones((1, 1, 1, 1, 1))
|
||||
vel = mx.zeros((1, 1, 1, 1, 1))
|
||||
sched.step(vel, sched.timesteps[0], sample)
|
||||
assert sched._step_index == 1
|
||||
sched.step(vel, sched.timesteps[1], sample)
|
||||
assert sched._step_index == 2
|
||||
|
||||
def test_reset(self):
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 1, 1))
|
||||
vel = mx.zeros((1, 1, 1, 1, 1))
|
||||
sched.step(vel, sched.timesteps[0], sample)
|
||||
assert sched._step_index == 1
|
||||
sched.reset()
|
||||
assert sched._step_index == 0
|
||||
|
||||
@pytest.mark.parametrize("steps", [10, 20, 40, 50])
|
||||
def test_various_step_counts(self, steps):
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(steps, shift=12.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
assert sched.timesteps.shape == (steps,)
|
||||
assert sched.sigmas.shape == (steps + 1,)
|
||||
|
||||
def test_full_denoise_loop(self):
|
||||
"""Run a complete denoise loop with zero velocity -> sample unchanged."""
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 2, 1, 2, 2))
|
||||
for i in range(5):
|
||||
vel = mx.zeros_like(sample)
|
||||
sample = sched.step(vel, sched.timesteps[i], sample)
|
||||
mx.eval(sample)
|
||||
# With zero velocity, sample should remain unchanged
|
||||
np.testing.assert_allclose(np.array(sample), 1.0, atol=1e-5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared Sigma Schedule Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComputeSigmas:
|
||||
"""Tests for the shared _compute_sigmas helper."""
|
||||
|
||||
def test_length(self):
|
||||
from mlx_video.models.wan_2.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(20, shift=5.0)
|
||||
assert len(sigmas) == 21 # num_steps + terminal
|
||||
|
||||
def test_terminal_zero(self):
|
||||
from mlx_video.models.wan_2.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(10, shift=1.0)
|
||||
assert sigmas[-1] == 0.0
|
||||
|
||||
def test_starts_near_one(self):
|
||||
from mlx_video.models.wan_2.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(20, shift=5.0)
|
||||
# Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0)
|
||||
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3)
|
||||
|
||||
def test_decreasing(self):
|
||||
from mlx_video.models.wan_2.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(20, shift=5.0)
|
||||
assert np.all(np.diff(sigmas) <= 0)
|
||||
|
||||
def test_matches_official_wan22(self):
|
||||
"""Sigma schedule should match the official Wan2.2 FlowUniPCMultistepScheduler.
|
||||
|
||||
The reference creates the scheduler with shift=1 (identity) in the
|
||||
constructor, then passes the actual shift to set_timesteps. This means
|
||||
sigma_max/sigma_min come from the *unshifted* training schedule, and the
|
||||
shift is applied only once (single-shift).
|
||||
"""
|
||||
from mlx_video.models.wan_2.scheduler import _compute_sigmas
|
||||
|
||||
steps, shift, N = 50, 5.0, 1000
|
||||
sigmas = _compute_sigmas(steps, shift, N)
|
||||
# Official single-shift: unshifted bounds, then shift once
|
||||
alphas = np.linspace(1.0, 1.0 / N, N)[::-1]
|
||||
sigmas_unshifted = 1.0 - alphas
|
||||
sigma_max = float(sigmas_unshifted[0]) # 0.999
|
||||
sigma_min = float(sigmas_unshifted[-1]) # 0.0
|
||||
official = np.linspace(sigma_max, sigma_min, steps + 1)[:-1]
|
||||
official = shift * official / (1.0 + (shift - 1.0) * official)
|
||||
official = np.append(official, 0.0).astype(np.float32)
|
||||
np.testing.assert_allclose(sigmas, official, atol=1e-6)
|
||||
|
||||
def test_shift_one_is_near_linear(self):
|
||||
from mlx_video.models.wan_2.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(10, shift=1.0)
|
||||
# With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule)
|
||||
# so schedule is nearly linear from ~0.999 to 0
|
||||
expected = np.linspace(1, 0, 11).astype(np.float32)
|
||||
np.testing.assert_allclose(sigmas, expected, atol=2e-3)
|
||||
|
||||
def test_all_schedulers_same_sigmas(self):
|
||||
"""All three schedulers should produce identical sigma schedules."""
|
||||
from mlx_video.models.wan_2.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowMatchEulerScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
|
||||
scheds = [
|
||||
FlowMatchEulerScheduler(1000),
|
||||
FlowDPMPP2MScheduler(1000),
|
||||
FlowUniPCScheduler(1000),
|
||||
]
|
||||
for s in scheds:
|
||||
s.set_timesteps(20, shift=5.0)
|
||||
mx.eval(*[s.sigmas for s in scheds])
|
||||
ref = np.array(scheds[0].sigmas)
|
||||
for s in scheds[1:]:
|
||||
np.testing.assert_allclose(np.array(s.sigmas), ref, atol=1e-6)
|
||||
|
||||
def test_all_schedulers_same_timesteps(self):
|
||||
from mlx_video.models.wan_2.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowMatchEulerScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
|
||||
scheds = [
|
||||
FlowMatchEulerScheduler(1000),
|
||||
FlowDPMPP2MScheduler(1000),
|
||||
FlowUniPCScheduler(1000),
|
||||
]
|
||||
for s in scheds:
|
||||
s.set_timesteps(30, shift=12.0)
|
||||
mx.eval(*[s.timesteps for s in scheds])
|
||||
ref = np.array(scheds[0].timesteps)
|
||||
for s in scheds[1:]:
|
||||
np.testing.assert_allclose(np.array(s.timesteps), ref, atol=1e-3)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DPM++ 2M Scheduler Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlowDPMPP2MScheduler:
|
||||
def test_initialization(self):
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
assert sched.num_train_timesteps == 1000
|
||||
assert sched.lower_order_final is True
|
||||
|
||||
def test_set_timesteps(self):
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(20, shift=5.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
assert sched.timesteps.shape == (20,)
|
||||
assert sched.sigmas.shape == (21,)
|
||||
|
||||
def test_step_index_increments(self):
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 4, 1, 2, 2))
|
||||
vel = mx.zeros_like(sample)
|
||||
assert sched._step_index == 0
|
||||
sched.step(vel, sched.timesteps[0], sample)
|
||||
assert sched._step_index == 1
|
||||
sched.step(vel, sched.timesteps[1], sample)
|
||||
assert sched._step_index == 2
|
||||
|
||||
def test_reset(self):
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 1, 1))
|
||||
sched.step(mx.zeros_like(sample), 0, sample)
|
||||
sched.reset()
|
||||
assert sched._step_index == 0
|
||||
assert sched._prev_x0 is None
|
||||
|
||||
def test_full_loop_finite(self):
|
||||
"""Full loop with constant velocity should produce finite output."""
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(10, shift=1.0)
|
||||
sample = mx.ones((1, 2, 1, 2, 2))
|
||||
for i in range(10):
|
||||
vel = mx.ones_like(sample) * 0.1
|
||||
sample = sched.step(vel, sched.timesteps[i], sample)
|
||||
mx.eval(sample)
|
||||
assert np.isfinite(np.array(sample)).all()
|
||||
|
||||
def test_first_step_is_first_order(self):
|
||||
"""First step should use 1st-order (no prev_x0 available)."""
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
sample = mx.random.normal((1, 4, 2, 4, 4))
|
||||
vel = mx.random.normal(sample.shape)
|
||||
# Before first step, no prev_x0
|
||||
assert sched._prev_x0 is None
|
||||
result = sched.step(vel, sched.timesteps[0], sample)
|
||||
mx.eval(result)
|
||||
# After first step, prev_x0 should be set
|
||||
assert sched._prev_x0 is not None
|
||||
|
||||
def test_second_step_uses_correction(self):
|
||||
"""After first step, DPM++ should have stored prev_x0 for correction."""
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
sample = mx.random.normal((1, 4, 1, 2, 2))
|
||||
vel = mx.random.normal(sample.shape)
|
||||
# Step 1
|
||||
sample = sched.step(vel, sched.timesteps[0], sample)
|
||||
mx.eval(sample)
|
||||
x0_after_first = sched._prev_x0
|
||||
# Step 2
|
||||
vel = mx.random.normal(sample.shape)
|
||||
sample = sched.step(vel, sched.timesteps[1], sample)
|
||||
mx.eval(sample)
|
||||
# prev_x0 should have been updated
|
||||
x0_after_second = sched._prev_x0
|
||||
assert x0_after_second is not None
|
||||
# The stored x0 should differ from the first step's
|
||||
assert not np.allclose(
|
||||
np.array(x0_after_first), np.array(x0_after_second), atol=1e-6
|
||||
)
|
||||
|
||||
def test_denoise_to_target(self):
|
||||
"""Perfect oracle should denoise to target with any solver."""
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(20, shift=5.0)
|
||||
target = mx.zeros((1, 2, 1, 4, 4))
|
||||
latents = mx.random.normal(target.shape)
|
||||
for i in range(20):
|
||||
sigma = float(sched.sigmas[i].item())
|
||||
v = latents / max(sigma, 1e-6) # perfect velocity for target=0
|
||||
latents = sched.step(v, sched.timesteps[i], latents)
|
||||
mx.eval(latents)
|
||||
np.testing.assert_allclose(np.array(latents), 0.0, atol=1e-3)
|
||||
|
||||
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
|
||||
def test_various_step_counts(self, steps):
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(steps, shift=5.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
assert sched.timesteps.shape == (steps,)
|
||||
assert sched.sigmas.shape == (steps + 1,)
|
||||
|
||||
def test_terminal_sigma_produces_x0(self):
|
||||
"""When sigma_next=0 the scheduler should return x0 directly."""
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sched = FlowDPMPP2MScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 1, 1)) * 3.0
|
||||
vel = mx.ones_like(sample) * 2.0
|
||||
# Run through all steps; the last step has sigma_next=0
|
||||
for i in range(5):
|
||||
sample = sched.step(vel, sched.timesteps[i], sample)
|
||||
mx.eval(sample)
|
||||
# Final value should be finite
|
||||
assert np.isfinite(np.array(sample)).all()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# UniPC Scheduler Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlowUniPCScheduler:
|
||||
def test_initialization(self):
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
assert sched.num_train_timesteps == 1000
|
||||
assert sched.solver_order == 2
|
||||
assert sched.lower_order_final is True
|
||||
|
||||
def test_set_timesteps(self):
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(30, shift=12.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
assert sched.timesteps.shape == (30,)
|
||||
assert sched.sigmas.shape == (31,)
|
||||
|
||||
def test_step_index_increments(self):
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 1, 1))
|
||||
vel = mx.zeros_like(sample)
|
||||
assert sched._step_index == 0
|
||||
sched.step(vel, 0, sample)
|
||||
assert sched._step_index == 1
|
||||
|
||||
def test_reset(self):
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 1, 1))
|
||||
sched.step(mx.zeros_like(sample), 0, sample)
|
||||
sched.reset()
|
||||
assert sched._step_index == 0
|
||||
assert sched._lower_order_nums == 0
|
||||
assert sched._last_sample is None
|
||||
assert all(m is None for m in sched._model_outputs)
|
||||
|
||||
def test_full_loop_finite(self):
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(10, shift=1.0)
|
||||
sample = mx.ones((1, 2, 1, 2, 2))
|
||||
for i in range(10):
|
||||
vel = mx.ones_like(sample) * 0.1
|
||||
sample = sched.step(vel, sched.timesteps[i], sample)
|
||||
mx.eval(sample)
|
||||
assert np.isfinite(np.array(sample)).all()
|
||||
|
||||
def test_corrector_not_applied_first_step(self):
|
||||
"""First step should skip the corrector (no history)."""
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler(use_corrector=True)
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
sample = mx.random.normal((1, 4, 1, 2, 2))
|
||||
vel = mx.random.normal(sample.shape)
|
||||
# Before step 0: no last_sample
|
||||
assert sched._last_sample is None
|
||||
sched.step(vel, sched.timesteps[0], sample)
|
||||
# After step 0: last_sample should be set for corrector on step 1
|
||||
assert sched._last_sample is not None
|
||||
|
||||
def test_corrector_applied_after_first_step(self):
|
||||
"""Steps after the first should use the corrector when enabled."""
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler(use_corrector=True)
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
sample = mx.random.normal((1, 2, 1, 4, 4))
|
||||
for i in range(3):
|
||||
vel = mx.random.normal(sample.shape)
|
||||
sample = sched.step(vel, sched.timesteps[i], sample)
|
||||
mx.eval(sample)
|
||||
# lower_order_nums should have increased
|
||||
assert sched._lower_order_nums >= 2
|
||||
|
||||
def test_denoise_to_target(self):
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(20, shift=5.0)
|
||||
target = mx.zeros((1, 2, 1, 4, 4))
|
||||
latents = mx.random.normal(target.shape)
|
||||
for i in range(20):
|
||||
sigma = float(sched.sigmas[i].item())
|
||||
v = latents / max(sigma, 1e-6)
|
||||
latents = sched.step(v, sched.timesteps[i], latents)
|
||||
mx.eval(latents)
|
||||
np.testing.assert_allclose(np.array(latents), 0.0, atol=1e-3)
|
||||
|
||||
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
|
||||
def test_various_step_counts(self, steps):
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
sched.set_timesteps(steps, shift=5.0)
|
||||
mx.eval(sched.timesteps, sched.sigmas)
|
||||
assert sched.timesteps.shape == (steps,)
|
||||
assert sched.sigmas.shape == (steps + 1,)
|
||||
|
||||
def test_disable_corrector(self):
|
||||
"""Disabling corrector on step 0 should still work without error."""
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0])
|
||||
sched.set_timesteps(5, shift=1.0)
|
||||
sample = mx.ones((1, 1, 1, 2, 2))
|
||||
for i in range(5):
|
||||
vel = mx.ones_like(sample) * 0.1
|
||||
sample = sched.step(vel, sched.timesteps[i], sample)
|
||||
mx.eval(sample)
|
||||
assert np.isfinite(np.array(sample)).all()
|
||||
|
||||
def test_solver_order_3(self):
|
||||
"""Order 3 should work without error."""
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler(solver_order=3, use_corrector=True)
|
||||
sched.set_timesteps(10, shift=5.0)
|
||||
sample = mx.random.normal((1, 2, 1, 2, 2))
|
||||
for i in range(10):
|
||||
vel = mx.random.normal(sample.shape)
|
||||
sample = sched.step(vel, sched.timesteps[i], sample)
|
||||
mx.eval(sample)
|
||||
assert np.isfinite(np.array(sample)).all()
|
||||
|
||||
def test_corrector_rhos_c_not_hardcoded(self):
|
||||
"""Corrector rhos_c should be computed via linalg.solve, not hardcoded 0.5."""
|
||||
import math
|
||||
|
||||
# For 50-step schedule with shift=5.0, order 2 corrector at step 5:
|
||||
# rhos_c[0] (history) should be ~0.07, NOT 0.5
|
||||
# rhos_c[1] (D1_t) should be ~0.45, NOT 0.5
|
||||
from mlx_video.models.wan_2.scheduler import _compute_sigmas
|
||||
|
||||
sigmas = _compute_sigmas(50, shift=5.0)
|
||||
|
||||
def _lambda(sigma):
|
||||
if sigma >= 1.0:
|
||||
return -math.inf
|
||||
if sigma <= 0.0:
|
||||
return math.inf
|
||||
return math.log(1 - sigma) - math.log(sigma)
|
||||
|
||||
for step_idx in [5, 10, 25, 45]:
|
||||
sigma_s0 = sigmas[step_idx - 1]
|
||||
sigma_t = sigmas[step_idx]
|
||||
lambda_s0 = _lambda(sigma_s0)
|
||||
lambda_t = _lambda(sigma_t)
|
||||
h = lambda_t - lambda_s0
|
||||
hh = -h
|
||||
|
||||
sigma_sk = sigmas[step_idx - 2]
|
||||
lambda_sk = _lambda(sigma_sk)
|
||||
rk = (lambda_sk - lambda_s0) / h
|
||||
rks = np.array([rk, 1.0])
|
||||
|
||||
h_phi_1 = math.expm1(hh)
|
||||
B_h = h_phi_1
|
||||
h_phi_k = h_phi_1 / hh - 1.0
|
||||
factorial_i = 1
|
||||
R_rows, b_vals = [], []
|
||||
for j in range(1, 3):
|
||||
R_rows.append(rks ** (j - 1))
|
||||
b_vals.append(h_phi_k * factorial_i / B_h)
|
||||
factorial_i *= j + 1
|
||||
h_phi_k = h_phi_k / hh - 1.0 / factorial_i
|
||||
R = np.stack(R_rows)
|
||||
b = np.array(b_vals)
|
||||
rhos_c = np.linalg.solve(R, b)
|
||||
|
||||
# History weight should be small (~0.07-0.09), not 0.5
|
||||
assert (
|
||||
rhos_c[0] < 0.15
|
||||
), f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large"
|
||||
assert (
|
||||
rhos_c[0] > 0.0
|
||||
), f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive"
|
||||
# D1_t weight should be ~0.42-0.45, not 0.5
|
||||
assert (
|
||||
0.3 < rhos_c[1] < 0.5
|
||||
), f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scheduler Coherence Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSchedulerCoherence:
|
||||
"""Tests that Euler, DPM++, and UniPC schedulers produce coherent results.
|
||||
|
||||
All three schedulers should agree on shared structure (sigma schedules,
|
||||
first-step behavior) and converge to the same result given perfect
|
||||
velocity oracles, even though they use different update rules.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_schedulers(steps=10, shift=5.0):
|
||||
from mlx_video.models.wan_2.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowMatchEulerScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
|
||||
scheds = {
|
||||
"euler": FlowMatchEulerScheduler(),
|
||||
"dpm++": FlowDPMPP2MScheduler(),
|
||||
"unipc": FlowUniPCScheduler(),
|
||||
}
|
||||
for s in scheds.values():
|
||||
s.set_timesteps(steps, shift=shift)
|
||||
return scheds
|
||||
|
||||
def test_identical_sigma_schedules(self):
|
||||
"""All schedulers must use the same sigma schedule."""
|
||||
scheds = self._make_schedulers(20, shift=5.0)
|
||||
ref = np.array(scheds["euler"].sigmas)
|
||||
for name in ("dpm++", "unipc"):
|
||||
np.testing.assert_allclose(
|
||||
np.array(scheds[name].sigmas),
|
||||
ref,
|
||||
atol=1e-6,
|
||||
err_msg=f"{name} sigma schedule differs from Euler",
|
||||
)
|
||||
|
||||
def test_identical_timesteps(self):
|
||||
"""All schedulers must produce the same timestep sequence."""
|
||||
scheds = self._make_schedulers(20, shift=5.0)
|
||||
ref = np.array(scheds["euler"].timesteps)
|
||||
for name in ("dpm++", "unipc"):
|
||||
np.testing.assert_allclose(
|
||||
np.array(scheds[name].timesteps),
|
||||
ref,
|
||||
atol=1e-6,
|
||||
err_msg=f"{name} timesteps differ from Euler",
|
||||
)
|
||||
|
||||
def test_first_step_matches_euler(self):
|
||||
"""Step 0 (1st-order for all solvers) should match Euler exactly."""
|
||||
mx.random.seed(42)
|
||||
shape = (1, 4, 1, 4, 4)
|
||||
noise = mx.random.normal(shape)
|
||||
vel = mx.random.normal(shape)
|
||||
|
||||
scheds = self._make_schedulers(10, shift=5.0)
|
||||
results = {}
|
||||
for name, sched in scheds.items():
|
||||
r = sched.step(vel, sched.timesteps[0], noise)
|
||||
mx.eval(r)
|
||||
results[name] = np.array(r)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
results["dpm++"],
|
||||
results["euler"],
|
||||
atol=1e-5,
|
||||
err_msg="DPM++ step 0 should match Euler",
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
results["unipc"],
|
||||
results["euler"],
|
||||
atol=1e-5,
|
||||
err_msg="UniPC step 0 should match Euler",
|
||||
)
|
||||
|
||||
def test_first_step_matches_across_shifts(self):
|
||||
"""Step 0 should match Euler for different shift values."""
|
||||
mx.random.seed(99)
|
||||
shape = (1, 2, 1, 2, 2)
|
||||
noise = mx.random.normal(shape)
|
||||
vel = mx.random.normal(shape)
|
||||
|
||||
for shift in (1.0, 5.0, 12.0):
|
||||
scheds = self._make_schedulers(10, shift=shift)
|
||||
euler_r = scheds["euler"].step(vel, scheds["euler"].timesteps[0], noise)
|
||||
dpm_r = scheds["dpm++"].step(vel, scheds["dpm++"].timesteps[0], noise)
|
||||
unipc_r = scheds["unipc"].step(vel, scheds["unipc"].timesteps[0], noise)
|
||||
mx.eval(euler_r, dpm_r, unipc_r)
|
||||
np.testing.assert_allclose(
|
||||
np.array(dpm_r),
|
||||
np.array(euler_r),
|
||||
atol=1e-5,
|
||||
err_msg=f"DPM++ step 0 differs from Euler at shift={shift}",
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
np.array(unipc_r),
|
||||
np.array(euler_r),
|
||||
atol=1e-5,
|
||||
err_msg=f"UniPC step 0 differs from Euler at shift={shift}",
|
||||
)
|
||||
|
||||
def test_oracle_all_converge_to_target(self):
|
||||
"""Given a perfect velocity oracle v=x/sigma, all solvers should
|
||||
denoise to approximately zero (the target)."""
|
||||
mx.random.seed(7)
|
||||
shape = (1, 2, 1, 4, 4)
|
||||
noise = mx.random.normal(shape)
|
||||
|
||||
for name, sched in self._make_schedulers(20, shift=5.0).items():
|
||||
latents = noise
|
||||
for i in range(20):
|
||||
sigma = float(sched.sigmas[i].item())
|
||||
v = latents / max(sigma, 1e-8)
|
||||
latents = sched.step(v, sched.timesteps[i], latents)
|
||||
mx.eval(latents)
|
||||
np.testing.assert_allclose(
|
||||
np.array(latents),
|
||||
0.0,
|
||||
atol=1e-3,
|
||||
err_msg=f"{name} did not converge to target with oracle",
|
||||
)
|
||||
|
||||
def test_oracle_higher_order_closer_to_target(self):
|
||||
"""With few steps and a perfect oracle, higher-order solvers should
|
||||
be at least as accurate as Euler."""
|
||||
mx.random.seed(12)
|
||||
shape = (1, 2, 1, 4, 4)
|
||||
noise = mx.random.normal(shape)
|
||||
steps = 5
|
||||
|
||||
errors = {}
|
||||
for name, sched in self._make_schedulers(steps, shift=5.0).items():
|
||||
latents = noise
|
||||
for i in range(steps):
|
||||
sigma = float(sched.sigmas[i].item())
|
||||
v = latents / max(sigma, 1e-8)
|
||||
latents = sched.step(v, sched.timesteps[i], latents)
|
||||
mx.eval(latents)
|
||||
errors[name] = float(mx.mean(mx.abs(latents)).item())
|
||||
|
||||
# Higher-order solvers should not be significantly worse than Euler
|
||||
# (add small epsilon to handle near-zero errors from floating point noise)
|
||||
eps = 1e-6
|
||||
assert (
|
||||
errors["dpm++"] <= errors["euler"] * 1.5 + eps
|
||||
), f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}"
|
||||
assert (
|
||||
errors["unipc"] <= errors["euler"] * 1.5 + eps
|
||||
), f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}"
|
||||
|
||||
def test_multistep_trajectory_similar_magnitude(self):
|
||||
"""Over a full denoising loop with constant velocity, all solvers
|
||||
should produce outputs of similar magnitude (not diverging)."""
|
||||
mx.random.seed(42)
|
||||
shape = (1, 4, 1, 4, 4)
|
||||
noise = mx.random.normal(shape)
|
||||
steps = 20
|
||||
|
||||
final_means = {}
|
||||
for name, sched in self._make_schedulers(steps, shift=5.0).items():
|
||||
latents = noise
|
||||
for i in range(steps):
|
||||
vel = latents * 0.1
|
||||
latents = sched.step(vel, sched.timesteps[i], latents)
|
||||
mx.eval(latents)
|
||||
final_means[name] = float(mx.mean(mx.abs(latents)).item())
|
||||
|
||||
# All solvers should produce results within the same order of magnitude
|
||||
vals = list(final_means.values())
|
||||
ratio = max(vals) / max(min(vals), 1e-10)
|
||||
assert (
|
||||
ratio < 10.0
|
||||
), f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}"
|
||||
|
||||
def test_intermediate_values_finite(self):
|
||||
"""Every intermediate latent value must be finite for all solvers."""
|
||||
mx.random.seed(0)
|
||||
shape = (1, 2, 1, 2, 2)
|
||||
noise = mx.random.normal(shape)
|
||||
|
||||
for name, sched in self._make_schedulers(15, shift=5.0).items():
|
||||
latents = noise
|
||||
for i in range(15):
|
||||
vel = mx.random.normal(shape)
|
||||
latents = sched.step(vel, sched.timesteps[i], latents)
|
||||
mx.eval(latents)
|
||||
assert np.isfinite(
|
||||
np.array(latents)
|
||||
).all(), f"{name} produced non-finite values at step {i}"
|
||||
|
||||
def test_lambda_boundary_values(self):
|
||||
"""_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0."""
|
||||
from mlx_video.models.wan_2.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
|
||||
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
|
||||
assert (
|
||||
cls._lambda(1.0) == -math.inf
|
||||
), f"{cls.__name__}._lambda(1.0) should be -inf"
|
||||
assert (
|
||||
cls._lambda(0.0) == math.inf
|
||||
), f"{cls.__name__}._lambda(0.0) should be +inf"
|
||||
# Interior values should be finite
|
||||
lam = cls._lambda(0.5)
|
||||
assert (
|
||||
math.isfinite(lam) and lam == 0.0
|
||||
), f"{cls.__name__}._lambda(0.5) should be 0.0"
|
||||
|
||||
def test_lambda_monotonically_decreasing(self):
|
||||
"""_lambda(sigma) should decrease as sigma increases (more noise → lower SNR)."""
|
||||
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
|
||||
|
||||
sigmas = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99]
|
||||
lambdas = [FlowDPMPP2MScheduler._lambda(s) for s in sigmas]
|
||||
for i in range(len(lambdas) - 1):
|
||||
assert lambdas[i] > lambdas[i + 1], (
|
||||
f"_lambda not decreasing: _lambda({sigmas[i]})={lambdas[i]} "
|
||||
f"vs _lambda({sigmas[i+1]})={lambdas[i+1]}"
|
||||
)
|
||||
|
||||
def test_step0_is_ddim_formula(self):
|
||||
"""At sigma=1.0, the DPM++/UniPC first step should reduce to the
|
||||
DDIM formula: x_next = sigma_next * x + (1 - sigma_next) * x0."""
|
||||
mx.random.seed(55)
|
||||
shape = (1, 2, 1, 2, 2)
|
||||
sample = mx.random.normal(shape)
|
||||
vel = mx.random.normal(shape)
|
||||
|
||||
for steps, shift in [(10, 5.0), (20, 12.0)]:
|
||||
scheds = self._make_schedulers(steps, shift=shift)
|
||||
sigma_next = float(scheds["euler"].sigmas[1].item())
|
||||
sigma_cur = float(scheds["euler"].sigmas[0].item())
|
||||
assert abs(sigma_cur - 1.0) < 1e-3, "First sigma should be ~1.0"
|
||||
|
||||
x0 = sample - sigma_cur * vel
|
||||
expected = sigma_next * sample + (1.0 - sigma_next) * x0
|
||||
mx.eval(expected)
|
||||
|
||||
for name in ("dpm++", "unipc"):
|
||||
result = scheds[name].step(vel, scheds[name].timesteps[0], sample)
|
||||
mx.eval(result)
|
||||
np.testing.assert_allclose(
|
||||
np.array(result),
|
||||
np.array(expected),
|
||||
atol=5e-4,
|
||||
err_msg=f"{name} step 0 doesn't match DDIM formula (shift={shift})",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
|
||||
def test_coherent_across_step_counts(self, steps):
|
||||
"""All solvers should agree on step 0 regardless of total step count."""
|
||||
mx.random.seed(77)
|
||||
shape = (1, 2, 1, 2, 2)
|
||||
noise = mx.random.normal(shape)
|
||||
vel = mx.random.normal(shape)
|
||||
|
||||
scheds = self._make_schedulers(steps, shift=5.0)
|
||||
results = {}
|
||||
for name, sched in scheds.items():
|
||||
r = sched.step(vel, sched.timesteps[0], noise)
|
||||
mx.eval(r)
|
||||
results[name] = np.array(r)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
results["dpm++"],
|
||||
results["euler"],
|
||||
atol=1e-5,
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
results["unipc"],
|
||||
results["euler"],
|
||||
atol=1e-5,
|
||||
)
|
||||
|
||||
def test_dpmpp_unipc_agree_on_step1(self):
|
||||
"""After warmup, DPM++ and UniPC step 1 should be similar
|
||||
(both use 2nd-order corrections based on the same model outputs)."""
|
||||
mx.random.seed(42)
|
||||
shape = (1, 4, 1, 4, 4)
|
||||
noise = mx.random.normal(shape)
|
||||
|
||||
scheds = self._make_schedulers(10, shift=5.0)
|
||||
# Run step 0 with same velocity
|
||||
vel0 = mx.random.normal(shape)
|
||||
for sched in scheds.values():
|
||||
sched.step(vel0, sched.timesteps[0], noise)
|
||||
|
||||
# Run step 1 from same sample with same velocity
|
||||
sample1 = scheds["euler"].step(vel0, scheds["euler"].timesteps[0], noise)
|
||||
mx.eval(sample1)
|
||||
vel1 = mx.random.normal(shape)
|
||||
|
||||
r_dpm = scheds["dpm++"].step(vel1, scheds["dpm++"].timesteps[1], sample1)
|
||||
r_unipc = scheds["unipc"].step(vel1, scheds["unipc"].timesteps[1], sample1)
|
||||
mx.eval(r_dpm, r_unipc)
|
||||
|
||||
# They won't be identical (different correction formulas) but should
|
||||
# be in the same ballpark (within 50% of each other's magnitude)
|
||||
mean_dpm = float(mx.mean(mx.abs(r_dpm)).item())
|
||||
mean_unipc = float(mx.mean(mx.abs(r_unipc)).item())
|
||||
ratio = max(mean_dpm, mean_unipc) / max(min(mean_dpm, mean_unipc), 1e-10)
|
||||
assert ratio < 2.0, (
|
||||
f"DPM++ and UniPC step 1 differ too much: "
|
||||
f"DPM++={mean_dpm:.4f}, UniPC={mean_unipc:.4f}"
|
||||
)
|
||||
|
||||
def test_reset_makes_solvers_reproducible(self):
|
||||
"""After reset(), running the same loop should produce identical output."""
|
||||
mx.random.seed(42)
|
||||
shape = (1, 2, 1, 2, 2)
|
||||
noise = mx.random.normal(shape)
|
||||
|
||||
from mlx_video.models.wan_2.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
|
||||
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
|
||||
sched = cls()
|
||||
sched.set_timesteps(5, shift=5.0)
|
||||
|
||||
# First run
|
||||
latents = noise
|
||||
for i in range(5):
|
||||
vel = latents * 0.1
|
||||
latents = sched.step(vel, sched.timesteps[i], latents)
|
||||
mx.eval(latents)
|
||||
result1 = np.array(latents)
|
||||
|
||||
# Reset and run again
|
||||
sched.reset()
|
||||
latents = noise
|
||||
for i in range(5):
|
||||
vel = latents * 0.1
|
||||
latents = sched.step(vel, sched.timesteps[i], latents)
|
||||
mx.eval(latents)
|
||||
result2 = np.array(latents)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
result1,
|
||||
result2,
|
||||
atol=1e-5,
|
||||
err_msg=f"{cls.__name__} not reproducible after reset()",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# UniPC Corrector Default Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUniPCCorrectorDefault:
|
||||
"""Tests that the UniPC corrector is enabled by default,
|
||||
matching official FlowUniPCMultistepScheduler behavior."""
|
||||
|
||||
def test_corrector_enabled_by_default(self):
|
||||
"""Default construction should have corrector enabled."""
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
sched = FlowUniPCScheduler()
|
||||
assert sched._use_corrector is True
|
||||
|
||||
def test_corrector_affects_output(self):
|
||||
"""Corrector should produce different results than no corrector after step 1."""
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
mx.random.seed(42)
|
||||
shape = (1, 4, 1, 4, 4)
|
||||
noise = mx.random.normal(shape)
|
||||
|
||||
sched_corr = FlowUniPCScheduler(use_corrector=True)
|
||||
sched_corr.set_timesteps(10, shift=5.0)
|
||||
sched_no = FlowUniPCScheduler(use_corrector=False)
|
||||
sched_no.set_timesteps(10, shift=5.0)
|
||||
|
||||
latent_corr = noise
|
||||
latent_no = noise
|
||||
for i in range(3):
|
||||
vel = mx.random.normal(shape) * 0.1
|
||||
latent_corr = sched_corr.step(vel, sched_corr.timesteps[i], latent_corr)
|
||||
latent_no = sched_no.step(vel, sched_no.timesteps[i], latent_no)
|
||||
mx.eval(latent_corr, latent_no)
|
||||
|
||||
diff = float(mx.abs(latent_corr - latent_no).max())
|
||||
assert diff > 1e-6, f"Corrector had no effect (max diff={diff})"
|
||||
|
||||
def test_corrector_does_not_affect_first_step(self):
|
||||
"""Step 0 should be identical regardless of corrector setting."""
|
||||
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
|
||||
|
||||
mx.random.seed(42)
|
||||
shape = (1, 4, 1, 4, 4)
|
||||
noise = mx.random.normal(shape)
|
||||
vel = mx.random.normal(shape)
|
||||
|
||||
sched_corr = FlowUniPCScheduler(use_corrector=True)
|
||||
sched_corr.set_timesteps(10, shift=5.0)
|
||||
sched_no = FlowUniPCScheduler(use_corrector=False)
|
||||
sched_no.set_timesteps(10, shift=5.0)
|
||||
|
||||
r1 = sched_corr.step(vel, sched_corr.timesteps[0], noise)
|
||||
r2 = sched_no.step(vel, sched_no.timesteps[0], noise)
|
||||
mx.eval(r1, r2)
|
||||
np.testing.assert_allclose(np.array(r1), np.array(r2), atol=1e-6)
|
||||
218
tests/test_wan_t5.py
Normal file
218
tests/test_wan_t5.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Tests for T5 encoder components."""
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# T5 Encoder Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestT5LayerNorm:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan_2.text_encoder import T5LayerNorm
|
||||
|
||||
norm = T5LayerNorm(64)
|
||||
x = mx.random.normal((2, 10, 64))
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (2, 10, 64)
|
||||
|
||||
def test_rms_normalization(self):
|
||||
"""After T5LayerNorm with weight=1, RMS should be ~1."""
|
||||
from mlx_video.models.wan_2.text_encoder import T5LayerNorm
|
||||
|
||||
norm = T5LayerNorm(128)
|
||||
x = mx.random.normal((1, 5, 128)) * 5.0
|
||||
out = norm(x)
|
||||
mx.eval(out)
|
||||
out_np = np.array(out[0])
|
||||
for i in range(5):
|
||||
rms = np.sqrt(np.mean(out_np[i] ** 2))
|
||||
np.testing.assert_allclose(rms, 1.0, rtol=0.1)
|
||||
|
||||
|
||||
class TestT5RelativeEmbedding:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan_2.text_encoder import T5RelativeEmbedding
|
||||
|
||||
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
|
||||
out = rel_emb(10, 10)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 4, 10, 10) # [1, N, lq, lk]
|
||||
|
||||
def test_asymmetric_lengths(self):
|
||||
from mlx_video.models.wan_2.text_encoder import T5RelativeEmbedding
|
||||
|
||||
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
|
||||
out = rel_emb(8, 12)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 4, 8, 12)
|
||||
|
||||
def test_symmetry(self):
|
||||
"""Position bias should have structure (not all zeros/random)."""
|
||||
from mlx_video.models.wan_2.text_encoder import T5RelativeEmbedding
|
||||
|
||||
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2)
|
||||
out = rel_emb(6, 6)
|
||||
mx.eval(out)
|
||||
out_np = np.array(out[0]) # [N, lq, lk]
|
||||
# Diagonal elements (position i attending to position i) should be consistent
|
||||
# (same relative distance = 0 for all diagonal elements)
|
||||
for h in range(2):
|
||||
diag = np.diag(out_np[h])
|
||||
np.testing.assert_allclose(diag, diag[0], atol=1e-5)
|
||||
|
||||
|
||||
class TestT5Attention:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan_2.text_encoder import T5Attention
|
||||
|
||||
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||
x = mx.random.normal((1, 10, 64))
|
||||
out = attn(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 10, 64)
|
||||
|
||||
def test_no_scaling(self):
|
||||
"""T5 attention famously has no sqrt(d) scaling. Verify structure."""
|
||||
from mlx_video.models.wan_2.text_encoder import T5Attention
|
||||
|
||||
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||
# No scale attribute (unlike standard attention)
|
||||
assert not hasattr(attn, "scale")
|
||||
|
||||
def test_with_position_bias(self):
|
||||
from mlx_video.models.wan_2.text_encoder import T5Attention, T5RelativeEmbedding
|
||||
|
||||
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||
rel_emb = T5RelativeEmbedding(32, 4)
|
||||
x = mx.random.normal((1, 10, 64))
|
||||
pos_bias = rel_emb(10, 10)
|
||||
out = attn(x, pos_bias=pos_bias)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 10, 64)
|
||||
|
||||
def test_with_mask(self):
|
||||
from mlx_video.models.wan_2.text_encoder import T5Attention
|
||||
|
||||
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||
x = mx.random.normal((1, 10, 64))
|
||||
mask = mx.ones((1, 10))
|
||||
mask = mx.concatenate([mask[:, :7], mx.zeros((1, 3))], axis=1)
|
||||
out = attn(x, mask=mask)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 10, 64)
|
||||
|
||||
|
||||
class TestT5FeedForward:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan_2.text_encoder import T5FeedForward
|
||||
|
||||
ffn = T5FeedForward(64, 256)
|
||||
x = mx.random.normal((1, 10, 64))
|
||||
out = ffn(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 10, 64)
|
||||
|
||||
def test_gated_structure(self):
|
||||
"""T5 FFN is gated: gate(x) * fc1(x)."""
|
||||
from mlx_video.models.wan_2.text_encoder import T5FeedForward
|
||||
|
||||
ffn = T5FeedForward(32, 64)
|
||||
assert hasattr(ffn, "gate_proj")
|
||||
assert hasattr(ffn, "fc1")
|
||||
assert hasattr(ffn, "fc2")
|
||||
|
||||
|
||||
class TestT5Encoder:
|
||||
def setup_method(self):
|
||||
mx.random.seed(42)
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan_2.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
vocab_size=100,
|
||||
dim=64,
|
||||
dim_attn=64,
|
||||
dim_ffn=128,
|
||||
num_heads=4,
|
||||
num_layers=2,
|
||||
num_buckets=32,
|
||||
shared_pos=False,
|
||||
)
|
||||
ids = mx.array([[1, 5, 10, 0, 0]])
|
||||
mask = mx.array([[1, 1, 1, 0, 0]])
|
||||
out = encoder(ids, mask=mask)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 5, 64)
|
||||
|
||||
def test_shared_pos(self):
|
||||
from mlx_video.models.wan_2.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
vocab_size=100,
|
||||
dim=64,
|
||||
dim_attn=64,
|
||||
dim_ffn=128,
|
||||
num_heads=4,
|
||||
num_layers=2,
|
||||
num_buckets=32,
|
||||
shared_pos=True,
|
||||
)
|
||||
assert encoder.pos_embedding is not None
|
||||
for block in encoder.blocks:
|
||||
assert block.pos_embedding is None
|
||||
|
||||
def test_per_layer_pos(self):
|
||||
from mlx_video.models.wan_2.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
vocab_size=100,
|
||||
dim=64,
|
||||
dim_attn=64,
|
||||
dim_ffn=128,
|
||||
num_heads=4,
|
||||
num_layers=2,
|
||||
num_buckets=32,
|
||||
shared_pos=False,
|
||||
)
|
||||
assert encoder.pos_embedding is None
|
||||
for block in encoder.blocks:
|
||||
assert block.pos_embedding is not None
|
||||
|
||||
def test_param_count(self):
|
||||
from mlx_video.models.wan_2.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
vocab_size=100,
|
||||
dim=64,
|
||||
dim_attn=64,
|
||||
dim_ffn=128,
|
||||
num_heads=4,
|
||||
num_layers=2,
|
||||
num_buckets=32,
|
||||
shared_pos=False,
|
||||
)
|
||||
num_params = sum(p.size for _, p in nn.utils.tree_flatten(encoder.parameters()))
|
||||
assert num_params > 0
|
||||
|
||||
def test_without_mask(self):
|
||||
from mlx_video.models.wan_2.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
vocab_size=100,
|
||||
dim=64,
|
||||
dim_attn=64,
|
||||
dim_ffn=128,
|
||||
num_heads=4,
|
||||
num_layers=2,
|
||||
num_buckets=32,
|
||||
shared_pos=False,
|
||||
)
|
||||
ids = mx.array([[1, 5, 10]])
|
||||
out = encoder(ids)
|
||||
mx.eval(out)
|
||||
assert out.shape == (1, 3, 64)
|
||||
213
tests/test_wan_tiling.py
Normal file
213
tests/test_wan_tiling.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Tests for Wan VAE tiled decoding."""
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import (
|
||||
TilingConfig,
|
||||
decode_with_tiling,
|
||||
split_in_spatial,
|
||||
)
|
||||
|
||||
|
||||
class TestNonCausalTemporal:
|
||||
"""Tests for the causal_temporal=False path in decode_with_tiling."""
|
||||
|
||||
def test_split_spatial_for_temporal(self):
|
||||
"""Non-causal temporal should use split_in_spatial (no causal shift)."""
|
||||
intervals = split_in_spatial(8, 2, 20)
|
||||
# No causal adjustment: starts should be evenly spaced
|
||||
assert intervals.starts[0] == 0
|
||||
for i in range(1, len(intervals.starts)):
|
||||
assert intervals.starts[i] == intervals.starts[i - 1] + (8 - 2)
|
||||
|
||||
def test_causal_vs_noncausal_output_size(self):
|
||||
"""Causal temporal gives 1+(T-1)*S frames, non-causal gives T*S."""
|
||||
mx.random.seed(42)
|
||||
b, c, t, h, w = 1, 4, 4, 4, 4
|
||||
latents = mx.random.normal((b, c, t, h, w))
|
||||
scale = 4
|
||||
|
||||
# Simple passthrough decoder: just repeat along dimensions
|
||||
def dummy_decoder_causal(x, **kwargs):
|
||||
b, c, t, h, w = x.shape
|
||||
out_t = 1 + (t - 1) * scale
|
||||
out_h = h * scale
|
||||
out_w = w * scale
|
||||
return mx.ones((b, 3, out_t, out_h, out_w))
|
||||
|
||||
def dummy_decoder_noncausal(x, **kwargs):
|
||||
b, c, t, h, w = x.shape
|
||||
out_t = t * scale
|
||||
out_h = h * scale
|
||||
out_w = w * scale
|
||||
return mx.ones((b, 3, out_t, out_h, out_w))
|
||||
|
||||
config = TilingConfig.spatial_only(tile_size=128, overlap=64)
|
||||
|
||||
# Causal: 1 + (4-1)*4 = 13
|
||||
out_causal = decode_with_tiling(
|
||||
dummy_decoder_causal,
|
||||
latents,
|
||||
config,
|
||||
spatial_scale=scale,
|
||||
temporal_scale=scale,
|
||||
causal_temporal=True,
|
||||
)
|
||||
mx.eval(out_causal)
|
||||
assert out_causal.shape[2] == 1 + (t - 1) * scale # 13
|
||||
|
||||
# Non-causal: 4*4 = 16
|
||||
out_noncausal = decode_with_tiling(
|
||||
dummy_decoder_noncausal,
|
||||
latents,
|
||||
config,
|
||||
spatial_scale=scale,
|
||||
temporal_scale=scale,
|
||||
causal_temporal=False,
|
||||
)
|
||||
mx.eval(out_noncausal)
|
||||
assert out_noncausal.shape[2] == t * scale # 16
|
||||
|
||||
|
||||
class TestWan22TiledDecoding:
|
||||
"""Tests for Wan2.2 VAE tiled decoding."""
|
||||
|
||||
def _make_small_wan22_decoder(self):
|
||||
"""Create a small Wan2.2 decoder for testing."""
|
||||
from mlx_video.models.wan_2.vae22 import Wan22VAEDecoder
|
||||
|
||||
# Use very small dimensions for fast testing
|
||||
vae = Wan22VAEDecoder(z_dim=48, dim=16, dec_dim=16)
|
||||
mx.eval(vae.parameters())
|
||||
return vae
|
||||
|
||||
def test_decode_tiled_output_shape(self):
|
||||
"""Tiled decode should produce same shape as non-tiled."""
|
||||
mx.random.seed(42)
|
||||
vae = self._make_small_wan22_decoder()
|
||||
|
||||
# Small input: [B=1, T=3, H=2, W=2, C=48]
|
||||
z = mx.random.normal((1, 3, 2, 2, 48))
|
||||
mx.eval(z)
|
||||
|
||||
# Non-tiled
|
||||
out_regular = vae(z)
|
||||
mx.eval(out_regular)
|
||||
|
||||
# Tiled (force tiling with very small tile sizes)
|
||||
# Use spatial tile=32px (2 latent @ scale 16) and temporal=8 frames (2 latent @ scale 4)
|
||||
config = TilingConfig(
|
||||
spatial_config=None, # Don't tile spatially (input is tiny)
|
||||
temporal_config=None, # Don't tile temporally (input is tiny)
|
||||
)
|
||||
# With no tiling config, decode_tiled should fall through to regular decode
|
||||
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
|
||||
mx.eval(out_tiled)
|
||||
|
||||
# Both should produce the same shape
|
||||
assert (
|
||||
out_regular.shape == out_tiled.shape
|
||||
), f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
|
||||
|
||||
def test_decode_tiled_falls_through_when_small(self):
|
||||
"""When input is smaller than tile size, decode_tiled should produce same output as __call__."""
|
||||
mx.random.seed(42)
|
||||
vae = self._make_small_wan22_decoder()
|
||||
|
||||
# Input smaller than any tile size
|
||||
z = mx.random.normal((1, 2, 2, 2, 48))
|
||||
mx.eval(z)
|
||||
|
||||
out_regular = vae(z)
|
||||
mx.eval(out_regular)
|
||||
|
||||
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
|
||||
mx.eval(out_tiled)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
np.array(out_regular),
|
||||
np.array(out_tiled),
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
err_msg="Tiled decode should match regular decode for small inputs",
|
||||
)
|
||||
|
||||
|
||||
class TestWan21TiledDecoding:
|
||||
"""Tests for Wan2.1 VAE tiled decoding."""
|
||||
|
||||
def _make_small_wan21_vae(self):
|
||||
"""Create a small Wan2.1 VAE for testing."""
|
||||
from mlx_video.models.wan_2.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16)
|
||||
mx.eval(vae.parameters())
|
||||
return vae
|
||||
|
||||
def test_decode_tiled_output_shape(self):
|
||||
"""Tiled decode should produce correct output shape."""
|
||||
mx.random.seed(42)
|
||||
vae = self._make_small_wan21_vae()
|
||||
|
||||
# [B=1, C=16, T=3, H=4, W=4]
|
||||
z = mx.random.normal((1, 16, 3, 4, 4))
|
||||
mx.eval(z)
|
||||
|
||||
out_regular = vae.decode(z)
|
||||
mx.eval(out_regular)
|
||||
|
||||
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
|
||||
mx.eval(out_tiled)
|
||||
|
||||
assert (
|
||||
out_regular.shape == out_tiled.shape
|
||||
), f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
|
||||
|
||||
def test_decode_tiled_falls_through_when_small(self):
|
||||
"""When input is smaller than tile size, decode_tiled should produce same output as decode."""
|
||||
mx.random.seed(42)
|
||||
vae = self._make_small_wan21_vae()
|
||||
|
||||
z = mx.random.normal((1, 16, 2, 4, 4))
|
||||
mx.eval(z)
|
||||
|
||||
out_regular = vae.decode(z)
|
||||
mx.eval(out_regular)
|
||||
|
||||
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
|
||||
mx.eval(out_tiled)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
np.array(out_regular),
|
||||
np.array(out_tiled),
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
err_msg="Tiled decode should match regular decode for small inputs",
|
||||
)
|
||||
|
||||
|
||||
class TestWan21TemporalScale:
|
||||
"""Verify Wan2.1 decoder temporal output is T*4 (non-causal)."""
|
||||
|
||||
def test_wan21_decoder_temporal_output(self):
|
||||
"""Wan2.1 Decoder3d should produce T*4 temporal output (non-causal doubling)."""
|
||||
from mlx_video.models.wan_2.vae import Decoder3d
|
||||
|
||||
# Small decoder for fast test
|
||||
dec = Decoder3d(
|
||||
dim=16,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 1, 1, 1],
|
||||
num_res_blocks=1,
|
||||
temporal_upsample=[True, True, False],
|
||||
)
|
||||
mx.eval(dec.parameters())
|
||||
|
||||
x = mx.random.normal((1, 4, 3, 4, 4)) # T=3
|
||||
mx.eval(x)
|
||||
out = dec(x)
|
||||
mx.eval(out)
|
||||
|
||||
# With two temporal 2× upsamples: T=3 → 6 → 12
|
||||
assert out.shape[2] == 3 * 4, f"Expected T=12, got T={out.shape[2]}"
|
||||
182
tests/test_wan_transformer.py
Normal file
182
tests/test_wan_transformer.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""Tests for Wan transformer block components."""
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transformer Block Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWanFFN:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan_2.transformer import WanFFN
|
||||
|
||||
ffn = WanFFN(64, 256)
|
||||
x = mx.random.normal((2, 10, 64))
|
||||
out = ffn(x)
|
||||
mx.eval(out)
|
||||
assert out.shape == (2, 10, 64)
|
||||
|
||||
def test_gelu_activation(self):
|
||||
"""FFN should use GELU activation (non-linearity)."""
|
||||
from mlx_video.models.wan_2.transformer import WanFFN
|
||||
|
||||
ffn = WanFFN(32, 128)
|
||||
x = mx.ones((1, 1, 32)) * 2.0
|
||||
out1 = ffn(x)
|
||||
x2 = mx.ones((1, 1, 32)) * 4.0
|
||||
out2 = ffn(x2)
|
||||
mx.eval(out1, out2)
|
||||
# Non-linear: 2x input should not give 2x output
|
||||
assert not np.allclose(np.array(out2), np.array(out1) * 2.0, rtol=0.1)
|
||||
|
||||
|
||||
class TestWanAttentionBlock:
|
||||
def setup_method(self):
|
||||
mx.random.seed(42)
|
||||
self.dim = 64
|
||||
self.ffn_dim = 128
|
||||
self.num_heads = 4
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
from mlx_video.models.wan_2.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(
|
||||
self.dim,
|
||||
self.ffn_dim,
|
||||
self.num_heads,
|
||||
cross_attn_norm=True,
|
||||
)
|
||||
B, L = 1, 24
|
||||
F, H, W = 2, 3, 4
|
||||
x = mx.random.normal((B, L, self.dim))
|
||||
e = mx.random.normal((B, L, 6, self.dim))
|
||||
context = mx.random.normal((B, 16, self.dim))
|
||||
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||
|
||||
out = block(
|
||||
x,
|
||||
e,
|
||||
seq_lens=[L],
|
||||
grid_sizes=[(F, H, W)],
|
||||
freqs=freqs,
|
||||
context=context,
|
||||
)
|
||||
mx.eval(out)
|
||||
assert out.shape == (B, L, self.dim)
|
||||
|
||||
def test_modulation_shape(self):
|
||||
from mlx_video.models.wan_2.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
|
||||
assert block.modulation.shape == (1, 6, self.dim)
|
||||
|
||||
def test_with_cross_attn_norm(self):
|
||||
from mlx_video.models.wan_2.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(
|
||||
self.dim,
|
||||
self.ffn_dim,
|
||||
self.num_heads,
|
||||
cross_attn_norm=True,
|
||||
)
|
||||
assert block.norm3 is not None
|
||||
|
||||
def test_without_cross_attn_norm(self):
|
||||
from mlx_video.models.wan_2.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(
|
||||
self.dim,
|
||||
self.ffn_dim,
|
||||
self.num_heads,
|
||||
cross_attn_norm=False,
|
||||
)
|
||||
assert block.norm3 is None
|
||||
|
||||
def test_residual_connection(self):
|
||||
"""Output should differ from zero even with small random init."""
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
from mlx_video.models.wan_2.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
|
||||
B, L = 1, 8
|
||||
F, H, W = 2, 2, 2
|
||||
x = mx.ones((B, L, self.dim))
|
||||
e = mx.zeros((B, L, 6, self.dim))
|
||||
context = mx.random.normal((B, 4, self.dim))
|
||||
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||
|
||||
out = block(x, e, [L], [(F, H, W)], freqs, context)
|
||||
mx.eval(out)
|
||||
# With residual connections, output should be close to input + corrections
|
||||
assert not np.allclose(np.array(out), 0.0, atol=1e-3)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Float32 Modulation Precision Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFloat32Modulation:
|
||||
"""Tests that modulation/gate operations are computed in float32,
|
||||
matching official torch.amp.autocast('cuda', dtype=torch.float32)."""
|
||||
|
||||
def setup_method(self):
|
||||
mx.random.seed(42)
|
||||
self.dim = 64
|
||||
|
||||
def test_block_modulation_in_float32(self):
|
||||
"""Modulation param starts random but should be usable as float32."""
|
||||
from mlx_video.models.wan_2.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(self.dim, 128, 4, cross_attn_norm=True)
|
||||
assert block.modulation.dtype == mx.float32
|
||||
|
||||
def test_block_output_float32_with_bf16_modulation_input(self):
|
||||
"""Even if e (time embedding) arrives as bf16, modulation should cast to f32."""
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
from mlx_video.models.wan_2.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(self.dim, 128, 4)
|
||||
B, L = 1, 8
|
||||
x = mx.random.normal((B, L, self.dim))
|
||||
e = mx.random.normal((B, L, 6, self.dim)).astype(mx.bfloat16)
|
||||
ctx = mx.random.normal((B, 4, self.dim))
|
||||
freqs = rope_params(1024, self.dim // 4)
|
||||
|
||||
out = block(x, e, [L], [(2, 2, 2)], freqs, ctx)
|
||||
mx.eval(out)
|
||||
assert out.dtype == mx.float32
|
||||
assert np.isfinite(np.array(out)).all()
|
||||
|
||||
def test_head_modulation_float32(self):
|
||||
"""Head modulation should be float32 even with bf16 e input."""
|
||||
from mlx_video.models.wan_2.wan_2 import Head
|
||||
|
||||
head = Head(self.dim, 4, (1, 2, 2))
|
||||
x = mx.random.normal((1, 8, self.dim))
|
||||
e = mx.random.normal((1, 8, self.dim)).astype(mx.bfloat16)
|
||||
out = head(x, e)
|
||||
mx.eval(out)
|
||||
assert np.isfinite(np.array(out.astype(mx.float32))).all()
|
||||
|
||||
def test_model_time_embedding_float32(self):
|
||||
"""sinusoidal_embedding_1d output must be float32."""
|
||||
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
|
||||
|
||||
t = mx.array([500.0])
|
||||
emb = sinusoidal_embedding_1d(256, t)
|
||||
mx.eval(emb)
|
||||
assert emb.dtype == mx.float32
|
||||
|
||||
def test_model_per_token_time_embedding_float32(self):
|
||||
"""Per-token time embeddings (I2V) should also be float32."""
|
||||
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
|
||||
|
||||
t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4]
|
||||
emb = sinusoidal_embedding_1d(256, t)
|
||||
mx.eval(emb)
|
||||
assert emb.dtype == mx.float32
|
||||
assert emb.shape == (1, 4, 256)
|
||||
1029
tests/test_wan_vae.py
Normal file
1029
tests/test_wan_vae.py
Normal file
File diff suppressed because it is too large
Load Diff
20
tests/wan_test_helpers.py
Normal file
20
tests/wan_test_helpers.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Shared test helpers for Wan test modules."""
|
||||
|
||||
|
||||
def _make_tiny_config():
|
||||
"""Create a tiny WanModelConfig for testing."""
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
# Override to tiny values
|
||||
config.dim = 64
|
||||
config.ffn_dim = 128
|
||||
config.num_heads = 4
|
||||
config.num_layers = 2
|
||||
config.in_dim = 4
|
||||
config.out_dim = 4
|
||||
config.patch_size = (1, 2, 2)
|
||||
config.freq_dim = 32
|
||||
config.text_dim = 32
|
||||
config.text_len = 8
|
||||
return config
|
||||
54
uv.lock
generated
54
uv.lock
generated
@@ -622,6 +622,18 @@ http = [
|
||||
{ name = "aiohttp" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ftfy"
|
||||
version = "6.3.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "wcwidth" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a5/d3/8650919bc3c7c6e90ee3fa7fd618bf373cbbe55dff043bd67353dbb20cd8/ftfy-6.3.1.tar.gz", hash = "sha256:9b3c3d90f84fb267fe64d375a07b7f8912d817cf86009ae134aa03e1819506ec", size = 308927, upload-time = "2024-10-26T00:50:35.149Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/6e/81d47999aebc1b155f81eca4477a616a70f238a2549848c38983f3c22a82/ftfy-6.3.1-py3-none-any.whl", hash = "sha256:7c70eb532015cd2f9adb53f101fb6c7945988d023a085d127d1573dc49dd0083", size = 44821, upload-time = "2024-10-26T00:50:33.425Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h11"
|
||||
version = "0.16.0"
|
||||
@@ -720,6 +732,33 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "imageio"
|
||||
version = "2.37.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "numpy" },
|
||||
{ name = "pillow" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a3/6f/606be632e37bf8d05b253e8626c2291d74c691ddc7bcdf7d6aaf33b32f6a/imageio-2.37.2.tar.gz", hash = "sha256:0212ef2727ac9caa5ca4b2c75ae89454312f440a756fcfc8ef1993e718f50f8a", size = 389600, upload-time = "2025-11-04T14:29:39.898Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/fe/301e0936b79bcab4cacc7548bf2853fc28dced0a578bab1f7ef53c9aa75b/imageio-2.37.2-py3-none-any.whl", hash = "sha256:ad9adfb20335d718c03de457358ed69f141021a333c40a53e57273d8a5bd0b9b", size = 317646, upload-time = "2025-11-04T14:29:37.948Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "imageio-ffmpeg"
|
||||
version = "0.6.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/44/bd/c3343c721f2a1b0c9fc71c1aebf1966a3b7f08c2eea8ed5437a2865611d6/imageio_ffmpeg-0.6.0.tar.gz", hash = "sha256:e2556bed8e005564a9f925bb7afa4002d82770d6b08825078b7697ab88ba1755", size = 25210, upload-time = "2025-01-16T21:34:32.747Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/da/58/87ef68ac83f4c7690961bce288fd8e382bc5f1513860fc7f90a9c1c1c6bf/imageio_ffmpeg-0.6.0-py3-none-macosx_10_9_intel.macosx_10_9_x86_64.whl", hash = "sha256:9d2baaf867088508d4a3458e61eeb30e945c4ad8016025545f66c4b5aaef0a61", size = 24932969, upload-time = "2025-01-16T21:34:20.464Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/40/5c/f3d8a657d362cc93b81aab8feda487317da5b5d31c0e1fdfd5e986e55d17/imageio_ffmpeg-0.6.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:b1ae3173414b5fc5f538a726c4e48ea97edc0d2cdc11f103afee655c463fa742", size = 21113891, upload-time = "2025-01-16T21:34:00.277Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/33/e7/1925bfbc563c39c1d2e82501d8372734a5c725e53ac3b31b4c2d081e895b/imageio_ffmpeg-0.6.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1d47bebd83d2c5fc770720d211855f208af8a596c82d17730aa51e815cdee6dc", size = 25632706, upload-time = "2025-01-16T21:33:53.475Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/2d/43c8522a2038e9d0e7dbdf3a61195ecc31ca576fb1527a528c877e87d973/imageio_ffmpeg-0.6.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c7e46fcec401dd990405049d2e2f475e2b397779df2519b544b8aab515195282", size = 29498237, upload-time = "2025-01-16T21:34:13.726Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/13/59da54728351883c3c1d9fca1710ab8eee82c7beba585df8f25ca925f08f/imageio_ffmpeg-0.6.0-py3-none-win32.whl", hash = "sha256:196faa79366b4a82f95c0f4053191d2013f4714a715780f0ad2a68ff37483cc2", size = 19652251, upload-time = "2025-01-16T21:34:06.812Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2c/c6/fa760e12a2483469e2bf5058c5faff664acf66cadb4df2ad6205b016a73d/imageio_ffmpeg-0.6.0-py3-none-win_amd64.whl", hash = "sha256:02fa47c83703c37df6bfe4896aab339013f62bf02c5ebf2dce6da56af04ffc0a", size = 31246824, upload-time = "2025-01-16T21:34:28.6Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
version = "2.3.0"
|
||||
@@ -969,7 +1008,10 @@ wheels = [
|
||||
name = "mlx-video"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "ftfy" },
|
||||
{ name = "huggingface-hub" },
|
||||
{ name = "imageio" },
|
||||
{ name = "imageio-ffmpeg" },
|
||||
{ name = "librosa" },
|
||||
{ name = "mlx" },
|
||||
{ name = "mlx-vlm" },
|
||||
@@ -989,7 +1031,10 @@ dev = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "ftfy" },
|
||||
{ name = "huggingface-hub" },
|
||||
{ name = "imageio", specifier = ">=2.37.2" },
|
||||
{ name = "imageio-ffmpeg", specifier = ">=0.6.0" },
|
||||
{ name = "librosa", specifier = ">=0.10.0" },
|
||||
{ name = "mlx", specifier = ">=0.22.0" },
|
||||
{ name = "mlx-vlm" },
|
||||
@@ -2482,6 +2527,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/0a/89/f8827ccff89c1586027a105e5630ff6139a64da2515e24dafe860bd9ae4d/uvicorn-0.42.0-py3-none-any.whl", hash = "sha256:96c30f5c7abe6f74ae8900a70e92b85ad6613b745d4879eb9b16ccad15645359", size = 68830, upload-time = "2026-03-16T06:19:48.325Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wcwidth"
|
||||
version = "0.6.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/35/a2/8e3becb46433538a38726c948d3399905a4c7cabd0df578ede5dc51f0ec2/wcwidth-0.6.0.tar.gz", hash = "sha256:cdc4e4262d6ef9a1a57e018384cbeb1208d8abbc64176027e2c2455c81313159", size = 159684, upload-time = "2026-02-06T19:19:40.919Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/68/5a/199c59e0a824a3db2b89c5d2dade7ab5f9624dbf6448dc291b46d5ec94d3/wcwidth-0.6.0-py3-none-any.whl", hash = "sha256:1a3a1e510b553315f8e146c54764f4fb6264ffad731b3d78088cdb1478ffbdad", size = 94189, upload-time = "2026-02-06T19:19:39.646Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "xxhash"
|
||||
version = "3.6.0"
|
||||
|
||||
Reference in New Issue
Block a user