Wan2.1 and Wan2.2 model support, including LoRA support & more Poodles

This commit is contained in:
Prince Canuma
2026-03-11 19:08:14 +01:00
committed by GitHub
48 changed files with 14206 additions and 35 deletions

167
README.md
View File

@@ -18,18 +18,20 @@ uv pip install git+https://github.com/Blaizzy/mlx-video.git
Supported models: Supported models:
### LTX-2 - [**LTX-2**](https://huggingface.co/Lightricks/LTX-Video) — 19B parameter video generation model from Lightricks
[LTX-2](https://huggingface.co/Lightricks/LTX-Video) is 19B parameter video generation model from Lightricks - [**Wan2.1**](https://github.com/Wan-Video/Wan2.1) — 1.3B / 14B parameter T2V models (single-model pipeline)
- [**Wan2.2**](https://github.com/Wan-Video/Wan2.2) — T2V-14B, TI2V-5B, and I2V-14B models (dual-model pipeline)
## Features ## Features
- Text-to-video generation with the LTX-2 19B DiT model - Text-to-video generation with multiple model families
- Two-stage generation pipeline for high-quality output - LTX-2: Two-stage pipeline with 2x spatial upscaling
- 2x spatial upscaling for images and videos - Wan2.1/2.2: Flow-matching diffusion with classifier-free guidance
- Optimized for Apple Silicon using MLX - Optimized for Apple Silicon using MLX
---
## Usage ## LTX-2
> ** Info:** Currently, only the distilled variant is supported. Full LTX-2 feature support is coming soon. > ** Info:** Currently, only the distilled variant is supported. Full LTX-2 feature support is coming soon.
@@ -53,7 +55,7 @@ python -m mlx_video.generate \
--output my_video.mp4 --output my_video.mp4
``` ```
### CLI Options ### LTX-2 CLI Options
| Option | Default | Description | | Option | Default | Description |
|--------|---------|-------------| |--------|---------|-------------|
@@ -67,45 +69,146 @@ python -m mlx_video.generate \
| `--save-frames` | false | Save individual frames as images | | `--save-frames` | false | Save individual frames as images |
| `--model-repo` | Lightricks/LTX-2 | HuggingFace model repository | | `--model-repo` | Lightricks/LTX-2 | HuggingFace model repository |
## How It Works ### How It Works (LTX-2)
The pipeline uses a two-stage generation process: 1. **Stage 1**: Generate at half resolution (e.g., 384×384) with 8 denoising steps
2. **Upsample**: 2× spatial upsampling via LatentUpsampler
1. **Stage 1**: Generate at half resolution (e.g., 384x384) with 8 denoising steps 3. **Stage 2**: Refine at full resolution (e.g., 768×768) with 3 denoising steps
2. **Upsample**: 2x spatial upsampling via LatentUpsampler
3. **Stage 2**: Refine at full resolution (e.g., 768x768) with 3 denoising steps
4. **Decode**: VAE decoder converts latents to RGB video 4. **Decode**: VAE decoder converts latents to RGB video
---
## Wan2.1 / Wan2.2
Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.com/Wan-Video/Wan2.2) are text-to-video diffusion models built on a DiT (Diffusion Transformer) backbone with a T5 text encoder and 3D VAE.
### Step 0: Download and Convert Weights
See the dedicated Wan2.1/Wan2.2 [README.md](mlx_video/models/wan/README.md) for details.
### Step 1: Generate Video
```bash
# Wan2.1 — uses defaults from config (50 steps, shift=5.0, guide=5.0)
python -m mlx_video.generate_wan \
--model-dir wan21_mlx \
--prompt "A cat playing piano in a cozy room"
# Wan2.2 — uses defaults from config (40 steps, shift=12.0, guide=3.0,4.0)
python -m mlx_video.generate_wan \
--model-dir wan22_mlx \
--prompt "A cat playing piano in a cozy room"
```
With custom settings:
```bash
python -m mlx_video.generate_wan \
--model-dir wan21_mlx \
--prompt "Ocean waves at sunset, cinematic, 4K" \
--negative-prompt "blurry, low quality" \
--width 1280 \
--height 720 \
--num-frames 81 \
--steps 50 \
--guide-scale 5.0 \
--shift 5.0 \
--seed 42 \
--output-path my_video.mp4
```
The pipeline auto-detects the model version from `config.json` and selects the right pipeline mode (single or dual model). You can also override any parameter via CLI flags.
#### Image-to-Video (I2V-14B)
```bash
# Generate video from an input image
python -m mlx_video.generate_wan \
--model-dir wan22_i2v_mlx \
--prompt "The camera slowly zooms in as the subject begins to move" \
--image start.png \
--num-frames 81 \
--output-path my_video.mp4
```
The I2V-14B model encodes the input image through the Wan2.1 VAE encoder and uses channel concatenation (`y` tensor with 4 mask + 16 image latent channels) to condition generation on the first frame.
#### Generation CLI Options
| Option | Default | Description |
|--------|---------|-------------|
| `--model-dir` | (required) | Path to converted MLX model directory |
| `--prompt` | (required) | Text description of the video |
| `--image` | `None` | Input image path (for I2V models) |
| `--negative-prompt` | `""` | Negative prompt for guidance |
| `--width` | 1280 | Video width |
| `--height` | 720 | Video height |
| `--num-frames` | 81 | Number of frames (must be 4n+1) |
| `--steps` | from config | Number of diffusion steps |
| `--guide-scale` | from config | Guidance scale: float or `low,high` pair |
| `--shift` | from config | Noise schedule shift |
| `--seed` | -1 (random) | Random seed for reproducibility |
| `--output-path` | `output.mp4` | Output video path |
## LoRA Support
LoRA's can be used with the `--lora-high` and `--lora-low` command line switches.
For example, for using the the distilled [Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) LoRA, use the following command. Lightning speeds up generation by using only 4 steps and a CFG scale of 1.
```bash
python -m mlx_video.generate_wan \
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-T2V-A14B-MLX \
--width 480 \
--height 704 \
--num-frames 41 \
--prompt "Two dogs of the poodle breed sitting on a beach wearing sunglasses, nodding with their heads, close up, cinematic, sunset" \
--steps 4 \
--guide-scale 1 \
--trim-first-frames 1 \
--seed 2391784614 \
--lora-high /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/high_noise_model.safetensors 1 \
--lora-low /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/low_noise_model.safetensors 1
```
Which results in
![Poodles](examples/poodles-wan.gif)
## Requirements ## Requirements
- macOS with Apple Silicon - macOS with Apple Silicon
- Python >= 3.11 - Python >= 3.11
- MLX >= 0.22.0 - MLX >= 0.22.0
- For weight conversion: PyTorch (`pip install torch`)
## Model Specifications
- **Transformer**: 48 layers, 32 attention heads, 128 dim per head
- **Latent channels**: 128
- **Text encoder**: Gemma 3 with 3840-dim output
- **RoPE**: Split mode with double precision
## Project Structure ## Project Structure
``` ```
mlx_video/ mlx_video/
├── generate.py # Video generation pipeline ├── generate.py # LTX-2 generation pipeline
├── convert.py # Weight conversion (PyTorch -> MLX) ├── generate_wan.py # Wan2.1/2.2 generation pipeline
├── postprocess.py # Video post-processing utilities ├── convert.py # LTX-2 weight conversion
├── utils.py # Helper functions ├── convert_wan.py # Wan weight conversion (PyTorch → MLX)
├── postprocess.py # Video post-processing utilities
├── utils.py # Helper functions
└── models/ └── models/
── ltx/ ── ltx/ # LTX-2 model
├── ltx.py # Main LTXModel (DiT transformer) ├── ltx.py # DiT transformer
├── config.py # Model configuration ├── config.py # Configuration
├── transformer.py # Transformer blocks ├── transformer.py # Transformer blocks
├── attention.py # Multi-head attention with RoPE ├── attention.py # Multi-head attention with RoPE
├── text_encoder.py # Text encoder ├── text_encoder.py # Gemma 3 text encoder
├── upsampler.py # 2x spatial upsampler ├── upsampler.py # 2x spatial upsampler
└── video_vae/ # VAE encoder/decoder └── video_vae/ # VAE encoder/decoder
└── wan/ # Wan2.1/2.2 model
├── config.py # Configuration (2.1 & 2.2 presets)
├── model.py # WanModel (DiT transformer)
├── transformer.py # Attention blocks with 6-element modulation
├── attention.py # Self/cross attention with QK-norm
├── rope.py # 3-way factorized RoPE
├── text_encoder.py # T5 UMT5-XXL encoder
├── vae.py # 3D causal VAE decoder
└── scheduler.py # Flow-matching Euler scheduler
``` ```
## License ## License

911
docs/PORTING-GUIDE.md Normal file
View File

@@ -0,0 +1,911 @@
# Porting Diffusion Video Models to MLX: Lessons Learned
A practical guide distilled from porting Wan2.1/2.2 (1.3B14B) and Helios 14B DiT
video generation models from PyTorch to MLX on Apple Silicon. These lessons apply
broadly to any diffusion-based video (or image) model port.
---
## Table of Contents
1. [Debugging Methodology](#1-debugging-methodology)
2. [Precision & Dtype Pitfalls](#2-precision--dtype-pitfalls)
3. [MLX-Specific Gotchas](#3-mlx-specific-gotchas)
4. [Autoregressive Chunk Boundaries](#4-autoregressive-chunk-boundaries)
5. [VAE Decoder Artifacts](#5-vae-decoder-artifacts)
6. [Scheduler & Timestep Issues](#6-scheduler--timestep-issues)
7. [Weight Conversion](#7-weight-conversion)
8. [Text Conditioning Failures](#8-text-conditioning-failures)
9. [Position Encodings (RoPE)](#9-position-encodings-rope)
10. [Multi-Stage / Pyramid Pipelines](#10-multi-stage--pyramid-pipelines)
11. [Common Symptoms → Root Causes](#11-common-symptoms--root-causes)
12. [Verification Checklist](#12-verification-checklist)
13. [Diagnostic Tools](#13-diagnostic-tools)
---
## 1. Debugging Methodology
### Component isolation first
Never debug the full pipeline. Test each component in isolation:
1. **Text encoder** — Does it produce embeddings with reasonable statistics? (std > 0.01)
2. **Scheduler** — Do sigma/timestep values match the reference exactly?
3. **Transformer** — Does a single forward pass match the reference? (cosine similarity > 0.999)
4. **VAE decoder** — Feed reference latents into your VAE. Does the output look correct?
If every component matches individually but the pipeline fails, the bug is in
**orchestration** — how components are wired together.
### Statistical fingerprinting
Track per-step statistics through the diffusion loop:
```python
# After each denoising step
print(f"step {i}: mean={latent.mean():.6f} std={latent.std():.6f} "
f"min={latent.min():.4f} max={latent.max():.4f}")
```
**What to look for:**
- **Progressive mean drift** (e.g., -0.002 → -0.040 → -0.123) signals accumulating errors
- **Collapsing std** (std dropping toward 0) signals broken conditioning or wrong noise schedule
- **Exploding values** signal wrong sigma scaling or scheduler formula
### Cross-framework numerical comparison
The most powerful debugging tool: save intermediate tensors from your MLX pipeline,
feed them to the PyTorch reference, compare outputs.
```python
# In MLX pipeline, save inputs before transformer call
mx.save("debug_inputs.npz", {"latent": latent, "timestep": t, "text_emb": text_emb})
# In PyTorch script, load and compare
inputs = np.load("debug_inputs.npz")
mlx_out = np.load("debug_output.npz")["flow"]
pt_out = reference_model(torch.from_numpy(inputs["latent"]), ...)
cos_sim = F.cosine_similarity(pt_out.flatten(), torch.from_numpy(mlx_out).flatten(), dim=0)
# cos_sim > 0.999 = model is correct; bug is elsewhere
# cos_sim < 0.99 = model has a bug; compare per-layer
```
### Ablation testing
When a pipeline has multiple "fixes" or features, disable them one at a time:
- **Frozen history**: Fix history to the same value for all chunks → proves whether
history propagation is the source of drift/zoom
- **Single chunk**: Generate only 1 chunk → isolates per-chunk quality from
multi-chunk interaction bugs
- **Disable post-processing**: Remove cross-fade, blending, corrections → reveals
what the raw model output looks like
### Use reference on same hardware
Run the PyTorch reference on the same device (MPS for Apple Silicon). CUDA and MPS
produce different numerical results due to different float handling. Comparing your
MLX output against a CUDA reference adds noise to the comparison.
```python
# MPS may not support float64 — patch the reference:
original_linspace = torch.linspace
def patched_linspace(*args, **kwargs):
kwargs.pop("dtype", None)
return original_linspace(*args, dtype=torch.float32, **kwargs)
torch.linspace = patched_linspace
```
---
## 2. Precision & Dtype Pitfalls
### The #1 source of subtle bugs
Precision issues caused the most insidious bugs in our port. They don't cause
crashes — they cause progressive quality degradation that's hard to attribute.
### Residual connections MUST be float32
**Bug**: Progressive zoom/shrinking across autoregressive chunks.
**Root cause**: Residual additions (`x = x + attn_out`) in bfloat16. With 7-bit
mantissa, high-frequency spatial detail is systematically truncated. Over 144
residual ops × 6+ model calls per chunk, detail is progressively smoothed away.
**Fix**: Promote to float32 for the addition:
```python
# BAD — bfloat16 accumulation
x = x + attn_out
# GOOD — match reference's .float() pattern
x = (x.astype(mx.float32) + attn_out).astype(weight_dtype)
```
**Rule**: If the reference uses `.float()` anywhere, copy that pattern exactly. It's
there for a reason, even if a quick test seems to work without it.
### Scheduler computations need high precision
Diffusion schedulers involve:
- `x0 = xt - sigma * flow` — catastrophic cancellation near sigma ≈ 1
- `log(sigma)` and `exp()` — sensitive to small precision differences
Some references use float64 for these computations. MLX GPU doesn't support float64,
so use float32 and accept small numerical differences, but **never** use bfloat16
for scheduler math.
### Dtype propagation is invisible
Track dtype through your pipeline. A single bfloat16 intermediate can silently
downcast everything downstream:
```python
# This looks harmless but if model output is bfloat16:
result = noise - sigma * model_output # result is bfloat16!
# Fix: explicit cast
result = (noise.astype(mx.float32) - sigma * model_output.astype(mx.float32))
```
### Type promotion rules differ across frameworks
- PyTorch: bfloat16 + float32 → float32
- MLX: bfloat16 + float32 → float32 (same, but verify)
- NumPy: no bfloat16 support
Always check what your framework does and match the reference's implicit promotations.
### Float32 for VAE decoding
**Bug** (Wan2.2): VAE decode in bfloat16 produced visibly worse quality than reference.
Official Wan2.2 runs VAE decode in `torch.float` (float32), but our converted weights
were bfloat16. The VAE has many sequential layers where precision loss compounds.
**Fix**: Upcast VAE weights to float32 at load time. The VAE runs once per generation,
so the performance impact is negligible compared to the transformer.
### Modulation/gate vectors need float32
**Bug** (Wan2.2): Quality degradation from bfloat16 modulation across 30 blocks × 50 steps.
The official Wan2.2 explicitly uses `torch.amp.autocast('cuda', dtype=torch.float32)`
for time embeddings, modulation parameters, norm outputs before modulation, and gate ops.
**Fix**: Keep modulation in float32, cast to working dtype only when applying to the
hidden state:
```python
# Modulation computed in float32
e0 = self.modulation(time_emb) # float32
scale, shift, gate = e0.split(3, axis=-1)
# Cast to bfloat16 only for the matmul with hidden state
x = (x * (1 + scale.astype(x.dtype)) + shift.astype(x.dtype))
```
### Map PyTorch autocast zones precisely
PyTorch models use nested `torch.amp.autocast` scopes to switch precision. Map these
exactly:
- **Outer scope** (`bfloat16`): attention QKV projections, FFN matmuls
- **Inner scope** (`float32`): modulation, gates, norms, RoPE
- **Residual stream**: float32 (the "backbone" between blocks)
```python
# Wan2.2 dtype flow (matches official):
# Modulation/gates: float32 (explicit)
# QKV/FFN linear projections: bfloat16 (weight dtype)
# RoPE: float32 (official uses float64, MLX lacks float64)
# Attention Q/K: cast back to bfloat16 after RoPE
# Residual stream: float32
```
### Float32 promotion cascades kill performance
**Bug** (Wan2.2): ~2x slowdown from accidental float32 promotion.
A single float32 tensor (e.g., time embedding) flowing into bfloat16 operations
promotes the entire computation graph to float32. In Wan2.2:
- Time embedding MLP output (float32) fed into transformer → all layers float32
- RoPE frequencies (float32) applied to Q/K → all attention float32
**Fix**: Cast intermediate results to model dtype at promotion boundaries:
```python
# After time embedding MLP (float32), cast before feeding to transformer
time_emb = time_mlp(t).astype(model_dtype)
# After RoPE (float32), cast Q/K back to attention dtype
q = rope_apply(q, freqs).astype(v.dtype)
```
---
## 3. MLX-Specific Gotchas
### Underscore-prefixed attributes are invisible
**Bug** (Wan2.2): 87 of 110 VAE weights silently dropped during loading.
MLX's `nn.Module.parameters()` and `nn.Module.load_weights()` **skip** attributes
whose names start with underscore. If you name a layer `self._layer_0`, its weights
will never be loaded or saved.
```python
# BAD — weights silently ignored
self._layer_0 = nn.Linear(...) # nn.Module skips _prefixed attrs
# GOOD
self.layer_0 = nn.Linear(...)
```
This is especially insidious because there's no error — the model loads, runs, and
produces output. The output is just garbage because most weights are random.
### nn.Sequential indexing vs named children
PyTorch's `nn.Sequential` uses integer indices (`sequential.0.weight`), while MLX's
module hierarchy uses named attributes. When mirroring a PyTorch module structure,
you need explicit key sanitization:
```python
def sanitize_key(key):
# PyTorch: "decoder.middle.0.residual.1.weight"
# MLX: "decoder.middle.layer_0.residual.layer_1.weight"
key = re.sub(r'\.(\d+)', lambda m: f'.layer_{m.group(1)}', key)
return key
```
### Reshape axis ordering differs from PyTorch
**Bug** (Wan2.2): Green checkerboard pattern from VAE attention.
`[B,C,T,H,W]` cannot be directly reshaped to `[BT,C,H,W]` because in memory C
comes before T. PyTorch's `reshape` works because it handles non-contiguous tensors.
MLX requires explicit transpose first:
```python
# BAD — mixes channels with time
x = x.reshape(B*T, C, H, W) # Corrupts spatial layout
# GOOD — make B,T adjacent first
x = x.transpose(0, 2, 1, 3, 4) # [B,T,C,H,W]
x = x.reshape(B*T, C, H, W) # Now correct
```
### Patchify channel ordering
**Bug** (Wan2.2): Solid green video output from wrong patchify order.
When converting a Conv3d patchify to a manual reshape+linear, the dimension ordering
in the reshape must match the Conv3d weight layout. Conv3d expects `[C, pt, ph, pw]`
(channels slowest), but a naive reshape produces `[pt, ph, pw, C]` (channels fastest):
```python
# BAD — channel scrambling
patches = x.reshape(B, F', H', W', pt, ph, pw, C)
# GOOD — match Conv3d weight layout
patches = x.reshape(B, F', pt, H', ph, W', pw, C)
patches = patches.transpose(0, 1, 3, 5, 7, 2, 4, 6) # [B, F', H', W', C, pt, ph, pw]
```
Verify numerically: the fixed version should match Conv3d output to ~1e-6.
### mx.zeros / padding inherits dtype
Use dtype-aware `mx.zeros` for padding and concatenation to avoid promotion:
```python
# BAD — default float32 padding promotes bfloat16 input
pad = mx.zeros((B, pad_len, C)) # float32!
x = mx.concatenate([pad, x], axis=1) # x promoted to float32
# GOOD — match input dtype
pad = mx.zeros((B, pad_len, C), dtype=x.dtype)
x = mx.concatenate([pad, x], axis=1) # stays bfloat16
```
### Use mx.fast kernels
Replace manual implementations with fused MLX kernels where possible:
```python
# Manual RMS norm → mx.fast.rms_norm
# Manual LayerNorm → mx.fast.layer_norm
# Manual attention → mx.fast.scaled_dot_product_attention
```
These are faster and handle precision internally.
---
## 4. Autoregressive Chunk Boundaries
For models that generate long videos by autoregressively extending chunks (Helios,
CogVideoX, etc.), chunk boundaries are the primary source of visual artifacts.
### Don't add post-processing the reference doesn't have
**Bug**: Added pixel cross-fade to smooth boundaries → caused 40% sharpness drop.
The reference pipeline used **no cross-fade at all**. The first frame of each new
chunk is intentionally a sharp reconstruction conditioned on history. Blending it with
the previous chunk's tail (which has different content) creates blur.
**Rule**: Before adding smoothing/blending, verify the reference doesn't do it.
Reference simplicity is usually correct.
### First-frame artifacts are common
The first pixel frame of each non-first chunk is typically a distorted reconstruction
of the conditioning frame. In many models, this is expected behavior:
- **Fix**: Drop the first frame from each chunk
- **Verify frame math**: If 33 raw frames at 16fps → drop 1 → 32 frames = exactly 2 seconds
### History conditioning errors compound
Small errors in how history is prepared, sliced, patchified, or position-encoded
will compound across chunks. The error is invisible in chunk 1, small in chunk 2,
and catastrophic by chunk 5.
**Debug strategy**: Generate with frozen history (same history for every chunk).
If the artifact disappears, the bug is in history handling.
---
## 5. VAE Decoder Artifacts
### Causal temporal convolutions cause boundary warmup
Video VAEs (WanVAE, CogVideoX-VAE) use causal temporal convolutions. When decoding
each chunk independently, the first few frames lack temporal context (only zero
padding), causing:
- **~7% contrast drop** in first frames of each chunk
- **Spatial brightness redistribution** (face darkens, background brightens)
This is inherent to the architecture. The reference has the same effect but at
lower magnitude.
### Post-processing to fix VAE warmup
Two-stage correction applied to first N frames of each non-first chunk:
```python
# Stage 1: Spatially-varying brightness correction
# Downsample reference (previous chunk's last frame) and current frame
ref_small = cv2.resize(ref_frame, (w//16, h//16), interpolation=cv2.INTER_AREA)
cur_small = cv2.resize(cur_frame, (w//16, h//16), interpolation=cv2.INTER_AREA)
diff_small = ref_small - cur_small
diff_full = cv2.resize(diff_small, (w, h), interpolation=cv2.INTER_LINEAR)
corrected = cur_frame + ramp * diff_full # ramp: 1.0 → 0.0 over N frames
# Stage 2: Per-channel contrast matching
for c in range(3):
ref_std = np.std(ref_frame[:,:,c])
cur_std = np.std(corrected[:,:,c])
scale = 1.0 + ramp * (ref_std / (cur_std + 1e-6) - 1.0)
corrected[:,:,c] = (corrected[:,:,c] - mean) * scale + mean
```
### VAE overlap decode does NOT work
**Attempted**: Prepend previous chunk's last latent frames to give the decoder
temporal context.
**Result**: Made things **worse** (22% contrast drop vs 7%). The causal convolutions
see conflicting content from different chunks and create larger artifacts than
zero-padding.
**Lesson**: Overlap only works when tiles contain the same content from the same
denoising process (e.g., spatial tiling). It fails for temporal chunks with
different content.
### Per-chunk VAE decoding is correct
Decode each chunk's latents independently, not concatenated. Concatenating all chunks
and decoding together lets boundary discontinuities propagate through temporal
convolutions, creating worse artifacts.
### First-frame quality: causal padding strategies
Multiple approaches were tried for the first-frame quality issue in Wan VAE:
| Approach | Result |
|----------|--------|
| Zero padding (default) | First ~4 frames degraded, but matches training |
| Replicate padding | Fixes artifacts but causes color intensity bias (conv applies all kernel weights to same value) |
| Warmup frame prepend | Helps motion but warmup frame itself has artifacts |
| Mirror-reflect warmup | Best compromise — varied context without zeros, no intensity bias |
**Lesson**: Don't assume "replicate padding is better than zero padding." The model
was trained with zero padding; changing it shifts the gain. Instead, prepend warmup
frames and trim them after decoding.
### RMS_norm vs L2 normalization
**Bug** (Wan2.2): Garbled output from incorrect normalization.
A PyTorch class named `RMS_norm` actually uses `F.normalize` (L2 norm: `x / ||x||_2`),
not RMS normalization (`x / sqrt(mean(x²))`). The difference is a factor of `sqrt(C)`,
causing values to explode through the decoder.
**Lesson**: Don't trust class names — read the actual implementation.
### Temporal frame count: causal boundary effects
**Bug** (Wan2.2): VAE produced 12 frames instead of 9 for a 9-frame input.
PyTorch reference processes frames one-by-one with caching, skipping temporal conv for
the first chunk. All-at-once decoding produces extra frames from zero-padded causal
context.
**Fix**: Use `first_chunk=True` flag to trim causal boundary frames, matching the
reference's chunked behavior.
### Chunked VAE encoding for I2V
**Bug** (Wan2.2 I2V-14B): Incorrect latents from non-chunked encoding.
Non-chunked encoding with causal zero-padding produces incorrect latents because
temporal features don't propagate correctly without caching. The reference uses chunked
encoding (1+4+4+... frames) with persistent temporal cache.
**Fix**: Implement chunked encoding with `feat_cache` propagation through CausalConv3d,
ResidualBlock, and Resample layers.
---
## 6. Scheduler & Timestep Issues
### Copy formulas exactly
Even small differences in scheduler formulas compound over many steps:
```python
# Dynamic time shifting — reference uses specific formula
mu = 0.5 + shift * 0.5 # NOT shift * 0.6 or any other constant
# Euler step
x_next = x + (sigma_next - sigma) * flow # order matters: next - current
```
### Verify sigma schedules numerically
Print and compare sigma values at each step:
```python
# Reference
sigmas_ref = [1.0, 0.99375, 0.9875, ...]
# Your implementation
sigmas = scheduler.get_sigmas(steps)
for i, (r, m) in enumerate(zip(sigmas_ref, sigmas)):
assert abs(r - m) < 1e-6, f"Step {i}: ref={r}, mlx={m}"
```
### Timestep embedding precision
Integer vs float timesteps matter. Some models expect `timestep=999` (int), others
expect `timestep=0.999` (float). Wrong type can silently produce wrong embeddings
with reasonable-looking but incorrect statistics.
### Boundary conditions: ±inf at sigma endpoints
**Bug** (Wan2.2): Greenish/yellow constant output from DPM++/UniPC schedulers.
The `lambda(sigma)` function must return `-inf` at `sigma=1.0` (pure noise) and `+inf`
at `sigma=0.0` (clean signal). Our implementation returned `0.0`, causing massive x0
overscaling on the first denoising step.
PyTorch naturally computes `torch.log(0) = -inf`, and `math.expm1(-inf) = -1.0`
handles the formulas correctly. Reproduce this behavior explicitly:
```python
def _lambda(self, sigma):
if sigma >= 1.0:
return float('-inf')
if sigma <= 0.0:
return float('inf')
return -math.log(sigma / (1 - sigma))
```
### UniPC corrector coefficients
**Bug** (Wan2.2): Accumulated artifacts across 47+ steps from wrong polynomial weights.
The UniPC corrector must compute `rhos_c` via `linalg.solve` for order ≥ 2. Hardcoded
`0.5` was 7× too large for the history weight (actual: ~0.08), causing massive
overweighting of history corrections.
---
## 7. Weight Conversion
### Always verify statistically
After converting weights from PyTorch to MLX format:
```python
for name in mlx_weights:
pt = pytorch_weights[map_name(name)]
mx_val = np.array(mlx_weights[name])
pt_val = pt.numpy()
cos_sim = np.dot(mx_val.flat, pt_val.flat) / (
np.linalg.norm(mx_val) * np.linalg.norm(pt_val) + 1e-10
)
if cos_sim < 0.9999:
print(f"MISMATCH: {name} cos_sim={cos_sim:.6f}")
```
### Conv3d → Linear reshaping
When converting 3D convolutions to linear layers (common for MLX which prefers
linear ops), the flattening order must match:
```python
# PyTorch Conv3d weight: (out_ch, in_ch, kT, kH, kW)
# Flatten to Linear: (out_ch, in_ch * kT * kH * kW)
# The reshape order MUST match how the input is patchified
```
### Sanitization functions
Write explicit weight sanitization that maps reference key names to your key names.
Don't rely on automatic matching — key naming conventions differ between frameworks.
### Module structure must mirror reference for direct loading
**Bug** (Wan2.2): Rewrote entire VAE module hierarchy to match PyTorch `nn.Sequential`
structure. ResidualBlock needed `None` gaps at specific indices to match the original
`nn.Sequential(RMSNorm, SiLU, Conv3d, ...)` indexing.
When possible, structure your modules to accept reference weights directly without
sanitization. This eliminates an entire class of bugs.
### Save VAE weights in float32
Even if the model uses bfloat16 for the transformer, save VAE weights in float32.
bfloat16 → float32 roundtrip loses precision that cannot be recovered by load-time
upcast.
### Temporal downsample/upsample order
**Bug** (Wan2.2): `temporal_downsample=[True, True, False]` but reference uses
`[False, True, True]`. Stage 0 created a `time_conv` with random weights (no matching
file key), and Stage 2 missed its `time_conv` (weights silently dropped).
Always verify boolean flags for each stage by inspecting the actual weight file keys.
### Silent weight drops are the worst bugs
When `load_weights()` with `strict=False` silently skips keys that don't match, you
get a model with random weights for those layers. This produces output that looks
"almost right" but is subtly wrong. Always log which keys were loaded vs skipped:
```python
loaded_keys = set()
for key, value in weights:
if key in model_params:
loaded_keys.add(key)
# Check for missing
expected = set(model_params.keys())
missing = expected - loaded_keys
if missing:
print(f"WARNING: {len(missing)} weights not loaded: {list(missing)[:5]}...")
```
---
## 8. Text Conditioning Failures
### Symptom: model predicts noise back to itself
If the model output correlates > 0.8 with its input noise, text conditioning is
likely broken. The model has learned nothing from the prompt and is just returning
its input.
### Check embedding statistics
```python
text_emb = text_encoder(prompt)
print(f"text_emb: mean={text_emb.mean():.4f} std={text_emb.std():.4f}")
# std < 0.01 → embeddings are collapsed → broken encoder or wrong weights
# std > 10.0 → embeddings are exploding → wrong normalization
```
### Verify with ablation
```python
# Generate with real text
output_text = denoise(latent, text_emb=real_embeddings)
# Generate with zeros
output_zero = denoise(latent, text_emb=mx.zeros_like(real_embeddings))
# Compare
text_influence = np.mean(np.abs(output_text - output_zero))
print(f"Text influence: {text_influence:.4f}") # Should be > 0 (typically 30-60% of output)
```
### Text preprocessing must match exactly
**Bug** (Wan2.2): Patchy-blurry output from wrong negative prompt tokenization.
The official Wan2.2 tokenizer applies `ftfy.fix_text` + `html.unescape` + whitespace
normalization before tokenization. Without this, fullwidth Chinese commas (U+FF0C)
tokenize differently from ASCII commas (U+002C), causing **27 different token IDs**
in the negative prompt. This made CFG's unconditional prediction wrong.
**Fix**: Apply the same text cleaning pipeline as the reference:
```python
import ftfy
import html
import re
def clean_text(text):
text = ftfy.fix_text(text)
text = html.unescape(text)
text = re.sub(r'\s+', ' ', text).strip()
return text
```
### T5 encoder precision
**Bug** (Wan2.2): Quality degradation from bfloat16 T5 attention.
T5 uses **no scaling** in attention (no `1/sqrt(d)` factor), so attention logits can
be very large. bfloat16 softmax loses significant precision across 24 encoder layers.
**Fix**: Compute T5 QK^T and softmax in float32. The T5 encoder only runs once per
generation, so the performance impact is negligible.
### Dual-model text embeddings
**Bug** (Wan2.2 I2V-14B): Low/high noise models have different `text_embedding` weights
(~42% relative difference). Using one model's embeddings for both caused incorrect
text conditioning for the high-noise model that handles critical early denoising steps.
**Fix**: Compute separate text embeddings for each model in dual-model setups.
---
## 9. Position Encodings (RoPE)
### Multi-scale consistency
In pyramid/multi-resolution models, RoPE must be computed consistently across scales.
If the model operates at 1/4 resolution in an early stage, the position grid must
reflect the actual spatial dimensions, not the final target dimensions.
### History vs current chunk
When conditioning on history from a previous chunk, the position encoding for
history frames must match what the model saw during training. Mismatches between
history and current-chunk position encodings can cause subtle spatial distortions
that compound across chunks.
### Factorized RoPE
3D video models often use factorized RoPE (separate temporal, height, width
frequencies). Verify each axis independently:
```python
# Compare temporal frequencies
assert np.allclose(mlx_rope_t, ref_rope_t, atol=1e-5)
# Compare spatial frequencies
assert np.allclose(mlx_rope_h, ref_rope_h, atol=1e-5)
assert np.allclose(mlx_rope_w, ref_rope_w, atol=1e-5)
```
### Per-axis frequency construction
**Bug** (Wan2.2): Grey/artifact-filled output from wrong frequency distribution.
The reference uses three separate `rope_params()` calls with different dimension
normalizations (e.g., 44, 42, 42 for Wan) so each axis gets its own full frequency
range. Consolidating into a single `rope_params(head_dim)` call and splitting gave
height frequencies starting at 0.042 and width at 0.002 (should be 1.0 for both).
**Fix** (and subsequent revert): This bug was introduced as a "fix" for a previous
RoPE issue, then had to be reverted. The lesson: RoPE changes have far-reaching effects.
Always verify with actual generation, not just numerical comparison of frequencies.
**Lesson**: Read the reference's frequency construction very carefully. Don't
"simplify" three separate calls into one unless you verify the frequency distribution
matches exactly.
---
## 10. Multi-Stage / Pyramid Pipelines
### Each stage is a potential failure point
Pyramid pipelines (generate at low res, upsample, refine at high res) multiply the
number of things that can go wrong:
- Downsampling method (bilinear vs area) must match reference
- Energy compensation factors (e.g., ×2 after bilinear downsample) must be present
- Alpha/beta noise mixing coefficients are stage-dependent
- Frame indices and history resolution change per stage
### Test single-stage first
If the model works at full resolution for a single stage but fails in the pyramid,
the bug is in stage orchestration — typically in how latents are passed between
stages or how position encodings adapt to different resolutions.
### Integration bugs are the hardest
We verified every Helios component matched the reference individually, but the
pyramid still produced uniform color. The bug was in dtype handling during stage
transitions. Integration bugs only appear when components interact.
---
## 11. Common Symptoms → Root Causes
| Symptom | Likely Root Causes |
|---------|-------------------|
| **Pure noise output** | Wrong sigma schedule, broken text conditioning, incorrect weight mapping |
| **Uniform color** | Model predicting noise back; text embeddings collapsed; wrong timestep format |
| **Progressive zoom/shrink** | bfloat16 residuals truncating high-freq detail; RoPE mismatch across chunks |
| **Brightness jumps at boundaries** | VAE causal warmup; cross-fade blending misaligned content |
| **Color drift across chunks** | Dtype in scheduler step; history normalization missing |
| **Blur at boundaries** | Cross-fade enabled; latent blending; wrong VAE decode order |
| **Grid/checker patterns** | Patchify channel ordering bug; latent blend artifacts; reshape axis error |
| **Green/magenta tint** | VAE weight key mismatch; wrong denormalization constants; cv2 YUV color matrix |
| **Mean drift across steps** | bfloat16 accumulation; wrong scheduler formula; missing energy compensation |
| **Garbled/scrambled output** | Silent weight drops (underscore prefix, wrong key mapping); RMS vs L2 norm |
| **Greenish-yellow constant** | Scheduler boundary condition (log(0) not returning -inf); x0 overscaling |
| **~2x slower than expected** | Float32 promotion cascade from single mistyped intermediate |
| **Extra output frames** | Causal padding producing extra temporal frames; missing `first_chunk` trim |
| **Grey/artifact output** | RoPE frequency construction wrong (per-axis vs single-call) |
| **Patchy-blurry with CFG** | Text preprocessing mismatch (fullwidth vs ASCII chars → wrong tokenization) |
| **I2V temporal mismatch** | Non-chunked VAE encoding vs reference's chunked encoding with temporal cache |
---
## 12. Verification Checklist
Use this checklist when porting a new diffusion video model:
### Model
- [ ] Weight conversion: all keys mapped, cosine similarity > 0.9999
- [ ] No silent weight drops (log loaded vs expected keys)
- [ ] Single forward pass matches reference (cos_sim > 0.999)
- [ ] Residual connections use float32 accumulation
- [ ] Attention computation matches reference precision
- [ ] Modulation/gate vectors in float32 (if reference uses autocast)
- [ ] No underscore-prefixed module attributes (MLX ignores them)
### Scheduler
- [ ] Sigma values match reference at every step (diff < 1e-6)
- [ ] Timestep format correct (int vs float, scale factor)
- [ ] Dynamic shifting formula copied exactly
- [ ] Step function returns correct dtype (float32)
- [ ] Boundary conditions: lambda(-inf) at sigma=1, lambda(+inf) at sigma=0
- [ ] Higher-order coefficients computed (not hardcoded) for UniPC/DPM++
### Text Encoder
- [ ] Embedding statistics reasonable (0.01 < std < 10)
- [ ] Text influence > 0 (ablation test)
- [ ] Tokenization matches (special tokens, padding, max length)
- [ ] Text preprocessing matches (ftfy, html unescape, whitespace normalization)
- [ ] T5/CLIP attention precision (float32 softmax if no 1/sqrt(d) scaling)
- [ ] Separate embeddings for dual-model setups (if applicable)
### VAE
- [ ] Denormalization constants match training pipeline
- [ ] Per-chunk decoding (not concatenated)
- [ ] Temporal frame count correct (account for causal padding)
- [ ] Weight keys mapped correctly (encoder vs decoder)
- [ ] Weights stored/loaded in float32 (not bfloat16)
- [ ] Temporal downsample/upsample order matches reference
- [ ] RMS_norm vs L2_norm: check actual implementation, not class name
- [ ] Chunked encoding for I2V (if applicable)
- [ ] Reshape axis ordering correct ([B,C,T,H,W] → transpose before reshape)
### Pipeline Orchestration
- [ ] Position encodings consistent across stages/chunks
- [ ] History slicing and conditioning correct
- [ ] Noise generation matches (distribution, correlation structure)
- [ ] Multi-chunk output visually consistent (no progressive degradation)
- [ ] Dimension auto-alignment (divisible by patch_size × vae_stride)
- [ ] Dtype-aware padding (mx.zeros with explicit dtype)
### Output
- [ ] Frame count matches expected (account for warmup/trim)
- [ ] FPS correct
- [ ] Color range [0, 255] uint8 for video
- [ ] No first-frame duplication artifacts
- [ ] Video codec correct (imageio/libx264 preferred over cv2/mp4v on macOS)
### Performance
- [ ] No float32 promotion cascades (check with profiler)
- [ ] Using mx.fast kernels (rms_norm, layer_norm, sdpa)
- [ ] Time embedding computed once per sample (not per position)
- [ ] Memory cleanup (delete temporaries before mx.eval)
---
## 13. Diagnostic Tools
### General video diagnostics (`scripts/video/`)
| Script | Purpose |
|--------|---------|
| `compare_videos.py` | PSNR, SSIM, temporal coherence, color fidelity between two videos |
| `video_quality.py` | Sharpness, stability, defect detection, chunk boundary analysis |
```bash
# Quick quality check
python scripts/video/video_quality.py output.mp4 --chunk-size 32
# Compare against reference
python scripts/video/compare_videos.py reference.mp4 output.mp4 --diff-video diff.mp4
```
### Model-specific diagnostics (`scripts/helios/`)
| Script | Purpose |
|--------|---------|
| `analyze_boundaries.py` | Detailed boundary quality metrics for Helios |
| `run_reference.py` | Run PyTorch reference on MPS |
| `compare_pipelines.py` | Compare scheduler/pipeline mechanics |
| `compare_models.py` | Cross-framework model output comparison |
### Inline debugging pattern
Add temporary debug output to the diffusion loop:
```python
for i, sigma in enumerate(sigmas):
flow = model(latent, sigma, text_emb)
latent = scheduler.step(latent, flow, sigma, sigma_next)
# Debug: track statistics
print(f"[step {i}] sigma={sigma:.4f} "
f"latent: mean={latent.mean():.6f} std={latent.std():.6f} "
f"flow: mean={flow.mean():.6f} std={flow.std():.6f}")
# Debug: save for cross-framework comparison
if os.environ.get("DEBUG"):
mx.save(f"/tmp/debug_step_{i}.npz", {
"latent": latent, "flow": flow, "sigma": mx.array(sigma)
})
```
---
## Key Takeaways
1. **Precision is the #1 bug source** — bfloat16 residuals, scheduler math, type
promotion, modulation vectors. Copy the reference's `.float()` and `autocast` zones.
2. **Don't add what the reference doesn't have** — cross-fade, overlap decode,
temporal blending. If the reference works without it, you probably have a bug
elsewhere.
3. **Silent failures are the hardest bugs** — underscore-prefixed weights, `strict=False`
weight loading, wrong normalization class names. Always verify weight load counts
and output statistics.
4. **Component isolation → integration testing** — verify each part matches, then
debug their interaction.
5. **Statistical comparison beats visual inspection** — mean drift, contrast ratios,
and cosine similarity catch bugs before they're visible.
6. **Autoregressive errors compound** — a 1% error per chunk becomes 10% by chunk 10.
Fix precision first, add corrections second.
7. **MLX has unique pitfalls** — underscore attribute names, reshape axis ordering,
dtype-unaware padding, and float32 promotion cascades. Know your framework.
8. **Text preprocessing matters** — Unicode normalization, fullwidth chars, HTML entities.
A single mismatched comma can break CFG guidance.
9. **VAE is deceptively complex** — causal padding, temporal frame counts, chunked vs
batch processing, norm implementations. Budget significant debugging time for VAE.

BIN
examples/poodles-wan.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.4 MiB

View File

@@ -1,9 +1,12 @@
from mlx_video.models.ltx import LTXModel, LTXModelConfig from mlx_video.models.ltx import LTXModel, LTXModelConfig
from mlx_video.models.wan import WanModel, WanModelConfig
from mlx_video.convert import load_transformer_weights, load_vae_weights from mlx_video.convert import load_transformer_weights, load_vae_weights
import os import os
__all__ = [ __all__ = [
"LTXModel", "LTXModel",
"LTXModelConfig", "LTXModelConfig",
"WanModel",
"WanModelConfig",
"load_transformer_weights", "load_transformer_weights",
"load_vae_weights", "load_vae_weights",
] ]

773
mlx_video/convert_wan.py Normal file
View File

@@ -0,0 +1,773 @@
"""Weight conversion for Wan2.2 models (PyTorch -> MLX)."""
import gc
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.utils
import numpy as np
logger = logging.getLogger(__name__)
def load_torch_weights(path: str) -> Dict[str, mx.array]:
"""Load PyTorch .pth weights and convert to MLX arrays.
Args:
path: Path to .pth file
Returns:
Dictionary of MLX arrays
"""
try:
import torch
except ImportError:
raise ImportError("PyTorch is required to load .pth weights: pip install torch")
logging.info(f"Loading weights from {path}")
state_dict = torch.load(path, map_location="cpu", weights_only=True)
weights = {}
for key, value in state_dict.items():
if isinstance(value, torch.Tensor):
np_val = value.detach().float().numpy()
weights[key] = mx.array(np_val)
return weights
def load_safetensors_weights(path: str) -> Dict[str, mx.array]:
"""Load safetensors weights as MLX arrays.
Args:
path: Path to directory with safetensors files or single file
Returns:
Dictionary of MLX arrays
"""
path = Path(path)
weights = {}
if path.is_file():
weights = mx.load(str(path))
elif path.is_dir():
for sf in sorted(path.glob("*.safetensors")):
weights.update(mx.load(str(sf)))
return weights
def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Convert Wan2.2 transformer weight keys to MLX model structure.
Wan2.2 keys follow the pattern:
patch_embedding.weight/bias
text_embedding.{0,2}.weight/bias
time_embedding.{0,2}.weight/bias
time_projection.1.weight/bias
blocks.{i}.norm1.weight
blocks.{i}.self_attn.{q,k,v,o}.weight/bias
blocks.{i}.self_attn.norm_q.weight
blocks.{i}.self_attn.norm_k.weight
blocks.{i}.norm3.weight/bias (if cross_attn_norm)
blocks.{i}.cross_attn.{q,k,v,o}.weight/bias
blocks.{i}.cross_attn.norm_q.weight
blocks.{i}.cross_attn.norm_k.weight
blocks.{i}.norm2.weight
blocks.{i}.ffn.{0,2}.weight/bias
blocks.{i}.modulation
head.norm.weight
head.head.weight/bias
head.modulation
freqs (buffer)
MLX model uses:
patch_embedding_proj.weight/bias (after patchify reshape)
text_embedding_0.weight/bias, text_embedding_1.weight/bias
time_embedding_0.weight/bias, time_embedding_1.weight/bias
time_projection.weight/bias
blocks.{i}.norm1.weight
blocks.{i}.self_attn.{q,k,v,o}.weight/bias
etc.
"""
sanitized = {}
consumed = set()
for key, value in weights.items():
new_key = key
# Patch embedding: Conv3d(16, 5120, (1,2,2)) weight is [O, I, D, H, W]
# MLX Linear expects [O, I*D*H*W] after we flatten in patchify
if key == "patch_embedding.weight":
# Original: [dim, in_dim, 1, 2, 2] -> reshape to [dim, in_dim*1*2*2]
value = value.reshape(value.shape[0], -1)
new_key = "patch_embedding_proj.weight"
sanitized[new_key] = value
consumed.add(key)
continue
if key == "patch_embedding.bias":
new_key = "patch_embedding_proj.bias"
sanitized[new_key] = value
consumed.add(key)
continue
# Text embedding Sequential: 0=Linear, 1=GELU(no params), 2=Linear
if key.startswith("text_embedding.0."):
new_key = key.replace("text_embedding.0.", "text_embedding_0.")
sanitized[new_key] = value
consumed.add(key)
continue
if key.startswith("text_embedding.2."):
new_key = key.replace("text_embedding.2.", "text_embedding_1.")
sanitized[new_key] = value
consumed.add(key)
continue
# Time embedding Sequential: 0=Linear, 1=SiLU(no params), 2=Linear
if key.startswith("time_embedding.0."):
new_key = key.replace("time_embedding.0.", "time_embedding_0.")
sanitized[new_key] = value
consumed.add(key)
continue
if key.startswith("time_embedding.2."):
new_key = key.replace("time_embedding.2.", "time_embedding_1.")
sanitized[new_key] = value
consumed.add(key)
continue
# Time projection Sequential: 0=SiLU(no params), 1=Linear
if key.startswith("time_projection.1."):
new_key = key.replace("time_projection.1.", "time_projection.")
sanitized[new_key] = value
consumed.add(key)
continue
# FFN: Sequential(Linear, GELU, Linear) -> ffn.{0,2} -> ffn.fc1, ffn.fc2
new_key = new_key.replace(".ffn.0.", ".ffn.fc1.")
new_key = new_key.replace(".ffn.2.", ".ffn.fc2.")
# Skip the freqs buffer (we compute it in the model)
if key == "freqs":
consumed.add(key)
continue
sanitized[new_key] = value
consumed.add(key)
unconsumed = set(weights.keys()) - consumed
if unconsumed:
logger.warning("Unconsumed transformer weight keys: %s", sorted(unconsumed))
return sanitized
def sanitize_wan_t5_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Convert Wan2.2 T5 encoder weight keys to MLX T5Encoder structure.
Wan2.2 T5 keys:
token_embedding.weight
pos_embedding.embedding.weight (if shared_pos)
blocks.{i}.norm1.weight
blocks.{i}.attn.{q,k,v,o}.weight
blocks.{i}.norm2.weight
blocks.{i}.ffn.gate.0.weight (gate linear)
blocks.{i}.ffn.fc1.weight
blocks.{i}.ffn.fc2.weight
blocks.{i}.pos_embedding.embedding.weight (if not shared_pos)
norm.weight
MLX T5Encoder structure:
token_embedding.weight
blocks.{i}.norm1.weight
blocks.{i}.attn.{q,k,v,o}.weight
blocks.{i}.norm2.weight
blocks.{i}.ffn.gate_proj.weight (mapped from gate.0)
blocks.{i}.ffn.fc1.weight
blocks.{i}.ffn.fc2.weight
blocks.{i}.pos_embedding.embedding.weight
norm.weight
"""
sanitized = {}
consumed = set()
for key, value in weights.items():
new_key = key
# Map gate.0 -> gate_proj (the GELU is a separate module, not a parameter)
new_key = new_key.replace(".ffn.gate.0.", ".ffn.gate_proj.")
sanitized[new_key] = value
consumed.add(key)
unconsumed = set(weights.keys()) - consumed
if unconsumed:
logger.warning("Unconsumed T5 weight keys: %s", sorted(unconsumed))
return sanitized
def sanitize_wan_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Convert Wan2.2 VAE weight keys to MLX WanVAE structure.
Handles Conv3d and Conv2d weight transpositions for MLX format.
"""
sanitized = {}
consumed = set()
for key, value in weights.items():
new_key = key
# Handle Conv3d: PyTorch [O, I, D, H, W] -> MLX CausalConv3d weight [O, D, H, W, I]
if "weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Handle Conv2d: PyTorch [O, I, H, W] -> MLX [O, H, W, I]
if "weight" in key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
# Map decoder keys to MLX decoder structure
# Wan2.2 uses encoder/decoder with downsamples/upsamples
# Need to adapt naming for our simplified structure
sanitized[new_key] = value
consumed.add(key)
unconsumed = set(weights.keys()) - consumed
if unconsumed:
logger.warning("Unconsumed VAE weight keys: %s", sorted(unconsumed))
return sanitized
def _load_lora_configs(
lora_configs: List[Tuple[str, float]],
) -> Dict[str, list]:
"""Load LoRA weights from config tuples, returning module_to_loras dict.
Shared between weight-merging and runtime-wrapping paths.
"""
from mlx_video.lora import LoRAConfig, load_multiple_loras
from mlx_video.generate_wan import Colors
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
configs = []
for lora_path, strength in lora_configs:
try:
config = LoRAConfig(path=lora_path, strength=strength)
configs.append(config)
print(f" - {Path(lora_path).name} (strength: {strength})")
except Exception as e:
print(f"{Colors.RED}Error loading LoRA {lora_path}: {e}{Colors.RESET}")
raise
module_to_loras = load_multiple_loras(configs)
if not module_to_loras:
print(f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}")
return module_to_loras
def load_and_apply_loras(
model_weights: Dict[str, mx.array],
lora_configs: Optional[List[Tuple[str, float]]] = None,
verbose: bool = False,
quantization_bits: int = 0,
) -> Dict[str, mx.array]:
"""Load and apply LoRA weights to model weights by merging into weight dict.
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
"""
from mlx_video.lora import apply_loras_to_weights
from mlx_video.generate_wan import Colors
if not lora_configs:
return model_weights
module_to_loras = _load_lora_configs(lora_configs)
if not module_to_loras:
return model_weights
print(f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}")
if verbose:
print(f" Model has {len(model_weights)} weight keys")
modified_weights = apply_loras_to_weights(
model_weights, module_to_loras, verbose=verbose, quantization_bits=quantization_bits
)
print(f"{Colors.GREEN}✓ LoRAs applied successfully{Colors.RESET}")
return modified_weights
def convert_wan_checkpoint(
checkpoint_dir: str,
output_dir: str,
dtype: str = "bfloat16",
model_version: str = "auto",
quantize: bool = False,
bits: int = 4,
group_size: int = 64,
):
"""Convert a Wan2.1 or Wan2.2 checkpoint directory to MLX format.
Wan2.2 expected structure:
checkpoint_dir/
models_t5_umt5-xxl-enc-bf16.pth
Wan2.1_VAE.pth
low_noise_model/ (safetensors)
high_noise_model/ (safetensors)
Wan2.1 expected structure:
checkpoint_dir/
models_t5_umt5-xxl-enc-bf16.pth
Wan2.1_VAE.pth
diffusion_pytorch_model*.safetensors (single model)
Args:
checkpoint_dir: Path to Wan checkpoint directory
output_dir: Path to output MLX model directory
dtype: Target dtype
model_version: "2.1", "2.2", or "auto" (detect from directory)
quantize: Whether to quantize the transformer weights
bits: Quantization bits (4 or 8)
group_size: Quantization group size (32, 64, or 128)
"""
import json
checkpoint_dir = Path(checkpoint_dir)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
dtype_map = {
"float16": mx.float16,
"float32": mx.float32,
"bfloat16": mx.bfloat16,
}
target_dtype = dtype_map.get(dtype, mx.bfloat16)
# Auto-detect version
if model_version == "auto":
if (checkpoint_dir / "low_noise_model").exists():
model_version = "2.2"
elif (checkpoint_dir / "Wan2.2_VAE.pth").exists():
model_version = "2.2"
else:
model_version = "2.1"
print(f"Auto-detected Wan{model_version} checkpoint")
is_dual = (checkpoint_dir / "low_noise_model").exists()
if is_dual:
# Wan2.2: Convert dual transformer models
low_noise_path = checkpoint_dir / "low_noise_model"
if low_noise_path.exists():
print("Converting low-noise transformer...")
weights = load_safetensors_weights(str(low_noise_path))
weights = sanitize_wan_transformer_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "low_noise_model.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
high_noise_path = checkpoint_dir / "high_noise_model"
if high_noise_path.exists():
print("Converting high-noise transformer...")
weights = load_safetensors_weights(str(high_noise_path))
weights = sanitize_wan_transformer_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "high_noise_model.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
else:
# Wan2.1: Convert single transformer model
# Try safetensors in the checkpoint dir itself
print("Converting transformer (single model)...")
weights = load_safetensors_weights(str(checkpoint_dir))
if not weights:
# Fallback: look for .pth files
for pth in sorted(checkpoint_dir.glob("*.pth")):
if "t5" not in pth.name.lower() and "vae" not in pth.name.lower():
print(f" Loading from {pth.name}...")
weights = load_torch_weights(str(pth))
break
if weights:
weights = sanitize_wan_transformer_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "model.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
else:
print(" Warning: No transformer weights found!")
# Save config — detect model size from source config.json or transformer weights
from mlx_video.models.wan.config import WanModelConfig
def _detect_config():
"""Detect config from source config.json or transformer weight shapes."""
if is_dual:
# Check source config.json for model_type (I2V vs T2V)
src_cfg_path = checkpoint_dir / "high_noise_model" / "config.json"
if src_cfg_path.exists():
with open(src_cfg_path) as f:
src_config = json.load(f)
src_model_type = src_config.get("model_type", "t2v")
if src_model_type == "i2v" or src_config.get("in_dim") == 36:
return WanModelConfig.wan22_i2v_14b()
return WanModelConfig.wan22_t2v_14b()
# Try reading source config.json first (most reliable)
src_cfg_path = checkpoint_dir / "config.json"
src_config = None
if src_cfg_path.exists():
with open(src_cfg_path) as f:
src_config = json.load(f)
if src_config and "dim" in src_config:
src_dim = src_config.get("dim", 5120)
src_in_dim = src_config.get("in_dim", 16)
src_out_dim = src_config.get("out_dim", 16)
src_ffn_dim = src_config.get("ffn_dim", 13824)
src_num_heads = src_config.get("num_heads", 40)
src_num_layers = src_config.get("num_layers", 40)
src_model_type = src_config.get("model_type", "t2v")
src_text_len = src_config.get("text_len", 512)
print(f" Source config: dim={src_dim}, layers={src_num_layers}, "
f"heads={src_num_heads}, type={src_model_type}")
# Use preset for known TI2V 5B configuration
if src_model_type == "ti2v" and src_dim == 3072:
return WanModelConfig.wan22_ti2v_5b()
is_22 = model_version == "2.2"
# Wan2.2 uses different VAE with z_dim=48 and stride (4,16,16)
vae_z = 48 if is_22 else 16
vae_s = (4, 16, 16) if is_22 else (4, 8, 8)
fps = 24 if is_22 else 16
return WanModelConfig(
model_type=src_model_type,
model_version=model_version,
dim=src_dim,
ffn_dim=src_ffn_dim,
in_dim=src_in_dim,
out_dim=src_out_dim,
num_heads=src_num_heads,
num_layers=src_num_layers,
text_len=src_text_len,
vae_z_dim=vae_z,
vae_stride=vae_s,
dual_model=False,
boundary=0.0,
sample_shift=5.0,
sample_steps=50,
sample_guide_scale=5.0,
sample_fps=fps,
)
# Fallback: detect from saved transformer weight shapes
saved_model = output_dir / "model.safetensors"
if saved_model.exists():
det_weights = mx.load(str(saved_model))
dim = None
for k, v in det_weights.items():
if "patch_embedding_proj.weight" in k:
dim = v.shape[0]
break
del det_weights
if dim is not None and dim <= 2048:
print(f" Auto-detected 1.3B model (dim={dim})")
return WanModelConfig.wan21_t2v_1_3b()
return WanModelConfig.wan21_t2v_14b()
config = _detect_config()
config_path = output_dir / "config.json"
with open(config_path, "w") as f:
json.dump(config.to_dict(), f, indent=2)
print(f" Saved config to {config_path}")
# Convert T5 encoder
t5_path = checkpoint_dir / "models_t5_umt5-xxl-enc-bf16.pth"
if t5_path.exists():
print("Converting T5 encoder...")
weights = load_torch_weights(str(t5_path))
weights = sanitize_wan_t5_weights(weights)
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
out_path = output_dir / "t5_encoder.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path}")
# Convert VAE (check both naming conventions)
vae_path = checkpoint_dir / "Wan2.1_VAE.pth"
is_wan22_vae = False
if not vae_path.exists():
vae_path = checkpoint_dir / "Wan2.2_VAE.pth"
is_wan22_vae = True
if vae_path.exists():
print(f"Converting VAE ({'Wan2.2' if is_wan22_vae else 'Wan2.1'})...")
weights = load_torch_weights(str(vae_path))
if is_wan22_vae:
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
include_encoder = config.model_type in ("ti2v", "i2v")
weights = sanitize_wan22_vae_weights(weights, include_encoder=include_encoder)
else:
weights = sanitize_wan_vae_weights(weights)
# Always save VAE in float32 — official Wan2.2 runs VAE decode in
# float32 (dtype=torch.float). Saving in bfloat16 loses precision
# that cannot be recovered by upcasting at load time.
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
out_path = output_dir / "vae.safetensors"
mx.save_safetensors(str(out_path), weights)
print(f" Saved {len(weights)} weight tensors to {out_path} (float32)")
# Quantize transformer weights if requested
if quantize:
print(f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})...")
_quantize_saved_model(output_dir, config, is_dual, bits, group_size)
print(f"\nConversion complete! Output: {output_dir}")
def _quantize_predicate(path: str, module) -> bool:
"""Return True for layers that should be quantized.
Targets heavyweight Linear layers in attention and FFN blocks.
Skips embeddings, norms, head, and modulation (small, precision-sensitive).
"""
if not hasattr(module, "to_quantized"):
return False
# Quantize attention Q/K/V/O and FFN fc1/fc2
quantize_patterns = (
".self_attn.q", ".self_attn.k", ".self_attn.v", ".self_attn.o",
".cross_attn.q", ".cross_attn.k", ".cross_attn.v", ".cross_attn.o",
".ffn.fc1", ".ffn.fc2",
)
return any(path.endswith(p) for p in quantize_patterns)
def _quantize_saved_model(
output_dir: Path,
config,
is_dual: bool,
bits: int,
group_size: int,
source_dir: Path = None,
):
"""Load saved bf16 model, quantize, and re-save.
Args:
output_dir: Directory to write quantized weights to.
config: WanModelConfig for creating the model.
is_dual: Whether this is a dual-expert model.
bits: Quantization bits.
group_size: Quantization group size.
source_dir: Directory to read bf16 weights from. Defaults to output_dir.
"""
import json
import mlx.nn as nn
from mlx_video.models.wan.model import WanModel
if source_dir is None:
source_dir = output_dir
model_names = []
if is_dual:
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
if (source_dir / name).exists():
model_names.append(name)
else:
if (source_dir / "model.safetensors").exists():
model_names.append("model.safetensors")
for name in model_names:
print(f" Quantizing {name}...")
model = WanModel(config)
weights = mx.load(str(source_dir / name))
model.load_weights(list(weights.items()), strict=False)
mx.eval(model.parameters())
del weights
gc.collect()
mx.clear_cache()
# Apply quantization to targeted layers
nn.quantize(
model,
group_size=group_size,
bits=bits,
class_predicate=lambda path, m: _quantize_predicate(path, m),
)
# Save quantized weights
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
# Validate: check for NaN/Inf in bias tensors (corruption canary)
bad_keys = []
for k, v in weights_dict.items():
if k.endswith(".bias") and not k.endswith(".biases"):
mx.eval(v)
if mx.any(mx.isnan(v)).item() or mx.any(mx.isinf(v)).item():
bad_keys.append(k)
if bad_keys:
raise RuntimeError(
f"Quantization produced corrupted weights in {model_path.name}: "
f"{len(bad_keys)} bias tensors contain NaN/Inf "
f"(e.g. {bad_keys[0]}). Try re-running with more available memory."
)
mx.save_safetensors(str(output_dir / name), weights_dict)
n_quantized = sum(1 for k in weights_dict if ".scales" in k)
print(f" {n_quantized} layers quantized, {len(weights_dict)} tensors saved")
# Free model before processing next file
del model, weights_dict
gc.collect()
mx.clear_cache()
# Update config.json with quantization metadata
config_path = output_dir / "config.json"
with open(config_path) as f:
cfg = json.load(f)
cfg["quantization"] = {
"group_size": group_size,
"bits": bits,
}
with open(config_path, "w") as f:
json.dump(cfg, f, indent=2)
print(f" Updated config.json with quantization metadata")
def quantize_mlx_model(
mlx_model_dir: str,
output_dir: str,
bits: int = 4,
group_size: int = 64,
):
"""Quantize an already-converted MLX model (skips PyTorch conversion).
Args:
mlx_model_dir: Path to existing MLX model directory (bf16/fp16).
output_dir: Path to output quantized model directory.
bits: Quantization bits (4 or 8).
group_size: Quantization group size (32, 64, or 128).
"""
import json
import shutil
src = Path(mlx_model_dir)
dst = Path(output_dir)
config_path = src / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"No config.json found in {src}")
with open(config_path) as f:
cfg = json.load(f)
if cfg.get("quantization"):
raise ValueError(
f"Model at {src} is already quantized "
f"({cfg['quantization']['bits']}-bit). Use a bf16/fp16 source."
)
# Detect dual vs single expert
is_dual = (src / "low_noise_model.safetensors").exists() and (
src / "high_noise_model.safetensors"
).exists()
# Build model config
from mlx_video.models.wan.config import WanModelConfig
config_dict = {k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__}
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
if key in config_dict and isinstance(config_dict[key], list):
config_dict[key] = tuple(config_dict[key])
config = WanModelConfig(**config_dict)
# Copy non-transformer files to output dir (skip large model weights)
transformer_files = {"low_noise_model.safetensors", "high_noise_model.safetensors", "model.safetensors"}
if dst.resolve() != src.resolve():
dst.mkdir(parents=True, exist_ok=True)
for f in src.iterdir():
if f.is_file() and f.name not in transformer_files:
shutil.copy2(f, dst / f.name)
print(f"Copied non-transformer files from {src} to {dst}")
print(f"Quantizing transformer weights ({bits}-bit, group_size={group_size})...")
_quantize_saved_model(dst, config, is_dual, bits, group_size, source_dir=src)
print(f"\nQuantization complete! Output: {dst}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Convert Wan model to MLX format")
parser.add_argument(
"--checkpoint-dir",
type=str,
required=True,
help="Path to Wan checkpoint directory",
)
parser.add_argument(
"--output-dir",
type=str,
default="wan_mlx_model",
help="Output path for MLX model",
)
parser.add_argument(
"--dtype",
type=str,
choices=["float16", "float32", "bfloat16"],
default="bfloat16",
help="Target dtype",
)
parser.add_argument(
"--model-version",
type=str,
choices=["2.1", "2.2", "auto"],
default="auto",
help="Wan model version (auto-detect by default)",
)
parser.add_argument(
"--quantize",
action="store_true",
help="Quantize transformer weights for faster inference",
)
parser.add_argument(
"--quantize-only",
action="store_true",
help="Quantize an already-converted MLX model (skips PyTorch conversion)",
)
parser.add_argument(
"--bits",
type=int,
choices=[4, 8],
default=4,
help="Quantization bits (default: 4)",
)
parser.add_argument(
"--group-size",
type=int,
choices=[32, 64, 128],
default=64,
help="Quantization group size (default: 64)",
)
args = parser.parse_args()
if args.quantize_only:
quantize_mlx_model(
args.checkpoint_dir, args.output_dir,
bits=args.bits, group_size=args.group_size,
)
else:
convert_wan_checkpoint(
args.checkpoint_dir, args.output_dir, args.dtype, args.model_version,
quantize=args.quantize, bits=args.bits, group_size=args.group_size,
)

828
mlx_video/generate_wan.py Normal file
View File

@@ -0,0 +1,828 @@
"""Wan2.2 Text-to-Video generation pipeline for MLX."""
import argparse
import gc
import math
import random
import sys
import time
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from tqdm import tqdm
from mlx_video.models.wan.i2v_utils import build_i2v_mask, preprocess_image
from mlx_video.models.wan.loading import (
_clean_text,
encode_text,
load_t5_encoder,
load_vae_decoder,
load_vae_encoder,
load_wan_model,
)
from mlx_video.models.wan.postprocess import save_video
class Colors:
"""ANSI color codes for terminal output."""
CYAN = "\033[96m"
BLUE = "\033[94m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
RED = "\033[91m"
MAGENTA = "\033[95m"
BOLD = "\033[1m"
DIM = "\033[2m"
RESET = "\033[0m"
# Backward-compat alias (tests and external code may use the old name)
_build_i2v_mask = build_i2v_mask
def _best_output_size(w, h, dw, dh, max_area):
"""Compute the best output resolution that fits within max_area while
preserving the input aspect ratio and satisfying alignment constraints.
Matches the reference implementation's best_output_size().
"""
ratio = w / h
ow = (max_area * ratio) ** 0.5
oh = max_area / ow
# Option 1: process width first
ow1 = int(ow // dw * dw)
oh1 = int(max_area / ow1 // dh * dh)
ratio1 = ow1 / oh1
# Option 2: process height first
oh2 = int(oh // dh * dh)
ow2 = int(max_area / oh2 // dw * dw)
ratio2 = ow2 / oh2
if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2, ratio2 / ratio):
return ow1, oh1
return ow2, oh2
def generate_video(
model_dir: str,
prompt: str,
negative_prompt: str | None = None,
image: str | None = None,
width: int = 1280,
height: int = 704,
num_frames: int = 81,
steps: int = None,
guide_scale: str | float | tuple = None,
shift: float = None,
seed: int = -1,
output_path: str = "output.mp4",
scheduler: str = "unipc",
loras: list | None = None,
loras_high: list | None = None,
loras_low: list | None = None,
tiling: str = "auto",
no_compile: bool = False,
trim_first_frames: int = 0,
debug_latents: bool = False,
):
"""Generate video using Wan pipeline (supports T2V and I2V).
Args:
model_dir: Path to converted MLX model directory
prompt: Text prompt
negative_prompt: Negative prompt (None = use config default, "" = no negative prompt)
image: Path to input image for I2V (None = T2V mode)
width: Video width
height: Video height
num_frames: Number of frames (must be 4n+1)
steps: Number of diffusion steps (None = use config default)
guide_scale: Guidance scale: float for single, (low,high) for dual (None = config default)
shift: Noise schedule shift (None = use config default)
seed: Random seed (-1 for random)
output_path: Output video path
scheduler: Solver type: 'euler', 'dpm++', or 'unipc' (default)
loras: Optional list of (path, strength) tuples applied to all models
loras_high: Optional list of (path, strength) tuples for high-noise model only
loras_low: Optional list of (path, strength) tuples for low-noise model only
tiling: Tiling mode for VAE decoding. Options:
- "auto": Automatically determine tiling based on video size (default)
- "none": Disable tiling
- "default", "aggressive", "conservative": Preset tiling configs
- "spatial": Spatial tiling only
- "temporal": Temporal tiling only
no_compile: If True, skip mx.compile on models (useful for debugging)
trim_first_frames: Number of temporal latent positions to generate extra
and discard from the start. Each position = 4 pixel frames. Use 1
to fix first-frame artifacts on 14B models (generates 4 extra frames,
discards first 4). Use 2 for more aggressive trimming. Default: 0.
debug_latents: If True, print per-temporal-position latent statistics
after denoising for diagnosing first-frame artifacts.
"""
import json
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
)
model_dir = Path(model_dir)
# Load config from model dir if available, otherwise auto-detect
config_path = model_dir / "config.json"
quantization = None
if config_path.exists():
with open(config_path) as f:
config_dict = json.load(f)
# Extract quantization config (not a model config field)
quantization = config_dict.pop("quantization", None)
# Handle tuple fields stored as lists in JSON
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
if key in config_dict and isinstance(config_dict[key], list):
config_dict[key] = tuple(config_dict[key])
config = WanModelConfig(**{
k: v for k, v in config_dict.items()
if k in WanModelConfig.__dataclass_fields__
})
else:
# Auto-detect: dual model files → 2.2, single model → 2.1
if (model_dir / "low_noise_model.safetensors").exists():
config = WanModelConfig.wan22_t2v_14b()
else:
# Detect 1.3B vs 14B from weight shapes
model_path = model_dir / "model.safetensors"
if model_path.exists():
probe = mx.load(str(model_path), return_metadata=False)
for k, v in probe.items():
if "patch_embedding_proj.weight" in k:
dim = v.shape[0]
if dim <= 2048:
config = WanModelConfig.wan21_t2v_1_3b()
else:
config = WanModelConfig.wan21_t2v_14b()
break
else:
config = WanModelConfig.wan21_t2v_14b()
del probe
else:
config = WanModelConfig.wan21_t2v_14b()
is_dual = config.dual_model
is_i2v = image is not None
# Validate config against actual weights (handles mismatched config.json)
if not is_dual:
model_path = model_dir / "model.safetensors"
if model_path.exists():
probe = mx.load(str(model_path), return_metadata=False)
for k, v in probe.items():
if "patch_embedding_proj.weight" in k:
actual_dim = v.shape[0]
if actual_dim != config.dim:
print(f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}")
if actual_dim <= 2048:
config = WanModelConfig.wan21_t2v_1_3b()
else:
config = WanModelConfig.wan21_t2v_14b()
break
del probe
# Auto-correct Wan2.2 VAE params from stale configs
if config.in_dim == 48 and config.vae_z_dim != 48:
print(f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}")
config = WanModelConfig(**{
**{f.name: getattr(config, f.name) for f in config.__dataclass_fields__.values()},
"vae_z_dim": 48,
"vae_stride": (4, 16, 16),
"sample_fps": 24,
})
# Apply defaults from config if not overridden
if steps is None:
steps = config.sample_steps
if shift is None:
shift = config.sample_shift
if guide_scale is None:
guide_scale = config.sample_guide_scale
# Normalize guide_scale
if isinstance(guide_scale, (int, float)):
guide_scale = float(guide_scale)
elif isinstance(guide_scale, str):
parts = [float(x) for x in guide_scale.split(",")]
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
# Detect CFG-disabled mode (guide_scale=1.0 for all models → skip uncond pass for 2x speedup)
if isinstance(guide_scale, tuple):
cfg_disabled = all(gs <= 1.0 for gs in guide_scale)
else:
cfg_disabled = guide_scale <= 1.0
# Validate frame count
assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}"
gen_frames = num_frames
if trim_first_frames > 0:
gen_frames = num_frames + trim_first_frames * 4
print(f"{Colors.DIM} Trim: generating {gen_frames} frames, will discard first {trim_first_frames * 4}{Colors.RESET}")
version_str = f"Wan{config.model_version}"
mode_str = "dual-model" if is_dual else "single-model"
pipeline_str = "Image-to-Video" if is_i2v else "Text-to-Video"
# Resolve negative prompt: explicit user value > config default
# The official Wan2.2 uses a Chinese negative prompt (config.sample_neg_prompt)
# that prevents oversaturation, artifacts, and comic look. We use it by default.
# Text cleaning (_clean_text) normalizes fullwidth chars to match official tokenization.
if negative_prompt is None:
neg_prompt_resolved = config.sample_neg_prompt
else:
neg_prompt_resolved = negative_prompt
print(f"{Colors.CYAN}{'='*60}")
print(f" {version_str} {pipeline_str} Generation (MLX, {mode_str})")
print(f"{'='*60}{Colors.RESET}")
print(f"{Colors.DIM} Prompt: {prompt}")
if is_i2v:
print(f" Image: {image}")
if neg_prompt_resolved and neg_prompt_resolved.strip():
neg_display = neg_prompt_resolved[:60] + "..." if len(neg_prompt_resolved) > 60 else neg_prompt_resolved
print(f" Neg prompt: {neg_display}")
print(f" Size: {width}x{height}, Frames: {num_frames}")
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}")
if cfg_disabled:
print(f" CFG: disabled (guide_scale≤1 → B=1 fast path, 2x denoising speedup)")
print(f"{Colors.RESET}")
# Seed
if seed < 0:
seed = random.randint(0, 2**32 - 1)
mx.random.seed(seed)
np.random.seed(seed)
print(f"{Colors.DIM} Seed: {seed}{Colors.RESET}")
# Align dimensions to patch_size * vae_stride (required for patchify)
vae_stride = config.vae_stride
patch_size = config.patch_size
align_h = patch_size[1] * vae_stride[1] # e.g. 2*16=32
align_w = patch_size[2] * vae_stride[2]
if height % align_h != 0 or width % align_w != 0:
old_h, old_w = height, width
height = (height // align_h) * align_h
width = (width // align_w) * align_w
if height == 0:
height = align_h
if width == 0:
width = align_w
print(f"{Colors.DIM} Aligned {old_w}x{old_h}{width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}")
# Enforce max_area constraint (model-specific resolution limit)
if config.max_area > 0 and height * width > config.max_area:
old_h, old_w = height, width
width, height = _best_output_size(width, height, align_w, align_h, config.max_area)
print(
f"{Colors.YELLOW} ⚠ Resolution {old_w}x{old_h} exceeds model's max area "
f"({config.max_area:,}px). Adjusted → {width}x{height}{Colors.RESET}"
)
# Compute target latent shape
z_dim = config.vae_z_dim
t_latent = (gen_frames - 1) // vae_stride[0] + 1
h_latent = height // vae_stride[1]
w_latent = width // vae_stride[2]
target_shape = (z_dim, t_latent, h_latent, w_latent)
# Sequence length for transformer
seq_len = math.ceil(
(h_latent * w_latent) / (patch_size[1] * patch_size[2]) * t_latent
)
print(f"{Colors.DIM} Latent shape: {target_shape}")
print(f" Sequence length: {seq_len}{Colors.RESET}")
# Load T5 encoder
t1 = time.time()
print(f"\n{Colors.BLUE}Loading T5 encoder...{Colors.RESET}")
t5_path = model_dir / "t5_encoder.safetensors"
t5_encoder = load_t5_encoder(t5_path, config)
# Load tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
# Encode prompts
print(f"{Colors.BLUE}Encoding text...{Colors.RESET}")
context = encode_text(t5_encoder, tokenizer, prompt, config.text_len)
if cfg_disabled:
context_null = None
mx.eval(context)
else:
context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len)
mx.eval(context, context_null)
# Free T5 from memory
del t5_encoder
gc.collect(); mx.clear_cache()
print(f"{Colors.DIM} T5 encoding: {time.time() - t1:.1f}s{Colors.RESET}")
# I2V: encode image to latent space
z_img = None
i2v_mask = None
i2v_mask_tokens = None
y_i2v = None
is_i2v_channel_concat = is_i2v and config.model_type == "i2v"
is_i2v_mask_blend = is_i2v and config.model_type != "i2v"
if is_i2v:
print(f"\n{Colors.BLUE}Encoding input image...{Colors.RESET}")
t_img = time.time()
vae_path = model_dir / "vae.safetensors"
if is_i2v_channel_concat:
# I2V-14B: encode full video (first frame = image, rest = zeros)
# and construct y tensor with mask + encoded latents
from PIL import Image
img = Image.open(image).convert("RGB")
scale = max(width / img.width, height / img.height)
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
x1, y1 = (img.width - width) // 2, (img.height - height) // 2
img = img.crop((x1, y1, x1 + width, y1 + height))
img_arr = mx.array(np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0) # [H, W, 3]
img_chw = img_arr.transpose(2, 0, 1) # [3, H, W]
# Build video: first frame = image, rest = zeros -> [3, F, H, W]
# Chunked encoding processes 1-frame + 4-frame chunks with temporal caching
video = mx.concatenate([
img_chw[:, None, :, :],
mx.zeros((3, num_frames - 1, height, width)),
], axis=1)
# Encode through Wan2.1 VAE -> [1, z_dim, T_lat, H_lat, W_lat]
vae_enc = load_vae_encoder(vae_path, config)
z_video = vae_enc.encode(video[None]) # [1, 16, T_lat, H_lat, W_lat]
mx.eval(z_video)
z_video = z_video[0] # [16, T_lat, H_lat, W_lat]
# Build mask: 1 for first frame, 0 for rest -> rearrange to [4, T_lat, H, W]
msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
# Repeat first frame 4x, concat rest: [1, 4 + (F-1), H_lat, W_lat]
msk = mx.concatenate([
mx.repeat(msk[:, :1], 4, axis=1),
msk[:, 1:],
], axis=1)
# Reshape to [1, T_lat, 4, H_lat, W_lat] then transpose -> [4, T_lat, H_lat, W_lat]
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
# y = concat([mask, encoded_video]) -> [20, T_lat, H_lat, W_lat]
y_i2v = mx.concatenate([msk, z_video], axis=0)
mx.eval(y_i2v)
del vae_enc, img_arr, img_chw, video, z_video, msk
else:
# TI2V-5B: encode single image, blend with noise via mask
img_tensor = preprocess_image(image, width, height)
mx.eval(img_tensor)
vae_enc = load_vae_encoder(vae_path, config)
z_img = vae_enc.encode(img_tensor) # [1, 1, H_lat, W_lat, z_dim]
mx.eval(z_img)
z_img = z_img[0].transpose(3, 0, 1, 2) # [z_dim, 1, H_lat, W_lat]
i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size)
del vae_enc, img_tensor
gc.collect(); mx.clear_cache()
print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}")
# Load transformer models
print(f"\n{Colors.BLUE}Loading transformer model(s)...{Colors.RESET}")
if quantization:
print(f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}")
t2 = time.time()
# Merge per-model LoRAs with shared LoRAs
_loras_low = (loras or []) + (loras_low or []) or None
_loras_high = (loras or []) + (loras_high or []) or None
_loras_single = loras
if is_dual:
low_noise_path = model_dir / "low_noise_model.safetensors"
high_noise_path = model_dir / "high_noise_model.safetensors"
low_noise_model = load_wan_model(low_noise_path, config, quantization, loras=_loras_low)
high_noise_model = load_wan_model(high_noise_path, config, quantization, loras=_loras_high)
else:
single_model = load_wan_model(model_dir / "model.safetensors", config, quantization, loras=_loras_single)
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
# Precompute text embeddings once (avoids redundant MLP in every step)
# Each model has its own text_embedding weights, so dual models need separate embeddings
if cfg_disabled:
# No CFG: only compute cond embeddings (B=1 forward pass, 2x faster)
if is_dual:
context_emb_low = low_noise_model.embed_text([context])
context_emb_high = high_noise_model.embed_text([context])
mx.eval(context_emb_low, context_emb_high)
context_cond_low = context_emb_low[0:1]
context_cond_high = context_emb_high[0:1]
else:
context_emb = single_model.embed_text([context])
mx.eval(context_emb)
context_cond = context_emb[0:1]
else:
if is_dual:
context_emb_low = low_noise_model.embed_text([context, context_null])
context_emb_high = high_noise_model.embed_text([context, context_null])
mx.eval(context_emb_low, context_emb_high)
context_cfg_low = mx.concatenate([context_emb_low[0:1], context_emb_low[1:2]], axis=0)
context_cfg_high = mx.concatenate([context_emb_high[0:1], context_emb_high[1:2]], axis=0)
else:
context_emb = single_model.embed_text([context, context_null])
mx.eval(context_emb)
context_cfg = mx.concatenate([context_emb[0:1], context_emb[1:2]], axis=0)
# Precompute cross-attention K/V caches (constant across all steps)
if cfg_disabled:
if is_dual:
cross_kv_low = low_noise_model.prepare_cross_kv(context_cond_low)
cross_kv_high = high_noise_model.prepare_cross_kv(context_cond_high)
mx.eval(cross_kv_low, cross_kv_high)
else:
cross_kv = single_model.prepare_cross_kv(context_cond)
mx.eval(cross_kv)
else:
if is_dual:
cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg_low)
cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg_high)
mx.eval(cross_kv_low, cross_kv_high)
else:
cross_kv = single_model.prepare_cross_kv(context_cfg)
mx.eval(cross_kv)
# Precompute RoPE frequencies (grid sizes are constant across all steps)
f_grid = t_latent // patch_size[0]
h_grid = h_latent // patch_size[1]
w_grid = w_latent // patch_size[2]
if cfg_disabled:
rope_grid_sizes = [(f_grid, h_grid, w_grid)]
else:
rope_grid_sizes = [(f_grid, h_grid, w_grid), (f_grid, h_grid, w_grid)]
if is_dual:
rope_cos_sin_low = low_noise_model.prepare_rope(rope_grid_sizes)
rope_cos_sin_high = high_noise_model.prepare_rope(rope_grid_sizes)
mx.eval(rope_cos_sin_low, rope_cos_sin_high)
else:
rope_cos_sin = single_model.prepare_rope(rope_grid_sizes)
mx.eval(rope_cos_sin)
# Setup scheduler
_schedulers = {
"euler": FlowMatchEulerScheduler,
"dpm++": FlowDPMPP2MScheduler,
"unipc": FlowUniPCScheduler,
}
sched_cls = _schedulers.get(scheduler, FlowUniPCScheduler)
sched = sched_cls(num_train_timesteps=config.num_train_timesteps)
sched.set_timesteps(steps, shift=shift)
# Generate initial noise
noise = mx.random.normal(target_shape)
# I2V initialization: TI2V-5B blends image with noise, I2V-14B uses pure noise
if is_i2v_mask_blend:
latents = (1.0 - i2v_mask) * z_img + i2v_mask * noise
else:
latents = noise
# Boundary for model switching (dual model only)
boundary = (config.boundary * config.num_train_timesteps) if is_dual else None
# Diffusion loop
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
t3 = time.time()
# Compile model forward for faster denoising
if not no_compile:
models_to_compile = (
[high_noise_model, low_noise_model] if is_dual else [single_model]
)
for m in models_to_compile:
m._compiled = mx.compile(m)
# Pre-convert timesteps to Python list to avoid .item() sync each step
timestep_list = sched.timesteps.tolist()
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
timestep_val = timestep_list[i]
# Select model, cached K/V, and precomputed RoPE
if is_dual:
if timestep_val >= boundary:
model = high_noise_model
kv = cross_kv_high
rcs = rope_cos_sin_high
else:
model = low_noise_model
kv = cross_kv_low
rcs = rope_cos_sin_low
else:
model = single_model
kv = cross_kv
rcs = rope_cos_sin
# Use compiled forward when available (faster after first trace)
_call = getattr(model, '_compiled', model)
if cfg_disabled:
# No CFG: B=1 forward pass (2x faster than B=2 CFG batch)
if is_i2v_mask_blend:
t_tokens = i2v_mask_tokens * timestep_val
pad_len = seq_len - t_tokens.shape[1]
if pad_len > 0:
t_tokens = mx.concatenate(
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
)
t_batch = t_tokens # [1, L]
else:
t_batch = mx.array([timestep_val])
y_arg = [y_i2v] if is_i2v_channel_concat else None
if is_dual:
ctx = context_cond_high if timestep_val >= boundary else context_cond_low
else:
ctx = context_cond
preds = _call(
[latents],
t=t_batch,
context=ctx,
seq_len=seq_len,
cross_kv_caches=kv,
y=y_arg,
rope_cos_sin=rcs,
)
noise_pred = preds[0]
del preds
else:
# CFG: batch cond + uncond into single B=2 forward pass
if is_dual:
gs = guide_scale[1] if timestep_val >= boundary else guide_scale[0]
else:
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
if is_i2v_mask_blend:
t_tokens = i2v_mask_tokens * timestep_val
pad_len = seq_len - t_tokens.shape[1]
if pad_len > 0:
t_tokens = mx.concatenate(
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
)
t_batch = mx.concatenate([t_tokens, t_tokens], axis=0)
else:
t_batch = mx.array([timestep_val, timestep_val])
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
ctx = context_cfg if not is_dual else (
context_cfg_high if timestep_val >= boundary else context_cfg_low
)
preds = _call(
[latents, latents],
t=t_batch,
context=ctx,
seq_len=seq_len,
cross_kv_caches=kv,
y=y_arg,
rope_cos_sin=rcs,
)
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
del noise_pred_cond, noise_pred_uncond, preds
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
# TI2V-5B: re-apply mask to keep first frame frozen
if is_i2v_mask_blend:
latents = (1.0 - i2v_mask) * z_img + i2v_mask * latents
# Release temporaries before eval to free memory for graph execution
del noise_pred
mx.eval(latents)
print(f"{Colors.DIM} Denoising: {time.time() - t3:.1f}s{Colors.RESET}")
# Diagnostic: per-temporal-position latent statistics
if debug_latents:
lat_np = np.array(latents) # [C, T, H, W]
n_t = lat_np.shape[1]
print(f"\n{Colors.CYAN} Latent diagnostics (shape {lat_np.shape}):{Colors.RESET}")
print(f" {'Pos':>4s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s} {'AbsMean':>8s}")
for t_pos in range(min(n_t, 8)):
frame = lat_np[:, t_pos, :, :]
print(f" {t_pos:4d} {frame.mean():8.4f} {frame.std():8.4f} "
f"{frame.min():8.4f} {frame.max():8.4f} {np.abs(frame).mean():8.4f}")
if n_t > 8:
interior = lat_np[:, 4:, :, :]
print(f" {'4+':>4s} {interior.mean():8.4f} {interior.std():8.4f} "
f"{interior.min():8.4f} {interior.max():8.4f} {np.abs(interior).mean():8.4f}")
print()
# Free transformer models and text embeddings
if is_dual:
del low_noise_model, high_noise_model, cross_kv_low, cross_kv_high
if cfg_disabled:
del context_cond_low, context_cond_high
else:
del context_cfg_low, context_cfg_high
else:
del single_model, cross_kv
if cfg_disabled:
del context_cond
else:
del context_cfg
del model, kv, context
if context_null is not None:
del context_null
gc.collect(); mx.clear_cache()
# Load VAE and decode
print(f"\n{Colors.BLUE}Decoding with VAE...{Colors.RESET}")
t4 = time.time()
vae_path = model_dir / "vae.safetensors"
vae = load_vae_decoder(vae_path, config)
is_wan22_vae = config.vae_z_dim == 48
# Temporal extend: prepend reflected latent frames to the VAE input so that
# the CausalConv3d zero-padding artifacts fall on the prefix (which we crop).
# This gives the first real frame a full temporal receptive field of real data.
# Select tiling configuration
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
if tiling == "none":
tiling_config = None
elif tiling == "auto":
tiling_config = TilingConfig.auto(height, width, num_frames)
elif tiling == "default":
tiling_config = TilingConfig.default()
elif tiling == "aggressive":
tiling_config = TilingConfig.aggressive()
elif tiling == "conservative":
tiling_config = TilingConfig.conservative()
elif tiling == "spatial":
tiling_config = TilingConfig.spatial_only()
elif tiling == "temporal":
tiling_config = TilingConfig.temporal_only()
else:
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
tiling_config = TilingConfig.auto(height, width, num_frames)
if tiling_config is not None:
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
if is_wan22_vae:
from mlx_video.models.wan.vae22 import denormalize_latents
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
z = latents.transpose(1, 2, 3, 0)[None]
z = denormalize_latents(z)
if tiling_config is not None:
video = vae.decode_tiled(z, tiling_config)
else:
video = vae(z)
mx.eval(video)
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
video = np.array(video[0]) # [T', H', W', 3]
video = (video + 1.0) / 2.0
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
else:
if tiling_config is not None:
video = vae.decode_tiled(latents[None], tiling_config)
else:
video = vae.decode(latents[None])
mx.eval(video)
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
video = np.array(video[0]) # [3, T', H, W]
video = (video + 1.0) / 2.0
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
video = video.transpose(1, 2, 3, 0) # [T, H, W, 3]
# Trim first N temporal chunks if requested (avoids first-frame artifacts)
if trim_first_frames > 0:
trim_pixels = trim_first_frames * 4
video = video[trim_pixels:]
print(f"{Colors.DIM} Trimmed first {trim_pixels} frames ({video.shape[0]} remaining){Colors.RESET}")
save_video(video, output_path, fps=config.sample_fps)
print(f"\n{Colors.GREEN}✓ Video saved to {output_path}{Colors.RESET}")
print(f"{Colors.DIM} Total time: {time.time() - t1:.1f}s{Colors.RESET}")
def main():
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory")
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
parser.add_argument("--image", type=str, default=None,
help="Path to input image for I2V (omit for T2V mode)")
parser.add_argument("--negative-prompt", type=str, default=None,
help="Negative prompt for CFG (default: official Chinese prompt from config)")
parser.add_argument("--no-negative-prompt", action="store_true",
help="Disable negative prompt (use empty string instead of config default)")
parser.add_argument("--width", type=int, default=1280, help="Video width (default: 1280)")
parser.add_argument("--height", type=int, default=704, help="Video height (default: 704; 720p models use 704)")
parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)")
parser.add_argument("--steps", type=int, default=None, help="Number of diffusion steps (default: from config)")
parser.add_argument("--guide-scale", type=str, default=None, help="Guidance scale: single float or low,high pair")
parser.add_argument("--shift", type=float, default=None, help="Noise schedule shift (default: from config)")
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
parser.add_argument("--output-path", type=str, default="output.mp4", help="Output video path")
parser.add_argument(
"--scheduler", type=str, default="unipc",
choices=["euler", "dpm++", "unipc"],
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
)
parser.add_argument(
"--lora", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to all models (repeatable). Format: --lora path.safetensors 0.8",
)
parser.add_argument(
"--lora-high", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to high-noise model only (dual-model, repeatable)",
)
parser.add_argument(
"--lora-low", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to low-noise model only (dual-model, repeatable)",
)
parser.add_argument(
"--tiling",
type=str,
default="auto",
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
help="VAE tiling mode to reduce memory during decoding (default: auto)",
)
parser.add_argument(
"--no-compile", action="store_true",
help="Disable mx.compile on models (for debugging)",
)
parser.add_argument(
"--trim-first-frames", type=int, default=0, metavar="N",
help="Generate N extra temporal chunks (N×4 frames) and discard them from the start. "
"Fixes first-frame color/lighting artifacts on 14B models. Try 1 first (4 frames). "
"Default: 0 (disabled)",
)
parser.add_argument(
"--debug-latents", action="store_true",
help="Print per-temporal-position latent statistics after denoising (diagnostic)",
)
args = parser.parse_args()
# Parse guide scale
guide_scale = None
if args.guide_scale is not None:
parts = [float(x) for x in args.guide_scale.split(",")]
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
# Handle negative prompt: --no-negative-prompt forces empty, otherwise pass through
neg_prompt = args.negative_prompt
if args.no_negative_prompt:
neg_prompt = ""
# Parse LoRA configs: convert [path, strength_str] → (path, float)
def _parse_lora_args(lora_list):
if not lora_list:
return None
return [(path, float(strength)) for path, strength in lora_list]
generate_video(
model_dir=args.model_dir,
prompt=args.prompt,
negative_prompt=neg_prompt,
image=args.image,
width=args.width,
height=args.height,
num_frames=args.num_frames,
steps=args.steps,
guide_scale=guide_scale,
shift=args.shift,
seed=args.seed,
output_path=args.output_path,
scheduler=args.scheduler,
loras=_parse_lora_args(args.lora),
loras_high=_parse_lora_args(args.lora_high),
loras_low=_parse_lora_args(args.lora_low),
tiling=args.tiling,
no_compile=args.no_compile,
trim_first_frames=args.trim_first_frames,
debug_latents=args.debug_latents,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,25 @@
"""LoRA support for mlx-video."""
from mlx_video.lora.apply import (
LoRALinear,
apply_lora_to_linear,
apply_loras_to_model,
apply_loras_to_weights,
)
from mlx_video.lora.loader import (
load_lora_weights,
load_multiple_loras,
)
from mlx_video.lora.types import AppliedLoRA, LoRAConfig, LoRAWeights
__all__ = [
"LoRAConfig",
"LoRAWeights",
"AppliedLoRA",
"load_lora_weights",
"load_multiple_loras",
"apply_lora_to_linear",
"apply_loras_to_weights",
"apply_loras_to_model",
"LoRALinear",
]

393
mlx_video/lora/apply.py Normal file
View File

@@ -0,0 +1,393 @@
"""Apply LoRA weights to model layers."""
from typing import Dict, List, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.lora.types import LoRAWeights
def apply_lora_to_linear(
linear_weight: mx.array,
lora_weights_and_strengths: List[Tuple[LoRAWeights, float]],
) -> mx.array:
"""Apply one or more LoRAs to a linear layer weight.
Args:
linear_weight: Original weight matrix [out_features, in_features]
lora_weights_and_strengths: List of (LoRAWeights, strength) tuples
Returns:
Modified weight with LoRA deltas applied (preserves original dtype)
"""
orig_dtype = linear_weight.dtype
modified_weight = linear_weight
for weights, strength in lora_weights_and_strengths:
scale = weights.scale
# Compute delta in float32 for precision, then cast back to avoid
# promoting model weights (e.g. bfloat16 → float32 causes ~1.5x slowdown)
delta = (weights.lora_B @ weights.lora_A) * (scale * strength)
modified_weight = modified_weight + delta.astype(orig_dtype)
return modified_weight
def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str:
"""Normalize LoRA module name to match Wan2.2 MLX model weight keys.
Handles:
- Stripping common prefixes (diffusion_model., model., etc.)
- FFN key mapping: ffn.0 → ffn.fc1, ffn.2 → ffn.fc2
- Embedding key mapping: text_embedding.0 → text_embedding_0, etc.
- Time projection: time_projection.1 → time_projection
- Patch embedding: patch_embedding → patch_embedding_proj
Args:
lora_key: Original LoRA module name
model_keys: Set of all model weight keys
Returns:
Normalized key that matches model weights
"""
# Try the key as-is first
if f"{lora_key}.weight" in model_keys or lora_key in model_keys:
return lora_key
# Common prefixes to strip
prefixes_to_strip = [
"model.diffusion_model.",
"diffusion_model.",
"base_model.model.",
"model.",
]
candidates = [lora_key]
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
candidates.append(lora_key[len(prefix):])
for candidate in candidates:
# Try as-is
if f"{candidate}.weight" in model_keys or candidate in model_keys:
return candidate
# Apply Wan2.2 key transformations
transformed = candidate
# FFN: ffn.0 → ffn.fc1, ffn.2 → ffn.fc2
transformed = transformed.replace(".ffn.0.", ".ffn.fc1.")
transformed = transformed.replace(".ffn.2.", ".ffn.fc2.")
if transformed.endswith(".ffn.0"):
transformed = transformed[:-len(".ffn.0")] + ".ffn.fc1"
if transformed.endswith(".ffn.2"):
transformed = transformed[:-len(".ffn.2")] + ".ffn.fc2"
# Text embedding: text_embedding.0 → text_embedding_0
transformed = transformed.replace("text_embedding.0.", "text_embedding_0.")
transformed = transformed.replace("text_embedding.2.", "text_embedding_1.")
if transformed.endswith("text_embedding.0"):
transformed = transformed[:-len("text_embedding.0")] + "text_embedding_0"
if transformed.endswith("text_embedding.2"):
transformed = transformed[:-len("text_embedding.2")] + "text_embedding_1"
# Time embedding: time_embedding.0 → time_embedding_0
transformed = transformed.replace("time_embedding.0.", "time_embedding_0.")
transformed = transformed.replace("time_embedding.2.", "time_embedding_1.")
if transformed.endswith("time_embedding.0"):
transformed = transformed[:-len("time_embedding.0")] + "time_embedding_0"
if transformed.endswith("time_embedding.2"):
transformed = transformed[:-len("time_embedding.2")] + "time_embedding_1"
# Time projection: time_projection.1 → time_projection
transformed = transformed.replace("time_projection.1.", "time_projection.")
if transformed.endswith("time_projection.1"):
transformed = transformed[:-len("time_projection.1")] + "time_projection"
# Patch embedding: patch_embedding → patch_embedding_proj
if "patch_embedding" in transformed and "patch_embedding_proj" not in transformed:
transformed = transformed.replace("patch_embedding", "patch_embedding_proj")
if f"{transformed}.weight" in model_keys or transformed in model_keys:
return transformed
# Return best attempt with prefix stripped
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
return lora_key[len(prefix):]
return lora_key
# Also support LTX-style key normalization
def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str:
"""Normalize LoRA module name to match LTX MLX model weight keys."""
if f"{lora_key}.weight" in model_keys or lora_key in model_keys:
return lora_key
prefixes_to_strip = [
"model.diffusion_model.",
"diffusion_model.",
"model.",
]
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
normalized = lora_key[len(prefix):]
if f"{normalized}.weight" in model_keys or normalized in model_keys:
return normalized
transformed = normalized
if transformed.endswith(".to_out.0"):
transformed = transformed[:-len(".to_out.0")] + ".to_out"
transformed = transformed.replace(".to_out.0.", ".to_out.")
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.")
transformed = transformed.replace(".ff.net.2", ".ff.proj_out")
transformed = transformed.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
transformed = transformed.replace(".audio_ff.net.0.proj", ".audio_ff.proj_in")
transformed = transformed.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
transformed = transformed.replace(".audio_ff.net.2", ".audio_ff.proj_out")
if f"{transformed}.weight" in model_keys or transformed in model_keys:
return transformed
# Try transformations on the original key
transformed = lora_key
if transformed.endswith(".to_out.0"):
transformed = transformed[:-len(".to_out.0")] + ".to_out"
transformed = transformed.replace(".to_out.0.", ".to_out.")
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.")
transformed = transformed.replace(".ff.net.2", ".ff.proj_out")
if f"{transformed}.weight" in model_keys or transformed in model_keys:
return transformed
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
return lora_key[len(prefix):]
return lora_key
def _normalize_lora_key(lora_key: str, model_keys: set) -> str:
"""Normalize LoRA module name to match model weight keys.
Auto-detects whether to use Wan2.2 or LTX key normalization based
on the presence of architecture-specific keys in the model.
"""
# Detect model architecture from keys
is_wan = any("self_attn.q.weight" in k for k in model_keys)
if is_wan:
return _normalize_wan_lora_key(lora_key, model_keys)
else:
return _normalize_ltx_lora_key(lora_key, model_keys)
def apply_loras_to_weights(
model_weights: Dict[str, mx.array],
module_to_loras: Dict[str, List[Tuple[LoRAWeights, float]]],
verbose: bool = False,
quantization_bits: int = 0,
) -> Dict[str, mx.array]:
"""Apply LoRAs to model weights.
Args:
model_weights: Original model state dictionary
module_to_loras: Dictionary mapping module names to lists of
(LoRAWeights, strength) tuples
verbose: If True, print detailed debug information
quantization_bits: If >0, weights are quantized at this bit width.
Quantized layers are dequantized before LoRA application
and re-quantized after.
Returns:
New state dictionary with LoRA-modified weights
"""
modified_weights = dict(model_weights)
model_keys = set(model_weights.keys())
applied_count = 0
skipped_count = 0
skipped_modules = []
for module_name, loras in module_to_loras.items():
normalized_name = _normalize_lora_key(module_name, model_keys)
weight_key = f"{normalized_name}.weight"
if weight_key not in modified_weights:
if normalized_name not in modified_weights:
skipped_count += 1
skipped_modules.append(module_name)
if verbose and skipped_count <= 5:
print(f" DEBUG: '{module_name}' -> '{normalized_name}' -> NOT FOUND")
similar = [
k
for k in list(model_keys)[:1000]
if normalized_name.split(".")[-1] in k
][:3]
if similar:
print(f" Similar keys: {similar}")
continue
weight_key = normalized_name
original_weight = modified_weights[weight_key]
# Handle quantized weights: dequantize → apply delta → re-quantize
scales_key = f"{normalized_name}.scales"
biases_key = f"{normalized_name}.biases"
is_quantized = (
original_weight.dtype == mx.uint32
and scales_key in modified_weights
and biases_key in modified_weights
)
if is_quantized:
scales = modified_weights[scales_key]
biases = modified_weights[biases_key]
group_size = (original_weight.shape[-1] * 32) // (scales.shape[-1] * quantization_bits)
dequantized = mx.dequantize(
original_weight, scales, biases, group_size=group_size, bits=quantization_bits
)
modified = apply_lora_to_linear(dequantized, loras)
# Re-quantize with same parameters
new_w, new_scales, new_biases = mx.quantize(modified, group_size=group_size, bits=quantization_bits)
modified_weights[weight_key] = new_w
modified_weights[scales_key] = new_scales
modified_weights[biases_key] = new_biases
else:
modified_weights[weight_key] = apply_lora_to_linear(original_weight, loras)
applied_count += 1
if applied_count > 0:
print(f" ✓ Applied to {applied_count} modules")
if skipped_count > 0:
print(f" ⚠ Skipped {skipped_count} incompatible modules")
return modified_weights
class LoRALinear(nn.Module):
"""Linear layer with on-the-fly LoRA application.
Wraps nn.Linear or nn.QuantizedLinear, computing LoRA delta at runtime:
output = base_linear(x) + (x @ lora_A.T @ lora_B.T) * scale * strength
"""
def __init__(
self,
linear: nn.Module,
lora_weights_and_strengths: List[Tuple[LoRAWeights, float]],
):
super().__init__()
self.linear = linear
self.lora_weights_and_strengths = lora_weights_and_strengths
def __call__(self, x: mx.array) -> mx.array:
output = self.linear(x)
for weights, strength in self.lora_weights_and_strengths:
scale = weights.scale
lora_out = x @ weights.lora_A.T @ weights.lora_B.T
output = output + (scale * strength * lora_out)
return output
def apply_loras_to_model(
model: nn.Module,
module_to_loras: Dict[str, List[Tuple[LoRAWeights, float]]],
verbose: bool = False,
) -> int:
"""Apply LoRAs to a model by merging into weights.
For QuantizedLinear layers: dequantizes to bf16, merges LoRA delta, and
replaces with a regular nn.Linear (no per-step overhead, no re-quantization
precision loss). Non-LoRA layers stay quantized.
For nn.Linear layers: merges LoRA delta directly into the weight.
Args:
model: The model to apply LoRAs to
module_to_loras: Dictionary mapping module names to (LoRAWeights, strength) lists
verbose: Print debug info
Returns:
Number of modules modified
"""
# Build a set of model module paths for key normalization
module_paths = set()
for name, _ in model.named_modules():
module_paths.add(name)
module_paths.add(f"{name}.weight")
# Map LoRA keys → model module paths
lora_to_module = {}
for lora_key in module_to_loras:
normalized = _normalize_lora_key(lora_key, module_paths)
if normalized.endswith(".weight"):
normalized = normalized[: -len(".weight")]
lora_to_module[lora_key] = normalized
applied_count = 0
dequant_count = 0
skipped = []
for lora_key, loras in module_to_loras.items():
module_path = lora_to_module[lora_key]
parts = module_path.split(".")
# Traverse to the parent module
parent = model
try:
for part in parts[:-1]:
parent = getattr(parent, part) if not part.isdigit() else parent[int(part)]
leaf_name = parts[-1]
target = getattr(parent, leaf_name) if not leaf_name.isdigit() else parent[int(leaf_name)]
except (AttributeError, IndexError, TypeError):
skipped.append(lora_key)
if verbose:
print(f" DEBUG: '{lora_key}' -> '{module_path}' -> module not found")
continue
if isinstance(target, nn.QuantizedLinear):
# Dequantize → merge LoRA → replace with bf16 Linear
weight = mx.dequantize(
target.weight, target.scales, target.biases,
group_size=target.group_size, bits=target.bits,
)
merged = apply_lora_to_linear(weight, loras)
new_linear = nn.Linear(merged.shape[1], merged.shape[0])
new_linear.weight = merged
if "bias" in target:
new_linear.bias = target.bias
if leaf_name.isdigit():
parent[int(leaf_name)] = new_linear
else:
setattr(parent, leaf_name, new_linear)
dequant_count += 1
applied_count += 1
elif isinstance(target, nn.Linear):
# Merge directly into weight
target.weight = apply_lora_to_linear(target.weight, loras)
applied_count += 1
else:
skipped.append(lora_key)
if verbose:
print(f" DEBUG: '{module_path}' is {type(target).__name__}, not Linear")
continue
if applied_count > 0:
msg = f" ✓ Applied to {applied_count} modules"
if dequant_count > 0:
msg += f" ({dequant_count} dequantized to bf16)"
print(msg)
if skipped:
print(f" ⚠ Skipped {len(skipped)} incompatible modules")
return applied_count

122
mlx_video/lora/loader.py Normal file
View File

@@ -0,0 +1,122 @@
"""LoRA weight loading utilities."""
import re
from pathlib import Path
from typing import Dict, List, Optional
import mlx.core as mx
from mlx_video.lora.types import LoRAConfig, LoRAWeights
def load_lora_weights(lora_path: Path) -> Dict[str, LoRAWeights]:
"""Load LoRA weights from a safetensors file.
Supports both key conventions:
- {module_name}.lora_A.weight / {module_name}.lora_B.weight
- {module_name}.lora_down.weight / {module_name}.lora_up.weight
Args:
lora_path: Path to the LoRA safetensors file
Returns:
Dictionary mapping module names to LoRAWeights objects
Raises:
FileNotFoundError: If the LoRA file doesn't exist
ValueError: If the LoRA file format is invalid
"""
if not lora_path.exists():
raise FileNotFoundError(f"LoRA file not found: {lora_path}")
all_weights = mx.load(str(lora_path))
# Group weights by module name, handling both naming conventions
lora_weights = {}
module_names = set()
for key in all_weights.keys():
# Format 1: {module}.lora_A.weight / {module}.lora_B.weight
match = re.match(r"(.+)\.lora_([AB])\.weight$", key)
if match:
module_names.add(match.group(1))
continue
# Format 2: {module}.lora_down.weight / {module}.lora_up.weight
match = re.match(r"(.+)\.lora_(down|up)\.weight$", key)
if match:
module_names.add(match.group(1))
for module_name in module_names:
# Try both key conventions
key_a = f"{module_name}.lora_A.weight"
key_b = f"{module_name}.lora_B.weight"
if key_a not in all_weights or key_b not in all_weights:
key_a = f"{module_name}.lora_down.weight"
key_b = f"{module_name}.lora_up.weight"
if key_a not in all_weights or key_b not in all_weights:
continue
lora_a = all_weights[key_a]
lora_b = all_weights[key_b]
if lora_a.ndim != 2 or lora_b.ndim != 2:
raise ValueError(
f"Invalid LoRA shape for {module_name}: "
f"lora_A={lora_a.shape}, lora_B={lora_b.shape}"
)
rank = lora_a.shape[0]
if lora_b.shape[1] != rank:
raise ValueError(
f"LoRA rank mismatch for {module_name}: "
f"lora_A rank={rank}, lora_B rank={lora_b.shape[1]}"
)
# Check for per-module alpha stored as a scalar tensor
alpha_key = f"{module_name}.alpha"
if alpha_key in all_weights:
alpha = float(all_weights[alpha_key].item())
else:
alpha = float(rank)
lora_weights[module_name] = LoRAWeights(
lora_A=lora_a,
lora_B=lora_b,
rank=rank,
alpha=alpha,
module_name=module_name,
)
if not lora_weights:
raise ValueError(f"No valid LoRA weights found in {lora_path}")
return lora_weights
def load_multiple_loras(
configs: List[LoRAConfig],
) -> Dict[str, List[tuple]]:
"""Load multiple LoRA configurations.
Args:
configs: List of LoRAConfig objects
Returns:
Dictionary mapping module names to lists of (LoRAWeights, strength) tuples.
"""
module_to_loras: Dict[str, list] = {}
for config in configs:
lora_weights = load_lora_weights(config.path)
for module_name, weights in lora_weights.items():
if config.target_modules is not None:
if module_name not in config.target_modules:
continue
if module_name not in module_to_loras:
module_to_loras[module_name] = []
module_to_loras[module_name].append((weights, config.strength))
return module_to_loras

74
mlx_video/lora/types.py Normal file
View File

@@ -0,0 +1,74 @@
"""Data structures for LoRA support."""
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import mlx.core as mx
@dataclass
class LoRAWeights:
"""Container for LoRA weight matrices.
Attributes:
lora_A: Low-rank matrix A of shape [rank, in_features]
lora_B: Low-rank matrix B of shape [out_features, rank]
rank: Rank of the LoRA decomposition
alpha: LoRA scaling parameter (default: rank)
module_name: Target module name in the model
"""
lora_A: mx.array
lora_B: mx.array
rank: int
alpha: float
module_name: str
@property
def scale(self) -> float:
"""Compute the scale factor: alpha / rank."""
return self.alpha / self.rank
@dataclass
class LoRAConfig:
"""Configuration for a single LoRA.
Attributes:
path: Path to the LoRA safetensors file
strength: Strength/weight to apply this LoRA (typically 0.0-2.0)
target_modules: Optional list of module names to apply LoRA to.
If None, applies to all available modules in the LoRA.
"""
path: Path
strength: float = 1.0
target_modules: Optional[list[str]] = None
def __post_init__(self):
"""Validate and normalize the configuration."""
self.path = Path(self.path)
if not self.path.exists():
raise FileNotFoundError(f"LoRA file not found: {self.path}")
if self.strength < 0:
raise ValueError(f"LoRA strength must be non-negative, got {self.strength}")
@dataclass
class AppliedLoRA:
"""Represents a LoRA applied to a specific module.
Attributes:
weights: The LoRA weight matrices
strength: Application strength for this LoRA
"""
weights: LoRAWeights
strength: float
def compute_delta(self) -> mx.array:
"""Compute the weight delta: strength * scale * (lora_B @ lora_A)."""
scale = self.weights.scale
delta = self.weights.lora_B @ self.weights.lora_A
return scale * self.strength * delta

View File

@@ -1,2 +1,3 @@
from mlx_video.models.ltx import LTXModel, LTXModelConfig from mlx_video.models.ltx import LTXModel, LTXModelConfig
from mlx_video.models.wan import WanModel, WanModelConfig

View File

@@ -0,0 +1,349 @@
## Wan2.1 / Wan2.2
Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.com/Wan-Video/Wan2.2) are text-to-video diffusion models built on a DiT (Diffusion Transformer) backbone with a T5 text encoder and 3D VAE.
They share the same model architecture — the difference is in the inference pipeline:
| | Wan2.1 | Wan2.2 T2V-14B | Wan2.2 I2V-14B | Wan2.2 TI2V-5B |
|---|--------|--------|--------|--------|
| **Task** | Text-to-Video | Text-to-Video | Image-to-Video | Text+Image-to-Video |
| **Pipeline** | Single model | Dual model | Dual model | Single model |
| **Sizes** | 1.3B, 14B | 14B | 14B | 5B |
| **Resolution** | 480P (1.3B), 720P (14B) | 720P | 720P | 720P |
| **Steps** | 50 | 40 | 40 | 40 |
| **Guidance** | 5.0 (fixed) | 3.0 / 4.0 | 3.5 / 3.5 | 5.0 (fixed) |
| **Shift** | 5.0 | 12.0 | 5.0 | 5.0 |
| **VAE** | Wan2.1 (z=16) | Wan2.1 (z=16) | Wan2.1 (z=16) + encoder | Wan2.2 (z=48) |
### Step 1: Download Weights
Download the original PyTorch checkpoints from HuggingFace using the `huggingface-cli` tool (install with `pip install huggingface_hub`):
**Wan2.1**
```bash
# Text-to-Video 1.3B (fast, fits in ~4 GB)
huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir ./Wan2.1-T2V-1.3B
# Text-to-Video 14B
huggingface-cli download Wan-AI/Wan2.1-T2V-14B --local-dir ./Wan2.1-T2V-14B
```
**Wan2.2**
```bash
# Text-to-Video 14B
huggingface-cli download Wan-AI/Wan2.2-T2V-A14B --local-dir ./Wan2.2-T2V-A14B
# Image-to-Video 14B
huggingface-cli download Wan-AI/Wan2.2-I2V-A14B --local-dir ./Wan2.2-I2V-A14B
# Text+Image-to-Video 5B (uses a different VAE — z_dim=48)
huggingface-cli download Wan-AI/Wan2.2-TI2V-5B --local-dir ./Wan2.2-TI2V-5B
```
Each downloaded directory will have this structure:
```
Wan2.1-T2V-*/
├── models_t5_umt5-xxl-enc-bf16.pth # T5 text encoder
├── Wan2.1_VAE.pth # 3D VAE
└── diffusion_pytorch_model*.safetensors # transformer (single)
Wan2.2-T2V-A14B/ or Wan2.2-I2V-A14B/
├── models_t5_umt5-xxl-enc-bf16.pth
├── Wan2.1_VAE.pth
├── low_noise_model/ # dual-model low-noise transformer
└── high_noise_model/ # dual-model high-noise transformer
Wan2.2-TI2V-5B/
├── models_t5_umt5-xxl-enc-bf16.pth
├── Wan2.2_VAE.pth # different VAE (z_dim=48)
└── diffusion_pytorch_model*.safetensors # transformer (single)
```
> **Wan2.2 I2V-14B** shares the same directory structure as Wan2.2 T2V. The conversion script auto-detects I2V from the model's `config.json` (`model_type: "i2v"`, `in_dim: 36`).
### Step 2: Convert to MLX Format
The conversion script auto-detects the model version from the directory structure (presence of `low_noise_model/` → Wan2.2 dual model) and the model type from `config.json` (I2V vs T2V).
#### Wan2.1 T2V 1.3B
```bash
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.1-T2V-1.3B \
--output-dir ./Wan2.1-T2V-1.3B-MLX
```
#### Wan2.1 T2V 14B
```bash
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.1-T2V-14B \
--output-dir ./Wan2.1-T2V-14B-MLX
```
#### Wan2.2 T2V 14B
```bash
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.2-T2V-A14B \
--output-dir ./Wan2.2-T2V-A14B-MLX
```
#### Wan2.2 I2V 14B
```bash
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.2-I2V-A14B \
--output-dir ./Wan2.2-I2V-A14B-MLX
```
The I2V model is auto-detected from `config.json`; the output will include a `vae_encoder.safetensors` used to encode the conditioning image.
#### Wan2.2 TI2V 5B
```bash
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.2-TI2V-5B \
--output-dir ./Wan2.2-TI2V-5B-MLX
```
The TI2V model uses a different VAE (`z_dim=48`, `vae_stride=(4,16,16)`) and is auto-detected during conversion.
---
You can also pass `--model-version 2.1` or `--model-version 2.2` to force the version instead of relying on auto-detection.
#### Conversion Options
| Option | Default | Description |
|--------|---------|-------------|
| `--checkpoint-dir` | (required) | Path to original PyTorch checkpoint directory |
| `--output-dir` | `wan_mlx_model` | Output path for MLX model |
| `--dtype` | `bfloat16` | Target dtype (`float16`, `float32`, `bfloat16`) |
| `--model-version` | `auto` | Model version: `2.1`, `2.2`, or `auto` |
| `--quantize` | off | Quantize transformer weights for reduced memory |
| `--bits` | `4` | Quantization bits: `4` or `8` |
| `--group-size` | `64` | Quantization group size: `32`, `64`, or `128` |
The converter produces:
```
wan_mlx/
├── config.json # Model configuration
├── t5_encoder.safetensors # T5 UMT5-XXL text encoder
├── vae.safetensors # 3D VAE decoder
├── vae_encoder.safetensors # 3D VAE encoder (I2V-14B only)
├── model.safetensors # (Wan2.1) Single transformer
├── low_noise_model.safetensors # (Wan2.2) Low-noise transformer
└── high_noise_model.safetensors # (Wan2.2) High-noise transformer
```
### Step 3: Generate Video
#### Wan2.1 T2V 1.3B
```bash
python -m mlx_video.generate_wan \
--model-dir ./Wan2.1-T2V-1.3B-MLX \
--prompt "A cat playing piano in a cozy living room, cinematic lighting" \
--width 832 --height 480 --num-frames 81 \
--steps 50 --guide-scale 5.0 \
--seed 42 \
--output-path wan21_1b.mp4
```
#### Wan2.1 T2V 14B
```bash
python -m mlx_video.generate_wan \
--model-dir ./Wan2.1-T2V-14B-MLX \
--prompt "A woman walks through a misty forest at dawn, slow motion, cinematic" \
--width 1280 --height 704 --num-frames 81 \
--steps 50 --guide-scale 5.0 \
--seed 42 \
--output-path wan21_14b.mp4
```
> **Tip**: If the first few frames look washed out or have color artifacts, add `--trim-first-frames 1` to generate 4 extra frames at the start and discard them. With the `unipc` scheduler (default), **10 steps** often gives satisfying results — useful for quick iteration.
#### Wan2.2 T2V 14B
Wan2.2 uses a dual-model pipeline (separate high-noise and low-noise transformers) and takes guidance as a `high,low` pair:
```bash
python -m mlx_video.generate_wan \
--model-dir ./Wan2.2-T2V-A14B-MLX \
--prompt "Two astronauts playing chess on the surface of the moon, dramatic lighting, 8K" \
--negative-prompt "low quality, blurry, distorted" \
--width 1280 --height 704 --num-frames 81 \
--steps 40 --guide-scale "3.0,4.0" \
--seed 42 \
--output-path wan22_t2v.mp4
```
> **Tip**: With the `unipc` scheduler (default), **10 steps** often produces satisfying results for 14B models — a significant speed-up with minimal quality loss. Try `--steps 10` for quick iterations.
#### Wan2.2 I2V 14B
Image-to-video: animates a starting image guided by a text prompt. Pass the image with `--image`:
```bash
python -m mlx_video.generate_wan \
--model-dir ./Wan2.2-I2V-A14B-MLX \
--image ./my_photo.png \
--prompt "The person slowly turns their head and smiles, cinematic, natural lighting" \
--negative-prompt "low quality, blurry, distorted" \
--width 1280 --height 704 --num-frames 81 \
--steps 40 --guide-scale "3.5,3.5" \
--seed 42 \
--output-path wan22_i2v.mp4
```
> **Tip**: As with T2V, `--steps 10` with the `unipc` scheduler is often sufficient for fast prototyping.
#### Wan2.2 TI2V 5B
Text+image-to-video: a single-model variant with a larger VAE (`z_dim=48`). Resolution must be divisible by **32** (not 16 as with other models):
```bash
python -m mlx_video.generate_wan \
--model-dir ./Wan2.2-TI2V-5B-MLX \
--image ./my_photo.png \
--prompt "The subject waves hello, warm sunlight, film grain" \
--width 1280 --height 704 --num-frames 41 \
--steps 40 --guide-scale 5.0 \
--seed 42 \
--output-path wan22_ti2v.mp4
```
> **Note**: The 5B model is fast — 40 steps run quickly and are recommended for best quality.
> **Frame count**: `--num-frames` must satisfy `4n+1` for all models (e.g. 5, 9, 13, 21, 41, 81, 101 …).
> **Resolution**: Always use the model's native resolution. While generation will succeed at other sizes, mismatched resolutions or aspect ratios are likely to produce visual artifacts. Preferred resolutions are:
> - **480P** — 832×480 (landscape) or 480×832 (portrait) — for Wan2.1 1.3B
> - **720P** — 1280×704 (landscape) or 704×1280 (portrait) — for Wan2.1 14B, Wan2.2 T2V/I2V/TI2V
#### Generation Options
| Option | Default | Description |
|--------|---------|-------------|
| `--model-dir` | (required) | Path to converted MLX model directory |
| `--prompt` | (required) | Text prompt |
| `--image` | — | Input image path (I2V and TI2V modes) |
| `--negative-prompt` | config default | Negative guidance prompt |
| `--width` | `1280` | Output width in pixels |
| `--height` | `704` | Output height in pixels |
| `--num-frames` | `81` | Number of frames (must be `4n+1`) |
| `--steps` | config default | Diffusion steps |
| `--guide-scale` | config default | Guidance scale; use `"high,low"` pair for Wan2.2 dual models |
| `--shift` | config default | Noise schedule shift |
| `--seed` | `-1` (random) | Random seed for reproducibility |
| `--output-path` | `output.mp4` | Output video file path |
| `--scheduler` | `unipc` | Solver: `euler`, `dpm++`, or `unipc` |
| `--trim-first-frames` | `0` | Drop N leading frames (fixes first-frame artifacts on 14B models) |
| `--tiling` | `auto` | VAE tiling: `auto`, `none`, `spatial`, `temporal` |
### Quantization (Reduced Memory)
Quantize the transformer weights to reduce memory usage by ~3.4×. Quantization is supported for all model variants and is especially important for running 14B models on devices with limited unified memory:
```bash
# Convert with 4-bit quantization (works for any variant)
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.1-T2V-1.3B \
--output-dir ./Wan2.1-T2V-1.3B-MLX-Q4 \
--quantize --bits 4 --group-size 64
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.1-T2V-14B \
--output-dir ./Wan2.1-T2V-14B-MLX-Q4 \
--quantize --bits 4 --group-size 64
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.2-T2V-A14B \
--output-dir ./Wan2.2-T2V-A14B-MLX-Q4 \
--quantize --bits 4 --group-size 64
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.2-I2V-A14B \
--output-dir ./Wan2.2-I2V-A14B-MLX-Q4 \
--quantize --bits 4 --group-size 64
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.2-TI2V-5B \
--output-dir ./Wan2.2-TI2V-5B-MLX-Q4 \
--quantize --bits 4 --group-size 64
```
You can also quantize an already-converted MLX model without re-converting from PyTorch:
```bash
python -m mlx_video.convert_wan \
--checkpoint-dir ./Wan2.2-T2V-A14B-MLX \
--output-dir ./Wan2.2-T2V-A14B-MLX-Q4 \
--quantize-only --bits 4
```
Quantized models are used exactly the same way — the quantization is auto-detected from `config.json`:
```bash
python -m mlx_video.generate_wan \
--model-dir ./Wan2.2-T2V-A14B-MLX-Q4 \
--prompt "A cat playing piano"
```
**What gets quantized**: Self-attention (Q/K/V/O), cross-attention (Q/K/V/O), and FFN (fc1/fc2) — 10 layers × N blocks = ~95% of model weights. Embeddings, norms, and the output head remain in bfloat16 for precision.
| Model | BF16 Size | 4-bit Size | Notes |
|-------|-----------|------------|-------|
| 1.3B | 2.7 GB | 799 MB | ~3.4x smaller |
| 14B | ~28 GB | ~8 GB | Enables running on 16GB devices |
> **Note**: On Apple Silicon, the 1.3B model fits comfortably in unified memory at bf16. Quantization reduces memory but may not speed up inference for small models. For the 14B model, quantization is essential to fit in memory and will also improve speed.
### Wan Model Specifications
**Transformer (14B)**
- 40 layers, 40 attention heads, dim 5120, head dim 128
- 3-way factorized RoPE (temporal + spatial)
- 14.29B parameters
**Transformer (1.3B, Wan2.1 only)**
- 30 layers, 12 attention heads, dim 1536, head dim 128
- Same architecture, smaller scale
**Text Encoder** — UMT5-XXL (5.68B parameters)
- 24 layers, 64 heads, dim 4096, vocab 256K
**VAE** — 3D causal convolution decoder (72.6M parameters)
- Latent channels: 16
- Compression: 4× temporal, 8× spatial
---
## LoRA Support
LoRA's can be used with the `--lora-high` and `--lora-low` command line switches.
For example, for using the the distilled [Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) LoRA, use the following command. Lightning speeds up generation by using only 4 steps and a CFG scale of 1.
```bash
python -m mlx_video.generate_wan \
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-T2V-A14B-MLX \
--width 480 \
--height 704 \
--num-frames 41 \
--prompt "Two dogs of the poodle breed sitting on a beach wearing sunglasses, nodding with their heads, close up, cinematic, sunset" \
--steps 4 \
--guide-scale 1 \
--trim-first-frames 1 \
--seed 2391784614 \
--lora-high /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/high_noise_model.safetensors 1 \
--lora-low /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/low_noise_model.safetensors 1
```
## Enjoy
![Poodles](../../../examples/poodles-wan.gif)

View File

@@ -0,0 +1,2 @@
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan.model import WanModel

View File

@@ -0,0 +1,221 @@
import mlx.core as mx
import mlx.nn as nn
from .rope import rope_apply
def _linear_dtype(layer) -> mx.Dtype:
"""Get the compute dtype of a linear layer, handling QuantizedLinear and LoRA wrappers."""
# Unwrap LoRA wrapper to get the underlying linear layer
inner = getattr(layer, "linear", layer)
if isinstance(inner, nn.QuantizedLinear):
return inner.scales.dtype
return inner.weight.dtype
class WanRMSNorm(nn.Module):
"""RMS normalization with learnable scale."""
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = mx.ones((dim,))
def __call__(self, x: mx.array) -> mx.array:
return mx.fast.rms_norm(x, self.weight, self.eps)
class WanLayerNorm(nn.Module):
"""LayerNorm computed in float32, with optional affine."""
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False):
super().__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
if elementwise_affine:
self.weight = mx.ones((dim,))
self.bias = mx.zeros((dim,))
def __call__(self, x: mx.array) -> mx.array:
if self.elementwise_affine:
return mx.fast.layer_norm(x, self.weight, self.bias, self.eps)
else:
return mx.fast.layer_norm(x, None, None, self.eps)
class WanSelfAttention(nn.Module):
"""Self-attention with QK normalization and 3-way factorized RoPE."""
def __init__(
self,
dim: int,
num_heads: int,
window_size: tuple = (-1, -1),
qk_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
assert dim % num_heads == 0
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.scale = self.head_dim**-0.5
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
def __call__(
self,
x: mx.array,
seq_lens: list,
grid_sizes: list,
freqs: mx.array,
rope_cos_sin: tuple | None = None,
attn_mask: mx.array | None = None,
) -> mx.array:
b, s, _ = x.shape
n, d = self.num_heads, self.head_dim
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = _linear_dtype(self.q)
x_w = x.astype(w_dtype)
q = self.q(x_w)
k = self.k(x_w)
if self.norm_q is not None:
q = self.norm_q(q)
if self.norm_k is not None:
k = self.norm_k(k)
q = q.reshape(b, s, n, d)
k = k.reshape(b, s, n, d)
v = self.v(x_w).reshape(b, s, n, d)
# RoPE in float32 for precision (official uses float64)
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
k = k.astype(w_dtype).transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
# Use precomputed mask or build from seq_lens
mask = attn_mask
if mask is None and any(sl < s for sl in seq_lens):
mask = mx.zeros((b, 1, 1, s), dtype=q.dtype)
for i, sl in enumerate(seq_lens):
mask[i, :, :, sl:] = -1e9
# Use memory-efficient scaled dot-product attention
# mx.fast.scaled_dot_product_attention expects [B, N, L, D]
if mask is not None:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale, mask=mask
)
else:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale
)
out = out.transpose(0, 2, 1, 3).reshape(b, s, -1)
return self.o(out)
class WanCrossAttention(nn.Module):
"""Cross-attention: Q from hidden states, K/V from text context."""
def __init__(
self,
dim: int,
num_heads: int,
qk_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
def prepare_kv(self, context: mx.array) -> tuple:
"""Pre-compute K and V projections for caching.
Args:
context: [B, L_ctx, dim]
Returns:
(k, v) each [B, N, L_ctx, D] ready for attention
"""
b = context.shape[0]
n, d = self.num_heads, self.head_dim
# Cast to compute dtype for efficient matmul
w_dtype = _linear_dtype(self.k)
ctx = context.astype(w_dtype)
k = self.k(ctx)
if self.norm_k is not None:
k = self.norm_k(k)
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
return k, v
def __call__(
self,
x: mx.array,
context: mx.array,
context_lens: list | None = None,
kv_cache: tuple | None = None,
) -> mx.array:
b = x.shape[0]
n, d = self.num_heads, self.head_dim
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = _linear_dtype(self.q)
q = self.q(x.astype(w_dtype))
if self.norm_q is not None:
q = self.norm_q(q)
q = q.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
if kv_cache is not None:
k, v = kv_cache
else:
ctx = context.astype(w_dtype)
k = self.k(ctx)
if self.norm_k is not None:
k = self.norm_k(k)
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
# Optional context masking
mask = None
if context_lens is not None:
ctx_len = k.shape[2]
mask = mx.zeros((b, 1, 1, ctx_len), dtype=q.dtype)
for i, cl in enumerate(context_lens):
mask[i, :, :, cl:] = -1e9
if mask is not None:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale, mask=mask
)
else:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale
)
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d)
return self.o(out)

View File

@@ -0,0 +1,129 @@
from dataclasses import dataclass
from typing import Tuple, Union
from mlx_video.models.ltx.config import BaseModelConfig
@dataclass
class WanModelConfig(BaseModelConfig):
"""Configuration for Wan T2V models (supports both 2.1 and 2.2)."""
model_type: str = "t2v"
model_version: str = "2.2"
patch_size: Tuple[int, int, int] = (1, 2, 2)
text_len: int = 512
in_dim: int = 16
dim: int = 5120
ffn_dim: int = 13824
freq_dim: int = 256
text_dim: int = 4096
out_dim: int = 16
num_heads: int = 40
num_layers: int = 40
window_size: Tuple[int, int] = (-1, -1)
qk_norm: bool = True
cross_attn_norm: bool = True
eps: float = 1e-6
# VAE
vae_stride: Tuple[int, int, int] = (4, 8, 8)
vae_z_dim: int = 16
# Inference
dual_model: bool = True
boundary: float = 0.875
sample_shift: float = 12.0
sample_steps: int = 40
sample_guide_scale: Union[float, Tuple[float, float]] = (3.0, 4.0)
num_train_timesteps: int = 1000
sample_fps: int = 16
frame_num: int = 81
sample_neg_prompt: str = (
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,"
"最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部"
"画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,"
"杂乱的背景,三条腿,背景人很多,倒着走"
)
# Resolution constraints
max_area: int = 0 # 0 = no limit; e.g. 704*1280 for TI2V-5B
t5_vocab_size: int = 256384
t5_dim: int = 4096
t5_dim_attn: int = 4096
t5_dim_ffn: int = 10240
t5_num_heads: int = 64
t5_num_layers: int = 24
t5_num_buckets: int = 32
@property
def head_dim(self) -> int:
return self.dim // self.num_heads
@classmethod
def wan21_t2v_14b(cls) -> "WanModelConfig":
"""Wan2.1 T2V 14B: single model, 40 layers, dim=5120."""
return cls(
model_version="2.1",
dual_model=False,
boundary=0.0,
sample_shift=5.0,
sample_steps=50,
sample_guide_scale=5.0,
)
@classmethod
def wan21_t2v_1_3b(cls) -> "WanModelConfig":
"""Wan2.1 T2V 1.3B: single model, 30 layers, dim=1536."""
return cls(
model_version="2.1",
dim=1536,
ffn_dim=8960,
num_heads=12,
num_layers=30,
dual_model=False,
boundary=0.0,
sample_shift=5.0,
sample_steps=50,
sample_guide_scale=5.0,
)
@classmethod
def wan22_t2v_14b(cls) -> "WanModelConfig":
"""Wan2.2 T2V 14B: dual model, 40 layers, dim=5120 (default)."""
return cls()
@classmethod
def wan22_i2v_14b(cls) -> "WanModelConfig":
"""Wan2.2 I2V 14B: dual model, image-to-video, 40 layers, dim=5120."""
return cls(
model_type="i2v",
in_dim=36,
out_dim=16,
dual_model=True,
boundary=0.900,
sample_shift=5.0,
sample_guide_scale=(3.5, 3.5),
max_area=704 * 1280,
)
@classmethod
def wan22_ti2v_5b(cls) -> "WanModelConfig":
"""Wan2.2 TI2V 5B: text+image to video, 30 layers, dim=3072."""
return cls(
model_type="ti2v",
dim=3072,
ffn_dim=14336,
in_dim=48,
out_dim=48,
num_heads=24,
num_layers=30,
vae_z_dim=48,
vae_stride=(4, 16, 16),
dual_model=False,
boundary=0.0,
sample_shift=5.0,
sample_steps=40,
sample_guide_scale=5.0,
sample_fps=24,
max_area=704 * 1280,
)

View File

@@ -0,0 +1,394 @@
# Wan2.2 I2V-14B Diagnostic Report
This document records the systematic diagnostic methodology used to debug the Wan2.2 I2V-14B (Image-to-Video, 14 billion parameter) pipeline in mlx-video, along with every bug found, its root cause, and fix.
## Table of Contents
- [Overview](#overview)
- [Architecture Summary](#architecture-summary)
- [Diagnostic Methodology](#diagnostic-methodology)
- [Bug 1: Text Embedding Cross-Contamination](#bug-1-text-embedding-cross-contamination)
- [Bug 2: VAE Encoder Weights Excluded from Conversion](#bug-2-vae-encoder-weights-excluded-from-conversion)
- [Bug 3: RoPE Frequency Computation (original)](#bug-3-rope-frequency-computation-original)
- [Bug 6: RoPE Frequency Distribution (Bug 3 Fix Was Wrong)](#bug-6-rope-frequency-distribution-bug-3-fix-was-wrong)
- [Bug 4: VAE Encoder Temporal Downsample Order](#bug-4-vae-encoder-temporal-downsample-order)
- [Bug 5: Non-Chunked VAE Encoding](#bug-5-non-chunked-vae-encoding)
- [Verified Correct Components](#verified-correct-components)
- [Performance Optimizations](#performance-optimizations)
- [Resolved: CFG Effectiveness](#resolved-cfg-effectiveness-was-open-investigation)
- [Reference Implementation](#reference-implementation)
- [Useful Diagnostic Commands](#useful-diagnostic-commands)
---
## Overview
The I2V-14B pipeline takes an input image and generates a video using a dual-model diffusion transformer. The initial implementation produced severely broken output — first frame showed the image, subsequent frames degraded to noise, checkerboard artifacts, or flat grey.
Through a systematic component-by-component comparison against the reference PyTorch implementation, **five bugs** were found and fixed. The approach was to verify each component in isolation numerically, then narrow down failures to the subsystem level.
### Timeline of Symptoms
| Stage | Symptom | Root Cause |
|-------|---------|------------|
| Initial | Grey/blurry frames after frame 1 | Non-chunked VAE encoding (Bug 5) |
| After chunked encoding fix | First frame OK, rest degrades to noise | Text embedding cross-contamination (Bug 1) + RoPE frequencies (Bug 3) |
| After text + RoPE fix | Severe 8px checkerboard on frames 4+ | VAE encoder temporal downsample order (Bug 4) |
| After VAE fix | Image in frames 0-3, grey frames 4+ | CFG effectiveness issue (open investigation) |
---
## Architecture Summary
```
I2V-14B Pipeline:
Input Image → VAE Encoder → [16, T_lat, H_lat, W_lat]
Mask Construction → [4, T_lat, H_lat, W_lat]
y = concat(mask, encoded_video) → [20, T_lat, H_lat, W_lat]
Noise [16, T_lat, H_lat, W_lat] + y → [36, T_lat, H_lat, W_lat]
Dual DiT (40 layers, 5120 dim) × 40 denoising steps
Denoised Latent [16, T_lat, H_lat, W_lat]
VAE Decoder → Video [3, F, H, W]
```
**Key parameters:**
- `in_dim=36` (16 noise + 4 mask + 16 image latents), `out_dim=16`
- Dual model: HIGH noise (t ≥ 900) and LOW noise (t < 900)
- 40 steps, shift=5.0, guide_scale=(3.5, 3.5)
- Uses Wan2.1 VAE (z_dim=16, stride 4×8×8)
---
## Diagnostic Methodology
### 1. Component-Level Numerical Verification
Each component was tested in isolation against the reference PyTorch implementation:
1. **Load identical inputs** (same random seed, same image, same prompt)
2. **Run through reference** (on CPU where possible) and save intermediate tensors as `.npy`
3. **Run through MLX** with the same inputs
4. **Compare outputs** with `np.abs(ours - ref).max()` and relative difference metrics
Components tested this way:
- RoPE frequency parameters and rotation output
- Time embedding (sinusoidal → MLP → projection)
- Patchify (reshape+Linear vs Conv3d)
- Unpatchify (transpose-based vs einsum)
- Scheduler (UniPC) timesteps and step formulas
- VAE encoder output (frame-by-frame comparison)
- Text embeddings (per-model MLP output)
- Cross-attention K/V cache shapes
- Mask construction values
### 2. Artifact Analysis
When visual artifacts appeared, quantitative metrics were used to characterize them:
- **Checkerboard metric**: Difference between even-indexed and odd-indexed pixels at patch boundaries. Values > 20 indicate visible checkerboard.
- **FFT frequency analysis**: Power at the 8px spatial frequency (matches VAE stride). 3× normal power confirmed VAE-stride-aligned artifacts.
- **Per-frame statistics**: Mean, std, min, max for each decoded video frame to track temporal degradation.
- **Frame difference**: `mean(|frame[i] - frame[i-1]|)` to measure motion vs static content.
### 3. Isolation Testing
- **VAE round-trip test**: Encode image+zeros → decode. If clean, VAE decoder is not the source.
- **Single-step model output**: Run one diffusion step and compare cond vs uncond predictions to check CFG effectiveness.
- **Patchify/unpatchify synthetic test**: Pass structured gradient through unpatchify to verify spatial ordering.
- **Resolution sweeps**: Test at 480×272, 640×384, 1280×720 to check resolution dependence.
- **Step count sweeps**: Test at 5, 20, 40 steps to distinguish convergence issues from model bugs.
### 4. Weight Comparison
Direct comparison of converted MLX weights against original PyTorch weights:
```python
# Load both weight sets
pt_weights = torch.load("model.safetensors")
mlx_weights = mx.load("model.safetensors")
# Compare each key
for key in pt_weights:
diff = np.abs(np.array(pt_weights[key]) - np.array(mlx_weights[key])).max()
```
Expected: max diff ≈ 0.001 (bfloat16 rounding). Actual: confirmed for all keys.
---
## Bug 1: Text Embedding Cross-Contamination
**Symptom:** Model ignores text prompt, generated frames lack semantic content.
**Root Cause:** For the dual-model architecture (high-noise and low-noise experts), text embeddings were computed using only `low_noise_model.embed_text()` and reused for both models' cross-attention K/V caches. The two models have **different** text embedding MLP weights — 42% relative mean difference in output.
**How Found:** Compared `text_embedding_0.weight` and `text_embedding_1.weight` between `high_noise_model.safetensors` and `low_noise_model.safetensors`. Found 17.9% and 26.3% relative differences in the weight matrices.
**Fix:** Compute separate text embeddings per model:
```python
# Before (broken):
context_emb = low_noise_model.embed_text([context, context_null])
cross_kv = low_noise_model.prepare_cross_kv(context_emb) # used for BOTH models
# After (correct):
context_emb_low = low_noise_model.embed_text([context, context_null])
context_emb_high = high_noise_model.embed_text([context, context_null])
cross_kv_low = low_noise_model.prepare_cross_kv(context_emb_low)
cross_kv_high = high_noise_model.prepare_cross_kv(context_emb_high)
```
**File:** `mlx_video/generate_wan.py` (lines 333349)
**Commit:** `a85b1c21`
---
## Bug 2: VAE Encoder Weights Excluded from Conversion
**Symptom:** VAE encoder produces constant output regardless of input image (all-zero weights after conversion).
**Root Cause:** The conversion script only included encoder weights for `model_type == "ti2v"` (TI2V-5B), not for `"i2v"` (I2V-14B). Since `load_vae_encoder()` uses `strict=False`, missing encoder weights were silently ignored, resulting in random initialization.
**How Found:** Traced through `convert_wan.py` and found `include_encoder = config.model_type == "ti2v"`. Cross-referenced with the fact that I2V-14B also requires a VAE encoder (for image conditioning).
**Fix:**
```python
# Before:
include_encoder = config.model_type == "ti2v"
# After:
include_encoder = config.model_type in ("ti2v", "i2v")
```
**Note:** The user's specific model happened to be manually converted with encoder weights already present, so this fix was preventive for future conversions.
**File:** `mlx_video/convert_wan.py` (line 424)
---
## Bug 3: RoPE Frequency Computation (original)
**Symptom:** Progressive 2px checkerboard artifacts on generated frames, increasing with temporal distance from the conditioned frame.
**Root Cause (original):** Our original code called `rope_params` three times but applied them incorrectly (per-axis in the model init, then rope_apply did NOT split). This was initially "fixed" by switching to a single `rope_params(1024, head_dim=128)` call, which reduced checkerboard but introduced Bug 6 (see below).
**File:** `mlx_video/models/wan/model.py`
**Commit:** `3da4a637`
---
## Bug 6: RoPE Frequency Distribution (Bug 3 Fix Was Wrong)
**Symptom:** I2V generates input image in frames 03, colorful checkerboard on frame 4, then grey frames. CFG cond/uncond predictions nearly identical. Model cannot produce coherent motion.
**Root Cause:** The Bug 3 "fix" replaced three separate `rope_params` calls with a single `rope_params(1024, 128)`. But the reference (`wan/modules/model.py` lines 400405) actually uses **three separate calls with different dimension normalizations**, concatenated:
```python
# Reference (CORRECT):
d = dim // num_heads # 128
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)), # rope_params(1024, 44)
rope_params(1024, 2 * (d // 6)), # rope_params(1024, 42)
rope_params(1024, 2 * (d // 6)) # rope_params(1024, 42)
], dim=1)
```
Each axis gets its own full frequency range [θ^0, θ^(-~0.95)]. The single-call approach gave:
- Temporal: low frequencies only [1.0 → 0.049]
- Height: medium frequencies only [0.042 → 0.002] (should start at 1.0!)
- Width: high frequencies only [0.002 → 0.0001] (should start at 1.0!)
The height/width position encoding was essentially destroyed — nearby spatial positions were indistinguishable (max diff 0.958 for height, 0.998 for width vs reference).
**How Found:** Direct line-by-line comparison of `WanModel.__init__` freq construction between reference `wan/modules/model.py` and our `models/wan/model.py`. Numerical verification confirmed the three-call approach gives each axis a full [0, ~1) exponent range, while the single-call monotonically assigns low→high across axes.
**Fix:**
```python
d = dim // config.num_heads
self.freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
```
**Verification:** Max diff vs reference cos/sin: 0.00000000 (exact float32 match).
**Impact:** Affects ALL Wan models (T2V, I2V, TI2V). Resolves the "Open Investigation: CFG Effectiveness" issue — the model could not produce meaningful cond/uncond differences because it couldn't encode spatial positions.
**File:** `mlx_video/models/wan/model.py` (line 155)
---
## Bug 4: VAE Encoder Temporal Downsample Order
**Symptom:** Massive checkerboard artifacts aligned to VAE spatial stride (8px period). VAE encoder output for frames 14 showed decreasing std (0.37→1.19) while reference showed stable std (0.95→1.34).
**Root Cause:** The VAE encoder has 3 downsampling stages. Two perform spatial+temporal downsampling (`downsample3d`) and one performs spatial-only (`downsample2d`). The order matters:
```
Reference: [False, True, True] → stage 0: 2d, stage 1: 3d, stage 2: 3d
Ours: [True, True, False] → stage 0: 3d, stage 1: 3d, stage 2: 2d ← WRONG
```
This caused temporal downsampling to happen at the wrong resolution stages (96-dim instead of 384-dim), corrupting temporal feature propagation.
**How Found:** Installed `einops` in the reference environment and ran the reference PyTorch VAE encoder on CPU. Compared frame-by-frame latent output:
- Frame 0 matched exactly (diff=0.0000) — spatial-only processing was correct
- Frames 14 had massive differences — proved temporal processing was broken
Then traced through the reference `_video_vae()` function and found it sets `temperal_downsample=[False, True, True]`, while our `Encoder3d` class used the wrong default `[True, True, False]`.
**Fix:**
```python
# In Encoder3d.__init__, change default:
temporal_downsample = [False, True, True] # was [True, True, False]
```
**Impact:** Encoder output now matches reference within float32 precision (max_diff=2.2e-5). Checkerboard metric dropped from 6080 to 0.17.7.
**File:** `mlx_video/models/wan/vae.py` (line 370)
**Commit:** `3da4a637`
---
## Bug 5: Non-Chunked VAE Encoding
**Symptom:** First 45 frames grey, then blurred version of image appears.
**Root Cause:** The reference VAE encoder uses **chunked encoding** with temporal caching (`feat_cache`):
1. Encode first frame alone (1 frame)
2. Encode remaining frames in chunks of 4, with cached temporal features propagating across chunks
3. Each `CausalConv3d` caches last 2 temporal frames from its output, prepending them to the next chunk's input
Our original implementation encoded all frames at once with zero-padded causal convolutions. The temporal feature propagation is fundamentally different because:
- Chunked: real features from previous chunks serve as causal context
- Non-chunked: zeros serve as causal context for the start
**How Found:** Studied the reference `CausalConv3d` caching mechanism (`feat_cache`, `feat_idx`) and traced the temporal dimension through all encoding stages. Confirmed that non-chunked encoding produces different output by comparing tensor shapes and values.
**Fix:** Implemented full chunked encoding with temporal caching:
- Added `cache_x` parameter to `CausalConv3d.__call__`
- Added `feat_cache`/`feat_idx` propagation to `ResidualBlock`, `Resample`, `Encoder3d`
- Rewrote `WanVAE.encode()` with chunked loop (1-frame first chunk, then 4-frame chunks)
- 24 cache slots across the encoder (1 conv1 + 18 downsamples + 4 middle + 1 head)
**File:** `mlx_video/models/wan/vae.py` (multiple methods)
**Commit:** `b6a94c4c`
---
## Verified Correct Components
These components were numerically verified against the reference and are **not** sources of bugs:
| Component | Method | Max Diff | Notes |
|-----------|--------|----------|-------|
| Weight conversion | Direct tensor comparison | ~0.001 | bfloat16 rounding only |
| RoPE rotation | Standalone comparison (float32 vs float64) | 1.3e-5 | Complex vs real multiplication equivalent |
| Time embedding | Full MLP comparison (sinusoidal→embed→project) | 7e-4 | 0.03% relative |
| Patchify | Conv3d vs reshape+Linear | 3.5e-3 | 0.16% relative |
| Unpatchify | einsum vs transpose(6,0,3,1,4,2,5) | exact | Identical operation |
| Scheduler (UniPC) | Formula-level audit + timestep comparison | exact | Predictor, corrector, lambda, rhos all match |
| Mask construction | Value comparison | exact | [4, T_lat, H_lat, W_lat], first temporal=1 |
| CFG formula | Code audit | — | `uncond + gs * (cond - uncond)` correct order |
| VAE decoder | Round-trip test (encode→decode) | clean | No checkerboard in round-trip output |
| Cross-attention K/V | Shape and value audit | — | Batch dimension preserved correctly |
---
## Performance Optimizations
Applied alongside bug fixes to improve inference speed:
### Pre-Computation (Before Diffusion Loop)
- **Cross-attention K/V caching**: Precompute K/V projections for all 40 blocks once
- **RoPE cos/sin precomputation**: Build frequency tensors once instead of per-step broadcast/concat
- **Attention mask precomputation**: Build padding mask once, pass via kwargs
- **Inverse frequency caching**: Store sinusoidal `inv_freq` in `__init__` instead of recomputing
- **Timestep list conversion**: `sched.timesteps.tolist()` before loop to avoid `.item()` sync
### Per-Step Optimizations
- **Single patchify + broadcast for CFG B=2**: Detect identical batch inputs, patchify once and broadcast instead of duplicating the Linear projection
- **Vectorized RoPE**: When all batch elements share the same grid size, apply rotation to the full batch tensor instead of looping per element
- **Redundant type cast removal**: MLX type promotion handles `bfloat16 * float32 → float32` automatically — removed 240 unnecessary graph nodes per step (6 casts × 40 blocks)
- **Euler scheduler sync fix**: Pre-store sigmas as Python floats to avoid `.item()` evaluation sync
---
## Resolved: CFG Effectiveness (was Open Investigation)
**Symptom:** Generated video shows the input image in frames 03 (latent frame 0), then grey/flat frames for the rest. Cond and uncond predictions were nearly identical.
**Resolution:** This was caused by Bug 6 (incorrect RoPE frequency distribution). The single `rope_params(1024, 128)` call gave height frequencies starting at 0.042 and width at 0.002 (instead of 1.0 for both), making the model unable to encode spatial positions. This caused the transformer to produce nearly identical outputs regardless of text conditioning, explaining the tiny cond/uncond differences.
---
## Reference Implementation
The reference PyTorch implementation is at `/Users/daniel/Projects/Wan2.2/`:
| File | Contents |
|------|----------|
| `wan/image2video.py` | I2V pipeline (y construction, mask, diffusion loop) |
| `wan/modules/model.py` | DiT model (forward pass, RoPE, patchify) |
| `wan/modules/vae2_1.py` | VAE encoder/decoder with chunked encoding |
| `wan/utils/fm_solvers_unipc.py` | UniPC scheduler |
| `wan/configs/wan_i2v_A14B.py` | Model configuration |
Key structural differences between reference and our implementation:
- Reference runs **separate B=1 forward passes** for cond/uncond; we batch as B=2
- Reference uses `torch.amp.autocast('cuda', dtype=bfloat16)` with explicit float32 blocks; we cast via weight dtype
- Reference uses `Conv3d` for patchify; we use equivalent `reshape + Linear`
- Reference casts timesteps to `int64`; we keep as float (diff < 1.0)
---
## Useful Diagnostic Commands
### Run I2V-14B generation
```bash
python -m mlx_video.generate_wan \
--prompt "A woman smiles at camera" \
--image start.png \
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-I2V-A14B-MLX \
--num-frames 17 --steps 40 \
--height 384 --width 640 \
--output output_i2v.mp4
```
### Check VAE encoder output
```python
import mlx.core as mx, numpy as np
from mlx_video.models.wan.vae import WanVAE
# Load VAE and encode an image
latents = vae.encode(video_tensor) # [1, 16, T_lat, H_lat, W_lat]
for t in range(latents.shape[2]):
frame = np.array(latents[0, :, t])
print(f"Frame {t}: mean={frame.mean():.4f} std={frame.std():.4f}")
```
### Analyze video frame quality
```python
import cv2, numpy as np
cap = cv2.VideoCapture("output.mp4")
while True:
ret, frame = cap.read()
if not ret: break
# Checkerboard metric: high values indicate patch-boundary artifacts
checker = np.abs(frame[::2, ::2].astype(float) - frame[1::2, 1::2].astype(float)).mean()
print(f"std={frame.std():.1f} checker={checker:.1f}")
```
### Compare weights between PyTorch and MLX
```python
import torch, mlx.core as mx, numpy as np
pt = torch.load("model.pt", map_location="cpu")
mlx_w = mx.load("model.safetensors")
for key in sorted(pt.keys()):
if key in mlx_w:
diff = np.abs(pt[key].float().numpy() - np.array(mlx_w[key])).max()
if diff > 0.01:
print(f"LARGE DIFF {key}: {diff:.6f}")
```

View File

@@ -0,0 +1,285 @@
# Wan2.2 MLX Implementation Notes
> Learnings and key decisions from porting Wan2.2 (TI2V-5B / T2V-14B / I2V-14B / T2V-1.3B) to Apple MLX.
## Architecture Overview
Wan2.2 is a Diffusion Transformer (DiT) for video generation. Despite early reports, the T2V/TI2V models do **not** use Mixture-of-Experts — they are dense DiT models with a dual-model architecture for the 14B variant (separate high-noise and low-noise denoisers with a boundary timestep).
### Key Parameters
| Model | dim | heads | layers | FFN mult | VAE z_dim | VAE stride | in_dim |
|-------|-----|-------|--------|----------|-----------|------------|--------|
| T2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | 16 |
| I2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | 36 |
| TI2V-5B | 3072 | 24 | 32 | 4×(3072×4/3) | 48 | (4, 16, 16) | 48 |
| T2V-1.3B | 1536 | 12 | 30 | 4×(1536×4/3) | 16 | (4, 8, 8) | 16 |
### Codebase Structure (~3900 lines of Wan2.2 code)
```
mlx_video/
├── generate_wan.py # 483L - Generation pipeline (T2V + I2V)
├── convert_wan.py # 564L - Weight conversion from HuggingFace
└── models/wan/
├── config.py # 113L - Model configs (dataclass presets)
├── model.py # 320L - DiT model (time embed, patchify, unpatchify)
├── transformer.py # 91L - Attention block + FFN
├── attention.py # 211L - Self-attention + cross-attention
├── rope.py # 100L - 3D Rotary Position Embeddings
├── text_encoder.py # 240L - T5 encoder (UMT5-XXL)
├── scheduler.py # 428L - Euler, DPM++ 2M, UniPC schedulers
├── vae.py # 315L - Wan2.1 VAE decoder (4×8×8)
├── vae22.py # 836L - Wan2.2 VAE encoder + decoder (4×16×16)
├── loading.py # 154L - Model loading utilities
└── i2v_utils.py # 58L - I2V mask/preprocessing
```
---
## Critical Bugs & Fixes
### 1. MLX Underscore Attribute Gotcha
**Problem**: MLX's `nn.Module` silently ignores underscore-prefixed attributes (`_layer_0`, `_layer_1`, etc.) in `parameters()` and `load_weights()`. The Wan2.2 VAE had layers named `_layer_N`, causing **87 out of 110 weights to be silently dropped** during loading.
**Fix**: Rename all `_layer_N` attributes to `layer_N`. MLX treats underscore-prefixed attributes as "private" and excludes them from the parameter tree.
**Lesson**: Never use underscore-prefixed names for `nn.Module` sub-modules in MLX.
### 2. Patchify Channel Ordering
**Problem**: The patchify/unpatchify operations transposed channels incorrectly — producing `[C fastest]` layout instead of `[C slowest]`, causing completely garbled video output.
**Fix**: Changed reshape to produce correct `[B, T', H', W', pt*ph*pw*C]` ordering matching PyTorch's contiguous memory layout.
**Lesson**: When porting PyTorch reshape/view operations to MLX, pay close attention to memory layout — PyTorch is row-major by default, and reshape semantics differ when dimensions are reordered.
### 3. VAE AttentionBlock Reshape
**Problem**: Attention block merged batch (B) with channels (C) instead of batch with temporal (T), producing a green checker pattern in output.
**Fix**: Correct reshape from `[B*C, T, H, W]` to `[B*T, C, H, W]` for spatial attention.
### 4. RMS Norm vs L2 Norm
**Problem**: The Wan2.2 VAE uses a class named `RMS_norm` in PyTorch, but it actually computes **L2 normalization** (divide by L2 norm), not RMS normalization (divide by RMS). Using actual RMS norm caused exponential value explosion.
**Fix**: Implement as `x / ||x||₂` instead of `x / sqrt(mean(x²))`.
**Lesson**: Don't trust class names in reference code — read the actual computation.
### 5. Video Codec Green Output
**Problem**: OpenCV's `mp4v` codec on macOS produces green-tinted video.
**Fix**: Switch to `imageio` with `libx264` codec. Fallback chain: imageio → cv2 (avc1) → PNG frames.
---
## Precision & Dtype Flow
### The bfloat16 Autocast Pattern
The official PyTorch implementation uses `torch.autocast("cuda", dtype=torch.bfloat16)` which automatically casts matmul inputs. In MLX, we replicate this manually:
| Operation | Official (PyTorch) | MLX Implementation |
|---|---|---|
| Modulation/gates | float32 (explicit `autocast(enabled=False)`) | `x.astype(mx.float32)` before modulation |
| QKV projections | bfloat16 (outer autocast) | Cast input to `self.q.weight.dtype` |
| RoPE computation | float64 → float32 | float32 (MLX lacks float64 on GPU) |
| Q/K after RoPE | bfloat16 (`q.to(v.dtype)`) | Cast back to weight dtype after RoPE |
| FFN matmuls | bfloat16 (outer autocast) | Cast input to `self.fc1.weight.dtype` |
| Residual stream | float32 | float32 (no cast) |
**Result**: ~16% speedup (47s vs 56s for 20 steps at 480p) with no quality regression.
**Key insight**: Modulation parameters (scale, shift, gate) must stay in float32 — they are small values (~0.010.1) that lose significant precision in bfloat16. The official code explicitly disables autocast for these computations.
### T5 Encoder Precision
The T5 text encoder must run in float32. Bfloat16 weights cause the attention softmax to produce degenerate distributions, which corrupts text conditioning and manifests as blurry patches in generated video. Since T5 only runs once per generation, the performance cost is negligible.
### VAE Decoder Precision
VAE weights must be float32. Bfloat16 VAE decode introduces visible quality loss in the decoded video frames.
---
## Scheduler Implementation Details
### Three Schedulers: Euler, DPM++ 2M, UniPC
All operate in the flow-matching formulation where `sigma` represents the noise level (1.0 = pure noise, 0.0 = clean).
**Euler**: Simple first-order ODE solver. Most stable, recommended for debugging.
**DPM++ 2M**: Second-order multistep solver. Uses previous step's model output for higher-order correction. Requires special handling at boundaries (return `±inf` from `_lambda()` when sigma is 0 or 1).
**UniPC** (default, matches official): Second-order predictor-corrector. The "C" (corrector) part is critical — it refines each step using the already-computed model output at **zero additional model evaluation cost**.
### UniPC Corrector: Must Be Enabled
**Discovery**: Our implementation had `use_corrector=False` by default, but the official Wan2.2 code **always** enables it (there's no flag — the corrector runs whenever `step_index > 0`).
**Impact**: Without the corrector, UniPC degrades to a simple predictor, losing its second-order accuracy advantage.
### UniPC Corrector Coefficients
The corrector coefficients (`rhos_c`) must be computed by solving a linear system, not hardcoded. For order ≥ 2, hardcoding `rhos_c[-1] = 0.5` introduces ~613% error in the correction term across 47+ steps. The fix uses `np.linalg.solve()` to compute exact coefficients.
### Sigma Schedule
```python
# Flow-matching sigma schedule with shift
sigmas = np.linspace(1.0, 1.0 / num_steps, num_steps)
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
```
Default shifts: T2V-14B uses 5.0, TI2V-5B uses 3.0, T2V-1.3B uses 3.0.
---
## Image-to-Video (I2V) Pipelines
Wan2.2 supports two distinct I2V approaches:
### TI2V-5B: Per-Token Timestep Masking
I2V conditions on a reference first frame by giving first-frame latent patches a timestep of 0 (clean) while other patches get the current diffusion timestep:
```python
# mask_tokens: [1, L] — 0 for first-frame patches, 1 for rest
t_tokens = mask_tokens * current_timestep # first-frame → t=0
```
The model receives 2D timestep input `[B, L]` instead of scalar, enabling per-token noise levels.
#### Mask Re-application
After each scheduler step, the first-frame latent is re-injected to prevent drift:
```python
latents = (1.0 - mask) * z_img + mask * latents
```
#### VAE Encoder Temporal Downsample Order
The Wan2.2 VAE encoder has `temporal_downsample = (False, True, True)`:
- Stage 0: Spatial-only downsampling
- Stages 12: Spatial + temporal downsampling
This was incorrectly set to `(True, True, False)` initially, causing wrong spatial processing paths.
### I2V-14B: Channel Concatenation
The I2V-14B model uses a fundamentally different approach — channel concatenation via a `y` tensor:
1. **Encode image**: Resize to target (H, W), create video tensor with image as first frame + zeros → VAE encode through Wan2.1 encoder → `[16, T_lat, H_lat, W_lat]`
2. **Build mask**: Binary mask with 1 for first frame, 0 for rest → rearranged to `[4, T_lat, H_lat, W_lat]`
3. **Construct y**: `y = concat([mask_4ch, encoded_16ch])``[20, T_lat, H_lat, W_lat]`
4. **Channel concat in model**: Before patchify, `x = concat([noise_16ch, y_20ch])` → 36 channels matching `in_dim=36`
Key differences from TI2V-5B:
- Uses **Wan2.1 VAE** (z_dim=16, stride 4,8,8), not Wan2.2 VAE
- Requires the **VAE encoder** (for encoding the reference image)
- Uses **scalar timesteps** (same as T2V) — no per-token masking
- **Dual model** pipeline with boundary=0.900
- Both conditional and unconditional predictions receive the same `y` tensor
---
## Dimension Constraints
### Patchify Alignment
Video dimensions must be divisible by `patch_size × vae_stride`:
- **TI2V-5B**: patch=(1,2,2), stride=(4,16,16) → alignment = **32** pixels
- **T2V-14B**: patch=(1,2,2), stride=(4,8,8) → alignment = **16** pixels
Example: 720p (1280×720) → 720 % 32 ≠ 0, auto-aligns to **704**.
### Frame Count
Frames must satisfy `num_frames = 4n + 1` (e.g., 5, 9, 13, ..., 81) due to temporal VAE stride of 4.
---
## Performance Optimizations
### Batched CFG
Instead of two separate forward passes for conditional and unconditional predictions, batch them into a single B=2 forward pass:
```python
preds = model([latents, latents], t=t_batch, context=context_cfg, ...)
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
```
**Result**: ~40% speedup by amortizing attention overhead.
### Precomputed Text Embeddings & Cross-Attention KV Cache
Text embeddings and cross-attention K/V projections are constant across all diffusion steps. Computing them once and passing as caches eliminates redundant computation.
### Memory Management in Diffusion Loop
```python
# Release temporaries before eval to free memory for graph execution
del noise_pred_cond, noise_pred_uncond, noise_pred, preds
mx.eval(latents)
```
MLX's lazy evaluation means `mx.eval()` triggers the full computation graph. Deleting intermediate arrays before eval allows MLX to reuse their memory during execution.
---
## Weight Conversion
### Key Mapping Patterns
The PyTorch → MLX conversion (`convert_wan.py`) handles several systematic transforms:
1. **Conv3d weight transposition**: PyTorch `(out, in, D, H, W)` → MLX `(out, D, H, W, in)`
2. **Linear weight transposition**: PyTorch `(out, in)` → MLX `(out, in)` (same convention for `nn.Linear`)
3. **Nested module paths**: `blocks.0.self_attn.q.weight` → same paths, MLX loads by dotted key
### Dual-Model Splitting
The T2V-14B uses dual models (high-noise and low-noise). The conversion script splits a single checkpoint into separate files or handles pre-split checkpoints from HuggingFace.
---
## Testing Strategy
332 tests across 10 files, all running in ~5 seconds:
| File | Focus |
|------|-------|
| test_wan_config.py | Config presets, field validation |
| test_wan_attention.py | Self/cross attention, RMSNorm, bf16 autocast |
| test_wan_transformer.py | FFN, attention block, float32 modulation |
| test_wan_model.py | Full DiT forward pass, per-token timesteps |
| test_wan_t5.py | T5 encoder layers and full encoding |
| test_wan_vae.py | VAE 2.1 decoder, VAE 2.2 encoder + decoder |
| test_wan_scheduler.py | All 3 schedulers, cross-scheduler coherence |
| test_wan_convert.py | Weight sanitization and conversion |
| test_wan_generate.py | End-to-end pipeline, I2V masks, dimension alignment |
| test_wan_i2v.py | I2V-14B config, y parameter, VAE encoder, mask construction |
Tests use a tiny config (`dim=64, heads=2, layers=2`) for fast execution. Cross-scheduler coherence tests verify that all three schedulers produce similar outputs from the same noise.
---
## Known Issues
### I2V Quality Degradation
Frames 213 gradually degrade, and frame 14 often has a "flash" artifact. All implementation details have been verified against the official PyTorch code with no discrepancies found. Possible causes:
- Subtle numerical differences from float32 vs float64 RoPE (MLX lacks float64 on GPU)
- MLX-specific attention precision behavior
- Better prompts and 720p resolution (the model's native resolution) help reduce artifacts
### Chinese Negative Prompt
The official Wan2.2 uses a Chinese negative prompt that prevents oversaturation and comic-style artifacts. Correct tokenization requires `ftfy.fix_text()` to normalize fullwidth characters and double HTML unescaping. Without proper text cleaning, the negative prompt tokens don't match the training distribution, causing blurry patches.

View File

@@ -0,0 +1,58 @@
"""Image-to-Video utility functions for Wan2.2."""
import mlx.core as mx
import numpy as np
def preprocess_image(image_path: str, width: int, height: int) -> mx.array:
"""Load, resize, center-crop, and normalize an image for I2V.
Args:
image_path: Path to input image
width: Target width
height: Target height
Returns:
Image tensor [1, 1, H, W, 3] in [-1, 1] (channels-last, batch + temporal dims)
"""
from PIL import Image
img = Image.open(image_path).convert("RGB")
# Resize so that the image covers the target size (LANCZOS)
scale = max(width / img.width, height / img.height)
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
# Center crop
x1 = (img.width - width) // 2
y1 = (img.height - height) // 2
img = img.crop((x1, y1, x1 + width, y1 + height))
# To tensor: [H, W, 3] float32 in [-1, 1]
arr = np.array(img, dtype=np.float32) / 255.0
arr = arr * 2.0 - 1.0 # [0,1] → [-1,1]
return mx.array(arr[None, None]) # [1, 1, H, W, 3]
def build_i2v_mask(z_shape, patch_size):
"""Build temporal mask for I2V: first frame = 0, rest = 1.
Args:
z_shape: Latent shape (C, T, H, W) in channels-first
patch_size: (pt, ph, pw) patch size
Returns:
mask: (C, T, H, W) float32 — 0 for first frame, 1 for rest
mask_tokens: (1, L) float32 — 0 for first-frame tokens, 1 for rest
"""
C, T, H, W = z_shape
mask = mx.ones(z_shape)
# Zero out the first temporal position
mask = mx.concatenate([mx.zeros((C, 1, H, W)), mask[:, 1:]], axis=1)
# Token-level mask for per-token timesteps: subsample to patch grid
# mask shape [C, T, H, W] → take first channel, subsample by patch_size
pt, ph, pw = patch_size
mask_tokens = mask[0, ::pt, ::ph, ::pw] # [T', H', W']
mask_tokens = mask_tokens.reshape(1, -1) # [1, L]
return mask, mask_tokens

View File

@@ -0,0 +1,183 @@
"""Wan model loading utilities."""
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
def load_wan_model(model_path: Path, config, quantization: dict | None = None, loras: list | None = None):
"""Load and initialize WanModel, with optional quantization and LoRA support.
Args:
model_path: Path to model safetensors file
config: WanModelConfig
quantization: Optional dict with 'bits' and 'group_size' keys.
If provided, creates QuantizedLinear stubs before loading.
loras: Optional list of (lora_path, strength) tuples to apply.
"""
from mlx_video.models.wan.model import WanModel
model = WanModel(config)
if quantization:
from mlx_video.convert_wan import _quantize_predicate
nn.quantize(
model,
group_size=quantization["group_size"],
bits=quantization["bits"],
class_predicate=lambda path, m: _quantize_predicate(path, m),
)
weights = mx.load(str(model_path))
# Apply LoRAs: dequantize+merge for quantized models, weight merge for bf16
if loras:
if quantization:
# Dequantize LoRA-targeted layers, merge delta, replace with bf16 Linear.
# Non-LoRA layers stay 4-bit. Zero per-step overhead.
from mlx_video.convert_wan import _load_lora_configs
from mlx_video.lora import apply_loras_to_model
model.load_weights(list(weights.items()), strict=False)
mx.eval(model.parameters())
module_to_loras = _load_lora_configs(loras)
apply_loras_to_model(model, module_to_loras)
mx.eval(model.parameters())
return model
else:
# Weight merging: fold LoRA into bf16 weights before loading
from mlx_video.convert_wan import load_and_apply_loras
weights = load_and_apply_loras(dict(weights), loras)
model.load_weights(list(weights.items()), strict=False)
mx.eval(model.parameters())
return model
def load_t5_encoder(model_path: Path, config):
"""Load T5 text encoder.
Weights are upcast to float32 for maximum precision — the T5 encoder
only runs once per generation, so performance impact is negligible.
This matches the official which computes softmax in float32 explicitly.
"""
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=config.t5_vocab_size,
dim=config.t5_dim,
dim_attn=config.t5_dim_attn,
dim_ffn=config.t5_dim_ffn,
num_heads=config.t5_num_heads,
num_layers=config.t5_num_layers,
num_buckets=config.t5_num_buckets,
shared_pos=False,
)
weights = mx.load(str(model_path))
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
encoder.load_weights(list(weights.items()))
mx.eval(encoder.parameters())
return encoder
def load_vae_decoder(model_path: Path, config=None):
"""Load VAE decoder (skips encoder weights with strict=False).
For Wan2.2 (vae_z_dim=48), uses Wan22VAEDecoder.
For Wan2.1 (vae_z_dim=16), uses WanVAE.
"""
is_wan22 = config is not None and config.vae_z_dim == 48
if is_wan22:
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
vae = Wan22VAEDecoder(z_dim=48)
else:
from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16)
weights = mx.load(str(model_path))
# Upcast VAE weights to float32 for quality — official Wan2.2 runs VAE in float32
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
vae.load_weights(list(weights.items()), strict=False)
mx.eval(vae.parameters())
return vae
def load_vae_encoder(model_path: Path, config=None):
"""Load VAE encoder for I2V image encoding.
For Wan2.2 TI2V (vae_z_dim=48), uses Wan22VAEEncoder.
For Wan2.1/I2V-14B (vae_z_dim=16), uses WanVAE with encoder=True.
"""
if config is not None and config.vae_z_dim == 16:
from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16, encoder=True)
else:
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
vae = Wan22VAEEncoder(z_dim=config.vae_z_dim if config else 48)
weights = mx.load(str(model_path))
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
vae.load_weights(list(weights.items()), strict=False)
mx.eval(vae.parameters())
return vae
def _clean_text(text: str) -> str:
"""Clean text matching official Wan2.2 tokenizer preprocessing.
Applies ftfy.fix_text (fixes mojibake, normalizes fullwidth chars),
double HTML unescape, and whitespace normalization. Critical for
correct tokenization of the Chinese negative prompt.
"""
import html
import re
try:
import ftfy
text = ftfy.fix_text(text)
except ImportError:
pass
text = html.unescape(html.unescape(text))
text = re.sub(r"\s+", " ", text).strip()
return text
def encode_text(
encoder,
tokenizer,
prompt: str,
text_len: int = 512,
) -> mx.array:
"""Encode text prompt using T5 encoder.
Args:
encoder: T5Encoder model
tokenizer: HuggingFace tokenizer
prompt: Text prompt
text_len: Maximum text length
Returns:
Text embeddings [L, dim]
"""
prompt = _clean_text(prompt)
tokens = tokenizer(
prompt,
max_length=text_len,
padding="max_length",
truncation=True,
return_tensors="np",
)
ids = mx.array(tokens["input_ids"])
mask = mx.array(tokens["attention_mask"])
embeddings = encoder(ids, mask=mask)
# Return only non-padding tokens
seq_len = int(mask.sum().item())
return embeddings[0, :seq_len]

View File

@@ -0,0 +1,377 @@
import math
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .attention import WanLayerNorm, _linear_dtype
from .config import WanModelConfig
from .rope import rope_params, rope_precompute_cos_sin
from .transformer import WanAttentionBlock
def sinusoidal_embedding_1d(dim: int, position: mx.array) -> mx.array:
"""Compute sinusoidal positional embeddings.
Args:
dim: Embedding dimension (must be even).
position: Tensor of positions — 1D [L] or 2D [B, L].
Returns:
Embeddings of shape [L, dim] or [B, L, dim].
"""
assert dim % 2 == 0
half = dim // 2
pos = position.astype(mx.float32)
inv_freq = mx.power(10000.0, -mx.arange(half).astype(mx.float32) / half)
sinusoid = pos[..., None] * inv_freq # [..., half]
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
class Head(nn.Module):
"""Output projection head with learned modulation."""
def __init__(self, dim: int, out_dim: int, patch_size: tuple, eps: float = 1e-6):
super().__init__()
self.out_dim = out_dim
self.patch_size = patch_size
proj_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, proj_dim)
self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(mx.float32)
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
"""
Args:
x: [B, L, dim]
e: [B, dim] or [B, 1, dim] (broadcast) or [B, L, dim] (per-token)
"""
if e.ndim == 2:
e = e[:, None, :] # [B, 1, dim]
# Compute modulation in float32 (matching reference's autocast(float32))
mod = self.modulation[:, None, :, :] + e[:, :, None, :] # float32
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
x_norm = self.norm(x)
x_mod = x_norm * (1 + e1) + e0
return self.head(x_mod)
class WanModel(nn.Module):
"""Wan2.2 diffusion backbone for text-to-video generation."""
def __init__(self, config: WanModelConfig):
super().__init__()
self.config = config
dim = config.dim
self.dim = dim
self.num_heads = config.num_heads
self.out_dim = config.out_dim
self.patch_size = config.patch_size
self.text_len = config.text_len
self.freq_dim = config.freq_dim
# Patch embedding: Conv3d implemented as a reshaped linear
# For kernel (1,2,2) and stride (1,2,2): reshape input then linear
patch_dim = config.in_dim * math.prod(config.patch_size)
self.patch_embedding_proj = nn.Linear(patch_dim, dim)
self._patch_size = config.patch_size
# Text embedding MLP
self.text_embedding_0 = nn.Linear(config.text_dim, dim)
self.text_embedding_act = nn.GELU(approx="tanh")
self.text_embedding_1 = nn.Linear(dim, dim)
# Time embedding MLP
self.time_embedding_0 = nn.Linear(config.freq_dim, dim)
self.time_embedding_act = nn.SiLU()
self.time_embedding_1 = nn.Linear(dim, dim)
# Time projection for modulation (6x dim)
self.time_projection_act = nn.SiLU()
self.time_projection = nn.Linear(dim, dim * 6)
# Transformer blocks
self.blocks = [
WanAttentionBlock(
dim=dim,
ffn_dim=config.ffn_dim,
num_heads=config.num_heads,
window_size=config.window_size,
qk_norm=config.qk_norm,
cross_attn_norm=config.cross_attn_norm,
eps=config.eps,
)
for _ in range(config.num_layers)
]
# Output head
self.head = Head(dim, config.out_dim, config.patch_size, config.eps)
# Precompute RoPE frequencies — three separate tables concatenated.
# Reference computes three rope_params with different dim normalizations
# so each axis (temporal/height/width) gets its own full frequency range.
d = dim // config.num_heads
self.freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
# Precompute sinusoidal inv_freq for time embedding.
half = config.freq_dim // 2
self._inv_freq = mx.array(
np.power(10000.0, -np.arange(half, dtype=np.float64) / half
).astype(np.float32)
)
def _patchify(self, x: mx.array) -> tuple:
"""Convert video tensor to patch embeddings.
Args:
x: Video latent [C, F, H, W]
Returns:
(patches, grid_size): patches [1, L, dim], grid_size (F', H', W')
"""
c, f, h, w = x.shape
pt, ph, pw = self._patch_size
f_out = f // pt
h_out = h // ph
w_out = w // pw
# Reshape: [C, F, H, W] -> [F', H', W', C, pt, ph, pw] -> [F'*H'*W', C*pt*ph*pw]
# Order must be [C, pt, ph, pw] (C slowest) to match Conv3d weight layout
x = x.reshape(c, f_out, pt, h_out, ph, w_out, pw)
x = x.transpose(1, 3, 5, 0, 2, 4, 6) # [F', H', W', C, pt, ph, pw]
x = x.reshape(f_out * h_out * w_out, -1) # [L, C*pt*ph*pw]
# Project and cast to model dtype to prevent float32 cascade from input latents
patches = self.patch_embedding_proj(x) # [L, dim]
patches = patches.astype(_linear_dtype(self.patch_embedding_proj))
patches = patches[None, :, :] # [1, L, dim]
return patches, (f_out, h_out, w_out)
def unpatchify(self, x: mx.array, grid_sizes: list) -> list:
"""Reconstruct video from patch embeddings.
Args:
x: [B, L, out_dim * prod(patch_size)]
grid_sizes: List of (F', H', W') per batch element
Returns:
List of tensors [C, F, H, W]
"""
c = self.out_dim
pt, ph, pw = self.patch_size
out = []
for i, (f, h, w) in enumerate(grid_sizes):
seq_len = f * h * w
u = x[i, :seq_len] # [L, out_dim * pt * ph * pw]
u = u.reshape(f, h, w, pt, ph, pw, c)
# Rearrange: [F', H', W', pt, ph, pw, C] -> [C, F'*pt, H'*ph, W'*pw]
u = u.transpose(6, 0, 3, 1, 4, 2, 5) # [C, F', pt, H', ph, W', pw]
u = u.reshape(c, f * pt, h * ph, w * pw)
out.append(u)
return out
def embed_text(self, context: list) -> mx.array:
"""Precompute text embeddings (call once, reuse across steps).
Args:
context: List of text embeddings [L_text, text_dim]
Returns:
Embedded context [B, text_len, dim] in model dtype
"""
model_dtype = _linear_dtype(self.patch_embedding_proj)
context_padded = []
for ctx in context:
pad_len = self.text_len - ctx.shape[0]
if pad_len > 0:
ctx = mx.concatenate(
[ctx, mx.zeros((pad_len, ctx.shape[1]), dtype=ctx.dtype)],
axis=0,
)
context_padded.append(ctx)
context_batch = mx.stack(context_padded) # [B, text_len, text_dim]
context_batch = self.text_embedding_1(
self.text_embedding_act(self.text_embedding_0(context_batch))
)
return context_batch.astype(model_dtype)
def prepare_cross_kv(self, context: mx.array) -> list:
"""Pre-compute cross-attention K/V for all blocks.
Call once before the diffusion loop to cache K/V projections,
eliminating redundant computation at each denoising step.
Args:
context: Pre-embedded text [B, text_len, dim]
Returns:
List of (k, v) tuples, one per block
"""
kv_caches = []
for block in self.blocks:
kv_caches.append(block.cross_attn.prepare_kv(context))
return kv_caches
def prepare_rope(self, grid_sizes: list) -> tuple:
"""Pre-compute RoPE cos/sin for constant grid sizes.
Call once before the diffusion loop when grid sizes don't change
across steps. Eliminates per-step broadcast/concat overhead.
Args:
grid_sizes: List of (F, H, W) tuples per batch element
Returns:
(cos_f, sin_f) precomputed frequency tensors
"""
w_dtype = _linear_dtype(self.patch_embedding_proj)
return rope_precompute_cos_sin(grid_sizes, self.freqs, dtype=w_dtype)
def __call__(
self,
x_list: list,
t: mx.array,
context: list | mx.array,
seq_len: int,
cross_kv_caches: list | None = None,
y: list | None = None,
rope_cos_sin: tuple | None = None,
) -> list:
"""Forward pass.
Args:
x_list: List of video latent tensors [C, F, H, W]
t: Timestep tensor [B]
context: List of raw text embeddings, OR pre-embedded tensor
from embed_text() [B, text_len, dim]
seq_len: Maximum sequence length for padding
cross_kv_caches: Optional list of (k, v) tuples from
prepare_cross_kv(), one per block.
y: Optional list of conditioning tensors for I2V [C_y, F, H, W].
Channel-concatenated with x before patchify.
rope_cos_sin: Optional precomputed (cos, sin) from prepare_rope().
Returns:
List of denoised tensors [C, F, H, W]
"""
# Detect identical inputs (CFG B=2) to avoid duplicate patchify work.
# Check BEFORE I2V concat since concat creates new array objects.
batch_size = len(x_list)
all_same = batch_size > 1 and all(
x_list[i] is x_list[0] for i in range(1, batch_size)
)
if all_same and y is not None:
all_same = all(y[i] is y[0] for i in range(1, len(y)))
# I2V: channel-concatenate conditioning y with noise x
if y is not None:
x_list = [mx.concatenate([u, v], axis=0) for u, v in zip(x_list, y)]
if all_same:
# Patchify once and broadcast — saves a Linear projection per step
p, gs = self._patchify(x_list[0]) # [1, L, dim]
grid_sizes = [gs] * batch_size
seq_lens_list = [p.shape[1]] * batch_size
# Pad and broadcast
if p.shape[1] < seq_len:
p = mx.concatenate(
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
axis=1,
)
x = mx.broadcast_to(p, (batch_size,) + p.shape[1:])
else:
patches = []
grid_sizes = []
seq_lens_list = []
for vid in x_list:
p, gs = self._patchify(vid) # [1, L, dim]
patches.append(p)
grid_sizes.append(gs)
seq_lens_list.append(p.shape[1])
x = mx.concatenate(
[
mx.concatenate(
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
axis=1,
)
if p.shape[1] < seq_len
else p
for p in patches
],
axis=0,
) # [B, seq_len, dim]
# Time embedding: sinusoidal from precomputed inv_freq.
# inv_freq was computed in float64 for precision, stored as float32.
# With integer timesteps (matching reference), float32 sin/cos is fine.
if t.ndim == 0:
t = t[None]
sinusoid = t[..., None].astype(mx.float32) * self._inv_freq
sin_emb = mx.concatenate(
[mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1
)
if t.ndim == 1:
# Standard T2V: scalar timestep per batch element [B]
e = self.time_embedding_1(
self.time_embedding_act(self.time_embedding_0(sin_emb))
) # [B, dim]
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
e0 = e0.reshape(batch_size, 1, 6, self.dim)
else:
# I2V: per-token timesteps [B, L]
e = self.time_embedding_1(
self.time_embedding_act(self.time_embedding_0(sin_emb))
) # [B, L, dim]
e0 = self.time_projection(self.time_projection_act(e)) # [B, L, dim*6]
e0 = e0.reshape(batch_size, -1, 6, self.dim)
# Text embedding: skip MLP if context is already embedded (mx.array)
if isinstance(context, mx.array):
# Pre-embedded: expand to batch size if needed
context_batch = context
if context_batch.shape[0] == 1 and batch_size > 1:
context_batch = mx.broadcast_to(
context_batch, (batch_size,) + context_batch.shape[1:]
)
else:
context_batch = self.embed_text(context)
# Pre-compute attention mask from seq_lens (constant across all blocks)
attn_mask = None
w_dtype = _linear_dtype(self.patch_embedding_proj)
if any(sl < seq_len for sl in seq_lens_list):
attn_mask = mx.zeros((batch_size, 1, 1, seq_len), dtype=w_dtype)
for i, sl in enumerate(seq_lens_list):
attn_mask[i, :, :, sl:] = -1e9
kwargs = dict(
e=e0,
seq_lens=seq_lens_list,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context_batch,
context_lens=None,
rope_cos_sin=rope_cos_sin,
attn_mask=attn_mask,
)
# Run transformer blocks
for i, block in enumerate(self.blocks):
kv = cross_kv_caches[i] if cross_kv_caches is not None else None
x = block(x, cross_kv_cache=kv, **kwargs)
# Output head
x = self.head(x, e)
# Unpatchify
outputs = self.unpatchify(x, grid_sizes)
return [u.astype(mx.float32) for u in outputs]

View File

@@ -0,0 +1,35 @@
import numpy as np
from pathlib import Path
def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
"""Save video frames to MP4.
Args:
frames: Video frames [T, H, W, 3] uint8
output_path: Output file path
fps: Frames per second
"""
try:
import imageio
writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8)
for frame in frames:
writer.append_data(frame)
writer.close()
except ImportError:
try:
import cv2
h, w = frames.shape[1], frames.shape[2]
fourcc = cv2.VideoWriter_fourcc(*"avc1")
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
for frame in frames:
writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
writer.release()
except (ImportError, Exception):
# Last resort: save as individual PNGs
from PIL import Image
out_dir = Path(output_path).parent / Path(output_path).stem
out_dir.mkdir(parents=True, exist_ok=True)
for i, frame in enumerate(frames):
Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png")
print(f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)")

View File

@@ -0,0 +1,178 @@
import math
import mlx.core as mx
import numpy as np
def rope_params(max_seq_len: int, dim: int, theta: float = 10000.0) -> mx.array:
"""Precompute RoPE frequency parameters as complex numbers.
Returns:
Complex frequency tensor of shape [max_seq_len, dim // 2].
"""
assert dim % 2 == 0
freqs = np.arange(max_seq_len, dtype=np.float64)[:, None] * (
1.0
/ np.power(
theta,
np.arange(0, dim, 2, dtype=np.float64) / dim,
)
)[None, :]
# Store as (cos, sin) pairs: shape [max_seq_len, dim // 2, 2]
cos_freqs = np.cos(freqs).astype(np.float32)
sin_freqs = np.sin(freqs).astype(np.float32)
return mx.array(np.stack([cos_freqs, sin_freqs], axis=-1))
def rope_apply(
x: mx.array,
grid_sizes: list,
freqs: mx.array,
precomputed_cos_sin: tuple | None = None,
) -> mx.array:
"""Apply 3-way factorized RoPE to Q or K tensor.
Args:
x: Shape [B, L, num_heads, head_dim]
grid_sizes: List of (F, H, W) tuples per batch element
freqs: Precomputed cos/sin, shape [1024, d//2, 2] split into 3 parts
precomputed_cos_sin: Optional (cos, sin) from rope_precompute_cos_sin()
"""
b, s, n, d = x.shape
half_d = d // 2
if precomputed_cos_sin is not None:
cos_f, sin_f = precomputed_cos_sin
# Check if all batch elements have the same grid (common for CFG B=2)
f0, h0, w0 = grid_sizes[0]
seq_len = f0 * h0 * w0
all_same_grid = all(
grid_sizes[i] == grid_sizes[0] for i in range(1, b)
) if b > 1 else True
if all_same_grid:
# Vectorized path: apply RoPE to all batch elements at once
x_seq = x[:, :seq_len].reshape(b, seq_len, n, half_d, 2)
x_real = x_seq[..., 0]
x_imag = x_seq[..., 1]
out_real = x_real * cos_f - x_imag * sin_f
out_imag = x_real * sin_f + x_imag * cos_f
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(b, seq_len, n, d)
if seq_len < s:
x_rotated = mx.concatenate([x_rotated, x[:, seq_len:]], axis=1)
return x_rotated
else:
# Per-element path for mixed grid sizes
outputs = []
for i in range(b):
f, h, w = grid_sizes[i]
sl = f * h * w
x_i = x[i, :sl].reshape(sl, n, half_d, 2)
x_real = x_i[..., 0]
x_imag = x_i[..., 1]
out_real = x_real * cos_f - x_imag * sin_f
out_imag = x_real * sin_f + x_imag * cos_f
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(sl, n, d)
if sl < s:
x_rotated = mx.concatenate([x_rotated, x[i, sl:]], axis=0)
outputs.append(x_rotated)
return mx.stack(outputs)
# Cast freqs to input dtype to prevent float32 promotion cascade
if freqs.dtype != x.dtype:
freqs = freqs.astype(x.dtype)
# Split frequency dimensions: temporal gets more capacity
d_t = half_d - 2 * (half_d // 3)
d_h = half_d // 3
d_w = half_d // 3
# Split freqs along dim axis
freqs_t = freqs[:, :d_t] # [1024, d_t, 2]
freqs_h = freqs[:, d_t : d_t + d_h] # [1024, d_h, 2]
freqs_w = freqs[:, d_t + d_h : d_t + d_h + d_w] # [1024, d_w, 2]
outputs = []
for i in range(b):
f, h, w = grid_sizes[i]
seq_len = f * h * w
# Reshape x to pairs for rotation: [seq_len, n, half_d, 2]
x_i = x[i, :seq_len].reshape(seq_len, n, half_d, 2)
# Build per-position frequencies by expanding along grid dims
# temporal: [f,1,1,d_t,2] -> [f,h,w,d_t,2]
ft = mx.broadcast_to(
freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2)
)
# height: [1,h,1,d_h,2] -> [f,h,w,d_h,2]
fh = mx.broadcast_to(
freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2)
)
# width: [1,1,w,d_w,2] -> [f,h,w,d_w,2]
fw = mx.broadcast_to(
freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2)
)
# Concatenate: [f*h*w, half_d, 2]
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)
# Apply rotation: (a + bi) * (cos + sin*i) = (a*cos - b*sin) + (a*sin + b*cos)i
cos_f = freqs_i[..., 0] # [seq_len, 1, half_d]
sin_f = freqs_i[..., 1] # [seq_len, 1, half_d]
x_real = x_i[..., 0] # [seq_len, n, half_d]
x_imag = x_i[..., 1] # [seq_len, n, half_d]
out_real = x_real * cos_f - x_imag * sin_f
out_imag = x_real * sin_f + x_imag * cos_f
# Interleave back: [seq_len, n, half_d, 2] -> [seq_len, n, d]
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(seq_len, n, d)
# Handle padding: keep non-rotated tokens after seq_len
if seq_len < s:
x_rotated = mx.concatenate([x_rotated, x[i, seq_len:]], axis=0)
outputs.append(x_rotated)
return mx.stack(outputs)
def rope_precompute_cos_sin(
grid_sizes: list, freqs: mx.array, dtype: type = mx.float32
) -> tuple:
"""Precompute cos/sin frequency tensors for constant grid sizes.
Call once before the diffusion loop. Pass result as precomputed_cos_sin
to rope_apply to skip per-step broadcast/concat.
Args:
grid_sizes: List of (F, H, W) tuples (must be same for all batch elements)
freqs: Precomputed frequencies [1024, d//2, 2]
dtype: Target dtype for the output tensors
Returns:
(cos_f, sin_f) each [seq_len, 1, half_d]
"""
if freqs.dtype != dtype:
freqs = freqs.astype(dtype)
f, h, w = grid_sizes[0]
seq_len = f * h * w
half_d = freqs.shape[1]
d_t = half_d - 2 * (half_d // 3)
d_h = half_d // 3
d_w = half_d // 3
freqs_t = freqs[:, :d_t]
freqs_h = freqs[:, d_t : d_t + d_h]
freqs_w = freqs[:, d_t + d_h : d_t + d_h + d_w]
ft = mx.broadcast_to(freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2))
fh = mx.broadcast_to(freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2))
fw = mx.broadcast_to(freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2))
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)
return freqs_i[..., 0], freqs_i[..., 1]

View File

@@ -0,0 +1,452 @@
"""Flow matching schedulers for Wan2.2 inference.
Provides Euler, DPM++2M, and UniPC solvers for flow matching diffusion.
Higher-order solvers (DPM++, UniPC) converge faster, needing fewer steps
for the same quality as Euler.
"""
import math
import numpy as np
import mlx.core as mx
def _compute_sigmas(
num_steps: int, shift: float = 1.0, num_train_timesteps: int = 1000
) -> np.ndarray:
"""Compute shifted sigma schedule matching official Wan2.2 scheduler.
The reference creates FlowUniPCMultistepScheduler with shift=1 (identity)
in the constructor, deriving sigma_max/sigma_min from the unshifted
training schedule. Then set_timesteps() builds a linspace between those
unshifted bounds and applies the actual shift once.
Returns num_steps+1 values (the last being 0.0 for the terminal state).
"""
# sigma bounds from unshifted training schedule (constructor uses shift=1)
alphas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps)[
::-1
]
sigmas_unshifted = 1.0 - alphas
sigma_max = float(sigmas_unshifted[0]) # (N-1)/N
sigma_min = float(sigmas_unshifted[-1]) # 0.0
# Interpolate, then apply shift once (matching set_timesteps)
sigmas = np.linspace(sigma_max, sigma_min, num_steps + 1)[:-1]
sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas)
return np.append(sigmas, 0.0).astype(np.float32)
class FlowMatchEulerScheduler:
"""1st-order Euler scheduler for flow matching diffusion."""
def __init__(self, num_train_timesteps: int = 1000):
self.num_train_timesteps = num_train_timesteps
self.timesteps = None
self.sigmas = None
def set_timesteps(self, num_steps: int, shift: float = 1.0):
sigmas = _compute_sigmas(num_steps, shift, self.num_train_timesteps)
self.sigmas = mx.array(sigmas)
# Integer timesteps to match reference (model trained with int timesteps)
self.timesteps = mx.array(
(sigmas[:-1] * self.num_train_timesteps).astype(np.int64).astype(np.float32)
)
# Store as Python floats to avoid .item() sync in step()
self._sigmas_float = sigmas.tolist()
self._step_index = 0
def step(
self,
model_output: mx.array,
timestep,
sample: mx.array,
) -> mx.array:
"""Euler step: x_next = x + (sigma_next - sigma_cur) * v."""
dt = self._sigmas_float[self._step_index + 1] - self._sigmas_float[self._step_index]
x_next = sample + dt * model_output
self._step_index += 1
return x_next
def reset(self):
self._step_index = 0
class FlowDPMPP2MScheduler:
"""DPM-Solver++(2M) for flow matching diffusion.
2nd-order multistep solver that reuses the previous step's model output
for a correction term. Falls back to 1st order on the first and
(optionally) last step. Reference: Wan2.2 fm_solvers.py.
"""
def __init__(
self,
num_train_timesteps: int = 1000,
lower_order_final: bool = True,
):
self.num_train_timesteps = num_train_timesteps
self.lower_order_final = lower_order_final
self.timesteps = None
self.sigmas = None
def set_timesteps(self, num_steps: int, shift: float = 1.0):
sigmas = _compute_sigmas(num_steps, shift, self.num_train_timesteps)
self.sigmas = mx.array(sigmas)
self.timesteps = mx.array(
(sigmas[:-1] * self.num_train_timesteps).astype(np.int64).astype(np.float32)
)
# Store sigmas as Python floats for scalar math
self._sigmas_float = sigmas.tolist()
self._step_index = 0
self._num_steps = num_steps
self._prev_x0 = None # previous x0 prediction for 2nd-order correction
@staticmethod
def _lambda(sigma: float) -> float:
"""log-SNR: lambda(sigma) = log((1-sigma)/sigma).
Returns -inf at sigma=1.0 (pure noise) and +inf at sigma=0.0 (clean),
matching torch.log behavior in the official code.
"""
if sigma >= 1.0:
return -math.inf
if sigma <= 0.0:
return math.inf
return math.log((1.0 - sigma) / sigma)
def step(
self,
model_output: mx.array,
timestep,
sample: mx.array,
) -> mx.array:
"""DPM++(2M) step for flow matching.
Converts velocity prediction to x0, then applies 1st or 2nd order
update depending on available history.
"""
i = self._step_index
s = self._sigmas_float
sigma_cur = s[i]
sigma_next = s[i + 1]
# Convert velocity -> x0 prediction: x0 = sample - sigma * v
x0 = sample - sigma_cur * model_output
# Decide order: 1st for first step, last step (if lower_order_final
# and few steps), otherwise 2nd
use_first_order = (
self._prev_x0 is None
or (
self.lower_order_final
and i == self._num_steps - 1
and self._num_steps < 15
)
)
if use_first_order or sigma_next == 0.0:
# 1st order DPM++ (equivalent to DDIM):
# x_next = (σ_next/σ_cur)*x - (α_next*(exp(-h)-1))*x0
if sigma_next == 0.0:
x_next = x0
else:
lambda_cur = self._lambda(sigma_cur)
lambda_next = self._lambda(sigma_next)
h = lambda_next - lambda_cur
alpha_next = 1.0 - sigma_next
coeff_x = sigma_next / sigma_cur
coeff_x0 = alpha_next * math.expm1(-h)
x_next = coeff_x * sample - coeff_x0 * x0
else:
# 2nd order DPM++(2M) with midpoint correction
sigma_prev = s[i - 1]
lambda_prev = self._lambda(sigma_prev)
lambda_cur = self._lambda(sigma_cur)
lambda_next = self._lambda(sigma_next)
h = lambda_next - lambda_cur
h_0 = lambda_cur - lambda_prev
r0 = h_0 / h
# D0 = current x0, D1 = correction from previous x0
D0 = x0
D1 = (1.0 / r0) * (x0 - self._prev_x0)
alpha_next = 1.0 - sigma_next
exp_neg_h_m1 = math.expm1(-h) # exp(-h) - 1
x_next = (
(sigma_next / sigma_cur) * sample
- (alpha_next * exp_neg_h_m1) * D0
- 0.5 * (alpha_next * exp_neg_h_m1) * D1
)
self._prev_x0 = x0
self._step_index += 1
return x_next
def reset(self):
self._step_index = 0
self._prev_x0 = None
class FlowUniPCScheduler:
"""UniPC (Unified Predictor-Corrector) for flow matching diffusion.
Multi-step predictor-corrector solver with configurable order.
The corrector refines each step using the model output that was already
computed, costing no extra model evaluations. Official Wan2.2 default.
Reference: Wan2.2 fm_solvers_unipc.py.
"""
def __init__(
self,
num_train_timesteps: int = 1000,
solver_order: int = 2,
lower_order_final: bool = True,
disable_corrector: list | None = None,
use_corrector: bool = True,
):
self.num_train_timesteps = num_train_timesteps
self.solver_order = solver_order
self.lower_order_final = lower_order_final
self._use_corrector = use_corrector
self.disable_corrector = set(disable_corrector or [])
self.timesteps = None
self.sigmas = None
def set_timesteps(self, num_steps: int, shift: float = 1.0):
sigmas = _compute_sigmas(num_steps, shift, self.num_train_timesteps)
self.sigmas = mx.array(sigmas)
self.timesteps = mx.array(
(sigmas[:-1] * self.num_train_timesteps).astype(np.int64).astype(np.float32)
)
self._sigmas_float = sigmas.tolist()
self._step_index = 0
self._num_steps = num_steps
self._lower_order_nums = 0
# Model output (x0) history for multi-step, stored newest-last
self._model_outputs = [None] * self.solver_order
self._last_sample = None # sample before prediction (for corrector)
self._this_order = 1
@staticmethod
def _lambda(sigma: float) -> float:
"""log-SNR: lambda(sigma) = log((1-sigma)/sigma).
Returns -inf at sigma=1.0 (pure noise) and +inf at sigma=0.0 (clean),
matching torch.log behavior in the official code.
"""
if sigma >= 1.0:
return -math.inf
if sigma <= 0.0:
return math.inf
return math.log((1.0 - sigma) / sigma)
def _convert_output(self, velocity: mx.array, sample: mx.array) -> mx.array:
"""Convert velocity prediction to x0: x0 = sample - sigma * v."""
sigma = self._sigmas_float[self._step_index]
return sample - sigma * velocity
def _uni_p_bh2(self, x0: mx.array, sample: mx.array, order: int) -> mx.array:
"""UniP predictor with B(h)=expm1(-h) basis (bh2 variant).
Matches official multistep_uni_p_bh_update: computes rhos_p via
linalg.solve for order >= 3; order <= 2 uses analytic rhos_p=[0.5].
"""
i = self._step_index
s = self._sigmas_float
sigma_s0 = s[i]
sigma_t = s[i + 1]
if sigma_t == 0.0:
return x0
lambda_s0 = self._lambda(sigma_s0)
lambda_t = self._lambda(sigma_t)
h = lambda_t - lambda_s0
hh = -h # negated for predict_x0
alpha_t = 1.0 - sigma_t
h_phi_1 = math.expm1(hh)
B_h = h_phi_1
m0 = self._model_outputs[-1]
# Base prediction
x_t = (sigma_t / sigma_s0) * sample - (alpha_t * h_phi_1) * m0
if order >= 2 and m0 is not None:
rks = []
D1s = []
for k in range(1, order):
si_idx = i - k
if si_idx < 0 or self._model_outputs[-(k + 1)] is None:
break
mk = self._model_outputs[-(k + 1)]
sigma_sk = s[si_idx]
lambda_sk = self._lambda(sigma_sk)
rk = (lambda_sk - lambda_s0) / h
if math.isinf(rk):
break
rks.append(rk)
D1s.append((mk - m0) / rk)
if D1s:
effective_order = len(D1s) + 1
if effective_order <= 2:
# Analytic solution for order 2
rhos_p = [0.5]
else:
rks_arr = np.array(rks, dtype=np.float64)
h_phi_k = h_phi_1 / hh - 1.0
factorial_i = 1
R_rows = []
b_vals = []
for j in range(1, effective_order):
R_rows.append(rks_arr ** (j - 1))
b_vals.append(float(h_phi_k * factorial_i / B_h))
factorial_i *= j + 1
h_phi_k = h_phi_k / hh - 1.0 / factorial_i
R = np.stack(R_rows)
b = np.array(b_vals)
rhos_p = np.linalg.solve(R, b).tolist()
pred_res = sum(r * d for r, d in zip(rhos_p, D1s))
x_t = x_t - (alpha_t * B_h) * pred_res
return x_t
def _uni_c_bh2(
self,
model_x0: mx.array,
last_sample: mx.array,
this_sample: mx.array,
order: int,
) -> mx.array:
"""UniC corrector with B(h)=expm1(-h) basis (bh2 variant).
Matches official multistep_uni_c_bh_update: computes rhos_c via
linalg.solve for order >= 2 (not hardcoded 0.5).
"""
i = self._step_index
s = self._sigmas_float
sigma_s0 = s[i - 1]
sigma_t = s[i]
if sigma_t == 0.0:
return this_sample
lambda_s0 = self._lambda(sigma_s0)
lambda_t = self._lambda(sigma_t)
h = lambda_t - lambda_s0
hh = -h # negated for predict_x0
alpha_t = 1.0 - sigma_t
h_phi_1 = math.expm1(hh)
B_h = h_phi_1
m0 = self._model_outputs[-1]
# Re-derive base from last_sample
x_t_ = (sigma_t / sigma_s0) * last_sample - (alpha_t * h_phi_1) * m0
D1_t = model_x0 - m0
# Gather rks and D1s from history
rks = []
D1s = []
for k in range(1, order):
si_idx = i - (k + 1)
if si_idx < 0 or self._model_outputs[-(k + 1)] is None:
break
mk = self._model_outputs[-(k + 1)]
sigma_sk = s[si_idx]
lambda_sk = self._lambda(sigma_sk)
rk = (lambda_sk - lambda_s0) / h
if math.isinf(rk):
break # History references sigma=1.0 boundary; reduce order
rks.append(rk)
D1s.append((mk - m0) / rk)
rks.append(1.0)
effective_order = len(rks) # = len(D1s) + 1
# Compute rhos_c coefficients
if effective_order == 1:
rhos_c = [0.5]
else:
rks_arr = np.array(rks, dtype=np.float64)
h_phi_k = h_phi_1 / hh - 1.0
factorial_i = 1
R_rows = []
b_vals = []
for j in range(1, effective_order + 1):
R_rows.append(rks_arr ** (j - 1))
b_vals.append(float(h_phi_k * factorial_i / B_h))
factorial_i *= j + 1
h_phi_k = h_phi_k / hh - 1.0 / factorial_i
R = np.stack(R_rows)
b = np.array(b_vals)
rhos_c = np.linalg.solve(R, b).tolist()
# Apply correction
corr_res = mx.zeros_like(D1_t)
for k_idx, d1 in enumerate(D1s):
corr_res = corr_res + rhos_c[k_idx] * d1
x_t = x_t_ - (alpha_t * B_h) * (corr_res + rhos_c[-1] * D1_t)
return x_t
def step(
self,
model_output: mx.array,
timestep,
sample: mx.array,
) -> mx.array:
"""UniPC step: correct current, then predict next."""
i = self._step_index
# Convert velocity -> x0
x0 = self._convert_output(model_output, sample)
# 1. Corrector: refine current sample if we have history
use_corrector = (
self._use_corrector
and i > 0
and (i - 1) not in self.disable_corrector
and self._last_sample is not None
)
if use_corrector:
sample = self._uni_c_bh2(x0, self._last_sample, sample, self._this_order)
# 2. Shift model output history
for k in range(self.solver_order - 1):
self._model_outputs[k] = self._model_outputs[k + 1]
self._model_outputs[-1] = x0
# 3. Determine prediction order
if self.lower_order_final:
this_order = min(self.solver_order, self._num_steps - i)
else:
this_order = self.solver_order
self._this_order = min(this_order, self._lower_order_nums + 1)
# 4. Predict next sample
self._last_sample = sample
x_next = self._uni_p_bh2(x0, sample, self._this_order)
if self._lower_order_nums < self.solver_order:
self._lower_order_nums += 1
self._step_index += 1
return x_next
def reset(self):
self._step_index = 0
self._lower_order_nums = 0
self._model_outputs = [None] * self.solver_order
self._last_sample = None
self._this_order = 1

View File

@@ -0,0 +1,240 @@
"""T5 Text Encoder (UMT5-XXL) for Wan2.2 text conditioning."""
import math
import mlx.core as mx
import mlx.nn as nn
class T5LayerNorm(nn.Module):
"""RMS-based layer normalization (T5 style)."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = mx.ones((dim,))
def __call__(self, x: mx.array) -> mx.array:
return mx.fast.rms_norm(x, self.weight, self.eps)
class T5RelativeEmbedding(nn.Module):
"""T5-style relative position bias with bucketing."""
def __init__(
self,
num_buckets: int,
num_heads: int,
bidirectional: bool = True,
max_dist: int = 128,
):
super().__init__()
self.num_buckets = num_buckets
self.num_heads = num_heads
self.bidirectional = bidirectional
self.max_dist = max_dist
self.embedding = nn.Embedding(num_buckets, num_heads)
def _relative_position_bucket(self, rel_pos: mx.array) -> mx.array:
if self.bidirectional:
num_buckets = self.num_buckets // 2
rel_buckets = (rel_pos > 0).astype(mx.int32) * num_buckets
rel_pos = mx.abs(rel_pos)
else:
num_buckets = self.num_buckets
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
rel_pos = mx.maximum(-rel_pos, mx.zeros_like(rel_pos))
max_exact = num_buckets // 2
is_small = rel_pos < max_exact
rel_pos_f = rel_pos.astype(mx.float32)
rel_pos_large = (
max_exact
+ (
mx.log(rel_pos_f / max_exact)
/ math.log(self.max_dist / max_exact)
* (num_buckets - max_exact)
).astype(mx.int32)
)
rel_pos_large = mx.minimum(
rel_pos_large,
mx.full(rel_pos_large.shape, num_buckets - 1, dtype=mx.int32),
)
rel_buckets = rel_buckets + mx.where(is_small, rel_pos.astype(mx.int32), rel_pos_large)
return rel_buckets
def __call__(self, lq: int, lk: int) -> mx.array:
positions_k = mx.arange(lk)[None, :] # [1, lk]
positions_q = mx.arange(lq)[:, None] # [lq, 1]
rel_pos = positions_k - positions_q # [lq, lk]
buckets = self._relative_position_bucket(rel_pos)
embeds = self.embedding(buckets) # [lq, lk, num_heads]
embeds = embeds.transpose(2, 0, 1)[None, :, :, :] # [1, N, lq, lk]
return embeds
class T5Attention(nn.Module):
"""T5-style multi-head attention (no scaling)."""
def __init__(self, dim: int, dim_attn: int, num_heads: int, dropout: float = 0.0):
super().__init__()
assert dim_attn % num_heads == 0
self.dim = dim
self.dim_attn = dim_attn
self.num_heads = num_heads
self.head_dim = dim_attn // num_heads
self.q = nn.Linear(dim, dim_attn, bias=False)
self.k = nn.Linear(dim, dim_attn, bias=False)
self.v = nn.Linear(dim, dim_attn, bias=False)
self.o = nn.Linear(dim_attn, dim, bias=False)
def __call__(
self,
x: mx.array,
context: mx.array | None = None,
mask: mx.array | None = None,
pos_bias: mx.array | None = None,
) -> mx.array:
context = x if context is None else context
b, n, c = x.shape[0], self.num_heads, self.head_dim
q = self.q(x).reshape(b, -1, n, c) # [B, Lq, N, C]
k = self.k(context).reshape(b, -1, n, c) # [B, Lk, N, C]
v = self.v(context).reshape(b, -1, n, c)
# T5 uses no scaling — compute attention manually with float32 softmax
# to match official: F.softmax(attn.float(), dim=-1).type_as(attn)
# Using SDPA with bfloat16 inputs causes precision loss in softmax
# since unscaled logits can be very large (no 1/sqrt(d) division).
q = q.transpose(0, 2, 1, 3) # [B, N, Lq, C]
k = k.transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
# QK^T (no scaling) — compute in float32 for precision
attn = (q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2))
# Add position bias
if pos_bias is not None:
attn = attn + pos_bias.astype(mx.float32)
# Apply attention mask (use dtype min like official, not -1e9)
if mask is not None:
if mask.ndim == 2:
mask = mask[:, None, None, :] # [B, 1, 1, Lk]
elif mask.ndim == 3:
mask = mask[:, None, :, :] # [B, 1, Lq, Lk]
additive_mask = mx.where(mask == 0, -3.389e38, 0.0).astype(mx.float32)
attn = attn + additive_mask
# Softmax in float32 (matches official), then cast back
attn = mx.softmax(attn, axis=-1).astype(q.dtype)
# Attention @ V
out = (attn @ v).transpose(0, 2, 1, 3).reshape(b, -1, n * c)
return self.o(out)
class T5FeedForward(nn.Module):
"""Gated feed-forward: gate(x) * fc1(x) -> fc2."""
def __init__(self, dim: int, dim_ffn: int):
super().__init__()
self.dim = dim
self.dim_ffn = dim_ffn
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
self.gate_act = nn.GELU(approx="tanh")
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
def __call__(self, x: mx.array) -> mx.array:
return self.fc2(self.fc1(x) * self.gate_act(self.gate_proj(x)))
class T5SelfAttentionBlock(nn.Module):
"""T5 encoder block: self-attention + FFN."""
def __init__(
self,
dim: int,
dim_attn: int,
dim_ffn: int,
num_heads: int,
num_buckets: int,
shared_pos: bool = True,
):
super().__init__()
self.shared_pos = shared_pos
self.norm1 = T5LayerNorm(dim)
self.attn = T5Attention(dim, dim_attn, num_heads)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn)
self.pos_embedding = (
None
if shared_pos
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
)
def __call__(
self,
x: mx.array,
mask: mx.array | None = None,
pos_bias: mx.array | None = None,
) -> mx.array:
e = pos_bias if self.shared_pos else self.pos_embedding(x.shape[1], x.shape[1])
x = x + self.attn(self.norm1(x), mask=mask, pos_bias=e)
x = x + self.ffn(self.norm2(x))
return x
class T5Encoder(nn.Module):
"""T5 Encoder (UMT5-XXL configuration)."""
def __init__(
self,
vocab_size: int = 256384,
dim: int = 4096,
dim_attn: int = 4096,
dim_ffn: int = 10240,
num_heads: int = 64,
num_layers: int = 24,
num_buckets: int = 32,
shared_pos: bool = False,
):
super().__init__()
self.dim = dim
self.token_embedding = nn.Embedding(vocab_size, dim)
self.pos_embedding = (
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
if shared_pos
else None
)
self.blocks = [
T5SelfAttentionBlock(
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos
)
for _ in range(num_layers)
]
self.norm = T5LayerNorm(dim)
def __call__(self, ids: mx.array, mask: mx.array | None = None) -> mx.array:
"""
Args:
ids: Token IDs [B, L]
mask: Attention mask [B, L]
Returns:
Hidden states [B, L, dim]
"""
x = self.token_embedding(ids)
e = self.pos_embedding(x.shape[1], x.shape[1]) if self.pos_embedding else None
for block in self.blocks:
x = block(x, mask=mask, pos_bias=e)
x = self.norm(x)
return x

View File

@@ -0,0 +1,281 @@
"""Wan-specific tiled VAE decoding.
Re-exports all tiling utilities from the LTX VAE tiling module and provides
a Wan-specific ``decode_with_tiling`` that adds ``causal_temporal`` support
for non-causal temporal decoders (e.g. Wan2.1 where T latent frames → T*scale
output frames rather than LTX's 1+(T-1)*scale mapping).
# TODO: This function can be refactored to consolidate with
# mlx_video.models.ltx.video_vae.tiling.decode_with_tiling once the
# causal_temporal generalisation is accepted upstream.
"""
from typing import Callable, Optional
import mlx.core as mx
from mlx_video.models.ltx.video_vae.tiling import (
SpatialTilingConfig,
TemporalTilingConfig,
TilingConfig,
map_spatial_slice,
map_temporal_slice,
split_in_spatial,
split_in_temporal,
)
__all__ = [
"SpatialTilingConfig",
"TemporalTilingConfig",
"TilingConfig",
"decode_with_tiling",
"map_spatial_slice",
"map_temporal_slice",
"split_in_spatial",
"split_in_temporal",
]
def decode_with_tiling(
decoder_fn,
latents: mx.array,
tiling_config: TilingConfig,
spatial_scale: int = 32,
temporal_scale: int = 8,
causal: bool = False,
causal_temporal: bool = True,
timestep: Optional[mx.array] = None,
chunked_conv: bool = False,
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
) -> mx.array:
"""Decode latents using tiling to reduce memory usage.
Args:
decoder_fn: Decoder function to call for each tile.
latents: Input latents of shape (B, C, F, H, W).
tiling_config: Tiling configuration.
spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify).
temporal_scale: Temporal scale factor (8 for LTX VAE).
causal: Whether to use causal convolutions.
causal_temporal: Whether the decoder uses causal temporal mapping where
T input frames produce 1+(T-1)*scale output frames. When False, uses
simple scaling where T frames produce T*scale output frames.
Default True (LTX behavior). Set False for non-causal decoders (e.g. Wan2.1).
timestep: Optional timestep for conditioning.
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames.
start_idx: Starting frame index in the full video.
Returns:
Decoded video.
"""
import gc
b, c, f_latent, h_latent, w_latent = latents.shape
# Compute output shape
out_f = (1 + (f_latent - 1) * temporal_scale) if causal_temporal else (f_latent * temporal_scale)
out_h = h_latent * spatial_scale
out_w = w_latent * spatial_scale
# Get tile size and overlap in latent space
if tiling_config.spatial_config is not None:
s_cfg = tiling_config.spatial_config
spatial_tile_size = s_cfg.tile_size_in_pixels // spatial_scale
spatial_overlap = s_cfg.tile_overlap_in_pixels // spatial_scale
else:
spatial_tile_size = max(h_latent, w_latent)
spatial_overlap = 0
if tiling_config.temporal_config is not None:
t_cfg = tiling_config.temporal_config
temporal_tile_size = t_cfg.tile_size_in_frames // temporal_scale
temporal_overlap = t_cfg.tile_overlap_in_frames // temporal_scale
else:
temporal_tile_size = f_latent
temporal_overlap = 0
# Compute intervals for each dimension
if causal_temporal:
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
else:
temporal_intervals = split_in_spatial(temporal_tile_size, temporal_overlap, f_latent)
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
num_t_tiles = len(temporal_intervals.starts)
num_h_tiles = len(height_intervals.starts)
num_w_tiles = len(width_intervals.starts)
total_tiles = num_t_tiles * num_h_tiles * num_w_tiles # noqa: F841
# Initialize output and weight accumulator
# Use float32 for accumulation to avoid precision issues
output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32)
weights = mx.zeros((b, 1, out_f, out_h, out_w), dtype=mx.float32)
mx.eval(output, weights)
tile_idx = 0
for t_idx in range(num_t_tiles):
t_start = temporal_intervals.starts[t_idx]
t_end = temporal_intervals.ends[t_idx]
t_left = temporal_intervals.left_ramps[t_idx]
t_right = temporal_intervals.right_ramps[t_idx]
# Map temporal coordinates
if causal_temporal:
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
else:
out_t_slice, t_mask = map_spatial_slice(t_start, t_end, t_left, t_right, temporal_scale)
for h_idx in range(num_h_tiles):
h_start = height_intervals.starts[h_idx]
h_end = height_intervals.ends[h_idx]
h_left = height_intervals.left_ramps[h_idx]
h_right = height_intervals.right_ramps[h_idx]
# Map height coordinates
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
for w_idx in range(num_w_tiles):
w_start = width_intervals.starts[w_idx]
w_end = width_intervals.ends[w_idx]
w_left = width_intervals.left_ramps[w_idx]
w_right = width_intervals.right_ramps[w_idx]
# Map width coordinates
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
# Extract tile latents (small slice)
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
# Decode tile
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
mx.eval(tile_output)
# Clear tile_latents reference
del tile_latents
# Get actual decoded dimensions
_, _, decoded_t, decoded_h, decoded_w = tile_output.shape
expected_t = out_t_slice.stop - out_t_slice.start
expected_h = out_h_slice.stop - out_h_slice.start
expected_w = out_w_slice.stop - out_w_slice.start
# Handle potential size mismatches (use minimum)
actual_t = min(decoded_t, expected_t)
actual_h = min(decoded_h, expected_h)
actual_w = min(decoded_w, expected_w)
# Build blend mask
t_mask_slice = t_mask[:actual_t] if len(t_mask) > actual_t else t_mask
h_mask_slice = h_mask[:actual_h] if len(h_mask) > actual_h else h_mask
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
blend_mask = (
t_mask_slice.reshape(1, 1, -1, 1, 1) *
h_mask_slice.reshape(1, 1, 1, -1, 1) *
w_mask_slice.reshape(1, 1, 1, 1, -1)
)
# Slice tile output to match
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
# Clear full tile_output
del tile_output
# Compute output coordinates
t_out_start = out_t_slice.start
t_out_end = t_out_start + actual_t
h_out_start = out_h_slice.start
h_out_end = h_out_start + actual_h
w_out_start = out_w_slice.start
w_out_end = w_out_start + actual_w
# Weighted accumulation
weighted_tile = tile_output_slice * blend_mask
# Update output using slice assignment
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
)
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
)
# Force evaluation to free memory
mx.eval(output, weights)
# Clean up tile-specific arrays
del tile_output_slice, weighted_tile, blend_mask
del t_mask_slice, h_mask_slice, w_mask_slice
tile_idx += 1
# Periodic garbage collection and cache clearing
if tile_idx % 4 == 0:
gc.collect()
try:
mx.clear_cache()
except Exception:
pass # May not be available on all platforms
# After completing all spatial tiles for this temporal tile,
# check if any frames are now finalized (no future tiles will contribute)
if on_frames_ready is not None and num_t_tiles > 1:
# Determine the finalized frame boundary
# Frames before the start of the next tile's output region are finalized
if t_idx < num_t_tiles - 1:
# Next tile starts at temporal_intervals.starts[t_idx + 1]
next_tile_start_latent = temporal_intervals.starts[t_idx + 1]
# Map to output frame index (first frame of next tile's contribution)
if next_tile_start_latent == 0:
next_tile_start_out = 0
elif causal_temporal:
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
else:
next_tile_start_out = next_tile_start_latent * temporal_scale
# We need to track how many frames we've already emitted
if not hasattr(decode_with_tiling, '_emitted_frames'):
decode_with_tiling._emitted_frames = 0
emitted = decode_with_tiling._emitted_frames
if next_tile_start_out > emitted:
# Normalize and emit frames [emitted, next_tile_start_out)
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
finalized_weights = mx.maximum(finalized_weights, 1e-8)
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
finalized_output = finalized_output.astype(latents.dtype)
mx.eval(finalized_output)
on_frames_ready(finalized_output, emitted)
decode_with_tiling._emitted_frames = next_tile_start_out
del finalized_output, finalized_weights
gc.collect()
# Normalize by weights
weights = mx.maximum(weights, 1e-8)
output = output / weights
mx.eval(output)
# Emit remaining frames if callback provided
if on_frames_ready is not None:
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
if emitted < out_f:
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
mx.eval(remaining_output)
on_frames_ready(remaining_output, emitted)
del remaining_output
# Reset emitted frames counter for next call
if hasattr(decode_with_tiling, '_emitted_frames'):
del decode_with_tiling._emitted_frames
# Clean up weights
del weights
gc.collect()
# Convert back to original dtype if needed
return output.astype(latents.dtype)

View File

@@ -0,0 +1,97 @@
import mlx.core as mx
import mlx.nn as nn
from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention, _linear_dtype
class WanAttentionBlock(nn.Module):
"""Wan transformer block with learned modulation, self-attn, cross-attn, and FFN."""
def __init__(
self,
dim: int,
ffn_dim: int,
num_heads: int,
window_size: tuple = (-1, -1),
qk_norm: bool = True,
cross_attn_norm: bool = False,
eps: float = 1e-6,
):
super().__init__()
# Self-attention
self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
# Cross-attention (with optional norm on context)
self.norm3 = (
WanLayerNorm(dim, eps, elementwise_affine=True)
if cross_attn_norm
else None
)
self.cross_attn = WanCrossAttention(dim, num_heads, qk_norm, eps)
# Feed-forward
self.norm2 = WanLayerNorm(dim, eps)
self.ffn = WanFFN(dim, ffn_dim)
# Learned modulation: 6 vectors for scale/shift/gate (kept in float32 for precision)
self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(mx.float32)
def __call__(
self,
x: mx.array,
e: mx.array,
seq_lens: list,
grid_sizes: list,
freqs: mx.array,
context: mx.array,
context_lens: list | None = None,
cross_kv_cache: tuple | None = None,
rope_cos_sin: tuple | None = None,
attn_mask: mx.array | None = None,
) -> mx.array:
# Modulation: compute in float32 for precision, matching the reference
# which keeps residual x in float32 via torch.amp.autocast(dtype=float32).
# By keeping modulation in float32, type promotion ensures the residual
# stream stays float32 throughout all 30 layers (gate * output + x → float32).
mod = self.modulation + e # float32
e0, e1, e2, e3, e4, e5 = (
mod[:, :, 0, :], # shift for self-attn
mod[:, :, 1, :], # scale for self-attn
mod[:, :, 2, :], # gate for self-attn
mod[:, :, 3, :], # shift for ffn
mod[:, :, 4, :], # scale for ffn
mod[:, :, 5, :], # gate for ffn
)
# Self-attention with modulation (hidden state stays in w_dtype)
x_mod = self.norm1(x) * (1 + e1) + e0
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs, rope_cos_sin=rope_cos_sin, attn_mask=attn_mask)
x = x + y * e2
# Cross-attention (no modulation, just norm)
x_cross = self.norm3(x) if self.norm3 is not None else x
x = x + self.cross_attn(x_cross, context, context_lens, kv_cache=cross_kv_cache)
# FFN with modulation
x_mod = self.norm2(x) * (1 + e4) + e3
y = self.ffn(x_mod)
x = x + y * e5
return x
class WanFFN(nn.Module):
"""Gated feed-forward network with GELU(tanh) activation."""
def __init__(self, dim: int, ffn_dim: int):
super().__init__()
self.fc1 = nn.Linear(dim, ffn_dim)
self.act = nn.GELU(approx="tanh")
self.fc2 = nn.Linear(ffn_dim, dim)
def __call__(self, x: mx.array) -> mx.array:
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
x_w = x.astype(_linear_dtype(self.fc1))
return self.fc2(self.act(self.fc1(x_w)))

589
mlx_video/models/wan/vae.py Normal file
View File

@@ -0,0 +1,589 @@
"""3D VAE Decoder for Wan2.1/2.2 (compression 4×8×8).
Module structure mirrors original PyTorch checkpoint key hierarchy
so weights load directly without key sanitization.
"""
import mlx.core as mx
import mlx.nn as nn
import numpy as np
CACHE_T = 2
# Per-channel normalization statistics for z_dim=16
VAE_MEAN = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921,
]
VAE_STD = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160,
]
class CausalConv3d(nn.Module):
"""3D convolution with causal temporal padding."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple,
stride: int | tuple = 1,
padding: int | tuple = 0,
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(padding, int):
padding = (padding, padding, padding)
self.kernel_size = kernel_size
self.stride = stride
# Causal padding: match reference formula dilation*(k-1) + (1-stride)
# With dilation=1: k-stride (pads left only, no future context)
self._causal_pad_t = kernel_size[0] - stride[0]
self._pad_h = padding[1]
self._pad_w = padding[2]
# MLX Conv3d: weight shape [O, D, H, W, I]
self.weight = mx.zeros((out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels))
self.bias = mx.zeros((out_channels,))
def __call__(self, x: mx.array, cache_x: mx.array = None) -> mx.array:
"""x: [B, C, T, H, W] (channel-first)"""
b, c, t, h, w = x.shape
causal_pad = self._causal_pad_t
if cache_x is not None and causal_pad > 0:
x = mx.concatenate([cache_x, x], axis=2)
causal_pad = max(0, causal_pad - cache_x.shape[2])
if causal_pad > 0:
pad_t = mx.zeros((b, c, causal_pad, h, w), dtype=x.dtype)
x = mx.concatenate([pad_t, x], axis=2)
if self._pad_h > 0 or self._pad_w > 0:
x = mx.pad(x, [(0, 0), (0, 0), (0, 0),
(self._pad_h, self._pad_h), (self._pad_w, self._pad_w)])
x = x.transpose(0, 2, 3, 4, 1) # [B, T, H, W, C]
out = self._conv3d(x)
return out.transpose(0, 4, 1, 2, 3) # [B, O, T', H', W']
def _conv3d(self, x: mx.array) -> mx.array:
"""3D conv via sliding window + 2D conv per time step.
x: [B, T, H, W, C_in] -> [B, T_out, H_out, W_out, C_out]
"""
b, t, h, w, c_in = x.shape
kt, kh, kw = self.kernel_size
st, sh, sw = self.stride
t_out = (t - kt) // st + 1
# Pre-reshape weight: [O, D, H, W, I] -> [O, H, W, D*I]
w_2d = self.weight.transpose(0, 2, 3, 1, 4).reshape(
self.weight.shape[0], kh, kw, kt * c_in
)
outputs = []
for t_i in range(t_out):
t_start = t_i * st
window = x[:, t_start : t_start + kt]
window = window.transpose(0, 2, 3, 1, 4).reshape(b, h, w, kt * c_in)
out_2d = mx.conv2d(window, w_2d, stride=(sh, sw)) + self.bias
outputs.append(out_2d)
return mx.stack(outputs, axis=1)
class RMS_norm(nn.Module):
"""Channel-first L2 normalization matching original Wan VAE.
Uses F.normalize (L2 norm) with learned scale, equivalent to RMS norm.
images=True: gamma shape (dim, 1, 1) for 4D (per-frame) input.
images=False: gamma shape (dim, 1, 1, 1) for 5D video input.
"""
def __init__(self, dim: int, channel_first: bool = True, images: bool = True):
super().__init__()
self.channel_first = channel_first
self.scale = dim**0.5
if channel_first:
broadcastable = (1, 1) if images else (1, 1, 1)
self.gamma = mx.ones((dim, *broadcastable))
else:
self.gamma = mx.ones((dim,))
def __call__(self, x: mx.array) -> mx.array:
norm_dim = 1 if self.channel_first else -1
# L2 normalize along channel dim (matches F.normalize)
norm = mx.sqrt(mx.clip(mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None))
return (x / norm) * self.scale * self.gamma
class ResidualBlock(nn.Module):
"""Residual block with causal 3D convolutions.
Uses `residual` list with None gaps to match original PyTorch
nn.Sequential indices: [0]=norm, [1]=SiLU, [2]=conv, [3]=norm,
[4]=SiLU, [5]=Dropout, [6]=conv. Only indices 0,2,3,6 have params.
"""
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.residual = [
RMS_norm(in_dim, images=False), # [0]
None, # [1] SiLU
CausalConv3d(in_dim, out_dim, 3, padding=1), # [2]
RMS_norm(out_dim, images=False), # [3]
None, # [4] SiLU
None, # [5] Dropout
CausalConv3d(out_dim, out_dim, 3, padding=1), # [6]
]
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
h = x if self.shortcut is None else self.shortcut(x)
if feat_cache is not None:
# First conv: norm -> silu -> [cache] -> conv
x = nn.silu(self.residual[0](x))
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:]
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
x = self.residual[2](x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
# Second conv: norm -> silu -> [cache] -> conv
x = nn.silu(self.residual[3](x))
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:]
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
x = self.residual[6](x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = nn.silu(self.residual[0](x))
x = self.residual[2](x)
x = nn.silu(self.residual[3](x))
x = self.residual[6](x)
return x + h
class AttentionBlock(nn.Module):
"""Single-head spatial self-attention."""
def __init__(self, dim: int):
super().__init__()
self.norm = RMS_norm(dim, images=True)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
def __call__(self, x: mx.array) -> mx.array:
"""x: [B, C, T, H, W]"""
identity = x
b, c, t, h, w = x.shape
# [B,C,T,H,W] -> [B,T,C,H,W] -> [BT,C,H,W] -> norm -> [BT,H,W,C]
x = x.transpose(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
x = self.norm(x)
x = x.transpose(0, 2, 3, 1) # [BT, H, W, C]
qkv = self.to_qkv(x) # [BT, H, W, 3C]
qkv = qkv.reshape(b * t, h * w, 3, c).transpose(2, 0, 1, 3)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q[:, None, :, :] # [BT, 1, HW, C]
k = k[:, None, :, :]
v = v[:, None, :, :]
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=c**-0.5)
out = out.squeeze(1).reshape(b * t, h, w, c) # [BT, H, W, C]
out = self.proj(out) # [BT, H, W, C]
out = out.reshape(b, t, h, w, c).transpose(0, 4, 1, 2, 3) # [B, C, T, H, W]
return out + identity
class Resample(nn.Module):
"""Resample block matching original Wan VAE structure.
Supports both upsampling (decoder) and downsampling (encoder).
Uses list-based param storage to match original nn.Sequential key hierarchy.
"""
def __init__(self, dim: int, mode: str):
super().__init__()
assert mode in ("upsample2d", "upsample3d", "downsample2d", "downsample3d")
self.mode = mode
self.dim = dim
if mode.startswith("upsample"):
# resample.0 = Upsample (no params), resample.1 = Conv2d
self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)]
if mode == "upsample3d":
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
else:
# resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2)
self.resample = [None, nn.Conv2d(dim, dim, 3, stride=2)]
if mode == "downsample3d":
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
"""x: [B, C, T, H, W]"""
b, c, t, h, w = x.shape
if self.mode == "upsample3d":
# Temporal upsample via learned conv
x_t = self.time_conv(x) # [B, 2C, T, H, W]
x_t = x_t.reshape(b, 2, c, t, h, w)
x = mx.stack([x_t[:, 0], x_t[:, 1]], axis=3).reshape(b, c, t * 2, h, w)
t = t * 2
if self.mode.startswith("upsample"):
# Per-frame spatial upsample: nearest 2x + Conv2d
x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C]
x = mx.repeat(x, 2, axis=1)
x = mx.repeat(x, 2, axis=2)
x = self.resample[1](x) # Conv2d [BT, 2H, 2W, C//2]
c_out = x.shape[-1]
return x.reshape(b, t, h * 2, w * 2, c_out).transpose(0, 4, 1, 2, 3)
else:
# Per-frame spatial downsample: ZeroPad(0,1,0,1) + Conv2d(stride=2)
x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C]
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) # ZeroPad2d(0,1,0,1)
x = self.resample[1](x) # Conv2d stride=2
c_out = x.shape[-1]
h_out, w_out = x.shape[1], x.shape[2]
x = x.reshape(b, t, h_out, w_out, c_out).transpose(0, 4, 1, 2, 3)
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
# First chunk: save x, skip time_conv
feat_cache[idx] = x
feat_idx[0] += 1
else:
# Subsequent chunks: use cached frame as temporal context
cache_x = x[:, :, -1:]
x = self.time_conv(
x, cache_x=feat_cache[idx][:, :, -1:])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.time_conv(x)
return x
class Decoder3d(nn.Module):
"""3D VAE Decoder matching Wan2.1 architecture.
Uses flat `middle` and `upsamples` lists to match original
PyTorch nn.Sequential weight key hierarchy.
"""
def __init__(
self,
dim: int = 96,
z_dim: int = 16,
dim_mult: list = None,
num_res_blocks: int = 2,
temporal_upsample: list = None,
):
super().__init__()
if dim_mult is None:
dim_mult = [1, 2, 4, 4]
if temporal_upsample is None:
temporal_upsample = [True, True, False]
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# Middle: [ResBlock, AttentionBlock, ResBlock]
self.middle = [
ResidualBlock(dims[0], dims[0]),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0]),
]
# Flat upsample list matching original nn.Sequential indexing
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
if i in (1, 2, 3):
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim))
in_dim = out_dim
if i != len(dim_mult) - 1:
mode = "upsample3d" if temporal_upsample[i] else "upsample2d"
upsamples.append(Resample(out_dim, mode=mode))
self.upsamples = upsamples
# Output head: [RMS_norm, SiLU (no params), CausalConv3d]
self.head = [
RMS_norm(dims[-1], images=False), # [0]
None, # [1] SiLU
CausalConv3d(dims[-1], 3, 3, padding=1), # [2]
]
def __call__(self, x: mx.array) -> mx.array:
"""x: [B, z_dim, T, H, W] -> [B, 3, T_out, H_out, W_out]"""
x = self.conv1(x)
for layer in self.middle:
x = layer(x)
for layer in self.upsamples:
x = layer(x)
x = nn.silu(self.head[0](x))
x = self.head[2](x)
return x
class Encoder3d(nn.Module):
"""3D VAE Encoder matching Wan2.1 architecture.
Mirror of Decoder3d with downsampling instead of upsampling.
Uses flat lists to match original PyTorch nn.Sequential weight key hierarchy.
"""
def __init__(
self,
dim: int = 96,
z_dim: int = 16,
dim_mult: list = None,
num_res_blocks: int = 2,
temporal_downsample: list = None,
):
super().__init__()
if dim_mult is None:
dim_mult = [1, 2, 4, 4]
if temporal_downsample is None:
temporal_downsample = [False, True, True]
dims = [dim * u for u in [1] + dim_mult]
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
# Flat downsample list matching original nn.Sequential indexing
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
for _ in range(num_res_blocks):
downsamples.append(ResidualBlock(in_dim, out_dim))
in_dim = out_dim
if i != len(dim_mult) - 1:
mode = "downsample3d" if temporal_downsample[i] else "downsample2d"
downsamples.append(Resample(out_dim, mode=mode))
self.downsamples = downsamples
# Middle: [ResBlock, AttentionBlock, ResBlock]
self.middle = [
ResidualBlock(dims[-1], dims[-1]),
AttentionBlock(dims[-1]),
ResidualBlock(dims[-1], dims[-1]),
]
# Output head: [RMS_norm, SiLU (no params), CausalConv3d]
self.head = [
RMS_norm(dims[-1], images=False),
None, # SiLU
CausalConv3d(dims[-1], z_dim, 3, padding=1),
]
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
"""x: [B, 3, T, H, W] -> [B, z_dim, T_lat, H_lat, W_lat]"""
if feat_cache is not None:
# conv1 with caching
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:]
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
x = self.conv1(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
for layer in self.downsamples:
if feat_cache is not None and isinstance(layer, (ResidualBlock, Resample)):
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = layer(x)
for layer in self.middle:
if feat_cache is not None and isinstance(layer, ResidualBlock):
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = layer(x)
if feat_cache is not None:
# Head: norm -> silu -> [cache] -> conv
x = nn.silu(self.head[0](x))
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:]
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
x = self.head[2](x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = nn.silu(self.head[0](x))
x = self.head[2](x)
return x
class WanVAE(nn.Module):
"""Wan2.1 VAE wrapper with per-channel normalization.
Supports both encode (for I2V) and decode (for all models).
"""
def __init__(self, z_dim: int = 16, encoder: bool = False):
super().__init__()
self.z_dim = z_dim
self.mean = mx.array(VAE_MEAN)
self.std = mx.array(VAE_STD)
self.inv_std = 1.0 / self.std
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim=96, z_dim=z_dim)
if encoder:
self.encoder = Encoder3d(dim=96, z_dim=z_dim * 2)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
def encode(self, x: mx.array) -> mx.array:
"""Encode video to normalized latent using chunked encoding.
Uses chunked encoding with temporal caching to match reference behavior.
First frame encoded alone, then 4-frame chunks with cached context.
Args:
x: Video [B, 3, T, H, W] in [-1, 1]
Returns:
Normalized latent [B, z_dim, T_lat, H_lat, W_lat]
"""
# Count cacheable CausalConv3d slots in encoder
num_slots = self._count_encoder_cache_slots()
feat_cache = [None] * num_slots
t = x.shape[2]
num_chunks = 1 + (t - 1) // 4
out = None
for i in range(num_chunks):
feat_idx = [0]
if i == 0:
chunk = x[:, :, :1]
else:
chunk = x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i]
chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx)
if out is None:
out = chunk_out
else:
out = mx.concatenate([out, chunk_out], axis=2)
mu, _ = mx.split(self.conv1(out), 2, axis=1)
# Normalize: (mu - mean) * inv_std
mean = self.mean.reshape(1, -1, 1, 1, 1)
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
return (mu - mean) * inv_std
def _count_encoder_cache_slots(self) -> int:
"""Count CausalConv3d that participate in chunked encoding cache."""
count = 1 # encoder.conv1
for layer in self.encoder.downsamples:
if isinstance(layer, ResidualBlock):
count += 2 # two convs in residual path
elif isinstance(layer, Resample) and layer.mode == "downsample3d":
count += 1 # time_conv
for layer in self.encoder.middle:
if isinstance(layer, ResidualBlock):
count += 2
count += 1 # encoder.head CausalConv3d
return count
def decode(self, z: mx.array) -> mx.array:
"""Decode latent to video.
Args:
z: Normalized latent [B, z_dim, T, H, W]
Returns:
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
"""
mean = self.mean.reshape(1, -1, 1, 1, 1)
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
z = z / inv_std + mean
x = self.conv2(z)
out = self.decoder(x)
return mx.clip(out, -1, 1)
def decode_tiled(self, z: mx.array, tiling_config=None) -> mx.array:
"""Decode latent to video using tiling to reduce memory usage.
Splits the latent tensor into overlapping spatial/temporal tiles,
decodes each tile independently, and blends them with trapezoidal
masks. Reuses the LTX-2 tiling infrastructure.
Args:
z: Normalized latent [B, z_dim, T, H, W]
tiling_config: Optional TilingConfig. If None, uses default.
Returns:
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
"""
from mlx_video.models.wan.tiling import TilingConfig, decode_with_tiling
if tiling_config is None:
tiling_config = TilingConfig.default()
# Check if tiling is actually needed
_, _, f, h, w = z.shape
needs_tiling = False
if tiling_config.spatial_config is not None:
s_tile = tiling_config.spatial_config.tile_size_in_pixels // 8
if h > s_tile or w > s_tile:
needs_tiling = True
if tiling_config.temporal_config is not None:
t_tile = tiling_config.temporal_config.tile_size_in_frames // 4
if f > t_tile:
needs_tiling = True
if not needs_tiling:
return self.decode(z)
# Denormalize once (small tensor), then tile the denormalized latents
mean = self.mean.reshape(1, -1, 1, 1, 1)
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
z_denorm = z / inv_std + mean
def tile_decode(tile_latents, **kwargs):
x = self.conv2(tile_latents)
out = self.decoder(x)
return mx.clip(out, -1, 1)
return decode_with_tiling(
decoder_fn=tile_decode,
latents=z_denorm,
tiling_config=tiling_config,
spatial_scale=8, # 3× spatial 2× upsamples = 8×
temporal_scale=4, # 2× temporal upsamples × 2 = 4×
causal_temporal=False, # Wan2.1 uses non-causal temporal (T → 4T)
)

File diff suppressed because it is too large Load Diff

View File

@@ -19,7 +19,10 @@ dependencies = [
"tqdm", "tqdm",
"opencv-python>=4.12.0.88", "opencv-python>=4.12.0.88",
"Pillow>=10.3.0", "Pillow>=10.3.0",
"mlx-vlm" "mlx-vlm",
"imageio>=2.37.2",
"imageio-ffmpeg>=0.6.0",
"ftfy",
] ]
license = {text="MIT"} license = {text="MIT"}
authors = [ authors = [
@@ -42,6 +45,7 @@ Issues = "https://github.com/Blaizzy/mlx-video/issues"
[project.scripts] [project.scripts]
"mlx_video.generate" = "mlx_video.generate:main" "mlx_video.generate" = "mlx_video.generate:main"
"mlx_video.generate_wan" = "mlx_video.generate_wan:main"
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
include = ["mlx_video*"] include = ["mlx_video*"]

View File

@@ -0,0 +1,306 @@
#!/usr/bin/env python3
"""Compare two videos frame-by-frame with quality metrics.
Useful for validating MLX ports against reference PyTorch implementations.
Reports PSNR, SSIM, per-frame differences, temporal coherence, and color
fidelity. Optionally saves a side-by-side diff video.
Usage:
# Basic comparison
python scripts/video/compare_videos.py reference.mp4 test.mp4
# Save side-by-side diff visualization
python scripts/video/compare_videos.py ref.mp4 test.mp4 --diff-video diff.mp4
# Compare only first 64 frames
python scripts/video/compare_videos.py ref.mp4 test.mp4 --max-frames 64
# Adjust SSIM window size (default: 7)
python scripts/video/compare_videos.py ref.mp4 test.mp4 --ssim-win 11
"""
import argparse
import sys
import cv2
import numpy as np
def load_video(path, max_frames=None):
"""Load video frames as float32 numpy arrays (0-255 range)."""
cap = cv2.VideoCapture(path)
if not cap.isOpened():
print(f"Error: cannot open {path}")
sys.exit(1)
fps = cap.get(cv2.CAP_PROP_FPS)
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frames.append(frame.astype(np.float32))
if max_frames and len(frames) >= max_frames:
break
cap.release()
return frames, fps
def compute_psnr(a, b):
"""Peak Signal-to-Noise Ratio between two frames."""
mse = np.mean((a - b) ** 2)
if mse == 0:
return float("inf")
return 10 * np.log10(255.0**2 / mse)
def compute_ssim(a, b, win_size=7):
"""Structural Similarity Index (per-channel, averaged).
Uses the standard SSIM formula with default constants.
"""
C1 = (0.01 * 255) ** 2
C2 = (0.03 * 255) ** 2
kernel = cv2.getGaussianKernel(win_size, 1.5)
window = kernel @ kernel.T
ssim_channels = []
for c in range(a.shape[2]):
ac, bc = a[:, :, c], b[:, :, c]
mu_a = cv2.filter2D(ac, -1, window)
mu_b = cv2.filter2D(bc, -1, window)
mu_a_sq = mu_a**2
mu_b_sq = mu_b**2
mu_ab = mu_a * mu_b
sigma_a_sq = cv2.filter2D(ac**2, -1, window) - mu_a_sq
sigma_b_sq = cv2.filter2D(bc**2, -1, window) - mu_b_sq
sigma_ab = cv2.filter2D(ac * bc, -1, window) - mu_ab
num = (2 * mu_ab + C1) * (2 * sigma_ab + C2)
den = (mu_a_sq + mu_b_sq + C1) * (sigma_a_sq + sigma_b_sq + C2)
ssim_map = num / den
ssim_channels.append(np.mean(ssim_map))
return np.mean(ssim_channels)
def temporal_coherence(frames):
"""Mean frame-to-frame difference (lower = smoother)."""
if len(frames) < 2:
return 0.0
diffs = []
for i in range(1, len(frames)):
diffs.append(np.mean(np.abs(frames[i] - frames[i - 1])))
return np.mean(diffs)
def color_histogram_distance(a, b, bins=64):
"""Chi-squared distance between color histograms."""
dist = 0.0
for c in range(3):
ha, _ = np.histogram(a[:, :, c], bins=bins, range=(0, 256))
hb, _ = np.histogram(b[:, :, c], bins=bins, range=(0, 256))
ha = ha.astype(np.float64) / (ha.sum() + 1e-10)
hb = hb.astype(np.float64) / (hb.sum() + 1e-10)
dist += np.sum((ha - hb) ** 2 / (ha + hb + 1e-10)) / 2
return dist / 3
def make_diff_frame(a, b, scale=5.0):
"""Create a heatmap visualization of the absolute difference."""
diff = np.mean(np.abs(a - b), axis=2)
diff_scaled = np.clip(diff * scale, 0, 255).astype(np.uint8)
heatmap = cv2.applyColorMap(diff_scaled, cv2.COLORMAP_JET)
return heatmap
def analyze(ref_frames, test_frames, ssim_win=7):
"""Compute per-frame and aggregate metrics."""
n = min(len(ref_frames), len(test_frames))
psnrs = []
ssims = []
mean_diffs = []
max_diffs = []
color_dists = []
for i in range(n):
r, t = ref_frames[i], test_frames[i]
psnrs.append(compute_psnr(r, t))
ssims.append(compute_ssim(r, t, ssim_win))
absdiff = np.abs(r - t)
mean_diffs.append(np.mean(absdiff))
max_diffs.append(np.max(absdiff))
color_dists.append(color_histogram_distance(r, t))
ref_tc = temporal_coherence(ref_frames[:n])
test_tc = temporal_coherence(test_frames[:n])
return {
"num_frames": n,
"psnr": np.array(psnrs),
"ssim": np.array(ssims),
"mean_diff": np.array(mean_diffs),
"max_diff": np.array(max_diffs),
"color_dist": np.array(color_dists),
"ref_temporal_coherence": ref_tc,
"test_temporal_coherence": test_tc,
}
def print_report(results, ref_path, test_path):
"""Print a formatted comparison report."""
n = results["num_frames"]
psnr = results["psnr"]
ssim = results["ssim"]
md = results["mean_diff"]
mx = results["max_diff"]
cd = results["color_dist"]
print("=" * 72)
print("VIDEO COMPARISON REPORT")
print("=" * 72)
print(f" Reference: {ref_path}")
print(f" Test: {test_path}")
print(f" Frames compared: {n}")
print()
print("AGGREGATE METRICS")
print("-" * 40)
print(f" PSNR (dB): mean={np.mean(psnr):6.2f} min={np.min(psnr):6.2f} max={np.max(psnr):6.2f}")
print(f" SSIM: mean={np.mean(ssim):.4f} min={np.min(ssim):.4f} max={np.max(ssim):.4f}")
print(f" Mean diff: mean={np.mean(md):6.2f} min={np.min(md):6.2f} max={np.max(md):6.2f}")
print(f" Max diff: mean={np.mean(mx):6.1f} min={np.min(mx):6.1f} max={np.max(mx):6.1f}")
print(f" Color dist: mean={np.mean(cd):.4f} min={np.min(cd):.4f} max={np.max(cd):.4f}")
print()
print("TEMPORAL COHERENCE (mean frame-to-frame diff, lower = smoother)")
print("-" * 40)
print(f" Reference: {results['ref_temporal_coherence']:.2f}")
print(f" Test: {results['test_temporal_coherence']:.2f}")
ratio = results["test_temporal_coherence"] / (results["ref_temporal_coherence"] + 1e-10)
print(f" Ratio: {ratio:.2f}x {'(test is smoother)' if ratio < 1 else '(test is jerkier)' if ratio > 1.05 else '(similar)'}")
print()
# Identify worst frames
print("WORST FRAMES (by PSNR)")
print("-" * 40)
worst_idx = np.argsort(psnr)[:5]
for i in worst_idx:
print(f" Frame {i:4d}: PSNR={psnr[i]:6.2f} dB SSIM={ssim[i]:.4f} mean_diff={md[i]:.2f}")
print()
# Quality assessment
mean_psnr = np.mean(psnr)
mean_ssim = np.mean(ssim)
print("QUALITY ASSESSMENT")
print("-" * 40)
if mean_psnr > 40:
grade = "Excellent"
elif mean_psnr > 35:
grade = "Good"
elif mean_psnr > 30:
grade = "Fair"
elif mean_psnr > 25:
grade = "Poor"
else:
grade = "Very different"
print(f" Overall: {grade} (PSNR={mean_psnr:.1f} dB, SSIM={mean_ssim:.4f})")
if mean_psnr < 30:
print(" ⚠ Videos differ significantly — likely a bug or different generation seed")
print("=" * 72)
def save_diff_video(ref_frames, test_frames, output_path, fps, scale=5.0):
"""Save a side-by-side video: reference | test | diff heatmap."""
n = min(len(ref_frames), len(test_frames))
h, w = ref_frames[0].shape[:2]
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(output_path, fourcc, fps, (w * 3, h))
for i in range(n):
r = ref_frames[i].astype(np.uint8)
t = test_frames[i].astype(np.uint8)
d = make_diff_frame(ref_frames[i], test_frames[i], scale)
combined = np.hstack([r, t, d])
out.write(combined)
out.release()
print(f"Diff video saved to {output_path}")
def main():
parser = argparse.ArgumentParser(
description="Compare two videos frame-by-frame with quality metrics"
)
parser.add_argument("reference", help="Path to reference video")
parser.add_argument("test", help="Path to test video")
parser.add_argument(
"--diff-video", help="Save side-by-side diff visualization to this path"
)
parser.add_argument(
"--max-frames", type=int, help="Compare only first N frames"
)
parser.add_argument(
"--ssim-win", type=int, default=7, help="SSIM window size (default: 7)"
)
parser.add_argument(
"--diff-scale",
type=float,
default=5.0,
help="Diff heatmap amplification (default: 5.0)",
)
parser.add_argument(
"--csv", help="Export per-frame metrics to CSV file"
)
args = parser.parse_args()
print(f"Loading reference: {args.reference}")
ref_frames, ref_fps = load_video(args.reference, args.max_frames)
print(f"{len(ref_frames)} frames, {ref_fps:.1f} fps, {ref_frames[0].shape[1]}x{ref_frames[0].shape[0]}")
print(f"Loading test: {args.test}")
test_frames, test_fps = load_video(args.test, args.max_frames)
print(f"{len(test_frames)} frames, {test_fps:.1f} fps, {test_frames[0].shape[1]}x{test_frames[0].shape[0]}")
if ref_frames[0].shape != test_frames[0].shape:
print(f"Warning: resolution mismatch {ref_frames[0].shape} vs {test_frames[0].shape}")
print("Resizing test frames to match reference...")
h, w = ref_frames[0].shape[:2]
test_frames = [
cv2.resize(f, (w, h), interpolation=cv2.INTER_LANCZOS4)
for f in test_frames
]
print("Computing metrics...")
results = analyze(ref_frames, test_frames, args.ssim_win)
print()
print_report(results, args.reference, args.test)
if args.diff_video:
save_diff_video(ref_frames, test_frames, args.diff_video, ref_fps, args.diff_scale)
if args.csv:
import csv
with open(args.csv, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["frame", "psnr", "ssim", "mean_diff", "max_diff", "color_dist"])
for i in range(results["num_frames"]):
writer.writerow([
i,
f"{results['psnr'][i]:.4f}",
f"{results['ssim'][i]:.6f}",
f"{results['mean_diff'][i]:.4f}",
f"{results['max_diff'][i]:.1f}",
f"{results['color_dist'][i]:.6f}",
])
print(f"Per-frame metrics saved to {args.csv}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,348 @@
#!/usr/bin/env python3
"""Analyze quality of a single generated video.
Reports sharpness, temporal stability, color distribution, motion smoothness,
chunk boundary artifacts, and common generation defects. Useful for quick
quality checks during model porting and debugging.
Usage:
# Basic analysis
python scripts/video/video_quality.py output.mp4
# With chunk boundary analysis (e.g., 32 frames/chunk)
python scripts/video/video_quality.py output.mp4 --chunk-size 32
# Detailed per-frame CSV export
python scripts/video/video_quality.py output.mp4 --csv metrics.csv
# Analyze specific frame range
python scripts/video/video_quality.py output.mp4 --start 0 --end 64
"""
import argparse
import sys
import cv2
import numpy as np
def load_video(path, start=0, end=None):
"""Load video frames as float32 numpy arrays (0-255 range)."""
cap = cv2.VideoCapture(path)
if not cap.isOpened():
print(f"Error: cannot open {path}")
sys.exit(1)
fps = cap.get(cv2.CAP_PROP_FPS)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if start > 0:
cap.set(cv2.CAP_PROP_POS_FRAMES, start)
frames = []
idx = start
while True:
ret, frame = cap.read()
if not ret:
break
frames.append(frame.astype(np.float32))
idx += 1
if end and idx >= end:
break
cap.release()
return frames, fps, total
def sharpness_laplacian(frame):
"""Laplacian variance — higher means sharper."""
gray = cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_BGR2GRAY)
return cv2.Laplacian(gray, cv2.CV_64F).var()
def sharpness_gradient(frame):
"""Mean gradient magnitude — higher means more edges/detail."""
gray = cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_BGR2GRAY).astype(np.float32)
gx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3)
gy = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3)
return np.mean(np.sqrt(gx**2 + gy**2))
def color_stats(frame):
"""Per-channel mean and std in BGR order."""
means = [np.mean(frame[:, :, c]) for c in range(3)]
stds = [np.std(frame[:, :, c]) for c in range(3)]
return means, stds
def detect_uniform_color(frame, std_threshold=15.0):
"""Detect if frame is near-uniform (common failure mode)."""
return np.std(frame) < std_threshold
def detect_noise(frame, threshold=200.0):
"""High Laplacian variance with low gradient can indicate noise."""
lap = sharpness_laplacian(frame)
grad = sharpness_gradient(frame)
# Noise has high variance but less coherent edges
return lap > threshold and grad < 5.0
def frame_difference(a, b):
"""Mean absolute pixel difference between frames."""
return np.mean(np.abs(a - b))
def optical_flow_magnitude(prev, curr):
"""Mean optical flow magnitude (Farneback method)."""
prev_gray = cv2.cvtColor(prev.astype(np.uint8), cv2.COLOR_BGR2GRAY)
curr_gray = cv2.cvtColor(curr.astype(np.uint8), cv2.COLOR_BGR2GRAY)
flow = cv2.calcOpticalFlowFarneback(
prev_gray, curr_gray, None, 0.5, 3, 15, 3, 5, 1.2, 0
)
mag = np.sqrt(flow[..., 0] ** 2 + flow[..., 1] ** 2)
return np.mean(mag), np.max(mag)
def analyze_video(frames, chunk_size=None, compute_flow=False):
"""Compute per-frame and aggregate quality metrics."""
n = len(frames)
metrics = {
"sharpness_lap": [],
"sharpness_grad": [],
"brightness": [],
"contrast": [],
"color_mean_b": [],
"color_mean_g": [],
"color_mean_r": [],
"frame_diff": [],
"is_uniform": [],
"is_noisy": [],
}
if compute_flow:
metrics["flow_mean"] = []
metrics["flow_max"] = []
for i in range(n):
f = frames[i]
metrics["sharpness_lap"].append(sharpness_laplacian(f))
metrics["sharpness_grad"].append(sharpness_gradient(f))
metrics["brightness"].append(np.mean(f))
metrics["contrast"].append(np.std(f))
means, _ = color_stats(f)
metrics["color_mean_b"].append(means[0])
metrics["color_mean_g"].append(means[1])
metrics["color_mean_r"].append(means[2])
metrics["is_uniform"].append(detect_uniform_color(f))
metrics["is_noisy"].append(detect_noise(f))
if i > 0:
metrics["frame_diff"].append(frame_difference(frames[i - 1], f))
if compute_flow:
fm, fmx = optical_flow_magnitude(frames[i - 1], f)
metrics["flow_mean"].append(fm)
metrics["flow_max"].append(fmx)
else:
metrics["frame_diff"].append(0.0)
if compute_flow:
metrics["flow_mean"].append(0.0)
metrics["flow_max"].append(0.0)
# Convert to arrays
for k in metrics:
metrics[k] = np.array(metrics[k])
# Chunk boundary analysis
if chunk_size and n > chunk_size:
boundaries = list(range(chunk_size, n, chunk_size))
boundary_metrics = []
for b in boundaries:
if b < n and b > 0:
pre = metrics["frame_diff"][b - 1] if b > 1 else metrics["frame_diff"][1]
at = metrics["frame_diff"][b]
ratio = at / (pre + 1e-10)
brightness_jump = metrics["brightness"][b] - metrics["brightness"][b - 1]
contrast_jump = (
(metrics["contrast"][b] - metrics["contrast"][b - 1])
/ (metrics["contrast"][b - 1] + 1e-10)
* 100
)
sharpness_jump = (
(metrics["sharpness_lap"][b] - metrics["sharpness_lap"][b - 1])
/ (metrics["sharpness_lap"][b - 1] + 1e-10)
* 100
)
boundary_metrics.append(
{
"frame": b,
"diff_ratio": ratio,
"brightness_jump": brightness_jump,
"contrast_jump_pct": contrast_jump,
"sharpness_jump_pct": sharpness_jump,
}
)
metrics["boundaries"] = boundary_metrics
return metrics
def print_report(metrics, path, fps, total_frames, frames_analyzed):
"""Print a formatted quality report."""
sl = metrics["sharpness_lap"]
sg = metrics["sharpness_grad"]
br = metrics["brightness"]
ct = metrics["contrast"]
fd = metrics["frame_diff"]
print("=" * 72)
print("VIDEO QUALITY REPORT")
print("=" * 72)
print(f" File: {path}")
print(f" Total frames: {total_frames} Analyzed: {frames_analyzed} FPS: {fps:.1f}")
duration = total_frames / fps if fps > 0 else 0
print(f" Duration: {duration:.1f}s")
print()
# Defect detection
n_uniform = int(np.sum(metrics["is_uniform"]))
n_noisy = int(np.sum(metrics["is_noisy"]))
if n_uniform > 0 or n_noisy > 0:
print("⚠ DEFECTS DETECTED")
print("-" * 40)
if n_uniform:
frames_list = np.where(metrics["is_uniform"])[0][:10]
print(f" Uniform/blank frames: {n_uniform} — frames {list(frames_list)}{'...' if n_uniform > 10 else ''}")
if n_noisy:
frames_list = np.where(metrics["is_noisy"])[0][:10]
print(f" Noisy frames: {n_noisy} — frames {list(frames_list)}{'...' if n_noisy > 10 else ''}")
print()
print("SHARPNESS")
print("-" * 40)
print(f" Laplacian var: mean={np.mean(sl):8.1f} min={np.min(sl):8.1f} max={np.max(sl):8.1f} std={np.std(sl):.1f}")
print(f" Gradient mag: mean={np.mean(sg):8.2f} min={np.min(sg):8.2f} max={np.max(sg):8.2f} std={np.std(sg):.2f}")
if np.std(sl) / (np.mean(sl) + 1e-10) > 0.3:
print(" ⚠ High sharpness variation — possible blur artifacts")
print()
print("BRIGHTNESS & CONTRAST")
print("-" * 40)
print(f" Brightness: mean={np.mean(br):6.1f} min={np.min(br):6.1f} max={np.max(br):6.1f} std={np.std(br):.2f}")
print(f" Contrast (std): mean={np.mean(ct):6.1f} min={np.min(ct):6.1f} max={np.max(ct):6.1f} std={np.std(ct):.2f}")
if np.std(br) > 3.0:
print(" ⚠ Brightness instability — may indicate chunk boundary artifacts")
print()
print("COLOR DISTRIBUTION (BGR)")
print("-" * 40)
print(f" Blue: mean={np.mean(metrics['color_mean_b']):6.1f} std={np.std(metrics['color_mean_b']):.2f}")
print(f" Green: mean={np.mean(metrics['color_mean_g']):6.1f} std={np.std(metrics['color_mean_g']):.2f}")
print(f" Red: mean={np.mean(metrics['color_mean_r']):6.1f} std={np.std(metrics['color_mean_r']):.2f}")
print()
print("TEMPORAL STABILITY")
print("-" * 40)
fd_nz = fd[1:] # skip first frame (always 0)
if len(fd_nz) > 0:
print(f" Frame diff: mean={np.mean(fd_nz):6.2f} min={np.min(fd_nz):6.2f} max={np.max(fd_nz):6.2f} std={np.std(fd_nz):.2f}")
if np.std(fd_nz) / (np.mean(fd_nz) + 1e-10) > 0.5:
print(" ⚠ High diff variance — jitter or discontinuities")
if "flow_mean" in metrics:
fm = metrics["flow_mean"][1:]
print(f" Optical flow: mean={np.mean(fm):6.2f} max_frame={np.max(metrics['flow_max'][1:]):.1f}")
print()
# Chunk boundaries
if "boundaries" in metrics and metrics["boundaries"]:
print("CHUNK BOUNDARIES")
print("-" * 40)
print(f" {'Frame':>6} {'Diff ratio':>10} {'Brightness':>10} {'Contrast %':>10} {'Sharpness %':>11}")
for bm in metrics["boundaries"]:
print(
f" {bm['frame']:6d}"
f" {bm['diff_ratio']:10.2f}x"
f" {bm['brightness_jump']:+10.1f}"
f" {bm['contrast_jump_pct']:+10.1f}%"
f" {bm['sharpness_jump_pct']:+11.1f}%"
)
avg_ratio = np.mean([b["diff_ratio"] for b in metrics["boundaries"]])
if avg_ratio > 2.0:
print(f" ⚠ Boundary diff ratio {avg_ratio:.1f}x — visible chunk transitions")
print()
# Overall grade
print("OVERALL ASSESSMENT")
print("-" * 40)
issues = []
if n_uniform > 0:
issues.append("uniform/blank frames")
if n_noisy > 0:
issues.append("noisy frames")
if np.std(br) > 3.0:
issues.append("brightness flicker")
if np.std(sl) / (np.mean(sl) + 1e-10) > 0.3:
issues.append("sharpness variation")
if "boundaries" in metrics and metrics["boundaries"]:
avg_ratio = np.mean([b["diff_ratio"] for b in metrics["boundaries"]])
if avg_ratio > 2.0:
issues.append("chunk boundary artifacts")
if issues:
print(f" Issues found: {', '.join(issues)}")
else:
print(" ✓ No significant quality issues detected")
print("=" * 72)
def main():
parser = argparse.ArgumentParser(
description="Analyze quality of a single generated video"
)
parser.add_argument("video", help="Path to video file")
parser.add_argument(
"--chunk-size",
type=int,
help="Frames per chunk for boundary analysis (e.g., 32)",
)
parser.add_argument(
"--start", type=int, default=0, help="Start frame (default: 0)"
)
parser.add_argument("--end", type=int, help="End frame (default: all)")
parser.add_argument(
"--flow",
action="store_true",
help="Compute optical flow (slower but more detailed)",
)
parser.add_argument("--csv", help="Export per-frame metrics to CSV")
args = parser.parse_args()
print(f"Loading: {args.video}")
frames, fps, total = load_video(args.video, args.start, args.end)
h, w = frames[0].shape[:2]
print(f"{len(frames)} frames, {fps:.1f} fps, {w}x{h}")
print("Analyzing...")
metrics = analyze_video(frames, args.chunk_size, args.flow)
print()
print_report(metrics, args.video, fps, total, len(frames))
if args.csv:
import csv
keys = [
"sharpness_lap", "sharpness_grad", "brightness", "contrast",
"color_mean_b", "color_mean_g", "color_mean_r", "frame_diff",
]
if args.flow:
keys += ["flow_mean", "flow_max"]
with open(args.csv, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["frame"] + keys)
for i in range(len(frames)):
row = [i] + [f"{metrics[k][i]:.4f}" for k in keys]
writer.writerow(row)
print(f"Per-frame metrics saved to {args.csv}")
if __name__ == "__main__":
main()

4
tests/conftest.py Normal file
View File

@@ -0,0 +1,4 @@
import os
import sys
sys.path.insert(0, os.path.dirname(__file__))

372
tests/test_wan_attention.py Normal file
View File

@@ -0,0 +1,372 @@
"""Tests for Wan attention components and RoPE."""
import mlx.core as mx
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# RoPE Tests
# ---------------------------------------------------------------------------
class TestRoPE:
"""Tests for 3-way factorized RoPE."""
def test_rope_params_shape(self):
from mlx_video.models.wan.rope import rope_params
freqs = rope_params(1024, 64)
mx.eval(freqs)
assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2]
def test_rope_params_different_dims(self):
from mlx_video.models.wan.rope import rope_params
for dim in [32, 64, 128]:
freqs = rope_params(512, dim)
mx.eval(freqs)
assert freqs.shape == (512, dim // 2, 2)
def test_rope_params_cos_sin_range(self):
from mlx_video.models.wan.rope import rope_params
freqs = rope_params(256, 64)
mx.eval(freqs)
cos_vals = np.array(freqs[:, :, 0])
sin_vals = np.array(freqs[:, :, 1])
assert np.all(cos_vals >= -1.0) and np.all(cos_vals <= 1.0)
assert np.all(sin_vals >= -1.0) and np.all(sin_vals <= 1.0)
def test_rope_params_position_zero(self):
"""At position 0, cos should be 1 and sin should be 0."""
from mlx_video.models.wan.rope import rope_params
freqs = rope_params(10, 64)
mx.eval(freqs)
np.testing.assert_allclose(np.array(freqs[0, :, 0]), 1.0, atol=1e-6)
np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6)
def test_rope_apply_output_shape(self):
from mlx_video.models.wan.rope import rope_params, rope_apply
B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim
x = mx.random.normal((B, L, N, D))
freqs = rope_params(1024, D)
grid_sizes = [(2, 3, 4)] # F*H*W = 24 = L
out = rope_apply(x, grid_sizes, freqs)
mx.eval(out)
assert out.shape == (B, L, N, D)
def test_rope_apply_preserves_norm(self):
"""RoPE rotation should preserve vector norms."""
from mlx_video.models.wan.rope import rope_params, rope_apply
B, N, D = 1, 2, 16
F, H, W = 2, 3, 4
L = F * H * W
x = mx.random.normal((B, L, N, D))
freqs = rope_params(1024, D)
out = rope_apply(x, [(F, H, W)], freqs)
mx.eval(x, out)
x_np = np.array(x[0])
out_np = np.array(out[0])
for i in range(L):
for h in range(N):
norm_in = np.linalg.norm(x_np[i, h])
norm_out = np.linalg.norm(out_np[i, h])
np.testing.assert_allclose(norm_in, norm_out, rtol=1e-4)
def test_rope_apply_with_padding(self):
"""When seq_len < L, extra tokens should be preserved unchanged."""
from mlx_video.models.wan.rope import rope_params, rope_apply
B, N, D = 1, 2, 16
F, H, W = 2, 2, 2
seq_len = F * H * W # 8
pad = 4
L = seq_len + pad
x = mx.random.normal((B, L, N, D))
freqs = rope_params(1024, D)
out = rope_apply(x, [(F, H, W)], freqs)
mx.eval(x, out)
# Padded tokens should be unchanged
np.testing.assert_allclose(
np.array(out[0, seq_len:]),
np.array(x[0, seq_len:]),
atol=1e-6,
)
def test_rope_apply_batch(self):
"""Test with batch_size > 1 and different grid sizes."""
from mlx_video.models.wan.rope import rope_params, rope_apply
B, N, D = 2, 2, 16
grids = [(2, 3, 4), (2, 3, 4)]
L = 2 * 3 * 4
x = mx.random.normal((B, L, N, D))
freqs = rope_params(1024, D)
out = rope_apply(x, grids, freqs)
mx.eval(out)
assert out.shape == (B, L, N, D)
def test_rope_frequency_split(self):
"""Verify the 3-way frequency dimension split matches Wan2.2 convention."""
D = 128 # head_dim for 14B model
half_d = D // 2
d_t = half_d - 2 * (half_d // 3)
d_h = half_d // 3
d_w = half_d // 3
assert d_t + d_h + d_w == half_d
# Temporal gets more capacity
assert d_t >= d_h
assert d_t >= d_w
# ---------------------------------------------------------------------------
# Attention Tests
# ---------------------------------------------------------------------------
class TestWanRMSNorm:
def test_output_shape(self):
from mlx_video.models.wan.attention import WanRMSNorm
norm = WanRMSNorm(64)
x = mx.random.normal((2, 10, 64))
out = norm(x)
mx.eval(out)
assert out.shape == (2, 10, 64)
def test_zero_mean_variance(self):
"""RMS norm should make RMS ≈ 1 before scaling."""
from mlx_video.models.wan.attention import WanRMSNorm
norm = WanRMSNorm(64)
x = mx.random.normal((1, 5, 64)) * 10.0
out = norm(x)
mx.eval(out)
out_np = np.array(out[0])
for i in range(5):
rms = np.sqrt(np.mean(out_np[i] ** 2))
# After RMS norm with weight=1, RMS should be ~1
np.testing.assert_allclose(rms, 1.0, rtol=0.1)
def test_dtype_preservation(self):
"""RMSNorm weight is float32, so output is promoted to float32."""
from mlx_video.models.wan.attention import WanRMSNorm
norm = WanRMSNorm(32)
x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16)
out = norm(x)
mx.eval(out)
# Weight is float32, so multiplication promotes result to float32
assert out.dtype == mx.float32
class TestWanLayerNorm:
def test_output_shape(self):
from mlx_video.models.wan.attention import WanLayerNorm
norm = WanLayerNorm(64)
x = mx.random.normal((2, 10, 64))
out = norm(x)
mx.eval(out)
assert out.shape == (2, 10, 64)
def test_without_affine(self):
from mlx_video.models.wan.attention import WanLayerNorm
norm = WanLayerNorm(64, elementwise_affine=False)
x = mx.random.normal((1, 4, 64))
out = norm(x)
mx.eval(out)
# Mean should be ~0, variance should be ~1
out_np = np.array(out[0])
for i in range(4):
np.testing.assert_allclose(np.mean(out_np[i]), 0.0, atol=0.05)
np.testing.assert_allclose(np.std(out_np[i]), 1.0, rtol=0.1)
def test_with_affine(self):
from mlx_video.models.wan.attention import WanLayerNorm
norm = WanLayerNorm(32, elementwise_affine=True)
assert hasattr(norm, "weight")
assert hasattr(norm, "bias")
x = mx.random.normal((1, 4, 32))
out = norm(x)
mx.eval(out)
assert out.shape == (1, 4, 32)
class TestWanSelfAttention:
def setup_method(self):
mx.random.seed(42)
self.dim = 64
self.num_heads = 4
def test_output_shape(self):
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads)
B, L = 1, 24
F, H, W = 2, 3, 4
x = mx.random.normal((B, L, self.dim))
freqs = rope_params(1024, self.dim // self.num_heads)
out = attn(x, seq_lens=[L], grid_sizes=[(F, H, W)], freqs=freqs)
mx.eval(out)
assert out.shape == (B, L, self.dim)
def test_with_qk_norm(self):
from mlx_video.models.wan.attention import WanSelfAttention
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True)
assert attn.norm_q is not None
assert attn.norm_k is not None
def test_without_qk_norm(self):
from mlx_video.models.wan.attention import WanSelfAttention
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
assert attn.norm_q is None
assert attn.norm_k is None
def test_masking(self):
"""Test that masking works: shorter seq_lens should mask later tokens."""
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
B, L = 1, 24
F, H, W = 2, 3, 4
x = mx.random.normal((B, L, self.dim))
freqs = rope_params(1024, self.dim // self.num_heads)
# Full sequence
out_full = attn(x, seq_lens=[L], grid_sizes=[(F, H, W)], freqs=freqs)
# Shorter sequence (mask last 4 tokens)
out_masked = attn(x, seq_lens=[L - 4], grid_sizes=[(F, H, W)], freqs=freqs)
mx.eval(out_full, out_masked)
# Outputs should differ when masking is applied
assert not np.allclose(np.array(out_full), np.array(out_masked), atol=1e-5)
class TestWanCrossAttention:
def setup_method(self):
mx.random.seed(42)
self.dim = 64
self.num_heads = 4
def test_output_shape(self):
from mlx_video.models.wan.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
B, L_q, L_kv = 1, 24, 16
x = mx.random.normal((B, L_q, self.dim))
context = mx.random.normal((B, L_kv, self.dim))
out = attn(x, context)
mx.eval(out)
assert out.shape == (B, L_q, self.dim)
def test_with_context_mask(self):
from mlx_video.models.wan.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
B, L_q, L_kv = 1, 12, 16
x = mx.random.normal((B, L_q, self.dim))
context = mx.random.normal((B, L_kv, self.dim))
out = attn(x, context, context_lens=[10])
mx.eval(out)
assert out.shape == (B, L_q, self.dim)
# ---------------------------------------------------------------------------
# bfloat16 Autocast Tests
# ---------------------------------------------------------------------------
class TestBFloat16Autocast:
"""Tests that attention and FFN cast inputs to weight dtype (bfloat16)
for efficient matmul, matching official PyTorch autocast behavior."""
def setup_method(self):
mx.random.seed(42)
self.dim = 64
self.num_heads = 4
@staticmethod
def _to_bf16(params):
"""Recursively cast all arrays in params to bfloat16."""
if isinstance(params, dict):
return {k: TestBFloat16Autocast._to_bf16(v) for k, v in params.items()}
elif isinstance(params, list):
return [TestBFloat16Autocast._to_bf16(v) for v in params]
elif isinstance(params, mx.array):
return params.astype(mx.bfloat16)
return params
def test_self_attn_casts_to_weight_dtype(self):
"""Self-attention should cast input to weight dtype for QKV projections."""
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
x = mx.random.normal((1, 8, self.dim))
freqs = rope_params(1024, self.dim // self.num_heads)
out = attn(x, seq_lens=[8], grid_sizes=[(2, 2, 2)], freqs=freqs)
mx.eval(out)
assert out.shape == (1, 8, self.dim)
assert np.isfinite(np.array(out.astype(mx.float32))).all()
def test_cross_attn_casts_to_weight_dtype(self):
"""Cross-attention should cast input to weight dtype."""
from mlx_video.models.wan.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
x = mx.random.normal((1, 8, self.dim))
ctx = mx.random.normal((1, 4, self.dim))
out = attn(x, ctx)
mx.eval(out)
assert out.shape == (1, 8, self.dim)
assert np.isfinite(np.array(out.astype(mx.float32))).all()
def test_cross_attn_kv_cache_uses_weight_dtype(self):
"""prepare_kv should cast context to weight dtype."""
from mlx_video.models.wan.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
ctx = mx.random.normal((1, 4, self.dim))
k, v = attn.prepare_kv(ctx)
mx.eval(k, v)
assert k.dtype == mx.bfloat16
assert v.dtype == mx.bfloat16
def test_ffn_casts_to_weight_dtype(self):
"""FFN should cast input to weight dtype for linear layers."""
from mlx_video.models.wan.transformer import WanFFN
ffn = WanFFN(self.dim, 128)
ffn.update(self._to_bf16(ffn.parameters()))
x = mx.random.normal((1, 8, self.dim))
out = ffn(x)
mx.eval(out)
assert out.shape == (1, 8, self.dim)
assert np.isfinite(np.array(out.astype(mx.float32))).all()
def test_self_attn_rope_in_float32(self):
"""RoPE should be applied in float32 for precision, even with bf16 weights."""
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
x = mx.random.normal((1, 8, self.dim))
freqs = rope_params(1024, self.dim // self.num_heads)
assert freqs.dtype == mx.float32
out = attn(x, seq_lens=[8], grid_sizes=[(2, 2, 2)], freqs=freqs)
mx.eval(out)
assert np.isfinite(np.array(out.astype(mx.float32))).all()
def test_block_float32_residual_with_bf16_weights(self):
"""Full block: residual stream stays float32, matmuls use bf16 weights."""
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan.rope import rope_params
block = WanAttentionBlock(self.dim, 128, self.num_heads, cross_attn_norm=True)
block.update(self._to_bf16(block.parameters()))
B, L = 1, 8
x = mx.random.normal((B, L, self.dim))
e = mx.random.normal((B, L, 6, self.dim))
ctx = mx.random.normal((B, 4, self.dim))
freqs = rope_params(1024, self.dim // self.num_heads)
out = block(x, e, [L], [(2, 2, 2)], freqs, ctx)
mx.eval(out)
assert out.dtype == mx.float32
assert np.isfinite(np.array(out)).all()

125
tests/test_wan_config.py Normal file
View File

@@ -0,0 +1,125 @@
"""Tests for Wan model configuration."""
import pytest
# ---------------------------------------------------------------------------
# Config Tests
# ---------------------------------------------------------------------------
class TestWanModelConfig:
"""Tests for WanModelConfig dataclass."""
def test_default_values(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig()
assert config.dim == 5120
assert config.ffn_dim == 13824
assert config.num_heads == 40
assert config.num_layers == 40
assert config.in_dim == 16
assert config.out_dim == 16
assert config.patch_size == (1, 2, 2)
assert config.vae_stride == (4, 8, 8)
assert config.vae_z_dim == 16
assert config.boundary == 0.875
assert config.sample_shift == 12.0
assert config.sample_steps == 40
assert config.sample_guide_scale == (3.0, 4.0)
assert config.num_train_timesteps == 1000
assert config.qk_norm is True
assert config.cross_attn_norm is True
assert config.text_len == 512
def test_head_dim_property(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig()
assert config.head_dim == 128 # 5120 // 40
def test_to_dict_roundtrip(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig()
d = config.to_dict()
assert isinstance(d, dict)
assert d["dim"] == 5120
assert d["patch_size"] == (1, 2, 2)
assert d["boundary"] == 0.875
def test_t5_config_values(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig()
assert config.t5_vocab_size == 256384
assert config.t5_dim == 4096
assert config.t5_dim_attn == 4096
assert config.t5_dim_ffn == 10240
assert config.t5_num_heads == 64
assert config.t5_num_layers == 24
assert config.t5_num_buckets == 32
# ---------------------------------------------------------------------------
# Wan2.1 Config Tests
# ---------------------------------------------------------------------------
class TestWan21Config:
"""Tests for Wan2.1 config presets."""
def test_wan21_14b_factory(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
assert config.model_version == "2.1"
assert config.dual_model is False
assert config.dim == 5120
assert config.ffn_dim == 13824
assert config.num_heads == 40
assert config.num_layers == 40
assert config.head_dim == 128
assert config.sample_guide_scale == 5.0
assert config.sample_shift == 5.0
assert config.sample_steps == 50
assert config.boundary == 0.0
def test_wan21_1_3b_factory(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b()
assert config.model_version == "2.1"
assert config.dual_model is False
assert config.dim == 1536
assert config.ffn_dim == 8960
assert config.num_heads == 12
assert config.num_layers == 30
assert config.head_dim == 128 # 1536 // 12
assert config.sample_guide_scale == 5.0
def test_wan22_14b_factory(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan22_t2v_14b()
assert config.model_version == "2.2"
assert config.dual_model is True
assert config.dim == 5120
assert config.sample_guide_scale == (3.0, 4.0)
assert config.sample_shift == 12.0
assert config.sample_steps == 40
assert config.boundary == 0.875
def test_wan21_config_to_dict(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
d = config.to_dict()
assert d["model_version"] == "2.1"
assert d["dual_model"] is False
assert d["sample_guide_scale"] == 5.0
def test_wan21_1_3b_config_to_dict(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b()
d = config.to_dict()
assert d["dim"] == 1536
assert d["num_layers"] == 30
def test_default_config_is_wan22(self):
"""Default WanModelConfig() should be Wan2.2 14B."""
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig()
assert config.model_version == "2.2"
assert config.dual_model is True

307
tests/test_wan_convert.py Normal file
View File

@@ -0,0 +1,307 @@
"""Tests for Wan weight conversion utilities."""
import logging
import mlx.core as mx
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# Transformer Weight Conversion Tests
# ---------------------------------------------------------------------------
class TestSanitizeTransformerWeights:
def test_patch_embedding_reshape(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
"patch_embedding.bias": mx.random.normal((5120,)),
}
out = sanitize_wan_transformer_weights(weights)
assert "patch_embedding_proj.weight" in out
assert "patch_embedding_proj.bias" in out
assert out["patch_embedding_proj.weight"].shape == (5120, 16 * 1 * 2 * 2)
def test_text_embedding_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"text_embedding.0.weight": mx.zeros((64, 32)),
"text_embedding.0.bias": mx.zeros((64,)),
"text_embedding.2.weight": mx.zeros((64, 64)),
"text_embedding.2.bias": mx.zeros((64,)),
}
out = sanitize_wan_transformer_weights(weights)
assert "text_embedding_0.weight" in out
assert "text_embedding_0.bias" in out
assert "text_embedding_1.weight" in out
assert "text_embedding_1.bias" in out
def test_time_embedding_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"time_embedding.0.weight": mx.zeros((64, 32)),
"time_embedding.2.weight": mx.zeros((64, 64)),
}
out = sanitize_wan_transformer_weights(weights)
assert "time_embedding_0.weight" in out
assert "time_embedding_1.weight" in out
def test_time_projection_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"time_projection.1.weight": mx.zeros((384, 64)),
"time_projection.1.bias": mx.zeros((384,)),
}
out = sanitize_wan_transformer_weights(weights)
assert "time_projection.weight" in out
assert "time_projection.bias" in out
def test_ffn_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
"blocks.0.ffn.0.bias": mx.zeros((128,)),
"blocks.0.ffn.2.weight": mx.zeros((64, 128)),
"blocks.0.ffn.2.bias": mx.zeros((64,)),
}
out = sanitize_wan_transformer_weights(weights)
assert "blocks.0.ffn.fc1.weight" in out
assert "blocks.0.ffn.fc1.bias" in out
assert "blocks.0.ffn.fc2.weight" in out
assert "blocks.0.ffn.fc2.bias" in out
def test_freqs_skipped(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"freqs": mx.zeros((1024, 64, 2)),
"blocks.0.norm1.weight": mx.zeros((64,)),
}
out = sanitize_wan_transformer_weights(weights)
assert "freqs" not in out
assert "blocks.0.norm1.weight" in out
def test_passthrough_keys(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
"blocks.0.self_attn.k.weight": mx.zeros((64, 64)),
"blocks.0.self_attn.v.weight": mx.zeros((64, 64)),
"blocks.0.self_attn.o.weight": mx.zeros((64, 64)),
"blocks.0.modulation": mx.zeros((1, 6, 64)),
"head.head.weight": mx.zeros((64, 64)),
"head.modulation": mx.zeros((1, 2, 64)),
}
out = sanitize_wan_transformer_weights(weights)
for key in weights:
assert key in out
def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
weights = {
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
"patch_embedding.bias": mx.random.normal((5120,)),
"text_embedding.0.weight": mx.zeros((64, 32)),
"text_embedding.2.weight": mx.zeros((64, 64)),
"time_embedding.0.weight": mx.zeros((64, 32)),
"time_embedding.2.weight": mx.zeros((64, 64)),
"time_projection.1.weight": mx.zeros((384, 64)),
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
"blocks.0.ffn.2.weight": mx.zeros((64, 128)),
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
"blocks.0.modulation": mx.zeros((1, 6, 64)),
"head.head.weight": mx.zeros((64, 64)),
"freqs": mx.zeros((1024, 64, 2)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
sanitize_wan_transformer_weights(weights)
assert "Unconsumed" not in caplog.text
class TestSanitizeT5Weights:
def test_gate_rename(self):
from mlx_video.convert_wan import sanitize_wan_t5_weights
weights = {
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
}
out = sanitize_wan_t5_weights(weights)
assert "blocks.0.ffn.gate_proj.weight" in out
assert "blocks.0.ffn.fc1.weight" in out
assert "blocks.0.ffn.fc2.weight" in out
def test_passthrough(self):
from mlx_video.convert_wan import sanitize_wan_t5_weights
weights = {
"token_embedding.weight": mx.zeros((100, 64)),
"blocks.0.attn.q.weight": mx.zeros((64, 64)),
"norm.weight": mx.zeros((64,)),
}
out = sanitize_wan_t5_weights(weights)
for key in weights:
assert key in out
def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_t5_weights
weights = {
"token_embedding.weight": mx.zeros((100, 64)),
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
"norm.weight": mx.zeros((64,)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
sanitize_wan_t5_weights(weights)
assert "Unconsumed" not in caplog.text
class TestSanitizeVAEWeights:
def test_conv3d_transpose(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = {
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W]
}
out = sanitize_wan_vae_weights(weights)
assert out["decoder.conv1.weight"].shape == (8, 3, 3, 3, 4) # [O, D, H, W, I]
def test_conv2d_transpose(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = {
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W]
}
out = sanitize_wan_vae_weights(weights)
assert out["decoder.proj.weight"].shape == (16, 3, 3, 8) # [O, H, W, I]
def test_non_conv_passthrough(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = {
"decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose
"decoder.bias": mx.zeros((16,)),
}
out = sanitize_wan_vae_weights(weights)
assert out["decoder.norm.weight"].shape == (64,)
assert out["decoder.bias"].shape == (16,)
def test_mixed_weights(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = {
"conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D
"conv2d.weight": mx.zeros((8, 4, 3, 3)), # 4D
"linear.weight": mx.zeros((8, 4)), # 2D
"norm.weight": mx.zeros((8,)), # 1D
}
out = sanitize_wan_vae_weights(weights)
assert out["conv3d.weight"].shape == (8, 3, 3, 3, 4)
assert out["conv2d.weight"].shape == (8, 3, 3, 4)
assert out["linear.weight"].shape == (8, 4)
assert out["norm.weight"].shape == (8,)
def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_vae_weights
weights = {
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)),
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)),
"decoder.norm.weight": mx.zeros((64,)),
"decoder.bias": mx.zeros((16,)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
sanitize_wan_vae_weights(weights)
assert "Unconsumed" not in caplog.text
# ---------------------------------------------------------------------------
# Wan2.1 Conversion Tests
# ---------------------------------------------------------------------------
class TestWan21Convert:
"""Tests for Wan2.1 conversion support."""
def test_auto_detect_wan21(self, tmp_path):
"""Auto-detect single-model directory as Wan2.1."""
# Create a Wan2.1-style directory (no low_noise_model subdir)
(tmp_path / "dummy.safetensors").touch()
# The auto-detect logic: no low_noise_model dir → 2.1
from pathlib import Path
low = tmp_path / "low_noise_model"
assert not low.exists()
# Simulates auto detection
version = "2.2" if low.exists() else "2.1"
assert version == "2.1"
def test_auto_detect_wan22(self, tmp_path):
"""Auto-detect dual-model directory as Wan2.2."""
(tmp_path / "low_noise_model").mkdir()
(tmp_path / "high_noise_model").mkdir()
from pathlib import Path
low = tmp_path / "low_noise_model"
assert low.exists()
version = "2.2" if low.exists() else "2.1"
assert version == "2.2"
def test_wan21_config_saved_correctly(self):
"""Verify config dict has correct fields for Wan2.1."""
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
d = config.to_dict()
assert d["model_version"] == "2.1"
assert d["dual_model"] is False
assert d["sample_steps"] == 50
assert d["sample_shift"] == 5.0
# ---------------------------------------------------------------------------
# Encoder Weight Sanitization Tests
# ---------------------------------------------------------------------------
class TestSanitizeEncoderWeights:
"""Tests for sanitize_wan22_vae_weights with include_encoder."""
def test_exclude_encoder_by_default(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
}
out = sanitize_wan22_vae_weights(weights, include_encoder=False)
assert "conv2.weight" in out
assert not any("encoder" in k or k.startswith("conv1") for k in out)
def test_include_encoder(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
}
out = sanitize_wan22_vae_weights(weights, include_encoder=True)
assert "encoder.conv1.weight" in out
assert "conv1.weight" in out
assert "conv2.weight" in out
def test_no_unconsumed_keys(self, caplog):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
sanitize_wan22_vae_weights(weights, include_encoder=True)
assert "Unconsumed" not in caplog.text
def test_no_unconsumed_keys_exclude_encoder(self, caplog):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
sanitize_wan22_vae_weights(weights, include_encoder=False)
assert "Unconsumed" not in caplog.text

238
tests/test_wan_generate.py Normal file
View File

@@ -0,0 +1,238 @@
"""Tests for end-to-end generation and I2V mask construction."""
import mlx.core as mx
import numpy as np
import pytest
from wan_test_helpers import _make_tiny_config
# ---------------------------------------------------------------------------
# Integration: end-to-end tiny model forward pass
# ---------------------------------------------------------------------------
class TestEndToEnd:
"""End-to-end test with tiny model (no real weights needed)."""
def test_tiny_model_denoise_step(self):
"""Simulate one denoising step with tiny model."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
mx.random.seed(42)
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=3.0)
latents = mx.random.normal((C, F, H, W))
context = mx.random.normal((4, config.text_dim))
# One step
t = sched.timesteps[0]
pred = model([latents], mx.array([t.item()]), [context], seq_len)[0]
latents_next = sched.step(pred[None], t, latents[None]).squeeze(0)
mx.eval(latents_next)
assert latents_next.shape == (C, F, H, W)
# Should differ from original noise
assert not np.allclose(np.array(latents_next), np.array(latents), atol=1e-5)
def test_tiny_model_full_loop(self):
"""Run a complete (tiny) diffusion loop."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
mx.random.seed(123)
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
sched = FlowMatchEulerScheduler()
num_steps = 3
sched.set_timesteps(num_steps, shift=3.0)
latents = mx.random.normal((C, F, H, W))
context = mx.random.normal((4, config.text_dim))
for i in range(num_steps):
t = sched.timesteps[i]
pred = model([latents], mx.array([t.item()]), [context], seq_len)[0]
latents = sched.step(pred[None], t, latents[None]).squeeze(0)
mx.eval(latents)
assert latents.shape == (C, F, H, W)
assert not mx.any(mx.isnan(latents)).item(), "NaN in output"
assert not mx.any(mx.isinf(latents)).item(), "Inf in output"
# ---------------------------------------------------------------------------
# I2V Mask Tests
# ---------------------------------------------------------------------------
class TestI2VMask:
"""Tests for _build_i2v_mask."""
def test_mask_shapes(self):
from mlx_video.generate_wan import _build_i2v_mask
z_shape = (48, 5, 4, 4) # C, T, H, W
patch_size = (1, 2, 2)
mask, mask_tokens = _build_i2v_mask(z_shape, patch_size)
assert mask.shape == z_shape
# Tokens: T=5, H/2=2, W/2=2 → 5*2*2 = 20
assert mask_tokens.shape == (1, 20)
def test_first_frame_zero(self):
from mlx_video.generate_wan import _build_i2v_mask
z_shape = (48, 5, 4, 4)
mask, mask_tokens = _build_i2v_mask(z_shape, (1, 2, 2))
mx.eval(mask, mask_tokens)
# First temporal position should be 0
assert float(mask[:, 0, :, :].max()) == 0.0
# Rest should be 1
assert float(mask[:, 1:, :, :].min()) == 1.0
# First-frame tokens (T=0) should be 0 in mask_tokens
# With T=5, H'=2, W'=2: first 4 tokens are frame 0
assert float(mask_tokens[0, :4].max()) == 0.0
assert float(mask_tokens[0, 4:].min()) == 1.0
class TestI2VMaskAlignment:
"""Tests that I2V mask works correctly with various aligned dimensions."""
def test_mask_with_ti2v_dimensions(self):
"""Mask should work with TI2V-5B typical dimensions."""
from mlx_video.generate_wan import _build_i2v_mask
# TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2)
# 704x1280 → latent 44x80, t_latent=21 for 81 frames
z_shape = (48, 21, 44, 80)
patch_size = (1, 2, 2)
mask, mask_tokens = _build_i2v_mask(z_shape, patch_size)
mx.eval(mask, mask_tokens)
assert mask.shape == z_shape
assert float(mask[:, 0].max()) == 0.0
assert float(mask[:, 1:].min()) == 1.0
expected_tokens = 21 * 22 * 40 # T * (H/ph) * (W/pw)
assert mask_tokens.shape == (1, expected_tokens)
first_frame_tokens = 1 * 22 * 40 # pt=1
assert float(mask_tokens[0, :first_frame_tokens].max()) == 0.0
assert float(mask_tokens[0, first_frame_tokens:].min()) == 1.0
def test_mask_per_token_timestep(self):
"""Per-token timesteps: first-frame tokens get t=0, rest get t=sigma."""
from mlx_video.generate_wan import _build_i2v_mask
z_shape = (4, 3, 4, 4)
patch_size = (1, 2, 2)
_, mask_tokens = _build_i2v_mask(z_shape, patch_size)
mx.eval(mask_tokens)
timestep_val = 0.8
t_tokens = mask_tokens * timestep_val
mx.eval(t_tokens)
first_tokens = 1 * 2 * 2 # pt * (H/ph) * (W/pw)
np.testing.assert_allclose(np.array(t_tokens[0, :first_tokens]), 0.0, atol=1e-7)
np.testing.assert_allclose(np.array(t_tokens[0, first_tokens:]), timestep_val, atol=1e-7)
# ---------------------------------------------------------------------------
# Dimension Alignment Tests
# ---------------------------------------------------------------------------
class TestDimensionAlignment:
"""Tests for automatic dimension alignment in generate_wan."""
def test_already_aligned(self):
"""Dimensions already divisible by alignment factor should be unchanged."""
# patch_size=(1,2,2), vae_stride=(4,16,16) → align = 32
align_h = 2 * 16 # 32
align_w = 2 * 16 # 32
h, w = 704, 1280
assert h % align_h == 0
assert w % align_w == 0
h_aligned = (h // align_h) * align_h
w_aligned = (w // align_w) * align_w
assert h_aligned == h
assert w_aligned == w
def test_720p_rounds_down(self):
"""720p (1280x720) should round height to 704."""
align_h = 32
align_w = 32
h, w = 720, 1280
assert h % align_h != 0 # 720 not divisible by 32
h_aligned = (h // align_h) * align_h
w_aligned = (w // align_w) * align_w
assert h_aligned == 704
assert w_aligned == 1280
def test_1080p_rounds_down(self):
"""1080p (1920x1080) should round height to 1056."""
align = 32
h, w = 1080, 1920
assert h % align != 0
assert (h // align) * align == 1056
assert (w // align) * align == 1920
def test_odd_sizes(self):
"""Odd sizes should be safely rounded down."""
align = 32
for size in [100, 255, 513, 1023]:
aligned = (size // align) * align
assert aligned % align == 0
assert aligned <= size
assert aligned + align > size # closest lower multiple
def test_patchify_valid_after_alignment(self):
"""After alignment, patchify should succeed without reshape errors."""
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
# Simulate 720p-like scenario with tiny config
vae_stride = config.vae_stride # (4, 8, 8)
patch_size = config.patch_size # (1, 2, 2)
align_h = patch_size[1] * vae_stride[1]
align_w = patch_size[2] * vae_stride[2]
# Pick a height not divisible by alignment
raw_h = align_h * 3 + 5 # e.g. 53 for align=16
raw_w = align_w * 4
h = (raw_h // align_h) * align_h # rounds down
w = (raw_w // align_w) * align_w
C = config.in_dim
t_latent = 1
h_latent = h // vae_stride[1]
w_latent = w // vae_stride[2]
vid = mx.random.normal((C, t_latent, h_latent, w_latent))
patches, grid_size = model._patchify(vid)
mx.eval(patches)
assert patches.ndim == 3 # [1, L, dim]
assert grid_size == (t_latent, h_latent // patch_size[1], w_latent // patch_size[2])
def test_alignment_with_ti2v_config(self):
"""TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32."""
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan22_ti2v_5b()
align_h = config.patch_size[1] * config.vae_stride[1]
align_w = config.patch_size[2] * config.vae_stride[2]
assert align_h == 32
assert align_w == 32
# 720 not divisible
assert 720 % align_h != 0
# 704 is
assert 704 % align_h == 0

570
tests/test_wan_i2v.py Normal file
View File

@@ -0,0 +1,570 @@
"""Tests for Wan2.2 I2V-14B support."""
import mlx.core as mx
import numpy as np
import pytest
from wan_test_helpers import _make_tiny_config
def _make_tiny_i2v_config():
"""Create a tiny I2V-14B config for testing."""
config = _make_tiny_config()
config.model_type = "i2v"
config.in_dim = 9 # 4 noise + 4 image + 1 mask (scaled down from 16+16+4=36)
config.out_dim = 4
config.vae_z_dim = 4
config.vae_stride = (4, 8, 8)
config.dual_model = True
config.boundary = 0.900
config.sample_shift = 5.0
config.sample_guide_scale = (3.5, 3.5)
return config
class TestI2VConfig:
"""Test I2V-14B config preset."""
def test_wan22_i2v_14b_preset(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan22_i2v_14b()
assert config.model_type == "i2v"
assert config.in_dim == 36
assert config.out_dim == 16
assert config.dim == 5120
assert config.num_layers == 40
assert config.dual_model is True
assert config.boundary == 0.900
assert config.sample_shift == 5.0
assert config.sample_guide_scale == (3.5, 3.5)
assert config.vae_stride == (4, 8, 8)
assert config.vae_z_dim == 16
def test_i2v_vs_t2v_differences(self):
from mlx_video.models.wan.config import WanModelConfig
i2v = WanModelConfig.wan22_i2v_14b()
t2v = WanModelConfig.wan22_t2v_14b()
assert i2v.model_type == "i2v"
assert t2v.model_type == "t2v"
assert i2v.in_dim == 36 and t2v.in_dim == 16
assert i2v.boundary == 0.900 and t2v.boundary == 0.875
assert i2v.sample_shift == 5.0 and t2v.sample_shift == 12.0
def test_i2v_serialization_roundtrip(self):
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan22_i2v_14b()
d = config.to_dict()
restored = WanModelConfig.from_dict(d)
assert restored.model_type == "i2v"
assert restored.in_dim == 36
assert restored.boundary == 0.900
class TestModelYParameter:
"""Test y parameter channel concatenation in WanModel."""
def test_forward_without_y(self):
"""Standard T2V forward pass (no y) still works."""
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
x_list = [mx.random.normal((C, F, H, W))]
t = mx.array([500.0])
context = [mx.random.normal((6, config.text_dim))]
out = model(x_list, t, context, seq_len)
mx.eval(out[0])
assert out[0].shape == (C, F, H, W)
def test_forward_with_y(self):
"""I2V forward pass with y channel concatenation."""
from mlx_video.models.wan.model import WanModel
config = _make_tiny_i2v_config()
model = WanModel(config)
C_noise = 4 # noise channels
C_y = 5 # mask (1) + image (4)
F, H, W = 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
x_list = [mx.random.normal((C_noise, F, H, W))]
y_list = [mx.random.normal((C_y, F, H, W))]
t = mx.array([500.0])
context = [mx.random.normal((6, config.text_dim))]
out = model(x_list, t, context, seq_len, y=y_list)
mx.eval(out[0])
# Output should match noise channels (out_dim), not concatenated in_dim
assert out[0].shape == (config.out_dim, F, H, W)
def test_y_none_is_noop(self):
"""Passing y=None should be identical to not passing y."""
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
mx.random.seed(42)
x = mx.random.normal((C, F, H, W))
t = mx.array([500.0])
ctx = [mx.random.normal((6, config.text_dim))]
out1 = model([x], t, ctx, seq_len)[0]
out2 = model([x], t, ctx, seq_len, y=None)[0]
mx.eval(out1, out2)
assert mx.allclose(out1, out2, atol=1e-5).item()
def test_batched_cfg_with_y(self):
"""Batched CFG (B=2) with y should work."""
from mlx_video.models.wan.model import WanModel
config = _make_tiny_i2v_config()
model = WanModel(config)
C_noise, C_y = 4, 5
F, H, W = 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
latents = mx.random.normal((C_noise, F, H, W))
y = mx.random.normal((C_y, F, H, W))
t = mx.array([500.0, 500.0])
ctx = [mx.random.normal((6, config.text_dim)), mx.random.normal((6, config.text_dim))]
out = model([latents, latents], t, ctx, seq_len, y=[y, y])
mx.eval(out[0], out[1])
assert len(out) == 2
assert out[0].shape == (config.out_dim, F, H, W)
assert out[1].shape == (config.out_dim, F, H, W)
class TestVAEEncoder:
"""Test Wan2.1 VAE encoder."""
def test_encoder3d_instantiation(self):
from mlx_video.models.wan.vae import Encoder3d
enc = Encoder3d(dim=32, z_dim=8) # z_dim=8 (will output 8ch, but WanVAE wraps with z*2)
assert enc.conv1 is not None
assert len(enc.downsamples) > 0
assert len(enc.middle) == 3
def test_encoder3d_output_shape(self):
"""Encoder should downsample spatially by 8x and temporally by 4x."""
from mlx_video.models.wan.vae import Encoder3d
enc = Encoder3d(dim=32, z_dim=8)
# Random input: [B=1, 3, T=5, H=32, W=32]
x = mx.random.normal((1, 3, 5, 32, 32))
out = enc(x)
mx.eval(out)
# With default dim_mult=[1,2,4,4] and temporal_downsample=[True,True,False]:
# Spatial: 32 -> 16 -> 8 -> 4 (3 spatial downsamples)
# Temporal: 5 -> 3 -> 2 (2 temporal downsamples: downsample3d stride 2)
assert out.shape[0] == 1
assert out.shape[1] == 8 # z_dim
assert out.shape[3] == 32 // 8 # spatial /8
assert out.shape[4] == 32 // 8
def test_wan_vae_encode(self):
"""WanVAE with encoder=True should produce normalized latents."""
from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16, encoder=True)
# Input: [B=1, 3, T=5, H=32, W=32]
x = mx.random.normal((1, 3, 5, 32, 32))
z = vae.encode(x)
mx.eval(z)
assert z.shape[0] == 1
assert z.shape[1] == 16 # z_dim
def test_wan_vae_encoder_flag(self):
"""WanVAE without encoder flag should not have encoder attribute."""
from mlx_video.models.wan.vae import WanVAE
vae_no_enc = WanVAE(z_dim=4, encoder=False)
assert not hasattr(vae_no_enc, 'encoder')
vae_enc = WanVAE(z_dim=4, encoder=True)
assert hasattr(vae_enc, 'encoder')
class TestResampleDownsample:
"""Test downsample modes in Resample."""
def test_downsample2d(self):
from mlx_video.models.wan.vae import Resample
r = Resample(dim=16, mode="downsample2d")
x = mx.random.normal((1, 16, 2, 8, 8))
out = r(x)
mx.eval(out)
# Spatial /2, temporal unchanged, channels same
assert out.shape == (1, 16, 2, 4, 4)
def test_downsample3d(self):
from mlx_video.models.wan.vae import Resample
r = Resample(dim=16, mode="downsample3d")
x = mx.random.normal((1, 16, 4, 8, 8))
out = r(x)
mx.eval(out)
# Spatial /2, temporal /2, channels same
assert out.shape == (1, 16, 2, 4, 4)
def test_upsample2d_still_works(self):
from mlx_video.models.wan.vae import Resample
r = Resample(dim=16, mode="upsample2d")
x = mx.random.normal((1, 16, 2, 4, 4))
out = r(x)
mx.eval(out)
assert out.shape == (1, 8, 2, 8, 8)
def test_upsample3d_still_works(self):
from mlx_video.models.wan.vae import Resample
r = Resample(dim=16, mode="upsample3d")
x = mx.random.normal((1, 16, 2, 4, 4))
out = r(x)
mx.eval(out)
assert out.shape == (1, 8, 4, 8, 8)
class TestI2VMaskConstruction:
"""Test mask construction for I2V-14B."""
def test_mask_shape(self):
"""I2V-14B mask should have 4 channels with correct temporal structure."""
num_frames = 81
h_latent, w_latent = 10, 18 # example latent dims
t_latent = (num_frames - 1) // 4 + 1 # = 21
# Build mask following reference logic
msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
assert msk.shape == (4, t_latent, h_latent, w_latent)
def test_mask_values(self):
"""First temporal position should be 1, rest 0."""
num_frames = 9
h_latent, w_latent = 4, 4
t_latent = (num_frames - 1) // 4 + 1 # = 3
msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0]
mx.eval(msk)
# First temporal position: all 4 channels should be 1
assert mx.all(msk[:, 0] == 1.0).item()
# Rest: all should be 0
assert mx.all(msk[:, 1:] == 0.0).item()
def test_y_tensor_shape(self):
"""y = concat([mask_4ch, encoded_video_16ch]) should be 20 channels."""
mask = mx.zeros((4, 5, 10, 18))
encoded = mx.zeros((16, 5, 10, 18))
y = mx.concatenate([mask, encoded], axis=0)
assert y.shape == (20, 5, 10, 18)
# ---------------------------------------------------------------------------
# Integration: I2V end-to-end pipeline
# ---------------------------------------------------------------------------
class TestI2VEndToEndPipeline:
"""Full I2V pipeline: image → preprocess → VAE encode → y tensor → denoise → VAE decode."""
def test_full_i2v_pipeline(self):
"""End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan.vae import WanVAE
mx.random.seed(0)
# --- Tiny I2V model config (z_dim=16 to match VAE normalization stats) ---
config = _make_tiny_i2v_config()
config.vae_z_dim = 16
config.out_dim = 16 # must match VAE z_dim for decode
config.in_dim = 16 + 4 + 16 # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36
model = WanModel(config)
# --- Tiny VAE (with encoder) ---
vae = WanVAE(z_dim=config.vae_z_dim, encoder=True)
# --- Synthetic image: [B=1, 3, T=1, H=32, W=32] in [-1, 1] ---
height, width = 32, 32
num_frames = 5 # small temporal extent
img = mx.random.uniform(-1, 1, (1, 3, 1, height, width))
# Build video: first frame = image, rest = zeros -> [1, 3, F, H, W]
video = mx.concatenate([
img,
mx.zeros((1, 3, num_frames - 1, height, width)),
], axis=2)
# --- VAE encode ---
z_video = vae.encode(video) # [1, z_dim, T_lat, H_lat, W_lat]
mx.eval(z_video)
assert z_video.ndim == 5
assert z_video.shape[1] == config.vae_z_dim
z_video = z_video[0] # [z_dim, T_lat, H_lat, W_lat]
t_latent = z_video.shape[1]
h_latent = z_video.shape[2]
w_latent = z_video.shape[3]
# --- Build I2V mask (4 channels) ---
msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
# --- Build y tensor: [mask(4ch) + encoded(z_dim ch)] ---
y_i2v = mx.concatenate([msk, z_video], axis=0)
mx.eval(y_i2v)
assert y_i2v.shape[0] == 4 + config.vae_z_dim
# --- Denoising loop (2 steps) ---
C_noise = config.out_dim # noise channels
pt, ph, pw = config.patch_size
seq_len = (t_latent // pt) * (h_latent // ph) * (w_latent // pw)
sched = FlowMatchEulerScheduler()
num_steps = 2
sched.set_timesteps(num_steps, shift=config.sample_shift)
latents = mx.random.normal((C_noise, t_latent, h_latent, w_latent))
context = mx.random.normal((4, config.text_dim))
for i in range(num_steps):
t_val = sched.timesteps[i].item()
pred = model(
[latents],
mx.array([t_val]),
[context],
seq_len,
y=[y_i2v],
)[0]
latents = sched.step(pred[None], t_val, latents[None]).squeeze(0)
mx.eval(latents)
assert latents.shape == (C_noise, t_latent, h_latent, w_latent)
assert not mx.any(mx.isnan(latents)).item(), "NaN in denoised latents"
assert not mx.any(mx.isinf(latents)).item(), "Inf in denoised latents"
# --- VAE decode ---
decoded = vae.decode(latents[None]) # [1, 3, T_out, H_out, W_out]
mx.eval(decoded)
assert decoded.ndim == 5
assert decoded.shape[0] == 1
assert decoded.shape[1] == 3 # RGB output
assert not mx.any(mx.isnan(decoded)).item(), "NaN in decoded video"
assert not mx.any(mx.isinf(decoded)).item(), "Inf in decoded video"
# VAE decode clips to [-1, 1]
assert float(decoded.max()) <= 1.0
assert float(decoded.min()) >= -1.0
class TestDualModelSwitching:
"""Test dual-model selection logic: high_noise vs low_noise based on boundary."""
def test_model_selection_by_timestep(self):
"""Verify high_noise model used for timesteps >= boundary, low_noise otherwise."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
mx.random.seed(1)
config = _make_tiny_i2v_config()
assert config.dual_model is True
high_noise_model = WanModel(config)
low_noise_model = WanModel(config)
boundary = config.boundary * config.num_train_timesteps # 0.9 * 1000 = 900
C_noise = config.out_dim # 4
C_y = config.in_dim - config.out_dim # 9 - 4 = 5
F, H, W = 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
sched = FlowMatchEulerScheduler()
num_steps = 5
sched.set_timesteps(num_steps, shift=config.sample_shift)
guide_scale = config.sample_guide_scale # (3.5, 3.5)
assert isinstance(guide_scale, tuple) and len(guide_scale) == 2
latents = mx.random.normal((C_noise, F, H, W))
y_i2v = mx.random.normal((C_y, F, H, W))
context = mx.random.normal((4, config.text_dim))
high_used_steps = []
low_used_steps = []
timestep_list = sched.timesteps.tolist()
for i in range(num_steps):
timestep_val = timestep_list[i]
if timestep_val >= boundary:
model = high_noise_model
gs = guide_scale[1]
high_used_steps.append(i)
else:
model = low_noise_model
gs = guide_scale[0]
low_used_steps.append(i)
# CFG pass: cond + uncond
preds = model(
[latents, latents],
mx.array([timestep_val, timestep_val]),
[context, context],
seq_len,
y=[y_i2v, y_i2v],
)
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
mx.eval(latents)
# With shift=5.0, early timesteps should be high (>=900), later ones low
assert len(high_used_steps) > 0, "High-noise model was never selected"
assert len(low_used_steps) > 0, "Low-noise model was never selected"
# High-noise steps should come before low-noise steps (timesteps decrease)
if high_used_steps and low_used_steps:
assert max(high_used_steps) < min(low_used_steps) or \
min(high_used_steps) < max(low_used_steps), \
"Model switching should happen during the loop"
assert latents.shape == (C_noise, F, H, W)
assert not mx.any(mx.isnan(latents)).item()
def test_guide_scale_tuple_applied_per_model(self):
"""Verify (low_gs, high_gs) tuple applies different scales per model."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
mx.random.seed(2)
config = _make_tiny_i2v_config()
config.sample_guide_scale = (2.0, 5.0) # distinct values
model = WanModel(config)
boundary = config.boundary * config.num_train_timesteps
C_noise = config.out_dim
F, H, W = 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=config.sample_shift)
latents = mx.random.normal((C_noise, F, H, W))
context = mx.random.normal((4, config.text_dim))
guide_scale = config.sample_guide_scale
C_y = config.in_dim - config.out_dim # y channels
y_i2v = mx.random.normal((C_y, F, H, W))
# Track which guide scale was used at each step
gs_per_step = []
timestep_list = sched.timesteps.tolist()
for i in range(5):
timestep_val = timestep_list[i]
if timestep_val >= boundary:
gs = guide_scale[1] # high_gs = 5.0
else:
gs = guide_scale[0] # low_gs = 2.0
gs_per_step.append(gs)
pred = model(
[latents, latents],
mx.array([timestep_val, timestep_val]),
[context, context],
seq_len,
y=[y_i2v, y_i2v],
)
noise_pred = pred[1] + gs * (pred[0] - pred[1])
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
mx.eval(latents)
# Verify both guide scales were used
assert 5.0 in gs_per_step, "High guide scale (5.0) was never used"
assert 2.0 in gs_per_step, "Low guide scale (2.0) was never used"
# High gs should appear first (high timesteps come first)
first_high = gs_per_step.index(5.0)
last_low = len(gs_per_step) - 1 - gs_per_step[::-1].index(2.0)
assert first_high < last_low, "High gs steps should precede low gs steps"
def test_single_model_fallback_with_tuple_guide_scale(self):
"""When dual_model=False, guide_scale tuple should use first element."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
mx.random.seed(3)
config = _make_tiny_config()
config.dual_model = False
config.sample_guide_scale = (3.0, 5.0)
model = WanModel(config)
guide_scale = config.sample_guide_scale
C, F, H, W = config.in_dim, 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
sched = FlowMatchEulerScheduler()
sched.set_timesteps(3, shift=3.0)
latents = mx.random.normal((C, F, H, W))
context = mx.random.normal((4, config.text_dim))
# Mimic generate_wan.py single-model logic:
# gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
assert gs == 3.0, "Single model should use first element of guide_scale tuple"
for i in range(3):
t_val = sched.timesteps[i].item()
pred = model(
[latents, latents],
mx.array([t_val, t_val]),
[context, context],
seq_len,
)
noise_pred = pred[1] + gs * (pred[0] - pred[1])
latents = sched.step(noise_pred[None], t_val, latents[None]).squeeze(0)
mx.eval(latents)
assert latents.shape == (C, F, H, W)
assert not mx.any(mx.isnan(latents)).item()

334
tests/test_wan_lora.py Normal file
View File

@@ -0,0 +1,334 @@
"""Tests for LoRA loading and application."""
import tempfile
from pathlib import Path
import mlx.core as mx
import numpy as np
import pytest
class TestLoRATypes:
"""Test LoRA data structures."""
def test_lora_weights_scale(self):
from mlx_video.lora.types import LoRAWeights
w = LoRAWeights(
lora_A=mx.zeros((16, 64)),
lora_B=mx.zeros((128, 16)),
rank=16,
alpha=32.0,
module_name="test",
)
assert w.scale == 2.0
def test_lora_weights_scale_default(self):
from mlx_video.lora.types import LoRAWeights
w = LoRAWeights(
lora_A=mx.zeros((16, 64)),
lora_B=mx.zeros((128, 16)),
rank=16,
alpha=16.0,
module_name="test",
)
assert w.scale == 1.0
def test_applied_lora_delta(self):
from mlx_video.lora.types import AppliedLoRA, LoRAWeights
lora_a = mx.ones((2, 4))
lora_b = mx.ones((8, 2))
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
applied = AppliedLoRA(weights=w, strength=0.5)
delta = applied.compute_delta()
# scale=1.0, strength=0.5, B@A = [[2,2,2,2]]*8 (each row sum of 2 ones)
expected = 0.5 * mx.ones((8, 4)) * 2.0
assert mx.allclose(delta, expected).item()
class TestLoRALoader:
"""Test LoRA weight loading from safetensors."""
def _make_lora_file(self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"):
"""Helper to create a mock LoRA safetensors file."""
weights = {}
for name in module_names:
if key_format == "AB":
weights[f"{name}.lora_A.weight"] = mx.random.normal((rank, in_dim))
weights[f"{name}.lora_B.weight"] = mx.random.normal((out_dim, rank))
else:
weights[f"{name}.lora_down.weight"] = mx.random.normal((rank, in_dim))
weights[f"{name}.lora_up.weight"] = mx.random.normal((out_dim, rank))
path = Path(tmp_dir) / "test_lora.safetensors"
mx.save_safetensors(str(path), weights)
return path
def test_load_lora_a_b_format(self):
from mlx_video.lora.loader import load_lora_weights
with tempfile.TemporaryDirectory() as tmp:
path = self._make_lora_file(tmp, ["blocks.0.self_attn.q"], key_format="AB")
lora_weights = load_lora_weights(path)
assert "blocks.0.self_attn.q" in lora_weights
w = lora_weights["blocks.0.self_attn.q"]
assert w.rank == 4
assert w.alpha == 4.0 # default: alpha == rank
assert w.lora_A.shape == (4, 64)
assert w.lora_B.shape == (128, 4)
def test_load_lora_down_up_format(self):
from mlx_video.lora.loader import load_lora_weights
with tempfile.TemporaryDirectory() as tmp:
path = self._make_lora_file(
tmp, ["blocks.0.self_attn.q"], key_format="down_up"
)
lora_weights = load_lora_weights(path)
assert "blocks.0.self_attn.q" in lora_weights
def test_load_multiple_modules(self):
from mlx_video.lora.loader import load_lora_weights
modules = [
"blocks.0.self_attn.q",
"blocks.0.self_attn.k",
"blocks.0.ffn.fc1",
]
with tempfile.TemporaryDirectory() as tmp:
path = self._make_lora_file(tmp, modules)
lora_weights = load_lora_weights(path)
assert len(lora_weights) == 3
for name in modules:
assert name in lora_weights
def test_load_with_alpha(self):
from mlx_video.lora.loader import load_lora_weights
with tempfile.TemporaryDirectory() as tmp:
weights = {
"test.lora_A.weight": mx.random.normal((8, 64)),
"test.lora_B.weight": mx.random.normal((128, 8)),
"test.alpha": mx.array(16.0),
}
path = Path(tmp) / "lora.safetensors"
mx.save_safetensors(str(path), weights)
lora_weights = load_lora_weights(path)
assert lora_weights["test"].alpha == 16.0
assert lora_weights["test"].rank == 8
assert lora_weights["test"].scale == 2.0
def test_file_not_found(self):
from mlx_video.lora.loader import load_lora_weights
with pytest.raises(FileNotFoundError):
load_lora_weights(Path("/nonexistent/lora.safetensors"))
class TestWanKeyNormalization:
"""Test Wan2.2 LoRA key normalization."""
def _wan_model_keys(self):
"""Simulate typical Wan2.2 MLX model weight keys."""
keys = set()
for i in range(2):
for layer in ["self_attn.q", "self_attn.k", "self_attn.v", "self_attn.o",
"cross_attn.q", "cross_attn.k", "cross_attn.v", "cross_attn.o"]:
keys.add(f"blocks.{i}.{layer}.weight")
keys.add(f"blocks.{i}.ffn.fc1.weight")
keys.add(f"blocks.{i}.ffn.fc2.weight")
keys.add("text_embedding_0.weight")
keys.add("text_embedding_1.weight")
keys.add("time_embedding_0.weight")
keys.add("time_embedding_1.weight")
keys.add("time_projection.weight")
keys.add("patch_embedding_proj.weight")
return keys
def test_direct_match(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
assert _normalize_wan_lora_key("blocks.0.self_attn.q", keys) == "blocks.0.self_attn.q"
def test_strip_diffusion_model_prefix(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
result = _normalize_wan_lora_key("diffusion_model.blocks.0.self_attn.q", keys)
assert result == "blocks.0.self_attn.q"
def test_strip_model_prefix(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
result = _normalize_wan_lora_key("model.diffusion_model.blocks.0.self_attn.k", keys)
assert result == "blocks.0.self_attn.k"
def test_ffn_key_mapping(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
assert _normalize_wan_lora_key("blocks.0.ffn.0", keys) == "blocks.0.ffn.fc1"
assert _normalize_wan_lora_key("blocks.0.ffn.2", keys) == "blocks.0.ffn.fc2"
def test_text_embedding_mapping(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
assert _normalize_wan_lora_key("text_embedding.0", keys) == "text_embedding_0"
assert _normalize_wan_lora_key("text_embedding.2", keys) == "text_embedding_1"
def test_time_embedding_mapping(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
assert _normalize_wan_lora_key("time_embedding.0", keys) == "time_embedding_0"
assert _normalize_wan_lora_key("time_embedding.2", keys) == "time_embedding_1"
def test_time_projection_mapping(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
assert _normalize_wan_lora_key("time_projection.1", keys) == "time_projection"
def test_patch_embedding_mapping(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
assert _normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj"
def test_combined_prefix_and_ffn(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
result = _normalize_wan_lora_key("diffusion_model.blocks.1.ffn.0", keys)
assert result == "blocks.1.ffn.fc1"
class TestApplyLoRA:
"""Test LoRA delta application to weights."""
def test_preserves_bfloat16_dtype(self):
"""LoRA delta must not promote bfloat16 weights to float32."""
from mlx_video.lora.apply import apply_lora_to_linear
from mlx_video.lora.types import LoRAWeights
original = mx.ones((8, 4), dtype=mx.bfloat16)
# LoRA weights in float32 (typical when loaded from safetensors)
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
result = apply_lora_to_linear(original, [(w, 1.0)])
assert result.dtype == mx.bfloat16, f"Expected bfloat16, got {result.dtype}"
def test_preserves_float16_dtype(self):
from mlx_video.lora.apply import apply_lora_to_linear
from mlx_video.lora.types import LoRAWeights
original = mx.ones((8, 4), dtype=mx.float16)
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
result = apply_lora_to_linear(original, [(w, 1.0)])
assert result.dtype == mx.float16, f"Expected float16, got {result.dtype}"
def test_apply_single_lora(self):
from mlx_video.lora.apply import apply_lora_to_linear
from mlx_video.lora.types import LoRAWeights
original = mx.ones((8, 4))
lora_a = mx.ones((2, 4)) * 0.1
lora_b = mx.ones((8, 2)) * 0.1
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
result = apply_lora_to_linear(original, [(w, 1.0)])
# delta = 1.0 * (B @ A) = ones(8,2)*0.1 @ ones(2,4)*0.1 = 0.02 * ones(8,4)
expected = original + 0.02 * mx.ones((8, 4))
assert mx.allclose(result, expected, atol=1e-6).item()
def test_apply_multiple_loras(self):
from mlx_video.lora.apply import apply_lora_to_linear
from mlx_video.lora.types import LoRAWeights
original = mx.zeros((8, 4))
w1 = LoRAWeights(
lora_A=mx.ones((2, 4)),
lora_B=mx.ones((8, 2)),
rank=2, alpha=2.0, module_name="a",
)
w2 = LoRAWeights(
lora_A=mx.ones((2, 4)) * 2,
lora_B=mx.ones((8, 2)) * 2,
rank=2, alpha=4.0, module_name="b",
)
result = apply_lora_to_linear(original, [(w1, 1.0), (w2, 0.5)])
# w1 delta: 1.0 * 1.0 * (ones(8,2) @ ones(2,4)) = 2 * ones(8,4)
# w2 delta: 2.0 * 0.5 * (2*ones(8,2) @ 2*ones(2,4)) = 1.0 * 8*ones(8,4) = 8
delta1 = mx.ones((8, 4)) * 2.0
delta2 = mx.ones((8, 4)) * 8.0
expected = delta1 + delta2
assert mx.allclose(result, expected, atol=1e-5).item()
def test_apply_loras_to_weights_dict(self):
from mlx_video.lora.apply import apply_loras_to_weights
from mlx_video.lora.types import LoRAWeights
model_weights = {
"blocks.0.self_attn.q.weight": mx.ones((128, 64)),
"blocks.0.self_attn.k.weight": mx.ones((128, 64)),
"blocks.0.ffn.fc1.weight": mx.ones((256, 64)),
}
w = LoRAWeights(
lora_A=mx.ones((4, 64)) * 0.01,
lora_B=mx.ones((128, 4)) * 0.01,
rank=4, alpha=4.0, module_name="blocks.0.self_attn.q",
)
module_to_loras = {"blocks.0.self_attn.q": [(w, 1.0)]}
result = apply_loras_to_weights(model_weights, module_to_loras)
# Only q should be modified
assert not mx.array_equal(
result["blocks.0.self_attn.q.weight"],
model_weights["blocks.0.self_attn.q.weight"],
).item()
assert mx.array_equal(
result["blocks.0.self_attn.k.weight"],
model_weights["blocks.0.self_attn.k.weight"],
).item()
class TestEndToEnd:
"""End-to-end LoRA loading and application."""
def test_load_and_apply_loras(self):
from mlx_video.convert_wan import load_and_apply_loras
with tempfile.TemporaryDirectory() as tmp:
# Create mock LoRA safetensors
rank = 4
weights = {
"blocks.0.self_attn.q.lora_A.weight": mx.random.normal((rank, 64)),
"blocks.0.self_attn.q.lora_B.weight": mx.random.normal((128, rank)),
}
lora_path = Path(tmp) / "test.safetensors"
mx.save_safetensors(str(lora_path), weights)
# Create mock model weights
model_weights = {
"blocks.0.self_attn.q.weight": mx.ones((128, 64)),
"blocks.0.self_attn.k.weight": mx.ones((128, 64)),
}
result = load_and_apply_loras(
model_weights, [(str(lora_path), 1.0)]
)
# q weight should be modified, k unchanged
assert not mx.array_equal(
result["blocks.0.self_attn.q.weight"],
model_weights["blocks.0.self_attn.q.weight"],
).item()
assert mx.array_equal(
result["blocks.0.self_attn.k.weight"],
model_weights["blocks.0.self_attn.k.weight"],
).item()

332
tests/test_wan_model.py Normal file
View File

@@ -0,0 +1,332 @@
"""Tests for Wan model components."""
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import pytest
from wan_test_helpers import _make_tiny_config
# ---------------------------------------------------------------------------
# Sinusoidal Embedding Tests
# ---------------------------------------------------------------------------
class TestSinusoidalEmbedding:
def test_output_shape(self):
from mlx_video.models.wan.model import sinusoidal_embedding_1d
pos = mx.arange(10).astype(mx.float32)
emb = sinusoidal_embedding_1d(256, pos)
mx.eval(emb)
assert emb.shape == (10, 256)
def test_position_zero(self):
"""Position 0 should have cos=1 for all dims and sin=0."""
from mlx_video.models.wan.model import sinusoidal_embedding_1d
pos = mx.array([0.0])
emb = sinusoidal_embedding_1d(64, pos)
mx.eval(emb)
emb_np = np.array(emb[0])
# First half is cos, should be 1 at position 0
np.testing.assert_allclose(emb_np[:32], 1.0, atol=1e-5)
# Second half is sin, should be 0 at position 0
np.testing.assert_allclose(emb_np[32:], 0.0, atol=1e-5)
def test_different_positions_differ(self):
from mlx_video.models.wan.model import sinusoidal_embedding_1d
pos = mx.array([0.0, 100.0, 999.0])
emb = sinusoidal_embedding_1d(128, pos)
mx.eval(emb)
emb_np = np.array(emb)
assert not np.allclose(emb_np[0], emb_np[1])
assert not np.allclose(emb_np[1], emb_np[2])
# ---------------------------------------------------------------------------
# Head Tests
# ---------------------------------------------------------------------------
class TestHead:
def test_output_shape(self):
from mlx_video.models.wan.model import Head
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
B, L = 1, 24
x = mx.random.normal((B, L, 64))
e = mx.random.normal((B, 64)) # time embedding: [B, dim]
out = head(x, e)
mx.eval(out)
expected_proj_dim = 16 * 1 * 2 * 2 # 64
assert out.shape == (B, L, expected_proj_dim)
def test_modulation_shape(self):
from mlx_video.models.wan.model import Head
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
assert head.modulation.shape == (1, 2, 64)
# ---------------------------------------------------------------------------
# WanModel (Tiny) Tests
# ---------------------------------------------------------------------------
class TestWanModel:
def setup_method(self):
mx.random.seed(42)
def test_instantiation(self):
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
num_params = sum(p.size for _, p in nn.utils.tree_flatten(model.parameters()))
assert num_params > 0
def test_patchify_shape(self):
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
# Input: [C=4, F=1, H=4, W=4]
x = mx.random.normal((4, 1, 4, 4))
patches, grid_size = model._patchify(x)
mx.eval(patches)
# Patch size (1,2,2): F'=1, H'=2, W'=2
assert grid_size == (1, 2, 2)
assert patches.shape == (1, 1 * 2 * 2, config.dim)
def test_patchify_various_sizes(self):
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
for f, h, w in [(1, 4, 4), (2, 6, 8), (3, 4, 6)]:
x = mx.random.normal((config.in_dim, f, h, w))
patches, (gf, gh, gw) = model._patchify(x)
mx.eval(patches)
pt, ph, pw = config.patch_size
assert gf == f // pt
assert gh == h // ph
assert gw == w // pw
assert patches.shape[1] == gf * gh * gw
def test_unpatchify_inverse(self):
"""Patchify then unpatchify should reconstruct original spatial dims."""
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 2, 4, 6
pt, ph, pw = config.patch_size
F_out, H_out, W_out = F // pt, H // ph, W // pw
L = F_out * H_out * W_out
proj_dim = config.out_dim * pt * ph * pw
# Simulated head output
x = mx.random.normal((1, L, proj_dim))
out = model.unpatchify(x, [(F_out, H_out, W_out)])
mx.eval(out[0])
assert out[0].shape == (config.out_dim, F, H, W)
def test_forward_pass(self):
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
x_list = [mx.random.normal((C, F, H, W))]
t = mx.array([500.0])
context = [mx.random.normal((6, config.text_dim))]
out = model(x_list, t, context, seq_len)
mx.eval(out[0])
assert len(out) == 1
assert out[0].shape == (C, F, H, W)
def test_forward_batch(self):
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
x_list = [mx.random.normal((C, F, H, W)), mx.random.normal((C, F, H, W))]
t = mx.array([500.0, 200.0])
context = [mx.random.normal((6, config.text_dim)), mx.random.normal((4, config.text_dim))]
out = model(x_list, t, context, seq_len)
mx.eval(out[0], out[1])
assert len(out) == 2
for o in out:
assert o.shape == (C, F, H, W)
def test_output_is_float32(self):
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
seq_len = (F // 1) * (H // 2) * (W // 2)
out = model([mx.random.normal((C, F, H, W))], mx.array([100.0]),
[mx.random.normal((4, config.text_dim))], seq_len)
mx.eval(out[0])
assert out[0].dtype == mx.float32
# ---------------------------------------------------------------------------
# Wan2.1 Model Tests
# ---------------------------------------------------------------------------
class TestWan21Model:
"""Test tiny Wan2.1-style model (single model mode)."""
def setup_method(self):
mx.random.seed(42)
def _make_tiny_wan21_config(self):
"""Create a tiny config mimicking Wan2.1 (single model)."""
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
# Override to tiny values
config.dim = 64
config.ffn_dim = 128
config.num_heads = 4
config.num_layers = 2
config.in_dim = 4
config.out_dim = 4
config.freq_dim = 32
config.text_dim = 32
config.text_len = 8
return config
def _make_tiny_wan21_1_3b_config(self):
"""Create a tiny config mimicking Wan2.1 1.3B."""
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b()
# Override to tiny values (preserve 1.3B head structure: 12 heads)
config.dim = 48
config.ffn_dim = 96
config.num_heads = 4
config.num_layers = 2
config.in_dim = 4
config.out_dim = 4
config.freq_dim = 24
config.text_dim = 24
config.text_len = 8
return config
def test_wan21_tiny_model_forward(self):
"""Forward pass with Wan2.1 tiny config."""
from mlx_video.models.wan.model import WanModel
config = self._make_tiny_wan21_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
seq_len = (F // 1) * (H // 2) * (W // 2)
latents = mx.random.normal((C, F, H, W))
context = mx.random.normal((4, config.text_dim))
t = mx.array([500.0])
out = model([latents], t, [context], seq_len)
mx.eval(out)
assert out[0].shape == (C, F, H, W)
def test_wan21_1_3b_tiny_model_forward(self):
"""Forward pass with Wan2.1 1.3B tiny config."""
from mlx_video.models.wan.model import WanModel
config = self._make_tiny_wan21_1_3b_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
seq_len = (F // 1) * (H // 2) * (W // 2)
latents = mx.random.normal((C, F, H, W))
context = mx.random.normal((4, config.text_dim))
t = mx.array([500.0])
out = model([latents], t, [context], seq_len)
mx.eval(out)
assert out[0].shape == (C, F, H, W)
def test_wan21_single_model_loop(self):
"""Full diffusion loop with single model (Wan2.1 style)."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
config = self._make_tiny_wan21_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
seq_len = (F // 1) * (H // 2) * (W // 2)
sched = FlowMatchEulerScheduler()
sched.set_timesteps(config.sample_steps, shift=config.sample_shift)
# Use only 3 steps for speed
latents = mx.random.normal((C, F, H, W))
context = mx.random.normal((4, config.text_dim))
context_null = mx.zeros((4, config.text_dim))
gs = config.sample_guide_scale # Should be float for Wan2.1
assert isinstance(gs, float), "Wan2.1 guide_scale should be float"
for i in range(3):
t = sched.timesteps[i]
pred_cond = model([latents], mx.array([t.item()]), [context], seq_len)[0]
pred_uncond = model([latents], mx.array([t.item()]), [context_null], seq_len)[0]
pred = pred_uncond + gs * (pred_cond - pred_uncond)
latents = sched.step(pred[None], t, latents[None]).squeeze(0)
mx.eval(latents)
assert latents.shape == (C, F, H, W)
assert not mx.any(mx.isnan(latents)).item()
def test_wan21_vs_wan22_config_differences(self):
"""Verify key differences between Wan2.1 and Wan2.2 configs."""
from mlx_video.models.wan.config import WanModelConfig
c21 = WanModelConfig.wan21_t2v_14b()
c22 = WanModelConfig.wan22_t2v_14b()
# Same architecture
assert c21.dim == c22.dim
assert c21.num_heads == c22.num_heads
assert c21.num_layers == c22.num_layers
# Different pipeline settings
assert c21.dual_model is False
assert c22.dual_model is True
assert isinstance(c21.sample_guide_scale, float)
assert isinstance(c22.sample_guide_scale, tuple)
assert c21.sample_shift != c22.sample_shift
assert c21.sample_steps != c22.sample_steps
# ---------------------------------------------------------------------------
# Per-Token Timestep Tests
# ---------------------------------------------------------------------------
class TestPerTokenTimestep:
"""Tests for per-token sinusoidal embedding."""
def test_1d_unchanged(self):
from mlx_video.models.wan.model import sinusoidal_embedding_1d
pos = mx.array([0.0, 100.0, 500.0])
emb = sinusoidal_embedding_1d(256, pos)
assert emb.shape == (3, 256)
def test_2d_per_token(self):
from mlx_video.models.wan.model import sinusoidal_embedding_1d
pos = mx.array([[0.0, 100.0, 100.0], [50.0, 50.0, 50.0]])
emb = sinusoidal_embedding_1d(256, pos)
assert emb.shape == (2, 3, 256)
def test_consistency(self):
from mlx_video.models.wan.model import sinusoidal_embedding_1d
pos_1d = mx.array([0.0, 100.0])
emb_1d = sinusoidal_embedding_1d(256, pos_1d)
pos_2d = mx.array([[0.0, 100.0]])
emb_2d = sinusoidal_embedding_1d(256, pos_2d)
assert mx.array_equal(emb_1d[0], emb_2d[0, 0])
assert mx.array_equal(emb_1d[1], emb_2d[0, 1])

View File

@@ -0,0 +1,313 @@
"""Tests for Wan model quantization pipeline."""
import json
import mlx.core as mx
import mlx.nn as nn
import mlx.utils
import numpy as np
import pytest
from wan_test_helpers import _make_tiny_config
# ---------------------------------------------------------------------------
# Quantize Predicate Tests
# ---------------------------------------------------------------------------
class TestQuantizePredicate:
def test_matches_self_attention_layers(self):
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for suffix in ["q", "k", "v", "o"]:
path = f"blocks.0.self_attn.{suffix}"
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
def test_matches_cross_attention_layers(self):
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for suffix in ["q", "k", "v", "o"]:
path = f"blocks.0.cross_attn.{suffix}"
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
def test_matches_ffn_layers(self):
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear)
assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear)
def test_rejects_embeddings(self):
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for path in ["patch_embedding_proj", "text_embedding_fc1", "time_embedding.fc1"]:
assert not _quantize_predicate(path, mock_linear), f"Should reject {path}"
def test_rejects_norms(self):
from mlx_video.convert_wan import _quantize_predicate
mock_norm = nn.RMSNorm(64)
assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm)
def test_rejects_non_quantizable_modules(self):
from mlx_video.convert_wan import _quantize_predicate
mock_norm = nn.RMSNorm(64)
# Even if path matches, module must have to_quantized
assert not _quantize_predicate("blocks.0.self_attn.q", mock_norm)
def test_all_10_patterns_covered(self):
"""Verify exactly 10 layer patterns are targeted."""
from mlx_video.convert_wan import _quantize_predicate
mock_linear = nn.Linear(64, 64)
patterns = [
"blocks.0.self_attn.q", "blocks.0.self_attn.k",
"blocks.0.self_attn.v", "blocks.0.self_attn.o",
"blocks.0.cross_attn.q", "blocks.0.cross_attn.k",
"blocks.0.cross_attn.v", "blocks.0.cross_attn.o",
"blocks.0.ffn.fc1", "blocks.0.ffn.fc2",
]
matched = [p for p in patterns if _quantize_predicate(p, mock_linear)]
assert len(matched) == 10
# ---------------------------------------------------------------------------
# Quantize Round-Trip Tests
# ---------------------------------------------------------------------------
class TestQuantizeRoundTrip:
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
"""Helper: create model, quantize, save to tmp_path."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_predicate
model = WanModel(config)
nn.quantize(
model,
group_size=group_size,
bits=bits,
class_predicate=lambda path, m: _quantize_predicate(path, m),
)
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
model_path = tmp_path / "model.safetensors"
mx.save_safetensors(str(model_path), weights_dict)
# Write config.json
cfg = {"quantization": {"bits": bits, "group_size": group_size}}
with open(tmp_path / "config.json", "w") as f:
json.dump(cfg, f)
return model_path, weights_dict
def test_4bit_roundtrip(self, tmp_path):
config = _make_tiny_config()
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4)
from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model(
model_path, config,
quantization={"bits": 4, "group_size": 64},
)
# Verify quantized layers have scales
has_scales = any("scales" in k for k in saved_weights)
assert has_scales, "Quantized model should have .scales tensors"
# Verify a self-attention layer is QuantizedLinear
assert isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear)
assert isinstance(loaded.blocks[0].ffn.fc1, nn.QuantizedLinear)
def test_8bit_roundtrip(self, tmp_path):
config = _make_tiny_config()
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8)
from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model(
model_path, config,
quantization={"bits": 8, "group_size": 64},
)
assert isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear)
assert isinstance(loaded.blocks[0].cross_attn.k, nn.QuantizedLinear)
def test_non_quantized_layers_remain_linear(self, tmp_path):
config = _make_tiny_config()
model_path, _ = self._quantize_and_save(config, tmp_path, bits=4)
from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model(
model_path, config,
quantization={"bits": 4, "group_size": 64},
)
# Head should NOT be quantized (it's not in the predicate patterns)
assert not isinstance(loaded.head, nn.QuantizedLinear)
def test_loading_without_quantization_flag(self, tmp_path):
"""Loading a non-quantized model should have standard Linear layers."""
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
model_path = tmp_path / "model.safetensors"
mx.save_safetensors(str(model_path), weights_dict)
from mlx_video.models.wan.loading import load_wan_model
loaded = load_wan_model(model_path, config, quantization=None)
assert isinstance(loaded.blocks[0].self_attn.q, nn.Linear)
assert not isinstance(loaded.blocks[0].self_attn.q, nn.QuantizedLinear)
# ---------------------------------------------------------------------------
# Quantized Inference Tests
# ---------------------------------------------------------------------------
class TestQuantizedInference:
def _make_quantized_model(self, config, bits=4):
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_predicate
model = WanModel(config)
nn.quantize(
model,
group_size=64,
bits=bits,
class_predicate=lambda path, m: _quantize_predicate(path, m),
)
mx.eval(model.parameters())
return model
def test_forward_pass_4bit(self):
config = _make_tiny_config()
model = self._make_quantized_model(config, bits=4)
C, F, H, W = config.in_dim, 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
x = [mx.random.normal((C, F, H, W))]
t = mx.array([500.0])
context = [mx.random.normal((4, config.text_dim))]
out = model(x, t, context, seq_len)
mx.eval(out[0])
assert len(out) == 1
assert out[0].shape == (C, F, H, W)
def test_forward_pass_8bit(self):
config = _make_tiny_config()
model = self._make_quantized_model(config, bits=8)
C, F, H, W = config.in_dim, 1, 4, 4
pt, ph, pw = config.patch_size
seq_len = (F // pt) * (H // ph) * (W // pw)
x = [mx.random.normal((C, F, H, W))]
t = mx.array([500.0])
context = [mx.random.normal((4, config.text_dim))]
out = model(x, t, context, seq_len)
mx.eval(out[0])
assert len(out) == 1
assert out[0].shape == (C, F, H, W)
def test_quantized_output_differs_from_unquantized(self):
"""Sanity check: quantization should change the weights."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_predicate
config = _make_tiny_config()
mx.random.seed(42)
# Get unquantized weights
model = WanModel(config)
mx.eval(model.parameters())
orig_weight = np.array(model.blocks[0].self_attn.q.weight)
# Quantize
nn.quantize(
model,
group_size=64,
bits=4,
class_predicate=lambda path, m: _quantize_predicate(path, m),
)
mx.eval(model.parameters())
# QuantizedLinear stores weight differently (uint32 packed)
assert isinstance(model.blocks[0].self_attn.q, nn.QuantizedLinear)
assert hasattr(model.blocks[0].self_attn.q, "scales")
# ---------------------------------------------------------------------------
# Config Metadata Tests
# ---------------------------------------------------------------------------
class TestQuantizationConfig:
def test_config_metadata_written(self, tmp_path):
"""Verify _quantize_saved_model writes quantization metadata to config.json."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_saved_model
config = _make_tiny_config()
model = WanModel(config)
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
# Save unquantized model + config
model_path = tmp_path / "model.safetensors"
mx.save_safetensors(str(model_path), weights_dict)
with open(tmp_path / "config.json", "w") as f:
json.dump({"dim": config.dim}, f)
# Run quantization
_quantize_saved_model(tmp_path, config, is_dual=False, bits=4, group_size=64)
# Verify metadata
with open(tmp_path / "config.json") as f:
cfg = json.load(f)
assert "quantization" in cfg
assert cfg["quantization"]["bits"] == 4
assert cfg["quantization"]["group_size"] == 64
def test_config_metadata_8bit(self, tmp_path):
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_saved_model
config = _make_tiny_config()
model = WanModel(config)
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
model_path = tmp_path / "model.safetensors"
mx.save_safetensors(str(model_path), weights_dict)
with open(tmp_path / "config.json", "w") as f:
json.dump({}, f)
_quantize_saved_model(tmp_path, config, is_dual=False, bits=8, group_size=32)
with open(tmp_path / "config.json") as f:
cfg = json.load(f)
assert cfg["quantization"]["bits"] == 8
assert cfg["quantization"]["group_size"] == 32
def test_dual_model_quantization(self, tmp_path):
"""Verify dual-model quantization writes both model files."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_saved_model
config = _make_tiny_config()
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
model = WanModel(config)
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
mx.save_safetensors(str(tmp_path / name), weights_dict)
with open(tmp_path / "config.json", "w") as f:
json.dump({}, f)
_quantize_saved_model(tmp_path, config, is_dual=True, bits=4, group_size=64)
# Both files should now contain quantized weights (have .scales keys)
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
weights = mx.load(str(tmp_path / name))
has_scales = any("scales" in k for k in weights)
assert has_scales, f"{name} should have quantized layers"

View File

@@ -0,0 +1,334 @@
"""Tests for Wan RoPE frequency construction (Bug 6 regression tests).
These tests verify that the RoPE frequency table is built correctly by
concatenating three separate rope_params calls with different dimension
normalizations, matching the reference implementation.
Background: The reference Wan model constructs RoPE frequencies as:
d = dim // num_heads (128 for all Wan models)
freqs = cat([
rope_params(1024, d - 4*(d//6)), # temporal (dim=44, 22 freqs)
rope_params(1024, 2*(d//6)), # height (dim=42, 21 freqs)
rope_params(1024, 2*(d//6)), # width (dim=42, 21 freqs)
])
A previous incorrect fix used a single rope_params(1024, 128) call, which
gave height/width axes only medium/high frequencies instead of full-range.
This destroyed spatial position encoding and caused grey/artifact output.
"""
import mlx.core as mx
import numpy as np
import pytest
class TestRoPEFrequencyConstruction:
"""Verify WanModel builds RoPE frequencies matching the reference."""
def _get_model_freqs(self, dim=64, num_heads=4):
"""Instantiate a tiny WanModel and return its .freqs tensor."""
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan.model import WanModel
config = WanModelConfig()
config.dim = dim
config.ffn_dim = dim * 2
config.num_heads = num_heads
config.num_layers = 1
config.in_dim = 4
config.out_dim = 4
config.freq_dim = 32
config.text_dim = 32
config.text_len = 8
model = WanModel(config)
mx.eval(model.freqs)
return model.freqs, dim // num_heads
def test_freqs_shape(self):
"""Freqs should be [1024, head_dim//2, 2] regardless of construction."""
freqs, head_dim = self._get_model_freqs(dim=64, num_heads=4)
assert freqs.shape == (1024, head_dim // 2, 2)
def test_three_call_vs_single_call_differ(self):
"""Three separate rope_params calls must differ from single call."""
from mlx_video.models.wan.rope import rope_params
d = 128 # head_dim for all Wan models
# Reference: three separate calls
correct = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
# Wrong: single call
wrong = rope_params(1024, d)
mx.eval(correct, wrong)
assert correct.shape == wrong.shape
diff = np.abs(np.array(correct) - np.array(wrong)).max()
assert diff > 0.1, f"Three-call and single-call should differ significantly, got max diff {diff}"
def test_each_axis_starts_at_frequency_one(self):
"""Each axis (temporal/height/width) should have cos=1, sin=0 at position 0.
This verifies each axis gets its own independent frequency range
starting from theta^0 = 1.0 (i.e., exponent 0/dim).
"""
from mlx_video.models.wan.rope import rope_params
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
mx.eval(freqs)
f = np.array(freqs)
half_d = d // 2 # 64
d_t = half_d - 2 * (half_d // 3) # 22
d_h = half_d // 3 # 21
# At position 0, cos=1 and sin=0 for ALL frequency components
np.testing.assert_allclose(f[0, :, 0], 1.0, atol=1e-6, err_msg="cos at pos 0")
np.testing.assert_allclose(f[0, :, 1], 0.0, atol=1e-6, err_msg="sin at pos 0")
# At position 1, each axis should have its FIRST frequency near cos(1/theta^0)=cos(1)
# Temporal axis first freq
np.testing.assert_allclose(f[1, 0, 0], np.cos(1.0), atol=1e-5,
err_msg="temporal[0] cos at pos 1")
# Height axis first freq (starts at index d_t)
np.testing.assert_allclose(f[1, d_t, 0], np.cos(1.0), atol=1e-5,
err_msg="height[0] cos at pos 1")
# Width axis first freq (starts at index d_t + d_h)
np.testing.assert_allclose(f[1, d_t + d_h, 0], np.cos(1.0), atol=1e-5,
err_msg="width[0] cos at pos 1")
def test_height_width_frequencies_identical(self):
"""Height and width axes should have identical frequency tables.
Both use rope_params(1024, 2*(d//6)) = rope_params(1024, 42).
"""
from mlx_video.models.wan.rope import rope_params
d = 128
d_h_dim = 2 * (d // 6) # 42
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, d_h_dim),
rope_params(1024, d_h_dim),
], axis=1)
mx.eval(freqs)
f = np.array(freqs)
half_d = d // 2
d_t = half_d - 2 * (half_d // 3)
d_h = half_d // 3
height_freqs = f[:, d_t:d_t + d_h]
width_freqs = f[:, d_t + d_h:]
np.testing.assert_array_equal(height_freqs, width_freqs)
def test_frequency_range_per_axis(self):
"""Each axis should span a full frequency range, not a slice of one range.
With three-call construction, the inverse frequency at index 0 of each
axis should be 1.0 (theta^0). A single-call approach would give height
starting at ~0.04 and width at ~0.002 instead of 1.0.
"""
from mlx_video.models.wan.rope import rope_params
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
mx.eval(freqs)
f = np.array(freqs)
half_d = d // 2
d_t = half_d - 2 * (half_d // 3)
d_h = half_d // 3
# At position 1, the first frequency component of each axis should
# have significant magnitude (cos ≈ 0.54), not near-zero
pos1_t = f[1, 0, 0] # temporal first freq
pos1_h = f[1, d_t, 0] # height first freq
pos1_w = f[1, d_t + d_h, 0] # width first freq
assert pos1_t > 0.5, f"Temporal first freq at pos 1 should be >0.5, got {pos1_t}"
assert pos1_h > 0.5, f"Height first freq at pos 1 should be >0.5, got {pos1_h}"
assert pos1_w > 0.5, f"Width first freq at pos 1 should be >0.5, got {pos1_w}"
def test_model_freqs_match_manual_construction(self):
"""WanModel.freqs should match manually constructed three-call freqs."""
from mlx_video.models.wan.rope import rope_params
freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4)
d = head_dim # 16
freqs_manual = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
mx.eval(freqs_model, freqs_manual)
np.testing.assert_array_equal(
np.array(freqs_model), np.array(freqs_manual),
err_msg="WanModel.freqs should use three-call construction"
)
def test_model_freqs_14b_dimensions(self):
"""Verify freq dimensions for 14B-scale head_dim=128."""
from mlx_video.models.wan.rope import rope_params
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)), # dim=44 → 22 freq pairs
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
], axis=1)
mx.eval(freqs)
assert freqs.shape == (1024, 64, 2)
# Verify the split dimensions used by rope_apply
half_d = 64
d_t = half_d - 2 * (half_d // 3)
d_h = half_d // 3
d_w = half_d // 3
assert (d_t, d_h, d_w) == (22, 21, 21)
assert d_t + d_h + d_w == half_d
class TestRoPEFrequencyMatchesReference:
"""Cross-validate MLX RoPE against PyTorch reference implementation."""
@pytest.fixture
def has_torch(self):
try:
import torch
return True
except ImportError:
pytest.skip("PyTorch not installed")
def test_freqs_match_pytorch_reference(self, has_torch):
"""Numerically compare MLX and PyTorch frequency tables."""
import torch
from mlx_video.models.wan.rope import rope_params
d = 128
# PyTorch reference (from wan/modules/model.py)
def pt_rope_params(max_seq_len, dim, theta=10000):
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
ref = torch.cat([
pt_rope_params(1024, d - 4 * (d // 6)),
pt_rope_params(1024, 2 * (d // 6)),
pt_rope_params(1024, 2 * (d // 6)),
], dim=1)
# MLX
ours = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
mx.eval(ours)
our_cos = np.array(ours[:, :, 0])
our_sin = np.array(ours[:, :, 1])
ref_cos = ref.real.float().numpy()
ref_sin = ref.imag.float().numpy()
np.testing.assert_allclose(our_cos, ref_cos, atol=1e-6,
err_msg="cos mismatch vs PyTorch reference")
np.testing.assert_allclose(our_sin, ref_sin, atol=1e-6,
err_msg="sin mismatch vs PyTorch reference")
class TestRoPEApplyWithCorrectFreqs:
"""Test that rope_apply produces correct rotations with three-call freqs."""
def test_different_spatial_positions_get_different_rotations(self):
"""Adjacent height/width positions must produce different RoPE rotations.
This is the key property that was broken by the single-call bug:
height/width frequencies were too low to distinguish nearby positions.
"""
from mlx_video.models.wan.rope import rope_params, rope_apply
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
B, N = 1, 4
F, H, W = 1, 4, 4
L = F * H * W
# Use a constant input so differences come purely from RoPE
x = mx.ones((B, L, N, d))
out = rope_apply(x, [(F, H, W)], freqs)
mx.eval(out)
out_np = np.array(out[0])
# Position (0,0,0) vs (0,1,0) — different height
pos_00 = out_np[0 * H * W + 0 * W + 0] # (f=0, h=0, w=0)
pos_10 = out_np[0 * H * W + 1 * W + 0] # (f=0, h=1, w=0)
height_diff = np.abs(pos_00 - pos_10).max()
# Position (0,0,0) vs (0,0,1) — different width
pos_01 = out_np[0 * H * W + 0 * W + 1] # (f=0, h=0, w=1)
width_diff = np.abs(pos_00 - pos_01).max()
# Max diff should be >0.5 for both axes. With the bug, height was ~0.04
# and width was ~0.002. With correct freqs, both are ~1.3.
assert height_diff > 0.5, (
f"Adjacent height positions should differ significantly, got {height_diff:.4f}"
)
assert width_diff > 0.5, (
f"Adjacent width positions should differ significantly, got {width_diff:.4f}"
)
# Height and width should have identical frequency tables → same diffs
np.testing.assert_allclose(height_diff, width_diff, rtol=1e-5,
err_msg="Height and width should use identical frequency tables")
def test_precomputed_matches_online(self):
"""rope_precompute_cos_sin + rope_apply should match non-precomputed path."""
from mlx_video.models.wan.rope import (
rope_apply,
rope_params,
rope_precompute_cos_sin,
)
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
B, N = 2, 4
F, H, W = 2, 3, 4
L = F * H * W
grids = [(F, H, W), (F, H, W)]
x = mx.random.normal((B, L, N, d))
# Online (no precomputed)
out_online = rope_apply(x, grids, freqs)
# Precomputed
cos_sin = rope_precompute_cos_sin(grids, freqs)
out_precomp = rope_apply(x, grids, freqs, precomputed_cos_sin=cos_sin)
mx.eval(out_online, out_precomp)
np.testing.assert_allclose(
np.array(out_online), np.array(out_precomp), atol=1e-5,
err_msg="Precomputed and online RoPE should match"
)

917
tests/test_wan_scheduler.py Normal file
View File

@@ -0,0 +1,917 @@
"""Tests for Wan scheduler components."""
import math
import mlx.core as mx
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# Euler Scheduler Tests
# ---------------------------------------------------------------------------
class TestFlowMatchEulerScheduler:
def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
assert sched.num_train_timesteps == 1000
assert sched.timesteps is None
assert sched.sigmas is None
def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(40, shift=12.0)
mx.eval(sched.timesteps, sched.sigmas)
assert sched.timesteps.shape == (40,)
assert sched.sigmas.shape == (41,) # 40 steps + terminal
def test_timesteps_decreasing(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(40, shift=12.0)
mx.eval(sched.timesteps)
ts = np.array(sched.timesteps)
# Timesteps should be monotonically decreasing
assert np.all(np.diff(ts) < 0), f"Timesteps not decreasing: {ts[:5]}..."
def test_sigmas_decreasing(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(20, shift=1.0)
mx.eval(sched.sigmas)
sigmas = np.array(sched.sigmas)
assert np.all(np.diff(sigmas) <= 0), "Sigmas not decreasing"
def test_terminal_sigma_is_zero(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(20, shift=5.0)
mx.eval(sched.sigmas)
np.testing.assert_allclose(np.array(sched.sigmas[-1]), 0.0, atol=1e-6)
def test_shift_effect(self):
"""Larger shift should push sigmas toward higher values."""
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched1 = FlowMatchEulerScheduler()
sched2 = FlowMatchEulerScheduler()
sched1.set_timesteps(20, shift=1.0)
sched2.set_timesteps(20, shift=12.0)
mx.eval(sched1.sigmas, sched2.sigmas)
mean1 = np.mean(np.array(sched1.sigmas[:-1]))
mean2 = np.mean(np.array(sched2.sigmas[:-1]))
assert mean2 > mean1, "Higher shift should push sigmas higher"
def test_step_euler(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(10, shift=1.0)
mx.eval(sched.sigmas)
sample = mx.ones((1, 4, 2, 2, 2))
velocity = mx.ones((1, 4, 2, 2, 2)) * 0.5
timestep = sched.timesteps[0]
sigma = float(np.array(sched.sigmas[0]))
sigma_next = float(np.array(sched.sigmas[1]))
result = sched.step(velocity, timestep, sample)
mx.eval(result)
# Euler: x_next = x + (sigma_next - sigma) * v
expected = 1.0 + (sigma_next - sigma) * 0.5
np.testing.assert_allclose(
np.array(result).flatten()[0], expected, rtol=1e-4,
)
def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
assert sched._step_index == 0
sample = mx.ones((1, 1, 1, 1, 1))
vel = mx.zeros((1, 1, 1, 1, 1))
sched.step(vel, sched.timesteps[0], sample)
assert sched._step_index == 1
sched.step(vel, sched.timesteps[1], sample)
assert sched._step_index == 2
def test_reset(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1))
vel = mx.zeros((1, 1, 1, 1, 1))
sched.step(vel, sched.timesteps[0], sample)
assert sched._step_index == 1
sched.reset()
assert sched._step_index == 0
@pytest.mark.parametrize("steps", [10, 20, 40, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(steps, shift=12.0)
mx.eval(sched.timesteps, sched.sigmas)
assert sched.timesteps.shape == (steps,)
assert sched.sigmas.shape == (steps + 1,)
def test_full_denoise_loop(self):
"""Run a complete denoise loop with zero velocity -> sample unchanged."""
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 2, 1, 2, 2))
for i in range(5):
vel = mx.zeros_like(sample)
sample = sched.step(vel, sched.timesteps[i], sample)
mx.eval(sample)
# With zero velocity, sample should remain unchanged
np.testing.assert_allclose(np.array(sample), 1.0, atol=1e-5)
# ---------------------------------------------------------------------------
# Shared Sigma Schedule Tests
# ---------------------------------------------------------------------------
class TestComputeSigmas:
"""Tests for the shared _compute_sigmas helper."""
def test_length(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
assert len(sigmas) == 21 # num_steps + terminal
def test_terminal_zero(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(10, shift=1.0)
assert sigmas[-1] == 0.0
def test_starts_near_one(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
# Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0)
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3)
def test_decreasing(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
assert np.all(np.diff(sigmas) <= 0)
def test_matches_official_wan22(self):
"""Sigma schedule should match the official Wan2.2 FlowUniPCMultistepScheduler.
The reference creates the scheduler with shift=1 (identity) in the
constructor, then passes the actual shift to set_timesteps. This means
sigma_max/sigma_min come from the *unshifted* training schedule, and the
shift is applied only once (single-shift).
"""
from mlx_video.models.wan.scheduler import _compute_sigmas
steps, shift, N = 50, 5.0, 1000
sigmas = _compute_sigmas(steps, shift, N)
# Official single-shift: unshifted bounds, then shift once
alphas = np.linspace(1.0, 1.0 / N, N)[::-1]
sigmas_unshifted = 1.0 - alphas
sigma_max = float(sigmas_unshifted[0]) # 0.999
sigma_min = float(sigmas_unshifted[-1]) # 0.0
official = np.linspace(sigma_max, sigma_min, steps + 1)[:-1]
official = shift * official / (1.0 + (shift - 1.0) * official)
official = np.append(official, 0.0).astype(np.float32)
np.testing.assert_allclose(sigmas, official, atol=1e-6)
def test_shift_one_is_near_linear(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(10, shift=1.0)
# With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule)
# so schedule is nearly linear from ~0.999 to 0
expected = np.linspace(1, 0, 11).astype(np.float32)
np.testing.assert_allclose(sigmas, expected, atol=2e-3)
def test_all_schedulers_same_sigmas(self):
"""All three schedulers should produce identical sigma schedules."""
from mlx_video.models.wan.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
)
scheds = [
FlowMatchEulerScheduler(1000),
FlowDPMPP2MScheduler(1000),
FlowUniPCScheduler(1000),
]
for s in scheds:
s.set_timesteps(20, shift=5.0)
mx.eval(*[s.sigmas for s in scheds])
ref = np.array(scheds[0].sigmas)
for s in scheds[1:]:
np.testing.assert_allclose(np.array(s.sigmas), ref, atol=1e-6)
def test_all_schedulers_same_timesteps(self):
from mlx_video.models.wan.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
)
scheds = [
FlowMatchEulerScheduler(1000),
FlowDPMPP2MScheduler(1000),
FlowUniPCScheduler(1000),
]
for s in scheds:
s.set_timesteps(30, shift=12.0)
mx.eval(*[s.timesteps for s in scheds])
ref = np.array(scheds[0].timesteps)
for s in scheds[1:]:
np.testing.assert_allclose(np.array(s.timesteps), ref, atol=1e-3)
# ---------------------------------------------------------------------------
# DPM++ 2M Scheduler Tests
# ---------------------------------------------------------------------------
class TestFlowDPMPP2MScheduler:
def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
assert sched.num_train_timesteps == 1000
assert sched.lower_order_final is True
def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(20, shift=5.0)
mx.eval(sched.timesteps, sched.sigmas)
assert sched.timesteps.shape == (20,)
assert sched.sigmas.shape == (21,)
def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 4, 1, 2, 2))
vel = mx.zeros_like(sample)
assert sched._step_index == 0
sched.step(vel, sched.timesteps[0], sample)
assert sched._step_index == 1
sched.step(vel, sched.timesteps[1], sample)
assert sched._step_index == 2
def test_reset(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1))
sched.step(mx.zeros_like(sample), 0, sample)
sched.reset()
assert sched._step_index == 0
assert sched._prev_x0 is None
def test_full_loop_finite(self):
"""Full loop with constant velocity should produce finite output."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=1.0)
sample = mx.ones((1, 2, 1, 2, 2))
for i in range(10):
vel = mx.ones_like(sample) * 0.1
sample = sched.step(vel, sched.timesteps[i], sample)
mx.eval(sample)
assert np.isfinite(np.array(sample)).all()
def test_first_step_is_first_order(self):
"""First step should use 1st-order (no prev_x0 available)."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 4, 2, 4, 4))
vel = mx.random.normal(sample.shape)
# Before first step, no prev_x0
assert sched._prev_x0 is None
result = sched.step(vel, sched.timesteps[0], sample)
mx.eval(result)
# After first step, prev_x0 should be set
assert sched._prev_x0 is not None
def test_second_step_uses_correction(self):
"""After first step, DPM++ should have stored prev_x0 for correction."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 4, 1, 2, 2))
vel = mx.random.normal(sample.shape)
# Step 1
sample = sched.step(vel, sched.timesteps[0], sample)
mx.eval(sample)
x0_after_first = sched._prev_x0
# Step 2
vel = mx.random.normal(sample.shape)
sample = sched.step(vel, sched.timesteps[1], sample)
mx.eval(sample)
# prev_x0 should have been updated
x0_after_second = sched._prev_x0
assert x0_after_second is not None
# The stored x0 should differ from the first step's
assert not np.allclose(np.array(x0_after_first), np.array(x0_after_second), atol=1e-6)
def test_denoise_to_target(self):
"""Perfect oracle should denoise to target with any solver."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(20, shift=5.0)
target = mx.zeros((1, 2, 1, 4, 4))
latents = mx.random.normal(target.shape)
for i in range(20):
sigma = float(sched.sigmas[i].item())
v = latents / max(sigma, 1e-6) # perfect velocity for target=0
latents = sched.step(v, sched.timesteps[i], latents)
mx.eval(latents)
np.testing.assert_allclose(np.array(latents), 0.0, atol=1e-3)
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(steps, shift=5.0)
mx.eval(sched.timesteps, sched.sigmas)
assert sched.timesteps.shape == (steps,)
assert sched.sigmas.shape == (steps + 1,)
def test_terminal_sigma_produces_x0(self):
"""When sigma_next=0 the scheduler should return x0 directly."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1)) * 3.0
vel = mx.ones_like(sample) * 2.0
# Run through all steps; the last step has sigma_next=0
for i in range(5):
sample = sched.step(vel, sched.timesteps[i], sample)
mx.eval(sample)
# Final value should be finite
assert np.isfinite(np.array(sample)).all()
# ---------------------------------------------------------------------------
# UniPC Scheduler Tests
# ---------------------------------------------------------------------------
class TestFlowUniPCScheduler:
def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
assert sched.num_train_timesteps == 1000
assert sched.solver_order == 2
assert sched.lower_order_final is True
def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(30, shift=12.0)
mx.eval(sched.timesteps, sched.sigmas)
assert sched.timesteps.shape == (30,)
assert sched.sigmas.shape == (31,)
def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1))
vel = mx.zeros_like(sample)
assert sched._step_index == 0
sched.step(vel, 0, sample)
assert sched._step_index == 1
def test_reset(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1))
sched.step(mx.zeros_like(sample), 0, sample)
sched.reset()
assert sched._step_index == 0
assert sched._lower_order_nums == 0
assert sched._last_sample is None
assert all(m is None for m in sched._model_outputs)
def test_full_loop_finite(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(10, shift=1.0)
sample = mx.ones((1, 2, 1, 2, 2))
for i in range(10):
vel = mx.ones_like(sample) * 0.1
sample = sched.step(vel, sched.timesteps[i], sample)
mx.eval(sample)
assert np.isfinite(np.array(sample)).all()
def test_corrector_not_applied_first_step(self):
"""First step should skip the corrector (no history)."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True)
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 4, 1, 2, 2))
vel = mx.random.normal(sample.shape)
# Before step 0: no last_sample
assert sched._last_sample is None
sched.step(vel, sched.timesteps[0], sample)
# After step 0: last_sample should be set for corrector on step 1
assert sched._last_sample is not None
def test_corrector_applied_after_first_step(self):
"""Steps after the first should use the corrector when enabled."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True)
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 2, 1, 4, 4))
for i in range(3):
vel = mx.random.normal(sample.shape)
sample = sched.step(vel, sched.timesteps[i], sample)
mx.eval(sample)
# lower_order_nums should have increased
assert sched._lower_order_nums >= 2
def test_denoise_to_target(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(20, shift=5.0)
target = mx.zeros((1, 2, 1, 4, 4))
latents = mx.random.normal(target.shape)
for i in range(20):
sigma = float(sched.sigmas[i].item())
v = latents / max(sigma, 1e-6)
latents = sched.step(v, sched.timesteps[i], latents)
mx.eval(latents)
np.testing.assert_allclose(np.array(latents), 0.0, atol=1e-3)
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(steps, shift=5.0)
mx.eval(sched.timesteps, sched.sigmas)
assert sched.timesteps.shape == (steps,)
assert sched.sigmas.shape == (steps + 1,)
def test_disable_corrector(self):
"""Disabling corrector on step 0 should still work without error."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0])
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 2, 2))
for i in range(5):
vel = mx.ones_like(sample) * 0.1
sample = sched.step(vel, sched.timesteps[i], sample)
mx.eval(sample)
assert np.isfinite(np.array(sample)).all()
def test_solver_order_3(self):
"""Order 3 should work without error."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(solver_order=3, use_corrector=True)
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 2, 1, 2, 2))
for i in range(10):
vel = mx.random.normal(sample.shape)
sample = sched.step(vel, sched.timesteps[i], sample)
mx.eval(sample)
assert np.isfinite(np.array(sample)).all()
def test_corrector_rhos_c_not_hardcoded(self):
"""Corrector rhos_c should be computed via linalg.solve, not hardcoded 0.5."""
import math
# For 50-step schedule with shift=5.0, order 2 corrector at step 5:
# rhos_c[0] (history) should be ~0.07, NOT 0.5
# rhos_c[1] (D1_t) should be ~0.45, NOT 0.5
from mlx_video.models.wan.scheduler import _compute_sigmas
sigmas = _compute_sigmas(50, shift=5.0)
def _lambda(sigma):
if sigma >= 1.0:
return -math.inf
if sigma <= 0.0:
return math.inf
return math.log(1 - sigma) - math.log(sigma)
for step_idx in [5, 10, 25, 45]:
sigma_s0 = sigmas[step_idx - 1]
sigma_t = sigmas[step_idx]
lambda_s0 = _lambda(sigma_s0)
lambda_t = _lambda(sigma_t)
h = lambda_t - lambda_s0
hh = -h
sigma_sk = sigmas[step_idx - 2]
lambda_sk = _lambda(sigma_sk)
rk = (lambda_sk - lambda_s0) / h
rks = np.array([rk, 1.0])
h_phi_1 = math.expm1(hh)
B_h = h_phi_1
h_phi_k = h_phi_1 / hh - 1.0
factorial_i = 1
R_rows, b_vals = [], []
for j in range(1, 3):
R_rows.append(rks ** (j - 1))
b_vals.append(h_phi_k * factorial_i / B_h)
factorial_i *= j + 1
h_phi_k = h_phi_k / hh - 1.0 / factorial_i
R = np.stack(R_rows)
b = np.array(b_vals)
rhos_c = np.linalg.solve(R, b)
# History weight should be small (~0.07-0.09), not 0.5
assert rhos_c[0] < 0.15, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large"
assert rhos_c[0] > 0.0, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive"
# D1_t weight should be ~0.42-0.45, not 0.5
assert 0.3 < rhos_c[1] < 0.5, f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range"
# ---------------------------------------------------------------------------
# Scheduler Coherence Tests
# ---------------------------------------------------------------------------
class TestSchedulerCoherence:
"""Tests that Euler, DPM++, and UniPC schedulers produce coherent results.
All three schedulers should agree on shared structure (sigma schedules,
first-step behavior) and converge to the same result given perfect
velocity oracles, even though they use different update rules.
"""
@staticmethod
def _make_schedulers(steps=10, shift=5.0):
from mlx_video.models.wan.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
)
scheds = {
"euler": FlowMatchEulerScheduler(),
"dpm++": FlowDPMPP2MScheduler(),
"unipc": FlowUniPCScheduler(),
}
for s in scheds.values():
s.set_timesteps(steps, shift=shift)
return scheds
def test_identical_sigma_schedules(self):
"""All schedulers must use the same sigma schedule."""
scheds = self._make_schedulers(20, shift=5.0)
ref = np.array(scheds["euler"].sigmas)
for name in ("dpm++", "unipc"):
np.testing.assert_allclose(
np.array(scheds[name].sigmas),
ref,
atol=1e-6,
err_msg=f"{name} sigma schedule differs from Euler",
)
def test_identical_timesteps(self):
"""All schedulers must produce the same timestep sequence."""
scheds = self._make_schedulers(20, shift=5.0)
ref = np.array(scheds["euler"].timesteps)
for name in ("dpm++", "unipc"):
np.testing.assert_allclose(
np.array(scheds[name].timesteps),
ref,
atol=1e-6,
err_msg=f"{name} timesteps differ from Euler",
)
def test_first_step_matches_euler(self):
"""Step 0 (1st-order for all solvers) should match Euler exactly."""
mx.random.seed(42)
shape = (1, 4, 1, 4, 4)
noise = mx.random.normal(shape)
vel = mx.random.normal(shape)
scheds = self._make_schedulers(10, shift=5.0)
results = {}
for name, sched in scheds.items():
r = sched.step(vel, sched.timesteps[0], noise)
mx.eval(r)
results[name] = np.array(r)
np.testing.assert_allclose(
results["dpm++"], results["euler"], atol=1e-5,
err_msg="DPM++ step 0 should match Euler",
)
np.testing.assert_allclose(
results["unipc"], results["euler"], atol=1e-5,
err_msg="UniPC step 0 should match Euler",
)
def test_first_step_matches_across_shifts(self):
"""Step 0 should match Euler for different shift values."""
mx.random.seed(99)
shape = (1, 2, 1, 2, 2)
noise = mx.random.normal(shape)
vel = mx.random.normal(shape)
for shift in (1.0, 5.0, 12.0):
scheds = self._make_schedulers(10, shift=shift)
euler_r = scheds["euler"].step(vel, scheds["euler"].timesteps[0], noise)
dpm_r = scheds["dpm++"].step(vel, scheds["dpm++"].timesteps[0], noise)
unipc_r = scheds["unipc"].step(vel, scheds["unipc"].timesteps[0], noise)
mx.eval(euler_r, dpm_r, unipc_r)
np.testing.assert_allclose(
np.array(dpm_r), np.array(euler_r), atol=1e-5,
err_msg=f"DPM++ step 0 differs from Euler at shift={shift}",
)
np.testing.assert_allclose(
np.array(unipc_r), np.array(euler_r), atol=1e-5,
err_msg=f"UniPC step 0 differs from Euler at shift={shift}",
)
def test_oracle_all_converge_to_target(self):
"""Given a perfect velocity oracle v=x/sigma, all solvers should
denoise to approximately zero (the target)."""
mx.random.seed(7)
shape = (1, 2, 1, 4, 4)
noise = mx.random.normal(shape)
for name, sched in self._make_schedulers(20, shift=5.0).items():
latents = noise
for i in range(20):
sigma = float(sched.sigmas[i].item())
v = latents / max(sigma, 1e-8)
latents = sched.step(v, sched.timesteps[i], latents)
mx.eval(latents)
np.testing.assert_allclose(
np.array(latents), 0.0, atol=1e-3,
err_msg=f"{name} did not converge to target with oracle",
)
def test_oracle_higher_order_closer_to_target(self):
"""With few steps and a perfect oracle, higher-order solvers should
be at least as accurate as Euler."""
mx.random.seed(12)
shape = (1, 2, 1, 4, 4)
noise = mx.random.normal(shape)
steps = 5
errors = {}
for name, sched in self._make_schedulers(steps, shift=5.0).items():
latents = noise
for i in range(steps):
sigma = float(sched.sigmas[i].item())
v = latents / max(sigma, 1e-8)
latents = sched.step(v, sched.timesteps[i], latents)
mx.eval(latents)
errors[name] = float(mx.mean(mx.abs(latents)).item())
# Higher-order solvers should not be significantly worse than Euler
# (add small epsilon to handle near-zero errors from floating point noise)
eps = 1e-6
assert errors["dpm++"] <= errors["euler"] * 1.5 + eps, (
f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}"
)
assert errors["unipc"] <= errors["euler"] * 1.5 + eps, (
f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}"
)
def test_multistep_trajectory_similar_magnitude(self):
"""Over a full denoising loop with constant velocity, all solvers
should produce outputs of similar magnitude (not diverging)."""
mx.random.seed(42)
shape = (1, 4, 1, 4, 4)
noise = mx.random.normal(shape)
steps = 20
final_means = {}
for name, sched in self._make_schedulers(steps, shift=5.0).items():
latents = noise
for i in range(steps):
vel = latents * 0.1
latents = sched.step(vel, sched.timesteps[i], latents)
mx.eval(latents)
final_means[name] = float(mx.mean(mx.abs(latents)).item())
# All solvers should produce results within the same order of magnitude
vals = list(final_means.values())
ratio = max(vals) / max(min(vals), 1e-10)
assert ratio < 10.0, (
f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}"
)
def test_intermediate_values_finite(self):
"""Every intermediate latent value must be finite for all solvers."""
mx.random.seed(0)
shape = (1, 2, 1, 2, 2)
noise = mx.random.normal(shape)
for name, sched in self._make_schedulers(15, shift=5.0).items():
latents = noise
for i in range(15):
vel = mx.random.normal(shape)
latents = sched.step(vel, sched.timesteps[i], latents)
mx.eval(latents)
assert np.isfinite(np.array(latents)).all(), (
f"{name} produced non-finite values at step {i}"
)
def test_lambda_boundary_values(self):
"""_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0."""
from mlx_video.models.wan.scheduler import (
FlowDPMPP2MScheduler,
FlowUniPCScheduler,
)
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
assert cls._lambda(1.0) == -math.inf, (
f"{cls.__name__}._lambda(1.0) should be -inf"
)
assert cls._lambda(0.0) == math.inf, (
f"{cls.__name__}._lambda(0.0) should be +inf"
)
# Interior values should be finite
lam = cls._lambda(0.5)
assert math.isfinite(lam) and lam == 0.0, (
f"{cls.__name__}._lambda(0.5) should be 0.0"
)
def test_lambda_monotonically_decreasing(self):
"""_lambda(sigma) should decrease as sigma increases (more noise → lower SNR)."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
sigmas = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99]
lambdas = [FlowDPMPP2MScheduler._lambda(s) for s in sigmas]
for i in range(len(lambdas) - 1):
assert lambdas[i] > lambdas[i + 1], (
f"_lambda not decreasing: _lambda({sigmas[i]})={lambdas[i]} "
f"vs _lambda({sigmas[i+1]})={lambdas[i+1]}"
)
def test_step0_is_ddim_formula(self):
"""At sigma=1.0, the DPM++/UniPC first step should reduce to the
DDIM formula: x_next = sigma_next * x + (1 - sigma_next) * x0."""
mx.random.seed(55)
shape = (1, 2, 1, 2, 2)
sample = mx.random.normal(shape)
vel = mx.random.normal(shape)
for steps, shift in [(10, 5.0), (20, 12.0)]:
scheds = self._make_schedulers(steps, shift=shift)
sigma_next = float(scheds["euler"].sigmas[1].item())
sigma_cur = float(scheds["euler"].sigmas[0].item())
assert abs(sigma_cur - 1.0) < 1e-3, "First sigma should be ~1.0"
x0 = sample - sigma_cur * vel
expected = sigma_next * sample + (1.0 - sigma_next) * x0
mx.eval(expected)
for name in ("dpm++", "unipc"):
result = scheds[name].step(vel, scheds[name].timesteps[0], sample)
mx.eval(result)
np.testing.assert_allclose(
np.array(result), np.array(expected), atol=5e-4,
err_msg=f"{name} step 0 doesn't match DDIM formula (shift={shift})",
)
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_coherent_across_step_counts(self, steps):
"""All solvers should agree on step 0 regardless of total step count."""
mx.random.seed(77)
shape = (1, 2, 1, 2, 2)
noise = mx.random.normal(shape)
vel = mx.random.normal(shape)
scheds = self._make_schedulers(steps, shift=5.0)
results = {}
for name, sched in scheds.items():
r = sched.step(vel, sched.timesteps[0], noise)
mx.eval(r)
results[name] = np.array(r)
np.testing.assert_allclose(
results["dpm++"], results["euler"], atol=1e-5,
)
np.testing.assert_allclose(
results["unipc"], results["euler"], atol=1e-5,
)
def test_dpmpp_unipc_agree_on_step1(self):
"""After warmup, DPM++ and UniPC step 1 should be similar
(both use 2nd-order corrections based on the same model outputs)."""
mx.random.seed(42)
shape = (1, 4, 1, 4, 4)
noise = mx.random.normal(shape)
scheds = self._make_schedulers(10, shift=5.0)
# Run step 0 with same velocity
vel0 = mx.random.normal(shape)
for sched in scheds.values():
sched.step(vel0, sched.timesteps[0], noise)
# Run step 1 from same sample with same velocity
sample1 = scheds["euler"].step(vel0, scheds["euler"].timesteps[0], noise)
mx.eval(sample1)
vel1 = mx.random.normal(shape)
r_dpm = scheds["dpm++"].step(vel1, scheds["dpm++"].timesteps[1], sample1)
r_unipc = scheds["unipc"].step(vel1, scheds["unipc"].timesteps[1], sample1)
mx.eval(r_dpm, r_unipc)
# They won't be identical (different correction formulas) but should
# be in the same ballpark (within 50% of each other's magnitude)
mean_dpm = float(mx.mean(mx.abs(r_dpm)).item())
mean_unipc = float(mx.mean(mx.abs(r_unipc)).item())
ratio = max(mean_dpm, mean_unipc) / max(min(mean_dpm, mean_unipc), 1e-10)
assert ratio < 2.0, (
f"DPM++ and UniPC step 1 differ too much: "
f"DPM++={mean_dpm:.4f}, UniPC={mean_unipc:.4f}"
)
def test_reset_makes_solvers_reproducible(self):
"""After reset(), running the same loop should produce identical output."""
mx.random.seed(42)
shape = (1, 2, 1, 2, 2)
noise = mx.random.normal(shape)
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler, FlowUniPCScheduler
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
sched = cls()
sched.set_timesteps(5, shift=5.0)
# First run
latents = noise
for i in range(5):
vel = latents * 0.1
latents = sched.step(vel, sched.timesteps[i], latents)
mx.eval(latents)
result1 = np.array(latents)
# Reset and run again
sched.reset()
latents = noise
for i in range(5):
vel = latents * 0.1
latents = sched.step(vel, sched.timesteps[i], latents)
mx.eval(latents)
result2 = np.array(latents)
np.testing.assert_allclose(result1, result2, atol=1e-5,
err_msg=f"{cls.__name__} not reproducible after reset()")
# ---------------------------------------------------------------------------
# UniPC Corrector Default Tests
# ---------------------------------------------------------------------------
class TestUniPCCorrectorDefault:
"""Tests that the UniPC corrector is enabled by default,
matching official FlowUniPCMultistepScheduler behavior."""
def test_corrector_enabled_by_default(self):
"""Default construction should have corrector enabled."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
assert sched._use_corrector is True
def test_corrector_affects_output(self):
"""Corrector should produce different results than no corrector after step 1."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
mx.random.seed(42)
shape = (1, 4, 1, 4, 4)
noise = mx.random.normal(shape)
sched_corr = FlowUniPCScheduler(use_corrector=True)
sched_corr.set_timesteps(10, shift=5.0)
sched_no = FlowUniPCScheduler(use_corrector=False)
sched_no.set_timesteps(10, shift=5.0)
latent_corr = noise
latent_no = noise
for i in range(3):
vel = mx.random.normal(shape) * 0.1
latent_corr = sched_corr.step(vel, sched_corr.timesteps[i], latent_corr)
latent_no = sched_no.step(vel, sched_no.timesteps[i], latent_no)
mx.eval(latent_corr, latent_no)
diff = float(mx.abs(latent_corr - latent_no).max())
assert diff > 1e-6, f"Corrector had no effect (max diff={diff})"
def test_corrector_does_not_affect_first_step(self):
"""Step 0 should be identical regardless of corrector setting."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
mx.random.seed(42)
shape = (1, 4, 1, 4, 4)
noise = mx.random.normal(shape)
vel = mx.random.normal(shape)
sched_corr = FlowUniPCScheduler(use_corrector=True)
sched_corr.set_timesteps(10, shift=5.0)
sched_no = FlowUniPCScheduler(use_corrector=False)
sched_no.set_timesteps(10, shift=5.0)
r1 = sched_corr.step(vel, sched_corr.timesteps[0], noise)
r2 = sched_no.step(vel, sched_no.timesteps[0], noise)
mx.eval(r1, r2)
np.testing.assert_allclose(np.array(r1), np.array(r2), atol=1e-6)

173
tests/test_wan_t5.py Normal file
View File

@@ -0,0 +1,173 @@
"""Tests for T5 encoder components."""
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# T5 Encoder Tests
# ---------------------------------------------------------------------------
class TestT5LayerNorm:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5LayerNorm
norm = T5LayerNorm(64)
x = mx.random.normal((2, 10, 64))
out = norm(x)
mx.eval(out)
assert out.shape == (2, 10, 64)
def test_rms_normalization(self):
"""After T5LayerNorm with weight=1, RMS should be ~1."""
from mlx_video.models.wan.text_encoder import T5LayerNorm
norm = T5LayerNorm(128)
x = mx.random.normal((1, 5, 128)) * 5.0
out = norm(x)
mx.eval(out)
out_np = np.array(out[0])
for i in range(5):
rms = np.sqrt(np.mean(out_np[i] ** 2))
np.testing.assert_allclose(rms, 1.0, rtol=0.1)
class TestT5RelativeEmbedding:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
out = rel_emb(10, 10)
mx.eval(out)
assert out.shape == (1, 4, 10, 10) # [1, N, lq, lk]
def test_asymmetric_lengths(self):
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
out = rel_emb(8, 12)
mx.eval(out)
assert out.shape == (1, 4, 8, 12)
def test_symmetry(self):
"""Position bias should have structure (not all zeros/random)."""
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2)
out = rel_emb(6, 6)
mx.eval(out)
out_np = np.array(out[0]) # [N, lq, lk]
# Diagonal elements (position i attending to position i) should be consistent
# (same relative distance = 0 for all diagonal elements)
for h in range(2):
diag = np.diag(out_np[h])
np.testing.assert_allclose(diag, diag[0], atol=1e-5)
class TestT5Attention:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
x = mx.random.normal((1, 10, 64))
out = attn(x)
mx.eval(out)
assert out.shape == (1, 10, 64)
def test_no_scaling(self):
"""T5 attention famously has no sqrt(d) scaling. Verify structure."""
from mlx_video.models.wan.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
# No scale attribute (unlike standard attention)
assert not hasattr(attn, "scale")
def test_with_position_bias(self):
from mlx_video.models.wan.text_encoder import T5Attention, T5RelativeEmbedding
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
rel_emb = T5RelativeEmbedding(32, 4)
x = mx.random.normal((1, 10, 64))
pos_bias = rel_emb(10, 10)
out = attn(x, pos_bias=pos_bias)
mx.eval(out)
assert out.shape == (1, 10, 64)
def test_with_mask(self):
from mlx_video.models.wan.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
x = mx.random.normal((1, 10, 64))
mask = mx.ones((1, 10))
mask = mx.concatenate([mask[:, :7], mx.zeros((1, 3))], axis=1)
out = attn(x, mask=mask)
mx.eval(out)
assert out.shape == (1, 10, 64)
class TestT5FeedForward:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5FeedForward
ffn = T5FeedForward(64, 256)
x = mx.random.normal((1, 10, 64))
out = ffn(x)
mx.eval(out)
assert out.shape == (1, 10, 64)
def test_gated_structure(self):
"""T5 FFN is gated: gate(x) * fc1(x)."""
from mlx_video.models.wan.text_encoder import T5FeedForward
ffn = T5FeedForward(32, 64)
assert hasattr(ffn, "gate_proj")
assert hasattr(ffn, "fc1")
assert hasattr(ffn, "fc2")
class TestT5Encoder:
def setup_method(self):
mx.random.seed(42)
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
)
ids = mx.array([[1, 5, 10, 0, 0]])
mask = mx.array([[1, 1, 1, 0, 0]])
out = encoder(ids, mask=mask)
mx.eval(out)
assert out.shape == (1, 5, 64)
def test_shared_pos(self):
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=True,
)
assert encoder.pos_embedding is not None
for block in encoder.blocks:
assert block.pos_embedding is None
def test_per_layer_pos(self):
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
)
assert encoder.pos_embedding is None
for block in encoder.blocks:
assert block.pos_embedding is not None
def test_param_count(self):
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
)
num_params = sum(p.size for _, p in nn.utils.tree_flatten(encoder.parameters()))
assert num_params > 0
def test_without_mask(self):
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
)
ids = mx.array([[1, 5, 10]])
out = encoder(ids)
mx.eval(out)
assert out.shape == (1, 3, 64)

198
tests/test_wan_tiling.py Normal file
View File

@@ -0,0 +1,198 @@
"""Tests for Wan VAE tiled decoding."""
import mlx.core as mx
import numpy as np
import pytest
from mlx_video.models.ltx.video_vae.tiling import (
TilingConfig,
decode_with_tiling,
split_in_spatial,
split_in_temporal,
)
class TestNonCausalTemporal:
"""Tests for the causal_temporal=False path in decode_with_tiling."""
def test_split_spatial_for_temporal(self):
"""Non-causal temporal should use split_in_spatial (no causal shift)."""
intervals = split_in_spatial(8, 2, 20)
# No causal adjustment: starts should be evenly spaced
assert intervals.starts[0] == 0
for i in range(1, len(intervals.starts)):
assert intervals.starts[i] == intervals.starts[i - 1] + (8 - 2)
def test_causal_vs_noncausal_output_size(self):
"""Causal temporal gives 1+(T-1)*S frames, non-causal gives T*S."""
mx.random.seed(42)
b, c, t, h, w = 1, 4, 4, 4, 4
latents = mx.random.normal((b, c, t, h, w))
scale = 4
# Simple passthrough decoder: just repeat along dimensions
def dummy_decoder_causal(x, **kwargs):
b, c, t, h, w = x.shape
out_t = 1 + (t - 1) * scale
out_h = h * scale
out_w = w * scale
return mx.ones((b, 3, out_t, out_h, out_w))
def dummy_decoder_noncausal(x, **kwargs):
b, c, t, h, w = x.shape
out_t = t * scale
out_h = h * scale
out_w = w * scale
return mx.ones((b, 3, out_t, out_h, out_w))
config = TilingConfig.spatial_only(tile_size=128, overlap=64)
# Causal: 1 + (4-1)*4 = 13
out_causal = decode_with_tiling(
dummy_decoder_causal, latents, config,
spatial_scale=scale, temporal_scale=scale, causal_temporal=True,
)
mx.eval(out_causal)
assert out_causal.shape[2] == 1 + (t - 1) * scale # 13
# Non-causal: 4*4 = 16
out_noncausal = decode_with_tiling(
dummy_decoder_noncausal, latents, config,
spatial_scale=scale, temporal_scale=scale, causal_temporal=False,
)
mx.eval(out_noncausal)
assert out_noncausal.shape[2] == t * scale # 16
class TestWan22TiledDecoding:
"""Tests for Wan2.2 VAE tiled decoding."""
def _make_small_wan22_decoder(self):
"""Create a small Wan2.2 decoder for testing."""
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
# Use very small dimensions for fast testing
vae = Wan22VAEDecoder(z_dim=48, dim=16, dec_dim=16)
mx.eval(vae.parameters())
return vae
def test_decode_tiled_output_shape(self):
"""Tiled decode should produce same shape as non-tiled."""
mx.random.seed(42)
vae = self._make_small_wan22_decoder()
# Small input: [B=1, T=3, H=2, W=2, C=48]
z = mx.random.normal((1, 3, 2, 2, 48))
mx.eval(z)
# Non-tiled
out_regular = vae(z)
mx.eval(out_regular)
# Tiled (force tiling with very small tile sizes)
# Use spatial tile=32px (2 latent @ scale 16) and temporal=8 frames (2 latent @ scale 4)
config = TilingConfig(
spatial_config=None, # Don't tile spatially (input is tiny)
temporal_config=None, # Don't tile temporally (input is tiny)
)
# With no tiling config, decode_tiled should fall through to regular decode
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
mx.eval(out_tiled)
# Both should produce the same shape
assert out_regular.shape == out_tiled.shape, (
f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
)
def test_decode_tiled_falls_through_when_small(self):
"""When input is smaller than tile size, decode_tiled should produce same output as __call__."""
mx.random.seed(42)
vae = self._make_small_wan22_decoder()
# Input smaller than any tile size
z = mx.random.normal((1, 2, 2, 2, 48))
mx.eval(z)
out_regular = vae(z)
mx.eval(out_regular)
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
mx.eval(out_tiled)
np.testing.assert_allclose(
np.array(out_regular), np.array(out_tiled),
rtol=1e-4, atol=1e-4,
err_msg="Tiled decode should match regular decode for small inputs",
)
class TestWan21TiledDecoding:
"""Tests for Wan2.1 VAE tiled decoding."""
def _make_small_wan21_vae(self):
"""Create a small Wan2.1 VAE for testing."""
from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16)
mx.eval(vae.parameters())
return vae
def test_decode_tiled_output_shape(self):
"""Tiled decode should produce correct output shape."""
mx.random.seed(42)
vae = self._make_small_wan21_vae()
# [B=1, C=16, T=3, H=4, W=4]
z = mx.random.normal((1, 16, 3, 4, 4))
mx.eval(z)
out_regular = vae.decode(z)
mx.eval(out_regular)
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
mx.eval(out_tiled)
assert out_regular.shape == out_tiled.shape, (
f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
)
def test_decode_tiled_falls_through_when_small(self):
"""When input is smaller than tile size, decode_tiled should produce same output as decode."""
mx.random.seed(42)
vae = self._make_small_wan21_vae()
z = mx.random.normal((1, 16, 2, 4, 4))
mx.eval(z)
out_regular = vae.decode(z)
mx.eval(out_regular)
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
mx.eval(out_tiled)
np.testing.assert_allclose(
np.array(out_regular), np.array(out_tiled),
rtol=1e-4, atol=1e-4,
err_msg="Tiled decode should match regular decode for small inputs",
)
class TestWan21TemporalScale:
"""Verify Wan2.1 decoder temporal output is T*4 (non-causal)."""
def test_wan21_decoder_temporal_output(self):
"""Wan2.1 Decoder3d should produce T*4 temporal output (non-causal doubling)."""
from mlx_video.models.wan.vae import Decoder3d
# Small decoder for fast test
dec = Decoder3d(dim=16, z_dim=4, dim_mult=[1, 1, 1, 1], num_res_blocks=1,
temporal_upsample=[True, True, False])
mx.eval(dec.parameters())
x = mx.random.normal((1, 4, 3, 4, 4)) # T=3
mx.eval(x)
out = dec(x)
mx.eval(out)
# With two temporal 2× upsamples: T=3 → 6 → 12
assert out.shape[2] == 3 * 4, f"Expected T=12, got T={out.shape[2]}"

View File

@@ -0,0 +1,160 @@
"""Tests for Wan transformer block components."""
import mlx.core as mx
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# Transformer Block Tests
# ---------------------------------------------------------------------------
class TestWanFFN:
def test_output_shape(self):
from mlx_video.models.wan.transformer import WanFFN
ffn = WanFFN(64, 256)
x = mx.random.normal((2, 10, 64))
out = ffn(x)
mx.eval(out)
assert out.shape == (2, 10, 64)
def test_gelu_activation(self):
"""FFN should use GELU activation (non-linearity)."""
from mlx_video.models.wan.transformer import WanFFN
ffn = WanFFN(32, 128)
x = mx.ones((1, 1, 32)) * 2.0
out1 = ffn(x)
x2 = mx.ones((1, 1, 32)) * 4.0
out2 = ffn(x2)
mx.eval(out1, out2)
# Non-linear: 2x input should not give 2x output
assert not np.allclose(np.array(out2), np.array(out1) * 2.0, rtol=0.1)
class TestWanAttentionBlock:
def setup_method(self):
mx.random.seed(42)
self.dim = 64
self.ffn_dim = 128
self.num_heads = 4
def test_output_shape(self):
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan.rope import rope_params
block = WanAttentionBlock(
self.dim, self.ffn_dim, self.num_heads,
cross_attn_norm=True,
)
B, L = 1, 24
F, H, W = 2, 3, 4
x = mx.random.normal((B, L, self.dim))
e = mx.random.normal((B, L, 6, self.dim))
context = mx.random.normal((B, 16, self.dim))
freqs = rope_params(1024, self.dim // self.num_heads)
out = block(
x, e, seq_lens=[L], grid_sizes=[(F, H, W)],
freqs=freqs, context=context,
)
mx.eval(out)
assert out.shape == (B, L, self.dim)
def test_modulation_shape(self):
from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
assert block.modulation.shape == (1, 6, self.dim)
def test_with_cross_attn_norm(self):
from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock(
self.dim, self.ffn_dim, self.num_heads,
cross_attn_norm=True,
)
assert block.norm3 is not None
def test_without_cross_attn_norm(self):
from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock(
self.dim, self.ffn_dim, self.num_heads,
cross_attn_norm=False,
)
assert block.norm3 is None
def test_residual_connection(self):
"""Output should differ from zero even with small random init."""
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan.rope import rope_params
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
B, L = 1, 8
F, H, W = 2, 2, 2
x = mx.ones((B, L, self.dim))
e = mx.zeros((B, L, 6, self.dim))
context = mx.random.normal((B, 4, self.dim))
freqs = rope_params(1024, self.dim // self.num_heads)
out = block(x, e, [L], [(F, H, W)], freqs, context)
mx.eval(out)
# With residual connections, output should be close to input + corrections
assert not np.allclose(np.array(out), 0.0, atol=1e-3)
# ---------------------------------------------------------------------------
# Float32 Modulation Precision Tests
# ---------------------------------------------------------------------------
class TestFloat32Modulation:
"""Tests that modulation/gate operations are computed in float32,
matching official torch.amp.autocast('cuda', dtype=torch.float32)."""
def setup_method(self):
mx.random.seed(42)
self.dim = 64
def test_block_modulation_in_float32(self):
"""Modulation param starts random but should be usable as float32."""
from mlx_video.models.wan.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, 128, 4, cross_attn_norm=True)
assert block.modulation.dtype == mx.float32
def test_block_output_float32_with_bf16_modulation_input(self):
"""Even if e (time embedding) arrives as bf16, modulation should cast to f32."""
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan.rope import rope_params
block = WanAttentionBlock(self.dim, 128, 4)
B, L = 1, 8
x = mx.random.normal((B, L, self.dim))
e = mx.random.normal((B, L, 6, self.dim)).astype(mx.bfloat16)
ctx = mx.random.normal((B, 4, self.dim))
freqs = rope_params(1024, self.dim // 4)
out = block(x, e, [L], [(2, 2, 2)], freqs, ctx)
mx.eval(out)
assert out.dtype == mx.float32
assert np.isfinite(np.array(out)).all()
def test_head_modulation_float32(self):
"""Head modulation should be float32 even with bf16 e input."""
from mlx_video.models.wan.model import Head
head = Head(self.dim, 4, (1, 2, 2))
x = mx.random.normal((1, 8, self.dim))
e = mx.random.normal((1, 8, self.dim)).astype(mx.bfloat16)
out = head(x, e)
mx.eval(out)
assert np.isfinite(np.array(out.astype(mx.float32))).all()
def test_model_time_embedding_float32(self):
"""sinusoidal_embedding_1d output must be float32."""
from mlx_video.models.wan.model import sinusoidal_embedding_1d
t = mx.array([500.0])
emb = sinusoidal_embedding_1d(256, t)
mx.eval(emb)
assert emb.dtype == mx.float32
def test_model_per_token_time_embedding_float32(self):
"""Per-token time embeddings (I2V) should also be float32."""
from mlx_video.models.wan.model import sinusoidal_embedding_1d
t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4]
emb = sinusoidal_embedding_1d(256, t)
mx.eval(emb)
assert emb.dtype == mx.float32
assert emb.shape == (1, 4, 256)

951
tests/test_wan_vae.py Normal file
View File

@@ -0,0 +1,951 @@
"""Tests for Wan VAE 2.1 and 2.2 components."""
import math
import mlx.core as mx
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# VAE 2.1 Tests
# ---------------------------------------------------------------------------
class TestCausalConv3d:
def test_output_shape_stride1(self):
from mlx_video.models.wan.vae import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1)
# Initialize weights
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
x = mx.random.normal((1, 4, 3, 8, 8)) # [B, C, T, H, W]
out = conv(x)
mx.eval(out)
# With causal padding and padding=1 on spatial, dims should be preserved
assert out.shape[0] == 1
assert out.shape[1] == 8 # out_channels
assert out.shape[2] == 3 # T preserved
assert out.shape[3] == 8 # H preserved
assert out.shape[4] == 8 # W preserved
def test_output_shape_kernel1(self):
from mlx_video.models.wan.vae import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0)
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
x = mx.random.normal((1, 4, 2, 4, 4))
out = conv(x)
mx.eval(out)
assert out.shape == (1, 8, 2, 4, 4)
def test_causal_padding(self):
"""Causal conv should only use past/current frames, not future."""
from mlx_video.models.wan.vae import CausalConv3d
conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1)
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
conv.bias = mx.zeros((2,))
# Create input where only the first frame has signal
x = mx.zeros((1, 2, 4, 4, 4))
x_np = np.zeros((1, 2, 4, 4, 4), dtype=np.float32)
x_np[:, :, 0, :, :] = 1.0
x = mx.array(x_np)
out = conv(x)
mx.eval(out)
# Due to causal padding, the output at t=0 should only depend on t=0
class TestResidualBlock:
def test_same_dim(self):
from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 8)
x = mx.random.normal((1, 8, 2, 4, 4))
out = block(x)
mx.eval(out)
assert out.shape == (1, 8, 2, 4, 4)
def test_different_dim(self):
from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 16)
x = mx.random.normal((1, 8, 2, 4, 4))
out = block(x)
mx.eval(out)
assert out.shape == (1, 16, 2, 4, 4)
def test_shortcut_exists_when_dims_differ(self):
from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 16)
assert block.shortcut is not None
def test_no_shortcut_when_dims_same(self):
from mlx_video.models.wan.vae import ResidualBlock
block = ResidualBlock(8, 8)
assert block.shortcut is None
class TestAttentionBlock:
def test_output_shape(self):
from mlx_video.models.wan.vae import AttentionBlock
block = AttentionBlock(8)
x = mx.random.normal((1, 8, 2, 4, 4))
out = block(x)
mx.eval(out)
assert out.shape == (1, 8, 2, 4, 4)
def test_residual_connection(self):
from mlx_video.models.wan.vae import AttentionBlock
block = AttentionBlock(8)
x = mx.random.normal((1, 8, 1, 3, 3))
out = block(x)
mx.eval(x, out)
# Residual: output should not be zero even with random init
assert np.abs(np.array(out)).max() > 0
class TestWanVAE:
def test_instantiation(self):
from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16)
assert vae.z_dim == 16
assert vae.mean.shape == (16,)
assert vae.std.shape == (16,)
def test_normalization_stats(self):
from mlx_video.models.wan.vae import WanVAE, VAE_MEAN, VAE_STD
assert len(VAE_MEAN) == 16
assert len(VAE_STD) == 16
assert all(s > 0 for s in VAE_STD)
# ---------------------------------------------------------------------------
# Wan2.2 VAE Component Tests
# ---------------------------------------------------------------------------
class TestVAE22CausalConv3d:
"""Tests for vae22.CausalConv3d (channels-last)."""
def test_output_shape_k3(self):
from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(8, 16, kernel_size=3, padding=1)
x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C]
out = conv(x)
mx.eval(out)
assert out.shape == (1, 4, 8, 8, 16)
def test_output_shape_k1(self):
from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(8, 16, kernel_size=1)
x = mx.random.normal((1, 2, 4, 4, 8))
out = conv(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 16)
def test_temporal_causal(self):
"""Output at t=0 should not depend on t>0."""
from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(2, 2, kernel_size=3, padding=1)
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
conv.bias = mx.zeros(conv.bias.shape)
x = mx.zeros((1, 4, 4, 4, 2))
out_zero = conv(x)
mx.eval(out_zero)
t0_ref = np.array(out_zero[0, 0])
# Modify t=2..3; output at t=0 should be unchanged
x_mod = mx.concatenate([
x[:, :2],
mx.ones((1, 2, 4, 4, 2)),
], axis=1)
out_mod = conv(x_mod)
mx.eval(out_mod)
t0_mod = np.array(out_mod[0, 0])
np.testing.assert_allclose(t0_ref, t0_mod, atol=1e-5)
def test_channels_last_format(self):
"""Verify input/output are channels-last [B, T, H, W, C]."""
from mlx_video.models.wan.vae22 import CausalConv3d
conv = CausalConv3d(4, 8, kernel_size=3, padding=1)
x = mx.random.normal((2, 3, 6, 6, 4))
out = conv(x)
mx.eval(out)
assert out.shape[-1] == 8 # last dim = out_channels
class TestRMSNorm:
"""Tests for vae22.RMS_norm (actually L2 normalization)."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import RMS_norm
norm = RMS_norm(16)
x = mx.random.normal((2, 4, 4, 4, 16))
out = norm(x)
mx.eval(out)
assert out.shape == x.shape
def test_l2_normalization(self):
"""RMS_norm should normalize to unit L2 norm * sqrt(dim)."""
from mlx_video.models.wan.vae22 import RMS_norm
dim = 32
norm = RMS_norm(dim)
x = mx.random.normal((1, 1, 1, 1, dim)) * 5.0 # large values
out = norm(x)
mx.eval(out)
# After L2 norm * scale(=sqrt(dim)) * gamma(=1): ||out|| = sqrt(dim)
out_np = np.array(out).flatten()
l2 = np.linalg.norm(out_np)
np.testing.assert_allclose(l2, math.sqrt(dim), rtol=1e-3)
def test_scale_invariant(self):
"""Scaling input by constant should not change output (L2 norm property)."""
from mlx_video.models.wan.vae22 import RMS_norm
norm = RMS_norm(8)
x = mx.random.normal((1, 1, 1, 1, 8))
out1 = norm(x)
out2 = norm(x * 10.0)
mx.eval(out1, out2)
np.testing.assert_allclose(np.array(out1), np.array(out2), atol=1e-4)
def test_gamma_effect(self):
"""Non-unit gamma should scale output."""
from mlx_video.models.wan.vae22 import RMS_norm
norm = RMS_norm(4)
norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0])
x = mx.ones((1, 1, 1, 1, 4))
out = norm(x)
mx.eval(out)
# With gamma=2, each component is 2 * sqrt(4) * x/||x|| = 2 * 2 * 1/2 = 2
np.testing.assert_allclose(np.array(out).flatten(), 2.0, atol=1e-4)
class TestDupUp3D:
"""Tests for vae22.DupUp3D spatial/temporal upsampling."""
def test_spatial_only(self):
from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
out = up(x)
mx.eval(out)
assert out.shape == (1, 3, 8, 8, 4)
def test_temporal_and_spatial(self):
from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(16, 8, factor_t=2, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 16))
out = up(x)
mx.eval(out)
assert out.shape == (1, 6, 8, 8, 8)
def test_first_chunk_trims(self):
from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=2, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
out_normal = up(x, first_chunk=False)
out_trimmed = up(x, first_chunk=True)
mx.eval(out_normal, out_trimmed)
# first_chunk removes factor_t-1=1 temporal frame
assert out_normal.shape[1] == 6
assert out_trimmed.shape[1] == 5
def test_no_temporal_first_chunk_noop(self):
from mlx_video.models.wan.vae22 import DupUp3D
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
x = mx.random.normal((1, 3, 4, 4, 8))
out_normal = up(x, first_chunk=False)
out_trimmed = up(x, first_chunk=True)
mx.eval(out_normal, out_trimmed)
# factor_t=1, so first_chunk removes 0 frames
assert out_normal.shape == out_trimmed.shape
class TestVAE22Resample:
"""Tests for vae22.Resample (spatial/temporal upsampling)."""
def test_upsample2d_shape(self):
from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample2d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 2, 4, 4, 8))
out = r(x)
mx.eval(out)
assert out.shape == (1, 2, 8, 8, 8) # 2x spatial, same temporal
def test_upsample3d_shape(self):
from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 2, 4, 4, 8))
out = r(x)
mx.eval(out)
assert out.shape == (1, 4, 8, 8, 8) # 2x spatial + 2x temporal
def test_upsample3d_first_chunk(self):
from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 2, 4, 4, 8))
out = r(x, first_chunk=True)
mx.eval(out)
# first_chunk: 1 (bypass) + 2*(T-1) (interleaved) = 2T-1 = 3
assert out.shape == (1, 3, 8, 8, 8)
def test_upsample3d_first_chunk_single_frame(self):
"""Single-frame input with first_chunk: no temporal upsample."""
from mlx_video.models.wan.vae22 import Resample
r = Resample(8, "upsample3d")
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
x = mx.random.normal((1, 1, 4, 4, 8))
out = r(x, first_chunk=True)
mx.eval(out)
# Single frame with first_chunk: falls through to non-first path
# time_conv on 1 frame → 2 interleaved
assert out.shape == (1, 2, 8, 8, 8)
def test_upsample3d_first_frame_bypasses_time_conv(self):
"""First frame of first_chunk should NOT go through time_conv.
Official Wan2.2 skips time_conv for the very first frame entirely.
We verify this by checking that the first output frame depends only on
the first input frame (not on time_conv parameters).
"""
from mlx_video.models.wan.vae22 import Resample
C = 8
r = Resample(C, "upsample3d")
# Set time_conv weights to large values so its effect is detectable
r.time_conv.weight = mx.ones(r.time_conv.weight.shape) * 10.0
r.time_conv.bias = mx.zeros(r.time_conv.bias.shape)
# Set spatial conv to identity-like
r.resample_weight = mx.zeros(r.resample_weight.shape)
r.resample_bias = mx.zeros(r.resample_bias.shape)
x = mx.random.normal((1, 3, 2, 2, C))
out = r(x, first_chunk=True)
mx.eval(out)
# Output: 5 frames (1 bypass + 4 interleaved from 2 remaining)
assert out.shape[1] == 5
# First frame should be spatial upsample of x[:, 0:1] only.
# Run just the first frame through spatial upsample for reference
first_only = x[:, 0:1]
ref = r._upsample2x(first_only.reshape(1, 2, 2, C))
ref = mx.pad(ref, [(0, 0), (1, 1), (1, 1), (0, 0)])
ref = mx.conv_general(ref, r.resample_weight) + r.resample_bias
mx.eval(ref)
# Compare first output frame to reference
first_out = out[:, 0:1].reshape(1, out.shape[2], out.shape[3], C)
mx.eval(first_out)
assert mx.allclose(first_out, ref, atol=1e-5).item(), \
"First frame should bypass time_conv and match spatial-only upsample"
class TestVAE22ResidualBlock:
"""Tests for vae22.ResidualBlock."""
def test_same_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 8)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 8)
def test_different_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 16)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 16)
def test_shortcut_when_dims_differ(self):
from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 16)
assert block.shortcut is not None
def test_no_shortcut_same_dim(self):
from mlx_video.models.wan.vae22 import ResidualBlock
block = ResidualBlock(8, 8)
assert block.shortcut is None
class TestResidualBlockLayers:
"""Tests for vae22.ResidualBlockLayers naming convention."""
def test_layer_names_no_underscore_prefix(self):
"""Layer names must NOT start with underscore (MLX ignores them)."""
from mlx_video.models.wan.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 8)
params = dict(block.parameters())
# All param keys should use layer_N, not _layer_N
for key in params:
assert not key.startswith("_"), f"Parameter {key} starts with underscore"
def test_has_expected_layers(self):
from mlx_video.models.wan.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 16)
assert hasattr(block, "layer_0") # first RMS_norm
assert hasattr(block, "layer_2") # first CausalConv3d
assert hasattr(block, "layer_3") # second RMS_norm
assert hasattr(block, "layer_6") # second CausalConv3d
def test_forward_shape(self):
from mlx_video.models.wan.vae22 import ResidualBlockLayers
block = ResidualBlockLayers(8, 16)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 16)
class TestVAE22AttentionBlock:
"""Tests for vae22.AttentionBlock (per-frame 2D self-attention)."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import AttentionBlock
block = AttentionBlock(16)
block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01
block.proj_weight = mx.random.normal(block.proj_weight.shape) * 0.01
x = mx.random.normal((1, 2, 4, 4, 16))
out = block(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 16)
def test_residual_connection(self):
from mlx_video.models.wan.vae22 import AttentionBlock
block = AttentionBlock(8)
block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape)
block.proj_weight = mx.zeros(block.proj_weight.shape)
x = mx.ones((1, 1, 2, 2, 8))
out = block(x)
mx.eval(out)
# With zero weights, attention output is 0 → residual is identity
np.testing.assert_allclose(np.array(out), np.array(x), atol=1e-5)
class TestHead22:
"""Tests for vae22.Head22 output head."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import Head22
head = Head22(16, out_channels=12)
x = mx.random.normal((1, 2, 4, 4, 16))
out = head(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 12)
def test_layer_names_no_underscore(self):
"""Head layers must not use underscore prefix."""
from mlx_video.models.wan.vae22 import Head22
head = Head22(8)
assert hasattr(head, "layer_0") # RMS_norm
assert hasattr(head, "layer_2") # CausalConv3d
params = dict(head.parameters())
for key in params:
assert not key.startswith("_"), f"Head param {key} starts with underscore"
class TestUnpatchify:
"""Tests for vae22._unpatchify."""
def test_basic_shape(self):
from mlx_video.models.wan.vae22 import _unpatchify
x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2
out = _unpatchify(x, patch_size=2)
mx.eval(out)
assert out.shape == (1, 2, 8, 8, 3)
def test_patch_size_1_noop(self):
from mlx_video.models.wan.vae22 import _unpatchify
x = mx.random.normal((1, 2, 4, 4, 3))
out = _unpatchify(x, patch_size=1)
mx.eval(out)
np.testing.assert_array_equal(np.array(out), np.array(x))
def test_preserves_content(self):
"""Unpatchify should be a lossless rearrangement."""
from mlx_video.models.wan.vae22 import _unpatchify
x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32)
out = _unpatchify(x, patch_size=2)
mx.eval(out)
# All elements should be preserved
assert np.array(out).size == 48
assert set(np.array(out).flatten().tolist()) == set(range(48))
class TestDenormalizeLatents:
"""Tests for vae22.denormalize_latents."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import denormalize_latents
z = mx.random.normal((1, 2, 4, 4, 48))
out = denormalize_latents(z)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 48)
def test_custom_mean_std(self):
from mlx_video.models.wan.vae22 import denormalize_latents
z = mx.ones((1, 1, 1, 1, 4))
mean = mx.array([1.0, 2.0, 3.0, 4.0])
std = mx.array([0.5, 0.5, 0.5, 0.5])
out = denormalize_latents(z, mean=mean, std=std)
mx.eval(out)
# z * std + mean = 1*0.5 + [1,2,3,4] = [1.5, 2.5, 3.5, 4.5]
np.testing.assert_allclose(np.array(out).flatten(), [1.5, 2.5, 3.5, 4.5], atol=1e-5)
def test_uses_default_constants(self):
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD, denormalize_latents
# Should not raise with default constants
z = mx.zeros((1, 1, 1, 1, 48))
out = denormalize_latents(z)
mx.eval(out)
# z=0 → result = 0 * std + mean = mean
np.testing.assert_allclose(
np.array(out).flatten(),
np.array(VAE22_MEAN).flatten(),
atol=1e-5,
)
class TestVAE22NormConstants:
"""Tests for VAE22_MEAN and VAE22_STD constants."""
def test_dimensions(self):
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD
mx.eval(VAE22_MEAN, VAE22_STD)
assert VAE22_MEAN.shape == (48,)
assert VAE22_STD.shape == (48,)
def test_std_positive(self):
from mlx_video.models.wan.vae22 import VAE22_STD
mx.eval(VAE22_STD)
assert (np.array(VAE22_STD) > 0).all()
class TestWan22VAEDecoder:
"""Tests for the full Wan22VAEDecoder (tiny configuration)."""
def test_output_shape_small(self):
"""Tiny decoder should produce correct spatial/temporal output."""
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
# Use very small dims to keep test fast
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
# Latent: [B=1, T=3, H=2, W=2, C=4]
# Expected: temporal 3→5→9→9→9 (two temporal upsamples), spatial 2→4→8→16
z = mx.random.normal((1, 3, 2, 2, 4)) * 0.1
out = dec(z)
mx.eval(out)
# Output should have 3 RGB channels and be clipped to [-1, 1]
assert out.shape[-1] == 3
assert out.ndim == 5
assert np.array(out).min() >= -1.0
assert np.array(out).max() <= 1.0
def test_output_clipped(self):
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values
out = dec(z)
mx.eval(out)
assert np.array(out).min() >= -1.0 - 1e-6
assert np.array(out).max() <= 1.0 + 1e-6
class TestSanitizeWan22VAEWeights:
"""Tests for vae22.sanitize_wan22_vae_weights."""
def test_skip_encoder(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.layer.weight": mx.zeros((4,)),
"conv1.weight": mx.zeros((4,)),
"decoder.conv1.bias": mx.zeros((4,)),
}
out = sanitize_wan22_vae_weights(weights)
assert "encoder.layer.weight" not in out
assert "conv1.weight" not in out
assert "decoder.conv1.bias" in out
def test_sequential_index_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)),
"decoder.upsamples.0.upsamples.0.residual.6.bias": mx.zeros((8,)),
"decoder.head.0.gamma": mx.ones((4,)),
"decoder.head.2.bias": mx.zeros((12,)),
}
out = sanitize_wan22_vae_weights(weights)
assert "decoder.upsamples.0.upsamples.0.residual.layer_0.gamma" in out
assert "decoder.upsamples.0.upsamples.0.residual.layer_6.bias" in out
assert "decoder.head.layer_0.gamma" in out
assert "decoder.head.layer_2.bias" in out
def test_resample_conv_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)),
"decoder.upsamples.1.upsamples.3.resample.1.bias": mx.zeros((8,)),
}
out = sanitize_wan22_vae_weights(weights)
assert "decoder.upsamples.1.upsamples.3.resample_weight" in out
assert "decoder.upsamples.1.upsamples.3.resample_bias" in out
def test_attention_remapping(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = {
"decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)),
"decoder.middle.1.to_qkv.bias": mx.zeros((24,)),
"decoder.middle.1.proj.weight": mx.zeros((8, 8, 1, 1)),
"decoder.middle.1.proj.bias": mx.zeros((8,)),
}
out = sanitize_wan22_vae_weights(weights)
assert "decoder.middle.1.to_qkv_weight" in out
assert "decoder.middle.1.to_qkv_bias" in out
assert "decoder.middle.1.proj_weight" in out
assert "decoder.middle.1.proj_bias" in out
def test_conv3d_transpose(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
# Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I]
w = mx.zeros((16, 8, 3, 3, 3))
weights = {"decoder.conv1.weight": w}
out = sanitize_wan22_vae_weights(weights)
assert out["decoder.conv1.weight"].shape == (16, 3, 3, 3, 8)
def test_conv2d_transpose(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
# Conv2d weight: [O, I, H, W] → [O, H, W, I]
w = mx.zeros((8, 8, 3, 3))
weights = {"decoder.upsamples.0.upsamples.2.resample.1.weight": w}
out = sanitize_wan22_vae_weights(weights)
key = "decoder.upsamples.0.upsamples.2.resample_weight"
assert out[key].shape == (8, 3, 3, 8)
def test_gamma_squeeze(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
# gamma: (dim, 1, 1, 1) → (dim,)
w = mx.ones((16, 1, 1, 1))
weights = {"decoder.upsamples.0.upsamples.0.residual.0.gamma": w}
out = sanitize_wan22_vae_weights(weights)
key = "decoder.upsamples.0.upsamples.0.residual.layer_0.gamma"
assert out[key].shape == (16,)
class TestUpResidualBlock:
"""Tests for vae22.Up_ResidualBlock."""
def test_no_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
# No upsample: same shape
assert out.shape == (1, 2, 4, 4, 8)
def test_spatial_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
# 2x spatial upsample, no temporal
assert out.shape == (1, 2, 8, 8, 4)
def test_spatial_temporal_upsample(self):
from mlx_video.models.wan.vae22 import Up_ResidualBlock
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True)
x = mx.random.normal((1, 2, 4, 4, 8))
out = block(x)
mx.eval(out)
# 2x spatial + 2x temporal
assert out.shape == (1, 4, 8, 8, 4)
class TestPatchify:
"""Tests for _patchify and _unpatchify round-trip."""
def test_roundtrip(self):
from mlx_video.models.wan.vae22 import _patchify, _unpatchify
x = mx.random.normal((1, 1, 64, 64, 3))
p = _patchify(x, patch_size=2)
assert p.shape == (1, 1, 32, 32, 12)
back = _unpatchify(p, patch_size=2)
assert back.shape == x.shape
assert float(mx.abs(x - back).max()) == 0.0
def test_identity_patch_1(self):
from mlx_video.models.wan.vae22 import _patchify, _unpatchify
x = mx.random.normal((1, 2, 8, 8, 3))
assert _patchify(x, patch_size=1).shape == x.shape
assert _unpatchify(x, patch_size=1).shape == x.shape
class TestAvgDown3D:
"""Tests for AvgDown3D downsampling."""
def test_spatial_only(self):
from mlx_video.models.wan.vae22 import AvgDown3D
down = AvgDown3D(8, 16, factor_t=1, factor_s=2)
x = mx.random.normal((1, 2, 8, 8, 8))
out = down(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 16)
def test_temporal_and_spatial(self):
from mlx_video.models.wan.vae22 import AvgDown3D
down = AvgDown3D(8, 16, factor_t=2, factor_s=2)
x = mx.random.normal((1, 4, 8, 8, 8))
out = down(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 16)
def test_single_frame(self):
from mlx_video.models.wan.vae22 import AvgDown3D
down = AvgDown3D(8, 8, factor_t=2, factor_s=2)
x = mx.random.normal((1, 1, 8, 8, 8))
out = down(x)
mx.eval(out)
# T=1 with factor_t=2: pads to T=2 then averages → T=1
assert out.shape == (1, 1, 4, 4, 8)
class TestDownResidualBlock:
"""Tests for Down_ResidualBlock."""
def test_no_downsample(self):
from mlx_video.models.wan.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False)
x = mx.random.normal((1, 2, 8, 8, 8))
out = block(x)
mx.eval(out)
assert out.shape == (1, 2, 8, 8, 8)
def test_spatial_downsample(self):
from mlx_video.models.wan.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True)
x = mx.random.normal((1, 2, 8, 8, 8))
out = block(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 16)
def test_spatial_temporal_downsample(self):
from mlx_video.models.wan.vae22 import Down_ResidualBlock
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True)
x = mx.random.normal((1, 4, 8, 8, 8))
out = block(x)
mx.eval(out)
assert out.shape == (1, 2, 4, 4, 16)
class TestEncoder3d:
"""Tests for Encoder3d."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import Encoder3d
enc = Encoder3d(dim=16, z_dim=8)
x = mx.random.normal((1, 1, 16, 16, 12))
mx.eval(enc.parameters())
out = enc(x)
mx.eval(out)
# 3 spatial downsamples ÷8: 16→2
assert out.shape == (1, 1, 2, 2, 8)
def test_multi_frame(self):
from mlx_video.models.wan.vae22 import Encoder3d
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
x = mx.random.normal((1, 5, 16, 16, 12))
mx.eval(enc.parameters())
out = enc(x)
mx.eval(out)
# T: 5→3 (1st t_down) →2 (2nd t_down), spatial ÷8
assert out.shape[2:] == (2, 2, 8)
class TestWan22VAEEncoder:
"""Tests for Wan22VAEEncoder wrapper."""
def test_output_shape(self):
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16)
# Input: single image 32×32 (patchify÷2 → 16×16, then 3 spatial ÷8 → 2×2)
img = mx.random.normal((1, 1, 32, 32, 3))
mx.eval(enc.parameters())
z = enc(img)
mx.eval(z)
assert z.shape == (1, 1, 2, 2, 48)
def test_full_dim(self):
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=160)
img = mx.random.normal((1, 1, 64, 64, 3))
mx.eval(enc.parameters())
z = enc(img)
mx.eval(z)
# 64 / 16 = 4 (vae stride 16×)
assert z.shape == (1, 1, 4, 4, 48)
class TestNormalizeLatents:
"""Tests for normalize/denormalize latent roundtrip."""
def test_roundtrip(self):
from mlx_video.models.wan.vae22 import denormalize_latents, normalize_latents
z = mx.random.normal((1, 2, 4, 4, 48))
z_norm = normalize_latents(z)
z_back = denormalize_latents(z_norm)
mx.eval(z_back)
assert float(mx.abs(z - z_back).max()) < 1e-4
class TestVAEEncoderTemporalOrder:
"""Tests that VAE encoder uses (False, True, True) temporal downsample order,
matching official Wan2.2 vae2_2.py."""
def test_encoder_temporal_downsample_pattern(self):
"""Encoder3d with (False, True, True): T=5→5→3→2."""
from mlx_video.models.wan.vae22 import Encoder3d
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
x = mx.random.normal((1, 5, 16, 16, 12))
mx.eval(enc.parameters())
out = enc(x)
mx.eval(out)
assert out.shape[1] == 2
def test_wrapper_uses_correct_pattern(self):
"""Wan22VAEEncoder should use (False, True, True) temporal downsample."""
from mlx_video.models.wan.vae22 import Wan22VAEEncoder, Resample
enc = Wan22VAEEncoder(z_dim=48, dim=16)
down_blocks = enc.encoder.downsamples
found_modes = []
for block in down_blocks:
for layer in block.downsamples:
if isinstance(layer, Resample):
found_modes.append(layer.mode)
# First spatial-only, then two with temporal
assert found_modes[0] == "downsample2d"
assert any("3d" in m for m in found_modes)
def test_single_frame_encoder(self):
"""Single frame (T=1) should work with (False, True, True) pattern."""
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
enc = Wan22VAEEncoder(z_dim=48, dim=16)
img = mx.random.normal((1, 1, 32, 32, 3))
mx.eval(enc.parameters())
z = enc(img)
mx.eval(z)
assert z.shape[1] == 1
assert z.shape[-1] == 48
def test_wrong_order_gives_different_result(self):
"""(True, True, False) vs (False, True, True) produce different outputs."""
from mlx_video.models.wan.vae22 import Encoder3d
enc_correct = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
enc_wrong = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
x = mx.random.normal((1, 5, 16, 16, 12))
mx.eval(enc_correct.parameters())
mx.eval(enc_wrong.parameters())
out_correct = enc_correct(x)
out_wrong = enc_wrong(x)
mx.eval(out_correct, out_wrong)
# Both give T=2 but spatial processing path differs
assert out_correct.shape[1] == 2
assert out_wrong.shape[1] == 2
# ---------------------------------------------------------------------------
# VAE Encode → Decode Round-Trip Tests
# ---------------------------------------------------------------------------
class TestVAE21RoundTrip:
"""Encode→decode round-trip for Wan 2.1 VAE (channels-first)."""
def test_encode_decode_shape_and_values(self):
"""Encoder3d → Decoder3d: output shape matches input, values are finite."""
from mlx_video.models.wan.vae import Decoder3d, Encoder3d
z_dim = 4
dim = 8
# No temporal up/downsampling to keep the test simple
enc = Encoder3d(
dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False]
)
dec = Decoder3d(
dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False]
)
mx.eval(enc.parameters(), dec.parameters())
# [B=1, C=3, T=1, H=8, W=8]
x = mx.random.normal((1, 3, 1, 8, 8)) * 0.5
z = enc(x)
mx.eval(z)
# 3 spatial downsamples (÷8): H=1, W=1
assert z.shape == (1, z_dim, 1, 1, 1)
x_hat = dec(z)
mx.eval(x_hat)
# 3 spatial upsamples (×8): should recover original shape
assert x_hat.shape == x.shape
out_np = np.array(x_hat)
assert np.all(np.isfinite(out_np))
assert np.abs(out_np).max() < 1000
class TestVAE22RoundTrip:
"""Encode→decode round-trip for Wan 2.2 VAE (channels-last)."""
def test_encode_decode_shape_and_values(self):
"""Wan22VAEEncoder → Wan22VAEDecoder: shapes consistent, values in range."""
from mlx_video.models.wan.vae22 import (
Wan22VAEDecoder,
Wan22VAEEncoder,
denormalize_latents,
)
enc = Wan22VAEEncoder(z_dim=48, dim=16)
dec = Wan22VAEDecoder(z_dim=48, dec_dim=8)
mx.eval(enc.parameters(), dec.parameters())
# [B=1, T=1, H=32, W=32, C=3]
img = mx.random.normal((1, 1, 32, 32, 3)) * 0.5
z_norm = enc(img)
mx.eval(z_norm)
# patchify(÷2) + 3 spatial downsamples(÷8) = ÷16
assert z_norm.shape == (1, 1, 2, 2, 48)
z = denormalize_latents(z_norm)
out = dec(z)
mx.eval(out)
# 3 spatial upsamples(×8) + unpatchify(×2) = ×16
assert out.shape[0] == 1 # batch
assert out.shape[2] == 32 # H recovered
assert out.shape[3] == 32 # W recovered
assert out.shape[-1] == 3 # RGB
out_np = np.array(out)
assert np.all(np.isfinite(out_np))
assert out_np.min() >= -1.0 - 1e-6
assert out_np.max() <= 1.0 + 1e-6

19
tests/wan_test_helpers.py Normal file
View File

@@ -0,0 +1,19 @@
"""Shared test helpers for Wan test modules."""
def _make_tiny_config():
"""Create a tiny WanModelConfig for testing."""
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig()
# Override to tiny values
config.dim = 64
config.ffn_dim = 128
config.num_heads = 4
config.num_layers = 2
config.in_dim = 4
config.out_dim = 4
config.patch_size = (1, 2, 2)
config.freq_dim = 32
config.text_dim = 32
config.text_len = 8
return config

34
uv.lock generated
View File

@@ -2,7 +2,8 @@ version = 1
revision = 3 revision = 3
requires-python = ">=3.11" requires-python = ">=3.11"
resolution-markers = [ resolution-markers = [
"python_full_version >= '3.12'", "python_full_version >= '3.13'",
"python_full_version == '3.12.*'",
"python_full_version < '3.12'", "python_full_version < '3.12'",
] ]
@@ -614,6 +615,33 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" }, { url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" },
] ]
[[package]]
name = "imageio"
version = "2.37.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy" },
{ name = "pillow" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a3/6f/606be632e37bf8d05b253e8626c2291d74c691ddc7bcdf7d6aaf33b32f6a/imageio-2.37.2.tar.gz", hash = "sha256:0212ef2727ac9caa5ca4b2c75ae89454312f440a756fcfc8ef1993e718f50f8a", size = 389600, upload-time = "2025-11-04T14:29:39.898Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/fb/fe/301e0936b79bcab4cacc7548bf2853fc28dced0a578bab1f7ef53c9aa75b/imageio-2.37.2-py3-none-any.whl", hash = "sha256:ad9adfb20335d718c03de457358ed69f141021a333c40a53e57273d8a5bd0b9b", size = 317646, upload-time = "2025-11-04T14:29:37.948Z" },
]
[[package]]
name = "imageio-ffmpeg"
version = "0.6.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/44/bd/c3343c721f2a1b0c9fc71c1aebf1966a3b7f08c2eea8ed5437a2865611d6/imageio_ffmpeg-0.6.0.tar.gz", hash = "sha256:e2556bed8e005564a9f925bb7afa4002d82770d6b08825078b7697ab88ba1755", size = 25210, upload-time = "2025-01-16T21:34:32.747Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/da/58/87ef68ac83f4c7690961bce288fd8e382bc5f1513860fc7f90a9c1c1c6bf/imageio_ffmpeg-0.6.0-py3-none-macosx_10_9_intel.macosx_10_9_x86_64.whl", hash = "sha256:9d2baaf867088508d4a3458e61eeb30e945c4ad8016025545f66c4b5aaef0a61", size = 24932969, upload-time = "2025-01-16T21:34:20.464Z" },
{ url = "https://files.pythonhosted.org/packages/40/5c/f3d8a657d362cc93b81aab8feda487317da5b5d31c0e1fdfd5e986e55d17/imageio_ffmpeg-0.6.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:b1ae3173414b5fc5f538a726c4e48ea97edc0d2cdc11f103afee655c463fa742", size = 21113891, upload-time = "2025-01-16T21:34:00.277Z" },
{ url = "https://files.pythonhosted.org/packages/33/e7/1925bfbc563c39c1d2e82501d8372734a5c725e53ac3b31b4c2d081e895b/imageio_ffmpeg-0.6.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1d47bebd83d2c5fc770720d211855f208af8a596c82d17730aa51e815cdee6dc", size = 25632706, upload-time = "2025-01-16T21:33:53.475Z" },
{ url = "https://files.pythonhosted.org/packages/a0/2d/43c8522a2038e9d0e7dbdf3a61195ecc31ca576fb1527a528c877e87d973/imageio_ffmpeg-0.6.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c7e46fcec401dd990405049d2e2f475e2b397779df2519b544b8aab515195282", size = 29498237, upload-time = "2025-01-16T21:34:13.726Z" },
{ url = "https://files.pythonhosted.org/packages/a0/13/59da54728351883c3c1d9fca1710ab8eee82c7beba585df8f25ca925f08f/imageio_ffmpeg-0.6.0-py3-none-win32.whl", hash = "sha256:196faa79366b4a82f95c0f4053191d2013f4714a715780f0ad2a68ff37483cc2", size = 19652251, upload-time = "2025-01-16T21:34:06.812Z" },
{ url = "https://files.pythonhosted.org/packages/2c/c6/fa760e12a2483469e2bf5058c5faff664acf66cadb4df2ad6205b016a73d/imageio_ffmpeg-0.6.0-py3-none-win_amd64.whl", hash = "sha256:02fa47c83703c37df6bfe4896aab339013f62bf02c5ebf2dce6da56af04ffc0a", size = 31246824, upload-time = "2025-01-16T21:34:28.6Z" },
]
[[package]] [[package]]
name = "iniconfig" name = "iniconfig"
version = "2.3.0" version = "2.3.0"
@@ -772,6 +800,8 @@ name = "mlx-video"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "huggingface-hub" }, { name = "huggingface-hub" },
{ name = "imageio" },
{ name = "imageio-ffmpeg" },
{ name = "mlx" }, { name = "mlx" },
{ name = "mlx-vlm" }, { name = "mlx-vlm" },
{ name = "numpy" }, { name = "numpy" },
@@ -790,6 +820,8 @@ dev = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "huggingface-hub" }, { name = "huggingface-hub" },
{ name = "imageio", specifier = ">=2.37.2" },
{ name = "imageio-ffmpeg", specifier = ">=0.6.0" },
{ name = "mlx", specifier = ">=0.22.0" }, { name = "mlx", specifier = ">=0.22.0" },
{ name = "mlx-vlm" }, { name = "mlx-vlm" },
{ name = "numpy" }, { name = "numpy" },